ferrite/autograd/grad_fn/
blas.rs1use crate::tensor::*;
2use super::super::grad::*;
3
4
5#[derive(Debug)]
6pub struct MatMulGrad {
7 lhs: Tensor,
8 rhs: Tensor,
9 output: Tensor,
10 trans_a: bool,
11 trans_b: bool
12}
13
14impl MatMulGrad {
15 pub fn new(lhs: &Tensor, rhs: &Tensor, output: &Tensor, trans_a: bool, trans_b: bool,) -> Self {
16 MatMulGrad {
17 lhs: lhs.clone(),
18 rhs: rhs.clone(),
19 output: output.clone(),
20 trans_a: trans_a,
21 trans_b: trans_b,
22 }
23 }
24}
25
26impl GradientFunction for MatMulGrad {
27 fn backward(&self) {
28 let out_grad = self.output.grad().unwrap();
29 let out_grad = out_grad.borrow();
30
31 if let Some(lhs_grad) = &self.lhs.grad() {
37 let grad_for_lhs = if !self.trans_b {
39 out_grad.matmul(self.rhs.tensor(), false, true)
40 } else {
41 out_grad.matmul(self.rhs.tensor(), false, false)
42 };
43 lhs_grad.borrow_mut().add_tensor_assign(&grad_for_lhs);
44 }
45
46 if let Some(rhs_grad) = &self.rhs.grad() {
47 let grad_for_rhs = if !self.trans_b {
49 self.lhs.tensor().matmul(&out_grad, true, false)
50 } else {
51 out_grad.matmul(&self.lhs.tensor(), true, false)
52 };
53 rhs_grad.borrow_mut().add_tensor_assign(&grad_for_rhs);
54 }
55 }
56
57 fn prev(&self) -> Vec<&Tensor> {
58 vec![&self.lhs, &self.rhs]
59 }
60}