hpt/backends/cpu/tensor_external/
gemm.rs1use std::borrow::{Borrow, BorrowMut};
2
3use crate::Tensor;
4use hpt_allocator::traits::{Allocator, AllocatorOutputRetrive};
5use hpt_allocator::Cpu;
6use hpt_common::error::base::TensorError;
7use hpt_traits::ops::binary::Gemm;
8use hpt_traits::tensor::CommonBounds;
9use hpt_types::{into_scalar::Cast, type_promote::NormalOut};
10
11type GemmOutput<A, B, const DEVICE: usize, A2> =
12 Tensor<<A as NormalOut<B>>::Output, Cpu, DEVICE, A2>;
13
14impl<A, B, A2, const DEVICE: usize> Gemm<Tensor<B, Cpu, DEVICE, A2>> for Tensor<A, Cpu, DEVICE, A2>
15where
16 A: CommonBounds + NormalOut<B> + Cast<<A as NormalOut<B>>::Output>,
17 B: CommonBounds + Cast<<A as NormalOut<B>>::Output>,
18 <A as NormalOut<B>>::Output: CommonBounds,
19 A2: Allocator,
20 A2::Output: AllocatorOutputRetrive,
21{
22 type Output = GemmOutput<A, B, DEVICE, A2>;
23
24 type OutputMeta = <A as NormalOut<B>>::Output;
25
26 type InplaceOutput = GemmOutput<A, B, DEVICE, A2>;
27
28 fn gemm(
29 &self,
30 rhs: Tensor<B, Cpu, DEVICE, A2>,
31 alpha: Self::OutputMeta,
32 beta: Self::OutputMeta,
33 conj_dst: bool,
34 conj_lhs: bool,
35 conj_rhs: bool,
36 ) -> Result<Self::Output, TensorError> {
37 Ok(self
38 .inner
39 .gemm(
40 rhs.inner.as_ref(),
41 alpha,
42 beta,
43 conj_dst,
44 conj_lhs,
45 conj_rhs,
46 )?
47 .into())
48 }
49 fn gemm_<U>(
50 &self,
51 rhs: Tensor<B, Cpu, DEVICE, A2>,
52 alpha: Self::OutputMeta,
53 beta: Self::OutputMeta,
54 conj_dst: bool,
55 conj_lhs: bool,
56 conj_rhs: bool,
57 out: U,
58 ) -> Result<Self::Output, TensorError>
59 where
60 U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
61 {
62 let out = out.borrow().inner.as_ref().clone();
63 Ok(self
64 .inner
65 .gemm_(
66 rhs.inner.as_ref(),
67 alpha,
68 beta,
69 conj_dst,
70 conj_lhs,
71 conj_rhs,
72 out,
73 )?
74 .into())
75 }
76}
77
78impl<A, B, A2, const DEVICE: usize> Gemm<&Tensor<B, Cpu, DEVICE, A2>> for Tensor<A, Cpu, DEVICE, A2>
79where
80 A: CommonBounds + NormalOut<B> + Cast<<A as NormalOut<B>>::Output>,
81 B: CommonBounds + Cast<<A as NormalOut<B>>::Output>,
82 <A as NormalOut<B>>::Output: CommonBounds,
83 A2: Allocator,
84 A2::Output: AllocatorOutputRetrive,
85{
86 type Output = GemmOutput<A, B, DEVICE, A2>;
87
88 type OutputMeta = <A as NormalOut<B>>::Output;
89
90 type InplaceOutput = GemmOutput<A, B, DEVICE, A2>;
91
92 fn gemm(
93 &self,
94 rhs: &Tensor<B, Cpu, DEVICE, A2>,
95 alpha: Self::OutputMeta,
96 beta: Self::OutputMeta,
97 conj_dst: bool,
98 conj_lhs: bool,
99 conj_rhs: bool,
100 ) -> Result<Self::Output, TensorError> {
101 Ok(self
102 .inner
103 .gemm(
104 rhs.inner.as_ref(),
105 alpha,
106 beta,
107 conj_dst,
108 conj_lhs,
109 conj_rhs,
110 )?
111 .into())
112 }
113
114 fn gemm_<U>(
115 &self,
116 rhs: &Tensor<B, Cpu, DEVICE, A2>,
117 alpha: Self::OutputMeta,
118 beta: Self::OutputMeta,
119 conj_dst: bool,
120 conj_lhs: bool,
121 conj_rhs: bool,
122 out: U,
123 ) -> Result<Self::Output, TensorError>
124 where
125 U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
126 {
127 let out = out.borrow().inner.as_ref().clone();
128 Ok(self
129 .inner
130 .gemm_(
131 rhs.inner.as_ref(),
132 alpha,
133 beta,
134 conj_dst,
135 conj_lhs,
136 conj_rhs,
137 out,
138 )?
139 .into())
140 }
141}