ferrite/autograd/grad_fn/
blas.rs

1use crate::tensor::*;
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct MatMulGrad {
7  lhs: Tensor,
8  rhs: Tensor,
9  output: Tensor,
10  trans_a: bool,
11  trans_b: bool
12}
13
14impl MatMulGrad {
15  pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor, trans_a: bool, trans_b: bool,) -> Self {
16    MatMulGrad {
17      lhs: lhs.clone(),
18      rhs: rhs.clone(),
19      output: output.clone(),
20      trans_a: trans_a,
21      trans_b: trans_b,
22    }
23  }
24}
25
26impl GradientFunction for MatMulGrad {
27  fn backward(&self) {
28    let out_grad = self.output.grad().unwrap();
29    let out_grad = out_grad.borrow();
30
31    // Case: C = t × w^T  (your test case)
32    // t: (2,3), w: (2,3), C: (2,2)
33    // dL/dt = dL/dC × w^T  - note: no transpose here since we already have w^T
34    // dL/dw = (dL/dC × t)^T
35
36    if let Some(lhs_grad) = &self.lhs.grad() {
37      // For input t: dL/dt = dL/dC × w^T
38      let grad_for_lhs = if !self.trans_b {
39        out_grad.matmul(self.rhs.tensor(), false, true)
40      } else {
41        out_grad.matmul(self.rhs.tensor(), false, false)
42      };
43      lhs_grad.borrow_mut().add_tensor_assign(&grad_for_lhs);
44    }
45
46    if let Some(rhs_grad) = &self.rhs.grad() {
47      // For weight w: dL/dw = (dL/dC × t)^T
48      let grad_for_rhs = if !self.trans_b {
49        self.lhs.tensor().matmul(&out_grad, true, false)
50      } else {
51        out_grad.matmul(&self.lhs.tensor(), true, false)
52      };
53      rhs_grad.borrow_mut().add_tensor_assign(&grad_for_rhs);
54    }
55  } 
56
57  fn prev(&self) -> Vec<&Tensor> {
58    vec![&self.lhs, &self.rhs]
59  }
60}