Skip to main content

pyo3_bytes/
bytes.rs

1//! Support for Python buffer protocol
2
3use std::fmt::Write;
4use std::os::raw::c_int;
5use std::ptr::NonNull;
6
7use bytes::{Bytes, BytesMut};
8use pyo3::buffer::PyBuffer;
9use pyo3::exceptions::{PyIndexError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::{PyDict, PySlice, PyTuple};
12use pyo3::{ffi, IntoPyObjectExt};
13
14/// A wrapper around a [`bytes::Bytes`][].
15///
16/// This implements both import and export via the Python buffer protocol.
17///
18/// ### Buffer protocol import
19///
20/// This can be very useful as a general way to support ingest of a Python buffer protocol object.
21///
22/// The underlying [Bytes] manages the external memory, automatically calling the Python
23/// buffer's release callback when the internal reference count reaches 0.
24///
25/// Note that converting this [`Bytes`] into a [BytesMut][::bytes::BytesMut] will always create a
26/// deep copy of the buffer into newly allocated memory, since this `Bytes` is constructed from an
27/// owner.
28///
29/// ### Buffer protocol export
30///
31/// PyBytes implements the Python buffer protocol to enable Python to access the underlying `Bytes`
32/// data view without copies. In Python, this `PyBytes` object can be passed to Python `bytes` or
33/// `memoryview` constructors, `numpy.frombuffer`, or any other function that supports buffer
34/// protocol input.
35#[pyclass(
36    name = "Bytes",
37    subclass,
38    frozen,
39    sequence,
40    weakref,
41    skip_from_py_object
42)]
43#[derive(Clone, Default, Hash, PartialEq, PartialOrd, Eq, Ord)]
44pub struct PyBytes(Bytes);
45
46impl AsRef<Bytes> for PyBytes {
47    fn as_ref(&self) -> &Bytes {
48        &self.0
49    }
50}
51
52impl AsRef<[u8]> for PyBytes {
53    fn as_ref(&self) -> &[u8] {
54        self.0.as_ref()
55    }
56}
57
58impl PyBytes {
59    /// Construct a new [PyBytes]
60    pub fn new(buffer: Bytes) -> Self {
61        Self(buffer)
62    }
63
64    /// Consume and return the [Bytes]
65    pub fn into_inner(self) -> Bytes {
66        self.0
67    }
68
69    /// Access the underlying buffer as a byte slice
70    pub fn as_slice(&self) -> &[u8] {
71        self.as_ref()
72    }
73
74    /// Slice the underlying buffer using a Python slice object
75    ///
76    /// This should behave the same as Python's byte slicing:
77    ///     - `ValueError` if step is zero
78    ///     - Negative indices a-ok
79    ///     - If start/stop are out of bounds, they are clipped to the bounds of the buffer
80    ///     - If start > stop, the slice is empty
81    ///
82    /// This is NOT exposed to Python under the `#[pymethods]` impl
83    fn slice(&self, slice: &Bound<'_, PySlice>) -> PyResult<PyBytes> {
84        let bytes_length = self.0.len() as isize;
85        let (start, stop, step) = {
86            let slice_indices = slice.indices(bytes_length)?;
87            (slice_indices.start, slice_indices.stop, slice_indices.step)
88        };
89
90        let new_capacity = if (step > 0 && stop > start) || (step < 0 && stop < start) {
91            (((stop - start).abs() + step.abs() - 1) / step.abs()) as usize
92        } else {
93            0
94        };
95
96        if new_capacity == 0 {
97            return Ok(PyBytes(Bytes::new()));
98        }
99        if step == 1 {
100            // if start < 0  and stop > len and step == 1 just copy?
101            if start < 0 && stop >= bytes_length {
102                let out = self.0.slice(..);
103                let py_bytes = PyBytes(out);
104                return Ok(py_bytes);
105            }
106
107            if start >= 0 && stop <= bytes_length && start < stop {
108                let out = self.0.slice(start as usize..stop as usize);
109                let py_bytes = PyBytes(out);
110                return Ok(py_bytes);
111            }
112            // fall through to the general case here...
113        }
114        if step > 0 {
115            // forward
116            let mut new_buf = BytesMut::with_capacity(new_capacity);
117            new_buf.extend(
118                (start..stop)
119                    .step_by(step as usize)
120                    .map(|i| self.0[i as usize]),
121            );
122            Ok(PyBytes(new_buf.freeze()))
123        } else {
124            // backward
125            let mut new_buf = BytesMut::with_capacity(new_capacity);
126            new_buf.extend(
127                (stop + 1..=start)
128                    .rev()
129                    .step_by((-step) as usize)
130                    .map(|i| self.0[i as usize]),
131            );
132            Ok(PyBytes(new_buf.freeze()))
133        }
134    }
135}
136
137impl From<PyBytes> for Bytes {
138    fn from(value: PyBytes) -> Self {
139        value.0
140    }
141}
142
143impl From<Vec<u8>> for PyBytes {
144    fn from(value: Vec<u8>) -> Self {
145        PyBytes(value.into())
146    }
147}
148
149impl From<Bytes> for PyBytes {
150    fn from(value: Bytes) -> Self {
151        PyBytes(value)
152    }
153}
154
155impl From<BytesMut> for PyBytes {
156    fn from(value: BytesMut) -> Self {
157        PyBytes(value.into())
158    }
159}
160
161#[pymethods]
162impl PyBytes {
163    // By setting the argument to PyBytes, this means that any buffer-protocol object is supported
164    // here, since it will use the FromPyObject impl.
165    #[new]
166    #[pyo3(signature = (buf = PyBytes(Bytes::new())), text_signature = "(buf = b'')")]
167    fn py_new(buf: PyBytes) -> Self {
168        buf
169    }
170
171    fn __getnewargs_ex__<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyTuple>> {
172        let py_bytes = self.to_bytes(py);
173        let args = PyTuple::new(py, vec![py_bytes])?.into_bound_py_any(py)?;
174        let kwargs = PyDict::new(py).into_bound_py_any(py)?;
175        PyTuple::new(py, [args, kwargs])
176    }
177
178    /// The number of bytes in this Bytes
179    fn __len__(&self) -> usize {
180        self.0.len()
181    }
182
183    fn __repr__(&self) -> String {
184        format!("{self:?}")
185    }
186
187    fn __add__(&self, other: PyBytes) -> PyBytes {
188        let total_length = self.0.len() + other.0.len();
189        let mut new_buffer = BytesMut::with_capacity(total_length);
190        new_buffer.extend_from_slice(&self.0);
191        new_buffer.extend_from_slice(&other.0);
192        new_buffer.into()
193    }
194
195    fn __contains__(&self, item: PyBytes) -> bool {
196        self.0
197            .windows(item.0.len())
198            .any(|window| window == item.as_slice())
199    }
200
201    fn __eq__(&self, other: PyBytes) -> bool {
202        self.0.as_ref() == other.0.as_ref()
203    }
204
205    fn __getitem__<'py>(
206        &self,
207        py: Python<'py>,
208        key: BytesGetItemKey<'py>,
209    ) -> PyResult<Bound<'py, PyAny>> {
210        match key {
211            BytesGetItemKey::Int(mut index) => {
212                if index < 0 {
213                    index += self.0.len() as isize;
214                }
215                if index < 0 {
216                    return Err(PyIndexError::new_err("Index out of range"));
217                }
218                self.0
219                    .get(index as usize)
220                    .ok_or(PyIndexError::new_err("Index out of range"))?
221                    .into_bound_py_any(py)
222            }
223            BytesGetItemKey::Slice(slice) => {
224                let s = self.slice(&slice)?;
225                s.into_bound_py_any(py)
226            }
227        }
228    }
229
230    fn __mul__(&self, value: usize) -> PyBytes {
231        let mut out_buf = BytesMut::with_capacity(self.0.len() * value);
232        (0..value).for_each(|_| out_buf.extend_from_slice(self.0.as_ref()));
233        out_buf.into()
234    }
235
236    /// This is taken from opendal:
237    /// https://github.com/apache/opendal/blob/d001321b0f9834bc1e2e7d463bcfdc3683e968c9/bindings/python/src/utils.rs#L51-L72
238    #[allow(unsafe_code)]
239    unsafe fn __getbuffer__(
240        slf: PyRef<Self>,
241        view: *mut ffi::Py_buffer,
242        flags: c_int,
243    ) -> PyResult<()> {
244        let bytes = slf.0.as_ref();
245        let ret = ffi::PyBuffer_FillInfo(
246            view,
247            slf.as_ptr() as *mut _,
248            bytes.as_ptr() as *mut _,
249            bytes.len().try_into().unwrap(),
250            1, // read only
251            flags,
252        );
253        if ret == -1 {
254            return Err(PyErr::fetch(slf.py()));
255        }
256        Ok(())
257    }
258
259    // Comment from david hewitt on discord:
260    // > I think normally `__getbuffer__` takes a pointer to the owning Python object, so you
261    // > don't need to treat the allocation as owned separately. It should be good enough to keep
262    // > the allocation owned by the object.
263    // https://discord.com/channels/1209263839632424990/1324816949464666194/1328299411427557397
264    #[allow(unsafe_code)]
265    unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
266
267    /// If the binary data starts with the prefix string, return bytes[len(prefix):]. Otherwise,
268    /// return a copy of the original binary data:
269    #[pyo3(signature = (prefix, /))]
270    fn removeprefix(&self, prefix: PyBytes) -> PyBytes {
271        if self.0.starts_with(prefix.as_ref()) {
272            self.0.slice(prefix.0.len()..).into()
273        } else {
274            self.0.clone().into()
275        }
276    }
277
278    /// If the binary data ends with the suffix string and that suffix is not empty, return
279    /// `bytes[:-len(suffix)]`. Otherwise, return the original binary data.
280    #[pyo3(signature = (suffix, /))]
281    fn removesuffix(&self, suffix: PyBytes) -> PyBytes {
282        if self.0.ends_with(suffix.as_ref()) {
283            self.0.slice(0..self.0.len() - suffix.0.len()).into()
284        } else {
285            self.0.clone().into()
286        }
287    }
288
289    /// Return True if all bytes in the sequence are alphabetical ASCII characters or ASCII decimal
290    /// digits and the sequence is not empty, False otherwise. Alphabetic ASCII characters are
291    /// those byte values in the sequence b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.
292    /// ASCII decimal digits are those byte values in the sequence b'0123456789'.
293    fn isalnum(&self) -> bool {
294        if self.0.is_empty() {
295            return false;
296        }
297
298        for c in self.0.as_ref() {
299            if !c.is_ascii_alphanumeric() {
300                return false;
301            }
302        }
303        true
304    }
305
306    /// Return True if all bytes in the sequence are alphabetic ASCII characters and the sequence
307    /// is not empty, False otherwise. Alphabetic ASCII characters are those byte values in the
308    /// sequence b'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'.
309    fn isalpha(&self) -> bool {
310        if self.0.is_empty() {
311            return false;
312        }
313
314        for c in self.0.as_ref() {
315            if !c.is_ascii_alphabetic() {
316                return false;
317            }
318        }
319        true
320    }
321
322    /// Return True if the sequence is empty or all bytes in the sequence are ASCII, False
323    /// otherwise. ASCII bytes are in the range 0-0x7F.
324    fn isascii(&self) -> bool {
325        for c in self.0.as_ref() {
326            if !c.is_ascii() {
327                return false;
328            }
329        }
330        true
331    }
332
333    /// Return True if all bytes in the sequence are ASCII decimal digits and the sequence is not
334    /// empty, False otherwise. ASCII decimal digits are those byte values in the sequence
335    /// b'0123456789'.
336    fn isdigit(&self) -> bool {
337        if self.0.is_empty() {
338            return false;
339        }
340
341        for c in self.0.as_ref() {
342            if !c.is_ascii_digit() {
343                return false;
344            }
345        }
346        true
347    }
348
349    /// Return True if there is at least one lowercase ASCII character in the sequence and no
350    /// uppercase ASCII characters, False otherwise.
351    fn islower(&self) -> bool {
352        let mut has_lower = false;
353        for c in self.0.as_ref() {
354            if c.is_ascii_uppercase() {
355                return false;
356            }
357            if !has_lower && c.is_ascii_lowercase() {
358                has_lower = true;
359            }
360        }
361
362        has_lower
363    }
364
365    /// Return True if all bytes in the sequence are ASCII whitespace and the sequence is not
366    /// empty, False otherwise. ASCII whitespace characters are those byte values in the sequence
367    /// b' \t\n\r\x0b\f' (space, tab, newline, carriage return, vertical tab, form feed).
368    fn isspace(&self) -> bool {
369        if self.0.is_empty() {
370            return false;
371        }
372
373        for c in self.0.as_ref() {
374            // Also check for vertical tab
375            if !(c.is_ascii_whitespace() || *c == b'\x0b') {
376                return false;
377            }
378        }
379        true
380    }
381
382    /// Return True if there is at least one uppercase alphabetic ASCII character in the sequence
383    /// and no lowercase ASCII characters, False otherwise.
384    fn isupper(&self) -> bool {
385        let mut has_upper = false;
386        for c in self.0.as_ref() {
387            if c.is_ascii_lowercase() {
388                return false;
389            }
390            if !has_upper && c.is_ascii_uppercase() {
391                has_upper = true;
392            }
393        }
394
395        has_upper
396    }
397
398    /// Return a copy of the sequence with all the uppercase ASCII characters converted to their
399    /// corresponding lowercase counterpart.
400    fn lower(&self) -> PyBytes {
401        self.0.to_ascii_lowercase().into()
402    }
403
404    /// Return a copy of the sequence with all the lowercase ASCII characters converted to their
405    /// corresponding uppercase counterpart.
406    fn upper(&self) -> PyBytes {
407        self.0.to_ascii_uppercase().into()
408    }
409
410    /// Copy this buffer's contents to a Python `bytes` object
411    fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, pyo3::types::PyBytes> {
412        pyo3::types::PyBytes::new(py, &self.0)
413    }
414}
415
416impl<'py> FromPyObject<'_, 'py> for PyBytes {
417    type Error = PyErr;
418
419    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
420        let buffer = obj.extract::<PyBytesWrapper>()?;
421        let bytes = Bytes::from_owner(buffer);
422        Ok(Self(bytes))
423    }
424}
425
426/// A wrapper around a PyBuffer that applies a custom destructor that checks if the Python
427/// interpreter is still initialized before freeing the buffer memory.
428///
429/// This also implements AsRef<[u8]> because that is required for Bytes::from_owner
430#[derive(Debug)]
431struct PyBytesWrapper(PyBuffer<u8>);
432
433impl AsRef<[u8]> for PyBytesWrapper {
434    #[allow(unsafe_code)]
435    fn as_ref(&self) -> &[u8] {
436        let len = self.0.item_count();
437
438        let ptr = NonNull::new(self.0.buf_ptr() as _).expect("Expected buffer ptr to be non null");
439
440        // Safety:
441        //
442        // This requires that the data will not be mutated from Python. Sadly, the buffer protocol
443        // does not uphold this invariant always for us, and the Python user must take care not to
444        // mutate the provided buffer.
445        unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u8, len) }
446    }
447}
448
449fn validate_buffer(buf: &PyBuffer<u8>) -> PyResult<()> {
450    if !buf.is_c_contiguous() {
451        return Err(PyValueError::new_err("Buffer is not C contiguous"));
452    }
453
454    if buf.strides().iter().any(|s| *s != 1) {
455        return Err(PyValueError::new_err(format!(
456            "strides other than 1 not supported, got: {:?} ",
457            buf.strides()
458        )));
459    }
460
461    Ok(())
462}
463
464impl<'py> FromPyObject<'_, 'py> for PyBytesWrapper {
465    type Error = PyErr;
466
467    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
468        let buffer = obj.extract::<PyBuffer<u8>>()?;
469        validate_buffer(&buffer)?;
470        Ok(Self(buffer))
471    }
472}
473
474/// This is _mostly_ the same as the upstream [`bytes::Bytes` Debug
475/// impl](https://github.com/tokio-rs/bytes/blob/71824b095c4150b3af0776ac158795c00ff9d53f/src/fmt/debug.rs#L6-L37),
476/// however we don't use it because that impl doesn't look how the python bytes repr looks; this
477/// isn't exactly the same either, as the python repr will switch between `'` and `"` based on the
478/// presence of the other in the string, but it's close enough AND we don't have to do a full scan
479/// of the bytes to check for that.
480impl std::fmt::Debug for PyBytes {
481    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
482        f.write_str("Bytes(b\"")?;
483        for &byte in self.0.as_ref() {
484            match byte {
485                // https://doc.rust-lang.org/reference/tokens.html#byte-escapes
486                b'\\' => f.write_str(r"\\")?,
487                b'"' => f.write_str("\\\"")?,
488                b'\n' => f.write_str(r"\n")?,
489                b'\r' => f.write_str(r"\r")?,
490                b'\t' => f.write_str(r"\t")?,
491                // printable ASCII
492                0x20..=0x7E => f.write_char(byte as char)?,
493                _ => write!(f, "\\x{byte:02x}")?,
494            }
495        }
496        f.write_str("\")")?;
497        Ok(())
498    }
499}
500
501/// A key for the `__getitem__` method of `PyBytes` - int/slice
502#[derive(FromPyObject)]
503enum BytesGetItemKey<'py> {
504    /// An integer index
505    Int(isize),
506    /// A python slice
507    Slice(Bound<'py, PySlice>),
508}