Skip to main content

trueno_tensor/
tensor.rs

1//! Dense N-dimensional tensor.
2
3use crate::error::TensorError;
4
5/// Dense tensor with arbitrary dimensions.
6///
7/// Row-major (C-order) storage. Shape `[d0, d1, ..., dk]` means
8/// element `[i0, i1, ..., ik]` is at offset `i0*stride[0] + i1*stride[1] + ...`.
9#[derive(Debug, Clone)]
10pub struct Tensor {
11    shape: Vec<usize>,
12    strides: Vec<usize>,
13    data: Vec<f32>,
14}
15
16impl Tensor {
17    /// Create a new tensor with the given shape and data.
18    ///
19    /// # Errors
20    ///
21    /// Returns error if data length doesn't match shape product.
22    pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
23        let product: usize = shape.iter().product();
24        if data.len() != product {
25            return Err(TensorError::DataLengthMismatch {
26                len: data.len(),
27                shape: shape.clone(),
28                product,
29            });
30        }
31        let strides = compute_strides(&shape);
32        Ok(Self {
33            shape,
34            strides,
35            data,
36        })
37    }
38
39    /// Create a zero tensor with the given shape.
40    pub fn zeros(shape: Vec<usize>) -> Self {
41        let product: usize = shape.iter().product();
42        let strides = compute_strides(&shape);
43        Self {
44            shape,
45            strides,
46            data: vec![0.0; product],
47        }
48    }
49
50    /// Tensor shape.
51    pub fn shape(&self) -> &[usize] {
52        &self.shape
53    }
54
55    /// Number of dimensions (rank).
56    pub fn ndim(&self) -> usize {
57        self.shape.len()
58    }
59
60    /// Total number of elements.
61    pub fn len(&self) -> usize {
62        self.data.len()
63    }
64
65    /// Whether the tensor is empty.
66    pub fn is_empty(&self) -> bool {
67        self.data.is_empty()
68    }
69
70    /// Raw data slice.
71    pub fn data(&self) -> &[f32] {
72        &self.data
73    }
74
75    /// Mutable data slice.
76    pub fn data_mut(&mut self) -> &mut [f32] {
77        &mut self.data
78    }
79
80    /// Get element at multi-index.
81    pub fn get(&self, indices: &[usize]) -> f32 {
82        let offset = self.offset(indices);
83        self.data[offset]
84    }
85
86    /// Set element at multi-index.
87    pub fn set(&mut self, indices: &[usize], value: f32) {
88        let offset = self.offset(indices);
89        self.data[offset] = value;
90    }
91
92    /// Compute linear offset from multi-index.
93    fn offset(&self, indices: &[usize]) -> usize {
94        indices
95            .iter()
96            .zip(self.strides.iter())
97            .map(|(&i, &s)| i * s)
98            .sum()
99    }
100
101    /// Reshape tensor (must have same total elements).
102    ///
103    /// # Errors
104    ///
105    /// Returns error if new shape has different total elements.
106    pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
107        let new_product: usize = new_shape.iter().product();
108        if new_product != self.data.len() {
109            return Err(TensorError::ShapeMismatch {
110                expected: new_shape,
111                got: self.shape.clone(),
112            });
113        }
114        Self::new(new_shape, self.data.clone())
115    }
116
117    /// Transpose: permute dimensions according to `perm`.
118    pub fn transpose(&self, perm: &[usize]) -> Self {
119        let ndim = self.ndim();
120        let mut new_shape = vec![0usize; ndim];
121        for (i, &p) in perm.iter().enumerate() {
122            new_shape[i] = self.shape[p];
123        }
124        let new_product: usize = new_shape.iter().product();
125        let mut new_data = vec![0.0f32; new_product];
126        let new_strides = compute_strides(&new_shape);
127
128        // Iterate over all elements
129        let mut old_indices = vec![0usize; ndim];
130        for flat in 0..self.data.len() {
131            // Convert flat index to multi-index (old)
132            let mut rem = flat;
133            for d in 0..ndim {
134                old_indices[d] = rem / self.strides[d];
135                rem %= self.strides[d];
136            }
137
138            // Permute indices
139            let new_offset: usize = perm
140                .iter()
141                .enumerate()
142                .map(|(new_d, &old_d)| old_indices[old_d] * new_strides[new_d])
143                .sum();
144
145            new_data[new_offset] = self.data[flat];
146        }
147
148        Self {
149            shape: new_shape,
150            strides: new_strides,
151            data: new_data,
152        }
153    }
154}
155
156/// Compute row-major strides for a shape.
157fn compute_strides(shape: &[usize]) -> Vec<usize> {
158    let ndim = shape.len();
159    if ndim == 0 {
160        return vec![];
161    }
162    let mut strides = vec![1usize; ndim];
163    for i in (0..ndim - 1).rev() {
164        strides[i] = strides[i + 1] * shape[i + 1];
165    }
166    strides
167}