1#![doc = include_str!("../README.md")]
2#![deny(warnings)]
3
4use digit_layout::DigitLayout;
5use ndarray_layout::{ArrayLayout, Endian::BigEndian};
6use std::{
7 borrow::Cow,
8 ops::{Deref, DerefMut},
9};
10
11pub extern crate digit_layout;
12pub extern crate ndarray_layout;
13
14#[derive(Clone)]
27pub struct Tensor<T, const N: usize> {
28 dt: DigitLayout,
30 layout: ArrayLayout<N>,
32 item: T,
34}
35
36impl<const N: usize> Tensor<usize, N> {
37 pub fn new(dt: DigitLayout, shape: [usize; N]) -> Self {
56 Self::from_dim_slice(dt, shape)
57 }
58
59 pub fn from_dim_slice(dt: DigitLayout, shape: impl AsRef<[usize]>) -> Self {
60 let shape = shape.as_ref();
61
62 let shape = match dt.group_size() {
63 1 => Cow::Borrowed(shape),
64 g => {
65 let mut shape = shape.to_vec();
66 let last = shape.last_mut().unwrap();
67 assert_eq!(*last % g, 0);
68 *last /= g;
69 Cow::Owned(shape)
70 }
71 };
72
73 let element_size = dt.nbytes();
74 let layout = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
75 let size = layout.num_elements() * element_size;
76 Self {
77 dt,
78 layout,
79 item: size,
80 }
81 }
82}
83
84impl<T, const N: usize> Tensor<T, N> {
85 pub const fn dt(&self) -> DigitLayout {
86 self.dt
87 }
88
89 pub const fn layout(&self) -> &ArrayLayout<N> {
90 &self.layout
91 }
92
93 pub fn shape(&self) -> &[usize] {
94 self.layout.shape()
95 }
96
97 pub fn strides(&self) -> &[isize] {
98 self.layout.strides()
99 }
100
101 pub fn offset(&self) -> isize {
102 self.layout.offset()
103 }
104
105 pub const fn get(&self) -> &T {
106 &self.item
107 }
108
109 pub fn get_mut(&mut self) -> &mut T {
110 &mut self.item
111 }
112
113 pub fn take(self) -> T {
114 self.item
115 }
116
117 pub const fn from_raw_parts(dt: DigitLayout, layout: ArrayLayout<N>, item: T) -> Self {
118 Self { dt, layout, item }
119 }
120
121 pub fn into_raw_parts(self) -> (DigitLayout, ArrayLayout<N>, T) {
122 let Self { dt, layout, item } = self;
123 (dt, layout, item)
124 }
125
126 pub fn use_info(&self) -> Tensor<usize, N> {
127 let dt = self.dt;
128 let element_size = dt.nbytes();
129 let layout = ArrayLayout::new_contiguous(self.layout.shape(), BigEndian, element_size);
130 let size = layout.num_elements() * element_size;
131 Tensor {
132 dt,
133 layout,
134 item: size,
135 }
136 }
137
138 pub fn is_contiguous(&self) -> bool {
139 match self.layout.merge_be(0, self.layout.ndim()) {
140 Some(layout) => {
141 let &[s] = layout.strides() else {
142 unreachable!()
143 };
144 s == self.dt.nbytes() as isize
145 }
146 None => false,
147 }
148 }
149}
150
151impl<T, const N: usize> Tensor<T, N> {
152 pub fn as_ref(&self) -> Tensor<&T, N> {
153 Tensor {
154 dt: self.dt,
155 layout: self.layout.clone(),
156 item: &self.item,
157 }
158 }
159
160 pub fn as_mut(&mut self) -> Tensor<&mut T, N> {
161 Tensor {
162 dt: self.dt,
163 layout: self.layout.clone(),
164 item: &mut self.item,
165 }
166 }
167
168 pub fn transform(self, f: impl FnOnce(ArrayLayout<N>) -> ArrayLayout<N>) -> Self {
169 let Self { dt, layout, item } = self;
170 Self {
171 dt,
172 layout: f(layout),
173 item,
174 }
175 }
176
177 pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Tensor<U, N> {
178 let Self { dt, layout, item } = self;
179 Tensor {
180 dt,
181 layout,
182 item: f(item),
183 }
184 }
185
186 pub fn replace<U>(self, u: U) -> (T, Tensor<U, N>) {
187 let Self { dt, layout, item } = self;
188 (
189 item,
190 Tensor {
191 dt,
192 layout,
193 item: u,
194 },
195 )
196 }
197}
198
199impl<T: Deref, const N: usize> Tensor<T, N> {
200 pub fn as_deref(&self) -> Tensor<&<T as Deref>::Target, N> {
201 Tensor {
202 dt: self.dt,
203 layout: self.layout.clone(),
204 item: self.item.deref(),
205 }
206 }
207}
208
209impl<T: DerefMut, const N: usize> Tensor<T, N> {
210 pub fn as_deref_mut(&mut self) -> Tensor<&mut <T as Deref>::Target, N> {
211 Tensor {
212 dt: self.dt,
213 layout: self.layout.clone(),
214 item: self.item.deref_mut(),
215 }
216 }
217}
218
219#[test]
220fn test_basic_functions() {
221 digit_layout::layout!(GROUP u(8); 32);
222 let t1 = Tensor::new(GROUP, [7, 1024]);
223 assert_eq!(t1.dt(), GROUP);
224
225 let l1 = t1.layout();
226 let l2: ArrayLayout<2> = ArrayLayout::new_contiguous(&[7, 1024 / 32], BigEndian, 32);
227 assert_eq!(l1.shape(), l2.shape());
228 assert_eq!(l1.strides(), l2.strides());
229 assert_eq!(l1.offset(), l2.offset());
230
231 assert_eq!(t1.shape(), [7, 1024 / 32]);
232 assert_eq!(t1.strides(), &[1024, 32]);
233 assert_eq!(t1.offset(), 0);
234
235 assert_eq!(*t1.get(), 7 * 1024);
236 let mut t2 = t1.clone();
237 *(t2.get_mut()) += 1;
238 assert_eq!(*t2.get(), 7 * 1024 + 1);
239 assert_eq!(t1.take(), 7 * 1024)
240}
241
242#[test]
243fn test_extra_functions() {
244 digit_layout::layout!(GROUP u(8); 32);
245 let shape = [7, 1024];
246 let element_size = 32;
247 let layout: ArrayLayout<2> = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
248 let item = 7 * 1024;
249
250 let tensor = Tensor::from_raw_parts(GROUP, layout.clone(), item);
251 assert_eq!(tensor.dt(), GROUP);
252
253 assert_eq!(tensor.layout().shape(), layout.shape());
254 assert_eq!(tensor.layout().strides(), layout.strides());
255 assert_eq!(tensor.layout().offset(), layout.offset());
256 assert_eq!(*tensor.get(), item);
257
258 let (dt, layout_from_parts, item_from_parts) = tensor.into_raw_parts();
259 assert_eq!(dt, GROUP);
260
261 assert_eq!(layout_from_parts.shape(), layout.shape());
262 assert_eq!(layout_from_parts.strides(), layout.strides());
263 assert_eq!(layout_from_parts.offset(), layout.offset());
264 assert_eq!(item_from_parts, item);
265
266 let tensor = Tensor::from_raw_parts(GROUP, layout, item);
267
268 let info_tensor = tensor.use_info();
269 assert_eq!(info_tensor.dt(), GROUP);
270 assert_eq!(info_tensor.shape(), tensor.shape());
271 let expected_size = info_tensor.layout().num_elements() * GROUP.nbytes();
272 assert_eq!(info_tensor.take(), expected_size);
273
274 let is_contig = tensor.is_contiguous();
275 assert!(is_contig);
276
277 let non_contig_layout: ArrayLayout<2> = ArrayLayout::new(&[7, 1024], &[1, 7], 0);
278 let non_contig_tensor = Tensor::from_raw_parts(GROUP, non_contig_layout, item);
279 assert!(!non_contig_tensor.is_contiguous())
280}
281
282#[test]
283fn test_as_ref() {
284 digit_layout::layout!(GROUP u(8); 32);
285 let t1 = Tensor::new(GROUP, [7, 1024]);
286
287 let ref_tensor = t1.as_ref();
288 assert_eq!(*ref_tensor.item, 7168)
289}
290
291#[test]
292fn test_as_mut() {
293 digit_layout::layout!(GROUP u(8); 32);
294 let mut t1 = Tensor::new(GROUP, [7, 1024]);
295
296 let mut ref_tensor = t1.as_mut();
297 (**ref_tensor.get_mut()) += 1;
298 assert_eq!(*t1.get(), 7169)
299}
300
301#[test]
302fn test_transform() {
303 digit_layout::layout!(GROUP u(8); 32);
304 let t1 = Tensor::new(GROUP, [7, 1024]);
305
306 fn trans(layout: ArrayLayout<2>) -> ArrayLayout<2> {
307 layout.transpose(&[0, 1])
308 }
309
310 let t2 = t1.transform(trans);
311 assert_eq!(t2.shape(), [7, 32]);
312 assert_eq!(t2.strides(), &[1024, 32]);
313 assert_eq!(t2.offset(), 0)
314}
315
316#[test]
317fn test_map() {
318 digit_layout::layout!(GROUP u(8); 32);
319 let t1 = Tensor::new(GROUP, [7, 1024]);
320
321 fn trans(n: usize) -> isize {
322 n as isize
323 }
324
325 let t2 = t1.map(trans);
326 assert_eq!((*t2.get()), (7 * 1024) as isize)
327}
328
329#[test]
330fn test_as_deref() {
331 struct TestDeref<T>(T);
332 impl<T> Deref for TestDeref<T> {
333 type Target = T;
334 fn deref(&self) -> &Self::Target {
335 &self.0
336 }
337 }
338
339 digit_layout::layout!(GROUP u(8); 32);
340 let tensor = Tensor {
341 dt: GROUP,
342 layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
343 item: TestDeref(42),
344 };
345 let tensor_ref = tensor.as_deref();
346 assert_eq!(*tensor_ref.item, 42)
347}
348
349#[test]
350fn test_as_deref_mut() {
351 struct TestDeref<T>(T);
352 impl<T> Deref for TestDeref<T> {
353 type Target = T;
354 fn deref(&self) -> &Self::Target {
355 &self.0
356 }
357 }
358 impl<T> DerefMut for TestDeref<T> {
359 fn deref_mut(&mut self) -> &mut Self::Target {
360 &mut self.0
361 }
362 }
363
364 digit_layout::layout!(GROUP u(8); 32);
365 let mut tensor = Tensor {
366 dt: GROUP,
367 layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
368 item: TestDeref(42),
369 };
370 let mut tensor_ref = tensor.as_deref_mut();
371 *(*tensor_ref.get_mut()) += 1;
372 assert_eq!(**tensor.get(), 43)
373}