1use crate::{Dim, NamedDims, NamedTensor};
2use burn_tensor::Tensor;
3use burn_tensor::backend::Backend;
4
5pub trait Matmul<Rhs, Out> {
6 fn matmul(self, rhs: Rhs) -> Out;
7}
8
9impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
10where
11 ND: NamedDims<B, Tensor = Tensor<B, D>>,
12{
13 pub fn matmul<NamedDimsRhs, NamedDimsOut>(
21 self,
22 rhs: NamedTensor<B, NamedDimsRhs>,
23 ) -> NamedTensor<B, NamedDimsOut>
24 where
25 NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
26 NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
27 Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
28 {
29 Matmul::matmul(self, rhs)
30 }
31}
32
33impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
34 for NamedTensor<B, (X, Y)>
35{
36 fn matmul(self, rhs: NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
37 NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
38 }
39}
40
41impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
42 Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
43 for NamedTensor<B, (Batch, X, Y)>
44{
45 fn matmul(self, rhs: NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
46 NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
47 }
48}
49
50impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
51 Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
52 for NamedTensor<B, (Batch1, Batch2, X, Y)>
53{
54 fn matmul(
55 self,
56 rhs: NamedTensor<B, (Batch1, Batch2, Y, Z)>,
57 ) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
58 NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
59 }
60}