1use std::{
2 borrow::{Borrow, BorrowMut},
3 cell::RefCell,
4 rc::Rc,
5};
6
7use hpt_common::error::base::TensorError;
8use hpt_traits::{
9 ops::{
10 binary::{Matmul, MatmulPost},
11 shape_manipulate::ShapeManipulate,
12 },
13 tensor::CommonBounds,
14};
15
16use crate::{
17 backends::cpu::{
18 kernels::matmul::microkernel_trait::MatmulMicroKernel, utils::diff::diff_utils::handle_grad,
19 },
20 tensor::{DiffTensor, Tensor},
21};
22use hpt_allocator::{
23 traits::{Allocator, AllocatorOutputRetrive},
24 Cpu,
25};
26
27impl<T, const DEVICE: usize, Al> Matmul<Tensor<T, Cpu, DEVICE, Al>> for Tensor<T, Cpu, DEVICE, Al>
28where
29 T: CommonBounds + MatmulMicroKernel,
30 Al: Allocator,
31 Al::Output: AllocatorOutputRetrive,
32{
33 type Output = Tensor<T, Cpu, DEVICE, Al>;
34
35 type OutputMeta = T;
36
37 type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
38
39 fn matmul(
40 &self,
41 rhs: Tensor<T, Cpu, DEVICE, Al>,
42 ) -> std::result::Result<Self::Output, TensorError> {
43 Ok(self.inner.matmul(rhs.inner.as_ref())?.into())
44 }
45 fn matmul_<U>(
46 &self,
47 rhs: Tensor<T, Cpu, DEVICE, Al>,
48 out: U,
49 ) -> std::result::Result<Self::Output, TensorError>
50 where
51 U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
52 {
53 let out = out.borrow().inner.as_ref().clone();
54 Ok(self.inner.matmul_(rhs.inner.as_ref(), out)?.into())
55 }
56}
57
58impl<T, const DEVICE: usize, Al> Matmul<&Tensor<T, Cpu, DEVICE, Al>> for Tensor<T, Cpu, DEVICE, Al>
59where
60 T: CommonBounds + MatmulMicroKernel,
61 Al: Allocator,
62 Al::Output: AllocatorOutputRetrive,
63{
64 type Output = Tensor<T, Cpu, DEVICE, Al>;
65
66 type OutputMeta = T;
67
68 type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
69
70 fn matmul(
71 &self,
72 rhs: &Tensor<T, Cpu, DEVICE, Al>,
73 ) -> std::result::Result<Self::Output, TensorError> {
74 Ok(self.inner.matmul(rhs.inner.as_ref())?.into())
75 }
76
77 fn matmul_<U>(
78 &self,
79 rhs: &Tensor<T, Cpu, DEVICE, Al>,
80 out: U,
81 ) -> std::result::Result<Self::Output, TensorError>
82 where
83 U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
84 {
85 let out = out.borrow().inner.as_ref().clone();
86 Ok(self.inner.matmul_(rhs.inner.as_ref(), out)?.into())
87 }
88}
89
90impl<T, const DEVICE: usize, Al> Matmul<DiffTensor<T, Cpu, DEVICE, Al>>
91 for DiffTensor<T, Cpu, DEVICE, Al>
92where
93 T: CommonBounds + MatmulMicroKernel,
94 Al: Allocator + 'static + Send + Sync,
95 Al::Output: AllocatorOutputRetrive,
96{
97 type Output = DiffTensor<T, Cpu, DEVICE, Al>;
98
99 type OutputMeta = T;
100
101 type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
102
103 fn matmul(
104 &self,
105 rhs: DiffTensor<T, Cpu, DEVICE, Al>,
106 ) -> std::result::Result<Self::Output, TensorError> {
107 let res = self.inner.matmul(&rhs.inner)?;
108 let mut lhs = self.clone();
109 let mut rhs = rhs.clone();
110 Ok(DiffTensor {
111 inner: res,
112 grad: Rc::new(RefCell::new(None)),
113 out_degree: Rc::new(RefCell::new(0)),
114 backward: Rc::new(RefCell::new(move |grad: Tensor<T, Cpu, DEVICE, Al>| {
115 let grad_a = grad.matmul(rhs.inner.t()?)?;
116 let grad_b = lhs.inner.t()?.matmul(grad)?;
117 handle_grad(&mut lhs, grad_a, &[])?;
118 handle_grad(&mut rhs, grad_b, &[])?;
119 Ok(false)
120 })),
121 })
122 }
123 fn matmul_<U>(
124 &self,
125 rhs: DiffTensor<T, Cpu, DEVICE, Al>,
126 out: U,
127 ) -> std::result::Result<Self::InplaceOutput, TensorError>
128 where
129 U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
130 {
131 self.inner.matmul_(&rhs.inner, out)
132 }
133}
134
135impl<T, A, const DEVICE: usize> MatmulPost<Tensor<T, Cpu, DEVICE, A>> for Tensor<T, Cpu, DEVICE, A>
136where
137 T: CommonBounds + MatmulMicroKernel,
138 A: Allocator,
139 A::Output: AllocatorOutputRetrive,
140{
141 type Output = Tensor<T, Cpu, DEVICE, A>;
142
143 type OutputMeta = T;
144
145 type InplaceOutput = Tensor<T, Cpu, DEVICE, A>;
146
147 fn matmul_post(
148 &self,
149 rhs: Tensor<T, Cpu, DEVICE, A>,
150 post_op: fn(T) -> T,
151 post_op_vec: fn(T::Vec) -> T::Vec,
152 ) -> std::result::Result<Self::Output, TensorError> {
153 Ok(self
154 .inner
155 .matmul_post(rhs.inner.as_ref(), post_op, post_op_vec)?
156 .into())
157 }
158
159 fn matmul_post_<U>(
160 &self,
161 rhs: Tensor<T, Cpu, DEVICE, A>,
162 post_op: fn(T) -> T,
163 post_op_vec: fn(T::Vec) -> T::Vec,
164 mut out: U,
165 ) -> std::result::Result<Self::InplaceOutput, TensorError>
166 where
167 U: BorrowMut<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
168 {
169 Ok(self
170 .inner
171 .matmul_post_(
172 rhs.inner.as_ref(),
173 post_op,
174 post_op_vec,
175 out.borrow_mut().inner.as_ref().clone(),
176 )?
177 .into())
178 }
179}
180
181impl<T, A, const DEVICE: usize> MatmulPost<&Tensor<T, Cpu, DEVICE, A>> for Tensor<T, Cpu, DEVICE, A>
182where
183 T: CommonBounds + MatmulMicroKernel,
184 A: Allocator,
185 A::Output: AllocatorOutputRetrive,
186{
187 type Output = Tensor<T, Cpu, DEVICE, A>;
188
189 type OutputMeta = T;
190
191 type InplaceOutput = Tensor<T, Cpu, DEVICE, A>;
192
193 fn matmul_post(
194 &self,
195 rhs: &Tensor<T, Cpu, DEVICE, A>,
196 post_op: fn(T) -> T,
197 post_op_vec: fn(T::Vec) -> T::Vec,
198 ) -> std::result::Result<Self::Output, TensorError> {
199 Ok(self
200 .inner
201 .matmul_post(rhs.inner.as_ref(), post_op, post_op_vec)?
202 .into())
203 }
204
205 fn matmul_post_<U>(
206 &self,
207 rhs: &Tensor<T, Cpu, DEVICE, A>,
208 post_op: fn(T) -> T,
209 post_op_vec: fn(T::Vec) -> T::Vec,
210 mut out: U,
211 ) -> std::result::Result<Self::InplaceOutput, TensorError>
212 where
213 U: BorrowMut<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
214 {
215 Ok(self
216 .inner
217 .matmul_post_(
218 rhs.inner.as_ref(),
219 post_op,
220 post_op_vec,
221 out.borrow_mut().inner.as_ref().clone(),
222 )?
223 .into())
224 }
225}