1use crate::plugin::*;
4use crate::transpose::*;
5use coaster::backend::Backend;
6use coaster::frameworks::native::Native;
7use coaster::tensor::{ITensorDesc, SharedTensor};
8use rblas;
9use rblas::math::bandmat::BandMat;
10use rblas::math::mat::Mat;
11use rblas::matrix::Matrix;
12
13macro_rules! read {
14 ($x:ident, $t:ident, $slf:ident) => {
15 $x.read($slf.device())?.as_slice::<$t>();
16 };
17}
18
19macro_rules! read_write {
20 ($x:ident, $t: ident, $slf:ident) => {
21 $x.read_write($slf.device())?.as_mut_slice::<$t>();
22 };
23}
24
25macro_rules! write_only {
26 ($x:ident, $t: ident, $slf:ident) => {
27 $x.write_only($slf.device())?.as_mut_slice::<$t>();
28 };
29}
30
31macro_rules! iblas_asum_for_native {
32 ($t:ident) => {
33 fn asum(
34 &self,
35 x: &SharedTensor<$t>,
36 result: &mut SharedTensor<$t>,
37 ) -> Result<(), ::coaster::error::Error> {
38 let r_slice = write_only!(result, $t, self);
39 r_slice[0] = rblas::Asum::asum(read!(x, $t, self));
40 Ok(())
41 }
42 };
43}
44
45macro_rules! iblas_axpy_for_native {
46 ($t:ident) => {
47 fn axpy(
48 &self,
49 a: &SharedTensor<$t>,
50 x: &SharedTensor<$t>,
51 y: &mut SharedTensor<$t>,
52 ) -> Result<(), ::coaster::error::Error> {
53 rblas::Axpy::axpy(
54 &read!(a, $t, self)[0],
55 read!(x, $t, self),
56 read_write!(y, $t, self),
57 );
58 Ok(())
59 }
60 };
61}
62
63macro_rules! iblas_copy_for_native {
64 ($t:ident) => {
65 fn copy(
66 &self,
67 x: &SharedTensor<$t>,
68 y: &mut SharedTensor<$t>,
69 ) -> Result<(), ::coaster::error::Error> {
70 rblas::Copy::copy(read!(x, $t, self), write_only!(y, $t, self));
71 Ok(())
72 }
73 };
74}
75
76macro_rules! iblas_dot_for_native {
77 ($t:ident) => {
78 fn dot(
79 &self,
80 x: &SharedTensor<$t>,
81 y: &SharedTensor<$t>,
82 result: &mut SharedTensor<$t>,
83 ) -> Result<(), ::coaster::error::Error> {
84 let r_slice = write_only!(result, $t, self);
85 r_slice[0] = rblas::Dot::dot(read!(x, $t, self), read!(y, $t, self));
86 Ok(())
87 }
88 };
89}
90
91macro_rules! iblas_nrm2_for_native {
92 ($t:ident) => {
93 fn nrm2(
94 &self,
95 x: &SharedTensor<$t>,
96 result: &mut SharedTensor<$t>,
97 ) -> Result<(), ::coaster::error::Error> {
98 let r_slice = write_only!(result, $t, self);
99 r_slice[0] = rblas::Nrm2::nrm2(read!(x, $t, self));
100 Ok(())
101 }
102 };
103}
104
105macro_rules! iblas_scal_for_native {
106 ($t:ident) => {
107 fn scal(
108 &self,
109 a: &SharedTensor<$t>,
110 x: &mut SharedTensor<$t>,
111 ) -> Result<(), ::coaster::error::Error> {
112 rblas::Scal::scal(&read!(a, $t, self)[0], read_write!(x, $t, self));
113 Ok(())
114 }
115 };
116}
117
118macro_rules! iblas_swap_for_native {
119 ($t:ident) => {
120 fn swap(
121 &self,
122 x: &mut SharedTensor<$t>,
123 y: &mut SharedTensor<$t>,
124 ) -> Result<(), ::coaster::error::Error> {
125 rblas::Swap::swap(read_write!(x, $t, self), read_write!(y, $t, self));
126 Ok(())
127 }
128 };
129}
130
131macro_rules! iblas_gbmv_for_native {
132 ($t: ident) => {
133 fn gbmv(
134 &self,
135 alpha: &SharedTensor<$t>,
136 at: Transpose,
137 a: &SharedTensor<$t>,
138 kl: &SharedTensor<u32>,
139 ku: &SharedTensor<u32>,
140 x: &SharedTensor<$t>,
141 beta: &SharedTensor<$t>,
142 c: &mut SharedTensor<$t>,
143 ) -> Result<(), ::coaster::error::Error> {
144 let a_slice = read!(a, $t, self);
145 let x_slice = read!(x, $t, self);
146 let c_slice = read_write!(c, $t, self);
147
148 let kl: u32 = read!(kl, u32, self)[0];
150 let ku: u32 = read!(ku, u32, self)[0];
151
152 let a_matrix = as_matrix(a_slice, a.desc().dims());
153 let a_matrix = BandMat::from_matrix(a_matrix, kl, ku);
154
155 rblas::Gbmv::gbmv(
156 at.to_rblas(),
157 &read!(alpha, $t, self)[0],
158 &a_matrix,
159 x_slice,
160 &read!(beta, $t, self)[0],
161 c_slice,
162 );
163
164 Ok(())
165 }
166 };
167}
168
169macro_rules! iblas_gemm_for_native {
170 ($t:ident) => {
171 fn gemm(
172 &self,
173 alpha: &SharedTensor<$t>,
174 at: Transpose,
175 a: &SharedTensor<$t>,
176 bt: Transpose,
177 b: &SharedTensor<$t>,
178 beta: &SharedTensor<$t>,
179 c: &mut SharedTensor<$t>,
180 ) -> Result<(), ::coaster::error::Error> {
181 let c_dims = c.desc().clone(); let a_slice = read!(a, $t, self);
184 let b_slice = read!(b, $t, self);
185 let c_slice = write_only!(c, $t, self);
186
187 let a_matrix = as_matrix(a_slice, a.desc().dims());
188 let b_matrix = as_matrix(b_slice, b.desc().dims());
189 let mut c_matrix = as_matrix(c_slice, &c_dims);
190 rblas::Gemm::gemm(
191 &read!(alpha, $t, self)[0],
192 at.to_rblas(),
193 &a_matrix,
194 bt.to_rblas(),
195 &b_matrix,
196 &read!(beta, $t, self)[0],
197 &mut c_matrix,
198 );
199 read_from_matrix(&c_matrix, c_slice);
200 Ok(())
201 }
202 };
203}
204
205macro_rules! impl_iblas_for {
206 ($t:ident, $b:ty) => {
207 impl IBlas<$t> for $b {}
208
209 impl Asum<$t> for $b {
212 iblas_asum_for_native!($t);
213 }
214
215 impl Axpy<$t> for $b {
216 iblas_axpy_for_native!($t);
217 }
218
219 impl Copy<$t> for $b {
220 iblas_copy_for_native!($t);
221 }
222
223 impl Dot<$t> for $b {
224 iblas_dot_for_native!($t);
225 }
226
227 impl Nrm2<$t> for $b {
228 iblas_nrm2_for_native!($t);
229 }
230
231 impl Scal<$t> for $b {
232 iblas_scal_for_native!($t);
233 }
234
235 impl Swap<$t> for $b {
236 iblas_swap_for_native!($t);
237 }
238
239 impl Gbmv<$t> for $b {
242 iblas_gbmv_for_native!($t);
243 }
244
245 impl Gemm<$t> for $b {
248 iblas_gemm_for_native!($t);
249 }
250 };
251}
252
253impl_iblas_for!(f32, Backend<Native>);
254impl_iblas_for!(f64, Backend<Native>);
255
256fn as_matrix<T: Clone + ::std::fmt::Debug>(slice: &[T], dims: &[usize]) -> Mat<T> {
258 let n = dims[0];
259 let m = dims.iter().skip(1).product();
260 let mut mat: Mat<T> = Mat::new(n, m);
261 for i in 0..n {
262 for j in 0..m {
263 let index = m * i + j;
264 unsafe {
265 *mat.as_mut_ptr().add(index) = slice[index].clone();
266 }
267 }
268 }
269
270 mat
271}
272
273fn read_from_matrix<T: Clone>(mat: &Mat<T>, slice: &mut [T]) {
274 let n = mat.rows();
275 let m = mat.cols();
276 for i in 0..n {
277 for j in 0..m {
278 let index = m * i + j;
279 slice[index] = mat[i][j].clone();
280 }
281 }
282}
283
284#[cfg(test)]
285mod test {
286 use super::as_matrix;
287 use coaster::backend::{Backend, IBackend};
288 use coaster::frameworks::native::flatbox::FlatBox;
289 use coaster::frameworks::Native;
290 use coaster::tensor::SharedTensor;
291
292 fn get_native_backend() -> Backend<Native> {
293 Backend::<Native>::default().unwrap()
294 }
295
296 pub fn write_to_memory<T: Copy>(mem: &mut FlatBox, data: &[T]) {
297 let mem_buffer = mem.as_mut_slice::<T>();
298 for (index, datum) in data.iter().enumerate() {
299 mem_buffer[index] = *datum;
300 }
301 }
302
303 #[test]
305 fn it_converts_correctly_to_and_from_matrix() {
306 let backend = get_native_backend();
307 let mut a = SharedTensor::<f32>::new(&vec![3, 2]);
308 write_to_memory(
309 a.write_only(backend.device()).unwrap(),
310 &[2f32, 5f32, 2f32, 5f32, 2f32, 5f32],
311 );
312
313 {
314 let a_slice_in = a.read(backend.device()).unwrap().as_slice::<f32>();
315 let a_mat = as_matrix(a_slice_in, &[3, 2]);
316 assert_eq!(a_mat[0][0], 2f32);
318 assert_eq!(a_mat[0][1], 5f32);
319 assert_eq!(a_mat[1][0], 2f32);
320 assert_eq!(a_mat[1][1], 5f32);
321 assert_eq!(a_mat[2][0], 2f32);
322 assert_eq!(a_mat[2][1], 5f32);
323 }
324 }
325}