pyo3_arrow/
buffer.rs

1//! Support for Python buffer protocol
2
3use std::ffi::CStr;
4use std::os::raw;
5use std::os::raw::c_int;
6use std::ptr::NonNull;
7use std::sync::Arc;
8
9use arrow_array::builder::BooleanBuilder;
10use arrow_array::{
11    ArrayRef, FixedSizeListArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
12    Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
13};
14use arrow_buffer::{Buffer, ScalarBuffer};
15use arrow_schema::Field;
16use pyo3::buffer::{ElementType, PyBuffer};
17use pyo3::exceptions::PyValueError;
18use pyo3::ffi;
19use pyo3::prelude::*;
20use pyo3::types::PyBytes;
21
22use crate::error::{PyArrowError, PyArrowResult};
23use crate::PyArray;
24
25/// A wrapper around an Arrow [Buffer].
26///
27/// This implements both import and export via the Python buffer protocol.
28///
29/// ### Buffer import
30///
31/// This can be very useful as a general way to support ingest of a Python buffer protocol object.
32/// The underlying Arrow [Buffer] manages the external memory, automatically calling the Python
33/// buffer's release callback when the Arrow [Buffer] reference count reaches 0.
34///
35/// This does not need to be used with Arrow at all! This can be used with any API where you want
36/// to handle both Python-provided and Rust-provided buffers. [`PyArrowBuffer`] implements
37/// `AsRef<[u8]>`.
38///
39/// ### Buffer export
40///
41/// The Python buffer protocol is implemented on this buffer to enable zero-copy data transfer of
42/// the core buffer into Python. This allows for zero-copy data sharing with numpy via
43/// `numpy.frombuffer`.
44#[pyclass(module = "arro3.core._core", name = "Buffer", subclass, frozen)]
45pub struct PyArrowBuffer(Buffer);
46
47impl AsRef<Buffer> for PyArrowBuffer {
48    fn as_ref(&self) -> &Buffer {
49        &self.0
50    }
51}
52
53impl AsRef<[u8]> for PyArrowBuffer {
54    fn as_ref(&self) -> &[u8] {
55        self.0.as_ref()
56    }
57}
58
59impl PyArrowBuffer {
60    /// Construct a new [PyArrowBuffer]
61    pub fn new(buffer: Buffer) -> Self {
62        Self(buffer)
63    }
64
65    /// Consume and return the [Buffer]
66    pub fn into_inner(self) -> Buffer {
67        self.0
68    }
69}
70
71#[pymethods]
72impl PyArrowBuffer {
73    /// new
74    #[new]
75    fn py_new(buf: PyArrowBuffer) -> Self {
76        buf
77    }
78
79    fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyBytes> {
80        PyBytes::new(py, &self.0)
81    }
82
83    fn __len__(&self) -> usize {
84        self.0.len()
85    }
86
87    /// This is taken from opendal:
88    /// https://github.com/apache/opendal/blob/d001321b0f9834bc1e2e7d463bcfdc3683e968c9/bindings/python/src/utils.rs#L51-L72
89    unsafe fn __getbuffer__(
90        slf: PyRef<Self>,
91        view: *mut ffi::Py_buffer,
92        flags: c_int,
93    ) -> PyResult<()> {
94        let bytes = slf.0.as_slice();
95        let ret = ffi::PyBuffer_FillInfo(
96            view,
97            slf.as_ptr() as *mut _,
98            bytes.as_ptr() as *mut _,
99            bytes.len().try_into().unwrap(),
100            1, // read only
101            flags,
102        );
103        if ret == -1 {
104            return Err(PyErr::fetch(slf.py()));
105        }
106        Ok(())
107    }
108
109    unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
110}
111
112impl<'py> FromPyObject<'_, 'py> for PyArrowBuffer {
113    type Error = PyErr;
114
115    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
116        let buffer = obj.extract::<AnyBufferProtocol>()?;
117        if !matches!(buffer, AnyBufferProtocol::UInt8(_)) {
118            return Err(PyValueError::new_err("Expected u8 buffer protocol object"));
119        }
120
121        Ok(Self(buffer.into_arrow_buffer()?))
122    }
123}
124
125/// An enum over buffer protocol input types.
126#[allow(missing_docs)]
127#[derive(Debug)]
128pub enum AnyBufferProtocol {
129    UInt8(PyBuffer<u8>),
130    UInt16(PyBuffer<u16>),
131    UInt32(PyBuffer<u32>),
132    UInt64(PyBuffer<u64>),
133    Int8(PyBuffer<i8>),
134    Int16(PyBuffer<i16>),
135    Int32(PyBuffer<i32>),
136    Int64(PyBuffer<i64>),
137    Float32(PyBuffer<f32>),
138    Float64(PyBuffer<f64>),
139}
140
141impl<'py> FromPyObject<'_, 'py> for AnyBufferProtocol {
142    type Error = PyErr;
143
144    fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
145        if let Ok(buf) = obj.extract() {
146            Ok(Self::UInt8(buf))
147        } else if let Ok(buf) = obj.extract() {
148            Ok(Self::UInt16(buf))
149        } else if let Ok(buf) = obj.extract() {
150            Ok(Self::UInt32(buf))
151        } else if let Ok(buf) = obj.extract() {
152            Ok(Self::UInt64(buf))
153        } else if let Ok(buf) = obj.extract() {
154            Ok(Self::Int8(buf))
155        } else if let Ok(buf) = obj.extract() {
156            Ok(Self::Int16(buf))
157        } else if let Ok(buf) = obj.extract() {
158            Ok(Self::Int32(buf))
159        } else if let Ok(buf) = obj.extract() {
160            Ok(Self::Int64(buf))
161        } else if let Ok(buf) = obj.extract() {
162            Ok(Self::Float32(buf))
163        } else if let Ok(buf) = obj.extract() {
164            Ok(Self::Float64(buf))
165        } else {
166            Err(PyValueError::new_err("Not a buffer protocol object"))
167        }
168    }
169}
170
171impl AnyBufferProtocol {
172    fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
173        let out = match self {
174            Self::UInt8(buf) => buf.buf_ptr(),
175            Self::UInt16(buf) => buf.buf_ptr(),
176            Self::UInt32(buf) => buf.buf_ptr(),
177            Self::UInt64(buf) => buf.buf_ptr(),
178            Self::Int8(buf) => buf.buf_ptr(),
179            Self::Int16(buf) => buf.buf_ptr(),
180            Self::Int32(buf) => buf.buf_ptr(),
181            Self::Int64(buf) => buf.buf_ptr(),
182            Self::Float32(buf) => buf.buf_ptr(),
183            Self::Float64(buf) => buf.buf_ptr(),
184        };
185        Ok(out)
186    }
187
188    #[allow(dead_code)]
189    fn dimensions(&self) -> PyResult<usize> {
190        let out = match self {
191            Self::UInt8(buf) => buf.dimensions(),
192            Self::UInt16(buf) => buf.dimensions(),
193            Self::UInt32(buf) => buf.dimensions(),
194            Self::UInt64(buf) => buf.dimensions(),
195            Self::Int8(buf) => buf.dimensions(),
196            Self::Int16(buf) => buf.dimensions(),
197            Self::Int32(buf) => buf.dimensions(),
198            Self::Int64(buf) => buf.dimensions(),
199            Self::Float32(buf) => buf.dimensions(),
200            Self::Float64(buf) => buf.dimensions(),
201        };
202        Ok(out)
203    }
204
205    fn format(&self) -> PyResult<&CStr> {
206        let out = match self {
207            Self::UInt8(buf) => buf.format(),
208            Self::UInt16(buf) => buf.format(),
209            Self::UInt32(buf) => buf.format(),
210            Self::UInt64(buf) => buf.format(),
211            Self::Int8(buf) => buf.format(),
212            Self::Int16(buf) => buf.format(),
213            Self::Int32(buf) => buf.format(),
214            Self::Int64(buf) => buf.format(),
215            Self::Float32(buf) => buf.format(),
216            Self::Float64(buf) => buf.format(),
217        };
218        Ok(out)
219    }
220
221    /// Consume this and convert to an Arrow [`ArrayRef`].
222    ///
223    /// For almost all buffer protocol objects this is zero-copy. Only boolean-typed buffers need
224    /// to be copied, because boolean Python buffers are one _byte_ per element, while Arrow
225    /// buffers are one _bit_ per element. All numeric buffers are zero-copy compatible.
226    ///
227    /// This uses [`Buffer::from_custom_allocation`][], which creates Arrow buffers from existing
228    /// memory regions. The [`Buffer`] tracks ownership of the [`PyBuffer`] memory via reference
229    /// counting. The [`PyBuffer`]'s release callback will be called when the Arrow [`Buffer`] sees
230    /// that the `PyBuffer`'s reference count
231    /// reaches zero.
232    ///
233    /// ## Safety
234    ///
235    /// - This assumes that the Python buffer is immutable. Immutability is not guaranteed by the
236    ///   Python buffer protocol, so the end user must uphold this. Mutating a Python buffer could
237    ///   lead to undefined behavior.
238    // Note: in the future, maybe you should check item alignment as well?
239    // https://github.com/PyO3/pyo3/blob/ce18f79d71f4d3eac54f55f7633cf08d2f57b64e/src/buffer.rs#L217-L221
240    pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
241        self.validate_buffer()?;
242
243        let shape = self.shape()?.to_vec();
244
245        // Handle multi dimensional arrays by wrapping in FixedSizeLists
246        if shape.len() == 1 {
247            self.into_arrow_values()
248        } else {
249            assert!(shape.len() > 1, "shape cannot be 0");
250
251            let mut values = self.into_arrow_values()?;
252
253            for size in shape[1..].iter().rev() {
254                let field = Arc::new(Field::new("item", values.data_type().clone(), false));
255                let x = FixedSizeListArray::new(field, (*size).try_into().unwrap(), values, None);
256                values = Arc::new(x);
257            }
258
259            Ok(values)
260        }
261    }
262
263    /// Convert the raw buffer to an [ArrayRef].
264    ///
265    /// In `into_arrow_array` the values will be wrapped in FixedSizeLists if needed for multi
266    /// dimensional input.
267    fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
268        let len = self.item_count()?;
269        let len_bytes = self.len_bytes()?;
270        let ptr = NonNull::new(self.buf_ptr()? as _)
271            .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
272        let element_type = ElementType::from_format(self.format()?);
273
274        // TODO: couldn't get this macro to work with error
275        // cannot find value `buf` in this scope
276        //
277        // macro_rules! impl_array {
278        //     ($array_type:ty) => {
279        //         let owner = Arc::new(buf);
280        //         let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
281        //         Ok(Arc::new(PrimitiveArray::<$array_type>::new(
282        //             ScalarBuffer::new(buffer, 0, len),
283        //             None,
284        //         )))
285        //     };
286        // }
287
288        match self {
289            Self::UInt8(buf) => match element_type {
290                ElementType::Bool => {
291                    let slice = NonNull::slice_from_raw_parts(ptr, len);
292                    let slice = unsafe { slice.as_ref() };
293                    let mut builder = BooleanBuilder::with_capacity(len);
294                    for val in slice {
295                        builder.append_value(*val > 0);
296                    }
297                    Ok(Arc::new(builder.finish()))
298                }
299                ElementType::UnsignedInteger { bytes } => {
300                    if bytes != 1 {
301                        return Err(PyValueError::new_err(format!(
302                            "Expected 1 byte element type, got {}",
303                            bytes
304                        ))
305                        .into());
306                    }
307
308                    let owner = Arc::new(buf);
309                    let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
310                    Ok(Arc::new(UInt8Array::new(
311                        ScalarBuffer::new(buffer, 0, len),
312                        None,
313                    )))
314                }
315                _ => Err(PyValueError::new_err(format!(
316                    "Unexpected element type {:?}",
317                    element_type
318                ))
319                .into()),
320            },
321            Self::UInt16(buf) => {
322                let owner = Arc::new(buf);
323                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
324                Ok(Arc::new(UInt16Array::new(
325                    ScalarBuffer::new(buffer, 0, len),
326                    None,
327                )))
328            }
329            Self::UInt32(buf) => {
330                let owner = Arc::new(buf);
331                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
332                Ok(Arc::new(UInt32Array::new(
333                    ScalarBuffer::new(buffer, 0, len),
334                    None,
335                )))
336            }
337            Self::UInt64(buf) => {
338                let owner = Arc::new(buf);
339                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
340                Ok(Arc::new(UInt64Array::new(
341                    ScalarBuffer::new(buffer, 0, len),
342                    None,
343                )))
344            }
345
346            Self::Int8(buf) => {
347                let owner = Arc::new(buf);
348                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
349                Ok(Arc::new(Int8Array::new(
350                    ScalarBuffer::new(buffer, 0, len),
351                    None,
352                )))
353            }
354            Self::Int16(buf) => {
355                let owner = Arc::new(buf);
356                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
357                Ok(Arc::new(Int16Array::new(
358                    ScalarBuffer::new(buffer, 0, len),
359                    None,
360                )))
361            }
362            Self::Int32(buf) => {
363                let owner = Arc::new(buf);
364                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
365                Ok(Arc::new(Int32Array::new(
366                    ScalarBuffer::new(buffer, 0, len),
367                    None,
368                )))
369            }
370            Self::Int64(buf) => {
371                let owner = Arc::new(buf);
372                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
373                Ok(Arc::new(Int64Array::new(
374                    ScalarBuffer::new(buffer, 0, len),
375                    None,
376                )))
377            }
378            Self::Float32(buf) => {
379                let owner = Arc::new(buf);
380                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
381                Ok(Arc::new(Float32Array::new(
382                    ScalarBuffer::new(buffer, 0, len),
383                    None,
384                )))
385            }
386            Self::Float64(buf) => {
387                let owner = Arc::new(buf);
388                let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
389                Ok(Arc::new(Float64Array::new(
390                    ScalarBuffer::new(buffer, 0, len),
391                    None,
392                )))
393            }
394        }
395    }
396
397    /// Consume this buffer protocol object and convert to an Arrow [Buffer].
398    pub fn into_arrow_buffer(self) -> PyArrowResult<Buffer> {
399        let len_bytes = self.len_bytes()?;
400        let ptr = NonNull::new(self.buf_ptr()? as _)
401            .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
402
403        let buffer = match self {
404            Self::UInt8(buf) => {
405                let owner = Arc::new(buf);
406                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
407            }
408            Self::UInt16(buf) => {
409                let owner = Arc::new(buf);
410                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
411            }
412            Self::UInt32(buf) => {
413                let owner = Arc::new(buf);
414                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
415            }
416            Self::UInt64(buf) => {
417                let owner = Arc::new(buf);
418                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
419            }
420            Self::Int8(buf) => {
421                let owner = Arc::new(buf);
422                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
423            }
424            Self::Int16(buf) => {
425                let owner = Arc::new(buf);
426                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
427            }
428            Self::Int32(buf) => {
429                let owner = Arc::new(buf);
430                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
431            }
432            Self::Int64(buf) => {
433                let owner = Arc::new(buf);
434                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
435            }
436            Self::Float32(buf) => {
437                let owner = Arc::new(buf);
438                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
439            }
440            Self::Float64(buf) => {
441                let owner = Arc::new(buf);
442                unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
443            }
444        };
445        Ok(buffer)
446    }
447
448    fn item_count(&self) -> PyResult<usize> {
449        let out = match self {
450            Self::UInt8(buf) => buf.item_count(),
451            Self::UInt16(buf) => buf.item_count(),
452            Self::UInt32(buf) => buf.item_count(),
453            Self::UInt64(buf) => buf.item_count(),
454            Self::Int8(buf) => buf.item_count(),
455            Self::Int16(buf) => buf.item_count(),
456            Self::Int32(buf) => buf.item_count(),
457            Self::Int64(buf) => buf.item_count(),
458            Self::Float32(buf) => buf.item_count(),
459            Self::Float64(buf) => buf.item_count(),
460        };
461        Ok(out)
462    }
463
464    fn is_c_contiguous(&self) -> PyResult<bool> {
465        let out = match self {
466            Self::UInt8(buf) => buf.is_c_contiguous(),
467            Self::UInt16(buf) => buf.is_c_contiguous(),
468            Self::UInt32(buf) => buf.is_c_contiguous(),
469            Self::UInt64(buf) => buf.is_c_contiguous(),
470            Self::Int8(buf) => buf.is_c_contiguous(),
471            Self::Int16(buf) => buf.is_c_contiguous(),
472            Self::Int32(buf) => buf.is_c_contiguous(),
473            Self::Int64(buf) => buf.is_c_contiguous(),
474            Self::Float32(buf) => buf.is_c_contiguous(),
475            Self::Float64(buf) => buf.is_c_contiguous(),
476        };
477        Ok(out)
478    }
479
480    fn len_bytes(&self) -> PyResult<usize> {
481        let out = match self {
482            Self::UInt8(buf) => buf.len_bytes(),
483            Self::UInt16(buf) => buf.len_bytes(),
484            Self::UInt32(buf) => buf.len_bytes(),
485            Self::UInt64(buf) => buf.len_bytes(),
486            Self::Int8(buf) => buf.len_bytes(),
487            Self::Int16(buf) => buf.len_bytes(),
488            Self::Int32(buf) => buf.len_bytes(),
489            Self::Int64(buf) => buf.len_bytes(),
490            Self::Float32(buf) => buf.len_bytes(),
491            Self::Float64(buf) => buf.len_bytes(),
492        };
493        Ok(out)
494    }
495
496    fn shape(&self) -> PyResult<&[usize]> {
497        let out = match self {
498            Self::UInt8(buf) => buf.shape(),
499            Self::UInt16(buf) => buf.shape(),
500            Self::UInt32(buf) => buf.shape(),
501            Self::UInt64(buf) => buf.shape(),
502            Self::Int8(buf) => buf.shape(),
503            Self::Int16(buf) => buf.shape(),
504            Self::Int32(buf) => buf.shape(),
505            Self::Int64(buf) => buf.shape(),
506            Self::Float32(buf) => buf.shape(),
507            Self::Float64(buf) => buf.shape(),
508        };
509        Ok(out)
510    }
511
512    fn validate_buffer(&self) -> PyArrowResult<()> {
513        if !self.is_c_contiguous()? {
514            return Err(PyValueError::new_err("Buffer is not C contiguous").into());
515        }
516
517        if self.shape()?.contains(&0) {
518            return Err(
519                PyValueError::new_err("0-length dimension not currently supported.").into(),
520            );
521        }
522
523        // Note: since we already checked for C-contiguous, we don't need to check for strides to
524        // be contiguous.
525
526        Ok(())
527    }
528}
529
530impl TryFrom<AnyBufferProtocol> for PyArray {
531    type Error = PyArrowError;
532
533    fn try_from(value: AnyBufferProtocol) -> Result<Self, Self::Error> {
534        let array = value.into_arrow_array()?;
535        Ok(Self::from_array_ref(array))
536    }
537}