Skip to main content

burn_named_tensor/
matmul.rs

1use crate::{Dim, NamedDims, NamedTensor};
2use burn_tensor::Tensor;
3use burn_tensor::backend::Backend;
4
5pub trait Matmul<Rhs, Out> {
6    fn matmul(self, rhs: Rhs) -> Out;
7}
8
9impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
10where
11    ND: NamedDims<B, Tensor = Tensor<B, D>>,
12{
13    /// Applies the matrix multiplication operation.
14    ///
15    /// `C = AB`
16    ///
17    /// # Panics
18    ///
19    /// If the two tensors dont' have a compatible shape.
20    pub fn matmul<NamedDimsRhs, NamedDimsOut>(
21        self,
22        rhs: NamedTensor<B, NamedDimsRhs>,
23    ) -> NamedTensor<B, NamedDimsOut>
24    where
25        NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
26        NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
27        Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
28    {
29        Matmul::matmul(self, rhs)
30    }
31}
32
33impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
34    for NamedTensor<B, (X, Y)>
35{
36    fn matmul(self, rhs: NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
37        NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
38    }
39}
40
41impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
42    Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
43    for NamedTensor<B, (Batch, X, Y)>
44{
45    fn matmul(self, rhs: NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
46        NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
47    }
48}
49
50impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
51    Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
52    for NamedTensor<B, (Batch1, Batch2, X, Y)>
53{
54    fn matmul(
55        self,
56        rhs: NamedTensor<B, (Batch1, Batch2, Y, Z)>,
57    ) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
58        NamedTensor::from_tensor(self.tensor.matmul(rhs.tensor))
59    }
60}