1use std::fs::File;
2use std::io::Error as IoError;
3use std::io::Read;
4use std::io::Seek;
5use std::io::SeekFrom;
6use std::io::Write;
7use std::os::raw::c_char;
8
9use pyo3::exceptions::PyOSError;
10use pyo3::exceptions::PyTypeError;
11use pyo3::prelude::*;
12use pyo3::types::PyByteArray;
13use pyo3::types::PyBytes;
14use pyo3::types::PyInt;
15
16#[macro_export]
19macro_rules! transmute_file_error {
20 ($self:ident, $e:ident, $msg:expr, $py:expr) => {{
21 if $e.is_instance_of::<PyOSError>($py) {
24 if let Ok(code) = &$e.value($py).getattr("errno") {
25 if let Ok(n) = code.extract::<i32>() {
26 return Err(IoError::from_raw_os_error(n));
27 }
28 }
29 }
30
31 $e.restore($py);
35 Err(IoError::new(std::io::ErrorKind::Other, $msg))
36 }};
37}
38
39#[derive(Debug)]
43pub struct PyFileRead {
44 file: PyObject,
45 has_readinto: bool,
46}
47
48impl PyFileRead {
49 pub fn from_ref<'py>(file: &Bound<'py, PyAny>) -> PyResult<PyFileRead> {
50 let py = file.py();
51
52 let implementation = py
53 .import(pyo3::intern!(py, "sys"))?
54 .getattr(pyo3::intern!(py, "implementation"))?
55 .getattr(pyo3::intern!(py, "name"))?;
56
57 if file.hasattr(pyo3::intern!(py, "readinto"))?
58 && implementation.eq(pyo3::intern!(py, "cpython"))?
59 {
60 let b = PyByteArray::new(py, &[]);
61 if let Ok(res) = file.call_method1(pyo3::intern!(py, "readinto"), (b,)) {
62 if res.downcast::<PyInt>().is_ok() {
63 return Ok({
64 PyFileRead {
65 file: file.clone().unbind().into_any(),
66 has_readinto: true,
67 }
68 });
69 }
70 }
71 }
72
73 let res = file.call_method1(pyo3::intern!(py, "read"), (0,))?;
74 if res.downcast::<PyBytes>().is_ok() {
75 Ok(PyFileRead {
76 file: file.clone().unbind().into_any(),
77 has_readinto: false,
78 })
79 } else {
80 let ty = res.get_type().name()?.to_string();
81 Err(PyTypeError::new_err(format!(
82 "expected bytes, found {}",
83 ty
84 )))
85 }
86 }
87
88 fn read_read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
89 Python::with_gil(|py| {
90 match self
91 .file
92 .call_method1(py, pyo3::intern!(py, "read"), (buf.len(),))
93 {
94 Ok(obj) => {
95 if let Ok(bytes) = obj.extract::<Bound<PyBytes>>(py) {
97 let b = bytes.as_bytes();
98 (&mut buf[..b.len()]).copy_from_slice(b);
99 Ok(b.len())
100 } else {
101 let ty = obj.bind(py).get_type().name()?.to_string();
102 let msg = format!("expected bytes, found {}", ty);
103 PyTypeError::new_err(msg).restore(py);
104 Err(IoError::new(
105 std::io::ErrorKind::Other,
106 "fh.read did not return bytes",
107 ))
108 }
109 }
110 Err(e) => {
111 transmute_file_error!(self, e, "read method failed", py)
112 }
113 }
114 })
115 }
116
117 fn read_readinto(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
118 Python::with_gil(|py| {
119 let memview = unsafe {
120 let m = pyo3::ffi::PyMemoryView_FromMemory(
121 buf.as_mut_ptr() as *mut c_char,
122 buf.len() as isize,
123 pyo3::ffi::PyBUF_WRITE,
124 );
125 PyObject::from_owned_ptr_or_err(py, m)?
126 };
127 match self
128 .file
129 .call_method1(py, pyo3::intern!(py, "readinto"), (memview,))
130 {
131 Ok(n) => match n.extract::<usize>(py) {
132 Ok(n) => Ok(n),
133 Err(_) => {
134 let ty = n.bind(py).get_type().name()?.to_string();
135 let msg = format!("expected int, found {}", ty);
136 PyTypeError::new_err(msg).restore(py);
137 Err(IoError::new(
138 std::io::ErrorKind::Other,
139 "fh.readinto did not return int",
140 ))
141 }
142 },
143 Err(e) => {
144 transmute_file_error!(self, e, "readinto method failed", py)
145 }
146 }
147 })
148 }
149}
150
151impl Read for PyFileRead {
152 fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
153 if self.has_readinto {
154 self.read_readinto(buf)
155 } else {
156 self.read_read(buf)
157 }
158 }
159}
160
161impl Seek for PyFileRead {
162 fn seek(&mut self, seek: SeekFrom) -> Result<u64, IoError> {
163 let (offset, whence) = match seek {
164 SeekFrom::Start(n) => (n as i64, 0),
165 SeekFrom::Current(n) => (n, 1),
166 SeekFrom::End(n) => (n, 2),
167 };
168 Python::with_gil(
169 |py| match self.file.call_method1(py, "seek", (offset, whence)) {
170 Ok(obj) => {
171 if let Ok(n) = obj.extract::<u64>(py) {
172 Ok(n)
173 } else {
174 let ty = obj.bind(py).get_type().name()?.to_string();
175 let msg = format!("expected int, found {}", ty);
176 PyTypeError::new_err(msg).restore(py);
177 Err(IoError::new(
178 std::io::ErrorKind::Other,
179 "fh.seek did not return position",
180 ))
181 }
182 }
183 Err(e) => Err(IoError::new(std::io::ErrorKind::Unsupported, e.to_string())),
184 },
185 )
186 }
187}
188
189#[derive(Debug)]
193pub struct PyFileWrite {
194 file: PyObject,
195}
196
197impl PyFileWrite {
198 pub fn from_ref<'py>(file: &Bound<'py, PyAny>) -> PyResult<PyFileWrite> {
199 let py = file.py();
200 file.call_method1(pyo3::intern!(py, "write"), (PyBytes::new(py, b""),))?;
201 Ok(Self {
202 file: file.clone().unbind().into_any(),
203 })
204 }
205}
206
207impl Write for PyFileWrite {
208 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
209 Python::with_gil(|py| {
210 let memview = unsafe {
212 let m = pyo3::ffi::PyMemoryView_FromMemory(
213 buf.as_ptr() as *mut c_char,
214 buf.len() as isize,
215 pyo3::ffi::PyBUF_READ,
216 );
217 PyObject::from_owned_ptr_or_err(py, m)?
218 };
219 match self
221 .file
222 .bind(py)
223 .call_method1(pyo3::intern!(py, "write"), (memview,))
224 {
225 Err(e) => {
226 transmute_file_error!(self, e, "write method failed", py)
227 }
228 Ok(obj) => {
229 if let Ok(n) = obj.extract::<usize>() {
230 Ok(n)
231 } else {
232 let ty = obj.get_type().name()?.to_string();
233 let msg = format!("expected int, found {}", ty);
234 PyTypeError::new_err(msg).restore(py);
235 Err(IoError::new(
236 std::io::ErrorKind::Other,
237 "readinto method did not return int",
238 ))
239 }
240 }
241 }
242 })
243 }
244
245 fn flush(&mut self) -> std::io::Result<()> {
246 Python::with_gil(
247 |py| match self.file.bind(py).call_method0(pyo3::intern!(py, "flush")) {
248 Ok(_) => Ok(()),
249 Err(e) => transmute_file_error!(self, e, "flush method failed", py),
250 },
251 )
252 }
253}
254
255pub enum PyFileReadWrapper {
258 PyFile(PyFileRead),
259 File(File),
260}
261
262impl Read for PyFileReadWrapper {
263 fn read(&mut self, buf: &mut [u8]) -> Result<usize, IoError> {
264 match self {
265 PyFileReadWrapper::PyFile(r) => r.read(buf),
266 PyFileReadWrapper::File(f) => f.read(buf),
267 }
268 }
269}
270
271impl Seek for PyFileReadWrapper {
272 fn seek(&mut self, seek: SeekFrom) -> Result<u64, IoError> {
273 match self {
274 PyFileReadWrapper::PyFile(r) => r.seek(seek),
275 PyFileReadWrapper::File(f) => f.seek(seek),
276 }
277 }
278}
279
280pub enum PyFileWriteWrapper {
283 PyFile(PyFileWrite),
284 File(File),
285}
286
287impl Write for PyFileWriteWrapper {
288 fn write(&mut self, buf: &[u8]) -> Result<usize, IoError> {
289 match self {
290 PyFileWriteWrapper::PyFile(f) => f.write(buf),
291 PyFileWriteWrapper::File(f) => f.write(buf),
292 }
293 }
294
295 fn flush(&mut self) -> Result<(), IoError> {
296 match self {
297 PyFileWriteWrapper::PyFile(f) => f.flush(),
298 PyFileWriteWrapper::File(f) => f.flush(),
299 }
300 }
301}