pyo3_filelike/
lib.rs

1/// This crate provides a wrapper for file-like objects in Python that implement the `io`
2/// protocol, and allows them to be used as `Read`, `Write`, and `Seek` traits in Rust.
3///
4/// Objects need to implement the `io` protocol. For the `Read` trait, the object must have a
5/// `read` method that takes a single argument, the number of bytes to read, and returns a `bytes`
6/// object. For the `Write` trait, the object must have a `write` method that takes a single
7/// argument, a `bytes` object. For the `Seek` trait, the object must have a `seek` method that
8/// takes two arguments, the offset and whence, and returns the new position.
9///
10/// The `mode` attribute is checked to ensure that the file is opened in binary mode.
11/// If the `mode` attribute is not present, the file is assumed to be opened in binary mode.
12///
13/// The `AsFd` trait is implemented for Unix-like systems, allowing the file to be used with
14/// functions that take a file descriptor. The `fileno` method is called to get the file
15/// descriptor.
16///
17/// # Example
18///
19/// ```rust
20/// use pyo3::prelude::*;
21/// use std::io::{Read, Write};
22///
23/// pyo3::Python::with_gil(|py| -> PyResult<()> {
24///    let io = py.import("io")?;
25///    let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
26///    let mut file = pyo3_filelike::PyBinaryFile::from(file);
27///    let mut buf = [0u8; 5];
28///    file.read_exact(&mut buf)?;
29///    assert_eq!(&buf, b"hello");
30///    Ok(())
31/// }).unwrap();
32
33use pyo3::prelude::*;
34use pyo3::exceptions::{PyValueError, PyAttributeError};
35use std::io::{Read, Write, Seek};
36#[cfg(any(unix, target_os = "wasi"))]
37use std::os::fd::{AsFd, BorrowedFd, RawFd};
38
39/// Rust wrapper for a Python file-like object that implements the `io` protocol.
40#[derive(Debug)]
41pub struct PyBinaryFile(pyo3::PyObject);
42
43impl ToPyObject for PyBinaryFile {
44    fn to_object(&self, py: Python) -> PyObject {
45        self.0.clone_ref(py)
46    }
47}
48
49impl Clone for PyBinaryFile{
50    fn clone(&self) -> Self {
51        Python::with_gil(|py| {
52            PyBinaryFile::from(self.0.clone_ref(py))
53        })
54    }
55}
56
57impl Read for PyBinaryFile {
58    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
59        Python::with_gil(|py| {
60            let bytes = self.0.call_method1(py, "read", (buf.len(), ))?;
61            let bytes = bytes.extract::<&[u8]>(py)?;
62            let len = std::cmp::min(buf.len(), bytes.len());
63            buf[..len].copy_from_slice(&bytes[..len]);
64            Ok(len)
65        })
66    }
67}
68
69impl Write for PyBinaryFile {
70    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
71        Python::with_gil(|py| {
72            let bytes = pyo3::types::PyBytes::new_bound(py, buf);
73            self.0.call_method1(py, "write", (bytes, ))?;
74            Ok(buf.len())
75        })
76    }
77
78    fn flush(&mut self) -> std::io::Result<()> {
79        Python::with_gil(|py| {
80            self.0.call_method0(py, "flush")?;
81            Ok(())
82        })
83    }
84}
85
86impl Seek for PyBinaryFile {
87    fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
88        Python::with_gil(|py| {
89            let (whence, offset) = match pos {
90                std::io::SeekFrom::Start(offset) => (0, offset as i64),
91                std::io::SeekFrom::End(offset) => (2, offset),
92                std::io::SeekFrom::Current(offset) => (1, offset),
93            };
94            let pos = self.0.call_method1(py, "seek", (offset, whence))?;
95            let pos = pos.extract::<u64>(py)?;
96            Ok(pos)
97        })
98    }
99}
100
101#[cfg(any(unix, target_os = "wasi"))]
102impl AsFd for PyBinaryFile {
103    fn as_fd(&self) -> BorrowedFd<'_> {
104        Python::with_gil(|py| {
105            let fd = self.0.call_method0(py, "fileno")?;
106            let fd = fd.extract::<RawFd>(py)?;
107            Ok::<BorrowedFd<'_>, PyErr>(unsafe { BorrowedFd::borrow_raw(fd) })
108        }).unwrap()
109    }
110}
111
112impl PyBinaryFile {
113    fn new(file: PyObject) -> PyResult<Self> {
114        let o = PyBinaryFile(file);
115        o.check_mode('b')?;
116        Ok(o)
117    }
118
119    fn check_mode(&self, expected_mode: char) -> PyResult<()> {
120        Python::with_gil(|py| {
121            match self.0.getattr(py, "mode") {
122                Ok(mode) => {
123                    if mode.extract::<&str>(py)?.contains(expected_mode) {
124                        return Ok(());
125                    }
126                    Err(PyValueError::new_err(format!(
127                        "file must be opened in {} mode",
128                        expected_mode
129                    )))
130                }
131                Err(e) if e.is_instance_of::<PyAttributeError>(py) => {
132                    // Assume binary mode if mode attribute is not present
133                    Ok(())
134                }
135                Err(e) => return Err(e),
136            }
137        })
138    }
139}
140
141impl From<pyo3::PyObject> for PyBinaryFile {
142    fn from(obj: pyo3::PyObject) -> Self {
143        PyBinaryFile::new(obj).unwrap()
144    }
145}
146
147impl From<Bound<'_, PyAny>> for PyBinaryFile {
148    fn from(obj: Bound<'_, PyAny>) -> Self {
149        PyBinaryFile::new(obj.into()).unwrap()
150    }
151}
152
153/// Rust wrapper for a Python text file-like object that implements the `io` protocol.
154///
155/// This wrapper is similar to `PyBinaryFile`, but it assumes that the file is text and
156/// returns the byte representation of the text.
157///
158/// The Python file-like object must have a `read` method that returns either a
159/// `bytes` object or a `str` object.
160///
161/// Seek operations are not supported on text files, since the equivalent Python
162/// file-like objects seek by characters, not bytes.
163#[derive(Debug)]
164pub struct PyTextFile {
165    inner: PyObject,
166    buffer: Vec<u8>
167}
168
169impl PyTextFile {
170    /// Create a new `PyTextFile` from a Python file-like object and an encoding.
171    pub fn new(file: PyObject) -> PyResult<Self> {
172        Ok(PyTextFile{ inner: file, buffer: Vec::new() })
173    }
174}
175
176impl Read for PyTextFile {
177    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
178        Python::with_gil(|py| {
179            if self.buffer.len() >= buf.len() {
180                buf.copy_from_slice(&self.buffer[..buf.len()]);
181                self.buffer.drain(..buf.len());
182                return Ok(buf.len());
183            }
184            let text = self.inner.call_method1(py, "read", (buf.len() - self.buffer.len(), ))?;
185            let extra = if let Ok(t) = text.extract::<&str>(py) {
186                t.as_bytes()
187            } else {
188                text.extract::<&[u8]>(py)?
189            };
190
191            self.buffer.extend_from_slice(extra);
192
193            let len = std::cmp::min(self.buffer.len(), buf.len());
194            buf[..len].copy_from_slice(&self.buffer[..len]);
195            self.buffer.drain(..len);
196            Ok(len)
197        })
198    }
199}
200
201impl From<pyo3::PyObject> for PyTextFile {
202    fn from(obj: pyo3::PyObject) -> Self {
203        PyTextFile::new(obj).unwrap()
204    }
205}
206
207impl From<Bound<'_, PyAny>> for PyTextFile {
208    fn from(obj: Bound<'_, PyAny>) -> Self {
209        PyTextFile::new(obj.into()).unwrap()
210    }
211}
212
213impl ToPyObject for PyTextFile {
214    fn to_object(&self, py: Python) -> PyObject {
215        self.inner.clone_ref(py)
216    }
217}
218
219impl Clone for PyTextFile{
220    fn clone(&self) -> Self {
221        Python::with_gil(|py| {
222            PyTextFile::from(self.inner.clone_ref(py))
223        })
224    }
225}
226
227
228
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_read() {
236        Python::with_gil(|py| -> PyResult<()> {
237            let io = py.import_bound("io")?;
238            let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
239            let mut file = PyBinaryFile::from(file);
240            let mut buf = [0u8; 5];
241            file.read_exact(&mut buf)?;
242            assert_eq!(&buf, b"hello");
243            Ok(())
244        }).unwrap();
245    }
246
247    #[test]
248    fn test_read_notexact() {
249        Python::with_gil(|py| -> PyResult<()> {
250            let io = py.import_bound("io")?;
251            let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
252            let mut file = PyBinaryFile::from(file);
253            let mut buf = [0u8; 10];
254            let n = file.read(&mut buf)?;
255            assert_eq!(n, 5);
256            assert_eq!(&buf[..n], b"hello");
257            Ok(())
258        }).unwrap();
259    }
260
261    #[test]
262    fn test_read_eof() {
263        Python::with_gil(|py| -> PyResult<()> {
264            let io = py.import_bound("io")?;
265            let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
266            let mut file = PyBinaryFile::from(file);
267            let mut buf = [0u8; 6];
268            let err = file.read_exact(&mut buf).unwrap_err();
269            assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
270            Ok(())
271        }).unwrap();
272    }
273
274    #[test]
275    fn test_read_to_end() {
276        Python::with_gil(|py| -> PyResult<()> {
277            let io = py.import_bound("io")?;
278            let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
279            let mut file = PyBinaryFile::from(file);
280            let mut buf = Vec::new();
281            file.read_to_end(&mut buf)?;
282            assert_eq!(&buf, b"hello");
283            Ok(())
284        }).unwrap();
285    }
286
287    #[test]
288    fn test_write() {
289        Python::with_gil(|py| {
290            let io = py.import_bound("io")?;
291            let file = io.call_method1("BytesIO", (&b""[..], ))?;
292            let mut file = PyBinaryFile::from(file);
293            file.write_all(b"hello ")?;
294            file.write_all(b"world")?;
295            assert_eq!(file.0.call_method0(py, "getvalue")?.extract::<&[u8]>(py)?, b"hello world");
296            Ok::<(), PyErr>(())
297        }).unwrap();
298    }
299
300    #[test]
301    fn test_seek() {
302        Python::with_gil(|py| {
303            let io = py.import_bound("io")?;
304            let file = io.call_method1("BytesIO", (&b"hello"[..], ))?;
305            let mut file = PyBinaryFile::from(file);
306            file.seek(std::io::SeekFrom::Start(1))?;
307            let mut buf = [0u8; 4];
308            file.read_exact(&mut buf)?;
309            assert_eq!(&buf, b"ello");
310            Ok::<(), PyErr>(())
311        }).unwrap();
312    }
313
314    #[test]
315    fn test_flush() {
316        Python::with_gil(|py| {
317            let io = py.import_bound("io")?;
318            let file = io.call_method1("BytesIO", (&b""[..], ))?;
319            let mut file = PyBinaryFile::from(file);
320            file.write_all(b"hello")?;
321            file.flush()?;
322            assert_eq!(file.0.call_method0(py, "getvalue")?.extract::<&[u8]>(py)?, b"hello");
323            Ok::<(), PyErr>(())
324        }).unwrap();
325    }
326
327    #[test]
328    fn test_read_text() {
329        Python::with_gil(|py| -> PyResult<()> {
330            let io = py.import_bound("io")?;
331            let file = io.call_method1("StringIO", ("hello world", ))?;
332            let mut file = PyTextFile::from(file);
333            let mut buf = [0u8; 5];
334            file.read_exact(&mut buf)?;
335            assert_eq!(&buf, b"hello");
336            file.read_exact(&mut buf)?;
337            assert_eq!(&buf, b" worl");
338            let mut buf = Vec::new();
339            file.read_to_end(&mut buf).unwrap();
340            assert_eq!(&buf, b"d");
341            Ok(())
342        }).unwrap();
343    }
344
345    #[test]
346    fn test_read_text_unicode() {
347        // read halfway through a unicode character
348        let io = Python::with_gil(|py| -> PyResult<PyObject> {
349            let io = py.import("io")?;
350            let file = io.call_method1("StringIO", ("hello \u{1f600} world", ))?;
351            Ok(file.into())
352        }).unwrap();
353
354        let mut file = PyTextFile::from(io);
355        let mut buf = [0u8; 7];
356        file.read_exact(&mut buf).unwrap();
357        assert_eq!(&buf, b"hello \xf0");
358
359        let mut buf = [0u8; 1];
360        file.read_exact(&mut buf).unwrap();
361        assert_eq!(&buf, b"\x9f");
362    }
363}