1#![allow(clippy::too_many_arguments)]
4
5use core::ops::{Add, Mul};
6
7use crate::prelude_dev::*;
8
9pub trait DeviceMatMulAPI<TA, TB, TC, DA, DB, DC>
10where
11 DA: DimAPI,
12 DB: DimAPI,
13 DC: DimAPI,
14 TA: Mul<TB, Output = TC>,
15 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
16 Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
17{
18 fn matmul(
19 &self,
20 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
21 lc: &Layout<DC>,
22 a: &<Self as DeviceRawAPI<TA>>::Raw,
23 la: &Layout<DA>,
24 b: &<Self as DeviceRawAPI<TB>>::Raw,
25 lb: &Layout<DB>,
26 alpha: TC,
27 beta: TC,
28 ) -> Result<()>;
29}
30
31pub trait DeviceGEMMAPI<TA, TB, TC>
32where
33 TA: Mul<TB, Output = TC>,
34 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
35 Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
36{
37 fn gemm(
38 &self,
39 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
40 lc: &Layout<Ix2>,
41 a: &<Self as DeviceRawAPI<TA>>::Raw,
42 la: &Layout<Ix2>,
43 b: &<Self as DeviceRawAPI<TB>>::Raw,
44 lb: &Layout<Ix2>,
45 alpha: TC,
46 beta: TC,
47 ) -> Result<()>;
48}
49
50pub trait DeviceSYMMAPI<TA, TB, TC>
51where
52 TA: Mul<TB, Output = TC>,
53 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
54 Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
55{
56 fn symm(
57 &self,
58 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
59 lc: &Layout<Ix2>,
60 a: &<Self as DeviceRawAPI<TA>>::Raw,
61 la: &Layout<Ix2>,
62 b: &<Self as DeviceRawAPI<TB>>::Raw,
63 lb: &Layout<Ix2>,
64 side: FlagSide,
65 uplo: FlagUpLo,
66 alpha: TC,
67 beta: TC,
68 ) -> Result<()>;
69}
70
71pub trait DeviceSYRKAPI<TA, TC>
72where
73 TA: Mul<TA, Output = TC>,
74 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
75 Self: DeviceAPI<TA> + DeviceAPI<TC>,
76{
77 fn syrk(
78 &self,
79 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
80 lc: &Layout<Ix2>,
81 a: &<Self as DeviceRawAPI<TA>>::Raw,
82 la: &Layout<Ix2>,
83 uplo: FlagUpLo,
84 alpha: TC,
85 beta: TC,
86 ) -> Result<()>;
87}
88
89pub trait DeviceHERKAPI<TA, TC>
90where
91 TA: Mul<TA, Output = TC>,
92 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
93 Self: DeviceAPI<TA> + DeviceAPI<TC>,
94{
95 fn herk(
96 &self,
97 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
98 lc: &Layout<Ix2>,
99 a: &<Self as DeviceRawAPI<TA>>::Raw,
100 la: &Layout<Ix2>,
101 uplo: FlagUpLo,
102 alpha: TC,
103 beta: TC,
104 ) -> Result<()>;
105}
106
107pub trait DeviceGEMVAPI<TA, TB, TC>
108where
109 TA: Mul<TB, Output = TC>,
110 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
111 Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
112{
113 fn gemv(
114 &self,
115 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
116 lc: &Layout<Ix1>,
117 a: &<Self as DeviceRawAPI<TA>>::Raw,
118 la: &Layout<Ix2>,
119 b: &<Self as DeviceRawAPI<TB>>::Raw,
120 lb: &Layout<Ix1>,
121 alpha: TC,
122 beta: TC,
123 ) -> Result<()>;
124
125 fn gevm(
126 &self,
127 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
128 lc: &Layout<Ix1>,
129 a: &<Self as DeviceRawAPI<TA>>::Raw,
130 la: &Layout<Ix1>,
131 b: &<Self as DeviceRawAPI<TB>>::Raw,
132 lb: &Layout<Ix2>,
133 alpha: TC,
134 beta: TC,
135 ) -> Result<()>;
136}
137
138pub trait DeviceInnerDotAPI<TA, TB, TC>
139where
140 TA: Mul<TB, Output = TC>,
141 TC: Mul<TC, Output = TC> + Add<TC, Output = TC>,
142 Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
143{
144 fn inner_dot(
145 &self,
146 c: &mut <Self as DeviceRawAPI<TC>>::Raw,
147 lc: &Layout<Ix0>,
148 a: &<Self as DeviceRawAPI<TA>>::Raw,
149 la: &Layout<Ix1>,
150 b: &<Self as DeviceRawAPI<TB>>::Raw,
151 lb: &Layout<Ix1>,
152 alpha: TC,
153 beta: TC,
154 ) -> Result<()>;
155}