unmtx_gpu/
lib.rs

1//
2// Copyright (c) 2025 Łukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! Micro neural matrix library for GPU is small library that operates on matrices.
9//!
10//! This library uses GPU by the following computing platforms:
11//!
12//! - OpenCL
13//! - CUDA
14//!
15//! If this library uses CUDA, this library can use the cuBLAS library to multiplication of
16//! matrices.
17//!
18//! A frontend-backend architecture is used by this library. The frontend of this library can use
19//! one of two backends (OpenCL or CUDA). These backend allows to use GPUs by the computing
20//! platforms. The frontend and the backend can have many instances. This library provides a
21//! high-level interfece to operations of matrices by the frontend and methods of a [`Matrix`]
22//! structure.
23//!
24//! # Examples
25//!
26//! ```
27//! # use unmtx_gpu::*;
28//! let a = matrix![
29//!     [1.0, 2.0],
30//!     [3.0, 4.0]
31//! ];
32//! let x = matrix![
33//!     [5.0],
34//!     [6.0]
35//! ];
36//! let b = matrix![
37//!     [7.0],
38//!     [8.0]
39//! ];
40//! let c = a * x + b;
41//! assert_eq!(vec![1.0 * 5.0 + 2.0 * 6.0 + 7.0, 3.0 * 5.0 + 4.0 * 6.0 + 8.0], c.elems());
42//! ```
43use std::ops::Add;
44use std::ops::AddAssign;
45use std::ops::Sub;
46use std::ops::SubAssign;
47use std::ops::Mul;
48use std::ops::MulAssign;
49use std::ops::Div;
50use std::ops::DivAssign;
51use std::error;
52use std::fmt;
53use std::result;
54use std::sync::Arc;
55use std::sync::Mutex;
56use std::sync::MutexGuard;
57
58#[cfg(feature = "opencl")]
59pub mod opencl;
60#[cfg(feature = "cuda")]
61pub mod cuda;
62
63/// A backend trait.
64///
65/// The backend provides a low-level interface to computing platform (OpenCL or CUDA) for basic
66/// operations and functions on matrices. The backend methods operate on backend arrays which
67/// refers to areas of the device memory. The backend is low-level layer between a frontend and
68/// computing platform.
69pub trait Backend
70{
71    /// Returns the backend name.
72    fn name(&self) -> &'static str;
73    
74    /// Returns `true` if the backend uses cuBLAS, otherwise `false`.
75    fn has_cublas(&self) -> bool;
76    
77    /// Allocates a backend array.
78    unsafe fn alloc(&self, n: usize) -> Result<BackendArray>;
79
80    /// Allocates a backend array and stores zeros in the backend array.
81    fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>;
82
83    /// Allocates a backend array and stores the elements in the backend array.
84    fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>;
85    
86    /// Loads elements from the backenc array.
87    fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>;
88
89    /// Stores elements in the backend array.
90    fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>;
91
92    /// Copies the `a` backend array to the `b` backend array.
93    fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>;
94
95    /// Transposes the `a` matrix and then the result is in the `b` matrix
96    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
97    fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
98
99    /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
100    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
101    fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
102
103    /// Adds the transposed `a` matrix onto the `b` matrix and then the result is in the `c` matrix
104    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
105    fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
106    
107    /// Adds the `a` matrix onto the transposed `b` matrix and then the result is in the `c` matrix
108    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
109    fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
110
111    /// Adds the transposed `a` matrix onto the transposed `b` matrix and then the result is in the
112    /// `c` matrix
113    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
114    fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
115
116    /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
117    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
118    fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
119
120    /// Subtracts the `b` matrix from the transposed `a` matrix and then the result is in the `c`
121    /// matrix
122    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
123    fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
124    
125    /// Subtracts the transposed `b` matrix from the `a` matrix and then the result is in the `c`
126    /// matrix
127    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
128    fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
129
130    /// Subtracts the transposed `b` matrix from the transposed `a` matrix and then the result is
131    /// in the `c` matrix
132    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
133    fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;    
134    
135    /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
136    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
137    fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
138
139    /// Multiplies the transposed `a` matrix by the `b` matrix and then the result is in the `c`
140    /// matrix
141    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
142    fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
143
144    /// Multiplies the `a` matrix by the transposed `b` matrix and then the result is in the `c`
145    /// matrix
146    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
147    fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
148
149    /// Multiplies the transposed `a` matrix by the transposed `b` matrix and then the result is in
150    /// the `c` matrix
151    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
152    fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
153
154    /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
155    /// `c` matrix
156    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
157    fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
158
159    /// Multiplies the transposed `a` matrix elements by the `b` matrix elements and saves the
160    /// result to the `c` matrix
161    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
162    fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
163
164    /// Multiplies the `a` matrix elements by the transposed `b` matrix elements and then the
165    /// result is in the `c` matrix
166    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
167    fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
168    
169    /// Multiplies the transposed `a` matrix elements by the transposed `b` matrix elements and
170    /// then the result is in the `c` matrix.
171    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
172    fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
173
174    /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the
175    /// `c` matrix
176    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
177    fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
178
179    /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
180    /// is in the `c` matrix
181    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
182    fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
183    
184    /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
185    /// is in the `c` matrix
186    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
187    fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
188    
189    /// Divides the transposed `a` matrix elements by the transposed `b` matrix elements and then
190    /// the result is in the `c` matrix
191    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
192    fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
193
194    /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
195    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
196    fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
197
198    /// Adds the transposed `a` matrix onto the `b` scalar and then the result is in the `c` matrix
199    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi>b</mi></mrow></math>).
200    fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
201
202    /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix.
203    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
204    fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
205
206    /// Subtracts the `b` scalar from the transposed `a` matrix and then the result is in the `c`
207    /// matrix
208    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi>b</mi></mrow></math>).
209    fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
210
211    /// Subtracts the `a` matrix from the `b` scalar  and then the result is in the `c` matrix
212    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
213    fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
214
215    /// Subtracts the transposed `a` matrix from the `b` scalar  and then the result is in the `c`
216    /// matrix
217    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
218    fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
219    
220    /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
221    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
222    fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
223
224    /// Multiplies the transposed `a` matrix by the `b` scalar and then the result is in the `c`
225    /// matrix
226    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi>b</mi></mrow></math>).
227    fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
228
229    /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
230    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
231    fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
232
233    /// Divides the transposed `a` matrix by the `b` scalar and then the result is in the `c`
234    /// matrix
235    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mi>b</mi></mfrac></mrow></math>).
236    fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
237
238    /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
239    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
240    fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
241
242    /// Divides the `b` scalar by the transposed `a` matrix elements and then the result is in the
243    /// `c` matrix
244    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
245    fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
246
247    /// Calculates sigmoid function for the `a` matrix adn the result is the `b` matrix
248    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
249    fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
250
251    /// Calculates sigmoid function for the transposed `a` matrix and then the result is in the `b`
252    /// matrix
253    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
254    fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
255
256    /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in `b`
257    /// matrix
258    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
259    fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
260
261    /// Calculates hyperbolic tangent function for the transposed `a` matrix and then the result is
262    /// in the `b` matrix
263    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
264    fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
265
266    /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
267    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
268    fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
269
270    /// Calculates softmax function for the transposed `a` matrix and then the result is in the `b`
271    /// matrix
272    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
273    fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
274
275    /// Repeats the `a` vector as column
276    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math>).
277    fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
278
279    /// Repeats the `a` vector as row
280    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
281    fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
282}
283
284/// An error enumeration.
285#[derive(Debug)]
286pub enum Error
287{
288    /// Can't initialize a default backend.
289    DefaultBackendInitialization,
290    /// Mismatched sizes of matrices for a matrix operation.
291    OpSize(usize, usize, usize, usize),
292    /// Mismatched sizes of matrices for a matrix multiplication.
293    MulSize(usize, usize, usize, usize, usize, usize),
294    /// Mismatched sizes of matrices for a matrix transposition.
295    TransposeSize(usize, usize, usize, usize),
296    /// An argument matrix is transposed.
297    ArgTransposition,
298    /// A result matrix is transposed.
299    ResTransposition,
300    /// A number of matrix elements isn't equal to a number of elements.
301    MatrixElemCount(usize, usize),
302    /// A matrix isn't a vector.
303    IsNotVector,
304    /// A mutex can't be locked.
305    Mutex,
306    /// An OpenCL error.
307    #[cfg(feature = "opencl")]
308    OpenCl(opencl::ClError),
309    /// A CUDA error.
310    #[cfg(feature = "cuda")]
311    Cuda(cuda::DriverError),
312    /// A cuBLAS error.
313    #[cfg(feature = "cuda")]
314    Cublas(cuda::CublasError),
315    /// No a cuBLAS.
316    #[cfg(feature = "cuda")]
317    NoCublas,
318    /// A compilation error.
319    Compilation(String),
320    /// No a platform.
321    NoPlatform,
322    /// No a device.
323    NoDevice,
324    /// No a kernel.
325    NoKernel(String),
326    /// A type of device information is invalid.
327    InvalidDeviceInfoType,
328    /// A number of backend array elements isn't equal to a number of elements.
329    BackendArrayElemCount(usize, usize),
330    /// Two numbers of elements of backend arrays aren't equal.
331    TwoBackendArrayElemCounts(usize, usize),
332    /// A backend array is invalid.
333    InvalidBackendArray,
334}
335
336impl error::Error for Error
337{}
338
339impl fmt::Display for Error
340{
341    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
342    {
343        match self {
344            Error::DefaultBackendInitialization => write!(f, "can't initialize default backend"),
345            Error::OpSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices ({}x{}, {}x{})", n1, m1, n2, m2),
346            Error::MulSize(n1, m1, n2, m2, n3, m3) => write!(f, "mismatched sizes of matrices for multiplication ({}x{}, {}x{}, {}x{})", n1, m1, n2, m2, n3, m3),
347            Error::TransposeSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices for transposition ({}x{}, {}x{})", n1, m1, n2, m2),
348            Error::ArgTransposition => write!(f, "argument matrix is transposed"),
349            Error::ResTransposition => write!(f, "result matrix is transposed"),
350            Error::MatrixElemCount(n1, n2) => write!(f, "number of matrix elements isn't equal to number of elements ({}, {})", n1, n2),
351            Error::IsNotVector => write!(f, "matrix isn't vector"),
352            Error::Mutex => write!(f, "can't lock mutex"),
353            #[cfg(feature = "opencl")]
354            Error::OpenCl(err) => write!(f, "OpenCL error: {}", err),
355            #[cfg(feature = "cuda")]
356            Error::Cuda(err) => write!(f, "CUDA error: {}", err),
357            #[cfg(feature = "cuda")]
358            Error::Cublas(err) => write!(f, "cuBLAS error: {}", err),
359            #[cfg(feature = "cuda")]
360            Error::NoCublas => write!(f, "no cuBLAS"),
361            Error::Compilation(msg) => write!(f, "{}", msg),
362            Error::NoPlatform => write!(f, "no platform"),
363            Error::NoDevice => write!(f, "no device"),
364            Error::NoKernel(name) => write!(f, "no kernel {}", name),
365            Error::InvalidDeviceInfoType => write!(f, "invalid device info type"),
366            Error::BackendArrayElemCount(n1, n2) => write!(f, "number of backend array elements isn't equal to number of elements ({}, {})", n1, n2),
367            Error::TwoBackendArrayElemCounts(n1, n2) => write!(f, "two numbers of elements of backend arrays aren't equal ({}, {})", n1, n2),
368            Error::InvalidBackendArray => write!(f, "invalid backend array"),
369        }
370    }
371}
372
373/// A result type.
374pub type Result<T> = result::Result<T, Error>;
375
376/// An enumeration of backend array.
377///
378/// This enumeration contains the reference to the area of the device memory for computing
379/// platform (OpenCL or CUDA).
380#[derive(Debug)]
381pub enum BackendArray
382{
383    /// A backend array for OpenCL.
384    #[cfg(feature = "opencl")]
385    OpenCl(opencl::ClBackendArray),
386    /// A backend array for CUDA.
387    #[cfg(feature = "cuda")]
388    Cuda(cuda::CudaBackendArray),
389}
390
391static mut DEFAULT_BACKEND: Mutex<Option<Arc<dyn Backend>>> = Mutex::new(None);
392
393fn mutex_lock<T>(mutex: &Mutex<T>) -> Result<MutexGuard<'_, T>>
394{
395    match mutex.lock() {
396        Ok(guard) => Ok(guard),
397        Err(_) => return Err(Error::Mutex),
398    }
399}
400
401/// Returns a default backend.
402pub fn get_default_backend() -> Result<Option<Arc<dyn Backend>>>
403{
404    unsafe {
405        let default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
406        Ok(default_backend_g.clone())
407    }
408}
409
410/// Sets a default backend.
411pub fn set_default_backend(backend: Arc<dyn Backend>) -> Result<()>
412{
413    unsafe {
414        let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
415        *default_backend_g = Some(backend);
416    }
417    Ok(())
418}
419
420/// Unsets a default backend.
421pub fn unset_default_backend() -> Result<()>
422{
423    unsafe {
424        let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
425        *default_backend_g = None;
426    }
427    Ok(())
428}
429
430/// Sets a default backend if the default backend is uninitialized and returns the default backend.
431///
432/// This method takes a closure that returns the backend and then the backend is set as the default
433/// backend if the default backend is uninitialized. The closure is only called if the backend is
434/// to be set.
435pub fn set_default_backend_for_uninitialized<F>(f: F) -> Result<Arc<dyn Backend>>
436    where F: FnOnce() -> Result<Arc<dyn Backend>>
437{
438    unsafe {
439        let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
440        match &*default_backend_g {
441            Some(default_backend) => Ok(default_backend.clone()),
442            None => {
443                let backend = f()?;
444                *default_backend_g = Some(backend.clone());
445                Ok(backend)
446            },
447        }
448    }
449}
450
451/// Initializes a default backend if the backend is uninitialized and returns the default backend.
452pub fn initialize_default_backend_for_uninitialized() -> Result<Arc<dyn Backend>>
453{
454    #[cfg(feature = "opencl")]
455    let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(opencl::ClBackend::new()?)));
456    #[cfg(all(not(feature = "opencl"), feature = "cuda"))]
457    let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(cuda::CudaBackend::new()?)));
458    #[cfg(all(not(feature = "opencl"), not(feature = "cuda")))]
459    let res: Result<Arc<dyn Backend>> = Err(Error::DefaultBackendInitialization);
460    res
461}
462
463/// Finalizes a default backend.
464pub fn finalize_default_backend() -> Result<()>
465{ unset_default_backend() }
466
467/// Creates a matrix from the arguments.
468///
469/// # Examples
470///
471/// ```
472/// # use unmtx_gpu::*;
473/// let a = matrix![
474///     [1.0, 2.0, 3.0],
475///     [4.0, 5.0, 6.0]
476/// ];
477/// assert_eq!(2, a.row_count());
478/// assert_eq!(3, a.col_count());
479/// assert_eq!(false, a.is_transposed());
480/// assert_eq!(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], a.elems());
481/// ```
482#[macro_export]
483macro_rules! matrix {
484    ($([$($elem:expr),* $(,)*]),* $(,)*) => {
485        $crate::Matrix::new_with_elem_vecs(vec![$(vec![$($elem),*]),*].as_slice())
486    };
487}
488
489/// A matrix structure.
490#[derive(Clone, Debug)]
491pub struct Matrix
492{
493    row_count: usize,
494    col_count: usize,
495    is_transposed: bool,
496    array: Arc<BackendArray>,
497}
498
499impl Matrix
500{
501    /// Creates a matrix with the number of rows and the number of columns.
502    pub fn new(row_count: usize, col_count: usize) -> Self
503    {
504        let frontend = Frontend::new().unwrap();
505        frontend.create_matrix_and_set_zeros(row_count, col_count).unwrap()
506    }
507
508    /// Creates a matrix with the number of rows, the number of columns, and the elements.
509    pub fn new_with_elems(row_count: usize, col_count: usize, elems: &[f32]) -> Self
510    {
511        let frontend = Frontend::new().unwrap();
512        frontend.create_matrix_and_set_elems(row_count, col_count, elems).unwrap()
513    }
514
515    /// Creates a matrix with the vector of rows.
516    pub fn new_with_elem_vecs(elem_vecs: &[Vec<f32>]) -> Self
517    {
518        let frontend = Frontend::new().unwrap();
519        let col_count = match elem_vecs.first() {
520            Some(elems) => elems.len(),
521            None => 0,
522        };
523        for row in elem_vecs {
524            assert_eq!(col_count, row.len());
525        }
526        let row_count = elem_vecs.len();
527        let elems: Vec<f32> = elem_vecs.iter().flatten().map(|e| *e).collect();
528        frontend.create_matrix_and_set_elems(row_count, col_count, elems.as_slice()).unwrap()
529    }
530
531    /// Returns the number of matrix rows.
532    pub fn row_count(&self) -> usize
533    { self.row_count }
534    
535    /// Returns the number of matrix columns.
536    pub fn col_count(&self) -> usize
537    { self.col_count }
538
539    /// Returns `true` if the matrix is transposed, otherwise `false`.
540    ///
541    /// This method indeed returns the transpose flag of matrix that is changed by
542    /// [`transpose`](Self::transpose).
543    pub fn is_transposed(&self) -> bool
544    { self.is_transposed }
545    
546    /// Returns the matrix elements.
547    pub fn elems(&self) -> Vec<f32>
548    {
549        let frontend = Frontend::new().unwrap();
550        frontend.elems_and_transpose_flag(self).unwrap().0
551    }
552    
553    /// Creates a matrix copy. 
554    ///
555    /// This method indeed copies the matrix array to a new matrix array.
556    pub fn copy(&self) -> Self
557    {
558        let frontend = Frontend::new().unwrap();
559        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
560        frontend.copy(self, &res).unwrap();
561        res
562    }
563    
564    /// Transposes the matrix
565    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
566    ///
567    /// This method doesn't indeed transpose the matrix but changes the transpose flag and
568    /// exchanges the number of matrix rows with the number of matrix columns.
569    ///
570    /// # Examples
571    ///
572    /// ```
573    /// # use unmtx_gpu::*;
574    /// let a = matrix![
575    ///     [1.0, 2.0, 3.0],
576    ///     [4.0, 5.0, 6.0]
577    /// ];
578    /// let b = a.transpose();
579    /// assert_eq!(3, b.row_count());
580    /// assert_eq!(2, b.col_count());
581    /// assert_eq!(true, b.is_transposed());
582    /// assert_eq!(a.elems(), b.elems());
583    /// let c = b.transpose();
584    /// assert_eq!(2, c.row_count());
585    /// assert_eq!(3, c.col_count());
586    /// assert_eq!(false, c.is_transposed());
587    /// assert_eq!(a.elems(), c.elems());
588    /// ```
589    pub fn transpose(&self) -> Self
590    {
591        Matrix {
592            row_count: self.col_count,
593            col_count: self.row_count,
594            is_transposed: !self.is_transposed,
595            array: self.array.clone(),
596        }
597    }
598    
599    /// Indeed transposes the matrix
600    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
601    ///
602    /// This method indeed transposes the matrix without changing the transpose flag.
603    ///
604    /// # Examples
605    ///
606    /// ```
607    /// # use unmtx_gpu::*;
608    /// let a = matrix![
609    ///     [1.0, 2.0, 3.0],
610    ///     [4.0, 5.0, 6.0]
611    /// ];
612    /// let b = a.really_transpose();
613    /// assert_eq!(3, b.row_count());
614    /// assert_eq!(2, b.col_count());
615    /// assert_eq!(false, b.is_transposed());
616    /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
617    /// ```
618    pub fn really_transpose(&self) -> Self
619    {
620        let frontend = Frontend::new().unwrap();
621        let res = unsafe { frontend.create_matrix(self.col_count, self.row_count) }.unwrap();
622        frontend.really_transpose(self, &res).unwrap();
623        res
624    }
625    
626    /// Multiplies the matrix elements by the `b` matrix elements
627    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).    
628    ///
629    /// # Examples
630    ///
631    /// ```
632    /// # use unmtx_gpu::*;
633    /// let a = matrix![
634    ///     [1.0, 2.0],
635    ///     [3.0, 4.0]
636    /// ];
637    /// let b = matrix![
638    ///     [5.0, 6.0],
639    ///     [7.0, 8.0]
640    /// ];
641    /// let c = a.mul_elems(&b);
642    /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
643    /// ```
644    pub fn mul_elems(&self, b: &Self) -> Self
645    {
646        let frontend = Frontend::new().unwrap();
647        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
648        frontend.mul_elems(self, b, &res).unwrap();
649        res
650    }
651
652    /// Divides the matrix elements by the `b` matrix elements
653    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
654    ///
655    /// # Examples
656    ///
657    /// ```
658    /// # use unmtx_gpu::*;
659    /// let a = matrix![
660    ///     [1.0, 2.0],
661    ///     [3.0, 4.0]
662    /// ];
663    /// let b = matrix![
664    ///     [5.0, 6.0],
665    ///     [7.0, 8.0]
666    /// ];
667    /// let c = a.div_elems(&b);
668    /// let elems = c.elems();
669    /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
670    /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
671    /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
672    /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
673    /// ```
674    pub fn div_elems(&self, b: &Self) -> Self
675    {
676        let frontend = Frontend::new().unwrap();
677        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
678        frontend.div_elems(self, b, &res).unwrap();
679        res
680    }
681
682    /// Subtracts the matrix from the scalar
683    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
684    ///
685    /// # Examples
686    ///
687    /// ```
688    /// # use unmtx_gpu::*;
689    /// let a = matrix![
690    ///     [1.0, 2.0],
691    ///     [3.0, 4.0]
692    /// ];
693    /// let b = a.rsub(10.5);
694    /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], b.elems());
695    /// ```
696    pub fn rsub(&self, b: f32) -> Self
697    {
698        let frontend = Frontend::new().unwrap();
699        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
700        frontend.rsub_for_scalar(self, b, &res).unwrap();
701        res
702    }
703
704    /// Divides the scalar by the matrix elements
705    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
706    ///
707    /// # Examples
708    ///
709    /// ```
710    /// # use unmtx_gpu::*;
711    /// let a = matrix![
712    ///     [1.0, 2.0],
713    ///     [3.0, 4.0]
714    /// ];
715    /// let b = a.rdiv(10.5);
716    /// let elems = b.elems();
717    /// assert!((10.5 / 1.0 - elems[0]).abs() < 0.001);
718    /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
719    /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
720    /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
721    /// ```
722    pub fn rdiv(&self, b: f32) -> Self
723    {
724        let frontend = Frontend::new().unwrap();
725        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
726        frontend.rdiv_for_scalar(self, b, &res).unwrap();
727        res
728    }
729
730    /// Calculates sigmoid function for the matrix
731    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
732    ///
733    /// # Examples
734    ///
735    /// ```
736    /// # use unmtx_gpu::*;
737    /// let a = matrix![
738    ///     [1.0, 2.0],
739    ///     [3.0, 4.0]
740    /// ];
741    /// let b = a.sigmoid();
742    /// let elems = b.elems();
743    /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
744    /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
745    /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
746    /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
747    /// ```
748    pub fn sigmoid(&self) -> Self
749    {
750        let frontend = Frontend::new().unwrap();
751        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
752        frontend.sigmoid(self, &res).unwrap();
753        res
754    }
755
756    /// Calculates hiperbolic tangent function for the matrix
757    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
758    ///
759    /// # Examples
760    ///
761    /// ```
762    /// # use unmtx_gpu::*;
763    /// let a = matrix![
764    ///     [1.0, 2.0],
765    ///     [3.0, 4.0]
766    /// ];
767    /// let b = a.tanh();
768    /// let elems = b.elems();
769    /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
770    /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
771    /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
772    /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
773    /// ```
774    pub fn tanh(&self) -> Self
775    {
776        let frontend = Frontend::new().unwrap();
777        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
778        frontend.tanh(self, &res).unwrap();
779        res
780    }
781
782    /// Calculates softmax function for the matrix
783    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
784    ///
785    /// # Examples
786    ///
787    /// ```
788    /// # use unmtx_gpu::*;
789    /// let a = matrix![
790    ///     [1.0, 2.0],
791    ///     [3.0, 4.0]
792    /// ];
793    /// let b = a.softmax();
794    /// let elems = b.elems();
795    /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
796    /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
797    /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
798    /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
799    /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
800    /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
801    /// ```
802    pub fn softmax(&self) -> Self
803    {
804        let frontend = Frontend::new().unwrap();
805        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
806        frontend.softmax(self, &res).unwrap();
807        res
808    }
809    
810    /// Repeats the vector as column or a row.
811    ///
812    /// # Examples
813    ///
814    /// ```
815    /// # use unmtx_gpu::*;
816    /// let a = matrix![
817    ///     [1.0],
818    ///     [2.0]
819    /// ];
820    /// let b = a.repeat(3);
821    /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
822    /// let c = matrix![[1.0, 2.0, 3.0]];
823    /// let d = c.repeat(2);
824    /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
825    /// ```
826    pub fn repeat(&self, n: usize) -> Self
827    {
828        assert!(self.col_count == 1 || self.row_count == 1); 
829        let frontend = Frontend::new().unwrap();
830        let res = if self.col_count == 1 {
831            unsafe { frontend.create_matrix(self.row_count, n) }.unwrap()
832        } else {
833            unsafe { frontend.create_matrix(n, self.col_count) }.unwrap()
834        };
835        frontend.repeat(self, &res).unwrap();
836        res
837    }
838}
839
840impl Add for Matrix
841{
842    type Output = Self;
843    
844    fn add(self, rhs: Self) -> Self::Output
845    {
846        let frontend = Frontend::new().unwrap();
847        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
848        frontend.add(&self, &rhs, &res).unwrap();
849        res
850    }
851}
852
853impl Add<&Matrix> for Matrix
854{
855    type Output = Self;
856    
857    fn add(self, rhs: &Matrix) -> Self::Output
858    {
859        let frontend = Frontend::new().unwrap();
860        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
861        frontend.add(&self, rhs, &res).unwrap();
862        res
863    }
864}
865
866impl Add<f32> for Matrix
867{
868    type Output = Self;
869    
870    fn add(self, rhs: f32) -> Self::Output
871    {
872        let frontend = Frontend::new().unwrap();
873        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
874        frontend.add_for_scalar(&self, rhs, &res).unwrap();
875        res
876    }
877}
878
879impl Add<&f32> for Matrix
880{
881    type Output = Self;
882    
883    fn add(self, rhs: &f32) -> Self::Output
884    {
885        let frontend = Frontend::new().unwrap();
886        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
887        frontend.add_for_scalar(&self, *rhs, &res).unwrap();
888        res
889    }
890}
891
892impl Add<Matrix> for &Matrix
893{
894    type Output = Matrix;
895    
896    fn add(self, rhs: Matrix) -> Self::Output
897    {
898        let frontend = Frontend::new().unwrap();
899        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
900        frontend.add(self, &rhs, &res).unwrap();
901        res
902    }
903}
904
905impl Add<&Matrix> for &Matrix
906{
907    type Output = Matrix;
908    
909    fn add(self, rhs: &Matrix) -> Self::Output
910    {
911        let frontend = Frontend::new().unwrap();
912        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
913        frontend.add(self, rhs, &res).unwrap();
914        res
915    }
916}
917
918impl Add<f32> for &Matrix
919{
920    type Output = Matrix;
921    
922    fn add(self, rhs: f32) -> Self::Output
923    {
924        let frontend = Frontend::new().unwrap();
925        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
926        frontend.add_for_scalar(self, rhs, &res).unwrap();
927        res
928    }
929}
930
931impl Add<&f32> for &Matrix
932{
933    type Output = Matrix;
934    
935    fn add(self, rhs: &f32) -> Self::Output
936    {
937        let frontend = Frontend::new().unwrap();
938        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
939        frontend.add_for_scalar(self, *rhs, &res).unwrap();
940        res
941    }
942}
943
944impl AddAssign for Matrix
945{
946    fn add_assign(&mut self, rhs: Self)
947    {
948        let frontend = Frontend::new().unwrap();
949        frontend.add(self, &rhs, &self).unwrap();
950    }
951}
952
953impl AddAssign<&Matrix> for Matrix
954{
955    fn add_assign(&mut self, rhs: &Self)
956    {
957        let frontend = Frontend::new().unwrap();
958        frontend.add(&self, rhs, &self).unwrap();
959    }
960}
961
962impl AddAssign<f32> for Matrix
963{
964    fn add_assign(&mut self, rhs: f32)
965    {
966        let frontend = Frontend::new().unwrap();
967        frontend.add_for_scalar(&self, rhs, &self).unwrap();
968    }
969}
970
971impl AddAssign<&f32> for Matrix
972{
973    fn add_assign(&mut self, rhs: &f32)
974    {
975        let frontend = Frontend::new().unwrap();
976        frontend.add_for_scalar(&self, *rhs, &self).unwrap();
977    }
978}
979
980impl Sub for Matrix
981{
982    type Output = Self;
983    
984    fn sub(self, rhs: Self) -> Self::Output
985    {
986        let frontend = Frontend::new().unwrap();
987        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
988        frontend.sub(&self, &rhs, &res).unwrap();
989        res
990    }
991}
992
993impl Sub<&Matrix> for Matrix
994{
995    type Output = Self;
996    
997    fn sub(self, rhs: &Matrix) -> Self::Output
998    {
999        let frontend = Frontend::new().unwrap();
1000        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1001        frontend.sub(&self, rhs, &res).unwrap();
1002        res
1003    }
1004}
1005
1006impl Sub<f32> for Matrix
1007{
1008    type Output = Self;
1009    
1010    fn sub(self, rhs: f32) -> Self::Output
1011    {
1012        let frontend = Frontend::new().unwrap();
1013        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1014        frontend.sub_for_scalar(&self, rhs, &res).unwrap();
1015        res
1016    }
1017}
1018
1019impl Sub<&f32> for Matrix
1020{
1021    type Output = Self;
1022    
1023    fn sub(self, rhs: &f32) -> Self::Output
1024    {
1025        let frontend = Frontend::new().unwrap();
1026        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1027        frontend.sub_for_scalar(&self, *rhs, &res).unwrap();
1028        res
1029    }
1030}
1031
1032impl Sub<Matrix> for &Matrix
1033{
1034    type Output = Matrix;
1035    
1036    fn sub(self, rhs: Matrix) -> Self::Output
1037    {
1038        let frontend = Frontend::new().unwrap();
1039        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1040        frontend.sub(self, &rhs, &res).unwrap();
1041        res
1042    }
1043}
1044
1045impl Sub<&Matrix> for &Matrix
1046{
1047    type Output = Matrix;
1048    
1049    fn sub(self, rhs: &Matrix) -> Self::Output
1050    {
1051        let frontend = Frontend::new().unwrap();
1052        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1053        frontend.sub(self, rhs, &res).unwrap();
1054        res
1055    }
1056}
1057
1058impl Sub<f32> for &Matrix
1059{
1060    type Output = Matrix;
1061    
1062    fn sub(self, rhs: f32) -> Self::Output
1063    {
1064        let frontend = Frontend::new().unwrap();
1065        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1066        frontend.sub_for_scalar(self, rhs, &res).unwrap();
1067        res
1068    }
1069}
1070
1071impl Sub<&f32> for &Matrix
1072{
1073    type Output = Matrix;
1074    
1075    fn sub(self, rhs: &f32) -> Self::Output
1076    {
1077        let frontend = Frontend::new().unwrap();
1078        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1079        frontend.sub_for_scalar(self, *rhs, &res).unwrap();
1080        res
1081    }
1082}
1083
1084impl SubAssign for Matrix
1085{
1086    fn sub_assign(&mut self, rhs: Self)
1087    {
1088        let frontend = Frontend::new().unwrap();
1089        frontend.sub(&self, &rhs, &self).unwrap();
1090    }
1091}
1092
1093impl SubAssign<&Matrix> for Matrix
1094{
1095    fn sub_assign(&mut self, rhs: &Self)
1096    {
1097        let frontend = Frontend::new().unwrap();
1098        frontend.sub(&self, rhs, &self).unwrap();
1099    }
1100}
1101
1102impl SubAssign<f32> for Matrix
1103{
1104    fn sub_assign(&mut self, rhs: f32)
1105    {
1106        let frontend = Frontend::new().unwrap();
1107        frontend.sub_for_scalar(&self, rhs, &self).unwrap();
1108    }
1109}
1110
1111impl SubAssign<&f32> for Matrix
1112{
1113    fn sub_assign(&mut self, rhs: &f32)
1114    {
1115        let frontend = Frontend::new().unwrap();
1116        frontend.sub_for_scalar(&self, *rhs, &self).unwrap();
1117    }
1118}
1119
1120impl Mul for Matrix
1121{
1122    type Output = Self;
1123    
1124    fn mul(self, rhs: Self) -> Self::Output
1125    {
1126        let frontend = Frontend::new().unwrap();
1127        let res = if frontend.backend().has_cublas() {
1128            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1129        } else {
1130            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1131        };
1132        frontend.mul(&self, &rhs, &res).unwrap();
1133        res
1134    }
1135}
1136
1137impl Mul<&Matrix> for Matrix
1138{
1139    type Output = Self;
1140    
1141    fn mul(self, rhs: &Matrix) -> Self::Output
1142    {
1143        let frontend = Frontend::new().unwrap();
1144        let res = if frontend.backend().has_cublas() {
1145            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1146        } else {
1147            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1148        };
1149        frontend.mul(&self, rhs, &res).unwrap();
1150        res
1151    }
1152}
1153
1154impl Mul<f32> for Matrix
1155{
1156    type Output = Self;
1157    
1158    fn mul(self, rhs: f32) -> Self::Output
1159    {
1160        let frontend = Frontend::new().unwrap();
1161        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1162        frontend.mul_for_scalar(&self, rhs, &res).unwrap();
1163        res
1164    }
1165}
1166
1167impl Mul<&f32> for Matrix
1168{
1169    type Output = Self;
1170    
1171    fn mul(self, rhs: &f32) -> Self::Output
1172    {
1173        let frontend = Frontend::new().unwrap();
1174        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1175        frontend.mul_for_scalar(&self, *rhs, &res).unwrap();
1176        res
1177    }
1178}
1179
1180impl Mul<Matrix> for &Matrix
1181{
1182    type Output = Matrix;
1183    
1184    fn mul(self, rhs: Matrix) -> Self::Output
1185    {
1186        let frontend = Frontend::new().unwrap();
1187        let res = if frontend.backend().has_cublas() {
1188            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1189        } else {
1190            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1191        };
1192        frontend.mul(self, &rhs, &res).unwrap();
1193        res
1194    }
1195}
1196
1197impl Mul<&Matrix> for &Matrix
1198{
1199    type Output = Matrix;
1200    
1201    fn mul(self, rhs: &Matrix) -> Self::Output
1202    {
1203        let frontend = Frontend::new().unwrap();
1204        let res = if frontend.backend().has_cublas() {
1205            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1206        } else {
1207            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1208        };
1209        frontend.mul(self, rhs, &res).unwrap();
1210        res
1211    }
1212}
1213
1214impl Mul<f32> for &Matrix
1215{
1216    type Output = Matrix;
1217    
1218    fn mul(self, rhs: f32) -> Self::Output
1219    {
1220        let frontend = Frontend::new().unwrap();
1221        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1222        frontend.mul_for_scalar(self, rhs, &res).unwrap();
1223        res
1224    }
1225}
1226
1227impl Mul<&f32> for &Matrix
1228{
1229    type Output = Matrix;
1230    
1231    fn mul(self, rhs: &f32) -> Self::Output
1232    {
1233        let frontend = Frontend::new().unwrap();
1234        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1235        frontend.mul_for_scalar(self, *rhs, &res).unwrap();
1236        res
1237    }
1238}
1239
1240impl MulAssign for Matrix
1241{
1242    fn mul_assign(&mut self, rhs: Self)
1243    {
1244        let frontend = Frontend::new().unwrap();
1245        let res = if frontend.backend().has_cublas() {
1246            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1247        } else {
1248            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1249        };
1250        frontend.mul(&self, &rhs, &res).unwrap();
1251        *self = res;
1252    }
1253}
1254
1255impl MulAssign<&Matrix> for Matrix
1256{
1257    fn mul_assign(&mut self, rhs: &Self)
1258    {
1259        let frontend = Frontend::new().unwrap();
1260        let res = if frontend.backend().has_cublas() {
1261            frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1262        } else {
1263            unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1264        };
1265        frontend.mul(&self, rhs, &res).unwrap();
1266        *self = res;
1267    }
1268}
1269
1270impl MulAssign<f32> for Matrix
1271{
1272    fn mul_assign(&mut self, rhs: f32)
1273    {
1274        let frontend = Frontend::new().unwrap();
1275        frontend.mul_for_scalar(&self, rhs, &self).unwrap();
1276    }
1277}
1278
1279impl MulAssign<&f32> for Matrix
1280{
1281    fn mul_assign(&mut self, rhs: &f32)
1282    {
1283        let frontend = Frontend::new().unwrap();
1284        frontend.mul_for_scalar(&self, *rhs, &self).unwrap();
1285    }
1286}
1287
1288impl Div<f32> for Matrix
1289{
1290    type Output = Self;
1291    
1292    fn div(self, rhs: f32) -> Self::Output
1293    {
1294        let frontend = Frontend::new().unwrap();
1295        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1296        frontend.div_for_scalar(&self, rhs, &res).unwrap();
1297        res
1298    }
1299}
1300
1301impl Div<&f32> for Matrix
1302{
1303    type Output = Self;
1304    
1305    fn div(self, rhs: &f32) -> Self::Output
1306    {
1307        let frontend = Frontend::new().unwrap();
1308        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1309        frontend.div_for_scalar(&self, *rhs, &res).unwrap();
1310        res
1311    }
1312}
1313
1314impl Div<f32> for &Matrix
1315{
1316    type Output = Matrix;
1317    
1318    fn div(self, rhs: f32) -> Self::Output
1319    {
1320        let frontend = Frontend::new().unwrap();
1321        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1322        frontend.div_for_scalar(self, rhs, &res).unwrap();
1323        res
1324    }
1325}
1326
1327impl Div<&f32> for &Matrix
1328{
1329    type Output = Matrix;
1330    
1331    fn div(self, rhs: &f32) -> Self::Output
1332    {
1333        let frontend = Frontend::new().unwrap();
1334        let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1335        frontend.div_for_scalar(self, *rhs, &res).unwrap();
1336        res
1337    }
1338}
1339
1340impl DivAssign<f32> for Matrix
1341{
1342    fn div_assign(&mut self, rhs: f32)
1343    {
1344        initialize_default_backend_for_uninitialized().unwrap();
1345        let frontend = Frontend::new().unwrap();
1346        frontend.div_for_scalar(&self, rhs, &self).unwrap();
1347    }
1348}
1349
1350impl DivAssign<&f32> for Matrix
1351{
1352    fn div_assign(&mut self, rhs: &f32)
1353    {
1354        initialize_default_backend_for_uninitialized().unwrap();
1355        let frontend = Frontend::new().unwrap();
1356        frontend.div_for_scalar(&self, *rhs, &self).unwrap();
1357    }
1358}
1359
1360/// A frontend structure.
1361///
1362/// The frontend contains methods which operate on matrices or calculate functions for the
1363/// matrices. Backend methods are called by the frontend to operate the matrices. The frontend is
1364/// high-level layer that can be directly used by programmer or a [`Matrix`] structure.
1365pub struct Frontend
1366{
1367    backend: Arc<dyn Backend>,
1368}
1369
1370impl Frontend
1371{
1372    /// Creates a frontend with a default backend.
1373    ///
1374    /// This method also automatically initializes a default backend if the default backend is
1375    /// uninitialized.
1376    pub fn new() -> Result<Frontend>
1377    { Ok(Frontend { backend: initialize_default_backend_for_uninitialized()?, }) }
1378
1379    /// Creates a frotend with the backend.
1380    pub fn new_with_backend(backend: Arc<dyn Backend>) -> Frontend
1381    { Frontend { backend, } }
1382    
1383    /// Returns the backend.
1384    pub fn backend(&self) -> Arc<dyn Backend>
1385    { self.backend.clone() }
1386    
1387    /// Creates a matrix with unset elements.
1388    pub unsafe fn create_matrix(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1389    {
1390        Ok(Matrix {
1391                row_count,
1392                col_count,
1393                is_transposed: false,
1394                array: Arc::new(self.backend.alloc(row_count * col_count)?),
1395        })
1396    }
1397
1398    /// Creates a matrix and sets the matrix elements on zeros.
1399    pub fn create_matrix_and_set_zeros(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1400    {
1401        Ok(Matrix {
1402                row_count,
1403                col_count,
1404                is_transposed: false,
1405                array: Arc::new(self.backend.alloc_and_store_zeros(row_count * col_count)?),
1406        })
1407    }
1408
1409    /// Creates a matrix and sets the matrix elements.
1410    pub fn create_matrix_and_set_elems(&self, row_count: usize, col_count: usize, elems: &[f32]) -> Result<Matrix>
1411    {
1412        if row_count * col_count != elems.len() {
1413            return Err(Error::MatrixElemCount(row_count * col_count, elems.len())); 
1414        }
1415        Ok(Matrix {
1416                row_count,
1417                col_count,
1418                is_transposed: false,
1419                array: Arc::new(self.backend.alloc_and_store(elems)?),
1420        })
1421    }
1422
1423    /// Sets the matrix elements.
1424    pub fn set_elems(&self, a: &Matrix, elems: &[f32]) -> Result<()>
1425    {
1426        if a.row_count() * a.col_count() != elems.len() {
1427            return Err(Error::MatrixElemCount(a.row_count() * a.col_count(), elems.len())); 
1428        }
1429        self.backend.store(&*a.array, elems)
1430    }    
1431
1432    /// Copies the `a` matrix to the `b` matrix.
1433    ///
1434    /// This method indeed copies the `a` matrix array to the `b` matrix array.
1435    pub fn copy(&self, a: &Matrix, b: &Matrix) -> Result<()>
1436    {
1437        if a.row_count != b.row_count || a.col_count != b.col_count {
1438            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1439        }
1440        self.backend.copy(&*a.array, &*b.array)        
1441    }    
1442    
1443    /// Copies the matrix elements to the mutable slice and the transpose flag to the object that
1444    /// is referred by the reference.
1445    pub fn get_elems_and_transpose_flag(&self, a: &Matrix, elems: &mut [f32], is_transposed: &mut bool) -> Result<()>
1446    {
1447        if a.row_count * a.col_count != elems.len() {
1448            return Err(Error::MatrixElemCount(a.row_count * a.col_count, elems.len())); 
1449        }
1450        if !a.is_transposed {
1451            self.backend.load(&*a.array, elems)?;
1452        } else {
1453            self.backend.load(&*a.array, elems)?;
1454        }
1455        *is_transposed = true;
1456        Ok(())
1457    }
1458    
1459    /// Returns the elements and the transpose flag of matrix.
1460    pub fn elems_and_transpose_flag(&self, a: &Matrix) -> Result<(Vec<f32>, bool)>
1461    {
1462        let mut elems: Vec<f32> = vec![0.0; a.row_count * a.col_count];
1463        let mut is_transposed = false;
1464        self.get_elems_and_transpose_flag(a, elems.as_mut_slice(), &mut is_transposed)?;
1465        Ok((elems, is_transposed))
1466    }
1467    
1468    /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
1469    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
1470    ///
1471    /// # Examples
1472    ///
1473    /// ```
1474    /// # use unmtx_gpu::*;
1475    /// let a = matrix![
1476    ///     [1.0, 2.0],
1477    ///     [3.0, 4.0]
1478    /// ];
1479    /// let b = matrix![
1480    ///     [5.0, 6.0],
1481    ///     [7.0, 8.0]
1482    /// ];
1483    /// let c = Matrix::new(2, 2);
1484    /// let frontend = Frontend::new().unwrap();
1485    /// frontend.add(&a, &b, &c).unwrap();
1486    /// assert_eq!(vec![1.0 + 5.0, 2.0 + 6.0, 3.0 + 7.0, 4.0 + 8.0], c.elems());
1487    /// ```
1488    pub fn add(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1489    {
1490        if a.row_count != b.row_count || a.col_count != b.col_count {
1491            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1492        }
1493        if a.row_count != c.row_count || a.col_count != c.col_count {
1494            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1495        }
1496        if c.is_transposed {
1497            return Err(Error::ResTransposition);
1498        }
1499        match (a.is_transposed, b.is_transposed) {
1500            (false, false) => self.backend.add_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1501            (true, false) => self.backend.add_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1502            (false, true) => self.backend.add_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1503            (true, true) => self.backend.add_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1504        }
1505    }
1506
1507    /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
1508    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
1509    ///
1510    /// # Examples
1511    ///
1512    /// ```
1513    /// # use unmtx_gpu::*;
1514    /// let a = matrix![
1515    ///     [1.0, 2.0],
1516    ///     [3.0, 4.0]
1517    /// ];
1518    /// let b = matrix![
1519    ///     [5.0, 6.0],
1520    ///     [7.0, 8.0]
1521    /// ];
1522    /// let c = Matrix::new(2, 2);
1523    /// let frontend = Frontend::new().unwrap();
1524    /// frontend.sub(&a, &b, &c).unwrap();
1525    /// assert_eq!(vec![1.0 - 5.0, 2.0 - 6.0, 3.0 - 7.0, 4.0 - 8.0], c.elems());
1526    /// ```
1527    pub fn sub(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1528    {
1529        if a.row_count != b.row_count || a.col_count != b.col_count {
1530            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1531        }
1532        if a.row_count != c.row_count || a.col_count != c.col_count {
1533            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1534        }
1535        if c.is_transposed {
1536            return Err(Error::ResTransposition);
1537        }
1538        match (a.is_transposed, b.is_transposed) {
1539            (false, false) => self.backend.sub_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1540            (true, false) => self.backend.sub_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1541            (false, true) => self.backend.sub_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1542            (true, true) => self.backend.sub_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1543        }
1544    }
1545
1546    /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
1547    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
1548    ///
1549    /// # Examples
1550    ///
1551    /// ```
1552    /// # use unmtx_gpu::*;
1553    /// let a = matrix![
1554    ///     [1.0, 2.0, 3.0],
1555    ///     [4.0, 5.0, 6.0]
1556    /// ];
1557    /// let b = matrix![
1558    ///     [7.0,  8.0],
1559    ///     [9.0,  10.0],
1560    ///     [11.0, 12.0]
1561    /// ];
1562    /// let c = Matrix::new(2, 2);
1563    /// let frontend = Frontend::new().unwrap();
1564    /// frontend.mul(&a, &b, &c).unwrap();
1565    /// let c11: f32 = 1.0 * 7.0 + 2.0 * 9.0 + 3.0 * 11.0;
1566    /// let c12: f32 = 1.0 * 8.0 + 2.0 * 10.0 + 3.0 * 12.0;
1567    /// let c21: f32 = 4.0 * 7.0 + 5.0 * 9.0 + 6.0 * 11.0;
1568    /// let c22: f32 = 4.0 * 8.0 + 5.0 * 10.0 + 6.0 * 12.0;
1569    /// assert_eq!(vec![c11, c12, c21, c22], c.elems());
1570    /// ```
1571    pub fn mul(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1572    {
1573        if a.row_count != c.row_count {
1574            return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count)); 
1575        }
1576        if b.col_count != c.col_count {
1577            return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count)); 
1578        }
1579        if a.col_count != b.row_count {
1580            return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count)); 
1581        }
1582        if c.is_transposed {
1583            return Err(Error::ResTransposition);
1584        }
1585        match (a.is_transposed, b.is_transposed) {
1586            (false, false) => self.backend.mul_a_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1587            (true, false) => self.backend.mul_at_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1588            (false, true) => self.backend.mul_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1589            (true, true) => self.backend.mul_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1590        }
1591    }
1592
1593    /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
1594    /// `c` matrix
1595    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mrow></math>).
1596    ///
1597    /// # Examples
1598    ///
1599    /// ```
1600    /// # use unmtx_gpu::*;
1601    /// let a = matrix![
1602    ///     [1.0, 2.0],
1603    ///     [3.0, 4.0]
1604    /// ];
1605    /// let b = matrix![
1606    ///     [5.0, 6.0],
1607    ///     [7.0, 8.0]
1608    /// ];
1609    /// let c = Matrix::new(2, 2);
1610    /// let frontend = Frontend::new().unwrap();
1611    /// frontend.mul_elems(&a, &b, &c).unwrap();
1612    /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
1613    /// ```
1614    pub fn mul_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1615    {
1616        if a.row_count != b.row_count || a.col_count != b.col_count {
1617            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1618        }
1619        if a.row_count != c.row_count || a.col_count != c.col_count {
1620            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1621        }
1622        if c.is_transposed {
1623            return Err(Error::ResTransposition);
1624        }
1625        match (a.is_transposed, b.is_transposed) {
1626            (false, false) => self.backend.mul_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1627            (true, false) => self.backend.mul_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1628            (false, true) => self.backend.mul_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1629            (true, true) => self.backend.mul_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1630        }
1631    }
1632
1633    /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the `c`
1634    /// matrix
1635    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
1636    ///
1637    /// # Examples
1638    ///
1639    /// ```
1640    /// # use unmtx_gpu::*;
1641    /// let a = matrix![
1642    ///     [1.0, 2.0],
1643    ///     [3.0, 4.0]
1644    /// ];
1645    /// let b = matrix![
1646    ///     [5.0, 6.0],
1647    ///     [7.0, 8.0]
1648    /// ];
1649    /// let c = Matrix::new(2, 2);
1650    /// let frontend = Frontend::new().unwrap();
1651    /// frontend.div_elems(&a, &b, &c).unwrap();
1652    /// let elems = c.elems();
1653    /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
1654    /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
1655    /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
1656    /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
1657    /// ```
1658    pub fn div_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1659    {
1660        if a.row_count != b.row_count || a.col_count != b.col_count {
1661            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1662        }
1663        if a.row_count != c.row_count || a.col_count != c.col_count {
1664            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1665        }
1666        if c.is_transposed {
1667            return Err(Error::ResTransposition);
1668        }
1669        match (a.is_transposed, b.is_transposed) {
1670            (false, false) => self.backend.div_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1671            (true, false) => self.backend.div_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1672            (false, true) => self.backend.div_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1673            (true, true) => self.backend.div_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1674        }
1675    }
1676
1677    /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
1678    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
1679    ///
1680    /// # Examples
1681    ///
1682    /// ```
1683    /// # use unmtx_gpu::*;
1684    /// let a = matrix![
1685    ///     [1.0, 2.0],
1686    ///     [3.0, 4.0]
1687    /// ];
1688    /// let c = Matrix::new(2, 2);
1689    /// let frontend = Frontend::new().unwrap();
1690    /// frontend.add_for_scalar(&a, 10.5, &c).unwrap();
1691    /// assert_eq!(vec![1.0 + 10.5, 2.0 + 10.5, 3.0 + 10.5, 4.0 + 10.5], c.elems());
1692    /// ```
1693    pub fn add_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1694    {
1695        if a.row_count != c.row_count || a.col_count != c.col_count {
1696            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1697        }
1698        if c.is_transposed {
1699            return Err(Error::ResTransposition);
1700        }
1701        if !a.is_transposed {
1702            self.backend.add_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1703        } else {
1704            self.backend.add_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1705        }
1706    }
1707
1708    /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix
1709    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
1710    ///
1711    /// # Examples
1712    ///
1713    /// ```
1714    /// # use unmtx_gpu::*;
1715    /// let a = matrix![
1716    ///     [1.0, 2.0],
1717    ///     [3.0, 4.0]
1718    /// ];
1719    /// let c = Matrix::new(2, 2);
1720    /// let frontend = Frontend::new().unwrap();
1721    /// frontend.sub_for_scalar(&a, 10.5, &c).unwrap();
1722    /// assert_eq!(vec![1.0 - 10.5, 2.0 - 10.5, 3.0 - 10.5, 4.0 - 10.5], c.elems());
1723    /// ```
1724    pub fn sub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1725    {
1726        if a.row_count != c.row_count || a.col_count != c.col_count {
1727            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1728        }
1729        if c.is_transposed {
1730            return Err(Error::ResTransposition);
1731        }
1732        if !a.is_transposed {
1733            self.backend.sub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1734        } else {
1735            self.backend.sub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1736        }
1737    }
1738
1739    /// Subtracts the `a` matrix from the `b` scalar and then the result is in the `c` matrix
1740    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
1741    ///
1742    /// # Examples
1743    ///
1744    /// ```
1745    /// # use unmtx_gpu::*;
1746    /// let a = matrix![
1747    ///     [1.0, 2.0],
1748    ///     [3.0, 4.0]
1749    /// ];
1750    /// let c = Matrix::new(2, 2);
1751    /// let frontend = Frontend::new().unwrap();
1752    /// frontend.rsub_for_scalar(&a, 10.5, &c).unwrap();
1753    /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], c.elems());
1754    /// ```
1755    pub fn rsub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1756    {
1757        if a.row_count != c.row_count || a.col_count != c.col_count {
1758            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1759        }
1760        if c.is_transposed {
1761            return Err(Error::ResTransposition);
1762        }
1763        if !a.is_transposed {
1764            self.backend.rsub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1765        } else {
1766            self.backend.rsub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1767        }
1768    }
1769    
1770    /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1771    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
1772    ///
1773    /// # Examples
1774    ///
1775    /// ```
1776    /// # use unmtx_gpu::*;
1777    /// let a = matrix![
1778    ///     [1.0, 2.0],
1779    ///     [3.0, 4.0]
1780    /// ];
1781    /// let c = Matrix::new(2, 2);
1782    /// let frontend = Frontend::new().unwrap();
1783    /// frontend.mul_for_scalar(&a, 10.5, &c).unwrap();
1784    /// assert_eq!(vec![1.0 * 10.5, 2.0 * 10.5, 3.0 * 10.5, 4.0 * 10.5], c.elems());
1785    /// ```
1786    pub fn mul_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1787    {
1788        if a.row_count != c.row_count || a.col_count != c.col_count {
1789            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1790        }
1791        if c.is_transposed {
1792            return Err(Error::ResTransposition);
1793        }
1794        if !a.is_transposed {
1795            self.backend.mul_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1796        } else {
1797            self.backend.mul_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1798        }
1799    }
1800    
1801    /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1802    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
1803    ///
1804    /// # Examples
1805    ///
1806    /// ```
1807    /// # use unmtx_gpu::*;
1808    /// let a = matrix![
1809    ///     [1.0, 2.0],
1810    ///     [3.0, 4.0]
1811    /// ];
1812    /// let c = Matrix::new(2, 2);
1813    /// let frontend = Frontend::new().unwrap();
1814    /// frontend.div_for_scalar(&a, 10.5, &c).unwrap();
1815    /// let elems = c.elems();
1816    /// assert!((1.0 / 10.5 - elems[0]).abs() < 0.001);
1817    /// assert!((2.0 / 10.5 - elems[1]).abs() < 0.001);
1818    /// assert!((3.0 / 10.5 - elems[2]).abs() < 0.001);
1819    /// assert!((4.0 / 10.5 - elems[3]).abs() < 0.001);
1820    /// ```
1821    pub fn div_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1822    {
1823        if a.row_count != c.row_count || a.col_count != c.col_count {
1824            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1825        }
1826        if c.is_transposed {
1827            return Err(Error::ResTransposition);
1828        }
1829        if !a.is_transposed {
1830            self.backend.div_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1831        } else {
1832            self.backend.div_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1833        }
1834    }
1835
1836    /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
1837    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ij</mi></msub></mfrac></mrow></math>).
1838    ///
1839    /// # Examples
1840    ///
1841    /// ```
1842    /// # use unmtx_gpu::*;
1843    /// let a = matrix![
1844    ///     [1.0, 2.0],
1845    ///     [3.0, 4.0]
1846    /// ];
1847    /// let c = Matrix::new(2, 2);
1848    /// let frontend = Frontend::new().unwrap();
1849    /// frontend.rdiv_for_scalar(&a, 10.5, &c).unwrap();
1850    /// let elems = c.elems();
1851    /// assert!((10.5 / 1.0- elems[0]).abs() < 0.001);
1852    /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
1853    /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
1854    /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
1855    /// ```
1856    pub fn rdiv_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1857    {
1858        if a.row_count != c.row_count || a.col_count != c.col_count {
1859            return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count)); 
1860        }
1861        if c.is_transposed {
1862            return Err(Error::ResTransposition);
1863        }
1864        if !a.is_transposed {
1865            self.backend.rdiv_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1866        } else {
1867            self.backend.rdiv_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1868        }
1869    }
1870
1871    /// Calculates sigmoid function for the `a` matrix and then the result is in the `b` matrix
1872    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1873    ///
1874    /// # Examples
1875    ///
1876    /// ```
1877    /// # use unmtx_gpu::*;
1878    /// let a = matrix![
1879    ///     [1.0, 2.0],
1880    ///     [3.0, 4.0]
1881    /// ];
1882    /// let b = Matrix::new(2, 2);
1883    /// let frontend = Frontend::new().unwrap();
1884    /// frontend.sigmoid(&a, &b).unwrap();
1885    /// let elems = b.elems();
1886    /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
1887    /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
1888    /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
1889    /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
1890    /// ```
1891    pub fn sigmoid(&self, a: &Matrix, b: &Matrix) -> Result<()>
1892    {
1893        if a.row_count != b.row_count || a.col_count != b.col_count {
1894            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1895        }
1896        if b.is_transposed {
1897            return Err(Error::ResTransposition);
1898        }
1899        if !a.is_transposed {
1900            self.backend.sigmoid_a(&*a.array, &*b.array, a.row_count, a.col_count)
1901        } else {
1902            self.backend.sigmoid_at(&*a.array, &*b.array, a.row_count, a.col_count)
1903        }
1904    }
1905
1906    /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in the `b`
1907    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1908    ///
1909    /// # Examples
1910    ///
1911    /// ```
1912    /// # use unmtx_gpu::*;
1913    /// let a = matrix![
1914    ///     [1.0, 2.0],
1915    ///     [3.0, 4.0]
1916    /// ];
1917    /// let b = Matrix::new(2, 2);
1918    /// let frontend = Frontend::new().unwrap();
1919    /// frontend.tanh(&a, &b).unwrap();
1920    /// let elems = b.elems();
1921    /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
1922    /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
1923    /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
1924    /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
1925    /// ```
1926    pub fn tanh(&self, a: &Matrix, b: &Matrix) -> Result<()>
1927    {
1928        if a.row_count != b.row_count || a.col_count != b.col_count {
1929            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1930        }
1931        if b.is_transposed {
1932            return Err(Error::ResTransposition);
1933        }
1934        if !a.is_transposed {
1935            self.backend.tanh_a(&*a.array, &*b.array, a.row_count, a.col_count)
1936        } else {
1937            self.backend.tanh_at(&*a.array, &*b.array, a.row_count, a.col_count)
1938        }
1939    }    
1940
1941    /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
1942    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1943    ///
1944    /// # Examples
1945    ///
1946    /// ```
1947    /// # use unmtx_gpu::*;
1948    /// let a = matrix![
1949    ///     [1.0, 2.0],
1950    ///     [3.0, 4.0]
1951    /// ];
1952    /// let b = Matrix::new(2, 2);
1953    /// let frontend = Frontend::new().unwrap();
1954    /// frontend.softmax(&a, &b).unwrap();
1955    /// let elems = b.elems();
1956    /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
1957    /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
1958    /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
1959    /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
1960    /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
1961    /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
1962    /// ```
1963    pub fn softmax(&self, a: &Matrix, b: &Matrix) -> Result<()>
1964    {
1965        if a.row_count != b.row_count || a.col_count != b.col_count {
1966            return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
1967        }
1968        if b.is_transposed {
1969            return Err(Error::ResTransposition);
1970        }
1971        if !a.is_transposed {
1972            self.backend.softmax_a(&*a.array, &*b.array, a.row_count, a.col_count)
1973        } else {
1974            self.backend.softmax_at(&*a.array, &*b.array, a.row_count, a.col_count)
1975        }
1976    }    
1977    
1978    /// Indeed transposes the `a` matrix and then the result is in the `b` matrix
1979    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
1980    ///
1981    /// This method indeed transposes the `a` matrix without changing the transpose flag.
1982    ///
1983    /// # Examples
1984    ///
1985    /// ```
1986    /// # use unmtx_gpu::*;
1987    /// let a = matrix![
1988    ///     [1.0, 2.0, 3.0],
1989    ///     [4.0, 5.0, 6.0]
1990    /// ];
1991    /// let b = Matrix::new(3, 2);
1992    /// let frontend = Frontend::new().unwrap();
1993    /// frontend.really_transpose(&a, &b).unwrap();
1994    /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
1995    /// ```
1996    pub fn really_transpose(&self, a: &Matrix, b: &Matrix) -> Result<()>
1997    {
1998        if a.row_count != b.col_count || a.col_count != b.row_count {
1999            return Err(Error::TransposeSize(a.row_count, a.col_count, b.row_count, b.col_count)); 
2000        }
2001        if a.is_transposed {
2002            return Err(Error::ArgTransposition);
2003        }
2004        if b.is_transposed {
2005            return Err(Error::ResTransposition);
2006        }
2007        self.backend.transpose_a(&*a.array, &*b.array, a.col_count, a.row_count)
2008    }
2009
2010    /// Repeats the `a` vector as column or a row
2011    /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math> or 
2012    /// <math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mi mathvariant="italic">ij</mi></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
2013    ///
2014    /// # Examples
2015    ///
2016    /// ```
2017    /// # use unmtx_gpu::*;
2018    /// let a = matrix![
2019    ///     [1.0],
2020    ///     [2.0]
2021    /// ];
2022    /// let b = Matrix::new(2, 3);
2023    /// let frontend = Frontend::new().unwrap();
2024    /// frontend.repeat(&a, &b).unwrap();
2025    /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
2026    /// let c = matrix![[1.0, 2.0, 3.0]];
2027    /// let d = Matrix::new(2, 3);
2028    /// let frontend = Frontend::new().unwrap();
2029    /// frontend.repeat(&c, &d).unwrap();
2030    /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
2031    /// ```
2032    pub fn repeat(&self, a: &Matrix, b: &Matrix) -> Result<()>
2033    {
2034        if b.is_transposed {
2035            return Err(Error::ResTransposition);
2036        }
2037        if a.col_count == 1 {
2038            if a.row_count != b.row_count {
2039                return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2040            }
2041            self.backend.repeat_col_a(&*a.array, &*b.array, a.row_count, b.col_count)
2042        } else if a.row_count == 1 {
2043            if a.col_count != b.col_count {
2044                return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2045            }
2046            self.backend.repeat_row_a(&*a.array, &*b.array, b.row_count, a.col_count)
2047        } else {
2048            Err(Error::IsNotVector)
2049        }
2050    }
2051}
2052
2053#[cfg(test)]
2054mod test_helpers;
2055#[cfg(all(test, not(feature = "test_only_backend")))]
2056mod tests;