1use crate::error::TensorError;
4
5#[derive(Debug, Clone)]
10pub struct Tensor {
11 shape: Vec<usize>,
12 strides: Vec<usize>,
13 data: Vec<f32>,
14}
15
16impl Tensor {
17 pub fn new(shape: Vec<usize>, data: Vec<f32>) -> Result<Self, TensorError> {
23 let product: usize = shape.iter().product();
24 if data.len() != product {
25 return Err(TensorError::DataLengthMismatch {
26 len: data.len(),
27 shape: shape.clone(),
28 product,
29 });
30 }
31 let strides = compute_strides(&shape);
32 Ok(Self {
33 shape,
34 strides,
35 data,
36 })
37 }
38
39 pub fn zeros(shape: Vec<usize>) -> Self {
41 let product: usize = shape.iter().product();
42 let strides = compute_strides(&shape);
43 Self {
44 shape,
45 strides,
46 data: vec![0.0; product],
47 }
48 }
49
50 pub fn shape(&self) -> &[usize] {
52 &self.shape
53 }
54
55 pub fn ndim(&self) -> usize {
57 self.shape.len()
58 }
59
60 pub fn len(&self) -> usize {
62 self.data.len()
63 }
64
65 pub fn is_empty(&self) -> bool {
67 self.data.is_empty()
68 }
69
70 pub fn data(&self) -> &[f32] {
72 &self.data
73 }
74
75 pub fn data_mut(&mut self) -> &mut [f32] {
77 &mut self.data
78 }
79
80 pub fn get(&self, indices: &[usize]) -> f32 {
82 let offset = self.offset(indices);
83 self.data[offset]
84 }
85
86 pub fn set(&mut self, indices: &[usize], value: f32) {
88 let offset = self.offset(indices);
89 self.data[offset] = value;
90 }
91
92 fn offset(&self, indices: &[usize]) -> usize {
94 indices
95 .iter()
96 .zip(self.strides.iter())
97 .map(|(&i, &s)| i * s)
98 .sum()
99 }
100
101 pub fn reshape(&self, new_shape: Vec<usize>) -> Result<Self, TensorError> {
107 let new_product: usize = new_shape.iter().product();
108 if new_product != self.data.len() {
109 return Err(TensorError::ShapeMismatch {
110 expected: new_shape,
111 got: self.shape.clone(),
112 });
113 }
114 Self::new(new_shape, self.data.clone())
115 }
116
117 pub fn transpose(&self, perm: &[usize]) -> Self {
119 let ndim = self.ndim();
120 let mut new_shape = vec![0usize; ndim];
121 for (i, &p) in perm.iter().enumerate() {
122 new_shape[i] = self.shape[p];
123 }
124 let new_product: usize = new_shape.iter().product();
125 let mut new_data = vec![0.0f32; new_product];
126 let new_strides = compute_strides(&new_shape);
127
128 let mut old_indices = vec![0usize; ndim];
130 for flat in 0..self.data.len() {
131 let mut rem = flat;
133 for d in 0..ndim {
134 old_indices[d] = rem / self.strides[d];
135 rem %= self.strides[d];
136 }
137
138 let new_offset: usize = perm
140 .iter()
141 .enumerate()
142 .map(|(new_d, &old_d)| old_indices[old_d] * new_strides[new_d])
143 .sum();
144
145 new_data[new_offset] = self.data[flat];
146 }
147
148 Self {
149 shape: new_shape,
150 strides: new_strides,
151 data: new_data,
152 }
153 }
154}
155
156fn compute_strides(shape: &[usize]) -> Vec<usize> {
158 let ndim = shape.len();
159 if ndim == 0 {
160 return vec![];
161 }
162 let mut strides = vec![1usize; ndim];
163 for i in (0..ndim - 1).rev() {
164 strides[i] = strides[i + 1] * shape[i + 1];
165 }
166 strides
167}