Skip to main content

llama_rs/tensor/
core.rs

1//! Tensor struct implementation
2
3use super::dtype::DType;
4use super::error::TensorError;
5use super::storage::TensorStorage;
6
7/// Compute strides from shape (row-major order)
8pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
9    if shape.is_empty() {
10        return vec![];
11    }
12    let mut strides = vec![1; shape.len()];
13    for i in (0..shape.len() - 1).rev() {
14        strides[i] = strides[i + 1] * shape[i + 1];
15    }
16    strides
17}
18
19/// A multi-dimensional tensor with typed storage
20#[derive(Debug, Clone)]
21pub struct Tensor {
22    storage: TensorStorage,
23    shape: Vec<usize>,
24    strides: Vec<usize>,
25    dtype: DType,
26    offset: usize,
27    /// Optional name for GPU weight lookup
28    name: Option<String>,
29}
30
31impl Tensor {
32    /// Create a new tensor from raw data bytes with the given shape and dtype
33    pub fn new(data: Vec<u8>, shape: Vec<usize>, dtype: DType) -> Result<Self, TensorError> {
34        let numel: usize = shape.iter().product();
35        let expected_size = dtype.size_for_elements(numel);
36
37        if data.len() != expected_size {
38            return Err(TensorError::SizeMismatch {
39                expected: expected_size,
40                got: data.len(),
41            });
42        }
43
44        let strides = compute_strides(&shape);
45
46        Ok(Self {
47            storage: TensorStorage::owned(data),
48            shape,
49            strides,
50            dtype,
51            offset: 0,
52            name: None,
53        })
54    }
55
56    /// Create a tensor from existing storage
57    ///
58    /// # Safety
59    /// The storage must contain valid data for the given shape and dtype.
60    /// The offset + size must not exceed the storage length.
61    pub unsafe fn from_storage(
62        storage: TensorStorage,
63        shape: Vec<usize>,
64        dtype: DType,
65        offset: usize,
66    ) -> Result<Self, TensorError> {
67        let numel: usize = shape.iter().product();
68        let required_size = dtype.size_for_elements(numel);
69
70        if offset + required_size > storage.len() {
71            return Err(TensorError::SizeMismatch {
72                expected: offset + required_size,
73                got: storage.len(),
74            });
75        }
76
77        let strides = compute_strides(&shape);
78
79        Ok(Self {
80            storage,
81            shape,
82            strides,
83            dtype,
84            offset,
85            name: None,
86        })
87    }
88
89    /// Create a tensor filled with zeros
90    pub fn zeros(shape: Vec<usize>, dtype: DType) -> Self {
91        let numel: usize = shape.iter().product();
92        let size = dtype.size_for_elements(numel);
93        let data = vec![0u8; size];
94        let strides = compute_strides(&shape);
95
96        Self {
97            storage: TensorStorage::owned(data),
98            shape,
99            strides,
100            dtype,
101            offset: 0,
102            name: None,
103        }
104    }
105
106    /// Create a tensor from f32 data
107    pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self, TensorError> {
108        let numel: usize = shape.iter().product();
109
110        if data.len() != numel {
111            return Err(TensorError::ShapeMismatch {
112                expected: numel,
113                got: data.len(),
114            });
115        }
116
117        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
118
119        Self::new(bytes, shape, DType::F32)
120    }
121
122    /// Get the shape of the tensor
123    pub fn shape(&self) -> &[usize] {
124        &self.shape
125    }
126
127    /// Get the number of dimensions
128    pub fn ndim(&self) -> usize {
129        self.shape.len()
130    }
131
132    /// Get the total number of elements
133    pub fn numel(&self) -> usize {
134        self.shape.iter().product()
135    }
136
137    /// Get the data type
138    pub fn dtype(&self) -> DType {
139        self.dtype
140    }
141
142    /// Get the strides
143    pub fn strides(&self) -> &[usize] {
144        &self.strides
145    }
146
147    /// Get the tensor name (for GPU weight lookup)
148    pub fn name(&self) -> Option<&str> {
149        self.name.as_deref()
150    }
151
152    /// Set the tensor name (for GPU weight lookup)
153    pub fn set_name(&mut self, name: impl Into<String>) {
154        self.name = Some(name.into());
155    }
156
157    /// Create a named tensor (builder pattern)
158    pub fn with_name(mut self, name: impl Into<String>) -> Self {
159        self.name = Some(name.into());
160        self
161    }
162
163    /// Get the raw byte data
164    pub fn data(&self) -> &[u8] {
165        let size = self.dtype.size_for_elements(self.numel());
166        &self.storage.as_bytes()[self.offset..self.offset + size]
167    }
168
169    /// Get mutable access to the raw byte data
170    pub fn data_mut(&mut self) -> Option<&mut [u8]> {
171        let size = self.dtype.size_for_elements(self.numel());
172        let offset = self.offset;
173        self.storage
174            .as_bytes_mut()
175            .map(|bytes| &mut bytes[offset..offset + size])
176    }
177
178    /// Get the data as f32 slice (only valid for F32 dtype)
179    pub fn as_f32(&self) -> Result<&[f32], TensorError> {
180        if self.dtype != DType::F32 {
181            return Err(TensorError::InvalidDType);
182        }
183        if !self.is_contiguous() {
184            return Err(TensorError::NotContiguous);
185        }
186
187        let data = self.data();
188        // SAFETY: We verified dtype is F32 and data is contiguous
189        let f32_slice =
190            unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, self.numel()) };
191        Ok(f32_slice)
192    }
193
194    /// Get mutable access to data as f32 slice (only valid for F32 dtype)
195    pub fn as_f32_mut(&mut self) -> Result<&mut [f32], TensorError> {
196        if self.dtype != DType::F32 {
197            return Err(TensorError::InvalidDType);
198        }
199        if !self.is_contiguous() {
200            return Err(TensorError::NotContiguous);
201        }
202
203        let numel = self.numel();
204        let data = self.data_mut().ok_or(TensorError::NotContiguous)?;
205        // SAFETY: We verified dtype is F32 and data is contiguous
206        let f32_slice =
207            unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut f32, numel) };
208        Ok(f32_slice)
209    }
210
211    /// Check if the tensor is contiguous in memory
212    pub fn is_contiguous(&self) -> bool {
213        if self.shape.is_empty() {
214            return true;
215        }
216
217        let expected_strides = compute_strides(&self.shape);
218        self.strides == expected_strides
219    }
220
221    /// Return a contiguous copy of this tensor if not already contiguous
222    pub fn contiguous(&self) -> Result<Self, TensorError> {
223        if self.is_contiguous() {
224            return Ok(self.clone());
225        }
226
227        // For non-contiguous tensors, we need to copy data
228        // This is only supported for non-quantized types
229        if self.dtype.is_quantized() {
230            return Err(TensorError::NotContiguous);
231        }
232
233        // Create a new contiguous tensor with the same data
234        let new_storage = self.storage.to_owned();
235        let new_strides = compute_strides(&self.shape);
236
237        Ok(Self {
238            storage: new_storage,
239            shape: self.shape.clone(),
240            strides: new_strides,
241            dtype: self.dtype,
242            offset: self.offset,
243            name: self.name.clone(),
244        })
245    }
246
247    /// Reshape the tensor to a new shape
248    ///
249    /// Returns a new tensor with the same data but different shape.
250    /// The new tensor has its own copy of the data to allow mutation.
251    pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
252        let old_numel: usize = self.shape.iter().product();
253        let new_numel: usize = new_shape.iter().product();
254
255        if old_numel != new_numel {
256            return Err(TensorError::ShapeMismatch {
257                expected: old_numel,
258                got: new_numel,
259            });
260        }
261
262        if !self.is_contiguous() {
263            return Err(TensorError::NotContiguous);
264        }
265
266        let new_strides = compute_strides(&new_shape);
267
268        // Create owned copy to allow mutation
269        Ok(Self {
270            storage: self.storage.to_owned(),
271            shape: new_shape,
272            strides: new_strides,
273            dtype: self.dtype,
274            offset: 0, // Reset offset since we copied the data
275            name: self.name.clone(),
276        })
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_compute_strides() {
286        // Empty shape
287        assert_eq!(compute_strides(&[]), Vec::<usize>::new());
288
289        // 1D
290        assert_eq!(compute_strides(&[5]), vec![1]);
291
292        // 2D (row-major)
293        assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
294
295        // 3D
296        assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
297    }
298
299    #[test]
300    fn test_tensor_zeros() {
301        let t = Tensor::zeros(vec![2, 3], DType::F32);
302        assert_eq!(t.shape(), &[2, 3]);
303        assert_eq!(t.ndim(), 2);
304        assert_eq!(t.numel(), 6);
305        assert_eq!(t.dtype(), DType::F32);
306        assert_eq!(t.strides(), &[3, 1]);
307        assert!(t.is_contiguous());
308    }
309
310    #[test]
311    fn test_tensor_from_f32() {
312        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
313        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
314
315        assert_eq!(t.shape(), &[2, 3]);
316        assert_eq!(t.numel(), 6);
317
318        let f32_data = t.as_f32().unwrap();
319        assert_eq!(f32_data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
320    }
321
322    #[test]
323    fn test_tensor_from_f32_shape_mismatch() {
324        let data = vec![1.0f32, 2.0, 3.0];
325        let result = Tensor::from_f32(&data, vec![2, 3]);
326        assert!(result.is_err());
327
328        match result {
329            Err(TensorError::ShapeMismatch { expected, got }) => {
330                assert_eq!(expected, 6);
331                assert_eq!(got, 3);
332            }
333            _ => panic!("Expected ShapeMismatch error"),
334        }
335    }
336
337    #[test]
338    fn test_tensor_reshape() {
339        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
340        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
341
342        let reshaped = t.reshape(vec![3, 2]).unwrap();
343        assert_eq!(reshaped.shape(), &[3, 2]);
344        assert_eq!(reshaped.strides(), &[2, 1]);
345
346        let reshaped_1d = t.reshape(vec![6]).unwrap();
347        assert_eq!(reshaped_1d.shape(), &[6]);
348        assert_eq!(reshaped_1d.strides(), &[1]);
349    }
350
351    #[test]
352    fn test_tensor_reshape_invalid() {
353        let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
354        let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
355
356        let result = t.reshape(vec![2, 4]);
357        assert!(result.is_err());
358    }
359
360    #[test]
361    fn test_tensor_as_f32_mut() {
362        let data = vec![1.0f32, 2.0, 3.0, 4.0];
363        let mut t = Tensor::from_f32(&data, vec![2, 2]).unwrap();
364
365        {
366            let f32_data = t.as_f32_mut().unwrap();
367            f32_data[0] = 10.0;
368            f32_data[3] = 40.0;
369        }
370
371        let f32_data = t.as_f32().unwrap();
372        assert_eq!(f32_data, &[10.0, 2.0, 3.0, 40.0]);
373    }
374
375    #[test]
376    fn test_tensor_quantized_zeros() {
377        let t = Tensor::zeros(vec![32], DType::Q4_0);
378        assert_eq!(t.shape(), &[32]);
379        assert_eq!(t.numel(), 32);
380        assert_eq!(t.dtype(), DType::Q4_0);
381        // Q4_0: 18 bytes per 32 elements
382        assert_eq!(t.data().len(), 18);
383    }
384
385    #[test]
386    fn test_tensor_is_contiguous() {
387        let t = Tensor::zeros(vec![2, 3, 4], DType::F32);
388        assert!(t.is_contiguous());
389    }
390
391    #[test]
392    fn test_tensor_new_size_mismatch() {
393        // F32 needs 24 bytes for 6 elements, but we provide 20
394        let data = vec![0u8; 20];
395        let result = Tensor::new(data, vec![2, 3], DType::F32);
396        assert!(result.is_err());
397
398        match result {
399            Err(TensorError::SizeMismatch { expected, got }) => {
400                assert_eq!(expected, 24);
401                assert_eq!(got, 20);
402            }
403            _ => panic!("Expected SizeMismatch error"),
404        }
405    }
406
407    #[test]
408    fn test_tensor_as_f32_wrong_dtype() {
409        let t = Tensor::zeros(vec![4], DType::F16);
410        let result = t.as_f32();
411        assert!(matches!(result, Err(TensorError::InvalidDType)));
412    }
413}