hpt/backends/cpu/tensor_external/
matmul.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    cell::RefCell,
4    rc::Rc,
5};
6
7use hpt_common::error::base::TensorError;
8use hpt_traits::{
9    ops::{
10        binary::{Matmul, MatmulPost},
11        shape_manipulate::ShapeManipulate,
12    },
13    tensor::CommonBounds,
14};
15
16use crate::{
17    backends::cpu::{
18        kernels::matmul::microkernel_trait::MatmulMicroKernel, utils::diff::diff_utils::handle_grad,
19    },
20    tensor::{DiffTensor, Tensor},
21};
22use hpt_allocator::{
23    traits::{Allocator, AllocatorOutputRetrive},
24    Cpu,
25};
26
27impl<T, const DEVICE: usize, Al> Matmul<Tensor<T, Cpu, DEVICE, Al>> for Tensor<T, Cpu, DEVICE, Al>
28where
29    T: CommonBounds + MatmulMicroKernel,
30    Al: Allocator,
31    Al::Output: AllocatorOutputRetrive,
32{
33    type Output = Tensor<T, Cpu, DEVICE, Al>;
34
35    type OutputMeta = T;
36
37    type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
38
39    fn matmul(
40        &self,
41        rhs: Tensor<T, Cpu, DEVICE, Al>,
42    ) -> std::result::Result<Self::Output, TensorError> {
43        Ok(self.inner.matmul(rhs.inner.as_ref())?.into())
44    }
45    fn matmul_<U>(
46        &self,
47        rhs: Tensor<T, Cpu, DEVICE, Al>,
48        out: U,
49    ) -> std::result::Result<Self::Output, TensorError>
50    where
51        U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
52    {
53        let out = out.borrow().inner.as_ref().clone();
54        Ok(self.inner.matmul_(rhs.inner.as_ref(), out)?.into())
55    }
56}
57
58impl<T, const DEVICE: usize, Al> Matmul<&Tensor<T, Cpu, DEVICE, Al>> for Tensor<T, Cpu, DEVICE, Al>
59where
60    T: CommonBounds + MatmulMicroKernel,
61    Al: Allocator,
62    Al::Output: AllocatorOutputRetrive,
63{
64    type Output = Tensor<T, Cpu, DEVICE, Al>;
65
66    type OutputMeta = T;
67
68    type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
69
70    fn matmul(
71        &self,
72        rhs: &Tensor<T, Cpu, DEVICE, Al>,
73    ) -> std::result::Result<Self::Output, TensorError> {
74        Ok(self.inner.matmul(rhs.inner.as_ref())?.into())
75    }
76
77    fn matmul_<U>(
78        &self,
79        rhs: &Tensor<T, Cpu, DEVICE, Al>,
80        out: U,
81    ) -> std::result::Result<Self::Output, TensorError>
82    where
83        U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
84    {
85        let out = out.borrow().inner.as_ref().clone();
86        Ok(self.inner.matmul_(rhs.inner.as_ref(), out)?.into())
87    }
88}
89
90impl<T, const DEVICE: usize, Al> Matmul<DiffTensor<T, Cpu, DEVICE, Al>>
91    for DiffTensor<T, Cpu, DEVICE, Al>
92where
93    T: CommonBounds + MatmulMicroKernel,
94    Al: Allocator + 'static + Send + Sync,
95    Al::Output: AllocatorOutputRetrive,
96{
97    type Output = DiffTensor<T, Cpu, DEVICE, Al>;
98
99    type OutputMeta = T;
100
101    type InplaceOutput = Tensor<T, Cpu, DEVICE, Al>;
102
103    fn matmul(
104        &self,
105        rhs: DiffTensor<T, Cpu, DEVICE, Al>,
106    ) -> std::result::Result<Self::Output, TensorError> {
107        let res = self.inner.matmul(&rhs.inner)?;
108        let mut lhs = self.clone();
109        let mut rhs = rhs.clone();
110        Ok(DiffTensor {
111            inner: res,
112            grad: Rc::new(RefCell::new(None)),
113            out_degree: Rc::new(RefCell::new(0)),
114            backward: Rc::new(RefCell::new(move |grad: Tensor<T, Cpu, DEVICE, Al>| {
115                let grad_a = grad.matmul(rhs.inner.t()?)?;
116                let grad_b = lhs.inner.t()?.matmul(grad)?;
117                handle_grad(&mut lhs, grad_a, &[])?;
118                handle_grad(&mut rhs, grad_b, &[])?;
119                Ok(false)
120            })),
121        })
122    }
123    fn matmul_<U>(
124        &self,
125        rhs: DiffTensor<T, Cpu, DEVICE, Al>,
126        out: U,
127    ) -> std::result::Result<Self::InplaceOutput, TensorError>
128    where
129        U: Borrow<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
130    {
131        self.inner.matmul_(&rhs.inner, out)
132    }
133}
134
135impl<T, A, const DEVICE: usize> MatmulPost<Tensor<T, Cpu, DEVICE, A>> for Tensor<T, Cpu, DEVICE, A>
136where
137    T: CommonBounds + MatmulMicroKernel,
138    A: Allocator,
139    A::Output: AllocatorOutputRetrive,
140{
141    type Output = Tensor<T, Cpu, DEVICE, A>;
142
143    type OutputMeta = T;
144
145    type InplaceOutput = Tensor<T, Cpu, DEVICE, A>;
146
147    fn matmul_post(
148        &self,
149        rhs: Tensor<T, Cpu, DEVICE, A>,
150        post_op: fn(T) -> T,
151        post_op_vec: fn(T::Vec) -> T::Vec,
152    ) -> std::result::Result<Self::Output, TensorError> {
153        Ok(self
154            .inner
155            .matmul_post(rhs.inner.as_ref(), post_op, post_op_vec)?
156            .into())
157    }
158
159    fn matmul_post_<U>(
160        &self,
161        rhs: Tensor<T, Cpu, DEVICE, A>,
162        post_op: fn(T) -> T,
163        post_op_vec: fn(T::Vec) -> T::Vec,
164        mut out: U,
165    ) -> std::result::Result<Self::InplaceOutput, TensorError>
166    where
167        U: BorrowMut<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
168    {
169        Ok(self
170            .inner
171            .matmul_post_(
172                rhs.inner.as_ref(),
173                post_op,
174                post_op_vec,
175                out.borrow_mut().inner.as_ref().clone(),
176            )?
177            .into())
178    }
179}
180
181impl<T, A, const DEVICE: usize> MatmulPost<&Tensor<T, Cpu, DEVICE, A>> for Tensor<T, Cpu, DEVICE, A>
182where
183    T: CommonBounds + MatmulMicroKernel,
184    A: Allocator,
185    A::Output: AllocatorOutputRetrive,
186{
187    type Output = Tensor<T, Cpu, DEVICE, A>;
188
189    type OutputMeta = T;
190
191    type InplaceOutput = Tensor<T, Cpu, DEVICE, A>;
192
193    fn matmul_post(
194        &self,
195        rhs: &Tensor<T, Cpu, DEVICE, A>,
196        post_op: fn(T) -> T,
197        post_op_vec: fn(T::Vec) -> T::Vec,
198    ) -> std::result::Result<Self::Output, TensorError> {
199        Ok(self
200            .inner
201            .matmul_post(rhs.inner.as_ref(), post_op, post_op_vec)?
202            .into())
203    }
204
205    fn matmul_post_<U>(
206        &self,
207        rhs: &Tensor<T, Cpu, DEVICE, A>,
208        post_op: fn(T) -> T,
209        post_op_vec: fn(T::Vec) -> T::Vec,
210        mut out: U,
211    ) -> std::result::Result<Self::InplaceOutput, TensorError>
212    where
213        U: BorrowMut<Self::InplaceOutput> + BorrowMut<Self::InplaceOutput>,
214    {
215        Ok(self
216            .inner
217            .matmul_post_(
218                rhs.inner.as_ref(),
219                post_op,
220                post_op_vec,
221                out.borrow_mut().inner.as_ref().clone(),
222            )?
223            .into())
224    }
225}