1use std::ffi::CStr;
4use std::os::raw;
5use std::os::raw::c_int;
6use std::ptr::NonNull;
7use std::sync::Arc;
8
9use arrow_array::builder::BooleanBuilder;
10use arrow_array::{
11 ArrayRef, FixedSizeListArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
12 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
13};
14use arrow_buffer::{Buffer, ScalarBuffer};
15use arrow_schema::Field;
16use pyo3::buffer::{Element, ElementType, PyBuffer};
17use pyo3::exceptions::PyValueError;
18use pyo3::ffi;
19use pyo3::prelude::*;
20use pyo3::types::PyBytes;
21
22use crate::error::{PyArrowError, PyArrowResult};
23use crate::PyArray;
24
25#[pyclass(module = "arro3.core._core", name = "Buffer", subclass, frozen)]
45pub struct PyArrowBuffer(Buffer);
46
47impl AsRef<Buffer> for PyArrowBuffer {
48 fn as_ref(&self) -> &Buffer {
49 &self.0
50 }
51}
52
53impl AsRef<[u8]> for PyArrowBuffer {
54 fn as_ref(&self) -> &[u8] {
55 self.0.as_ref()
56 }
57}
58
59impl PyArrowBuffer {
60 pub fn new(buffer: Buffer) -> Self {
62 Self(buffer)
63 }
64
65 pub fn into_inner(self) -> Buffer {
67 self.0
68 }
69}
70
71#[pymethods]
72impl PyArrowBuffer {
73 #[new]
75 fn py_new(buf: PyArrowBuffer) -> Self {
76 buf
77 }
78
79 fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyBytes> {
80 PyBytes::new(py, &self.0)
81 }
82
83 fn __len__(&self) -> usize {
84 self.0.len()
85 }
86
87 unsafe fn __getbuffer__(
90 slf: PyRef<Self>,
91 view: *mut ffi::Py_buffer,
92 flags: c_int,
93 ) -> PyResult<()> {
94 let bytes = slf.0.as_slice();
95 let ret = ffi::PyBuffer_FillInfo(
96 view,
97 slf.as_ptr() as *mut _,
98 bytes.as_ptr() as *mut _,
99 bytes.len().try_into().unwrap(),
100 1, flags,
102 );
103 if ret == -1 {
104 return Err(PyErr::fetch(slf.py()));
105 }
106 Ok(())
107 }
108
109 unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
110}
111
112impl<'py> FromPyObject<'py> for PyArrowBuffer {
113 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
114 let buffer = ob.extract::<AnyBufferProtocol>()?;
115 if !matches!(buffer, AnyBufferProtocol::UInt8(_)) {
116 return Err(PyValueError::new_err("Expected u8 buffer protocol object"));
117 }
118
119 Ok(Self(buffer.into_arrow_buffer()?))
120 }
121}
122
123#[derive(Debug)]
126pub struct PyBufferWrapper<T: Element>(Option<PyBuffer<T>>);
127
128impl<T: Element> PyBufferWrapper<T> {
129 fn inner(&self) -> PyResult<&PyBuffer<T>> {
130 self.0
131 .as_ref()
132 .ok_or(PyValueError::new_err("Buffer already disposed"))
133 }
134}
135
136impl<T: Element> Drop for PyBufferWrapper<T> {
137 fn drop(&mut self) {
138 let is_initialized = unsafe { ffi::Py_IsInitialized() };
143 if let Some(val) = self.0.take() {
144 if is_initialized == 0 {
145 std::mem::forget(val);
146 } else {
147 std::mem::drop(val);
148 }
149 }
150 }
151}
152
153#[allow(missing_docs)]
155#[derive(Debug)]
156pub enum AnyBufferProtocol {
157 UInt8(PyBufferWrapper<u8>),
158 UInt16(PyBufferWrapper<u16>),
159 UInt32(PyBufferWrapper<u32>),
160 UInt64(PyBufferWrapper<u64>),
161 Int8(PyBufferWrapper<i8>),
162 Int16(PyBufferWrapper<i16>),
163 Int32(PyBufferWrapper<i32>),
164 Int64(PyBufferWrapper<i64>),
165 Float32(PyBufferWrapper<f32>),
166 Float64(PyBufferWrapper<f64>),
167}
168
169impl<'py> FromPyObject<'py> for AnyBufferProtocol {
170 fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
171 if let Ok(buf) = ob.extract::<PyBuffer<u8>>() {
172 Ok(Self::UInt8(PyBufferWrapper(Some(buf))))
173 } else if let Ok(buf) = ob.extract::<PyBuffer<u16>>() {
174 Ok(Self::UInt16(PyBufferWrapper(Some(buf))))
175 } else if let Ok(buf) = ob.extract::<PyBuffer<u32>>() {
176 Ok(Self::UInt32(PyBufferWrapper(Some(buf))))
177 } else if let Ok(buf) = ob.extract::<PyBuffer<u64>>() {
178 Ok(Self::UInt64(PyBufferWrapper(Some(buf))))
179 } else if let Ok(buf) = ob.extract::<PyBuffer<i8>>() {
180 Ok(Self::Int8(PyBufferWrapper(Some(buf))))
181 } else if let Ok(buf) = ob.extract::<PyBuffer<i16>>() {
182 Ok(Self::Int16(PyBufferWrapper(Some(buf))))
183 } else if let Ok(buf) = ob.extract::<PyBuffer<i32>>() {
184 Ok(Self::Int32(PyBufferWrapper(Some(buf))))
185 } else if let Ok(buf) = ob.extract::<PyBuffer<i64>>() {
186 Ok(Self::Int64(PyBufferWrapper(Some(buf))))
187 } else if let Ok(buf) = ob.extract::<PyBuffer<f32>>() {
188 Ok(Self::Float32(PyBufferWrapper(Some(buf))))
189 } else if let Ok(buf) = ob.extract::<PyBuffer<f64>>() {
190 Ok(Self::Float64(PyBufferWrapper(Some(buf))))
191 } else {
192 Err(PyValueError::new_err("Not a buffer protocol object"))
193 }
194 }
195}
196
197impl AnyBufferProtocol {
198 fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
199 let out = match self {
200 Self::UInt8(buf) => buf.inner()?.buf_ptr(),
201 Self::UInt16(buf) => buf.inner()?.buf_ptr(),
202 Self::UInt32(buf) => buf.inner()?.buf_ptr(),
203 Self::UInt64(buf) => buf.inner()?.buf_ptr(),
204 Self::Int8(buf) => buf.inner()?.buf_ptr(),
205 Self::Int16(buf) => buf.inner()?.buf_ptr(),
206 Self::Int32(buf) => buf.inner()?.buf_ptr(),
207 Self::Int64(buf) => buf.inner()?.buf_ptr(),
208 Self::Float32(buf) => buf.inner()?.buf_ptr(),
209 Self::Float64(buf) => buf.inner()?.buf_ptr(),
210 };
211 Ok(out)
212 }
213
214 #[allow(dead_code)]
215 fn dimensions(&self) -> PyResult<usize> {
216 let out = match self {
217 Self::UInt8(buf) => buf.inner()?.dimensions(),
218 Self::UInt16(buf) => buf.inner()?.dimensions(),
219 Self::UInt32(buf) => buf.inner()?.dimensions(),
220 Self::UInt64(buf) => buf.inner()?.dimensions(),
221 Self::Int8(buf) => buf.inner()?.dimensions(),
222 Self::Int16(buf) => buf.inner()?.dimensions(),
223 Self::Int32(buf) => buf.inner()?.dimensions(),
224 Self::Int64(buf) => buf.inner()?.dimensions(),
225 Self::Float32(buf) => buf.inner()?.dimensions(),
226 Self::Float64(buf) => buf.inner()?.dimensions(),
227 };
228 Ok(out)
229 }
230
231 fn format(&self) -> PyResult<&CStr> {
232 let out = match self {
233 Self::UInt8(buf) => buf.inner()?.format(),
234 Self::UInt16(buf) => buf.inner()?.format(),
235 Self::UInt32(buf) => buf.inner()?.format(),
236 Self::UInt64(buf) => buf.inner()?.format(),
237 Self::Int8(buf) => buf.inner()?.format(),
238 Self::Int16(buf) => buf.inner()?.format(),
239 Self::Int32(buf) => buf.inner()?.format(),
240 Self::Int64(buf) => buf.inner()?.format(),
241 Self::Float32(buf) => buf.inner()?.format(),
242 Self::Float64(buf) => buf.inner()?.format(),
243 };
244 Ok(out)
245 }
246
247 pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
267 self.validate_buffer()?;
268
269 let shape = self.shape()?.to_vec();
270
271 if shape.len() == 1 {
273 self.into_arrow_values()
274 } else {
275 assert!(shape.len() > 1, "shape cannot be 0");
276
277 let mut values = self.into_arrow_values()?;
278
279 for size in shape[1..].iter().rev() {
280 let field = Arc::new(Field::new("item", values.data_type().clone(), false));
281 let x = FixedSizeListArray::new(field, (*size).try_into().unwrap(), values, None);
282 values = Arc::new(x);
283 }
284
285 Ok(values)
286 }
287 }
288
289 fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
294 let len = self.item_count()?;
295 let len_bytes = self.len_bytes()?;
296 let ptr = NonNull::new(self.buf_ptr()? as _)
297 .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
298 let element_type = ElementType::from_format(self.format()?);
299
300 match self {
315 Self::UInt8(buf) => match element_type {
316 ElementType::Bool => {
317 let slice = NonNull::slice_from_raw_parts(ptr, len);
318 let slice = unsafe { slice.as_ref() };
319 let mut builder = BooleanBuilder::with_capacity(len);
320 for val in slice {
321 builder.append_value(*val > 0);
322 }
323 Ok(Arc::new(builder.finish()))
324 }
325 ElementType::UnsignedInteger { bytes } => {
326 if bytes != 1 {
327 return Err(PyValueError::new_err(format!(
328 "Expected 1 byte element type, got {}",
329 bytes
330 ))
331 .into());
332 }
333
334 let owner = Arc::new(buf);
335 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
336 Ok(Arc::new(UInt8Array::new(
337 ScalarBuffer::new(buffer, 0, len),
338 None,
339 )))
340 }
341 _ => Err(PyValueError::new_err(format!(
342 "Unexpected element type {:?}",
343 element_type
344 ))
345 .into()),
346 },
347 Self::UInt16(buf) => {
348 let owner = Arc::new(buf);
349 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
350 Ok(Arc::new(UInt16Array::new(
351 ScalarBuffer::new(buffer, 0, len),
352 None,
353 )))
354 }
355 Self::UInt32(buf) => {
356 let owner = Arc::new(buf);
357 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
358 Ok(Arc::new(UInt32Array::new(
359 ScalarBuffer::new(buffer, 0, len),
360 None,
361 )))
362 }
363 Self::UInt64(buf) => {
364 let owner = Arc::new(buf);
365 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
366 Ok(Arc::new(UInt64Array::new(
367 ScalarBuffer::new(buffer, 0, len),
368 None,
369 )))
370 }
371
372 Self::Int8(buf) => {
373 let owner = Arc::new(buf);
374 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
375 Ok(Arc::new(Int8Array::new(
376 ScalarBuffer::new(buffer, 0, len),
377 None,
378 )))
379 }
380 Self::Int16(buf) => {
381 let owner = Arc::new(buf);
382 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
383 Ok(Arc::new(Int16Array::new(
384 ScalarBuffer::new(buffer, 0, len),
385 None,
386 )))
387 }
388 Self::Int32(buf) => {
389 let owner = Arc::new(buf);
390 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
391 Ok(Arc::new(Int32Array::new(
392 ScalarBuffer::new(buffer, 0, len),
393 None,
394 )))
395 }
396 Self::Int64(buf) => {
397 let owner = Arc::new(buf);
398 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
399 Ok(Arc::new(Int64Array::new(
400 ScalarBuffer::new(buffer, 0, len),
401 None,
402 )))
403 }
404 Self::Float32(buf) => {
405 let owner = Arc::new(buf);
406 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
407 Ok(Arc::new(Float32Array::new(
408 ScalarBuffer::new(buffer, 0, len),
409 None,
410 )))
411 }
412 Self::Float64(buf) => {
413 let owner = Arc::new(buf);
414 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
415 Ok(Arc::new(Float64Array::new(
416 ScalarBuffer::new(buffer, 0, len),
417 None,
418 )))
419 }
420 }
421 }
422
423 pub fn into_arrow_buffer(self) -> PyArrowResult<Buffer> {
425 let len_bytes = self.len_bytes()?;
426 let ptr = NonNull::new(self.buf_ptr()? as _)
427 .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
428
429 let buffer = match self {
430 Self::UInt8(buf) => {
431 let owner = Arc::new(buf);
432 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
433 }
434 Self::UInt16(buf) => {
435 let owner = Arc::new(buf);
436 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
437 }
438 Self::UInt32(buf) => {
439 let owner = Arc::new(buf);
440 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
441 }
442 Self::UInt64(buf) => {
443 let owner = Arc::new(buf);
444 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
445 }
446 Self::Int8(buf) => {
447 let owner = Arc::new(buf);
448 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
449 }
450 Self::Int16(buf) => {
451 let owner = Arc::new(buf);
452 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
453 }
454 Self::Int32(buf) => {
455 let owner = Arc::new(buf);
456 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
457 }
458 Self::Int64(buf) => {
459 let owner = Arc::new(buf);
460 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
461 }
462 Self::Float32(buf) => {
463 let owner = Arc::new(buf);
464 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
465 }
466 Self::Float64(buf) => {
467 let owner = Arc::new(buf);
468 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
469 }
470 };
471 Ok(buffer)
472 }
473
474 fn item_count(&self) -> PyResult<usize> {
475 let out = match self {
476 Self::UInt8(buf) => buf.inner()?.item_count(),
477 Self::UInt16(buf) => buf.inner()?.item_count(),
478 Self::UInt32(buf) => buf.inner()?.item_count(),
479 Self::UInt64(buf) => buf.inner()?.item_count(),
480 Self::Int8(buf) => buf.inner()?.item_count(),
481 Self::Int16(buf) => buf.inner()?.item_count(),
482 Self::Int32(buf) => buf.inner()?.item_count(),
483 Self::Int64(buf) => buf.inner()?.item_count(),
484 Self::Float32(buf) => buf.inner()?.item_count(),
485 Self::Float64(buf) => buf.inner()?.item_count(),
486 };
487 Ok(out)
488 }
489
490 fn is_c_contiguous(&self) -> PyResult<bool> {
491 let out = match self {
492 Self::UInt8(buf) => buf.inner()?.is_c_contiguous(),
493 Self::UInt16(buf) => buf.inner()?.is_c_contiguous(),
494 Self::UInt32(buf) => buf.inner()?.is_c_contiguous(),
495 Self::UInt64(buf) => buf.inner()?.is_c_contiguous(),
496 Self::Int8(buf) => buf.inner()?.is_c_contiguous(),
497 Self::Int16(buf) => buf.inner()?.is_c_contiguous(),
498 Self::Int32(buf) => buf.inner()?.is_c_contiguous(),
499 Self::Int64(buf) => buf.inner()?.is_c_contiguous(),
500 Self::Float32(buf) => buf.inner()?.is_c_contiguous(),
501 Self::Float64(buf) => buf.inner()?.is_c_contiguous(),
502 };
503 Ok(out)
504 }
505
506 fn len_bytes(&self) -> PyResult<usize> {
507 let out = match self {
508 Self::UInt8(buf) => buf.inner()?.len_bytes(),
509 Self::UInt16(buf) => buf.inner()?.len_bytes(),
510 Self::UInt32(buf) => buf.inner()?.len_bytes(),
511 Self::UInt64(buf) => buf.inner()?.len_bytes(),
512 Self::Int8(buf) => buf.inner()?.len_bytes(),
513 Self::Int16(buf) => buf.inner()?.len_bytes(),
514 Self::Int32(buf) => buf.inner()?.len_bytes(),
515 Self::Int64(buf) => buf.inner()?.len_bytes(),
516 Self::Float32(buf) => buf.inner()?.len_bytes(),
517 Self::Float64(buf) => buf.inner()?.len_bytes(),
518 };
519 Ok(out)
520 }
521
522 fn shape(&self) -> PyResult<&[usize]> {
523 let out = match self {
524 Self::UInt8(buf) => buf.inner()?.shape(),
525 Self::UInt16(buf) => buf.inner()?.shape(),
526 Self::UInt32(buf) => buf.inner()?.shape(),
527 Self::UInt64(buf) => buf.inner()?.shape(),
528 Self::Int8(buf) => buf.inner()?.shape(),
529 Self::Int16(buf) => buf.inner()?.shape(),
530 Self::Int32(buf) => buf.inner()?.shape(),
531 Self::Int64(buf) => buf.inner()?.shape(),
532 Self::Float32(buf) => buf.inner()?.shape(),
533 Self::Float64(buf) => buf.inner()?.shape(),
534 };
535 Ok(out)
536 }
537
538 fn strides(&self) -> PyResult<&[isize]> {
539 let out = match self {
540 Self::UInt8(buf) => buf.inner()?.strides(),
541 Self::UInt16(buf) => buf.inner()?.strides(),
542 Self::UInt32(buf) => buf.inner()?.strides(),
543 Self::UInt64(buf) => buf.inner()?.strides(),
544 Self::Int8(buf) => buf.inner()?.strides(),
545 Self::Int16(buf) => buf.inner()?.strides(),
546 Self::Int32(buf) => buf.inner()?.strides(),
547 Self::Int64(buf) => buf.inner()?.strides(),
548 Self::Float32(buf) => buf.inner()?.strides(),
549 Self::Float64(buf) => buf.inner()?.strides(),
550 };
551 Ok(out)
552 }
553
554 fn validate_buffer(&self) -> PyArrowResult<()> {
555 if !self.is_c_contiguous()? {
556 return Err(PyValueError::new_err("Buffer is not C contiguous").into());
557 }
558
559 if self.shape()?.contains(&0) {
560 return Err(
561 PyValueError::new_err("0-length dimension not currently supported.").into(),
562 );
563 }
564
565 if self.strides()?.iter().any(|s| *s != 1) {
566 return Err(PyValueError::new_err(format!(
567 "strides other than 1 not supported, got: {:?} ",
568 self.strides()
569 ))
570 .into());
571 }
572
573 Ok(())
574 }
575}
576
577impl TryFrom<AnyBufferProtocol> for PyArray {
578 type Error = PyArrowError;
579
580 fn try_from(value: AnyBufferProtocol) -> Result<Self, Self::Error> {
581 let array = value.into_arrow_array()?;
582 Ok(Self::from_array_ref(array))
583 }
584}