pyo3_bytes/
bytes.rs

1//! Support for Python buffer protocol
2
3use std::fmt::Write;
4use std::os::raw::c_int;
5use std::ptr::NonNull;
6
7use bytes::{Bytes, BytesMut};
8use pyo3::buffer::PyBuffer;
9use pyo3::exceptions::{PyIndexError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyDict, PySlice, PyTuple};
12use pyo3::{ffi, IntoPyObjectExt};
13
14/// A wrapper around a [`bytes::Bytes`][].
15///
16/// This implements both import and export via the Python buffer protocol.
17///
18/// ### Buffer protocol import
19///
20/// This can be very useful as a general way to support ingest of a Python buffer protocol object.
21///
22/// The underlying [Bytes] manages the external memory, automatically calling the Python
23/// buffer's release callback when the internal reference count reaches 0.
24///
25/// Note that converting this [`Bytes`] into a [BytesMut][::bytes::BytesMut] will always create a
26/// deep copy of the buffer into newly allocated memory, since this `Bytes` is constructed from an
27/// owner.
28///
29/// ### Buffer protocol export
30///
31/// PyBytes implements the Python buffer protocol to enable Python to access the underlying `Bytes`
32/// data view without copies. In Python, this `PyBytes` object can be passed to Python `bytes` or
33/// `memoryview` constructors, `numpy.frombuffer`, or any other function that supports buffer
34/// protocol input.
35#[pyclass(name = "Bytes", subclass, frozen, sequence, weakref)]
36#[derive(Hash, PartialEq, PartialOrd, Eq, Ord)]
37pub struct PyBytes(Bytes);
38
39impl AsRef<Bytes> for PyBytes {
40    fn as_ref(&self) -> &Bytes {
41        &self.0
42    }
43}
44
45impl AsRef<[u8]> for PyBytes {
46    fn as_ref(&self) -> &[u8] {
47        self.0.as_ref()
48    }
49}
50
51impl PyBytes {
52    /// Construct a new [PyBytes]
53    pub fn new(buffer: Bytes) -> Self {
54        Self(buffer)
55    }
56
57    /// Consume and return the [Bytes]
58    pub fn into_inner(self) -> Bytes {
59        self.0
60    }
61
62    /// Access the underlying buffer as a byte slice
63    pub fn as_slice(&self) -> &[u8] {
64        self.as_ref()
65    }
66
67    /// Slice the underlying buffer using a Python slice object
68    ///
69    /// This should behave the same as Python's byte slicing:
70    ///     - `ValueError` if step is zero
71    ///     - Negative indices a-ok
72    ///     - If start/stop are out of bounds, they are clipped to the bounds of the buffer
73    ///     - If start > stop, the slice is empty
74    ///
75    /// This is NOT exposed to Python under the `#[pymethods]` impl
76    fn slice(&self, slice: &Bound<'_, PySlice>) -> PyResult<PyBytes> {
77        let bytes_length = self.0.len() as isize;
78        let (start, stop, step) = {
79            let slice_indices = slice.indices(bytes_length)?;
80            (slice_indices.start, slice_indices.stop, slice_indices.step)
81        };
82
83        let new_capacity = if (step > 0 && stop > start) || (step < 0 && stop < start) {
84            (((stop - start).abs() + step.abs() - 1) / step.abs()) as usize
85        } else {
86            0
87        };
88
89        if new_capacity == 0 {
90            return Ok(PyBytes(Bytes::new()));
91        }
92        if step == 1 {
93            // if start < 0  and stop > len and step == 1 just copy?
94            if start < 0 && stop >= bytes_length {
95                let out = self.0.slice(..);
96                let py_bytes = PyBytes(out);
97                return Ok(py_bytes);
98            }
99
100            if start >= 0 && stop <= bytes_length && start < stop {
101                let out = self.0.slice(start as usize..stop as usize);
102                let py_bytes = PyBytes(out);
103                return Ok(py_bytes);
104            }
105            // fall through to the general case here...
106        }
107        if step > 0 {
108            // forward
109            let mut new_buf = BytesMut::with_capacity(new_capacity);
110            new_buf.extend(
111                (start..stop)
112                    .step_by(step as usize)
113                    .map(|i| self.0[i as usize]),
114            );
115            Ok(PyBytes(new_buf.freeze()))
116        } else {
117            // backward
118            let mut new_buf = BytesMut::with_capacity(new_capacity);
119            new_buf.extend(
120                (stop + 1..=start)
121                    .rev()
122                    .step_by((-step) as usize)
123                    .map(|i| self.0[i as usize]),
124            );
125            Ok(PyBytes(new_buf.freeze()))
126        }
127    }
128}
129
130impl From<PyBytes> for Bytes {
131    fn from(value: PyBytes) -> Self {
132        value.0
133    }
134}
135
136impl From<Vec<u8>> for PyBytes {
137    fn from(value: Vec<u8>) -> Self {
138        PyBytes(value.into())
139    }
140}
141
142impl From<Bytes> for PyBytes {
143    fn from(value: Bytes) -> Self {
144        PyBytes(value)
145    }
146}
147
148impl From<BytesMut> for PyBytes {
149    fn from(value: BytesMut) -> Self {
150        PyBytes(value.into())
151    }
152}
153
154#[pymethods]
155impl PyBytes {
156    // By setting the argument to PyBytes, this means that any buffer-protocol object is supported
157    // here, since it will use the FromPyObject impl.
158    #[new]
159    #[pyo3(signature = (buf = PyBytes(Bytes::new())), text_signature = "(buf = b'')")]
160    fn py_new(buf: PyBytes) -> Self {
161        buf
162    }
163
164    fn __getnewargs_ex__(&self, py: Python) -> PyResult<PyObject> {
165        let py_bytes = self.to_bytes(py);
166        let args = PyTuple::new(py, vec![py_bytes])?.into_py_any(py)?;
167        let kwargs = PyDict::new(py);
168        PyTuple::new(py, [args, kwargs.into_py_any(py)?])?.into_py_any(py)
169    }
170
171    /// The number of bytes in this Bytes
172    fn __len__(&self) -> usize {
173        self.0.len()
174    }
175
176    fn __repr__(&self) -> String {
177        format!("{self:?}")
178    }
179
180    fn __add__(&self, other: PyBytes) -> PyBytes {
181        let total_length = self.0.len() + other.0.len();
182        let mut new_buffer = BytesMut::with_capacity(total_length);
183        new_buffer.extend_from_slice(&self.0);
184        new_buffer.extend_from_slice(&other.0);
185        new_buffer.into()
186    }
187
188    fn __contains__(&self, item: PyBytes) -> bool {
189        self.0
190            .windows(item.0.len())
191            .any(|window| window == item.as_slice())
192    }
193
194    fn __eq__(&self, other: PyBytes) -> bool {
195        self.0.as_ref() == other.0.as_ref()
196    }
197
198    fn __getitem__<'py>(
199        &self,
200        py: Python<'py>,
201        key: BytesGetItemKey<'py>,
202    ) -> PyResult<Bound<'py, PyAny>> {
203        match key {
204            BytesGetItemKey::Int(mut index) => {
205                if index < 0 {
206                    index += self.0.len() as isize;
207                }
208                if index < 0 {
209                    return Err(PyIndexError::new_err("Index out of range"));
210                }
211                self.0
212                    .get(index as usize)
213                    .ok_or(PyIndexError::new_err("Index out of range"))?
214                    .into_bound_py_any(py)
215            }
216            BytesGetItemKey::Slice(slice) => {
217                let s = self.slice(&slice)?;
218                s.into_bound_py_any(py)
219            }
220        }
221    }
222
223    fn __mul__(&self, value: usize) -> PyBytes {
224        let mut out_buf = BytesMut::with_capacity(self.0.len() * value);
225        (0..value).for_each(|_| out_buf.extend_from_slice(self.0.as_ref()));
226        out_buf.into()
227    }
228
229    /// This is taken from opendal:
230    /// https://github.com/apache/opendal/blob/d001321b0f9834bc1e2e7d463bcfdc3683e968c9/bindings/python/src/utils.rs#L51-L72
231    #[allow(unsafe_code)]
232    unsafe fn __getbuffer__(
233        slf: PyRef<Self>,
234        view: *mut ffi::Py_buffer,
235        flags: c_int,
236    ) -> PyResult<()> {
237        let bytes = slf.0.as_ref();
238        let ret = ffi::PyBuffer_FillInfo(
239            view,
240            slf.as_ptr() as *mut _,
241            bytes.as_ptr() as *mut _,
242            bytes.len().try_into().unwrap(),
243            1, // read only
244            flags,
245        );
246        if ret == -1 {
247            return Err(PyErr::fetch(slf.py()));
248        }
249        Ok(())
250    }
251
252    // Comment from david hewitt on discord:
253    // > I think normally `__getbuffer__` takes a pointer to the owning Python object, so you
254    // > don't need to treat the allocation as owned separately. It should be good enough to keep
255    // > the allocation owned by the object.
256    // https://discord.com/channels/1209263839632424990/1324816949464666194/1328299411427557397
257    #[allow(unsafe_code)]
258    unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
259
260    /// If the binary data starts with the prefix string, return bytes[len(prefix):]. Otherwise,
261    /// return a copy of the original binary data:
262    #[pyo3(signature = (prefix, /))]
263    fn removeprefix(&self, prefix: PyBytes) -> PyBytes {
264        if self.0.starts_with(prefix.as_ref()) {
265            self.0.slice(prefix.0.len()..).into()
266        } else {
267            self.0.clone().into()
268        }
269    }
270
271    /// If the binary data ends with the suffix string and that suffix is not empty, return
272    /// `bytes[:-len(suffix)]`. Otherwise, return the original binary data.
273    #[pyo3(signature = (suffix, /))]
274    fn removesuffix(&self, suffix: PyBytes) -> PyBytes {
275        if self.0.ends_with(suffix.as_ref()) {
276            self.0.slice(0..self.0.len() - suffix.0.len()).into()
277        } else {
278            self.0.clone().into()
279        }
280    }
281
282    /// Return True if all bytes in the sequence are alphabetical ASCII characters or ASCII decimal
283    /// digits and the sequence is not empty, False otherwise. Alphabetic ASCII characters are
284    /// those byte values in the sequence b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.
285    /// ASCII decimal digits are those byte values in the sequence b'0123456789'.
286    fn isalnum(&self) -> bool {
287        if self.0.is_empty() {
288            return false;
289        }
290
291        for c in self.0.as_ref() {
292            if !c.is_ascii_alphanumeric() {
293                return false;
294            }
295        }
296        true
297    }
298
299    /// Return True if all bytes in the sequence are alphabetic ASCII characters and the sequence
300    /// is not empty, False otherwise. Alphabetic ASCII characters are those byte values in the
301    /// sequence b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.
302    fn isalpha(&self) -> bool {
303        if self.0.is_empty() {
304            return false;
305        }
306
307        for c in self.0.as_ref() {
308            if !c.is_ascii_alphabetic() {
309                return false;
310            }
311        }
312        true
313    }
314
315    /// Return True if the sequence is empty or all bytes in the sequence are ASCII, False
316    /// otherwise. ASCII bytes are in the range 0-0x7F.
317    fn isascii(&self) -> bool {
318        for c in self.0.as_ref() {
319            if !c.is_ascii() {
320                return false;
321            }
322        }
323        true
324    }
325
326    /// Return True if all bytes in the sequence are ASCII decimal digits and the sequence is not
327    /// empty, False otherwise. ASCII decimal digits are those byte values in the sequence
328    /// b'0123456789'.
329    fn isdigit(&self) -> bool {
330        if self.0.is_empty() {
331            return false;
332        }
333
334        for c in self.0.as_ref() {
335            if !c.is_ascii_digit() {
336                return false;
337            }
338        }
339        true
340    }
341
342    /// Return True if there is at least one lowercase ASCII character in the sequence and no
343    /// uppercase ASCII characters, False otherwise.
344    fn islower(&self) -> bool {
345        let mut has_lower = false;
346        for c in self.0.as_ref() {
347            if c.is_ascii_uppercase() {
348                return false;
349            }
350            if !has_lower && c.is_ascii_lowercase() {
351                has_lower = true;
352            }
353        }
354
355        has_lower
356    }
357
358    /// Return True if all bytes in the sequence are ASCII whitespace and the sequence is not
359    /// empty, False otherwise. ASCII whitespace characters are those byte values in the sequence
360    /// b' \t\n\r\x0b\f' (space, tab, newline, carriage return, vertical tab, form feed).
361    fn isspace(&self) -> bool {
362        if self.0.is_empty() {
363            return false;
364        }
365
366        for c in self.0.as_ref() {
367            // Also check for vertical tab
368            if !(c.is_ascii_whitespace() || *c == b'\x0b') {
369                return false;
370            }
371        }
372        true
373    }
374
375    /// Return True if there is at least one uppercase alphabetic ASCII character in the sequence
376    /// and no lowercase ASCII characters, False otherwise.
377    fn isupper(&self) -> bool {
378        let mut has_upper = false;
379        for c in self.0.as_ref() {
380            if c.is_ascii_lowercase() {
381                return false;
382            }
383            if !has_upper && c.is_ascii_uppercase() {
384                has_upper = true;
385            }
386        }
387
388        has_upper
389    }
390
391    /// Return a copy of the sequence with all the uppercase ASCII characters converted to their
392    /// corresponding lowercase counterpart.
393    fn lower(&self) -> PyBytes {
394        self.0.to_ascii_lowercase().into()
395    }
396
397    /// Return a copy of the sequence with all the lowercase ASCII characters converted to their
398    /// corresponding uppercase counterpart.
399    fn upper(&self) -> PyBytes {
400        self.0.to_ascii_uppercase().into()
401    }
402
403    /// Copy this buffer's contents to a Python `bytes` object
404    fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, pyo3::types::PyBytes> {
405        pyo3::types::PyBytes::new(py, &self.0)
406    }
407}
408
409impl<'py> FromPyObject<'py> for PyBytes {
410    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
411        let buffer = ob.extract::<PyBytesWrapper>()?;
412        let bytes = Bytes::from_owner(buffer);
413        Ok(Self(bytes))
414    }
415}
416
417/// A wrapper around a PyBuffer that applies a custom destructor that checks if the Python
418/// interpreter is still initialized before freeing the buffer memory.
419///
420/// This also implements AsRef<[u8]> because that is required for Bytes::from_owner
421#[derive(Debug)]
422struct PyBytesWrapper(Option<PyBuffer<u8>>);
423
424impl Drop for PyBytesWrapper {
425    #[allow(unsafe_code)]
426    fn drop(&mut self) {
427        // Only call the underlying Drop of PyBuffer if the Python interpreter is still
428        // initialized. Sometimes the Drop can attempt to happen after the Python interpreter was
429        // already finalized.
430        // https://github.com/kylebarron/arro3/issues/230
431        let is_initialized = unsafe { ffi::Py_IsInitialized() };
432        if let Some(val) = self.0.take() {
433            if is_initialized == 0 {
434                std::mem::forget(val);
435            } else {
436                drop(val);
437            }
438        }
439    }
440}
441
442impl AsRef<[u8]> for PyBytesWrapper {
443    #[allow(unsafe_code)]
444    fn as_ref(&self) -> &[u8] {
445        let buffer = self.0.as_ref().expect("Buffer already disposed");
446        let len = buffer.item_count();
447
448        let ptr = NonNull::new(buffer.buf_ptr() as _).expect("Expected buffer ptr to be non null");
449
450        // Safety:
451        //
452        // This requires that the data will not be mutated from Python. Sadly, the buffer protocol
453        // does not uphold this invariant always for us, and the Python user must take care not to
454        // mutate the provided buffer.
455        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u8, len) }
456    }
457}
458
459fn validate_buffer(buf: &PyBuffer<u8>) -> PyResult<()> {
460    if !buf.is_c_contiguous() {
461        return Err(PyValueError::new_err("Buffer is not C contiguous"));
462    }
463
464    if buf.strides().iter().any(|s| *s != 1) {
465        return Err(PyValueError::new_err(format!(
466            "strides other than 1 not supported, got: {:?} ",
467            buf.strides()
468        )));
469    }
470
471    Ok(())
472}
473
474impl<'py> FromPyObject<'py> for PyBytesWrapper {
475    fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
476        let buffer = ob.extract::<PyBuffer<u8>>()?;
477        validate_buffer(&buffer)?;
478        Ok(Self(Some(buffer)))
479    }
480}
481
482/// This is _mostly_ the same as the upstream [`bytes::Bytes` Debug
483/// impl](https://github.com/tokio-rs/bytes/blob/71824b095c4150b3af0776ac158795c00ff9d53f/src/fmt/debug.rs#L6-L37),
484/// however we don't use it because that impl doesn't look how the python bytes repr looks; this
485/// isn't exactly the same either, as the python repr will switch between `'` and `"` based on the
486/// presence of the other in the string, but it's close enough AND we don't have to do a full scan
487/// of the bytes to check for that.
488impl std::fmt::Debug for PyBytes {
489    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490        f.write_str("Bytes(b\"")?;
491        for &byte in self.0.as_ref() {
492            match byte {
493                // https://doc.rust-lang.org/reference/tokens.html#byte-escapes
494                b'\\' => f.write_str(r"\\")?,
495                b'"' => f.write_str("\\\"")?,
496                b'\n' => f.write_str(r"\n")?,
497                b'\r' => f.write_str(r"\r")?,
498                b'\t' => f.write_str(r"\t")?,
499                // printable ASCII
500                0x20..=0x7E => f.write_char(byte as char)?,
501                _ => write!(f, "\\x{byte:02x}")?,
502            }
503        }
504        f.write_str("\")")?;
505        Ok(())
506    }
507}
508
509/// A key for the `__getitem__` method of `PyBytes` - int/slice
510#[derive(FromPyObject)]
511enum BytesGetItemKey<'py> {
512    /// An integer index
513    Int(isize),
514    /// A python slice
515    Slice(Bound<'py, PySlice>),
516}