custos_math/ops/
gemm.rs

1use custos::{impl_stack, Device, Dim2, GenericBlas, MainMemory, MayDim2, Shape, CPU};
2
3#[cfg(feature = "stack")]
4use custos::Stack;
5
6#[cfg(feature = "opencl")]
7use custos::CDatatype;
8
9#[cfg(feature = "opencl")]
10use crate::cl_gemm;
11#[cfg(feature = "opencl")]
12use custos::OpenCL;
13
14use crate::Matrix;
15
16/*pub trait GemmMat<'a,
17    T,
18    D: Gemm<T, LS, RS, OS, D>,
19    LS: MayDim2<M, K>,
20    RS: MayDim2<K, N>,
21    OS: MayDim2<M, N>,
22    const M: usize = 0,
23    const K: usize = 0,
24    const N: usize = 0,
25>
26{
27    fn gemm(&self, rhs: &Matrix<'a, T, D, RS>) -> Matrix<T, D, OS>;
28}
29
30impl<
31        'a,
32        T,
33        D: Gemm<T, LS, RS, OS, D>,
34        LS: MayDim2<M, K>,
35        RS: MayDim2<K, N>,
36        OS: MayDim2<M, N>,
37        const M: usize,
38        const K: usize,
39        const N: usize,
40    > GemmMat<'a, T, D, LS, RS, OS, M, K, N> for Matrix<'a, T, D, LS>
41{
42    fn gemm(&self, rhs: &Matrix<'a, T, D, RS>) -> Matrix<T, D, OS> {
43        self.device().gemm(self, rhs)
44    }
45}*/
46
47impl<'a, T, D: Device, LS: Shape> Matrix<'a, T, D, LS> {
48    /// Matrix multiplication. Uses current global device.
49    /// # Example
50    #[cfg_attr(feature = "cpu", doc = "```")]
51    #[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
52    /// use custos::CPU;
53    /// use custos_math::Matrix;
54    ///
55    /// let device = CPU::new();
56    ///
57    /// let a = Matrix::from((&device, (2, 3), [1., 2., 3., 4., 5., 6.,]));
58    /// let b = Matrix::from((&device, (3, 2), [6., 5., 4., 3., 2., 1.,]));
59    ///
60    /// let c = a.gemm(&b);
61    /// println!("c: {c:?}");
62    ///
63    /// assert_eq!(c.read(), vec![20., 14., 56., 41.,]);
64    /// ```
65    #[inline]
66    pub fn gemm<RS: Shape, OS: Shape>(&self, rhs: &Matrix<'a, T, D, RS>) -> Matrix<'a, T, D, OS>
67    where
68        D: Gemm<T, LS, RS, OS, D>,
69    {
70        self.device().gemm(self, rhs)
71    }
72}
73
74/*impl<'a, T, D: Device> Matrix<'a, T, D> {
75    #[inline]
76    pub fn gemm(&self, rhs: &Matrix<'a, T, D>) -> Matrix<'a, T, D>
77    where
78        D: Gemm<T, (), (), (), D>,
79    {
80        self.device().gemm(self, rhs)
81    }
82}*/
83
84/*impl<'a, T, D: Device, const M: usize, const K: usize> Matrix<'a, T, D, Dim2<M, K>> {
85    /// Matrix multiplication. Uses current global device.
86    /// # Example
87    /// ```
88    /// use custos::CPU;
89    /// use custos_math::Matrix;
90    ///
91    /// let device = CPU::new();
92    ///
93    /// let a = Matrix::from((&device, (2, 3), [1., 2., 3., 4., 5., 6.,]));
94    /// let b = Matrix::from((&device, (3, 2), [6., 5., 4., 3., 2., 1.,]));
95    ///
96    /// let c = a.gemm(&b);
97    /// println!("c: {c:?}");
98    ///
99    /// assert_eq!(c.read(), vec![20., 14., 56., 41.,]);
100    /// ```
101    #[inline]
102    pub fn gemm<const N: usize>(
103        &self,
104        rhs: &Matrix<'a, T, D, Dim2<K, N>>,
105    ) -> Matrix<'a, T, D, Dim2<M, N>>
106    where
107        D: Gemm<T, Dim2<M, K>, Dim2<K, N>, Dim2<M, N>, D>,
108    {
109        self.device().gemm(self, rhs)
110    }
111}*/
112
113/// Matrix multiplication. Uses provided device.
114/// # Example
115#[cfg_attr(feature = "cpu", doc = "```")]
116#[cfg_attr(not(feature = "cpu"), doc = "```ignore")]
117/// use custos::{CPU, Read};
118/// use custos_math::{Matrix, Gemm};
119///
120/// let device = CPU::new();
121///
122/// let a = Matrix::from((&device, (2, 3), [1., 2., 3., 4., 5., 6.,]));
123/// let b = Matrix::from((&device, (3, 2), [6., 5., 4., 3., 2., 1.,]));
124///
125/// let c: Matrix = device.gemm(&a, &b);
126///
127/// assert_eq!(c.read(), vec![20., 14., 56., 41.,]);
128/// ```
129pub trait Gemm<T, LS: Shape = (), RS: Shape = (), OS: Shape = (), D: Device = Self>:
130    Device
131{
132    fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS>;
133}
134
135// #[cfg(not(feature = "no-std"))]
136// #[cfg(feature = "cpu")]
137#[cfg(feature = "blas")]
138#[cfg(not(feature = "matrixmultiply"))]
139#[impl_stack]
140impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
141where
142    T: GenericBlas + Default + Copy,
143    D: MainMemory,
144    LS: Shape,
145    RS: Shape,
146    OS: Shape,
147{
148    #[inline]
149    fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
150        let (m, k) = lhs.dims();
151        let n = rhs.cols();
152
153        debug_assert!(k == rhs.rows());
154
155        let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
156        T::gemm(m, n, k, lhs, rhs, &mut out);
157        (out, m, n).into()
158    }
159}
160
161#[cfg(feature = "matrixmultiply")]
162#[cfg(not(feature = "blas"))]
163#[impl_stack]
164impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
165where
166    T: crate::matrix_multiply::MatrixMultiply + Default + Copy,
167    D: MainMemory,
168    LS: Shape,
169    RS: Shape,
170    OS: Shape,
171{
172    #[inline]
173    fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
174        let (m, k) = lhs.dims();
175        let n = rhs.cols();
176
177        debug_assert!(k == rhs.rows());
178
179        let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
180        T::gemm(m, k, n, lhs, k, 1, rhs, n, 1, &mut out, n, 1);
181        (out, m, n).into()
182    }
183}
184
185#[cfg(not(feature = "matrixmultiply"))]
186#[cfg(not(feature = "blas"))]
187#[impl_stack]
188impl<T, D, LS, RS, OS> Gemm<T, LS, RS, OS, D> for CPU
189where
190    T: Default + Copy + core::ops::Mul<Output = T> + core::ops::AddAssign,
191    D: MainMemory,
192    LS: Shape,
193    RS: Shape,
194    OS: Shape,
195{
196    #[inline]
197    fn gemm(&self, lhs: &Matrix<T, D, LS>, rhs: &Matrix<T, D, RS>) -> Matrix<T, Self, OS> {
198        let (m, k) = lhs.dims();
199        let n = rhs.cols();
200
201        debug_assert!(k == rhs.rows());
202
203        let mut out = self.retrieve(m * n, (lhs.node.idx, rhs.node.idx));
204        crate::raw_ops::naive_gemm(m, k, n, lhs, rhs, &mut out);
205        (out, m, n).into()
206    }
207}
208
209#[cfg(feature = "opencl")]
210impl<T: CDatatype> Gemm<T> for OpenCL {
211    fn gemm(&self, lhs: &Matrix<T, Self>, rhs: &Matrix<T, Self>) -> Matrix<T, Self> {
212        assert!(lhs.dims().1 == rhs.dims().0);
213        //crate::opencl::ops::ocl_gemm1(self.clone(), rhs, lhs).unwrap()
214        let buf = cl_gemm(self, rhs.cols(), rhs.rows(), lhs.rows(), rhs, lhs).unwrap();
215        (buf, lhs.rows(), rhs.cols()).into()
216    }
217}
218
219#[cfg(feature = "cuda")]
220impl<T: GenericBlas> Gemm<T> for custos::CUDA {
221    fn gemm(
222        &self,
223        lhs: &Matrix<T, custos::CUDA>,
224        rhs: &Matrix<T, custos::CUDA>,
225    ) -> Matrix<T, custos::CUDA> {
226        use custos::CacheBuf;
227        assert!(
228            lhs.cols() == rhs.rows(),
229            "wrong dims for matrix multiplication"
230        );
231        let out = self.cached(lhs.rows() * rhs.cols());
232        T::cugemm(
233            self.handle(),
234            lhs.rows(),
235            rhs.cols(),
236            lhs.cols(),
237            lhs.as_buf().ptr.ptr,
238            rhs.as_buf().ptr.ptr,
239            out.ptr.ptr,
240        )
241        .unwrap();
242        (out, lhs.rows(), rhs.cols()).into()
243    }
244}
245
246#[cfg(test)]
247mod tests {
248
249    #[cfg(feature = "stack")]
250    #[test]
251    fn test_stack_impl() {
252        use custos::{Buffer, Dim1, Dim2, Stack};
253
254        use crate::Matrix;
255
256        let data = Buffer::from((Stack, &[3., 1., 5.]));
257        let lhs = Matrix { data, dims: (1, 3) };
258
259        /*let data = Buffer::<_, _, Dim2<3, 1>>::from((Stack, &[3., 1., 5.]));
260        let rhs = Matrix { data, dims: (3, 1) };
261
262        let out: Matrix<f64, Stack, Dim1<1>> = lhs.gemm(&rhs);*/
263    }
264}