hpt/backends/cpu/tensor_external/
gemm.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use crate::Tensor;
4use hpt_allocator::traits::{Allocator, AllocatorOutputRetrive};
5use hpt_allocator::Cpu;
6use hpt_common::error::base::TensorError;
7use hpt_traits::ops::binary::Gemm;
8use hpt_traits::tensor::CommonBounds;
9use hpt_types::{into_scalar::Cast, type_promote::NormalOut};
10
11type GemmOutput<A, B, const DEVICE: usize, A2> =
12    Tensor<<A as NormalOut<B>>::Output, Cpu, DEVICE, A2>;
13
14impl<A, B, A2, const DEVICE: usize> Gemm<Tensor<B, Cpu, DEVICE, A2>> for Tensor<A, Cpu, DEVICE, A2>
15where
16    A: CommonBounds + NormalOut<B> + Cast<<A as NormalOut<B>>::Output>,
17    B: CommonBounds + Cast<<A as NormalOut<B>>::Output>,
18    <A as NormalOut<B>>::Output: CommonBounds,
19    A2: Allocator,
20    A2::Output: AllocatorOutputRetrive,
21{
22    type Output = GemmOutput<A, B, DEVICE, A2>;
23
24    type OutputMeta = <A as NormalOut<B>>::Output;
25
26    type InplaceOutput = GemmOutput<A, B, DEVICE, A2>;
27
28    fn gemm(
29        &self,
30        rhs: Tensor<B, Cpu, DEVICE, A2>,
31        alpha: Self::OutputMeta,
32        beta: Self::OutputMeta,
33        conj_dst: bool,
34        conj_lhs: bool,
35        conj_rhs: bool,
36    ) -> Result<Self::Output, TensorError> {
37        Ok(self
38            .inner
39            .gemm(
40                rhs.inner.as_ref(),
41                alpha,
42                beta,
43                conj_dst,
44                conj_lhs,
45                conj_rhs,
46            )?
47            .into())
48    }
49    fn gemm_<U>(
50        &self,
51        rhs: Tensor<B, Cpu, DEVICE, A2>,
52        alpha: Self::OutputMeta,
53        beta: Self::OutputMeta,
54        conj_dst: bool,
55        conj_lhs: bool,
56        conj_rhs: bool,
57        out: U,
58    ) -> Result<Self::Output, TensorError>
59    where
60        U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
61    {
62        let out = out.borrow().inner.as_ref().clone();
63        Ok(self
64            .inner
65            .gemm_(
66                rhs.inner.as_ref(),
67                alpha,
68                beta,
69                conj_dst,
70                conj_lhs,
71                conj_rhs,
72                out,
73            )?
74            .into())
75    }
76}
77
78impl<A, B, A2, const DEVICE: usize> Gemm<&Tensor<B, Cpu, DEVICE, A2>> for Tensor<A, Cpu, DEVICE, A2>
79where
80    A: CommonBounds + NormalOut<B> + Cast<<A as NormalOut<B>>::Output>,
81    B: CommonBounds + Cast<<A as NormalOut<B>>::Output>,
82    <A as NormalOut<B>>::Output: CommonBounds,
83    A2: Allocator,
84    A2::Output: AllocatorOutputRetrive,
85{
86    type Output = GemmOutput<A, B, DEVICE, A2>;
87
88    type OutputMeta = <A as NormalOut<B>>::Output;
89
90    type InplaceOutput = GemmOutput<A, B, DEVICE, A2>;
91
92    fn gemm(
93        &self,
94        rhs: &Tensor<B, Cpu, DEVICE, A2>,
95        alpha: Self::OutputMeta,
96        beta: Self::OutputMeta,
97        conj_dst: bool,
98        conj_lhs: bool,
99        conj_rhs: bool,
100    ) -> Result<Self::Output, TensorError> {
101        Ok(self
102            .inner
103            .gemm(
104                rhs.inner.as_ref(),
105                alpha,
106                beta,
107                conj_dst,
108                conj_lhs,
109                conj_rhs,
110            )?
111            .into())
112    }
113
114    fn gemm_<U>(
115        &self,
116        rhs: &Tensor<B, Cpu, DEVICE, A2>,
117        alpha: Self::OutputMeta,
118        beta: Self::OutputMeta,
119        conj_dst: bool,
120        conj_lhs: bool,
121        conj_rhs: bool,
122        out: U,
123    ) -> Result<Self::Output, TensorError>
124    where
125        U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
126    {
127        let out = out.borrow().inner.as_ref().clone();
128        Ok(self
129            .inner
130            .gemm_(
131                rhs.inner.as_ref(),
132                alpha,
133                beta,
134                conj_dst,
135                conj_lhs,
136                conj_rhs,
137                out,
138            )?
139            .into())
140    }
141}