ferrite/tensor/device/cpu/kernels/
transform.rs1use crate::*;
2
3impl TransformOps for CpuStorage {
4 fn apply_assign<F>(&mut self, op: F)
5 where
6 F: Fn(f32) -> f32,
7 {
8 let data = self.data().read().unwrap().iter()
9 .map(|a| op(*a))
10 .collect();
11
12 self.set_data(data);
13 }
14
15 fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
16 where
17 F: Fn(f32, f32) -> f32,
18 {
19 let total_elements = self.shape().iter().product();
20 let mut result = vec![0.0; total_elements];
21
22 let self_binding = self.data();
24 let self_data = self_binding.read().unwrap();
25 let other_binding = other.data();
26 let other_data = other_binding.read().unwrap();
27
28 let rank = self.shape().len();
30 let shape = self.shape();
31 let self_strides = self.stride();
32 let other_strides = other.stride();
33
34 let mut chunk_size = 1;
36 let mut contiguous_dims = 0;
37 for dim in (0..rank).rev() {
38 if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
39 chunk_size *= shape[dim];
40 contiguous_dims += 1;
41 } else {
42 break;
43 }
44 }
45
46 let outer_dims = rank - contiguous_dims;
47 let mut indices = vec![0; outer_dims];
48
49 let chunks = total_elements / chunk_size;
51 for chunk_idx in 0..chunks {
52 let mut self_base_idx = 0;
54 let mut other_base_idx = 0;
55
56 for (dim, &idx) in indices.iter().enumerate() {
57 self_base_idx += idx * self_strides[dim];
58 other_base_idx += idx * other_strides[dim];
59 }
60
61 let result_start = chunk_idx * chunk_size;
63 for i in 0..chunk_size {
64 let self_val = self_data[self_base_idx + i];
65 let other_val = other_data[other_base_idx + i];
66 result[result_start + i] = op(self_val, other_val);
67 }
68
69 for dim in (0..outer_dims).rev() {
71 indices[dim] += 1;
72 if indices[dim] < shape[dim] {
73 break;
74 }
75 indices[dim] = 0;
76 }
77 }
78
79 self.set_data(result);
80 }
81
82 fn reshape(&mut self, new_shape: Vec<usize>) {
83 self.set_shape(new_shape);
84 }
85
86 fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
87 where
88 F: Fn(f32, f32) -> f32,
89 {
90 let data = self.data().read().unwrap().iter()
91 .map(|a| op(*a, scalar))
92 .collect();
93
94 self.set_data(data);
95 }
96
97 fn permute(&mut self, dims: &[usize]) {
98 let self_shape = self.shape();
99 let shape = dims.iter().map(|&i| self_shape[i]).collect();
100
101 let self_stride = self.stride();
102 let stride = dims.iter().map(|&i| self_stride[i]).collect();
103
104 self.set_shape(shape);
105 self.set_stride(stride);
106 }
107
108 fn flatten(&mut self) {
109 let shape: Vec<usize> = vec![self.shape().iter().product()];
110 let stride = vec![1];
111
112 self.set_shape(shape);
113 self.set_stride(stride);
114 }
115
116 fn squeeze(&mut self) {
117 let shape: Vec<usize> = self.shape().to_owned().iter().filter(|&&x| x != 1).cloned().collect();
119 let stride = Self::compute_strides(&shape);
120
121 self.set_shape(shape);
122 self.set_stride(stride);
123 }
124
125 fn unsqueeze(&mut self, dim: usize) {
126 let mut shape: Vec<usize> = self.shape().to_owned();
127 shape.insert(dim, 1);
128 let stride = Self::compute_strides(&shape);
129
130 self.set_shape(shape);
131 self.set_stride(stride);
132 }
133
134 fn apply<F>(&self, op: F) -> Self
135 where
136 F: Fn(f32) -> f32,
137 {
138 let data = self.data().read().unwrap().iter()
139 .map(|a| op(*a))
140 .collect();
141
142 Self::new(data, self.shape().clone())
143 }
144
145 fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
146 where
147 F: Fn(f32, f32) -> f32,
148 {
149 let total_elements = self.shape().iter().product();
150 let mut result = vec![0.0; total_elements];
151
152 let self_binding = self.data();
154 let self_data = self_binding.read().unwrap();
155 let other_binding = other.data();
156 let other_data = other_binding.read().unwrap();
157
158 let rank = self.shape().len();
160 let shape = self.shape();
161 let self_strides = self.stride();
162 let other_strides = other.stride();
163
164 let mut chunk_size = 1;
166 let mut contiguous_dims = 0;
167 for dim in (0..rank).rev() {
168 if self_strides[dim] == chunk_size && other_strides[dim] == chunk_size {
169 chunk_size *= shape[dim];
170 contiguous_dims += 1;
171 } else {
172 break;
173 }
174 }
175
176 let outer_dims = rank - contiguous_dims;
177 let mut indices = vec![0; outer_dims];
178
179 let chunks = total_elements / chunk_size;
181 for chunk_idx in 0..chunks {
182 let mut self_base_idx = 0;
184 let mut other_base_idx = 0;
185
186 for (dim, &idx) in indices.iter().enumerate() {
187 self_base_idx += idx * self_strides[dim];
188 other_base_idx += idx * other_strides[dim];
189 }
190
191 let result_start = chunk_idx * chunk_size;
193 for i in 0..chunk_size {
194 let self_val = self_data[self_base_idx + i];
195 let other_val = other_data[other_base_idx + i];
196 result[result_start + i] = op(self_val, other_val);
197 }
198
199 for dim in (0..outer_dims).rev() {
201 indices[dim] += 1;
202 if indices[dim] < shape[dim] {
203 break;
204 }
205 indices[dim] = 0;
206 }
207 }
208
209 Self::new(result, self.shape().clone())
210 }
211
212 fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
213 where
214 F: Fn(f32, f32) -> f32,
215 {
216 let data = self.data().read().unwrap().iter()
217 .map(|a| op(*a, scalar))
218 .collect();
219
220 Self::new(data, self.shape().clone())
221 }
222
223 fn sum_dim(&self, dims: &[bool]) -> Self {
224 if self.shape().len() == 1 && self.shape()[0] == 1 {
226 return self.clone();
227 }
228
229 let mut new_shape: Vec<usize> = self.shape().iter()
231 .zip(dims.iter().chain(std::iter::repeat(&false)))
232 .filter_map(|(&dim, &should_sum)| if !should_sum { Some(dim) } else { None })
233 .collect();
234
235 if new_shape.is_empty() {
237 let sum: f32 = self.data().read().unwrap().iter().sum();
238 return Self::new(vec![sum], vec![1]);
239 }
240
241 if new_shape.is_empty() {
243 new_shape.push(1);
244 }
245
246 let mut result = vec![0.0; new_shape.iter().product()];
247
248 let mut sum = 0.0;
250 let binding = self.data();
251 let data = binding.read().unwrap();
252 for i in 0..data.len() {
253 sum += data[i];
254 }
255 result[0] = sum;
256
257 Self::new(result, new_shape)
258 }
259
260 fn transpose(&self) -> Self {
261 if self.shape().len() != 2 { panic!("Must be 2-D Tensor (Matrix)"); }
263
264 let mut shape = self.shape().to_owned();
265 shape.reverse();
266
267 let mut stride = self.stride().to_owned();
268 stride.reverse();
269
270 Self::create(self.data(), shape, stride)
271 }
272
273 fn broadcast(&self, new_shape: &[usize]) -> Self {
274 let broadcast_shape = self.compute_broadcast_shape(new_shape);
276
277 let broadcast_strides = self.compute_broadcast_strides(&broadcast_shape);
279
280 Self::create(self.data(), broadcast_shape, broadcast_strides)
281 }
282
283 fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
285 let self_rank = self.shape().len();
286 let target_rank = target_shape.len();
287 let max_rank = std::cmp::max(self_rank, target_rank);
288
289 let self_padded = self.pad_shape(max_rank);
291 let mut result_shape = Vec::with_capacity(max_rank);
292
293 for i in 0..max_rank {
295 let self_dim = self_padded[i];
296 let target_dim = if i >= max_rank - target_rank {
297 target_shape[i - (max_rank - target_rank)]
298 } else {
299 1
300 };
301
302 if self_dim == target_dim {
303 result_shape.push(self_dim);
304 } else if self_dim == 1 {
305 result_shape.push(target_dim);
306 } else if target_dim == 1 {
307 result_shape.push(self_dim);
308 } else {
309 panic!(
310 "Incompatible broadcast dimensions: {} and {}",
311 self_dim, target_dim
312 )
313 }
314 }
315
316 result_shape
317 }
318
319 fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
321 let self_rank = self.shape().len();
322 let broadcast_rank = broadcast_shape.len();
323 let rank_diff = broadcast_rank - self_rank;
324
325 let mut new_strides = vec![0; broadcast_rank];
326
327 for i in 0..self_rank {
329 let broadcast_idx = i + rank_diff;
330 if broadcast_shape[broadcast_idx] == self.shape()[i] {
331 new_strides[broadcast_idx] = self.stride()[i];
332 } else if self.shape()[i] == 1 {
333 new_strides[broadcast_idx] = 0; } else {
335 panic!("Invalid broadcast shape")
336 }
337 }
338
339 new_strides
340 }
341
342 fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
344 let mut padded = vec![1; target_rank];
345 let rank_diff = target_rank - self.shape().len();
346 padded[rank_diff..].copy_from_slice(self.shape());
347 padded
348 }
349
350 fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) {
351 let broadcast_shape = a.compute_broadcast_shape(b.shape());
352 let broadcast_a = a.broadcast(&broadcast_shape);
353 let broadcast_b = b.broadcast(&broadcast_shape);
354 (broadcast_a, broadcast_b)
355 }
356}