1use crate::array::TensorError;
2use crate::{Shape, Strides};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use std::fmt::Debug;
6
7#[derive(Default)]
10#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
11pub struct Tensor<T> {
12 pub(super) data: Vec<T>,
13 pub(super) shape: Shape,
14 pub(super) strides: Strides,
15}
16
17impl<T> Tensor<T> {
18 pub fn new(data: Vec<T>, shape: impl Into<Shape>) -> Self {
19 let shape = shape.into();
20 let strides = Strides::from(&shape);
21
22 let expected = shape.try_size().unwrap_or_else(|| {
23 panic!(
24 "Tensor::new: shape size overflow for dims={:?}",
25 shape.as_slice()
26 )
27 });
28
29 assert!(
30 data.len() == expected,
31 "Tensor::new: data.len()={} does not match shape product {}",
32 data.len(),
33 expected
34 );
35
36 Self {
37 data,
38 shape,
39 strides,
40 }
41 }
42
43 pub fn try_new(data: Vec<T>, shape: impl Into<Shape>) -> Result<Self, TensorError> {
44 let shape = shape.into();
45 let strides = Strides::from(&shape);
46
47 let expected = shape.try_size().ok_or_else(|| TensorError::ShapeOverflow {
48 dims: shape.as_slice().to_vec(),
49 })?;
50
51 if data.len() != expected {
52 return Err(TensorError::LenMismatch {
53 len: data.len(),
54 expected,
55 });
56 }
57
58 Ok(Self {
59 data,
60 shape,
61 strides,
62 })
63 }
64
65 #[inline]
78 pub fn rank(&self) -> usize {
79 self.shape.dimensions()
80 }
81
82 #[inline]
93 pub fn dims(&self) -> &[usize] {
94 self.shape.as_slice()
95 }
96
97 #[inline]
108 pub fn shape(&self) -> &Shape {
109 &self.shape
110 }
111
112 #[inline]
125 pub fn strides(&self) -> &Strides {
126 &self.strides
127 }
128
129 #[inline]
139 pub fn data(&self) -> &[T] {
140 &self.data
141 }
142
143 #[inline]
155 pub fn data_mut(&mut self) -> &mut [T] {
156 &mut self.data
157 }
158
159 #[inline]
170 pub fn is_empty(&self) -> bool {
171 self.data.is_empty()
172 }
173
174 #[inline]
176 pub fn as_ptr(&self) -> *const T {
177 self.data.as_ptr()
178 }
179
180 #[inline]
181 pub fn as_mut_ptr(&mut self) -> *mut T {
182 self.data.as_mut_ptr()
183 }
184
185 #[inline]
186 pub fn as_raw_parts(&self) -> (*const T, usize) {
187 (self.data.as_ptr(), self.data.len())
188 }
189
190 #[inline]
191 pub fn as_raw_parts_mut(&mut self) -> (*mut T, usize) {
192 (self.data.as_mut_ptr(), self.data.len())
193 }
194
195 #[inline]
197 pub fn iter(&self) -> std::slice::Iter<'_, T> {
198 self.data.iter()
199 }
200
201 #[inline]
202 pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, T> {
203 self.data.iter_mut()
204 }
205
206 #[inline]
219 pub fn reshape(&mut self, new_shape: impl Into<Shape>) {
220 let new_shape = new_shape.into();
221 let expected = new_shape.try_size().unwrap_or_else(|| {
222 panic!(
223 "Tensor::reshape: shape size overflow for dims={:?}",
224 new_shape.as_slice()
225 )
226 });
227
228 assert!(
229 expected == self.data.len(),
230 "Tensor::reshape: new shape product {} != data.len() {}",
231 expected,
232 self.data.len()
233 );
234
235 self.shape = new_shape.clone();
236 self.strides = Strides::from(&new_shape);
237 }
238}
239
240impl<T: Clone> Tensor<T> {
241 pub fn from_elem(shape: impl Into<Shape>, value: T) -> Self {
242 let shape = shape.into();
243 let n = shape.try_size().unwrap_or_else(|| {
244 panic!(
245 "Tensor::from_elem: shape size overflow for dims={:?}",
246 shape.as_slice()
247 )
248 });
249
250 let data = vec![value; n];
251 Self::new(data, shape)
252 }
253}
254
255impl<T: Default + Clone> Tensor<T> {
256 pub fn zeros(shape: impl Into<Shape>) -> Self {
257 Self::from_elem(shape, T::default())
258 }
259}
260
261impl<T: Debug> Debug for Tensor<T> {
262 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
263 writeln!(f, "Tensor(shape={:?}, data=", self.shape.dimensions())?;
264
265 fn fmt_recursive<T: std::fmt::Debug>(
266 f: &mut std::fmt::Formatter<'_>,
267 data: &[T],
268 shape: &[usize],
269 strides: &[usize],
270 offset: usize,
271 depth: usize,
272 ) -> std::fmt::Result {
273 let indent = " ".repeat(depth);
274
275 if shape.len() == 1 {
276 write!(f, "{}[", indent)?;
278 for i in 0..shape[0] {
279 if i > 0 {
280 write!(f, ", ")?;
281 }
282 write!(f, "{:?}", data[offset + i * strides[0]])?;
283 }
284 write!(f, "]")?;
285 } else {
286 write!(f, "{}[", indent)?;
288 for i in 0..shape[0] {
289 if i > 0 {
290 writeln!(f, ",")?;
291 } else {
292 writeln!(f)?;
293 }
294
295 fmt_recursive(
296 f,
297 data,
298 &shape[1..],
299 &strides[1..],
300 offset + i * strides[0],
301 depth + 1,
302 )?;
303 }
304 writeln!(f)?;
305 write!(f, "{}]", indent)?;
306 }
307
308 Ok(())
309 }
310
311 fmt_recursive(
312 f,
313 &self.data,
314 (0..self.shape.dimensions())
315 .map(|i| self.shape.dim_at(i))
316 .collect::<Vec<usize>>()
317 .as_slice(),
318 (0..self.shape.dimensions())
319 .map(|i| self.strides.stride_at(i))
320 .collect::<Vec<usize>>()
321 .as_slice(),
322 0,
323 0,
324 )?;
325
326 write!(f, ")")
327 }
328}
329
330#[cfg(test)]
331mod tests {
332
333 use super::*;
334
335 #[test]
336 fn test_tensor_basic() {
337 let tensor = Tensor::new(vec![1, 2, 3, 4, 5, 6], (2, 3));
338
339 assert_eq!(tensor.rank(), 2);
340 assert_eq!(tensor.dims(), &[2, 3]);
341 assert_eq!(tensor.shape().as_slice(), &[2, 3]);
342 assert_eq!(tensor.strides().as_slice(), &[3, 1]);
343 assert_eq!(tensor.data(), &[1, 2, 3, 4, 5, 6]);
344 }
345
346 #[test]
347 fn test_tensor_from_elem() {
348 let tensor = Tensor::from_elem((2, 2), 42);
349
350 assert_eq!(tensor.rank(), 2);
351 assert_eq!(tensor.dims(), &[2, 2]);
352 assert_eq!(tensor.shape().as_slice(), &[2, 2]);
353 assert_eq!(tensor.strides().as_slice(), &[2, 1]);
354 assert_eq!(tensor.data(), &[42, 42, 42, 42]);
355 }
356
357 #[test]
358 fn test_try_new_len_mismatch_err() {
359 let err = Tensor::try_new(vec![1, 2, 3], (2, 2)).unwrap_err();
360 match err {
361 TensorError::LenMismatch { len, expected } => {
362 assert_eq!(len, 3);
363 assert_eq!(expected, 4);
364 }
365 other => panic!("expected LenMismatch, got: {:?}", other),
366 }
367 }
368
369 #[test]
370 fn test_reshape_updates_shape_and_strides() {
371 let mut t = Tensor::new((0..6).collect::<Vec<i32>>(), (2, 3));
372
373 t.reshape((3, 2));
375
376 assert_eq!(t.dims(), &[3, 2]);
377 assert_eq!(t.strides().as_slice(), &[2, 1]); assert_eq!(t.data(), &[0, 1, 2, 3, 4, 5]);
379 }
380
381 #[test]
382 #[should_panic]
383 fn test_reshape_panics_on_mismatched_size() {
384 let mut t = Tensor::new(vec![0; 6], (2, 3));
385 t.reshape((2, 2)); }
387
388 #[test]
389 fn test_from_elem_fills_correctly() {
390 let t = Tensor::from_elem((2, 3), 7u32);
391 assert_eq!(t.data(), &[7, 7, 7, 7, 7, 7]);
392 assert_eq!(t.strides().as_slice(), &[3, 1]);
393 }
394
395 #[test]
396 fn test_zeros_works_for_numeric() {
397 let t = Tensor::<i32>::zeros((2, 2, 2));
398 assert_eq!(t.data(), &[0; 8]);
399 assert_eq!(t.strides().as_slice(), &[4, 2, 1]);
400 }
401
402 #[test]
403 fn test_as_raw_parts_consistency() {
404 let t = Tensor::new(vec![10, 11, 12, 13], (2, 2));
405 let (ptr, len) = t.as_raw_parts();
406 assert_eq!(len, 4);
407 assert_eq!(ptr, t.as_ptr());
409 }
410}