nafcodec_py/
pyfile.rs

1use std::fs::File;
2use std::io::Error as IoError;
3use std::io::Read;
4use std::io::Seek;
5use std::io::SeekFrom;
6use std::io::Write;
7use std::os::raw::c_char;
8
9use pyo3::exceptions::PyOSError;
10use pyo3::exceptions::PyTypeError;
11use pyo3::prelude::*;
12use pyo3::types::PyByteArray;
13use pyo3::types::PyBytes;
14use pyo3::types::PyInt;
15
16// ---------------------------------------------------------------------------
17
18#[macro_export]
19macro_rules! transmute_file_error {
20    ($self:ident, $e:ident, $msg:expr, $py:expr) => {{
21        // Attempt to transmute the Python OSError to an actual
22        // Rust `std::io::Error` using `from_raw_os_error`.
23        if $e.is_instance_of::<PyOSError>($py) {
24            if let Ok(code) = &$e.value($py).getattr("errno") {
25                if let Ok(n) = code.extract::<i32>() {
26                    return Err(IoError::from_raw_os_error(n));
27                }
28            }
29        }
30
31        // if the conversion is not possible for any reason we fail
32        // silently, wrapping the Python error, and returning a
33        // generic Rust error instead.
34        $e.restore($py);
35        Err(IoError::new(std::io::ErrorKind::Other, $msg))
36    }};
37}
38
39// ---------------------------------------------------------------------------
40
41/// A wrapper around a readable Python file borrowed within a GIL lifetime.
42#[derive(Debug)]
43pub struct PyFileRead {
44    file: PyObject,
45    has_readinto: bool,
46}
47
48impl PyFileRead {
49    pub fn from_ref<'py>(file: &Bound<'py, PyAny>) -> PyResult<PyFileRead> {
50        let py = file.py();
51
52        let implementation = py
53            .import(pyo3::intern!(py, "sys"))?
54            .getattr(pyo3::intern!(py, "implementation"))?
55            .getattr(pyo3::intern!(py, "name"))?;
56
57        if file.hasattr(pyo3::intern!(py, "readinto"))?
58            && implementation.eq(pyo3::intern!(py, "cpython"))?
59        {
60            let b = PyByteArray::new(py, &[]);
61            if let Ok(res) = file.call_method1(pyo3::intern!(py, "readinto"), (b,)) {
62                if res.downcast::<PyInt>().is_ok() {
63                    return Ok({
64                        PyFileRead {
65                            file: file.clone().unbind().into_any(),
66                            has_readinto: true,
67                        }
68                    });
69                }
70            }
71        }
72
73        let res = file.call_method1(pyo3::intern!(py, "read"), (0,))?;
74        if res.downcast::<PyBytes>().is_ok() {
75            Ok(PyFileRead {
76                file: file.clone().unbind().into_any(),
77                has_readinto: false,
78            })
79        } else {
80            let ty = res.get_type().name()?.to_string();
81            Err(PyTypeError::new_err(format!(
82                "expected bytes, found {}",
83                ty
84            )))
85        }
86    }
87
88    fn read_read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
89        Python::with_gil(|py| {
90            match self
91                .file
92                .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
93            {
94                Ok(obj) => {
95                    // Check `fh.read` returned bytes, else raise a `TypeError`.
96                    if let Ok(bytes) = obj.extract::<Bound<PyBytes>>(py) {
97                        let b = bytes.as_bytes();
98                        (&mut buf[..b.len()]).copy_from_slice(b);
99                        Ok(b.len())
100                    } else {
101                        let ty = obj.bind(py).get_type().name()?.to_string();
102                        let msg = format!("expected bytes, found {}", ty);
103                        PyTypeError::new_err(msg).restore(py);
104                        Err(IoError::new(
105                            std::io::ErrorKind::Other,
106                            "fh.read did not return bytes",
107                        ))
108                    }
109                }
110                Err(e) => {
111                    transmute_file_error!(self, e, "read method failed", py)
112                }
113            }
114        })
115    }
116
117    fn read_readinto(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
118        Python::with_gil(|py| {
119            let memview = unsafe {
120                let m = pyo3::ffi::PyMemoryView_FromMemory(
121                    buf.as_mut_ptr() as *mut c_char,
122                    buf.len() as isize,
123                    pyo3::ffi::PyBUF_WRITE,
124                );
125                PyObject::from_owned_ptr_or_err(py, m)?
126            };
127            match self
128                .file
129                .call_method1(py, pyo3::intern!(py, "readinto"), (memview,))
130            {
131                Ok(n) => match n.extract::<usize>(py) {
132                    Ok(n) => Ok(n),
133                    Err(_) => {
134                        let ty = n.bind(py).get_type().name()?.to_string();
135                        let msg = format!("expected int, found {}", ty);
136                        PyTypeError::new_err(msg).restore(py);
137                        Err(IoError::new(
138                            std::io::ErrorKind::Other,
139                            "fh.readinto did not return int",
140                        ))
141                    }
142                },
143                Err(e) => {
144                    transmute_file_error!(self, e, "readinto method failed", py)
145                }
146            }
147        })
148    }
149}
150
151impl Read for PyFileRead {
152    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
153        if self.has_readinto {
154            self.read_readinto(buf)
155        } else {
156            self.read_read(buf)
157        }
158    }
159}
160
161impl Seek for PyFileRead {
162    fn seek(&mut self, seek: SeekFrom) -> Result<u64, IoError> {
163        let (offset, whence) = match seek {
164            SeekFrom::Start(n) => (n as i64, 0),
165            SeekFrom::Current(n) => (n, 1),
166            SeekFrom::End(n) => (n, 2),
167        };
168        Python::with_gil(
169            |py| match self.file.call_method1(py, "seek", (offset, whence)) {
170                Ok(obj) => {
171                    if let Ok(n) = obj.extract::<u64>(py) {
172                        Ok(n)
173                    } else {
174                        let ty = obj.bind(py).get_type().name()?.to_string();
175                        let msg = format!("expected int, found {}", ty);
176                        PyTypeError::new_err(msg).restore(py);
177                        Err(IoError::new(
178                            std::io::ErrorKind::Other,
179                            "fh.seek did not return position",
180                        ))
181                    }
182                }
183                Err(e) => Err(IoError::new(std::io::ErrorKind::Unsupported, e.to_string())),
184            },
185        )
186    }
187}
188
189// ---------------------------------------------------------------------------
190
191/// A wrapper around a readable Python file borrowed within a GIL lifetime.
192#[derive(Debug)]
193pub struct PyFileWrite {
194    file: PyObject,
195}
196
197impl PyFileWrite {
198    pub fn from_ref<'py>(file: &Bound<'py, PyAny>) -> PyResult<PyFileWrite> {
199        let py = file.py();
200        file.call_method1(pyo3::intern!(py, "write"), (PyBytes::new(py, b""),))?;
201        Ok(Self {
202            file: file.clone().unbind().into_any(),
203        })
204    }
205}
206
207impl Write for PyFileWrite {
208    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
209        Python::with_gil(|py| {
210            // prepare a `memoryview` to expose the buffer
211            let memview = unsafe {
212                let m = pyo3::ffi::PyMemoryView_FromMemory(
213                    buf.as_ptr() as *mut c_char,
214                    buf.len() as isize,
215                    pyo3::ffi::PyBUF_READ,
216                );
217                PyObject::from_owned_ptr_or_err(py, m)?
218            };
219            // write the buffer contents to the file
220            match self
221                .file
222                .bind(py)
223                .call_method1(pyo3::intern!(py, "write"), (memview,))
224            {
225                Err(e) => {
226                    transmute_file_error!(self, e, "write method failed", py)
227                }
228                Ok(obj) => {
229                    if let Ok(n) = obj.extract::<usize>() {
230                        Ok(n)
231                    } else {
232                        let ty = obj.get_type().name()?.to_string();
233                        let msg = format!("expected int, found {}", ty);
234                        PyTypeError::new_err(msg).restore(py);
235                        Err(IoError::new(
236                            std::io::ErrorKind::Other,
237                            "readinto method did not return int",
238                        ))
239                    }
240                }
241            }
242        })
243    }
244
245    fn flush(&mut self) -> std::io::Result<()> {
246        Python::with_gil(
247            |py| match self.file.bind(py).call_method0(pyo3::intern!(py, "flush")) {
248                Ok(_) => Ok(()),
249                Err(e) => transmute_file_error!(self, e, "flush method failed", py),
250            },
251        )
252    }
253}
254
255// ---------------------------------------------------------------------------
256
257pub enum PyFileReadWrapper {
258    PyFile(PyFileRead),
259    File(File),
260}
261
262impl Read for PyFileReadWrapper {
263    fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
264        match self {
265            PyFileReadWrapper::PyFile(r) => r.read(buf),
266            PyFileReadWrapper::File(f) => f.read(buf),
267        }
268    }
269}
270
271impl Seek for PyFileReadWrapper {
272    fn seek(&mut self, seek: SeekFrom) -> Result<u64, IoError> {
273        match self {
274            PyFileReadWrapper::PyFile(r) => r.seek(seek),
275            PyFileReadWrapper::File(f) => f.seek(seek),
276        }
277    }
278}
279
280// ---------------------------------------------------------------------------
281
282pub enum PyFileWriteWrapper {
283    PyFile(PyFileWrite),
284    File(File),
285}
286
287impl Write for PyFileWriteWrapper {
288    fn write(&mut self, buf: &[u8]) -> Result<usize, IoError> {
289        match self {
290            PyFileWriteWrapper::PyFile(f) => f.write(buf),
291            PyFileWriteWrapper::File(f) => f.write(buf),
292        }
293    }
294
295    fn flush(&mut self) -> Result<(), IoError> {
296        match self {
297            PyFileWriteWrapper::PyFile(f) => f.flush(),
298            PyFileWriteWrapper::File(f) => f.flush(),
299        }
300    }
301}