ferrite/autograd/grad_fn/
transform.rs1use crate::{reduce_grad, tensor::*};
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct PermuteGrad {
7 input: Tensor,
8 output: Tensor,
9}
10
11
12impl PermuteGrad {
13 pub fn new(input: &Tensor, output: &Tensor) -> Self {
14 PermuteGrad {
15 input: input.clone(),
16 output: output.clone(),
17 }
18 }
19}
20
21impl GradientFunction for PermuteGrad {
22 fn backward(&self) {
23 let out_grad = self.output.grad().unwrap();
24 let out_grad = out_grad.borrow();
25
26 if let Some(input_grad) = &self.input.grad() {
28 let input_shape = self.input.tensor().shape();
31 let output_shape = self.output.tensor().shape();
32
33 let mut permutation: Vec<usize> = Vec::new();
35 for i in 0..input_shape.len() {
36 for j in 0..output_shape.len() {
37 if input_shape[i] == output_shape[j] {
38 permutation.push(j);
39 break;
40 }
41 }
42 }
43
44 let mut inverse_perm = vec![0; permutation.len()];
46 for (i, &p) in permutation.iter().enumerate() {
47 inverse_perm[p] = i;
48 }
49
50 let mut grad_tensor = out_grad.clone();
52 grad_tensor.permute(&inverse_perm);
53
54 input_grad.borrow_mut().add_tensor_assign(&grad_tensor);
56 }
57
58 }
59
60 fn prev(&self) -> Vec<&Tensor> {
61 vec![&self.input]
62 }
63}