ferrite/tensor/ops/
blas.rs1use std::rc::Rc;
2
3use crate::*;
4
5pub trait BlasOps {
6 fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self;
7}
8
9
10impl BlasOps for Storage {
11 fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self {
12 match_storage!(binary self, matmul, other, trans_a, trans_b)
13 }
14}
15
16impl BlasOps for Tensor {
17 fn matmul(&self, other: &Self, trans_a: bool, trans_b: bool) -> Self {
18 let tensor = self.tensor().matmul(other.tensor(), trans_a, trans_b);
19
20 let requires_grad = *self.requires_grad() || *other.requires_grad();
21 let mut result = Tensor::new(tensor, self.device(), requires_grad);
22
23 if requires_grad {
24 result.set_grad_fn(Some(Rc::new(MatMulGrad::new(
25 self,
26 other,
27 &result,
28 trans_a,
29 trans_b
30 ))));
31 }
32
33 result
34 }
35}