ferrite/tensor/ops/
transform.rs

1use std::rc::Rc;
2
3use crate::{match_storage, match_storage_assign, DeviceStorage, PermuteGrad, Storage, Tensor};
4
5
6pub trait TransformOps {
7  fn apply<F>(&self, op: F) -> Self
8  where
9    F: Fn(f32) -> f32;
10  
11  fn apply_assign<F>(&mut self, op: F)
12  where
13    F: Fn(f32) -> f32;
14
15  fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
16  where
17  F: Fn(f32, f32) -> f32;
18
19  fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
20  where
21  F: Fn(f32, f32) -> f32;
22
23  fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
24  where
25  F: Fn(f32, f32) -> f32;
26
27  fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
28  where
29  F: Fn(f32, f32) -> f32;
30
31  fn sum_dim(&self, dims: &[bool]) -> Self;
32  fn reshape(&mut self, new_shape: Vec<usize>);
33  fn permute(&mut self, dims: &[usize]);
34  fn transpose(&self) -> Self;
35  fn flatten(&mut self);
36  fn squeeze(&mut self);
37  fn unsqueeze(&mut self, dim: usize);
38
39  fn broadcast(&self, new_shape: &[usize]) -> Self;
40  fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize>;
41  fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize>;
42  fn pad_shape(&self, target_rank: usize) -> Vec<usize>;
43  fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized;
44
45}
46
47macro_rules! match_storage {
48  // Binary operation with two storage arguments
49  (binary $self:expr, $method:ident, $other:expr $(, $args:expr)*) => {
50    match ($self, $other) {
51      (Storage::Cpu(cpu_self), Storage::Cpu(cpu_other)) => {
52        Storage::Cpu(cpu_self.$method(cpu_other $(, $args)*))
53      }
54      _ => unimplemented!("Cross-device operations not supported"),
55    }
56  };
57
58  // Unary operation with single storage argument
59  (unary $self:expr, $method:ident $(, $args:expr)*) => {
60    match $self {
61      Storage::Cpu(cpu) => Storage::Cpu(cpu.$method($($args),*)),
62      _ => unimplemented!("Device not supported"),
63    }
64  };
65}
66
67
68impl TransformOps for Storage {
69  fn apply<F>(&self, op: F) -> Self
70  where
71    F: Fn(f32) -> f32 {
72    match_storage!(unary self, apply, op)
73  }
74
75  fn apply_assign<F>(&mut self, op: F)
76  where
77    F: Fn(f32) -> f32 {
78    match_storage_assign!(unary self, apply_assign, op)
79  }
80
81  fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
82  where
83  F: Fn(f32, f32) -> f32 {
84    match_storage!(binary self, elementwise_op, other, op)
85  }
86
87  fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
88  where
89  F: Fn(f32, f32) -> f32 {
90    match_storage!(unary self, scalar_op, scalar, op)
91  }
92
93  fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
94  where
95  F: Fn(f32, f32) -> f32 {
96    match_storage_assign!(binary self, elementwise_op_assign, other, op)
97  }
98
99  fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
100  where
101  F: Fn(f32, f32) -> f32 {
102    match_storage_assign!(unary self, scalar_op_assign, scalar, op)
103  }
104
105  fn sum_dim(&self, dims: &[bool]) -> Self {
106    match_storage!(unary self, sum_dim, dims)
107  }
108
109  fn reshape(&mut self, new_shape: Vec<usize>) {
110    match_storage_assign!(unary self, reshape, new_shape)
111  }
112
113  fn permute(&mut self, dims: &[usize]) {
114    match_storage_assign!(unary self, permute, dims)
115  }
116
117  fn transpose(&self) -> Self {
118    match_storage!(unary self, transpose)
119  }
120
121  fn flatten(&mut self) {
122    match_storage_assign!(unary self, flatten)
123  }
124
125  fn squeeze(&mut self) {
126    match_storage_assign!(unary self, squeeze)
127  }
128
129  fn unsqueeze(&mut self, dim: usize) {
130    match_storage_assign!(unary self, unsqueeze, dim)
131  }
132
133  fn broadcast(&self, new_shape: &[usize]) -> Self {
134    match_storage!(unary self, broadcast, new_shape)
135  }
136
137  fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
138    let self_rank = self.shape().len();
139    let target_rank = target_shape.len();
140    let max_rank = std::cmp::max(self_rank, target_rank);
141    
142    // Pad shapes with 1s to match maximum rank
143    let self_padded = self.pad_shape(max_rank);
144    let mut result_shape = Vec::with_capacity(max_rank);
145
146    // Compare dimensions from right to left
147    for i in 0..max_rank {
148      let self_dim = self_padded[i];
149      let target_dim = if i >= max_rank - target_rank {
150          target_shape[i - (max_rank - target_rank)]
151        } else {
152          1
153        };
154
155      if self_dim == target_dim {
156        result_shape.push(self_dim);
157      } else if self_dim == 1 {
158        result_shape.push(target_dim);
159      } else if target_dim == 1 {
160        result_shape.push(self_dim);
161      } else {
162        panic!(
163          "Incompatible broadcast dimensions: {} and {}",
164          self_dim, target_dim
165        )
166      }
167    }
168
169    result_shape
170  }
171
172  fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
173    let self_rank = self.shape().len();
174    let broadcast_rank = broadcast_shape.len();
175    let rank_diff = broadcast_rank - self_rank;
176    
177    let mut new_strides = vec![0; broadcast_rank];
178    
179    // Fill in strides for dimensions that match or are broadcasted
180    for i in 0..self_rank {
181      let broadcast_idx = i + rank_diff;
182      if broadcast_shape[broadcast_idx] == self.shape()[i] {
183        new_strides[broadcast_idx] = self.stride()[i];
184      } else if self.shape()[i] == 1 {
185        new_strides[broadcast_idx] = 0;  // Stride of 0 for broadcasted dimensions
186      } else {
187        panic!("Invalid broadcast shape")
188      }
189    }
190
191    new_strides
192  }
193
194  fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
195    let mut padded = vec![1; target_rank];
196    let rank_diff = target_rank - self.shape().len();
197    padded[rank_diff..].copy_from_slice(self.shape());
198    padded
199  }
200
201  fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized {
202    // Use a's compute_broadcast_shape to get the final shape
203    let broadcast_shape = a.compute_broadcast_shape(b.shape());
204
205    // Broadcast both tensors to the new shape
206    let broadcast_a = a.broadcast(&broadcast_shape);
207    let broadcast_b = b.broadcast(&broadcast_shape);
208    
209    (broadcast_a, broadcast_b)
210  }
211}
212
213
214impl TransformOps for Tensor {
215  fn apply<F>(&self, op: F) -> Self
216  where
217    F: Fn(f32) -> f32 {
218    todo!()
219  }
220
221  fn apply_assign<F>(&mut self, op: F)
222  where
223    F: Fn(f32) -> f32 {
224    todo!()
225  }
226
227  fn elementwise_op<F>(&self, other: &Self, op: F) -> Self
228  where
229  F: Fn(f32, f32) -> f32 {
230    todo!()
231  }
232
233  fn scalar_op<F>(&self, scalar: f32, op: F) -> Self
234  where
235    F: Fn(f32, f32) -> f32 {
236    todo!()
237  }
238
239  fn elementwise_op_assign<F>(&mut self, other: &Self, op: F)
240  where
241    F: Fn(f32, f32) -> f32 {
242    todo!()
243  }
244
245  fn scalar_op_assign<F>(&mut self, scalar: f32, op: F)
246  where
247    F: Fn(f32, f32) -> f32 {
248    todo!()
249  }
250
251  fn sum_dim(&self, dims: &[bool]) -> Self {
252    todo!()
253  }
254
255  fn reshape(&mut self, new_shape: Vec<usize>) {
256    self.tensor_mut().set_shape(new_shape);
257  }
258
259  fn transpose(&self) -> Self {
260    // Transpose by swapping dimensions & strides
261
262    let tensor = self.tensor().transpose();
263    let requires_grad = *self.requires_grad();
264    let mut result = Tensor::new(tensor, self.device(), requires_grad);
265    
266    if requires_grad {
267      result.set_grad_fn(Some(Rc::new(PermuteGrad::new(self, &result))));
268    }
269    
270    result
271  }
272
273  fn permute(&mut self, dims: &[usize]) {
274    self.tensor_mut().permute(dims);
275  }
276
277  fn flatten(&mut self) {
278    self.tensor_mut().flatten();
279  }
280
281  fn squeeze(&mut self) {
282    self.tensor_mut().squeeze();
283  } 
284
285  fn unsqueeze(&mut self, dim: usize) {
286    self.tensor_mut().unsqueeze(dim);
287  }
288
289  fn broadcast(&self, new_shape: &[usize]) -> Self {
290    let tensor = self.tensor().broadcast(new_shape);
291    
292    // When broadcasting, we need to maintain the gradient tracking
293    let requires_grad = *self.requires_grad();
294    let mut result = Tensor::new(tensor, self.device(), requires_grad);
295    
296    // If original tensor requires gradient, the broadcasted tensor
297    // should have the same gradient function
298    if requires_grad {
299      result.set_grad_fn(self.grad_fn());
300    }
301    
302    result
303  }
304
305  fn compute_broadcast_shape(&self, target_shape: &[usize]) -> Vec<usize> {
306    self.tensor().compute_broadcast_shape(target_shape)
307  }
308
309  fn compute_broadcast_strides(&self, broadcast_shape: &[usize]) -> Vec<usize> {
310    self.tensor().compute_broadcast_strides(broadcast_shape)
311  }
312
313  fn pad_shape(&self, target_rank: usize) -> Vec<usize> {
314    self.tensor().pad_shape(target_rank)
315  }
316
317  fn broadcast_tensors(a: &Self, b: &Self) -> (Self, Self) where Self: Sized {
318    todo!()
319  }
320}