1use std::ffi::CStr;
4use std::os::raw;
5use std::os::raw::c_int;
6use std::ptr::NonNull;
7use std::sync::Arc;
8
9use arrow_array::builder::BooleanBuilder;
10use arrow_array::{
11 ArrayRef, FixedSizeListArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array,
12 Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
13};
14use arrow_buffer::{Buffer, ScalarBuffer};
15use arrow_schema::Field;
16use pyo3::buffer::{ElementType, PyBuffer};
17use pyo3::exceptions::PyValueError;
18use pyo3::ffi;
19use pyo3::prelude::*;
20use pyo3::types::PyBytes;
21
22use crate::error::{PyArrowError, PyArrowResult};
23use crate::PyArray;
24
25#[pyclass(module = "arro3.core._core", name = "Buffer", subclass, frozen)]
45pub struct PyArrowBuffer(Buffer);
46
47impl AsRef<Buffer> for PyArrowBuffer {
48 fn as_ref(&self) -> &Buffer {
49 &self.0
50 }
51}
52
53impl AsRef<[u8]> for PyArrowBuffer {
54 fn as_ref(&self) -> &[u8] {
55 self.0.as_ref()
56 }
57}
58
59impl PyArrowBuffer {
60 pub fn new(buffer: Buffer) -> Self {
62 Self(buffer)
63 }
64
65 pub fn into_inner(self) -> Buffer {
67 self.0
68 }
69}
70
71#[pymethods]
72impl PyArrowBuffer {
73 #[new]
75 fn py_new(buf: PyArrowBuffer) -> Self {
76 buf
77 }
78
79 fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyBytes> {
80 PyBytes::new(py, &self.0)
81 }
82
83 fn __len__(&self) -> usize {
84 self.0.len()
85 }
86
87 unsafe fn __getbuffer__(
90 slf: PyRef<Self>,
91 view: *mut ffi::Py_buffer,
92 flags: c_int,
93 ) -> PyResult<()> {
94 let bytes = slf.0.as_slice();
95 let ret = ffi::PyBuffer_FillInfo(
96 view,
97 slf.as_ptr() as *mut _,
98 bytes.as_ptr() as *mut _,
99 bytes.len().try_into().unwrap(),
100 1, flags,
102 );
103 if ret == -1 {
104 return Err(PyErr::fetch(slf.py()));
105 }
106 Ok(())
107 }
108
109 unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
110}
111
112impl<'py> FromPyObject<'_, 'py> for PyArrowBuffer {
113 type Error = PyErr;
114
115 fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
116 let buffer = obj.extract::<AnyBufferProtocol>()?;
117 if !matches!(buffer, AnyBufferProtocol::UInt8(_)) {
118 return Err(PyValueError::new_err("Expected u8 buffer protocol object"));
119 }
120
121 Ok(Self(buffer.into_arrow_buffer()?))
122 }
123}
124
125#[allow(missing_docs)]
127#[derive(Debug)]
128pub enum AnyBufferProtocol {
129 UInt8(PyBuffer<u8>),
130 UInt16(PyBuffer<u16>),
131 UInt32(PyBuffer<u32>),
132 UInt64(PyBuffer<u64>),
133 Int8(PyBuffer<i8>),
134 Int16(PyBuffer<i16>),
135 Int32(PyBuffer<i32>),
136 Int64(PyBuffer<i64>),
137 Float32(PyBuffer<f32>),
138 Float64(PyBuffer<f64>),
139}
140
141impl<'py> FromPyObject<'_, 'py> for AnyBufferProtocol {
142 type Error = PyErr;
143
144 fn extract(obj: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
145 if let Ok(buf) = obj.extract() {
146 Ok(Self::UInt8(buf))
147 } else if let Ok(buf) = obj.extract() {
148 Ok(Self::UInt16(buf))
149 } else if let Ok(buf) = obj.extract() {
150 Ok(Self::UInt32(buf))
151 } else if let Ok(buf) = obj.extract() {
152 Ok(Self::UInt64(buf))
153 } else if let Ok(buf) = obj.extract() {
154 Ok(Self::Int8(buf))
155 } else if let Ok(buf) = obj.extract() {
156 Ok(Self::Int16(buf))
157 } else if let Ok(buf) = obj.extract() {
158 Ok(Self::Int32(buf))
159 } else if let Ok(buf) = obj.extract() {
160 Ok(Self::Int64(buf))
161 } else if let Ok(buf) = obj.extract() {
162 Ok(Self::Float32(buf))
163 } else if let Ok(buf) = obj.extract() {
164 Ok(Self::Float64(buf))
165 } else {
166 Err(PyValueError::new_err("Not a buffer protocol object"))
167 }
168 }
169}
170
171impl AnyBufferProtocol {
172 fn buf_ptr(&self) -> PyResult<*mut raw::c_void> {
173 let out = match self {
174 Self::UInt8(buf) => buf.buf_ptr(),
175 Self::UInt16(buf) => buf.buf_ptr(),
176 Self::UInt32(buf) => buf.buf_ptr(),
177 Self::UInt64(buf) => buf.buf_ptr(),
178 Self::Int8(buf) => buf.buf_ptr(),
179 Self::Int16(buf) => buf.buf_ptr(),
180 Self::Int32(buf) => buf.buf_ptr(),
181 Self::Int64(buf) => buf.buf_ptr(),
182 Self::Float32(buf) => buf.buf_ptr(),
183 Self::Float64(buf) => buf.buf_ptr(),
184 };
185 Ok(out)
186 }
187
188 #[allow(dead_code)]
189 fn dimensions(&self) -> PyResult<usize> {
190 let out = match self {
191 Self::UInt8(buf) => buf.dimensions(),
192 Self::UInt16(buf) => buf.dimensions(),
193 Self::UInt32(buf) => buf.dimensions(),
194 Self::UInt64(buf) => buf.dimensions(),
195 Self::Int8(buf) => buf.dimensions(),
196 Self::Int16(buf) => buf.dimensions(),
197 Self::Int32(buf) => buf.dimensions(),
198 Self::Int64(buf) => buf.dimensions(),
199 Self::Float32(buf) => buf.dimensions(),
200 Self::Float64(buf) => buf.dimensions(),
201 };
202 Ok(out)
203 }
204
205 fn format(&self) -> PyResult<&CStr> {
206 let out = match self {
207 Self::UInt8(buf) => buf.format(),
208 Self::UInt16(buf) => buf.format(),
209 Self::UInt32(buf) => buf.format(),
210 Self::UInt64(buf) => buf.format(),
211 Self::Int8(buf) => buf.format(),
212 Self::Int16(buf) => buf.format(),
213 Self::Int32(buf) => buf.format(),
214 Self::Int64(buf) => buf.format(),
215 Self::Float32(buf) => buf.format(),
216 Self::Float64(buf) => buf.format(),
217 };
218 Ok(out)
219 }
220
221 pub fn into_arrow_array(self) -> PyArrowResult<ArrayRef> {
241 self.validate_buffer()?;
242
243 let shape = self.shape()?.to_vec();
244
245 if shape.len() == 1 {
247 self.into_arrow_values()
248 } else {
249 assert!(shape.len() > 1, "shape cannot be 0");
250
251 let mut values = self.into_arrow_values()?;
252
253 for size in shape[1..].iter().rev() {
254 let field = Arc::new(Field::new("item", values.data_type().clone(), false));
255 let x = FixedSizeListArray::new(field, (*size).try_into().unwrap(), values, None);
256 values = Arc::new(x);
257 }
258
259 Ok(values)
260 }
261 }
262
263 fn into_arrow_values(self) -> PyArrowResult<ArrayRef> {
268 let len = self.item_count()?;
269 let len_bytes = self.len_bytes()?;
270 let ptr = NonNull::new(self.buf_ptr()? as _)
271 .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
272 let element_type = ElementType::from_format(self.format()?);
273
274 match self {
289 Self::UInt8(buf) => match element_type {
290 ElementType::Bool => {
291 let slice = NonNull::slice_from_raw_parts(ptr, len);
292 let slice = unsafe { slice.as_ref() };
293 let mut builder = BooleanBuilder::with_capacity(len);
294 for val in slice {
295 builder.append_value(*val > 0);
296 }
297 Ok(Arc::new(builder.finish()))
298 }
299 ElementType::UnsignedInteger { bytes } => {
300 if bytes != 1 {
301 return Err(PyValueError::new_err(format!(
302 "Expected 1 byte element type, got {}",
303 bytes
304 ))
305 .into());
306 }
307
308 let owner = Arc::new(buf);
309 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
310 Ok(Arc::new(UInt8Array::new(
311 ScalarBuffer::new(buffer, 0, len),
312 None,
313 )))
314 }
315 _ => Err(PyValueError::new_err(format!(
316 "Unexpected element type {:?}",
317 element_type
318 ))
319 .into()),
320 },
321 Self::UInt16(buf) => {
322 let owner = Arc::new(buf);
323 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
324 Ok(Arc::new(UInt16Array::new(
325 ScalarBuffer::new(buffer, 0, len),
326 None,
327 )))
328 }
329 Self::UInt32(buf) => {
330 let owner = Arc::new(buf);
331 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
332 Ok(Arc::new(UInt32Array::new(
333 ScalarBuffer::new(buffer, 0, len),
334 None,
335 )))
336 }
337 Self::UInt64(buf) => {
338 let owner = Arc::new(buf);
339 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
340 Ok(Arc::new(UInt64Array::new(
341 ScalarBuffer::new(buffer, 0, len),
342 None,
343 )))
344 }
345
346 Self::Int8(buf) => {
347 let owner = Arc::new(buf);
348 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
349 Ok(Arc::new(Int8Array::new(
350 ScalarBuffer::new(buffer, 0, len),
351 None,
352 )))
353 }
354 Self::Int16(buf) => {
355 let owner = Arc::new(buf);
356 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
357 Ok(Arc::new(Int16Array::new(
358 ScalarBuffer::new(buffer, 0, len),
359 None,
360 )))
361 }
362 Self::Int32(buf) => {
363 let owner = Arc::new(buf);
364 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
365 Ok(Arc::new(Int32Array::new(
366 ScalarBuffer::new(buffer, 0, len),
367 None,
368 )))
369 }
370 Self::Int64(buf) => {
371 let owner = Arc::new(buf);
372 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
373 Ok(Arc::new(Int64Array::new(
374 ScalarBuffer::new(buffer, 0, len),
375 None,
376 )))
377 }
378 Self::Float32(buf) => {
379 let owner = Arc::new(buf);
380 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
381 Ok(Arc::new(Float32Array::new(
382 ScalarBuffer::new(buffer, 0, len),
383 None,
384 )))
385 }
386 Self::Float64(buf) => {
387 let owner = Arc::new(buf);
388 let buffer = unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) };
389 Ok(Arc::new(Float64Array::new(
390 ScalarBuffer::new(buffer, 0, len),
391 None,
392 )))
393 }
394 }
395 }
396
397 pub fn into_arrow_buffer(self) -> PyArrowResult<Buffer> {
399 let len_bytes = self.len_bytes()?;
400 let ptr = NonNull::new(self.buf_ptr()? as _)
401 .ok_or(PyValueError::new_err("Expected buffer ptr to be non null"))?;
402
403 let buffer = match self {
404 Self::UInt8(buf) => {
405 let owner = Arc::new(buf);
406 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
407 }
408 Self::UInt16(buf) => {
409 let owner = Arc::new(buf);
410 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
411 }
412 Self::UInt32(buf) => {
413 let owner = Arc::new(buf);
414 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
415 }
416 Self::UInt64(buf) => {
417 let owner = Arc::new(buf);
418 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
419 }
420 Self::Int8(buf) => {
421 let owner = Arc::new(buf);
422 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
423 }
424 Self::Int16(buf) => {
425 let owner = Arc::new(buf);
426 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
427 }
428 Self::Int32(buf) => {
429 let owner = Arc::new(buf);
430 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
431 }
432 Self::Int64(buf) => {
433 let owner = Arc::new(buf);
434 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
435 }
436 Self::Float32(buf) => {
437 let owner = Arc::new(buf);
438 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
439 }
440 Self::Float64(buf) => {
441 let owner = Arc::new(buf);
442 unsafe { Buffer::from_custom_allocation(ptr, len_bytes, owner) }
443 }
444 };
445 Ok(buffer)
446 }
447
448 fn item_count(&self) -> PyResult<usize> {
449 let out = match self {
450 Self::UInt8(buf) => buf.item_count(),
451 Self::UInt16(buf) => buf.item_count(),
452 Self::UInt32(buf) => buf.item_count(),
453 Self::UInt64(buf) => buf.item_count(),
454 Self::Int8(buf) => buf.item_count(),
455 Self::Int16(buf) => buf.item_count(),
456 Self::Int32(buf) => buf.item_count(),
457 Self::Int64(buf) => buf.item_count(),
458 Self::Float32(buf) => buf.item_count(),
459 Self::Float64(buf) => buf.item_count(),
460 };
461 Ok(out)
462 }
463
464 fn is_c_contiguous(&self) -> PyResult<bool> {
465 let out = match self {
466 Self::UInt8(buf) => buf.is_c_contiguous(),
467 Self::UInt16(buf) => buf.is_c_contiguous(),
468 Self::UInt32(buf) => buf.is_c_contiguous(),
469 Self::UInt64(buf) => buf.is_c_contiguous(),
470 Self::Int8(buf) => buf.is_c_contiguous(),
471 Self::Int16(buf) => buf.is_c_contiguous(),
472 Self::Int32(buf) => buf.is_c_contiguous(),
473 Self::Int64(buf) => buf.is_c_contiguous(),
474 Self::Float32(buf) => buf.is_c_contiguous(),
475 Self::Float64(buf) => buf.is_c_contiguous(),
476 };
477 Ok(out)
478 }
479
480 fn len_bytes(&self) -> PyResult<usize> {
481 let out = match self {
482 Self::UInt8(buf) => buf.len_bytes(),
483 Self::UInt16(buf) => buf.len_bytes(),
484 Self::UInt32(buf) => buf.len_bytes(),
485 Self::UInt64(buf) => buf.len_bytes(),
486 Self::Int8(buf) => buf.len_bytes(),
487 Self::Int16(buf) => buf.len_bytes(),
488 Self::Int32(buf) => buf.len_bytes(),
489 Self::Int64(buf) => buf.len_bytes(),
490 Self::Float32(buf) => buf.len_bytes(),
491 Self::Float64(buf) => buf.len_bytes(),
492 };
493 Ok(out)
494 }
495
496 fn shape(&self) -> PyResult<&[usize]> {
497 let out = match self {
498 Self::UInt8(buf) => buf.shape(),
499 Self::UInt16(buf) => buf.shape(),
500 Self::UInt32(buf) => buf.shape(),
501 Self::UInt64(buf) => buf.shape(),
502 Self::Int8(buf) => buf.shape(),
503 Self::Int16(buf) => buf.shape(),
504 Self::Int32(buf) => buf.shape(),
505 Self::Int64(buf) => buf.shape(),
506 Self::Float32(buf) => buf.shape(),
507 Self::Float64(buf) => buf.shape(),
508 };
509 Ok(out)
510 }
511
512 fn validate_buffer(&self) -> PyArrowResult<()> {
513 if !self.is_c_contiguous()? {
514 return Err(PyValueError::new_err("Buffer is not C contiguous").into());
515 }
516
517 if self.shape()?.contains(&0) {
518 return Err(
519 PyValueError::new_err("0-length dimension not currently supported.").into(),
520 );
521 }
522
523 Ok(())
527 }
528}
529
530impl TryFrom<AnyBufferProtocol> for PyArray {
531 type Error = PyArrowError;
532
533 fn try_from(value: AnyBufferProtocol) -> Result<Self, Self::Error> {
534 let array = value.into_arrow_array()?;
535 Ok(Self::from_array_ref(array))
536 }
537}