1use crate::array::TensorError;
2use crate::{Shape, Strides};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6
7#[derive(Default)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub struct Tensor<T> {
12 pub(super) data: Vec<T>,
13 pub(super) shape: Shape,
14 pub(super) strides: Strides,
15}
16
17impl<T> Tensor<T> {
18 pub fn new(data: Vec<T>, shape: impl Into<Shape>) -> Self {
19 let shape = shape.into();
20 let strides = Strides::from(&shape);
21
22 let expected = shape.try_size().unwrap_or_else(|| {
23 panic!(
24 "Tensor::new: shape size overflow for dims={:?}",
25 shape.as_slice()
26 )
27 });
28
29 assert!(
30 data.len() == expected,
31 "Tensor::new: data.len()={} does not match shape product {}",
32 data.len(),
33 expected
34 );
35
36 Self {
37 data,
38 shape,
39 strides,
40 }
41 }
42
43 pub fn try_new(data: Vec<T>, shape: impl Into<Shape>) -> Result<Self, TensorError> {
44 let shape = shape.into();
45 let strides = Strides::from(&shape);
46
47 let expected = shape.try_size().ok_or_else(|| TensorError::ShapeOverflow {
48 dims: shape.as_slice().to_vec(),
49 })?;
50
51 if data.len() != expected {
52 return Err(TensorError::LenMismatch {
53 len: data.len(),
54 expected,
55 });
56 }
57
58 Ok(Self {
59 data,
60 shape,
61 strides,
62 })
63 }
64
65 #[inline]
78 pub fn rank(&self) -> usize {
79 self.shape.dimensions()
80 }
81
82 #[inline]
93 pub fn dims(&self) -> &[usize] {
94 self.shape.as_slice()
95 }
96
97 #[inline]
108 pub fn shape(&self) -> &Shape {
109 &self.shape
110 }
111
112 #[inline]
125 pub fn strides(&self) -> &Strides {
126 &self.strides
127 }
128
129 #[inline]
139 pub fn data(&self) -> &[T] {
140 &self.data
141 }
142
143 #[inline]
155 pub fn data_mut(&mut self) -> &mut [T] {
156 &mut self.data
157 }
158
159 #[inline]
170 pub fn is_empty(&self) -> bool {
171 self.data.is_empty()
172 }
173
174 pub fn len(&self) -> usize {
175 self.data.len()
176 }
177
178 pub fn clear(&mut self) {
179 self.data.clear();
180 }
181
182 #[inline]
184 pub fn as_ptr(&self) -> *const T {
185 self.data.as_ptr()
186 }
187
188 #[inline]
189 pub fn as_mut_ptr(&mut self) -> *mut T {
190 self.data.as_mut_ptr()
191 }
192
193 #[inline]
194 pub fn as_raw_parts(&self) -> (*const T, usize) {
195 (self.data.as_ptr(), self.data.len())
196 }
197
198 #[inline]
199 pub fn as_raw_parts_mut(&mut self) -> (*mut T, usize) {
200 (self.data.as_mut_ptr(), self.data.len())
201 }
202
203 #[inline]
205 pub fn iter(&self) -> std::slice::Iter<'_, T> {
206 self.data.iter()
207 }
208
209 #[inline]
210 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
211 self.data.iter_mut()
212 }
213
214 #[inline]
227 pub fn reshape(&mut self, new_shape: impl Into<Shape>) {
228 let new_shape = new_shape.into();
229 let expected = new_shape.try_size().unwrap_or_else(|| {
230 panic!(
231 "Tensor::reshape: shape size overflow for dims={:?}",
232 new_shape.as_slice()
233 )
234 });
235
236 assert!(
237 expected == self.data.len(),
238 "Tensor::reshape: new shape product {} != data.len() {}",
239 expected,
240 self.data.len()
241 );
242
243 self.shape = new_shape.clone();
244 self.strides = Strides::from(&new_shape);
245 }
246}
247
248impl<T: Clone> Tensor<T> {
249 pub fn from_elem(shape: impl Into<Shape>, value: T) -> Self {
250 let shape = shape.into();
251 let n = shape.try_size().unwrap_or_else(|| {
252 panic!(
253 "Tensor::from_elem: shape size overflow for dims={:?}",
254 shape.as_slice()
255 )
256 });
257
258 let data = vec![value; n];
259 Self::new(data, shape)
260 }
261}
262
263impl<T: Default + Clone> Tensor<T> {
264 pub fn zeros(shape: impl Into<Shape>) -> Self {
265 Self::from_elem(shape, T::default())
266 }
267}
268
269impl<T: Debug> Debug for Tensor<T> {
270 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
271 writeln!(f, "Tensor(shape={:?}, data=", self.shape.dimensions())?;
272
273 fn fmt_recursive<T: std::fmt::Debug>(
274 f: &mut std::fmt::Formatter<'_>,
275 data: &[T],
276 shape: &[usize],
277 strides: &[usize],
278 offset: usize,
279 depth: usize,
280 ) -> std::fmt::Result {
281 let indent = " ".repeat(depth);
282
283 if shape.len() == 1 {
284 write!(f, "{}[", indent)?;
286 for i in 0..shape[0] {
287 if i > 0 {
288 write!(f, ", ")?;
289 }
290 write!(f, "{:?}", data[offset + i * strides[0]])?;
291 }
292 write!(f, "]")?;
293 } else {
294 write!(f, "{}[", indent)?;
296 for i in 0..shape[0] {
297 if i > 0 {
298 writeln!(f, ",")?;
299 } else {
300 writeln!(f)?;
301 }
302
303 fmt_recursive(
304 f,
305 data,
306 &shape[1..],
307 &strides[1..],
308 offset + i * strides[0],
309 depth + 1,
310 )?;
311 }
312 writeln!(f)?;
313 write!(f, "{}]", indent)?;
314 }
315
316 Ok(())
317 }
318
319 fmt_recursive(
320 f,
321 &self.data,
322 (0..self.shape.dimensions())
323 .map(|i| self.shape.dim_at(i))
324 .collect::<Vec<usize>>()
325 .as_slice(),
326 (0..self.shape.dimensions())
327 .map(|i| self.strides.stride_at(i))
328 .collect::<Vec<usize>>()
329 .as_slice(),
330 0,
331 0,
332 )?;
333
334 write!(f, ")")
335 }
336}
337
338#[cfg(test)]
339mod tests {
340
341 use super::*;
342
343 #[test]
344 fn test_tensor_basic() {
345 let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
346
347 assert_eq!(tensor.rank(), 2);
348 assert_eq!(tensor.dims(), &[2, 3]);
349 assert_eq!(tensor.shape().as_slice(), &[2, 3]);
350 assert_eq!(tensor.strides().as_slice(), &[3, 1]);
351 assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
352 }
353
354 #[test]
355 fn test_tensor_from_elem() {
356 let tensor = Tensor::from_elem((2, 2), 42);
357
358 assert_eq!(tensor.rank(), 2);
359 assert_eq!(tensor.dims(), &[2, 2]);
360 assert_eq!(tensor.shape().as_slice(), &[2, 2]);
361 assert_eq!(tensor.strides().as_slice(), &[2, 1]);
362 assert_eq!(tensor.data(), &[42, 42, 42, 42]);
363 }
364
365 #[test]
366 fn test_try_new_len_mismatch_err() {
367 let err = Tensor::try_new(vec![1, 2, 3], (2, 2)).unwrap_err();
368 match err {
369 TensorError::LenMismatch { len, expected } => {
370 assert_eq!(len, 3);
371 assert_eq!(expected, 4);
372 }
373 other => panic!("expected LenMismatch, got: {:?}", other),
374 }
375 }
376
377 #[test]
378 fn test_reshape_updates_shape_and_strides() {
379 let mut t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
380
381 t.reshape((3, 2));
383
384 assert_eq!(t.dims(), &[3, 2]);
385 assert_eq!(t.strides().as_slice(), &[2, 1]); assert_eq!(t.data(), &[0, 1, 2, 3, 4, 5]);
387 }
388
389 #[test]
390 #[should_panic]
391 fn test_reshape_panics_on_mismatched_size() {
392 let mut t = Tensor::new(vec![0; 6], (2, 3));
393 t.reshape((2, 2)); }
395
396 #[test]
397 fn test_from_elem_fills_correctly() {
398 let t = Tensor::from_elem((2, 3), 7u32);
399 assert_eq!(t.data(), &[7, 7, 7, 7, 7, 7]);
400 assert_eq!(t.strides().as_slice(), &[3, 1]);
401 }
402
403 #[test]
404 fn test_zeros_works_for_numeric() {
405 let t = Tensor::<i32>::zeros((2, 2, 2));
406 assert_eq!(t.data(), &[0; 8]);
407 assert_eq!(t.strides().as_slice(), &[4, 2, 1]);
408 }
409
410 #[test]
411 fn test_as_raw_parts_consistency() {
412 let t = Tensor::new(vec![10, 11, 12, 13], (2, 2));
413 let (ptr, len) = t.as_raw_parts();
414 assert_eq!(len, 4);
415 assert_eq!(ptr, t.as_ptr());
417 }
418}