1use super::dtype::DType;
4use super::error::TensorError;
5use super::storage::TensorStorage;
6
7pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
9 if shape.is_empty() {
10 return vec![];
11 }
12 let mut strides = vec![1; shape.len()];
13 for i in (0..shape.len() - 1).rev() {
14 strides[i] = strides[i + 1] * shape[i + 1];
15 }
16 strides
17}
18
19#[derive(Debug, Clone)]
21pub struct Tensor {
22 storage: TensorStorage,
23 shape: Vec<usize>,
24 strides: Vec<usize>,
25 dtype: DType,
26 offset: usize,
27 name: Option<String>,
29}
30
31impl Tensor {
32 pub fn new(data: Vec<u8>, shape: Vec<usize>, dtype: DType) -> Result<Self, TensorError> {
34 let numel: usize = shape.iter().product();
35 let expected_size = dtype.size_for_elements(numel);
36
37 if data.len() != expected_size {
38 return Err(TensorError::SizeMismatch {
39 expected: expected_size,
40 got: data.len(),
41 });
42 }
43
44 let strides = compute_strides(&shape);
45
46 Ok(Self {
47 storage: TensorStorage::owned(data),
48 shape,
49 strides,
50 dtype,
51 offset: 0,
52 name: None,
53 })
54 }
55
56 pub unsafe fn from_storage(
62 storage: TensorStorage,
63 shape: Vec<usize>,
64 dtype: DType,
65 offset: usize,
66 ) -> Result<Self, TensorError> {
67 let numel: usize = shape.iter().product();
68 let required_size = dtype.size_for_elements(numel);
69
70 if offset + required_size > storage.len() {
71 return Err(TensorError::SizeMismatch {
72 expected: offset + required_size,
73 got: storage.len(),
74 });
75 }
76
77 let strides = compute_strides(&shape);
78
79 Ok(Self {
80 storage,
81 shape,
82 strides,
83 dtype,
84 offset,
85 name: None,
86 })
87 }
88
89 pub fn zeros(shape: Vec<usize>, dtype: DType) -> Self {
91 let numel: usize = shape.iter().product();
92 let size = dtype.size_for_elements(numel);
93 let data = vec![0u8; size];
94 let strides = compute_strides(&shape);
95
96 Self {
97 storage: TensorStorage::owned(data),
98 shape,
99 strides,
100 dtype,
101 offset: 0,
102 name: None,
103 }
104 }
105
106 pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Result<Self, TensorError> {
108 let numel: usize = shape.iter().product();
109
110 if data.len() != numel {
111 return Err(TensorError::ShapeMismatch {
112 expected: numel,
113 got: data.len(),
114 });
115 }
116
117 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
118
119 Self::new(bytes, shape, DType::F32)
120 }
121
122 pub fn shape(&self) -> &[usize] {
124 &self.shape
125 }
126
127 pub fn ndim(&self) -> usize {
129 self.shape.len()
130 }
131
132 pub fn numel(&self) -> usize {
134 self.shape.iter().product()
135 }
136
137 pub fn dtype(&self) -> DType {
139 self.dtype
140 }
141
142 pub fn strides(&self) -> &[usize] {
144 &self.strides
145 }
146
147 pub fn name(&self) -> Option<&str> {
149 self.name.as_deref()
150 }
151
152 pub fn set_name(&mut self, name: impl Into<String>) {
154 self.name = Some(name.into());
155 }
156
157 pub fn with_name(mut self, name: impl Into<String>) -> Self {
159 self.name = Some(name.into());
160 self
161 }
162
163 pub fn data(&self) -> &[u8] {
165 let size = self.dtype.size_for_elements(self.numel());
166 &self.storage.as_bytes()[self.offset..self.offset + size]
167 }
168
169 pub fn data_mut(&mut self) -> Option<&mut [u8]> {
171 let size = self.dtype.size_for_elements(self.numel());
172 let offset = self.offset;
173 self.storage
174 .as_bytes_mut()
175 .map(|bytes| &mut bytes[offset..offset + size])
176 }
177
178 pub fn as_f32(&self) -> Result<&[f32], TensorError> {
180 if self.dtype != DType::F32 {
181 return Err(TensorError::InvalidDType);
182 }
183 if !self.is_contiguous() {
184 return Err(TensorError::NotContiguous);
185 }
186
187 let data = self.data();
188 let f32_slice =
190 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const f32, self.numel()) };
191 Ok(f32_slice)
192 }
193
194 pub fn as_f32_mut(&mut self) -> Result<&mut [f32], TensorError> {
196 if self.dtype != DType::F32 {
197 return Err(TensorError::InvalidDType);
198 }
199 if !self.is_contiguous() {
200 return Err(TensorError::NotContiguous);
201 }
202
203 let numel = self.numel();
204 let data = self.data_mut().ok_or(TensorError::NotContiguous)?;
205 let f32_slice =
207 unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr() as *mut f32, numel) };
208 Ok(f32_slice)
209 }
210
211 pub fn is_contiguous(&self) -> bool {
213 if self.shape.is_empty() {
214 return true;
215 }
216
217 let expected_strides = compute_strides(&self.shape);
218 self.strides == expected_strides
219 }
220
221 pub fn contiguous(&self) -> Result<Self, TensorError> {
223 if self.is_contiguous() {
224 return Ok(self.clone());
225 }
226
227 if self.dtype.is_quantized() {
230 return Err(TensorError::NotContiguous);
231 }
232
233 let new_storage = self.storage.to_owned();
235 let new_strides = compute_strides(&self.shape);
236
237 Ok(Self {
238 storage: new_storage,
239 shape: self.shape.clone(),
240 strides: new_strides,
241 dtype: self.dtype,
242 offset: self.offset,
243 name: self.name.clone(),
244 })
245 }
246
247 pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
252 let old_numel: usize = self.shape.iter().product();
253 let new_numel: usize = new_shape.iter().product();
254
255 if old_numel != new_numel {
256 return Err(TensorError::ShapeMismatch {
257 expected: old_numel,
258 got: new_numel,
259 });
260 }
261
262 if !self.is_contiguous() {
263 return Err(TensorError::NotContiguous);
264 }
265
266 let new_strides = compute_strides(&new_shape);
267
268 Ok(Self {
270 storage: self.storage.to_owned(),
271 shape: new_shape,
272 strides: new_strides,
273 dtype: self.dtype,
274 offset: 0, name: self.name.clone(),
276 })
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_compute_strides() {
286 assert_eq!(compute_strides(&[]), Vec::<usize>::new());
288
289 assert_eq!(compute_strides(&[5]), vec![1]);
291
292 assert_eq!(compute_strides(&[3, 4]), vec![4, 1]);
294
295 assert_eq!(compute_strides(&[2, 3, 4]), vec![12, 4, 1]);
297 }
298
299 #[test]
300 fn test_tensor_zeros() {
301 let t = Tensor::zeros(vec![2, 3], DType::F32);
302 assert_eq!(t.shape(), &[2, 3]);
303 assert_eq!(t.ndim(), 2);
304 assert_eq!(t.numel(), 6);
305 assert_eq!(t.dtype(), DType::F32);
306 assert_eq!(t.strides(), &[3, 1]);
307 assert!(t.is_contiguous());
308 }
309
310 #[test]
311 fn test_tensor_from_f32() {
312 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
313 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
314
315 assert_eq!(t.shape(), &[2, 3]);
316 assert_eq!(t.numel(), 6);
317
318 let f32_data = t.as_f32().unwrap();
319 assert_eq!(f32_data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
320 }
321
322 #[test]
323 fn test_tensor_from_f32_shape_mismatch() {
324 let data = vec![1.0f32, 2.0, 3.0];
325 let result = Tensor::from_f32(&data, vec![2, 3]);
326 assert!(result.is_err());
327
328 match result {
329 Err(TensorError::ShapeMismatch { expected, got }) => {
330 assert_eq!(expected, 6);
331 assert_eq!(got, 3);
332 }
333 _ => panic!("Expected ShapeMismatch error"),
334 }
335 }
336
337 #[test]
338 fn test_tensor_reshape() {
339 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
340 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
341
342 let reshaped = t.reshape(vec![3, 2]).unwrap();
343 assert_eq!(reshaped.shape(), &[3, 2]);
344 assert_eq!(reshaped.strides(), &[2, 1]);
345
346 let reshaped_1d = t.reshape(vec![6]).unwrap();
347 assert_eq!(reshaped_1d.shape(), &[6]);
348 assert_eq!(reshaped_1d.strides(), &[1]);
349 }
350
351 #[test]
352 fn test_tensor_reshape_invalid() {
353 let data = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
354 let t = Tensor::from_f32(&data, vec![2, 3]).unwrap();
355
356 let result = t.reshape(vec![2, 4]);
357 assert!(result.is_err());
358 }
359
360 #[test]
361 fn test_tensor_as_f32_mut() {
362 let data = vec![1.0f32, 2.0, 3.0, 4.0];
363 let mut t = Tensor::from_f32(&data, vec![2, 2]).unwrap();
364
365 {
366 let f32_data = t.as_f32_mut().unwrap();
367 f32_data[0] = 10.0;
368 f32_data[3] = 40.0;
369 }
370
371 let f32_data = t.as_f32().unwrap();
372 assert_eq!(f32_data, &[10.0, 2.0, 3.0, 40.0]);
373 }
374
375 #[test]
376 fn test_tensor_quantized_zeros() {
377 let t = Tensor::zeros(vec![32], DType::Q4_0);
378 assert_eq!(t.shape(), &[32]);
379 assert_eq!(t.numel(), 32);
380 assert_eq!(t.dtype(), DType::Q4_0);
381 assert_eq!(t.data().len(), 18);
383 }
384
385 #[test]
386 fn test_tensor_is_contiguous() {
387 let t = Tensor::zeros(vec![2, 3, 4], DType::F32);
388 assert!(t.is_contiguous());
389 }
390
391 #[test]
392 fn test_tensor_new_size_mismatch() {
393 let data = vec![0u8; 20];
395 let result = Tensor::new(data, vec![2, 3], DType::F32);
396 assert!(result.is_err());
397
398 match result {
399 Err(TensorError::SizeMismatch { expected, got }) => {
400 assert_eq!(expected, 24);
401 assert_eq!(got, 20);
402 }
403 _ => panic!("Expected SizeMismatch error"),
404 }
405 }
406
407 #[test]
408 fn test_tensor_as_f32_wrong_dtype() {
409 let t = Tensor::zeros(vec![4], DType::F16);
410 let result = t.as_f32();
411 assert!(matches!(result, Err(TensorError::InvalidDType)));
412 }
413}