ferrite/autograd/grad_fn/
transform.rs

1use crate::{reduce_grad, tensor::*};
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct PermuteGrad {
7  input: Tensor,
8  output: Tensor,
9}
10
11
12impl PermuteGrad {
13  pub fn new(input: &Tensor, output: &Tensor) -> Self {
14    PermuteGrad {
15      input: input.clone(),
16      output: output.clone(),
17    }
18  }
19}
20
21impl GradientFunction for PermuteGrad {
22  fn backward(&self) {
23    let out_grad = self.output.grad().unwrap();
24    let out_grad = out_grad.borrow();
25
26    // Get input gradient if it exists (it should since we're backpropagating)
27    if let Some(input_grad) = &self.input.grad() {
28      // Determine the permutation that was applied
29      // Compare input and output shapes/strides
30      let input_shape = self.input.tensor().shape();
31      let output_shape = self.output.tensor().shape();
32      
33      // Find the permutation by matching dimensions
34      let mut permutation: Vec<usize> = Vec::new();
35      for i in 0..input_shape.len() {
36        for j in 0..output_shape.len() {
37          if input_shape[i] == output_shape[j] {
38            permutation.push(j);
39            break;
40          }
41        }
42      }
43      
44      // Create inverse permutation array
45      let mut inverse_perm = vec![0; permutation.len()];
46      for (i, &p) in permutation.iter().enumerate() {
47        inverse_perm[p] = i;
48      }
49      
50      // Apply inverse permutation to gradient
51      let mut grad_tensor = out_grad.clone();
52      grad_tensor.permute(&inverse_perm);
53      
54      // Accumulate the gradient
55      input_grad.borrow_mut().add_tensor_assign(&grad_tensor);
56    }
57  
58  }
59
60  fn prev(&self) -> Vec<&Tensor> {
61    vec![&self.input]
62  }
63}