any_tensor/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(warnings)]
3
4use digit_layout::DigitLayout;
5use ndarray_layout::{ArrayLayout, Endian::BigEndian};
6use std::{
7    borrow::Cow,
8    ops::{Deref, DerefMut},
9};
10
11pub extern crate digit_layout;
12pub extern crate ndarray_layout;
13
14// TODO: 现在 digit_layout 要求无符号类型中单个元素的宽度是 2 的幂且不小于 8,没有必要。
15
16/// 张量是一种数据的容器,代表在均质的数据上附加了数据类型、形状和数据布局的动态信息。
17///
18/// 作为一个容器,`Tensor<T, N>` 类似于 [`Option<T>`]、[`Result<T, _>`],可通过一系列方法变换其信息或数据:
19///
20/// - [`clone`](Tensor::clone):复制张量的信息,也复制张量的数据;
21/// - [`transform`](Tensor::transform):在不改变数据的情况下变换张量的布局;
22/// - [`map`](Tensor::map):替换张量的数据;
23/// - [`as_ref`](Tensor::as_ref):返回引用原始数据的新张量;
24/// - [`as_mut`](Tensor::as_mut):返回可变引用原始数据的新张量;
25/// - ...
26#[derive(Clone)]
27pub struct Tensor<T, const N: usize> {
28    /// 数据类型。
29    dt: DigitLayout,
30    /// 形状和布局。
31    layout: ArrayLayout<N>,
32    /// 数据成员。
33    item: T,
34}
35
36impl<const N: usize> Tensor<usize, N> {
37    /// 创建使用指定数据类型 `dt` 和形状 `shape` 的张量,张量的“数据”是连续存储其数据占用的字节数。
38    ///
39    /// 传入的 `shape` 应为张量中的数值的数量。
40    /// 在底层存储中,可能将多个数值捆绑为一个数据组。
41    /// 获取张量的形状时,将返回作为 N 维数组的形状,其连续维度除去了组的规模。
42    ///
43    /// 例如,对于将 32 个数字绑定为一组的数据类型,`shape` 为 `[7, 1024]` 时,产生的张量的形状是 `[7, 32]`。
44    ///
45    /// ```rust
46    /// # use tensor::Tensor;
47    /// // 定义一个数据类型,以 32 个 8 位无符号数为一组。
48    /// digit_layout::layout!(GROUP u(8); 32);
49    ///
50    /// let tensor = Tensor::new(GROUP, [7, 1024]);
51    /// assert_eq!(tensor.dt(), GROUP);
52    /// assert_eq!(tensor.shape(), [7, 32]);
53    /// assert_eq!(tensor.take(), 7 * 32 * 32);
54    /// ```
55    pub fn new(dt: DigitLayout, shape: [usize; N]) -> Self {
56        Self::from_dim_slice(dt, shape)
57    }
58
59    pub fn from_dim_slice(dt: DigitLayout, shape: impl AsRef<[usize]>) -> Self {
60        let shape = shape.as_ref();
61
62        let shape = match dt.group_size() {
63            1 => Cow::Borrowed(shape),
64            g => {
65                let mut shape = shape.to_vec();
66                let last = shape.last_mut().unwrap();
67                assert_eq!(*last % g, 0);
68                *last /= g;
69                Cow::Owned(shape)
70            }
71        };
72
73        let element_size = dt.nbytes();
74        let layout = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
75        let size = layout.num_elements() * element_size;
76        Self {
77            dt,
78            layout,
79            item: size,
80        }
81    }
82}
83
84impl<T, const N: usize> Tensor<T, N> {
85    pub const fn dt(&self) -> DigitLayout {
86        self.dt
87    }
88
89    pub const fn layout(&self) -> &ArrayLayout<N> {
90        &self.layout
91    }
92
93    pub fn shape(&self) -> &[usize] {
94        self.layout.shape()
95    }
96
97    pub fn strides(&self) -> &[isize] {
98        self.layout.strides()
99    }
100
101    pub fn offset(&self) -> isize {
102        self.layout.offset()
103    }
104
105    pub const fn get(&self) -> &T {
106        &self.item
107    }
108
109    pub fn get_mut(&mut self) -> &mut T {
110        &mut self.item
111    }
112
113    pub fn take(self) -> T {
114        self.item
115    }
116
117    pub const fn from_raw_parts(dt: DigitLayout, layout: ArrayLayout<N>, item: T) -> Self {
118        Self { dt, layout, item }
119    }
120
121    pub fn into_raw_parts(self) -> (DigitLayout, ArrayLayout<N>, T) {
122        let Self { dt, layout, item } = self;
123        (dt, layout, item)
124    }
125
126    pub fn use_info(&self) -> Tensor<usize, N> {
127        let dt = self.dt;
128        let element_size = dt.nbytes();
129        let layout = ArrayLayout::new_contiguous(self.layout.shape(), BigEndian, element_size);
130        let size = layout.num_elements() * element_size;
131        Tensor {
132            dt,
133            layout,
134            item: size,
135        }
136    }
137
138    pub fn is_contiguous(&self) -> bool {
139        match self.layout.merge_be(0, self.layout.ndim()) {
140            Some(layout) => {
141                let &[s] = layout.strides() else {
142                    unreachable!()
143                };
144                s == self.dt.nbytes() as isize
145            }
146            None => false,
147        }
148    }
149}
150
151impl<T, const N: usize> Tensor<T, N> {
152    pub fn as_ref(&self) -> Tensor<&T, N> {
153        Tensor {
154            dt: self.dt,
155            layout: self.layout.clone(),
156            item: &self.item,
157        }
158    }
159
160    pub fn as_mut(&mut self) -> Tensor<&mut T, N> {
161        Tensor {
162            dt: self.dt,
163            layout: self.layout.clone(),
164            item: &mut self.item,
165        }
166    }
167
168    pub fn transform(self, f: impl FnOnce(ArrayLayout<N>) -> ArrayLayout<N>) -> Self {
169        let Self { dt, layout, item } = self;
170        Self {
171            dt,
172            layout: f(layout),
173            item,
174        }
175    }
176
177    pub fn map<U>(self, f: impl FnOnce(T) -> U) -> Tensor<U, N> {
178        let Self { dt, layout, item } = self;
179        Tensor {
180            dt,
181            layout,
182            item: f(item),
183        }
184    }
185
186    pub fn replace<U>(self, u: U) -> (T, Tensor<U, N>) {
187        let Self { dt, layout, item } = self;
188        (
189            item,
190            Tensor {
191                dt,
192                layout,
193                item: u,
194            },
195        )
196    }
197}
198
199impl<T: Deref, const N: usize> Tensor<T, N> {
200    pub fn as_deref(&self) -> Tensor<&<T as Deref>::Target, N> {
201        Tensor {
202            dt: self.dt,
203            layout: self.layout.clone(),
204            item: self.item.deref(),
205        }
206    }
207}
208
209impl<T: DerefMut, const N: usize> Tensor<T, N> {
210    pub fn as_deref_mut(&mut self) -> Tensor<&mut <T as Deref>::Target, N> {
211        Tensor {
212            dt: self.dt,
213            layout: self.layout.clone(),
214            item: self.item.deref_mut(),
215        }
216    }
217}
218
219#[test]
220fn test_basic_functions() {
221    digit_layout::layout!(GROUP u(8); 32);
222    let t1 = Tensor::new(GROUP, [7, 1024]);
223    assert_eq!(t1.dt(), GROUP);
224
225    let l1 = t1.layout();
226    let l2: ArrayLayout<2> = ArrayLayout::new_contiguous(&[7, 1024 / 32], BigEndian, 32);
227    assert_eq!(l1.shape(), l2.shape());
228    assert_eq!(l1.strides(), l2.strides());
229    assert_eq!(l1.offset(), l2.offset());
230
231    assert_eq!(t1.shape(), [7, 1024 / 32]);
232    assert_eq!(t1.strides(), &[1024, 32]);
233    assert_eq!(t1.offset(), 0);
234
235    assert_eq!(*t1.get(), 7 * 1024);
236    let mut t2 = t1.clone();
237    *(t2.get_mut()) += 1;
238    assert_eq!(*t2.get(), 7 * 1024 + 1);
239    assert_eq!(t1.take(), 7 * 1024)
240}
241
242#[test]
243fn test_extra_functions() {
244    digit_layout::layout!(GROUP u(8); 32);
245    let shape = [7, 1024];
246    let element_size = 32;
247    let layout: ArrayLayout<2> = ArrayLayout::new_contiguous(&shape, BigEndian, element_size);
248    let item = 7 * 1024;
249
250    let tensor = Tensor::from_raw_parts(GROUP, layout.clone(), item);
251    assert_eq!(tensor.dt(), GROUP);
252
253    assert_eq!(tensor.layout().shape(), layout.shape());
254    assert_eq!(tensor.layout().strides(), layout.strides());
255    assert_eq!(tensor.layout().offset(), layout.offset());
256    assert_eq!(*tensor.get(), item);
257
258    let (dt, layout_from_parts, item_from_parts) = tensor.into_raw_parts();
259    assert_eq!(dt, GROUP);
260
261    assert_eq!(layout_from_parts.shape(), layout.shape());
262    assert_eq!(layout_from_parts.strides(), layout.strides());
263    assert_eq!(layout_from_parts.offset(), layout.offset());
264    assert_eq!(item_from_parts, item);
265
266    let tensor = Tensor::from_raw_parts(GROUP, layout, item);
267
268    let info_tensor = tensor.use_info();
269    assert_eq!(info_tensor.dt(), GROUP);
270    assert_eq!(info_tensor.shape(), tensor.shape());
271    let expected_size = info_tensor.layout().num_elements() * GROUP.nbytes();
272    assert_eq!(info_tensor.take(), expected_size);
273
274    let is_contig = tensor.is_contiguous();
275    assert!(is_contig);
276
277    let non_contig_layout: ArrayLayout<2> = ArrayLayout::new(&[7, 1024], &[1, 7], 0);
278    let non_contig_tensor = Tensor::from_raw_parts(GROUP, non_contig_layout, item);
279    assert!(!non_contig_tensor.is_contiguous())
280}
281
282#[test]
283fn test_as_ref() {
284    digit_layout::layout!(GROUP u(8); 32);
285    let t1 = Tensor::new(GROUP, [7, 1024]);
286
287    let ref_tensor = t1.as_ref();
288    assert_eq!(*ref_tensor.item, 7168)
289}
290
291#[test]
292fn test_as_mut() {
293    digit_layout::layout!(GROUP u(8); 32);
294    let mut t1 = Tensor::new(GROUP, [7, 1024]);
295
296    let mut ref_tensor = t1.as_mut();
297    (**ref_tensor.get_mut()) += 1;
298    assert_eq!(*t1.get(), 7169)
299}
300
301#[test]
302fn test_transform() {
303    digit_layout::layout!(GROUP u(8); 32);
304    let t1 = Tensor::new(GROUP, [7, 1024]);
305
306    fn trans(layout: ArrayLayout<2>) -> ArrayLayout<2> {
307        layout.transpose(&[0, 1])
308    }
309
310    let t2 = t1.transform(trans);
311    assert_eq!(t2.shape(), [7, 32]);
312    assert_eq!(t2.strides(), &[1024, 32]);
313    assert_eq!(t2.offset(), 0)
314}
315
316#[test]
317fn test_map() {
318    digit_layout::layout!(GROUP u(8); 32);
319    let t1 = Tensor::new(GROUP, [7, 1024]);
320
321    fn trans(n: usize) -> isize {
322        n as isize
323    }
324
325    let t2 = t1.map(trans);
326    assert_eq!((*t2.get()), (7 * 1024) as isize)
327}
328
329#[test]
330fn test_as_deref() {
331    struct TestDeref<T>(T);
332    impl<T> Deref for TestDeref<T> {
333        type Target = T;
334        fn deref(&self) -> &Self::Target {
335            &self.0
336        }
337    }
338
339    digit_layout::layout!(GROUP u(8); 32);
340    let tensor = Tensor {
341        dt: GROUP,
342        layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
343        item: TestDeref(42),
344    };
345    let tensor_ref = tensor.as_deref();
346    assert_eq!(*tensor_ref.item, 42)
347}
348
349#[test]
350fn test_as_deref_mut() {
351    struct TestDeref<T>(T);
352    impl<T> Deref for TestDeref<T> {
353        type Target = T;
354        fn deref(&self) -> &Self::Target {
355            &self.0
356        }
357    }
358    impl<T> DerefMut for TestDeref<T> {
359        fn deref_mut(&mut self) -> &mut Self::Target {
360            &mut self.0
361        }
362    }
363
364    digit_layout::layout!(GROUP u(8); 32);
365    let mut tensor = Tensor {
366        dt: GROUP,
367        layout: ArrayLayout::<2>::new_contiguous(&[7, 1024], BigEndian, 32),
368        item: TestDeref(42),
369    };
370    let mut tensor_ref = tensor.as_deref_mut();
371    *(*tensor_ref.get_mut()) += 1;
372    assert_eq!(**tensor.get(), 43)
373}