1use crate::{
5 Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, TryFromBorrowedObject, VirtualMachine,
6 common::{
7 borrow::{BorrowedValue, BorrowedValueMut},
8 lock::{MapImmutable, PyMutex, PyMutexGuard},
9 },
10 object::PyObjectPayload,
11 sliceable::SequenceIndexOp,
12};
13use alloc::borrow::Cow;
14use core::{fmt::Debug, ops::Range};
15use itertools::Itertools;
16
17pub struct BufferMethods {
18 pub obj_bytes: fn(&PyBuffer) -> BorrowedValue<'_, [u8]>,
19 pub obj_bytes_mut: fn(&PyBuffer) -> BorrowedValueMut<'_, [u8]>,
20 pub release: fn(&PyBuffer),
21 pub retain: fn(&PyBuffer),
22}
23
24impl Debug for BufferMethods {
25 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
26 f.debug_struct("BufferMethods")
27 .field("obj_bytes", &(self.obj_bytes as usize))
28 .field("obj_bytes_mut", &(self.obj_bytes_mut as usize))
29 .field("release", &(self.release as usize))
30 .field("retain", &(self.retain as usize))
31 .finish()
32 }
33}
34
35#[derive(Debug, Clone, Traverse)]
36pub struct PyBuffer {
37 pub obj: PyObjectRef,
38 #[pytraverse(skip)]
39 pub desc: BufferDescriptor,
40 #[pytraverse(skip)]
41 methods: &'static BufferMethods,
42}
43
44impl PyBuffer {
45 pub fn new(obj: PyObjectRef, desc: BufferDescriptor, methods: &'static BufferMethods) -> Self {
46 let zelf = Self {
47 obj,
48 desc: desc.validate(),
49 methods,
50 };
51 zelf.retain();
52 zelf
53 }
54
55 pub fn as_contiguous(&self) -> Option<BorrowedValue<'_, [u8]>> {
56 self.desc
57 .is_contiguous()
58 .then(|| unsafe { self.contiguous_unchecked() })
59 }
60
61 pub fn as_contiguous_mut(&self) -> Option<BorrowedValueMut<'_, [u8]>> {
62 (!self.desc.readonly && self.desc.is_contiguous())
63 .then(|| unsafe { self.contiguous_mut_unchecked() })
64 }
65
66 pub fn from_byte_vector(bytes: Vec<u8>, vm: &VirtualMachine) -> Self {
67 let bytes_len = bytes.len();
68 Self::new(
69 PyPayload::into_pyobject(VecBuffer::from(bytes), vm),
70 BufferDescriptor::simple(bytes_len, true),
71 &VEC_BUFFER_METHODS,
72 )
73 }
74
75 pub unsafe fn contiguous_unchecked(&self) -> BorrowedValue<'_, [u8]> {
78 self.obj_bytes()
79 }
80
81 pub unsafe fn contiguous_mut_unchecked(&self) -> BorrowedValueMut<'_, [u8]> {
84 self.obj_bytes_mut()
85 }
86
87 pub fn append_to(&self, buf: &mut Vec<u8>) {
88 if let Some(bytes) = self.as_contiguous() {
89 buf.extend_from_slice(&bytes);
90 } else {
91 let bytes = &*self.obj_bytes();
92 self.desc.for_each_segment(true, |range| {
93 buf.extend_from_slice(&bytes[range.start as usize..range.end as usize])
94 });
95 }
96 }
97
98 pub fn contiguous_or_collect<R, F: FnOnce(&[u8]) -> R>(&self, f: F) -> R {
99 let borrowed;
100 let mut collected;
101 let v = if let Some(bytes) = self.as_contiguous() {
102 borrowed = bytes;
103 &*borrowed
104 } else {
105 collected = vec![];
106 self.append_to(&mut collected);
107 &collected
108 };
109 f(v)
110 }
111
112 pub fn obj_as<T: PyObjectPayload>(&self) -> &Py<T> {
113 unsafe { self.obj.downcast_unchecked_ref() }
114 }
115
116 pub fn obj_bytes(&self) -> BorrowedValue<'_, [u8]> {
117 (self.methods.obj_bytes)(self)
118 }
119
120 pub fn obj_bytes_mut(&self) -> BorrowedValueMut<'_, [u8]> {
121 (self.methods.obj_bytes_mut)(self)
122 }
123
124 pub fn release(&self) {
125 (self.methods.release)(self)
126 }
127
128 pub fn retain(&self) {
129 (self.methods.retain)(self)
130 }
131
132 pub(crate) unsafe fn drop_without_release(&mut self) {
136 unsafe {
138 core::ptr::drop_in_place(&mut self.obj);
139 core::ptr::drop_in_place(&mut self.desc);
140 }
141 }
142}
143
144impl<'a> TryFromBorrowedObject<'a> for PyBuffer {
145 fn try_from_borrowed_object(vm: &VirtualMachine, obj: &'a PyObject) -> PyResult<Self> {
146 let cls = obj.class();
147 if let Some(f) = cls.slots.as_buffer {
148 return f(obj, vm);
149 }
150 Err(vm.new_type_error(format!(
151 "a bytes-like object is required, not '{}'",
152 cls.name()
153 )))
154 }
155}
156
157impl Drop for PyBuffer {
158 fn drop(&mut self) {
159 self.release();
160 }
161}
162
163#[derive(Debug, Clone)]
164pub struct BufferDescriptor {
165 pub len: usize,
168 pub readonly: bool,
169 pub itemsize: usize,
170 pub format: Cow<'static, str>,
171 pub dim_desc: Vec<(usize, isize, isize)>,
173 }
175
176impl BufferDescriptor {
177 pub fn simple(bytes_len: usize, readonly: bool) -> Self {
178 Self {
179 len: bytes_len,
180 readonly,
181 itemsize: 1,
182 format: Cow::Borrowed("B"),
183 dim_desc: vec![(bytes_len, 1, 0)],
184 }
185 }
186
187 pub fn format(
188 bytes_len: usize,
189 readonly: bool,
190 itemsize: usize,
191 format: Cow<'static, str>,
192 ) -> Self {
193 Self {
194 len: bytes_len,
195 readonly,
196 itemsize,
197 format,
198 dim_desc: vec![(bytes_len / itemsize, itemsize as isize, 0)],
199 }
200 }
201
202 #[cfg(debug_assertions)]
203 pub fn validate(self) -> Self {
204 if self.ndim() == 0 {
206 if self.len > 0 {
208 assert!(self.itemsize != 0);
209 }
210 assert!(self.itemsize == self.len);
211 } else {
212 let mut shape_product = 1;
213 let has_zero_dim = self.dim_desc.iter().any(|(s, _, _)| *s == 0);
214 for (shape, stride, suboffset) in self.dim_desc.iter().cloned() {
215 shape_product *= shape;
216 assert!(suboffset >= 0);
217 if !has_zero_dim {
219 assert!(stride != 0);
220 }
221 }
222 assert!(shape_product * self.itemsize == self.len);
223 }
224 self
225 }
226
227 #[cfg(not(debug_assertions))]
228 pub fn validate(self) -> Self {
229 self
230 }
231
232 pub fn ndim(&self) -> usize {
233 self.dim_desc.len()
234 }
235
236 pub fn is_contiguous(&self) -> bool {
237 if self.len == 0 {
238 return true;
239 }
240 let mut sd = self.itemsize;
241 for (shape, stride, _) in self.dim_desc.iter().cloned().rev() {
242 if shape > 1 && stride != sd as isize {
243 return false;
244 }
245 sd *= shape;
246 }
247 true
248 }
249
250 pub fn fast_position(&self, indices: &[usize]) -> isize {
253 let mut pos = 0;
254 for (i, (_, stride, suboffset)) in indices
255 .iter()
256 .cloned()
257 .zip_eq(self.dim_desc.iter().cloned())
258 {
259 pos += i as isize * stride + suboffset;
260 }
261 pos
262 }
263
264 pub fn position(&self, indices: &[isize], vm: &VirtualMachine) -> PyResult<isize> {
266 let mut pos = 0;
267 for (i, (shape, stride, suboffset)) in indices
268 .iter()
269 .cloned()
270 .zip_eq(self.dim_desc.iter().cloned())
271 {
272 let i = i.wrapped_at(shape).ok_or_else(|| {
273 vm.new_index_error(format!("index out of bounds on dimension {i}"))
274 })?;
275 pos += i as isize * stride + suboffset;
276 }
277 Ok(pos)
278 }
279
280 pub fn for_each_segment<F>(&self, try_contiguous: bool, mut f: F)
281 where
282 F: FnMut(Range<isize>),
283 {
284 if self.ndim() == 0 {
285 f(0..self.itemsize as isize);
286 return;
287 }
288 if try_contiguous && self.is_last_dim_contiguous() {
289 self._for_each_segment::<_, true>(0, 0, &mut f);
290 } else {
291 self._for_each_segment::<_, false>(0, 0, &mut f);
292 }
293 }
294
295 fn _for_each_segment<F, const CONTIGUOUS: bool>(&self, mut index: isize, dim: usize, f: &mut F)
296 where
297 F: FnMut(Range<isize>),
298 {
299 let (shape, stride, suboffset) = self.dim_desc[dim];
300 if dim + 1 == self.ndim() {
301 if CONTIGUOUS {
302 f(index..index + (shape * self.itemsize) as isize);
303 } else {
304 for _ in 0..shape {
305 let pos = index + suboffset;
306 f(pos..pos + self.itemsize as isize);
307 index += stride;
308 }
309 }
310 return;
311 }
312 for _ in 0..shape {
313 self._for_each_segment::<F, CONTIGUOUS>(index + suboffset, dim + 1, f);
314 index += stride;
315 }
316 }
317
318 pub fn zip_eq<F>(&self, other: &Self, try_contiguous: bool, mut f: F)
320 where
321 F: FnMut(Range<isize>, Range<isize>) -> bool,
322 {
323 if self.ndim() == 0 {
324 f(0..self.itemsize as isize, 0..other.itemsize as isize);
325 return;
326 }
327 if try_contiguous && self.is_last_dim_contiguous() {
328 self._zip_eq::<_, true>(other, 0, 0, 0, &mut f);
329 } else {
330 self._zip_eq::<_, false>(other, 0, 0, 0, &mut f);
331 }
332 }
333
334 fn _zip_eq<F, const CONTIGUOUS: bool>(
335 &self,
336 other: &Self,
337 mut a_index: isize,
338 mut b_index: isize,
339 dim: usize,
340 f: &mut F,
341 ) where
342 F: FnMut(Range<isize>, Range<isize>) -> bool,
343 {
344 let (shape, a_stride, a_suboffset) = self.dim_desc[dim];
345 let (_b_shape, b_stride, b_suboffset) = other.dim_desc[dim];
346 debug_assert_eq!(shape, _b_shape);
347 if dim + 1 == self.ndim() {
348 if CONTIGUOUS {
349 if f(
350 a_index..a_index + (shape * self.itemsize) as isize,
351 b_index..b_index + (shape * other.itemsize) as isize,
352 ) {
353 return;
354 }
355 } else {
356 for _ in 0..shape {
357 let a_pos = a_index + a_suboffset;
358 let b_pos = b_index + b_suboffset;
359 if f(
360 a_pos..a_pos + self.itemsize as isize,
361 b_pos..b_pos + other.itemsize as isize,
362 ) {
363 return;
364 }
365 a_index += a_stride;
366 b_index += b_stride;
367 }
368 }
369 return;
370 }
371
372 for _ in 0..shape {
373 self._zip_eq::<F, CONTIGUOUS>(
374 other,
375 a_index + a_suboffset,
376 b_index + b_suboffset,
377 dim + 1,
378 f,
379 );
380 a_index += a_stride;
381 b_index += b_stride;
382 }
383 }
384
385 fn is_last_dim_contiguous(&self) -> bool {
386 let (_, stride, suboffset) = self.dim_desc[self.ndim() - 1];
387 suboffset == 0 && stride == self.itemsize as isize
388 }
389
390 pub fn is_zero_in_shape(&self) -> bool {
391 self.dim_desc.iter().any(|(shape, _, _)| *shape == 0)
392 }
393
394 }
396
397pub trait BufferResizeGuard {
398 type Resizable<'a>: 'a
399 where
400 Self: 'a;
401 fn try_resizable_opt(&self) -> Option<Self::Resizable<'_>>;
402 fn try_resizable(&self, vm: &VirtualMachine) -> PyResult<Self::Resizable<'_>> {
403 self.try_resizable_opt().ok_or_else(|| {
404 vm.new_buffer_error("Existing exports of data: object cannot be re-sized")
405 })
406 }
407}
408
409#[pyclass(module = false, name = "vec_buffer")]
410#[derive(Debug, PyPayload)]
411pub struct VecBuffer {
412 data: PyMutex<Vec<u8>>,
413}
414
415#[pyclass(flags(BASETYPE, DISALLOW_INSTANTIATION))]
416impl VecBuffer {
417 pub fn take(&self) -> Vec<u8> {
418 core::mem::take(&mut self.data.lock())
419 }
420}
421
422impl From<Vec<u8>> for VecBuffer {
423 fn from(data: Vec<u8>) -> Self {
424 Self {
425 data: PyMutex::new(data),
426 }
427 }
428}
429
430impl PyRef<VecBuffer> {
431 pub fn into_pybuffer(self, readonly: bool) -> PyBuffer {
432 let len = self.data.lock().len();
433 PyBuffer::new(
434 self.into(),
435 BufferDescriptor::simple(len, readonly),
436 &VEC_BUFFER_METHODS,
437 )
438 }
439
440 pub fn into_pybuffer_with_descriptor(self, desc: BufferDescriptor) -> PyBuffer {
441 PyBuffer::new(self.into(), desc, &VEC_BUFFER_METHODS)
442 }
443}
444
445static VEC_BUFFER_METHODS: BufferMethods = BufferMethods {
446 obj_bytes: |buffer| {
447 PyMutexGuard::map_immutable(buffer.obj_as::<VecBuffer>().data.lock(), |x| x.as_slice())
448 .into()
449 },
450 obj_bytes_mut: |buffer| {
451 PyMutexGuard::map(buffer.obj_as::<VecBuffer>().data.lock(), |x| {
452 x.as_mut_slice()
453 })
454 .into()
455 },
456 release: |_| {},
457 retain: |_| {},
458};