Skip to main content

radiate_utils/array/
tensor.rs

1use crate::array::TensorError;
2use crate::{Shape, Strides};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6
7/// Row-major tensor structure. The data is stored in a contiguous vector,
8/// and the shape and strides are used to interpret the data.
9#[derive(Default)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub struct Tensor<T> {
12    pub(super) data: Vec<T>,
13    pub(super) shape: Shape,
14    pub(super) strides: Strides,
15}
16
17impl<T> Tensor<T> {
18    pub fn new(data: Vec<T>, shape: impl Into<Shape>) -> Self {
19        let shape = shape.into();
20        let strides = Strides::from(&shape);
21
22        let expected = shape.try_size().unwrap_or_else(|| {
23            panic!(
24                "Tensor::new: shape size overflow for dims={:?}",
25                shape.as_slice()
26            )
27        });
28
29        assert!(
30            data.len() == expected,
31            "Tensor::new: data.len()={} does not match shape product {}",
32            data.len(),
33            expected
34        );
35
36        Self {
37            data,
38            shape,
39            strides,
40        }
41    }
42
43    pub fn try_new(data: Vec<T>, shape: impl Into<Shape>) -> Result<Self, TensorError> {
44        let shape = shape.into();
45        let strides = Strides::from(&shape);
46
47        let expected = shape.try_size().ok_or_else(|| TensorError::ShapeOverflow {
48            dims: shape.as_slice().to_vec(),
49        })?;
50
51        if data.len() != expected {
52            return Err(TensorError::LenMismatch {
53                len: data.len(),
54                expected,
55            });
56        }
57
58        Ok(Self {
59            data,
60            shape,
61            strides,
62        })
63    }
64
65    /// The rank (number of dimensions) of the tensor.
66    ///
67    /// For example, a matrix has rank 2, a vector has rank 1, and a scalar has rank 0.
68    ///
69    /// ```rust
70    /// use radiate_utils::Tensor;
71    ///
72    /// let two = Tensor::new(vec![1, 2, 3, 4], (2, 2));
73    /// let three = Tensor::new(vec![0; 24], (2, 3, 4));
74    /// assert_eq!(two.rank(), 2);
75    /// assert_eq!(three.rank(), 3);
76    /// ```
77    #[inline]
78    pub fn rank(&self) -> usize {
79        self.shape.dimensions()
80    }
81
82    /// The dimensions of the tensor. This is essentially a shortcut
83    /// for `tensor.shape.as_slice()`. Array of length equal to
84    /// the tensor's rank, where each entry is the size of that dimension.
85    ///
86    /// ```rust
87    /// use radiate_utils::Tensor;
88    ///
89    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
90    /// assert_eq!(tensor.dims(), &[2, 3]);
91    /// ```
92    #[inline]
93    pub fn dims(&self) -> &[usize] {
94        self.shape.as_slice()
95    }
96
97    /// The shape of the tensor. This describes the size of each dimension.
98    /// For example, a tensor with shape `[2, 3]` has 2 rows and 3 columns -
99    /// essentially a 2x3 matrix.
100    ///
101    /// ```rust
102    /// use radiate_utils::Tensor;
103    ///
104    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
105    /// assert_eq!(tensor.shape().as_slice(), &[2, 3]);
106    /// ```
107    #[inline]
108    pub fn shape(&self) -> &Shape {
109        &self.shape
110    }
111
112    /// The strides of the tensor. Strides indicate how many elements
113    /// to skip in the underlying data vector to move to the next element
114    /// along each dimension. For a row-major tensor, the last dimension
115    /// has a stride of 1, the second-to-last dimension has a stride equal
116    /// to the size of the last dimension, and so on.
117    ///
118    /// ```rust
119    /// use radiate_utils::Tensor;
120    ///
121    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
122    /// assert_eq!(tensor.strides().as_slice(), &[3, 1]);
123    /// ```
124    #[inline]
125    pub fn strides(&self) -> &Strides {
126        &self.strides
127    }
128
129    /// The underlying data of the tensor as a flat slice.
130    /// This data is stored in row-major order.
131    ///
132    /// ```rust
133    /// use radiate_utils::Tensor;
134    ///
135    /// let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
136    /// assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
137    /// ```
138    #[inline]
139    pub fn data(&self) -> &[T] {
140        &self.data
141    }
142
143    /// The underlying data of the tensor as a mutable flat slice.
144    /// This data is stored in row-major order.
145    ///
146    /// ```rust
147    /// use radiate_utils::Tensor;
148    ///
149    /// let mut tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
150    /// let data_mut = tensor.data_mut();
151    /// data_mut[0] = 10;
152    /// assert_eq!(tensor.data(), &[10, 2, 3, 4, 5, 6]);
153    /// ```
154    #[inline]
155    pub fn data_mut(&mut self) -> &mut [T] {
156        &mut self.data
157    }
158
159    /// Figure out if the tensor has no elements.
160    ///
161    /// ```rust
162    /// use radiate_utils::Tensor;
163    ///
164    /// let empty = Tensor::<i32>::new(vec![], (0, 3));
165    /// let non_empty = Tensor::new(vec![1, 2, 3], (1, 3));
166    /// assert!(empty.is_empty());
167    /// assert!(!non_empty.is_empty());
168    /// ```
169    #[inline]
170    pub fn is_empty(&self) -> bool {
171        self.data.is_empty()
172    }
173
174    /// --- raw pointers ---
175    #[inline]
176    pub fn as_ptr(&self) -> *const T {
177        self.data.as_ptr()
178    }
179
180    #[inline]
181    pub fn as_mut_ptr(&mut self) -> *mut T {
182        self.data.as_mut_ptr()
183    }
184
185    #[inline]
186    pub fn as_raw_parts(&self) -> (*const T, usize) {
187        (self.data.as_ptr(), self.data.len())
188    }
189
190    #[inline]
191    pub fn as_raw_parts_mut(&mut self) -> (*mut T, usize) {
192        (self.data.as_mut_ptr(), self.data.len())
193    }
194
195    /// --- iterators ---
196    #[inline]
197    pub fn iter(&self) -> std::slice::Iter<'_, T> {
198        self.data.iter()
199    }
200
201    #[inline]
202    pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
203        self.data.iter_mut()
204    }
205
206    /// Reshape without changing the underlying data.
207    /// Panics if the new shape has a different total element count.
208    ///
209    /// ```rust
210    /// use radiate_utils::Tensor;
211    ///
212    /// let mut tensor = Tensor::new(vec![0, 1, 2, 3, 4, 5], (2, 3));
213    /// tensor.reshape((3, 2));
214    /// assert_eq!(tensor.shape().as_slice(), &[3, 2]);
215    /// assert_eq!(tensor.strides().as_slice(), &[2, 1]); // row-major
216    /// assert_eq!(tensor.data(), &[0, 1, 2, 3, 4, 5]);
217    /// ```
218    #[inline]
219    pub fn reshape(&mut self, new_shape: impl Into<Shape>) {
220        let new_shape = new_shape.into();
221        let expected = new_shape.try_size().unwrap_or_else(|| {
222            panic!(
223                "Tensor::reshape: shape size overflow for dims={:?}",
224                new_shape.as_slice()
225            )
226        });
227
228        assert!(
229            expected == self.data.len(),
230            "Tensor::reshape: new shape product {} != data.len() {}",
231            expected,
232            self.data.len()
233        );
234
235        self.shape = new_shape.clone();
236        self.strides = Strides::from(&new_shape);
237    }
238}
239
240impl<T: Clone> Tensor<T> {
241    pub fn from_elem(shape: impl Into<Shape>, value: T) -> Self {
242        let shape = shape.into();
243        let n = shape.try_size().unwrap_or_else(|| {
244            panic!(
245                "Tensor::from_elem: shape size overflow for dims={:?}",
246                shape.as_slice()
247            )
248        });
249
250        let data = vec![value; n];
251        Self::new(data, shape)
252    }
253}
254
255impl<T: Default + Clone> Tensor<T> {
256    pub fn zeros(shape: impl Into<Shape>) -> Self {
257        Self::from_elem(shape, T::default())
258    }
259}
260
261impl<T: Debug> Debug for Tensor<T> {
262    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263        writeln!(f, "Tensor(shape={:?}, data=", self.shape.dimensions())?;
264
265        fn fmt_recursive<T: std::fmt::Debug>(
266            f: &mut std::fmt::Formatter<'_>,
267            data: &[T],
268            shape: &[usize],
269            strides: &[usize],
270            offset: usize,
271            depth: usize,
272        ) -> std::fmt::Result {
273            let indent = " ".repeat(depth);
274
275            if shape.len() == 1 {
276                // Vector / leaf
277                write!(f, "{}[", indent)?;
278                for i in 0..shape[0] {
279                    if i > 0 {
280                        write!(f, ", ")?;
281                    }
282                    write!(f, "{:?}", data[offset + i * strides[0]])?;
283                }
284                write!(f, "]")?;
285            } else {
286                // Higher rank
287                write!(f, "{}[", indent)?;
288                for i in 0..shape[0] {
289                    if i > 0 {
290                        writeln!(f, ",")?;
291                    } else {
292                        writeln!(f)?;
293                    }
294
295                    fmt_recursive(
296                        f,
297                        data,
298                        &shape[1..],
299                        &strides[1..],
300                        offset + i * strides[0],
301                        depth + 1,
302                    )?;
303                }
304                writeln!(f)?;
305                write!(f, "{}]", indent)?;
306            }
307
308            Ok(())
309        }
310
311        fmt_recursive(
312            f,
313            &self.data,
314            (0..self.shape.dimensions())
315                .map(|i| self.shape.dim_at(i))
316                .collect::<Vec<usize>>()
317                .as_slice(),
318            (0..self.shape.dimensions())
319                .map(|i| self.strides.stride_at(i))
320                .collect::<Vec<usize>>()
321                .as_slice(),
322            0,
323            0,
324        )?;
325
326        write!(f, ")")
327    }
328}
329
330#[cfg(test)]
331mod tests {
332
333    use super::*;
334
335    #[test]
336    fn test_tensor_basic() {
337        let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
338
339        assert_eq!(tensor.rank(), 2);
340        assert_eq!(tensor.dims(), &[2, 3]);
341        assert_eq!(tensor.shape().as_slice(), &[2, 3]);
342        assert_eq!(tensor.strides().as_slice(), &[3, 1]);
343        assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
344    }
345
346    #[test]
347    fn test_tensor_from_elem() {
348        let tensor = Tensor::from_elem((2, 2), 42);
349
350        assert_eq!(tensor.rank(), 2);
351        assert_eq!(tensor.dims(), &[2, 2]);
352        assert_eq!(tensor.shape().as_slice(), &[2, 2]);
353        assert_eq!(tensor.strides().as_slice(), &[2, 1]);
354        assert_eq!(tensor.data(), &[42, 42, 42, 42]);
355    }
356
357    #[test]
358    fn test_try_new_len_mismatch_err() {
359        let err = Tensor::try_new(vec![1, 2, 3], (2, 2)).unwrap_err();
360        match err {
361            TensorError::LenMismatch { len, expected } => {
362                assert_eq!(len, 3);
363                assert_eq!(expected, 4);
364            }
365            other => panic!("expected LenMismatch, got: {:?}", other),
366        }
367    }
368
369    #[test]
370    fn test_reshape_updates_shape_and_strides() {
371        let mut t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
372
373        // reshape to (3, 2)
374        t.reshape((3, 2));
375
376        assert_eq!(t.dims(), &[3, 2]);
377        assert_eq!(t.strides().as_slice(), &[2, 1]); // row-major
378        assert_eq!(t.data(), &[0, 1, 2, 3, 4, 5]);
379    }
380
381    #[test]
382    #[should_panic]
383    fn test_reshape_panics_on_mismatched_size() {
384        let mut t = Tensor::new(vec![0; 6], (2, 3));
385        t.reshape((2, 2)); // product 4 != 6
386    }
387
388    #[test]
389    fn test_from_elem_fills_correctly() {
390        let t = Tensor::from_elem((2, 3), 7u32);
391        assert_eq!(t.data(), &[7, 7, 7, 7, 7, 7]);
392        assert_eq!(t.strides().as_slice(), &[3, 1]);
393    }
394
395    #[test]
396    fn test_zeros_works_for_numeric() {
397        let t = Tensor::<i32>::zeros((2, 2, 2));
398        assert_eq!(t.data(), &[0; 8]);
399        assert_eq!(t.strides().as_slice(), &[4, 2, 1]);
400    }
401
402    #[test]
403    fn test_as_raw_parts_consistency() {
404        let t = Tensor::new(vec![10, 11, 12, 13], (2, 2));
405        let (ptr, len) = t.as_raw_parts();
406        assert_eq!(len, 4);
407        // pointer identity check (safe as long as we don't deref past len)
408        assert_eq!(ptr, t.as_ptr());
409    }
410}