1use {
8 crate::{conversion::pyobject_to_pathbuf, decode_source},
9 anyhow::{anyhow, Result},
10 pyo3::{
11 buffer::PyBuffer,
12 exceptions::{PyImportError, PyValueError},
13 ffi as pyffi,
14 prelude::*,
15 types::{PyBytes, PyDict, PyType},
16 PyNativeType, PyTraverseError, PyVisit,
17 },
18 std::{
19 collections::HashMap,
20 io::{BufReader, Cursor, Read, Seek},
21 path::{Path, PathBuf},
22 },
23 zip::read::ZipArchive,
24};
25
26pub struct ZipPythonModule {
28 pub is_package: bool,
30
31 pub source_path: Option<PathBuf>,
33
34 pub bytecode_path: Option<PathBuf>,
36}
37
38pub struct ZipIndex<R: Read + Seek> {
44 archive: ZipArchive<R>,
46 prefix: Option<PathBuf>,
47 members: HashMap<PathBuf, usize>,
48}
49
50impl<R: Read + Seek> ZipIndex<R> {
51 pub fn new(reader: R, prefix: Option<&Path>) -> Result<Self> {
56 let prefix = prefix.map(|p| p.to_path_buf());
57
58 let mut archive = ZipArchive::new(reader)?;
59
60 let mut members = HashMap::with_capacity(archive.len());
61
62 for index in 0..archive.len() {
63 let zf = archive.by_index_raw(index)?;
64
65 if let Some(name) = zf.enclosed_name() {
67 let name = if let Some(prefix) = &prefix {
70 if !name.starts_with(prefix) {
71 continue;
72 }
73
74 name.strip_prefix(prefix)?
75 } else {
76 name
77 };
78
79 members.insert(name.to_path_buf(), index);
80 }
81 }
82
83 Ok(Self {
84 archive,
85 prefix,
86 members,
87 })
88 }
89
90 pub fn find_python_module(&mut self, full_name: &str) -> Option<ZipPythonModule> {
94 let mut common_path = self.prefix.clone().unwrap_or_default();
95 common_path.extend(full_name.split('.'));
96
97 let package_py_path = common_path.join("__init__").with_extension("py");
98 let package_pyc_path = common_path.join("__init__").with_extension("pyc");
99
100 let non_package_py_path = common_path.with_extension("py");
101 let non_package_pyc_path = common_path.with_extension("pyc");
102
103 let mut is_package = false;
104 let mut source_path = None;
105 let mut bytecode_path = None;
106
107 if self.members.contains_key(&package_py_path) {
108 is_package = true;
109 source_path = Some(package_py_path);
110 }
111
112 if self.members.contains_key(&package_pyc_path) {
113 is_package = true;
114 bytecode_path = Some(package_pyc_path);
115 }
116
117 if is_package {
118 return Some(ZipPythonModule {
119 is_package,
120 source_path,
121 bytecode_path,
122 });
123 }
124
125 if self.members.contains_key(&non_package_py_path) {
126 source_path = Some(non_package_py_path);
127 }
128 if self.members.contains_key(&non_package_pyc_path) {
129 bytecode_path = Some(non_package_pyc_path);
130 }
131
132 if source_path.is_some() || bytecode_path.is_some() {
133 Some(ZipPythonModule {
134 is_package,
135 source_path,
136 bytecode_path,
137 })
138 } else {
139 None
140 }
141 }
142
143 pub fn resolve_path_content(&mut self, path: &Path) -> Result<Vec<u8>> {
147 let index = self
148 .members
149 .get(path)
150 .ok_or_else(|| anyhow!("path {} not present in archive", path.display()))?;
151
152 let mut zf = self.archive.by_index(*index)?;
153
154 let mut buffer = Vec::<u8>::with_capacity(zf.size() as _);
155 zf.read_to_end(&mut buffer)?;
156
157 Ok(buffer)
158 }
159}
160
161pub trait SeekableReader: Read + Seek + Send {}
162
163impl SeekableReader for Cursor<Vec<u8>> {}
164impl SeekableReader for Cursor<&[u8]> {}
165impl SeekableReader for BufReader<std::fs::File> {}
166
167#[pyclass(module = "oxidized_importer")]
175pub struct OxidizedZipFinder {
176 backing_pyobject: Option<Py<PyAny>>,
180
181 index: ZipIndex<Box<dyn SeekableReader>>,
187
188 zip_path: PathBuf,
195
196 module_spec_type: Py<PyAny>,
198
199 io_module: Py<PyModule>,
201
202 marshal_loads: Py<PyAny>,
204
205 builtins_compile: Py<PyAny>,
207
208 builtins_exec: Py<PyAny>,
210}
211
212impl OxidizedZipFinder {
213 pub fn new_from_data(
215 py: Python,
216 zip_path: PathBuf,
217 data: Vec<u8>,
218 prefix: Option<&Path>,
219 ) -> PyResult<Self> {
220 let reader: Box<dyn SeekableReader> = Box::new(Cursor::new(data));
221
222 let index = ZipIndex::new(reader, prefix)
223 .map_err(|e| PyValueError::new_err(format!("error indexing zip data: {}", e)))?;
224
225 Self::new_internal(py, index, zip_path, None)
226 }
227
228 pub fn new_from_pyobject(
230 py: Python,
231 zip_path: PathBuf,
232 source: &PyAny,
233 prefix: Option<&Path>,
234 ) -> PyResult<Self> {
235 let buffer = PyBuffer::<u8>::get(source)?;
236
237 let data = unsafe {
238 std::slice::from_raw_parts::<u8>(buffer.buf_ptr() as *const _, buffer.len_bytes())
239 };
240
241 let reader: Box<dyn SeekableReader> = Box::new(Cursor::new(data));
242
243 let index = ZipIndex::new(reader, prefix)
244 .map_err(|e| PyValueError::new_err(format!("error indexing zip data: {}", e)))?;
245
246 Self::new_internal(py, index, zip_path, Some(source.into_py(py)))
247 }
248
249 pub fn new_from_reader(
253 py: Python,
254 zip_path: PathBuf,
255 reader: Box<dyn SeekableReader>,
256
257 prefix: Option<&Path>,
258 ) -> PyResult<Self> {
259 let index = ZipIndex::new(reader, prefix)
260 .map_err(|e| PyValueError::new_err(format!("error indexing zip data: {}", e)))?;
261
262 Self::new_internal(py, index, zip_path, None)
263 }
264
265 fn new_internal(
266 py: Python,
267 index: ZipIndex<Box<dyn SeekableReader>>,
268 zip_path: PathBuf,
269 backing_pyobject: Option<Py<PyAny>>,
270 ) -> PyResult<Self> {
271 let importlib_bootstrap = py.import("_frozen_importlib")?;
272 let module_spec_type = importlib_bootstrap.getattr("ModuleSpec")?.into_py(py);
273 let io_module = py.import("_io")?.into_py(py);
274 let marshal_module = py.import("marshal")?;
275 let marshal_loads = marshal_module.getattr("loads")?.into_py(py);
276 let builtins_module = py.import("builtins")?;
277 let builtins_compile = builtins_module.getattr("compile")?.into_py(py);
278 let builtins_exec = builtins_module.getattr("exec")?.into_py(py);
279
280 Ok(Self {
281 backing_pyobject,
282 index,
283 zip_path,
284 module_spec_type,
285 io_module,
286 marshal_loads,
287 builtins_compile,
288 builtins_exec,
289 })
290 }
291
292 fn resolve_python_module(
293 slf: &mut PyRefMut<Self>,
294 full_name: &str,
295 ) -> PyResult<ZipPythonModule> {
296 if let Some(module) = slf.index.find_python_module(full_name) {
297 Ok(module)
298 } else {
299 Err(PyImportError::new_err((
300 "module not found in zip archive",
301 full_name.to_string(),
302 )))
303 }
304 }
305}
306
307#[pymethods]
308impl OxidizedZipFinder {
309 fn __traverse__(&self, visit: PyVisit) -> Result<(), PyTraverseError> {
310 if let Some(o) = &self.backing_pyobject {
311 visit.call(o)?;
312 }
313
314 visit.call(&self.module_spec_type)?;
315 visit.call(&self.io_module)?;
316 visit.call(&self.marshal_loads)?;
317 visit.call(&self.builtins_compile)?;
318 visit.call(&self.builtins_exec)?;
319
320 Ok(())
321 }
322
323 #[classmethod]
324 #[allow(unused)]
325 fn from_path(cls: &PyType, py: Python, path: &PyAny) -> PyResult<Self> {
326 let path = pyobject_to_pathbuf(py, path)?;
327
328 let f = Box::new(BufReader::new(std::fs::File::open(&path).map_err(|e| {
329 PyValueError::new_err(format!("failed to open path {}: {}", path.display(), e))
330 })?));
331
332 Self::new_from_reader(py, path, f, None)
333 }
334
335 #[classmethod]
336 #[args(path = "None")]
337 #[allow(unused)]
338 fn from_zip_data(
339 cls: &PyType,
340 py: Python,
341 source: &PyAny,
342 path: Option<&PyAny>,
343 ) -> PyResult<Self> {
344 let path = if let Some(o) = path {
345 o
346 } else {
347 let sys_module = py.import("sys")?;
348 sys_module.getattr("executable")?
349 };
350
351 let zip_path = pyobject_to_pathbuf(py, path)?;
352
353 Self::new_from_pyobject(py, zip_path, source, None)
354 }
355
356 #[args(target = "None")]
358 #[allow(unused)]
359 fn find_spec<'p>(
360 slf: &'p PyCell<Self>,
361 fullname: String,
362 path: &PyAny,
363 target: Option<&PyAny>,
364 ) -> PyResult<&'p PyAny> {
365 let py = slf.py();
368 let mut importer = slf.try_borrow_mut()?;
369
370 let module = if let Some(module) = importer.index.find_python_module(&fullname) {
371 module
372 } else {
373 return Ok(py.None().into_ref(py));
374 };
375
376 let module_spec_type = importer.module_spec_type.clone_ref(py);
377
378 let kwargs = PyDict::new(py);
379 kwargs.set_item("is_package", module.is_package)?;
380
381 let mut origin = importer.zip_path.clone();
383 if let Some(prefix) = &importer.index.prefix {
384 origin = origin.join(prefix);
385 }
386
387 if let Some(path) = module.source_path {
388 origin = origin.join(path);
389 } else if let Some(path) = module.bytecode_path {
390 origin = origin.join(path);
391 }
392
393 kwargs.set_item("origin", (&origin).into_py(py))?;
394
395 let spec = module_spec_type
396 .call(py, (&fullname, slf), Some(kwargs))?
397 .into_ref(py);
398
399 spec.setattr("has_location", true)?;
400 spec.setattr("cached", py.None())?;
401
402 if module.is_package {
405 let parent = origin.parent().ok_or_else(|| {
406 PyValueError::new_err(
407 "unable to determine dirname(origin); this should never happen",
408 )
409 })?;
410
411 let locations = vec![parent.into_py(py)];
412 spec.setattr("submodule_search_locations", locations)?;
413 }
414
415 Ok(spec)
416 }
417
418 #[allow(unused)]
419 #[args(path = "None")]
420 fn find_module<'p>(
421 slf: &'p PyCell<Self>,
422 fullname: String,
423 path: Option<&PyAny>,
424 ) -> PyResult<&'p PyAny> {
425 let find_spec = slf.getattr("find_spec")?;
428 let spec = find_spec.call((fullname, path), None)?;
429
430 if spec.is_none() {
431 Ok(slf.py().None().into_ref(slf.py()))
432 } else {
433 spec.getattr("loader")
434 }
435 }
436
437 fn invalidate_caches(&self) -> PyResult<()> {
438 Ok(())
439 }
440
441 #[allow(unused)]
446 fn create_module(&self, py: Python, spec: &PyAny) -> PyResult<Py<PyAny>> {
447 Ok(py.None())
449 }
450
451 fn exec_module(slf: &PyCell<Self>, module: &PyAny) -> PyResult<Py<PyAny>> {
452 let py = slf.py();
453
454 let name = module.getattr("__name__")?;
455 let full_name = name.extract::<String>()?;
456 let dict = module.getattr("__dict__")?;
457
458 let code = Self::get_code(slf, &full_name)?;
459
460 let importer = slf.try_borrow()?;
461 let builtins_exec = importer.builtins_exec.clone_ref(py);
464 std::mem::drop(importer);
465
466 builtins_exec.call(py, (code, dict), None)
467 }
468
469 fn get_code(slf: &PyCell<Self>, fullname: &str) -> PyResult<Py<PyAny>> {
474 let py = slf.py();
475 let mut importer = slf.try_borrow_mut()?;
476
477 let module: ZipPythonModule = Self::resolve_python_module(&mut importer, fullname)?;
478
479 if let Some(path) = module.bytecode_path {
480 let bytecode_data = importer.index.resolve_path_content(&path).map_err(|e| {
481 PyImportError::new_err((
482 format!("error reading module bytecode from zip: {}", e),
483 fullname.to_string(),
484 ))
485 })?;
486
487 let marshal_loads = importer.marshal_loads.clone_ref(py);
489 std::mem::drop(importer);
490
491 let bytecode = &bytecode_data[16..];
494 let ptr = unsafe {
495 pyffi::PyMemoryView_FromMemory(
496 bytecode.as_ptr() as _,
497 bytecode.len() as _,
498 pyffi::PyBUF_READ,
499 )
500 };
501
502 let bytecode_obj = if ptr.is_null() {
503 return Err(PyImportError::new_err((
504 "error coercing bytecode to memoryview",
505 fullname.to_string(),
506 )));
507 } else {
508 unsafe { PyObject::from_owned_ptr(py, ptr) }
509 };
510
511 marshal_loads.call1(py, (bytecode_obj,))
512 } else if let Some(path) = module.source_path {
513 let source_bytes: Vec<u8> =
514 importer.index.resolve_path_content(&path).map_err(|e| {
515 PyImportError::new_err((
516 format!("error reading module source from zip: {}", e),
517 fullname.to_string(),
518 ))
519 })?;
520
521 let builtins_compile = importer.builtins_compile.clone_ref(py);
523 std::mem::drop(importer);
524
525 let source_bytes = PyBytes::new(py, &source_bytes);
526
527 let crlf = PyBytes::new(py, b"\r\n");
528 let lf = PyBytes::new(py, b"\n");
529 let cr = PyBytes::new(py, b"\r");
530
531 let source_bytes = source_bytes.call_method("replace", (crlf, lf), None)?;
532 let source_bytes = source_bytes.call_method("replace", (cr, lf), None)?;
533
534 let kwargs = PyDict::new(py);
535 kwargs.set_item("dont_inherit", true)?;
536
537 builtins_compile.call(py, (source_bytes, path, "exec"), Some(kwargs))
538 } else {
539 Err(PyImportError::new_err((
540 "unable to resolve bytecode for module",
541 fullname.to_string(),
542 )))
543 }
544 }
545
546 fn get_source(slf: &PyCell<Self>, fullname: &str) -> PyResult<Py<PyAny>> {
547 let py = slf.py();
548 let mut importer = slf.try_borrow_mut()?;
549
550 let module = Self::resolve_python_module(&mut importer, fullname)?;
551
552 let source_bytes = if let Some(source_path) = module.source_path {
553 importer
554 .index
555 .resolve_path_content(&source_path)
556 .map_err(|e| {
557 PyImportError::new_err((
558 format!("error reading module source from zip: {}", e),
559 fullname.to_string(),
560 ))
561 })?
562 } else {
563 return Ok(py.None());
564 };
565
566 let source_bytes = PyBytes::new(py, &source_bytes);
567
568 let source = decode_source(py, importer.io_module.as_ref(py), source_bytes)?;
569
570 Ok(source.into_py(py))
571 }
572
573 fn is_package(slf: &PyCell<Self>, fullname: &str) -> PyResult<bool> {
574 let mut importer = slf.try_borrow_mut()?;
575
576 let module = Self::resolve_python_module(&mut importer, fullname)?;
577 Ok(module.is_package)
578 }
579
580 }