1use crate::dimension::Dimension;
4use crate::dtype::{DType, Element};
5
6use crate::array::arc::ArcArray;
7use crate::array::owned::Array;
8use crate::array::view::ArrayView;
9
10pub trait AsRawBuffer {
15 fn raw_ptr(&self) -> *const u8;
17
18 fn raw_shape(&self) -> &[usize];
20
21 fn raw_strides_bytes(&self) -> Vec<isize>;
23
24 fn raw_dtype(&self) -> DType;
26
27 fn is_c_contiguous(&self) -> bool;
29
30 fn is_f_contiguous(&self) -> bool;
32}
33
34macro_rules! impl_as_raw_buffer {
41 ($ty:ty, $($lt:lifetime)?) => {
42 impl<$($lt,)? T: Element, D: Dimension> AsRawBuffer for $ty {
43 fn raw_ptr(&self) -> *const u8 {
44 self.as_ptr() as *const u8
45 }
46
47 fn raw_shape(&self) -> &[usize] {
48 self.shape()
49 }
50
51 fn raw_strides_bytes(&self) -> Vec<isize> {
52 let itemsize = std::mem::size_of::<T>() as isize;
53 self.strides().iter().map(|&s| s * itemsize).collect()
54 }
55
56 fn raw_dtype(&self) -> DType {
57 T::dtype()
58 }
59
60 fn is_c_contiguous(&self) -> bool {
61 self.layout().is_c_contiguous()
62 }
63
64 fn is_f_contiguous(&self) -> bool {
65 self.layout().is_f_contiguous()
66 }
67 }
68 };
69}
70
71impl_as_raw_buffer!(Array<T, D>,);
72impl_as_raw_buffer!(ArrayView<'a, T, D>, 'a);
73impl_as_raw_buffer!(ArcArray<T, D>,);
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use crate::dimension::Ix2;
79
80 #[test]
81 fn raw_buffer_array() {
82 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
83 .unwrap();
84
85 assert_eq!(arr.raw_shape(), &[2, 3]);
86 assert_eq!(arr.raw_dtype(), DType::F64);
87 assert!(arr.is_c_contiguous());
88 assert!(!arr.raw_ptr().is_null());
89
90 let strides = arr.raw_strides_bytes();
92 assert_eq!(strides, vec![24, 8]);
93 }
94
95 #[test]
96 fn raw_buffer_view() {
97 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
98 let v = arr.view();
99
100 assert_eq!(v.raw_dtype(), DType::F32);
101 assert_eq!(v.raw_shape(), &[2, 2]);
102 assert!(v.is_c_contiguous());
103 }
104
105 #[test]
106 fn raw_buffer_arc() {
107 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
108 let arc = ArcArray::from_owned(arr);
109
110 assert_eq!(arc.raw_dtype(), DType::I32);
111 assert_eq!(arc.raw_shape(), &[2, 2]);
112 }
113}