1use custos::{impl_stack, Device, Dim2, GenericBlas, MainMemory, MayDim2, Shape, CPU};
2
3#[cfg(feature = "stack")]
4use custos::Stack;
5
6#[cfg(feature = "opencl")]
7use custos::CDatatype;
8
9#[cfg(feature = "opencl")]
10use crate::cl_gemm;
11#[cfg(feature = "opencl")]
12use custos::OpenCL;
13
14use crate::Matrix;
15
16impl<'a, T, D: Device, LS: Shape> Matrix<'a, T, D, LS> {
48 #[cfg_attr(feature = "cpu", doc = "```")]
51 #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
52 #[inline]
66 pub fn gemm<RS: Shape, OS: Shape>(&self, rhs: &Matrix<'a, T, D, RS>) -> Matrix<'a, T, D, OS>
67 where
68 D: Gemm<T, LS, RS, OS, D>,
69 {
70 self.device().gemm(self, rhs)
71 }
72}
73
74#[cfg_attr(feature = "cpu", doc = "```")]
116#[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
117pub trait Gemm<T, LS: Shape = (), RS: Shape = (), OS: Shape = (), D: Device = Self>:
130 Device
131{
132 fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS>;
133}
134
135#[cfg(feature = "blas")]
138#[cfg(not(feature = "matrixmultiply"))]
139#[impl_stack]
140impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
141where
142 T: GenericBlas + Default + Copy,
143 D: MainMemory,
144 LS: Shape,
145 RS: Shape,
146 OS: Shape,
147{
148 #[inline]
149 fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
150 let (m, k) = lhs.dims();
151 let n = rhs.cols();
152
153 debug_assert!(k == rhs.rows());
154
155 let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
156 T::gemm(m, n, k, lhs, rhs, &mut out);
157 (out, m, n).into()
158 }
159}
160
161#[cfg(feature = "matrixmultiply")]
162#[cfg(not(feature = "blas"))]
163#[impl_stack]
164impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
165where
166 T: crate::matrix_multiply::MatrixMultiply + Default + Copy,
167 D: MainMemory,
168 LS: Shape,
169 RS: Shape,
170 OS: Shape,
171{
172 #[inline]
173 fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
174 let (m, k) = lhs.dims();
175 let n = rhs.cols();
176
177 debug_assert!(k == rhs.rows());
178
179 let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
180 T::gemm(m, k, n, lhs, k, 1, rhs, n, 1, &mut out, n, 1);
181 (out, m, n).into()
182 }
183}
184
185#[cfg(not(feature = "matrixmultiply"))]
186#[cfg(not(feature = "blas"))]
187#[impl_stack]
188impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
189where
190 T: Default + Copy + core::ops::Mul<Output = T> + core::ops::AddAssign,
191 D: MainMemory,
192 LS: Shape,
193 RS: Shape,
194 OS: Shape,
195{
196 #[inline]
197 fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
198 let (m, k) = lhs.dims();
199 let n = rhs.cols();
200
201 debug_assert!(k == rhs.rows());
202
203 let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
204 crate::raw_ops::naive_gemm(m, k, n, lhs, rhs, &mut out);
205 (out, m, n).into()
206 }
207}
208
209#[cfg(feature = "opencl")]
210impl<T: CDatatype> Gemm<T> for OpenCL {
211 fn gemm(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
212 assert!(lhs.dims().1 == rhs.dims().0);
213 let buf = cl_gemm(self, rhs.cols(), rhs.rows(), lhs.rows(), rhs, lhs).unwrap();
215 (buf, lhs.rows(), rhs.cols()).into()
216 }
217}
218
219#[cfg(feature = "cuda")]
220impl<T: GenericBlas> Gemm<T> for custos::CUDA {
221 fn gemm(
222 &self,
223 lhs: &Matrix<T, custos::CUDA>,
224 rhs: &Matrix<T, custos::CUDA>,
225 ) -> Matrix<T, custos::CUDA> {
226 use custos::CacheBuf;
227 assert!(
228 lhs.cols() == rhs.rows(),
229 "wrong dims for matrix multiplication"
230 );
231 let out = self.cached(lhs.rows() * rhs.cols());
232 T::cugemm(
233 self.handle(),
234 lhs.rows(),
235 rhs.cols(),
236 lhs.cols(),
237 lhs.as_buf().ptr.ptr,
238 rhs.as_buf().ptr.ptr,
239 out.ptr.ptr,
240 )
241 .unwrap();
242 (out, lhs.rows(), rhs.cols()).into()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248
249 #[cfg(feature = "stack")]
250 #[test]
251 fn test_stack_impl() {
252 use custos::{Buffer, Dim1, Dim2, Stack};
253
254 use crate::Matrix;
255
256 let data = Buffer::from((Stack, &[3., 1., 5.]));
257 let lhs = Matrix { data, dims: (1, 3) };
258
259 }
264}