Skip to main content

rustpython_vm/protocol/
buffer.rs

1//! Buffer protocol
2//! <https://docs.python.org/3/c-api/buffer.html>
3
4use crate::{
5    Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine,
6    common::{
7        borrow::{BorrowedValue, BorrowedValueMut},
8        lock::{MapImmutable, PyMutex, PyMutexGuard},
9    },
10    object::PyObjectPayload,
11    sliceable::SequenceIndexOp,
12};
13use alloc::borrow::Cow;
14use core::{fmt::Debug, ops::Range};
15use itertools::Itertools;
16
17pub struct BufferMethods {
18    pub obj_bytes: fn(&PyBuffer) -> BorrowedValue<'_, [u8]>,
19    pub obj_bytes_mut: fn(&PyBuffer) -> BorrowedValueMut<'_, [u8]>,
20    pub release: fn(&PyBuffer),
21    pub retain: fn(&PyBuffer),
22}
23
24impl Debug for BufferMethods {
25    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26        f.debug_struct("BufferMethods")
27            .field("obj_bytes", &(self.obj_bytes as usize))
28            .field("obj_bytes_mut", &(self.obj_bytes_mut as usize))
29            .field("release", &(self.release as usize))
30            .field("retain", &(self.retain as usize))
31            .finish()
32    }
33}
34
35#[derive(Debug, Clone, Traverse)]
36pub struct PyBuffer {
37    pub obj: PyObjectRef,
38    #[pytraverse(skip)]
39    pub desc: BufferDescriptor,
40    #[pytraverse(skip)]
41    methods: &'static BufferMethods,
42}
43
44impl PyBuffer {
45    pub fn new(obj: PyObjectRef, desc: BufferDescriptor, methods: &'static BufferMethods) -> Self {
46        let zelf = Self {
47            obj,
48            desc: desc.validate(),
49            methods,
50        };
51        zelf.retain();
52        zelf
53    }
54
55    pub fn as_contiguous(&self) -> Option<BorrowedValue<'_, [u8]>> {
56        self.desc
57            .is_contiguous()
58            .then(|| unsafe { self.contiguous_unchecked() })
59    }
60
61    pub fn as_contiguous_mut(&self) -> Option<BorrowedValueMut<'_, [u8]>> {
62        (!self.desc.readonly && self.desc.is_contiguous())
63            .then(|| unsafe { self.contiguous_mut_unchecked() })
64    }
65
66    pub fn from_byte_vector(bytes: Vec<u8>, vm: &VirtualMachine) -> Self {
67        let bytes_len = bytes.len();
68        Self::new(
69            PyPayload::into_pyobject(VecBuffer::from(bytes), vm),
70            BufferDescriptor::simple(bytes_len, true),
71            &VEC_BUFFER_METHODS,
72        )
73    }
74
75    /// # Safety
76    /// assume the buffer is contiguous
77    pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<'_, [u8]> {
78        self.obj_bytes()
79    }
80
81    /// # Safety
82    /// assume the buffer is contiguous and writable
83    pub unsafe fn contiguous_mut_unchecked(&self) -> BorrowedValueMut<'_, [u8]> {
84        self.obj_bytes_mut()
85    }
86
87    pub fn append_to(&self, buf: &mut Vec<u8>) {
88        if let Some(bytes) = self.as_contiguous() {
89            buf.extend_from_slice(&bytes);
90        } else {
91            let bytes = &*self.obj_bytes();
92            self.desc.for_each_segment(true, |range| {
93                buf.extend_from_slice(&bytes[range.start as usize..range.end as usize])
94            });
95        }
96    }
97
98    pub fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
99        let borrowed;
100        let mut collected;
101        let v = if let Some(bytes) = self.as_contiguous() {
102            borrowed = bytes;
103            &*borrowed
104        } else {
105            collected = vec![];
106            self.append_to(&mut collected);
107            &collected
108        };
109        f(v)
110    }
111
112    pub fn obj_as<T: PyObjectPayload>(&self) -> &Py<T> {
113        unsafe { self.obj.downcast_unchecked_ref() }
114    }
115
116    pub fn obj_bytes(&self) -> BorrowedValue<'_, [u8]> {
117        (self.methods.obj_bytes)(self)
118    }
119
120    pub fn obj_bytes_mut(&self) -> BorrowedValueMut<'_, [u8]> {
121        (self.methods.obj_bytes_mut)(self)
122    }
123
124    pub fn release(&self) {
125        (self.methods.release)(self)
126    }
127
128    pub fn retain(&self) {
129        (self.methods.retain)(self)
130    }
131
132    // drop PyBuffer without calling release
133    // after this function, the owner should use forget()
134    // or wrap PyBuffer in the ManuallyDrop to prevent drop()
135    pub(crate) unsafe fn drop_without_release(&mut self) {
136        // SAFETY: requirements forwarded from caller
137        unsafe {
138            core::ptr::drop_in_place(&mut self.obj);
139            core::ptr::drop_in_place(&mut self.desc);
140        }
141    }
142}
143
144impl<'a> TryFromBorrowedObject<'a> for PyBuffer {
145    fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult<Self> {
146        let cls = obj.class();
147        if let Some(f) = cls.slots.as_buffer {
148            return f(obj, vm);
149        }
150        Err(vm.new_type_error(format!(
151            "a bytes-like object is required, not '{}'",
152            cls.name()
153        )))
154    }
155}
156
157impl Drop for PyBuffer {
158    fn drop(&mut self) {
159        self.release();
160    }
161}
162
163#[derive(Debug, Clone)]
164pub struct BufferDescriptor {
165    /// product(shape) * itemsize
166    /// bytes length, but not the length for obj_bytes() even is contiguous
167    pub len: usize,
168    pub readonly: bool,
169    pub itemsize: usize,
170    pub format: Cow<'static, str>,
171    /// (shape, stride, suboffset) for each dimension
172    pub dim_desc: Vec<(usize, isize, isize)>,
173    // TODO: flags
174}
175
176impl BufferDescriptor {
177    pub fn simple(bytes_len: usize, readonly: bool) -> Self {
178        Self {
179            len: bytes_len,
180            readonly,
181            itemsize: 1,
182            format: Cow::Borrowed("B"),
183            dim_desc: vec![(bytes_len, 1, 0)],
184        }
185    }
186
187    pub fn format(
188        bytes_len: usize,
189        readonly: bool,
190        itemsize: usize,
191        format: Cow<'static, str>,
192    ) -> Self {
193        Self {
194            len: bytes_len,
195            readonly,
196            itemsize,
197            format,
198            dim_desc: vec![(bytes_len / itemsize, itemsize as isize, 0)],
199        }
200    }
201
202    #[cfg(debug_assertions)]
203    pub fn validate(self) -> Self {
204        // ndim=0 is valid for scalar types (e.g., ctypes Structure)
205        if self.ndim() == 0 {
206            // Empty structures (len=0) can have itemsize=0
207            if self.len > 0 {
208                assert!(self.itemsize != 0);
209            }
210            assert!(self.itemsize == self.len);
211        } else {
212            let mut shape_product = 1;
213            let has_zero_dim = self.dim_desc.iter().any(|(s, _, _)| *s == 0);
214            for (shape, stride, suboffset) in self.dim_desc.iter().cloned() {
215                shape_product *= shape;
216                assert!(suboffset >= 0);
217                // For empty arrays (any dimension is 0), strides can be 0
218                if !has_zero_dim {
219                    assert!(stride != 0);
220                }
221            }
222            assert!(shape_product * self.itemsize == self.len);
223        }
224        self
225    }
226
227    #[cfg(not(debug_assertions))]
228    pub fn validate(self) -> Self {
229        self
230    }
231
232    pub fn ndim(&self) -> usize {
233        self.dim_desc.len()
234    }
235
236    pub fn is_contiguous(&self) -> bool {
237        if self.len == 0 {
238            return true;
239        }
240        let mut sd = self.itemsize;
241        for (shape, stride, _) in self.dim_desc.iter().cloned().rev() {
242            if shape > 1 && stride != sd as isize {
243                return false;
244            }
245            sd *= shape;
246        }
247        true
248    }
249
250    /// this function do not check the bound
251    /// panic if indices.len() != ndim
252    pub fn fast_position(&self, indices: &[usize]) -> isize {
253        let mut pos = 0;
254        for (i, (_, stride, suboffset)) in indices
255            .iter()
256            .cloned()
257            .zip_eq(self.dim_desc.iter().cloned())
258        {
259            pos += i as isize * stride + suboffset;
260        }
261        pos
262    }
263
264    /// panic if indices.len() != ndim
265    pub fn position(&self, indices: &[isize], vm: &VirtualMachine) -> PyResult<isize> {
266        let mut pos = 0;
267        for (i, (shape, stride, suboffset)) in indices
268            .iter()
269            .cloned()
270            .zip_eq(self.dim_desc.iter().cloned())
271        {
272            let i = i.wrapped_at(shape).ok_or_else(|| {
273                vm.new_index_error(format!("index out of bounds on dimension {i}"))
274            })?;
275            pos += i as isize * stride + suboffset;
276        }
277        Ok(pos)
278    }
279
280    pub fn for_each_segment<F>(&self, try_contiguous: bool, mut f: F)
281    where
282        F: FnMut(Range<isize>),
283    {
284        if self.ndim() == 0 {
285            f(0..self.itemsize as isize);
286            return;
287        }
288        if try_contiguous && self.is_last_dim_contiguous() {
289            self._for_each_segment::<_, true>(0, 0, &mut f);
290        } else {
291            self._for_each_segment::<_, false>(0, 0, &mut f);
292        }
293    }
294
295    fn _for_each_segment<F, const CONTIGUOUS: bool>(&self, mut index: isize, dim: usize, f: &mut F)
296    where
297        F: FnMut(Range<isize>),
298    {
299        let (shape, stride, suboffset) = self.dim_desc[dim];
300        if dim + 1 == self.ndim() {
301            if CONTIGUOUS {
302                f(index..index + (shape * self.itemsize) as isize);
303            } else {
304                for _ in 0..shape {
305                    let pos = index + suboffset;
306                    f(pos..pos + self.itemsize as isize);
307                    index += stride;
308                }
309            }
310            return;
311        }
312        for _ in 0..shape {
313            self._for_each_segment::<F, CONTIGUOUS>(index + suboffset, dim + 1, f);
314            index += stride;
315        }
316    }
317
318    /// zip two BufferDescriptor with the same shape
319    pub fn zip_eq<F>(&self, other: &Self, try_contiguous: bool, mut f: F)
320    where
321        F: FnMut(Range<isize>, Range<isize>) -> bool,
322    {
323        if self.ndim() == 0 {
324            f(0..self.itemsize as isize, 0..other.itemsize as isize);
325            return;
326        }
327        if try_contiguous && self.is_last_dim_contiguous() {
328            self._zip_eq::<_, true>(other, 0, 0, 0, &mut f);
329        } else {
330            self._zip_eq::<_, false>(other, 0, 0, 0, &mut f);
331        }
332    }
333
334    fn _zip_eq<F, const CONTIGUOUS: bool>(
335        &self,
336        other: &Self,
337        mut a_index: isize,
338        mut b_index: isize,
339        dim: usize,
340        f: &mut F,
341    ) where
342        F: FnMut(Range<isize>, Range<isize>) -> bool,
343    {
344        let (shape, a_stride, a_suboffset) = self.dim_desc[dim];
345        let (_b_shape, b_stride, b_suboffset) = other.dim_desc[dim];
346        debug_assert_eq!(shape, _b_shape);
347        if dim + 1 == self.ndim() {
348            if CONTIGUOUS {
349                if f(
350                    a_index..a_index + (shape * self.itemsize) as isize,
351                    b_index..b_index + (shape * other.itemsize) as isize,
352                ) {
353                    return;
354                }
355            } else {
356                for _ in 0..shape {
357                    let a_pos = a_index + a_suboffset;
358                    let b_pos = b_index + b_suboffset;
359                    if f(
360                        a_pos..a_pos + self.itemsize as isize,
361                        b_pos..b_pos + other.itemsize as isize,
362                    ) {
363                        return;
364                    }
365                    a_index += a_stride;
366                    b_index += b_stride;
367                }
368            }
369            return;
370        }
371
372        for _ in 0..shape {
373            self._zip_eq::<F, CONTIGUOUS>(
374                other,
375                a_index + a_suboffset,
376                b_index + b_suboffset,
377                dim + 1,
378                f,
379            );
380            a_index += a_stride;
381            b_index += b_stride;
382        }
383    }
384
385    fn is_last_dim_contiguous(&self) -> bool {
386        let (_, stride, suboffset) = self.dim_desc[self.ndim() - 1];
387        suboffset == 0 && stride == self.itemsize as isize
388    }
389
390    pub fn is_zero_in_shape(&self) -> bool {
391        self.dim_desc.iter().any(|(shape, _, _)| *shape == 0)
392    }
393
394    // TODO: support column-major order
395}
396
397pub trait BufferResizeGuard {
398    type Resizable<'a>: 'a
399    where
400        Self: 'a;
401    fn try_resizable_opt(&self) -> Option<Self::Resizable<'_>>;
402    fn try_resizable(&self, vm: &VirtualMachine) -> PyResult<Self::Resizable<'_>> {
403        self.try_resizable_opt().ok_or_else(|| {
404            vm.new_buffer_error("Existing exports of data: object cannot be re-sized")
405        })
406    }
407}
408
409#[pyclass(module = false, name = "vec_buffer")]
410#[derive(Debug, PyPayload)]
411pub struct VecBuffer {
412    data: PyMutex<Vec<u8>>,
413}
414
415#[pyclass(flags(BASETYPE, DISALLOW_INSTANTIATION))]
416impl VecBuffer {
417    pub fn take(&self) -> Vec<u8> {
418        core::mem::take(&mut self.data.lock())
419    }
420}
421
422impl From<Vec<u8>> for VecBuffer {
423    fn from(data: Vec<u8>) -> Self {
424        Self {
425            data: PyMutex::new(data),
426        }
427    }
428}
429
430impl PyRef<VecBuffer> {
431    pub fn into_pybuffer(self, readonly: bool) -> PyBuffer {
432        let len = self.data.lock().len();
433        PyBuffer::new(
434            self.into(),
435            BufferDescriptor::simple(len, readonly),
436            &VEC_BUFFER_METHODS,
437        )
438    }
439
440    pub fn into_pybuffer_with_descriptor(self, desc: BufferDescriptor) -> PyBuffer {
441        PyBuffer::new(self.into(), desc, &VEC_BUFFER_METHODS)
442    }
443}
444
445static VEC_BUFFER_METHODS: BufferMethods = BufferMethods {
446    obj_bytes: |buffer| {
447        PyMutexGuard::map_immutable(buffer.obj_as::<VecBuffer>().data.lock(), |x| x.as_slice())
448            .into()
449    },
450    obj_bytes_mut: |buffer| {
451        PyMutexGuard::map(buffer.obj_as::<VecBuffer>().data.lock(), |x| {
452            x.as_mut_slice()
453        })
454        .into()
455    },
456    release: |_| {},
457    retain: |_| {},
458};