1use std::rc::Rc;
2
3use crate::{match_storage, match_storage_assign, DeviceStorage, PermuteGrad, Storage, Tensor};
4
5
6pub trait TransformOps {
7 fn apply<F>(&self, op: F) -> Self
8 where
9 F: Fn(f32) -> f32;
10
11 fn apply_assign<F>(&mut self, op: F)
12 where
13 F: Fn(f32) -> f32;
14
15 fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
16 where
17 F: Fn(f32, f32) -> f32;
18
19 fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
20 where
21 F: Fn(f32, f32) -> f32;
22
23 fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
24 where
25 F: Fn(f32, f32) -> f32;
26
27 fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
28 where
29 F: Fn(f32, f32) -> f32;
30
31 fn sum_dim(&self, dims: &[bool]) -> Self;
32 fn reshape(&mut self, new_shape: Vec<usize>);
33 fn permute(&mut self, dims: &[usize]);
34 fn transpose(&self) -> Self;
35 fn flatten(&mut self);
36 fn squeeze(&mut self);
37 fn unsqueeze(&mut self, dim: usize);
38
39 fn broadcast(&self, new_shape: &[usize]) -> Self;
40 fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize>;
41 fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize>;
42 fn pad_shape(&self, target_rank: usize) -> Vec<usize>;
43 fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized;
44
45}
46
47macro_rules! match_storage {
48 (binary $self:expr, $method:ident, $other:expr $(, $args:expr)*) => {
50 match ($self, $other) {
51 (Storage::Cpu(cpu_self), Storage::Cpu(cpu_other)) => {
52 Storage::Cpu(cpu_self.$method(cpu_other $(, $args)*))
53 }
54 _ => unimplemented!("Cross-device operations not supported"),
55 }
56 };
57
58 (unary $self:expr, $method:ident $(, $args:expr)*) => {
60 match $self {
61 Storage::Cpu(cpu) => Storage::Cpu(cpu.$method($($args),*)),
62 _ => unimplemented!("Device not supported"),
63 }
64 };
65}
66
67
68impl TransformOps for Storage {
69 fn apply<F>(&self, op: F) -> Self
70 where
71 F: Fn(f32) -> f32 {
72 match_storage!(unary self, apply, op)
73 }
74
75 fn apply_assign<F>(&mut self, op: F)
76 where
77 F: Fn(f32) -> f32 {
78 match_storage_assign!(unary self, apply_assign, op)
79 }
80
81 fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
82 where
83 F: Fn(f32, f32) -> f32 {
84 match_storage!(binary self, elementwise_op, other, op)
85 }
86
87 fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
88 where
89 F: Fn(f32, f32) -> f32 {
90 match_storage!(unary self, scalar_op, scalar, op)
91 }
92
93 fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
94 where
95 F: Fn(f32, f32) -> f32 {
96 match_storage_assign!(binary self, elementwise_op_assign, other, op)
97 }
98
99 fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
100 where
101 F: Fn(f32, f32) -> f32 {
102 match_storage_assign!(unary self, scalar_op_assign, scalar, op)
103 }
104
105 fn sum_dim(&self, dims: &[bool]) -> Self {
106 match_storage!(unary self, sum_dim, dims)
107 }
108
109 fn reshape(&mut self, new_shape: Vec<usize>) {
110 match_storage_assign!(unary self, reshape, new_shape)
111 }
112
113 fn permute(&mut self, dims: &[usize]) {
114 match_storage_assign!(unary self, permute, dims)
115 }
116
117 fn transpose(&self) -> Self {
118 match_storage!(unary self, transpose)
119 }
120
121 fn flatten(&mut self) {
122 match_storage_assign!(unary self, flatten)
123 }
124
125 fn squeeze(&mut self) {
126 match_storage_assign!(unary self, squeeze)
127 }
128
129 fn unsqueeze(&mut self, dim: usize) {
130 match_storage_assign!(unary self, unsqueeze, dim)
131 }
132
133 fn broadcast(&self, new_shape: &[usize]) -> Self {
134 match_storage!(unary self, broadcast, new_shape)
135 }
136
137 fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
138 let self_rank = self.shape().len();
139 let target_rank = target_shape.len();
140 let max_rank = std::cmp::max(self_rank, target_rank);
141
142 let self_padded = self.pad_shape(max_rank);
144 let mut result_shape = Vec::with_capacity(max_rank);
145
146 for i in 0..max_rank {
148 let self_dim = self_padded[i];
149 let target_dim = if i >= max_rank - target_rank {
150 target_shape[i - (max_rank - target_rank)]
151 } else {
152 1
153 };
154
155 if self_dim == target_dim {
156 result_shape.push(self_dim);
157 } else if self_dim == 1 {
158 result_shape.push(target_dim);
159 } else if target_dim == 1 {
160 result_shape.push(self_dim);
161 } else {
162 panic!(
163 "Incompatible broadcast dimensions: {} and {}",
164 self_dim, target_dim
165 )
166 }
167 }
168
169 result_shape
170 }
171
172 fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
173 let self_rank = self.shape().len();
174 let broadcast_rank = broadcast_shape.len();
175 let rank_diff = broadcast_rank - self_rank;
176
177 let mut new_strides = vec![0; broadcast_rank];
178
179 for i in 0..self_rank {
181 let broadcast_idx = i + rank_diff;
182 if broadcast_shape[broadcast_idx] == self.shape()[i] {
183 new_strides[broadcast_idx] = self.stride()[i];
184 } else if self.shape()[i] == 1 {
185 new_strides[broadcast_idx] = 0; } else {
187 panic!("Invalid broadcast shape")
188 }
189 }
190
191 new_strides
192 }
193
194 fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
195 let mut padded = vec![1; target_rank];
196 let rank_diff = target_rank - self.shape().len();
197 padded[rank_diff..].copy_from_slice(self.shape());
198 padded
199 }
200
201 fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized {
202 let broadcast_shape = a.compute_broadcast_shape(b.shape());
204
205 let broadcast_a = a.broadcast(&broadcast_shape);
207 let broadcast_b = b.broadcast(&broadcast_shape);
208
209 (broadcast_a, broadcast_b)
210 }
211}
212
213
214impl TransformOps for Tensor {
215 fn apply<F>(&self, op: F) -> Self
216 where
217 F: Fn(f32) -> f32 {
218 todo!()
219 }
220
221 fn apply_assign<F>(&mut self, op: F)
222 where
223 F: Fn(f32) -> f32 {
224 todo!()
225 }
226
227 fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
228 where
229 F: Fn(f32, f32) -> f32 {
230 todo!()
231 }
232
233 fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
234 where
235 F: Fn(f32, f32) -> f32 {
236 todo!()
237 }
238
239 fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
240 where
241 F: Fn(f32, f32) -> f32 {
242 todo!()
243 }
244
245 fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
246 where
247 F: Fn(f32, f32) -> f32 {
248 todo!()
249 }
250
251 fn sum_dim(&self, dims: &[bool]) -> Self {
252 todo!()
253 }
254
255 fn reshape(&mut self, new_shape: Vec<usize>) {
256 self.tensor_mut().set_shape(new_shape);
257 }
258
259 fn transpose(&self) -> Self {
260 let tensor = self.tensor().transpose();
263 let requires_grad = *self.requires_grad();
264 let mut result = Tensor::new(tensor, self.device(), requires_grad);
265
266 if requires_grad {
267 result.set_grad_fn(Some(Rc::new(PermuteGrad::new(self, &result))));
268 }
269
270 result
271 }
272
273 fn permute(&mut self, dims: &[usize]) {
274 self.tensor_mut().permute(dims);
275 }
276
277 fn flatten(&mut self) {
278 self.tensor_mut().flatten();
279 }
280
281 fn squeeze(&mut self) {
282 self.tensor_mut().squeeze();
283 }
284
285 fn unsqueeze(&mut self, dim: usize) {
286 self.tensor_mut().unsqueeze(dim);
287 }
288
289 fn broadcast(&self, new_shape: &[usize]) -> Self {
290 let tensor = self.tensor().broadcast(new_shape);
291
292 let requires_grad = *self.requires_grad();
294 let mut result = Tensor::new(tensor, self.device(), requires_grad);
295
296 if requires_grad {
299 result.set_grad_fn(self.grad_fn());
300 }
301
302 result
303 }
304
305 fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
306 self.tensor().compute_broadcast_shape(target_shape)
307 }
308
309 fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
310 self.tensor().compute_broadcast_strides(broadcast_shape)
311 }
312
313 fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
314 self.tensor().pad_shape(target_rank)
315 }
316
317 fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized {
318 todo!()
319 }
320}