1use crate::dimension::Dimension;
4use crate::dtype::{DType, Element};
5
6use crate::array::arc::ArcArray;
7use crate::array::owned::Array;
8use crate::array::view::ArrayView;
9
10pub trait AsRawBuffer {
15 fn raw_ptr(&self) -> *const u8;
17
18 fn raw_shape(&self) -> &[usize];
20
21 fn raw_strides_bytes(&self) -> Vec<isize>;
23
24 fn raw_dtype(&self) -> DType;
26
27 fn is_c_contiguous(&self) -> bool;
29
30 fn is_f_contiguous(&self) -> bool;
32}
33
34impl<T: Element, D: Dimension> AsRawBuffer for Array<T, D> {
35 fn raw_ptr(&self) -> *const u8 {
36 self.as_ptr() as *const u8
37 }
38
39 fn raw_shape(&self) -> &[usize] {
40 self.shape()
41 }
42
43 fn raw_strides_bytes(&self) -> Vec<isize> {
44 let itemsize = std::mem::size_of::<T>() as isize;
45 self.strides().iter().map(|&s| s * itemsize).collect()
46 }
47
48 fn raw_dtype(&self) -> DType {
49 T::dtype()
50 }
51
52 fn is_c_contiguous(&self) -> bool {
53 self.layout().is_c_contiguous()
54 }
55
56 fn is_f_contiguous(&self) -> bool {
57 self.layout().is_f_contiguous()
58 }
59}
60
61impl<T: Element, D: Dimension> AsRawBuffer for ArrayView<'_, T, D> {
62 fn raw_ptr(&self) -> *const u8 {
63 self.as_ptr() as *const u8
64 }
65
66 fn raw_shape(&self) -> &[usize] {
67 self.shape()
68 }
69
70 fn raw_strides_bytes(&self) -> Vec<isize> {
71 let itemsize = std::mem::size_of::<T>() as isize;
72 self.strides().iter().map(|&s| s * itemsize).collect()
73 }
74
75 fn raw_dtype(&self) -> DType {
76 T::dtype()
77 }
78
79 fn is_c_contiguous(&self) -> bool {
80 self.layout().is_c_contiguous()
81 }
82
83 fn is_f_contiguous(&self) -> bool {
84 self.layout().is_f_contiguous()
85 }
86}
87
88impl<T: Element, D: Dimension> AsRawBuffer for ArcArray<T, D> {
89 fn raw_ptr(&self) -> *const u8 {
90 self.as_ptr() as *const u8
91 }
92
93 fn raw_shape(&self) -> &[usize] {
94 self.shape()
95 }
96
97 fn raw_strides_bytes(&self) -> Vec<isize> {
98 let itemsize = std::mem::size_of::<T>() as isize;
99 self.strides().iter().map(|&s| s * itemsize).collect()
100 }
101
102 fn raw_dtype(&self) -> DType {
103 T::dtype()
104 }
105
106 fn is_c_contiguous(&self) -> bool {
107 self.layout().is_c_contiguous()
108 }
109
110 fn is_f_contiguous(&self) -> bool {
111 self.layout().is_f_contiguous()
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::dimension::Ix2;
119
120 #[test]
121 fn raw_buffer_array() {
122 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
123 .unwrap();
124
125 assert_eq!(arr.raw_shape(), &[2, 3]);
126 assert_eq!(arr.raw_dtype(), DType::F64);
127 assert!(arr.is_c_contiguous());
128 assert!(!arr.raw_ptr().is_null());
129
130 let strides = arr.raw_strides_bytes();
132 assert_eq!(strides, vec![24, 8]);
133 }
134
135 #[test]
136 fn raw_buffer_view() {
137 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
138 let v = arr.view();
139
140 assert_eq!(v.raw_dtype(), DType::F32);
141 assert_eq!(v.raw_shape(), &[2, 2]);
142 assert!(v.is_c_contiguous());
143 }
144
145 #[test]
146 fn raw_buffer_arc() {
147 let arr = Array::<i32, Ix2>::from_vec(Ix2::new([2, 2]), vec![1, 2, 3, 4]).unwrap();
148 let arc = ArcArray::from_owned(arr);
149
150 assert_eq!(arc.raw_dtype(), DType::I32);
151 assert_eq!(arc.raw_shape(), &[2, 2]);
152 }
153}