use crate::dimension::Dimension;
use crate::dtype::{DType, Element};
use crate::array::arc::ArcArray;
use crate::array::owned::Array;
use crate::array::view::ArrayView;
#[repr(C)]
#[derive(Debug, Clone)]
pub struct BufferDescriptor<'a> {
pub data: *const u8,
pub ndim: usize,
pub shape: &'a [usize],
pub strides_bytes: Box<[isize]>,
pub dtype: DType,
pub itemsize: usize,
pub c_contiguous: bool,
pub f_contiguous: bool,
}
pub trait AsRawBuffer {
fn raw_ptr(&self) -> *const u8;
fn raw_shape(&self) -> &[usize];
fn raw_strides_bytes(&self) -> Vec<isize>;
fn raw_dtype(&self) -> DType;
fn is_c_contiguous(&self) -> bool;
fn is_f_contiguous(&self) -> bool;
fn buffer_descriptor(&self) -> BufferDescriptor<'_> {
let dtype = self.raw_dtype();
let shape = self.raw_shape();
BufferDescriptor {
data: self.raw_ptr(),
ndim: shape.len(),
shape,
strides_bytes: self.raw_strides_bytes().into_boxed_slice(),
dtype,
itemsize: dtype.size_of(),
c_contiguous: self.is_c_contiguous(),
f_contiguous: self.is_f_contiguous(),
}
}
}
macro_rules! impl_as_raw_buffer {
($ty:ty, $($lt:lifetime)?) => {
impl<$($lt,)? T: Element, D: Dimension> AsRawBuffer for $ty {
fn raw_ptr(&self) -> *const u8 {
self.as_ptr().cast::<u8>()
}
fn raw_shape(&self) -> &[usize] {
self.shape()
}
fn raw_strides_bytes(&self) -> Vec<isize> {
let itemsize = std::mem::size_of::<T>() as isize;
self.strides().iter().map(|&s| s * itemsize).collect()
}
fn raw_dtype(&self) -> DType {
T::dtype()
}
fn is_c_contiguous(&self) -> bool {
self.layout().is_c_contiguous()
}
fn is_f_contiguous(&self) -> bool {
self.layout().is_f_contiguous()
}
}
};
}
impl_as_raw_buffer!(Array<T, D>,);
impl_as_raw_buffer!(ArrayView<'a, T, D>, 'a);
impl_as_raw_buffer!(ArcArray<T, D>,);
#[cfg(test)]
mod tests {
use super::*;
use crate::dimension::Ix2;
#[test]
fn raw_buffer_array() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
.unwrap();
assert_eq!(arr.raw_shape(), &[2, 3]);
assert_eq!(arr.raw_dtype(), DType::F64);
assert!(arr.is_c_contiguous());
assert!(!arr.raw_ptr().is_null());
let strides = arr.raw_strides_bytes();
assert_eq!(strides, vec![24, 8]);
}
#[test]
fn raw_buffer_view() {
let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
let v = arr.view();
assert_eq!(v.raw_dtype(), DType::F32);
assert_eq!(v.raw_shape(), &[2, 2]);
assert!(v.is_c_contiguous());
}
#[test]
fn raw_buffer_arc() {
let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
let arc = ArcArray::from_owned(arr);
assert_eq!(arc.raw_dtype(), DType::I32);
assert_eq!(arc.raw_shape(), &[2, 2]);
}
#[test]
fn buffer_descriptor_aggregates_layout() {
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
let d = arr.buffer_descriptor();
assert_eq!(d.ndim, 2);
assert_eq!(d.shape, &[2, 3]);
assert_eq!(&*d.strides_bytes, &[24, 8]);
assert_eq!(d.dtype, DType::F64);
assert_eq!(d.itemsize, 8);
assert!(d.c_contiguous);
assert!(!d.f_contiguous);
assert!(!d.data.is_null());
assert_eq!(d.data, arr.raw_ptr());
}
#[test]
fn buffer_descriptor_repr_c() {
let arr = Array::<u32, Ix2>::from_vec(Ix2::new([2, 2]), vec![0u32; 4]).unwrap();
let d = arr.buffer_descriptor();
assert_eq!(d.itemsize, 4);
assert_eq!(d.dtype, DType::U32);
assert_eq!(d.shape, &[2, 2]);
}
}