rstsr_core/storage/
matmul.rs

1//! Matrix, vector multiplication and related operations.
2
3#![allow(clippy::too_many_arguments)]
4
5use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9pub trait DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>
10where
11    DA: DimAPI,
12    DB: DimAPI,
13    DC: DimAPI,
14    TA: Mul<TB, Output = TC>,
15    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
16    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
17{
18    fn matmul(
19        &self,
20        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
21        lc: &Layout<DC>,
22        a: &<Self as DeviceRawAPI<TA>>::Raw,
23        la: &Layout<DA>,
24        b: &<Self as DeviceRawAPI<TB>>::Raw,
25        lb: &Layout<DB>,
26        alpha: TC,
27        beta: TC,
28    ) -> Result<()>;
29}
30
31pub trait DeviceGEMMAPI<TA, TB, TC>
32where
33    TA: Mul<TB, Output = TC>,
34    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
35    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
36{
37    fn gemm(
38        &self,
39        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
40        lc: &Layout<Ix2>,
41        a: &<Self as DeviceRawAPI<TA>>::Raw,
42        la: &Layout<Ix2>,
43        b: &<Self as DeviceRawAPI<TB>>::Raw,
44        lb: &Layout<Ix2>,
45        alpha: TC,
46        beta: TC,
47    ) -> Result<()>;
48}
49
50pub trait DeviceSYMMAPI<TA, TB, TC>
51where
52    TA: Mul<TB, Output = TC>,
53    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
54    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
55{
56    fn symm(
57        &self,
58        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
59        lc: &Layout<Ix2>,
60        a: &<Self as DeviceRawAPI<TA>>::Raw,
61        la: &Layout<Ix2>,
62        b: &<Self as DeviceRawAPI<TB>>::Raw,
63        lb: &Layout<Ix2>,
64        side: FlagSide,
65        uplo: FlagUpLo,
66        alpha: TC,
67        beta: TC,
68    ) -> Result<()>;
69}
70
71pub trait DeviceSYRKAPI<TA, TC>
72where
73    TA: Mul<TA, Output = TC>,
74    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
75    Self: DeviceAPI<TA> + DeviceAPI<TC>,
76{
77    fn syrk(
78        &self,
79        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
80        lc: &Layout<Ix2>,
81        a: &<Self as DeviceRawAPI<TA>>::Raw,
82        la: &Layout<Ix2>,
83        uplo: FlagUpLo,
84        alpha: TC,
85        beta: TC,
86    ) -> Result<()>;
87}
88
89pub trait DeviceHERKAPI<TA, TC>
90where
91    TA: Mul<TA, Output = TC>,
92    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
93    Self: DeviceAPI<TA> + DeviceAPI<TC>,
94{
95    fn herk(
96        &self,
97        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
98        lc: &Layout<Ix2>,
99        a: &<Self as DeviceRawAPI<TA>>::Raw,
100        la: &Layout<Ix2>,
101        uplo: FlagUpLo,
102        alpha: TC,
103        beta: TC,
104    ) -> Result<()>;
105}
106
107pub trait DeviceGEMVAPI<TA, TB, TC>
108where
109    TA: Mul<TB, Output = TC>,
110    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
111    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
112{
113    fn gemv(
114        &self,
115        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
116        lc: &Layout<Ix1>,
117        a: &<Self as DeviceRawAPI<TA>>::Raw,
118        la: &Layout<Ix2>,
119        b: &<Self as DeviceRawAPI<TB>>::Raw,
120        lb: &Layout<Ix1>,
121        alpha: TC,
122        beta: TC,
123    ) -> Result<()>;
124
125    fn gevm(
126        &self,
127        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
128        lc: &Layout<Ix1>,
129        a: &<Self as DeviceRawAPI<TA>>::Raw,
130        la: &Layout<Ix1>,
131        b: &<Self as DeviceRawAPI<TB>>::Raw,
132        lb: &Layout<Ix2>,
133        alpha: TC,
134        beta: TC,
135    ) -> Result<()>;
136}
137
138pub trait DeviceInnerDotAPI<TA, TB, TC>
139where
140    TA: Mul<TB, Output = TC>,
141    TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
142    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
143{
144    fn inner_dot(
145        &self,
146        c: &mut <Self as DeviceRawAPI<TC>>::Raw,
147        lc: &Layout<Ix0>,
148        a: &<Self as DeviceRawAPI<TA>>::Raw,
149        la: &Layout<Ix1>,
150        b: &<Self as DeviceRawAPI<TB>>::Raw,
151        lb: &Layout<Ix1>,
152        alpha: TC,
153        beta: TC,
154    ) -> Result<()>;
155}