ferrite/tensor/ops/
blas.rs

1use std::rc::Rc;
2
3use crate::*;
4
5pub trait BlasOps {
6  fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self;
7}
8
9
10impl BlasOps for Storage {
11  fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self {
12    match_storage!(binary self, matmul, other, trans_a, trans_b)
13  }
14}
15
16impl BlasOps for Tensor {
17  fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self {
18    let tensor = self.tensor().matmul(other.tensor(), trans_a, trans_b);
19    
20    let requires_grad = *self.requires_grad() || *other.requires_grad();
21    let mut result = Tensor::new(tensor, self.device(), requires_grad);
22    
23    if requires_grad {
24      result.set_grad_fn(Some(Rc::new(MatMulGrad::new(
25        self,
26        other,
27        &result,
28        trans_a,
29        trans_b
30      ))));
31    }
32    
33    result
34  }
35}