1use crate::dimension::Dimension;
4use crate::dtype::{DType, Element};
5
6use crate::array::arc::ArcArray;
7use crate::array::owned::Array;
8use crate::array::view::ArrayView;
9
10#[repr(C)]
24#[derive(Debug, Clone)]
25pub struct BufferDescriptor<'a> {
26 pub data: *const u8,
28 pub ndim: usize,
30 pub shape: &'a [usize],
32 pub strides_bytes: Box<[isize]>,
34 pub dtype: DType,
36 pub itemsize: usize,
39 pub c_contiguous: bool,
41 pub f_contiguous: bool,
43}
44
45pub trait AsRawBuffer {
50 fn raw_ptr(&self) -> *const u8;
52
53 fn raw_shape(&self) -> &[usize];
55
56 fn raw_strides_bytes(&self) -> Vec<isize>;
58
59 fn raw_dtype(&self) -> DType;
61
62 fn is_c_contiguous(&self) -> bool;
64
65 fn is_f_contiguous(&self) -> bool;
67
68 fn buffer_descriptor(&self) -> BufferDescriptor<'_> {
72 let dtype = self.raw_dtype();
73 let shape = self.raw_shape();
74 BufferDescriptor {
75 data: self.raw_ptr(),
76 ndim: shape.len(),
77 shape,
78 strides_bytes: self.raw_strides_bytes().into_boxed_slice(),
79 dtype,
80 itemsize: dtype.size_of(),
81 c_contiguous: self.is_c_contiguous(),
82 f_contiguous: self.is_f_contiguous(),
83 }
84 }
85}
86
87macro_rules! impl_as_raw_buffer {
94 ($ty:ty, $($lt:lifetime)?) => {
95 impl<$($lt,)? T: Element, D: Dimension> AsRawBuffer for $ty {
96 fn raw_ptr(&self) -> *const u8 {
97 self.as_ptr().cast::<u8>()
98 }
99
100 fn raw_shape(&self) -> &[usize] {
101 self.shape()
102 }
103
104 fn raw_strides_bytes(&self) -> Vec<isize> {
105 let itemsize = std::mem::size_of::<T>() as isize;
106 self.strides().iter().map(|&s| s * itemsize).collect()
107 }
108
109 fn raw_dtype(&self) -> DType {
110 T::dtype()
111 }
112
113 fn is_c_contiguous(&self) -> bool {
114 self.layout().is_c_contiguous()
115 }
116
117 fn is_f_contiguous(&self) -> bool {
118 self.layout().is_f_contiguous()
119 }
120 }
121 };
122}
123
124impl_as_raw_buffer!(Array<T, D>,);
125impl_as_raw_buffer!(ArrayView<'a, T, D>, 'a);
126impl_as_raw_buffer!(ArcArray<T, D>,);
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::dimension::Ix2;
132
133 #[test]
134 fn raw_buffer_array() {
135 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
136 .unwrap();
137
138 assert_eq!(arr.raw_shape(), &[2, 3]);
139 assert_eq!(arr.raw_dtype(), DType::F64);
140 assert!(arr.is_c_contiguous());
141 assert!(!arr.raw_ptr().is_null());
142
143 let strides = arr.raw_strides_bytes();
145 assert_eq!(strides, vec![24, 8]);
146 }
147
148 #[test]
149 fn raw_buffer_view() {
150 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
151 let v = arr.view();
152
153 assert_eq!(v.raw_dtype(), DType::F32);
154 assert_eq!(v.raw_shape(), &[2, 2]);
155 assert!(v.is_c_contiguous());
156 }
157
158 #[test]
159 fn raw_buffer_arc() {
160 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
161 let arc = ArcArray::from_owned(arr);
162
163 assert_eq!(arc.raw_dtype(), DType::I32);
164 assert_eq!(arc.raw_shape(), &[2, 2]);
165 }
166
167 #[test]
168 fn buffer_descriptor_aggregates_layout() {
169 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0; 6]).unwrap();
171 let d = arr.buffer_descriptor();
172 assert_eq!(d.ndim, 2);
173 assert_eq!(d.shape, &[2, 3]);
174 assert_eq!(&*d.strides_bytes, &[24, 8]);
175 assert_eq!(d.dtype, DType::F64);
176 assert_eq!(d.itemsize, 8);
177 assert!(d.c_contiguous);
178 assert!(!d.f_contiguous);
179 assert!(!d.data.is_null());
181 assert_eq!(d.data, arr.raw_ptr());
182 }
183
184 #[test]
185 fn buffer_descriptor_repr_c() {
186 let arr = Array::<u32, Ix2>::from_vec(Ix2::new([2, 2]), vec![0u32; 4]).unwrap();
191 let d = arr.buffer_descriptor();
192 assert_eq!(d.itemsize, 4);
193 assert_eq!(d.dtype, DType::U32);
194 assert_eq!(d.shape, &[2, 2]);
195 }
196}