Skip to main content

radiate_utils/array/
tensor.rs

1use crate::array::TensorError;
2use crate::{Shape, Strides};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6
7/// Row-major tensor structure. The data is stored in a contiguous vector,
8/// and the shape and strides are used to interpret the data.
9#[derive(Default)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub struct Tensor<T> {
12    pub(super) data: Vec<T>,
13    pub(super) shape: Shape,
14    pub(super) strides: Strides,
15}
16
17impl<T> Tensor<T> {
18    pub fn new(data: Vec<T>, shape: impl Into<Shape>) -> Self {
19        let shape = shape.into();
20        let strides = Strides::from(&shape);
21
22        let expected = shape.try_size().unwrap_or_else(|| {
23            panic!(
24                "Tensor::new: shape size overflow for dims={:?}",
25                shape.as_slice()
26            )
27        });
28
29        assert!(
30            data.len() == expected,
31            "Tensor::new: data.len()={} does not match shape product {}",
32            data.len(),
33            expected
34        );
35
36        Self {
37            data,
38            shape,
39            strides,
40        }
41    }
42
43    pub fn try_new(data: Vec<T>, shape: impl Into<Shape>) -> Result<Self, TensorError> {
44        let shape = shape.into();
45        let strides = Strides::from(&shape);
46
47        let expected = shape.try_size().ok_or_else(|| TensorError::ShapeOverflow {
48            dims: shape.as_slice().to_vec(),
49        })?;
50
51        if data.len() != expected {
52            return Err(TensorError::LenMismatch {
53                len: data.len(),
54                expected,
55            });
56        }
57
58        Ok(Self {
59            data,
60            shape,
61            strides,
62        })
63    }
64
65    /// The rank (number of dimensions) of the tensor.
66    ///
67    /// For example, a matrix has rank 2, a vector has rank 1, and a scalar has rank 0.
68    ///
69    /// ```rust
70    /// use radiate_utils::Tensor;
71    ///
72    /// let two = Tensor::new(vec![1, 2, 3, 4], (2, 2));
73    /// let three = Tensor::new(vec![0; 24], (2, 3, 4));
74    /// assert_eq!(two.rank(), 2);
75    /// assert_eq!(three.rank(), 3);
76    /// ```
77    #[inline]
78    pub fn rank(&self) -> usize {
79        self.shape.dimensions()
80    }
81
82    /// The dimensions of the tensor. This is essentially a shortcut
83    /// for `tensor.shape.as_slice()`. Array of length equal to
84    /// the tensor's rank, where each entry is the size of that dimension.
85    ///
86    /// ```rust
87    /// use radiate_utils::Tensor;
88    ///
89    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
90    /// assert_eq!(tensor.dims(), &[2, 3]);
91    /// ```
92    #[inline]
93    pub fn dims(&self) -> &[usize] {
94        self.shape.as_slice()
95    }
96
97    /// The shape of the tensor. This describes the size of each dimension.
98    /// For example, a tensor with shape `[2, 3]` has 2 rows and 3 columns -
99    /// essentially a 2x3 matrix.
100    ///
101    /// ```rust
102    /// use radiate_utils::Tensor;
103    ///
104    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
105    /// assert_eq!(tensor.shape().as_slice(), &[2, 3]);
106    /// ```
107    #[inline]
108    pub fn shape(&self) -> &Shape {
109        &self.shape
110    }
111
112    /// The strides of the tensor. Strides indicate how many elements
113    /// to skip in the underlying data vector to move to the next element
114    /// along each dimension. For a row-major tensor, the last dimension
115    /// has a stride of 1, the second-to-last dimension has a stride equal
116    /// to the size of the last dimension, and so on.
117    ///
118    /// ```rust
119    /// use radiate_utils::Tensor;
120    ///
121    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
122    /// assert_eq!(tensor.strides().as_slice(), &[3, 1]);
123    /// ```
124    #[inline]
125    pub fn strides(&self) -> &Strides {
126        &self.strides
127    }
128
129    /// The underlying data of the tensor as a flat slice.
130    /// This data is stored in row-major order.
131    ///
132    /// ```rust
133    /// use radiate_utils::Tensor;
134    ///
135    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
136    /// assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
137    /// ```
138    #[inline]
139    pub fn data(&self) -> &[T] {
140        &self.data
141    }
142
143    /// The underlying data of the tensor as a mutable flat slice.
144    /// This data is stored in row-major order.
145    ///
146    /// ```rust
147    /// use radiate_utils::Tensor;
148    ///
149    /// let mut tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
150    /// let data_mut = tensor.data_mut();
151    /// data_mut[0] = 10;
152    /// assert_eq!(tensor.data(), &[10, 2, 3, 4, 5, 6]);
153    /// ```
154    #[inline]
155    pub fn data_mut(&mut self) -> &mut [T] {
156        &mut self.data
157    }
158
159    /// Figure out if the tensor has no elements.
160    ///
161    /// ```rust
162    /// use radiate_utils::Tensor;
163    ///
164    /// let empty = Tensor::<i32>::new(vec![], (0, 3));
165    /// let non_empty = Tensor::new(vec![1, 2, 3], (1, 3));
166    /// assert!(empty.is_empty());
167    /// assert!(!non_empty.is_empty());
168    /// ```
169    #[inline]
170    pub fn is_empty(&self) -> bool {
171        self.data.is_empty()
172    }
173
174    pub fn len(&self) -> usize {
175        self.data.len()
176    }
177
178    pub fn clear(&mut self) {
179        self.data.clear();
180    }
181
182    /// --- raw pointers ---
183    #[inline]
184    pub fn as_ptr(&self) -> *const T {
185        self.data.as_ptr()
186    }
187
188    #[inline]
189    pub fn as_mut_ptr(&mut self) -> *mut T {
190        self.data.as_mut_ptr()
191    }
192
193    #[inline]
194    pub fn as_raw_parts(&self) -> (*const T, usize) {
195        (self.data.as_ptr(), self.data.len())
196    }
197
198    #[inline]
199    pub fn as_raw_parts_mut(&mut self) -> (*mut T, usize) {
200        (self.data.as_mut_ptr(), self.data.len())
201    }
202
203    /// --- iterators ---
204    #[inline]
205    pub fn iter(&self) -> std::slice::Iter<'_, T> {
206        self.data.iter()
207    }
208
209    #[inline]
210    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
211        self.data.iter_mut()
212    }
213
214    /// Reshape without changing the underlying data.
215    /// Panics if the new shape has a different total element count.
216    ///
217    /// ```rust
218    /// use radiate_utils::Tensor;
219    ///
220    /// let mut tensor = Tensor::new(vec![0, 1, 2, 3, 4, 5], (2, 3));
221    /// tensor.reshape((3, 2));
222    /// assert_eq!(tensor.shape().as_slice(), &[3, 2]);
223    /// assert_eq!(tensor.strides().as_slice(), &[2, 1]); // row-major
224    /// assert_eq!(tensor.data(), &[0, 1, 2, 3, 4, 5]);
225    /// ```
226    #[inline]
227    pub fn reshape(&mut self, new_shape: impl Into<Shape>) {
228        let new_shape = new_shape.into();
229        let expected = new_shape.try_size().unwrap_or_else(|| {
230            panic!(
231                "Tensor::reshape: shape size overflow for dims={:?}",
232                new_shape.as_slice()
233            )
234        });
235
236        assert!(
237            expected == self.data.len(),
238            "Tensor::reshape: new shape product {} != data.len() {}",
239            expected,
240            self.data.len()
241        );
242
243        self.shape = new_shape.clone();
244        self.strides = Strides::from(&new_shape);
245    }
246}
247
248impl<T: Clone> Tensor<T> {
249    pub fn from_elem(shape: impl Into<Shape>, value: T) -> Self {
250        let shape = shape.into();
251        let n = shape.try_size().unwrap_or_else(|| {
252            panic!(
253                "Tensor::from_elem: shape size overflow for dims={:?}",
254                shape.as_slice()
255            )
256        });
257
258        let data = vec![value; n];
259        Self::new(data, shape)
260    }
261}
262
263impl<T: Default + Clone> Tensor<T> {
264    pub fn zeros(shape: impl Into<Shape>) -> Self {
265        Self::from_elem(shape, T::default())
266    }
267}
268
269impl<T: Debug> Debug for Tensor<T> {
270    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271        writeln!(f, "Tensor(shape={:?}, data=", self.shape.dimensions())?;
272
273        fn fmt_recursive<T: std::fmt::Debug>(
274            f: &mut std::fmt::Formatter<'_>,
275            data: &[T],
276            shape: &[usize],
277            strides: &[usize],
278            offset: usize,
279            depth: usize,
280        ) -> std::fmt::Result {
281            let indent = " ".repeat(depth);
282
283            if shape.len() == 1 {
284                // Vector / leaf
285                write!(f, "{}[", indent)?;
286                for i in 0..shape[0] {
287                    if i > 0 {
288                        write!(f, ", ")?;
289                    }
290                    write!(f, "{:?}", data[offset + i * strides[0]])?;
291                }
292                write!(f, "]")?;
293            } else {
294                // Higher rank
295                write!(f, "{}[", indent)?;
296                for i in 0..shape[0] {
297                    if i > 0 {
298                        writeln!(f, ",")?;
299                    } else {
300                        writeln!(f)?;
301                    }
302
303                    fmt_recursive(
304                        f,
305                        data,
306                        &shape[1..],
307                        &strides[1..],
308                        offset + i * strides[0],
309                        depth + 1,
310                    )?;
311                }
312                writeln!(f)?;
313                write!(f, "{}]", indent)?;
314            }
315
316            Ok(())
317        }
318
319        fmt_recursive(
320            f,
321            &self.data,
322            (0..self.shape.dimensions())
323                .map(|i| self.shape.dim_at(i))
324                .collect::<Vec<usize>>()
325                .as_slice(),
326            (0..self.shape.dimensions())
327                .map(|i| self.strides.stride_at(i))
328                .collect::<Vec<usize>>()
329                .as_slice(),
330            0,
331            0,
332        )?;
333
334        write!(f, ")")
335    }
336}
337
338#[cfg(test)]
339mod tests {
340
341    use super::*;
342
343    #[test]
344    fn test_tensor_basic() {
345        let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
346
347        assert_eq!(tensor.rank(), 2);
348        assert_eq!(tensor.dims(), &[2, 3]);
349        assert_eq!(tensor.shape().as_slice(), &[2, 3]);
350        assert_eq!(tensor.strides().as_slice(), &[3, 1]);
351        assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
352    }
353
354    #[test]
355    fn test_tensor_from_elem() {
356        let tensor = Tensor::from_elem((2, 2), 42);
357
358        assert_eq!(tensor.rank(), 2);
359        assert_eq!(tensor.dims(), &[2, 2]);
360        assert_eq!(tensor.shape().as_slice(), &[2, 2]);
361        assert_eq!(tensor.strides().as_slice(), &[2, 1]);
362        assert_eq!(tensor.data(), &[42, 42, 42, 42]);
363    }
364
365    #[test]
366    fn test_try_new_len_mismatch_err() {
367        let err = Tensor::try_new(vec![1, 2, 3], (2, 2)).unwrap_err();
368        match err {
369            TensorError::LenMismatch { len, expected } => {
370                assert_eq!(len, 3);
371                assert_eq!(expected, 4);
372            }
373            other => panic!("expected LenMismatch, got: {:?}", other),
374        }
375    }
376
377    #[test]
378    fn test_reshape_updates_shape_and_strides() {
379        let mut t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
380
381        // reshape to (3, 2)
382        t.reshape((3, 2));
383
384        assert_eq!(t.dims(), &[3, 2]);
385        assert_eq!(t.strides().as_slice(), &[2, 1]); // row-major
386        assert_eq!(t.data(), &[0, 1, 2, 3, 4, 5]);
387    }
388
389    #[test]
390    #[should_panic]
391    fn test_reshape_panics_on_mismatched_size() {
392        let mut t = Tensor::new(vec![0; 6], (2, 3));
393        t.reshape((2, 2)); // product 4 != 6
394    }
395
396    #[test]
397    fn test_from_elem_fills_correctly() {
398        let t = Tensor::from_elem((2, 3), 7u32);
399        assert_eq!(t.data(), &[7, 7, 7, 7, 7, 7]);
400        assert_eq!(t.strides().as_slice(), &[3, 1]);
401    }
402
403    #[test]
404    fn test_zeros_works_for_numeric() {
405        let t = Tensor::<i32>::zeros((2, 2, 2));
406        assert_eq!(t.data(), &[0; 8]);
407        assert_eq!(t.strides().as_slice(), &[4, 2, 1]);
408    }
409
410    #[test]
411    fn test_as_raw_parts_consistency() {
412        let t = Tensor::new(vec![10, 11, 12, 13], (2, 2));
413        let (ptr, len) = t.as_raw_parts();
414        assert_eq!(len, 4);
415        // pointer identity check (safe as long as we don't deref past len)
416        assert_eq!(ptr, t.as_ptr());
417    }
418}