unmtx_gpu/lib.rs
1//
2// Copyright (c) 2025 Łukasz Szpakowski
3//
4// This Source Code Form is subject to the terms of the Mozilla Public
5// License, v. 2.0. If a copy of the MPL was not distributed with this
6// file, You can obtain one at https://mozilla.org/MPL/2.0/.
7//
8//! Micro neural matrix library for GPU is small library that operates on matrices.
9//!
10//! This library uses GPU by the following computing platforms:
11//!
12//! - OpenCL
13//! - CUDA
14//!
15//! If this library uses CUDA, this library can use the cuBLAS library to multiplication of
16//! matrices.
17//!
18//! A frontend-backend architecture is used by this library. The frontend of this library can use
19//! one of two backends (OpenCL or CUDA). These backend allows to use GPUs by the computing
20//! platforms. The frontend and the backend can have many instances. This library provides a
21//! high-level interfece to operations of matrices by the frontend and methods of a [`Matrix`]
22//! structure.
23//!
24//! # Examples
25//!
26//! ```
27//! # use unmtx_gpu::*;
28//! let a = matrix![
29//! [1.0, 2.0],
30//! [3.0, 4.0]
31//! ];
32//! let x = matrix![
33//! [5.0],
34//! [6.0]
35//! ];
36//! let b = matrix![
37//! [7.0],
38//! [8.0]
39//! ];
40//! let c = a * x + b;
41//! assert_eq!(vec![1.0 * 5.0 + 2.0 * 6.0 + 7.0, 3.0 * 5.0 + 4.0 * 6.0 + 8.0], c.elems());
42//! ```
43use std::ops::Add;
44use std::ops::AddAssign;
45use std::ops::Sub;
46use std::ops::SubAssign;
47use std::ops::Mul;
48use std::ops::MulAssign;
49use std::ops::Div;
50use std::ops::DivAssign;
51use std::error;
52use std::fmt;
53use std::result;
54use std::sync::Arc;
55use std::sync::Mutex;
56use std::sync::MutexGuard;
57
58#[cfg(feature = "opencl")]
59pub mod opencl;
60#[cfg(feature = "cuda")]
61pub mod cuda;
62
63/// A backend trait.
64///
65/// The backend provides a low-level interface to computing platform (OpenCL or CUDA) for basic
66/// operations and functions on matrices. The backend methods operate on backend arrays which
67/// refers to areas of the device memory. The backend is low-level layer between a frontend and
68/// computing platform.
69pub trait Backend
70{
71 /// Returns the backend name.
72 fn name(&self) -> &'static str;
73
74 /// Returns `true` if the backend uses cuBLAS, otherwise `false`.
75 fn has_cublas(&self) -> bool;
76
77 /// Allocates a backend array.
78 unsafe fn alloc(&self, n: usize) -> Result<BackendArray>;
79
80 /// Allocates a backend array and stores zeros in the backend array.
81 fn alloc_and_store_zeros(&self, n: usize) -> Result<BackendArray>;
82
83 /// Allocates a backend array and stores the elements in the backend array.
84 fn alloc_and_store(&self, elems: &[f32]) -> Result<BackendArray>;
85
86 /// Loads elements from the backenc array.
87 fn load(&self, a: &BackendArray, elems: &mut [f32]) -> Result<()>;
88
89 /// Stores elements in the backend array.
90 fn store(&self, a: &BackendArray, elems: &[f32]) -> Result<()>;
91
92 /// Copies the `a` backend array to the `b` backend array.
93 fn copy(&self, a: &BackendArray, b: &BackendArray) -> Result<()>;
94
95 /// Transposes the `a` matrix and then the result is in the `b` matrix
96 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
97 fn transpose_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
98
99 /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
100 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
101 fn add_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
102
103 /// Adds the transposed `a` matrix onto the `b` matrix and then the result is in the `c` matrix
104 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
105 fn add_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
106
107 /// Adds the `a` matrix onto the transposed `b` matrix and then the result is in the `c` matrix
108 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
109 fn add_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
110
111 /// Adds the transposed `a` matrix onto the transposed `b` matrix and then the result is in the
112 /// `c` matrix
113 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
114 fn add_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
115
116 /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
117 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
118 fn sub_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
119
120 /// Subtracts the `b` matrix from the transposed `a` matrix and then the result is in the `c`
121 /// matrix
122 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
123 fn sub_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
124
125 /// Subtracts the transposed `b` matrix from the `a` matrix and then the result is in the `c`
126 /// matrix
127 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
128 fn sub_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
129
130 /// Subtracts the transposed `b` matrix from the transposed `a` matrix and then the result is
131 /// in the `c` matrix
132 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
133 fn sub_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
134
135 /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
136 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
137 fn mul_a_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
138
139 /// Multiplies the transposed `a` matrix by the `b` matrix and then the result is in the `c`
140 /// matrix
141 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
142 fn mul_at_b(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
143
144 /// Multiplies the `a` matrix by the transposed `b` matrix and then the result is in the `c`
145 /// matrix
146 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
147 fn mul_a_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
148
149 /// Multiplies the transposed `a` matrix by the transposed `b` matrix and then the result is in
150 /// the `c` matrix
151 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><msup><mi mathvariant="bold">B</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
152 fn mul_at_bt(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize, l: usize) -> Result<()>;
153
154 /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
155 /// `c` matrix
156 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>·</mo><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow></math>).
157 fn mul_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
158
159 /// Multiplies the transposed `a` matrix elements by the `b` matrix elements and saves the
160 /// result to the `c` matrix
161 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow></math>).
162 fn mul_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
163
164 /// Multiplies the `a` matrix elements by the transposed `b` matrix elements and then the
165 /// result is in the `c` matrix
166 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
167 fn mul_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
168
169 /// Multiplies the transposed `a` matrix elements by the transposed `b` matrix elements and
170 /// then the result is in the `c` matrix.
171 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><mo>·</mo><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mrow></math>).
172 fn mul_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
173
174 /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the
175 /// `c` matrix
176 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
177 fn div_a_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
178
179 /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
180 /// is in the `c` matrix
181 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
182 fn div_at_b_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
183
184 /// Divides the transposed `a` matrix elements by the `b` matrix elements and then the result
185 /// is in the `c` matrix
186 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
187 fn div_a_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
188
189 /// Divides the transposed `a` matrix elements by the transposed `b` matrix elements and then
190 /// the result is in the `c` matrix
191 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub><msub><mi>b</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
192 fn div_at_bt_for_elems(&self, a: &BackendArray, b: &BackendArray, c: &BackendArray, n: usize, m: usize) -> Result<()>;
193
194 /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
195 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
196 fn add_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
197
198 /// Adds the transposed `a` matrix onto the `b` scalar and then the result is in the `c` matrix
199 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>+</mo><mi>b</mi></mrow></math>).
200 fn add_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
201
202 /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix.
203 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
204 fn sub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
205
206 /// Subtracts the `b` scalar from the transposed `a` matrix and then the result is in the `c`
207 /// matrix
208 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>-</mo><mi>b</mi></mrow></math>).
209 fn sub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
210
211 /// Subtracts the `a` matrix from the `b` scalar and then the result is in the `c` matrix
212 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
213 fn rsub_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
214
215 /// Subtracts the transposed `a` matrix from the `b` scalar and then the result is in the `c`
216 /// matrix
217 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
218 fn rsub_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
219
220 /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
221 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
222 fn mul_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
223
224 /// Multiplies the transposed `a` matrix by the `b` scalar and then the result is in the `c`
225 /// matrix
226 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo>·</mo><mi>b</mi></mrow></math>).
227 fn mul_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
228
229 /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
230 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
231 fn div_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
232
233 /// Divides the transposed `a` matrix by the `b` scalar and then the result is in the `c`
234 /// matrix
235 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mi>b</mi></mfrac></mrow></math>).
236 fn div_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
237
238 /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
239 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
240 fn rdiv_a_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
241
242 /// Divides the `b` scalar by the transposed `a` matrix elements and then the result is in the
243 /// `c` matrix
244 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mi mathvariant="italic">ji</mi></msub></mfrac></mrow></math>).
245 fn rdiv_at_b_for_scalar(&self, a: &BackendArray, b: f32, c: &BackendArray, n: usize, m: usize) -> Result<()>;
246
247 /// Calculates sigmoid function for the `a` matrix adn the result is the `b` matrix
248 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
249 fn sigmoid_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
250
251 /// Calculates sigmoid function for the transposed `a` matrix and then the result is in the `b`
252 /// matrix
253 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
254 fn sigmoid_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
255
256 /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in `b`
257 /// matrix
258 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
259 fn tanh_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
260
261 /// Calculates hyperbolic tangent function for the transposed `a` matrix and then the result is
262 /// in the `b` matrix
263 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
264 fn tanh_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
265
266 /// Calculates swish function for the `a` matrix adn the result is the `b` matrix
267 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>swish</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
268 fn swish_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
269
270 /// Calculates swish function for the transposed `a` matrix and then the result is in the `b`
271 /// matrix
272 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>swish</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
273 fn swish_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
274
275 /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
276 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
277 fn softmax_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
278
279 /// Calculates softmax function for the transposed `a` matrix and then the result is in the `b`
280 /// matrix
281 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup><mo fence="true">)</mo></mrow></math>).
282 fn softmax_at(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
283
284 /// Repeats the `a` vector as column
285 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math>).
286 fn repeat_col_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
287
288 /// Repeats the `a` vector as row
289 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
290 fn repeat_row_a(&self, a: &BackendArray, b: &BackendArray, n: usize, m: usize) -> Result<()>;
291}
292
293/// An error enumeration.
294#[derive(Debug)]
295pub enum Error
296{
297 /// Can't initialize a default backend.
298 DefaultBackendInitialization,
299 /// Mismatched sizes of matrices for a matrix operation.
300 OpSize(usize, usize, usize, usize),
301 /// Mismatched sizes of matrices for a matrix multiplication.
302 MulSize(usize, usize, usize, usize, usize, usize),
303 /// Mismatched sizes of matrices for a matrix transposition.
304 TransposeSize(usize, usize, usize, usize),
305 /// An argument matrix is transposed.
306 ArgTransposition,
307 /// A result matrix is transposed.
308 ResTransposition,
309 /// A number of matrix elements isn't equal to a number of elements.
310 MatrixElemCount(usize, usize),
311 /// A matrix isn't a vector.
312 IsNotVector,
313 /// A mutex can't be locked.
314 Mutex,
315 /// An OpenCL error.
316 #[cfg(feature = "opencl")]
317 OpenCl(opencl::ClError),
318 /// A CUDA error.
319 #[cfg(feature = "cuda")]
320 Cuda(cuda::DriverError),
321 /// A cuBLAS error.
322 #[cfg(feature = "cuda")]
323 Cublas(cuda::CublasError),
324 /// No a cuBLAS.
325 #[cfg(feature = "cuda")]
326 NoCublas,
327 /// A compilation error.
328 Compilation(String),
329 /// No a platform.
330 NoPlatform,
331 /// No a device.
332 NoDevice,
333 /// No a kernel.
334 NoKernel(String),
335 /// A type of device information is invalid.
336 InvalidDeviceInfoType,
337 /// A number of backend array elements isn't equal to a number of elements.
338 BackendArrayElemCount(usize, usize),
339 /// Two numbers of elements of backend arrays aren't equal.
340 TwoBackendArrayElemCounts(usize, usize),
341 /// A backend array is invalid.
342 InvalidBackendArray,
343}
344
345impl error::Error for Error
346{}
347
348impl fmt::Display for Error
349{
350 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result
351 {
352 match self {
353 Error::DefaultBackendInitialization => write!(f, "can't initialize default backend"),
354 Error::OpSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices ({}x{}, {}x{})", n1, m1, n2, m2),
355 Error::MulSize(n1, m1, n2, m2, n3, m3) => write!(f, "mismatched sizes of matrices for multiplication ({}x{}, {}x{}, {}x{})", n1, m1, n2, m2, n3, m3),
356 Error::TransposeSize(n1, m1, n2, m2) => write!(f, "mismatched sizes of matrices for transposition ({}x{}, {}x{})", n1, m1, n2, m2),
357 Error::ArgTransposition => write!(f, "argument matrix is transposed"),
358 Error::ResTransposition => write!(f, "result matrix is transposed"),
359 Error::MatrixElemCount(n1, n2) => write!(f, "number of matrix elements isn't equal to number of elements ({}, {})", n1, n2),
360 Error::IsNotVector => write!(f, "matrix isn't vector"),
361 Error::Mutex => write!(f, "can't lock mutex"),
362 #[cfg(feature = "opencl")]
363 Error::OpenCl(err) => write!(f, "OpenCL error: {}", err),
364 #[cfg(feature = "cuda")]
365 Error::Cuda(err) => write!(f, "CUDA error: {}", err),
366 #[cfg(feature = "cuda")]
367 Error::Cublas(err) => write!(f, "cuBLAS error: {}", err),
368 #[cfg(feature = "cuda")]
369 Error::NoCublas => write!(f, "no cuBLAS"),
370 Error::Compilation(msg) => write!(f, "{}", msg),
371 Error::NoPlatform => write!(f, "no platform"),
372 Error::NoDevice => write!(f, "no device"),
373 Error::NoKernel(name) => write!(f, "no kernel {}", name),
374 Error::InvalidDeviceInfoType => write!(f, "invalid device info type"),
375 Error::BackendArrayElemCount(n1, n2) => write!(f, "number of backend array elements isn't equal to number of elements ({}, {})", n1, n2),
376 Error::TwoBackendArrayElemCounts(n1, n2) => write!(f, "two numbers of elements of backend arrays aren't equal ({}, {})", n1, n2),
377 Error::InvalidBackendArray => write!(f, "invalid backend array"),
378 }
379 }
380}
381
382/// A result type.
383pub type Result<T> = result::Result<T, Error>;
384
385/// An enumeration of backend array.
386///
387/// This enumeration contains the reference to the area of the device memory for computing
388/// platform (OpenCL or CUDA).
389#[derive(Debug)]
390pub enum BackendArray
391{
392 /// A backend array for OpenCL.
393 #[cfg(feature = "opencl")]
394 OpenCl(opencl::ClBackendArray),
395 /// A backend array for CUDA.
396 #[cfg(feature = "cuda")]
397 Cuda(cuda::CudaBackendArray),
398}
399
400static mut DEFAULT_BACKEND: Mutex<Option<Arc<dyn Backend>>> = Mutex::new(None);
401
402fn mutex_lock<T>(mutex: &Mutex<T>) -> Result<MutexGuard<'_, T>>
403{
404 match mutex.lock() {
405 Ok(guard) => Ok(guard),
406 Err(_) => return Err(Error::Mutex),
407 }
408}
409
410/// Returns a default backend.
411pub fn get_default_backend() -> Result<Option<Arc<dyn Backend>>>
412{
413 unsafe {
414 let default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
415 Ok(default_backend_g.clone())
416 }
417}
418
419/// Sets a default backend.
420pub fn set_default_backend(backend: Arc<dyn Backend>) -> Result<()>
421{
422 unsafe {
423 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
424 *default_backend_g = Some(backend);
425 }
426 Ok(())
427}
428
429/// Unsets a default backend.
430pub fn unset_default_backend() -> Result<()>
431{
432 unsafe {
433 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
434 *default_backend_g = None;
435 }
436 Ok(())
437}
438
439/// Sets a default backend if the default backend is uninitialized and returns the default backend.
440///
441/// This method takes a closure that returns the backend and then the backend is set as the default
442/// backend if the default backend is uninitialized. The closure is only called if the backend is
443/// to be set.
444pub fn set_default_backend_for_uninitialized<F>(f: F) -> Result<Arc<dyn Backend>>
445 where F: FnOnce() -> Result<Arc<dyn Backend>>
446{
447 unsafe {
448 let mut default_backend_g = mutex_lock(&DEFAULT_BACKEND)?;
449 match &*default_backend_g {
450 Some(default_backend) => Ok(default_backend.clone()),
451 None => {
452 let backend = f()?;
453 *default_backend_g = Some(backend.clone());
454 Ok(backend)
455 },
456 }
457 }
458}
459
460/// Initializes a default backend if the backend is uninitialized and returns the default backend.
461pub fn initialize_default_backend_for_uninitialized() -> Result<Arc<dyn Backend>>
462{
463 #[cfg(feature = "opencl")]
464 let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(opencl::ClBackend::new()?)));
465 #[cfg(all(not(feature = "opencl"), feature = "cuda"))]
466 let res = set_default_backend_for_uninitialized(|| Ok(Arc::new(cuda::CudaBackend::new()?)));
467 #[cfg(all(not(feature = "opencl"), not(feature = "cuda")))]
468 let res: Result<Arc<dyn Backend>> = Err(Error::DefaultBackendInitialization);
469 res
470}
471
472/// Finalizes a default backend.
473pub fn finalize_default_backend() -> Result<()>
474{ unset_default_backend() }
475
476/// Creates a matrix from the arguments.
477///
478/// # Examples
479///
480/// ```
481/// # use unmtx_gpu::*;
482/// let a = matrix![
483/// [1.0, 2.0, 3.0],
484/// [4.0, 5.0, 6.0]
485/// ];
486/// assert_eq!(2, a.row_count());
487/// assert_eq!(3, a.col_count());
488/// assert_eq!(false, a.is_transposed());
489/// assert_eq!(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], a.elems());
490/// ```
491#[macro_export]
492macro_rules! matrix {
493 ($([$($elem:expr),* $(,)*]),* $(,)*) => {
494 $crate::Matrix::new_with_elem_vecs(vec![$(vec![$($elem),*]),*].as_slice())
495 };
496}
497
498/// A matrix structure.
499#[derive(Clone, Debug)]
500pub struct Matrix
501{
502 row_count: usize,
503 col_count: usize,
504 is_transposed: bool,
505 array: Arc<BackendArray>,
506}
507
508impl Matrix
509{
510 /// Creates a matrix with the number of rows and the number of columns.
511 pub fn new(row_count: usize, col_count: usize) -> Self
512 {
513 let frontend = Frontend::new().unwrap();
514 frontend.create_matrix_and_set_zeros(row_count, col_count).unwrap()
515 }
516
517 /// Creates a matrix with the number of rows, the number of columns, and the elements.
518 pub fn new_with_elems(row_count: usize, col_count: usize, elems: &[f32]) -> Self
519 {
520 let frontend = Frontend::new().unwrap();
521 frontend.create_matrix_and_set_elems(row_count, col_count, elems).unwrap()
522 }
523
524 /// Creates a matrix with the vector of rows.
525 pub fn new_with_elem_vecs(elem_vecs: &[Vec<f32>]) -> Self
526 {
527 let frontend = Frontend::new().unwrap();
528 let col_count = match elem_vecs.first() {
529 Some(elems) => elems.len(),
530 None => 0,
531 };
532 for row in elem_vecs {
533 assert_eq!(col_count, row.len());
534 }
535 let row_count = elem_vecs.len();
536 let elems: Vec<f32> = elem_vecs.iter().flatten().map(|e| *e).collect();
537 frontend.create_matrix_and_set_elems(row_count, col_count, elems.as_slice()).unwrap()
538 }
539
540 /// Returns the number of matrix rows.
541 pub fn row_count(&self) -> usize
542 { self.row_count }
543
544 /// Returns the number of matrix columns.
545 pub fn col_count(&self) -> usize
546 { self.col_count }
547
548 /// Returns `true` if the matrix is transposed, otherwise `false`.
549 ///
550 /// This method indeed returns the transpose flag of matrix that is changed by
551 /// [`transpose`](Self::transpose).
552 pub fn is_transposed(&self) -> bool
553 { self.is_transposed }
554
555 /// Returns the matrix elements.
556 pub fn elems(&self) -> Vec<f32>
557 {
558 let frontend = Frontend::new().unwrap();
559 frontend.elems_and_transpose_flag(self).unwrap().0
560 }
561
562 /// Creates a matrix copy.
563 ///
564 /// This method indeed copies the matrix array to a new matrix array.
565 pub fn copy(&self) -> Self
566 {
567 let frontend = Frontend::new().unwrap();
568 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
569 frontend.copy(self, &res).unwrap();
570 res
571 }
572
573 /// Transposes the matrix
574 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
575 ///
576 /// This method doesn't indeed transpose the matrix but changes the transpose flag and
577 /// exchanges the number of matrix rows with the number of matrix columns.
578 ///
579 /// # Examples
580 ///
581 /// ```
582 /// # use unmtx_gpu::*;
583 /// let a = matrix![
584 /// [1.0, 2.0, 3.0],
585 /// [4.0, 5.0, 6.0]
586 /// ];
587 /// let b = a.transpose();
588 /// assert_eq!(3, b.row_count());
589 /// assert_eq!(2, b.col_count());
590 /// assert_eq!(true, b.is_transposed());
591 /// assert_eq!(a.elems(), b.elems());
592 /// let c = b.transpose();
593 /// assert_eq!(2, c.row_count());
594 /// assert_eq!(3, c.col_count());
595 /// assert_eq!(false, c.is_transposed());
596 /// assert_eq!(a.elems(), c.elems());
597 /// ```
598 pub fn transpose(&self) -> Self
599 {
600 Matrix {
601 row_count: self.col_count,
602 col_count: self.row_count,
603 is_transposed: !self.is_transposed,
604 array: self.array.clone(),
605 }
606 }
607
608 /// See [`transpose`](Self::transpose).
609 pub fn t(&self) -> Self
610 { self.transpose() }
611
612 /// Indeed transposes the matrix
613 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
614 ///
615 /// This method indeed transposes the matrix without changing the transpose flag.
616 ///
617 /// # Examples
618 ///
619 /// ```
620 /// # use unmtx_gpu::*;
621 /// let a = matrix![
622 /// [1.0, 2.0, 3.0],
623 /// [4.0, 5.0, 6.0]
624 /// ];
625 /// let b = a.really_transpose();
626 /// assert_eq!(3, b.row_count());
627 /// assert_eq!(2, b.col_count());
628 /// assert_eq!(false, b.is_transposed());
629 /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
630 /// ```
631 pub fn really_transpose(&self) -> Self
632 {
633 let frontend = Frontend::new().unwrap();
634 let res = unsafe { frontend.create_matrix(self.col_count, self.row_count) }.unwrap();
635 frontend.really_transpose(self, &res).unwrap();
636 res
637 }
638
639 /// See [`really_transpose`](Self::really_transpose).
640 pub fn rt(&self) -> Self
641 { self.really_transpose() }
642
643 /// Multiplies the matrix elements by the `b` matrix elements
644 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>·</mo><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow></math>).
645 ///
646 /// # Examples
647 ///
648 /// ```
649 /// # use unmtx_gpu::*;
650 /// let a = matrix![
651 /// [1.0, 2.0],
652 /// [3.0, 4.0]
653 /// ];
654 /// let b = matrix![
655 /// [5.0, 6.0],
656 /// [7.0, 8.0]
657 /// ];
658 /// let c = a.mul_elems(&b);
659 /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
660 /// ```
661 pub fn mul_elems(&self, b: &Self) -> Self
662 {
663 let frontend = Frontend::new().unwrap();
664 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
665 frontend.mul_elems(self, b, &res).unwrap();
666 res
667 }
668
669 /// Divides the matrix elements by the `b` matrix elements
670 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
671 ///
672 /// # Examples
673 ///
674 /// ```
675 /// # use unmtx_gpu::*;
676 /// let a = matrix![
677 /// [1.0, 2.0],
678 /// [3.0, 4.0]
679 /// ];
680 /// let b = matrix![
681 /// [5.0, 6.0],
682 /// [7.0, 8.0]
683 /// ];
684 /// let c = a.div_elems(&b);
685 /// let elems = c.elems();
686 /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
687 /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
688 /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
689 /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
690 /// ```
691 pub fn div_elems(&self, b: &Self) -> Self
692 {
693 let frontend = Frontend::new().unwrap();
694 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
695 frontend.div_elems(self, b, &res).unwrap();
696 res
697 }
698
699 /// Subtracts the matrix from the scalar
700 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
701 ///
702 /// # Examples
703 ///
704 /// ```
705 /// # use unmtx_gpu::*;
706 /// let a = matrix![
707 /// [1.0, 2.0],
708 /// [3.0, 4.0]
709 /// ];
710 /// let b = a.rsub(10.5);
711 /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], b.elems());
712 /// ```
713 pub fn rsub(&self, b: f32) -> Self
714 {
715 let frontend = Frontend::new().unwrap();
716 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
717 frontend.rsub_for_scalar(self, b, &res).unwrap();
718 res
719 }
720
721 /// Divides the scalar by the matrix elements
722 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mfrac><mi>b</mi><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
723 ///
724 /// # Examples
725 ///
726 /// ```
727 /// # use unmtx_gpu::*;
728 /// let a = matrix![
729 /// [1.0, 2.0],
730 /// [3.0, 4.0]
731 /// ];
732 /// let b = a.rdiv(10.5);
733 /// let elems = b.elems();
734 /// assert!((10.5 / 1.0 - elems[0]).abs() < 0.001);
735 /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
736 /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
737 /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
738 /// ```
739 pub fn rdiv(&self, b: f32) -> Self
740 {
741 let frontend = Frontend::new().unwrap();
742 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
743 frontend.rdiv_for_scalar(self, b, &res).unwrap();
744 res
745 }
746
747 /// Calculates sigmoid function for the matrix
748 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
749 ///
750 /// # Examples
751 ///
752 /// ```
753 /// # use unmtx_gpu::*;
754 /// let a = matrix![
755 /// [1.0, 2.0],
756 /// [3.0, 4.0]
757 /// ];
758 /// let b = a.sigmoid();
759 /// let elems = b.elems();
760 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
761 /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
762 /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
763 /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
764 /// ```
765 pub fn sigmoid(&self) -> Self
766 {
767 let frontend = Frontend::new().unwrap();
768 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
769 frontend.sigmoid(self, &res).unwrap();
770 res
771 }
772
773 /// Calculates hiperbolic tangent function for the matrix
774 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
775 ///
776 /// # Examples
777 ///
778 /// ```
779 /// # use unmtx_gpu::*;
780 /// let a = matrix![
781 /// [1.0, 2.0],
782 /// [3.0, 4.0]
783 /// ];
784 /// let b = a.tanh();
785 /// let elems = b.elems();
786 /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
787 /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
788 /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
789 /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
790 /// ```
791 pub fn tanh(&self) -> Self
792 {
793 let frontend = Frontend::new().unwrap();
794 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
795 frontend.tanh(self, &res).unwrap();
796 res
797 }
798
799 /// Calculates swish function for the matrix
800 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>swish</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
801 ///
802 /// # Examples
803 ///
804 /// ```
805 /// # use unmtx_gpu::*;
806 /// let a = matrix![
807 /// [1.0, 2.0],
808 /// [3.0, 4.0]
809 /// ];
810 /// let b = a.swish();
811 /// let elems = b.elems();
812 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
813 /// assert!((2.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
814 /// assert!((3.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
815 /// assert!((4.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
816 /// ```
817 pub fn swish(&self) -> Self
818 {
819 let frontend = Frontend::new().unwrap();
820 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
821 frontend.swish(self, &res).unwrap();
822 res
823 }
824
825 /// Calculates softmax function for the matrix
826 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
827 ///
828 /// # Examples
829 ///
830 /// ```
831 /// # use unmtx_gpu::*;
832 /// let a = matrix![
833 /// [1.0, 2.0],
834 /// [3.0, 4.0]
835 /// ];
836 /// let b = a.softmax();
837 /// let elems = b.elems();
838 /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
839 /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
840 /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
841 /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
842 /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
843 /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
844 /// ```
845 pub fn softmax(&self) -> Self
846 {
847 let frontend = Frontend::new().unwrap();
848 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
849 frontend.softmax(self, &res).unwrap();
850 res
851 }
852
853 /// Repeats the vector as column or a row.
854 ///
855 /// # Examples
856 ///
857 /// ```
858 /// # use unmtx_gpu::*;
859 /// let a = matrix![
860 /// [1.0],
861 /// [2.0]
862 /// ];
863 /// let b = a.repeat(3);
864 /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
865 /// let c = matrix![[1.0, 2.0, 3.0]];
866 /// let d = c.repeat(2);
867 /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
868 /// ```
869 pub fn repeat(&self, n: usize) -> Self
870 {
871 assert!(self.col_count == 1 || self.row_count == 1);
872 let frontend = Frontend::new().unwrap();
873 let res = if self.col_count == 1 {
874 unsafe { frontend.create_matrix(self.row_count, n) }.unwrap()
875 } else {
876 unsafe { frontend.create_matrix(n, self.col_count) }.unwrap()
877 };
878 frontend.repeat(self, &res).unwrap();
879 res
880 }
881}
882
883impl Add for Matrix
884{
885 type Output = Self;
886
887 fn add(self, rhs: Self) -> Self::Output
888 {
889 let frontend = Frontend::new().unwrap();
890 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
891 frontend.add(&self, &rhs, &res).unwrap();
892 res
893 }
894}
895
896impl Add<&Matrix> for Matrix
897{
898 type Output = Self;
899
900 fn add(self, rhs: &Matrix) -> Self::Output
901 {
902 let frontend = Frontend::new().unwrap();
903 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
904 frontend.add(&self, rhs, &res).unwrap();
905 res
906 }
907}
908
909impl Add<f32> for Matrix
910{
911 type Output = Self;
912
913 fn add(self, rhs: f32) -> Self::Output
914 {
915 let frontend = Frontend::new().unwrap();
916 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
917 frontend.add_for_scalar(&self, rhs, &res).unwrap();
918 res
919 }
920}
921
922impl Add<&f32> for Matrix
923{
924 type Output = Self;
925
926 fn add(self, rhs: &f32) -> Self::Output
927 {
928 let frontend = Frontend::new().unwrap();
929 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
930 frontend.add_for_scalar(&self, *rhs, &res).unwrap();
931 res
932 }
933}
934
935impl Add<Matrix> for &Matrix
936{
937 type Output = Matrix;
938
939 fn add(self, rhs: Matrix) -> Self::Output
940 {
941 let frontend = Frontend::new().unwrap();
942 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
943 frontend.add(self, &rhs, &res).unwrap();
944 res
945 }
946}
947
948impl Add<&Matrix> for &Matrix
949{
950 type Output = Matrix;
951
952 fn add(self, rhs: &Matrix) -> Self::Output
953 {
954 let frontend = Frontend::new().unwrap();
955 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
956 frontend.add(self, rhs, &res).unwrap();
957 res
958 }
959}
960
961impl Add<f32> for &Matrix
962{
963 type Output = Matrix;
964
965 fn add(self, rhs: f32) -> Self::Output
966 {
967 let frontend = Frontend::new().unwrap();
968 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
969 frontend.add_for_scalar(self, rhs, &res).unwrap();
970 res
971 }
972}
973
974impl Add<&f32> for &Matrix
975{
976 type Output = Matrix;
977
978 fn add(self, rhs: &f32) -> Self::Output
979 {
980 let frontend = Frontend::new().unwrap();
981 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
982 frontend.add_for_scalar(self, *rhs, &res).unwrap();
983 res
984 }
985}
986
987impl AddAssign for Matrix
988{
989 fn add_assign(&mut self, rhs: Self)
990 {
991 let frontend = Frontend::new().unwrap();
992 frontend.add(self, &rhs, &self).unwrap();
993 }
994}
995
996impl AddAssign<&Matrix> for Matrix
997{
998 fn add_assign(&mut self, rhs: &Self)
999 {
1000 let frontend = Frontend::new().unwrap();
1001 frontend.add(&self, rhs, &self).unwrap();
1002 }
1003}
1004
1005impl AddAssign<f32> for Matrix
1006{
1007 fn add_assign(&mut self, rhs: f32)
1008 {
1009 let frontend = Frontend::new().unwrap();
1010 frontend.add_for_scalar(&self, rhs, &self).unwrap();
1011 }
1012}
1013
1014impl AddAssign<&f32> for Matrix
1015{
1016 fn add_assign(&mut self, rhs: &f32)
1017 {
1018 let frontend = Frontend::new().unwrap();
1019 frontend.add_for_scalar(&self, *rhs, &self).unwrap();
1020 }
1021}
1022
1023impl Sub for Matrix
1024{
1025 type Output = Self;
1026
1027 fn sub(self, rhs: Self) -> Self::Output
1028 {
1029 let frontend = Frontend::new().unwrap();
1030 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1031 frontend.sub(&self, &rhs, &res).unwrap();
1032 res
1033 }
1034}
1035
1036impl Sub<&Matrix> for Matrix
1037{
1038 type Output = Self;
1039
1040 fn sub(self, rhs: &Matrix) -> Self::Output
1041 {
1042 let frontend = Frontend::new().unwrap();
1043 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1044 frontend.sub(&self, rhs, &res).unwrap();
1045 res
1046 }
1047}
1048
1049impl Sub<f32> for Matrix
1050{
1051 type Output = Self;
1052
1053 fn sub(self, rhs: f32) -> Self::Output
1054 {
1055 let frontend = Frontend::new().unwrap();
1056 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1057 frontend.sub_for_scalar(&self, rhs, &res).unwrap();
1058 res
1059 }
1060}
1061
1062impl Sub<&f32> for Matrix
1063{
1064 type Output = Self;
1065
1066 fn sub(self, rhs: &f32) -> Self::Output
1067 {
1068 let frontend = Frontend::new().unwrap();
1069 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1070 frontend.sub_for_scalar(&self, *rhs, &res).unwrap();
1071 res
1072 }
1073}
1074
1075impl Sub<Matrix> for &Matrix
1076{
1077 type Output = Matrix;
1078
1079 fn sub(self, rhs: Matrix) -> Self::Output
1080 {
1081 let frontend = Frontend::new().unwrap();
1082 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1083 frontend.sub(self, &rhs, &res).unwrap();
1084 res
1085 }
1086}
1087
1088impl Sub<&Matrix> for &Matrix
1089{
1090 type Output = Matrix;
1091
1092 fn sub(self, rhs: &Matrix) -> Self::Output
1093 {
1094 let frontend = Frontend::new().unwrap();
1095 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1096 frontend.sub(self, rhs, &res).unwrap();
1097 res
1098 }
1099}
1100
1101impl Sub<f32> for &Matrix
1102{
1103 type Output = Matrix;
1104
1105 fn sub(self, rhs: f32) -> Self::Output
1106 {
1107 let frontend = Frontend::new().unwrap();
1108 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1109 frontend.sub_for_scalar(self, rhs, &res).unwrap();
1110 res
1111 }
1112}
1113
1114impl Sub<&f32> for &Matrix
1115{
1116 type Output = Matrix;
1117
1118 fn sub(self, rhs: &f32) -> Self::Output
1119 {
1120 let frontend = Frontend::new().unwrap();
1121 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1122 frontend.sub_for_scalar(self, *rhs, &res).unwrap();
1123 res
1124 }
1125}
1126
1127impl SubAssign for Matrix
1128{
1129 fn sub_assign(&mut self, rhs: Self)
1130 {
1131 let frontend = Frontend::new().unwrap();
1132 frontend.sub(&self, &rhs, &self).unwrap();
1133 }
1134}
1135
1136impl SubAssign<&Matrix> for Matrix
1137{
1138 fn sub_assign(&mut self, rhs: &Self)
1139 {
1140 let frontend = Frontend::new().unwrap();
1141 frontend.sub(&self, rhs, &self).unwrap();
1142 }
1143}
1144
1145impl SubAssign<f32> for Matrix
1146{
1147 fn sub_assign(&mut self, rhs: f32)
1148 {
1149 let frontend = Frontend::new().unwrap();
1150 frontend.sub_for_scalar(&self, rhs, &self).unwrap();
1151 }
1152}
1153
1154impl SubAssign<&f32> for Matrix
1155{
1156 fn sub_assign(&mut self, rhs: &f32)
1157 {
1158 let frontend = Frontend::new().unwrap();
1159 frontend.sub_for_scalar(&self, *rhs, &self).unwrap();
1160 }
1161}
1162
1163impl Mul for Matrix
1164{
1165 type Output = Self;
1166
1167 fn mul(self, rhs: Self) -> Self::Output
1168 {
1169 let frontend = Frontend::new().unwrap();
1170 let res = if frontend.backend().has_cublas() {
1171 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1172 } else {
1173 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1174 };
1175 frontend.mul(&self, &rhs, &res).unwrap();
1176 res
1177 }
1178}
1179
1180impl Mul<&Matrix> for Matrix
1181{
1182 type Output = Self;
1183
1184 fn mul(self, rhs: &Matrix) -> Self::Output
1185 {
1186 let frontend = Frontend::new().unwrap();
1187 let res = if frontend.backend().has_cublas() {
1188 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1189 } else {
1190 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1191 };
1192 frontend.mul(&self, rhs, &res).unwrap();
1193 res
1194 }
1195}
1196
1197impl Mul<f32> for Matrix
1198{
1199 type Output = Self;
1200
1201 fn mul(self, rhs: f32) -> Self::Output
1202 {
1203 let frontend = Frontend::new().unwrap();
1204 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1205 frontend.mul_for_scalar(&self, rhs, &res).unwrap();
1206 res
1207 }
1208}
1209
1210impl Mul<&f32> for Matrix
1211{
1212 type Output = Self;
1213
1214 fn mul(self, rhs: &f32) -> Self::Output
1215 {
1216 let frontend = Frontend::new().unwrap();
1217 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1218 frontend.mul_for_scalar(&self, *rhs, &res).unwrap();
1219 res
1220 }
1221}
1222
1223impl Mul<Matrix> for &Matrix
1224{
1225 type Output = Matrix;
1226
1227 fn mul(self, rhs: Matrix) -> Self::Output
1228 {
1229 let frontend = Frontend::new().unwrap();
1230 let res = if frontend.backend().has_cublas() {
1231 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1232 } else {
1233 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1234 };
1235 frontend.mul(self, &rhs, &res).unwrap();
1236 res
1237 }
1238}
1239
1240impl Mul<&Matrix> for &Matrix
1241{
1242 type Output = Matrix;
1243
1244 fn mul(self, rhs: &Matrix) -> Self::Output
1245 {
1246 let frontend = Frontend::new().unwrap();
1247 let res = if frontend.backend().has_cublas() {
1248 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1249 } else {
1250 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1251 };
1252 frontend.mul(self, rhs, &res).unwrap();
1253 res
1254 }
1255}
1256
1257impl Mul<f32> for &Matrix
1258{
1259 type Output = Matrix;
1260
1261 fn mul(self, rhs: f32) -> Self::Output
1262 {
1263 let frontend = Frontend::new().unwrap();
1264 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1265 frontend.mul_for_scalar(self, rhs, &res).unwrap();
1266 res
1267 }
1268}
1269
1270impl Mul<&f32> for &Matrix
1271{
1272 type Output = Matrix;
1273
1274 fn mul(self, rhs: &f32) -> Self::Output
1275 {
1276 let frontend = Frontend::new().unwrap();
1277 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1278 frontend.mul_for_scalar(self, *rhs, &res).unwrap();
1279 res
1280 }
1281}
1282
1283impl MulAssign for Matrix
1284{
1285 fn mul_assign(&mut self, rhs: Self)
1286 {
1287 let frontend = Frontend::new().unwrap();
1288 let res = if frontend.backend().has_cublas() {
1289 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1290 } else {
1291 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1292 };
1293 frontend.mul(&self, &rhs, &res).unwrap();
1294 *self = res;
1295 }
1296}
1297
1298impl MulAssign<&Matrix> for Matrix
1299{
1300 fn mul_assign(&mut self, rhs: &Self)
1301 {
1302 let frontend = Frontend::new().unwrap();
1303 let res = if frontend.backend().has_cublas() {
1304 frontend.create_matrix_and_set_zeros(self.row_count, rhs.col_count).unwrap()
1305 } else {
1306 unsafe { frontend.create_matrix(self.row_count, rhs.col_count) }.unwrap()
1307 };
1308 frontend.mul(&self, rhs, &res).unwrap();
1309 *self = res;
1310 }
1311}
1312
1313impl MulAssign<f32> for Matrix
1314{
1315 fn mul_assign(&mut self, rhs: f32)
1316 {
1317 let frontend = Frontend::new().unwrap();
1318 frontend.mul_for_scalar(&self, rhs, &self).unwrap();
1319 }
1320}
1321
1322impl MulAssign<&f32> for Matrix
1323{
1324 fn mul_assign(&mut self, rhs: &f32)
1325 {
1326 let frontend = Frontend::new().unwrap();
1327 frontend.mul_for_scalar(&self, *rhs, &self).unwrap();
1328 }
1329}
1330
1331impl Div<f32> for Matrix
1332{
1333 type Output = Self;
1334
1335 fn div(self, rhs: f32) -> Self::Output
1336 {
1337 let frontend = Frontend::new().unwrap();
1338 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1339 frontend.div_for_scalar(&self, rhs, &res).unwrap();
1340 res
1341 }
1342}
1343
1344impl Div<&f32> for Matrix
1345{
1346 type Output = Self;
1347
1348 fn div(self, rhs: &f32) -> Self::Output
1349 {
1350 let frontend = Frontend::new().unwrap();
1351 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1352 frontend.div_for_scalar(&self, *rhs, &res).unwrap();
1353 res
1354 }
1355}
1356
1357impl Div<f32> for &Matrix
1358{
1359 type Output = Matrix;
1360
1361 fn div(self, rhs: f32) -> Self::Output
1362 {
1363 let frontend = Frontend::new().unwrap();
1364 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1365 frontend.div_for_scalar(self, rhs, &res).unwrap();
1366 res
1367 }
1368}
1369
1370impl Div<&f32> for &Matrix
1371{
1372 type Output = Matrix;
1373
1374 fn div(self, rhs: &f32) -> Self::Output
1375 {
1376 let frontend = Frontend::new().unwrap();
1377 let res = unsafe { frontend.create_matrix(self.row_count, self.col_count) }.unwrap();
1378 frontend.div_for_scalar(self, *rhs, &res).unwrap();
1379 res
1380 }
1381}
1382
1383impl DivAssign<f32> for Matrix
1384{
1385 fn div_assign(&mut self, rhs: f32)
1386 {
1387 initialize_default_backend_for_uninitialized().unwrap();
1388 let frontend = Frontend::new().unwrap();
1389 frontend.div_for_scalar(&self, rhs, &self).unwrap();
1390 }
1391}
1392
1393impl DivAssign<&f32> for Matrix
1394{
1395 fn div_assign(&mut self, rhs: &f32)
1396 {
1397 initialize_default_backend_for_uninitialized().unwrap();
1398 let frontend = Frontend::new().unwrap();
1399 frontend.div_for_scalar(&self, *rhs, &self).unwrap();
1400 }
1401}
1402
1403/// A frontend structure.
1404///
1405/// The frontend contains methods which operate on matrices or calculate functions for the
1406/// matrices. Backend methods are called by the frontend to operate the matrices. The frontend is
1407/// high-level layer that can be directly used by programmer or a [`Matrix`] structure.
1408pub struct Frontend
1409{
1410 backend: Arc<dyn Backend>,
1411}
1412
1413impl Frontend
1414{
1415 /// Creates a frontend with a default backend.
1416 ///
1417 /// This method also automatically initializes a default backend if the default backend is
1418 /// uninitialized.
1419 pub fn new() -> Result<Frontend>
1420 { Ok(Frontend { backend: initialize_default_backend_for_uninitialized()?, }) }
1421
1422 /// Creates a frotend with the backend.
1423 pub fn new_with_backend(backend: Arc<dyn Backend>) -> Frontend
1424 { Frontend { backend, } }
1425
1426 /// Returns the backend.
1427 pub fn backend(&self) -> Arc<dyn Backend>
1428 { self.backend.clone() }
1429
1430 /// Creates a matrix with unset elements.
1431 pub unsafe fn create_matrix(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1432 {
1433 Ok(Matrix {
1434 row_count,
1435 col_count,
1436 is_transposed: false,
1437 array: Arc::new(self.backend.alloc(row_count * col_count)?),
1438 })
1439 }
1440
1441 /// Creates a matrix and sets the matrix elements on zeros.
1442 pub fn create_matrix_and_set_zeros(&self, row_count: usize, col_count: usize) -> Result<Matrix>
1443 {
1444 Ok(Matrix {
1445 row_count,
1446 col_count,
1447 is_transposed: false,
1448 array: Arc::new(self.backend.alloc_and_store_zeros(row_count * col_count)?),
1449 })
1450 }
1451
1452 /// Creates a matrix and sets the matrix elements.
1453 pub fn create_matrix_and_set_elems(&self, row_count: usize, col_count: usize, elems: &[f32]) -> Result<Matrix>
1454 {
1455 if row_count * col_count != elems.len() {
1456 return Err(Error::MatrixElemCount(row_count * col_count, elems.len()));
1457 }
1458 Ok(Matrix {
1459 row_count,
1460 col_count,
1461 is_transposed: false,
1462 array: Arc::new(self.backend.alloc_and_store(elems)?),
1463 })
1464 }
1465
1466 /// Sets the matrix elements.
1467 pub fn set_elems(&self, a: &Matrix, elems: &[f32]) -> Result<()>
1468 {
1469 if a.row_count() * a.col_count() != elems.len() {
1470 return Err(Error::MatrixElemCount(a.row_count() * a.col_count(), elems.len()));
1471 }
1472 self.backend.store(&*a.array, elems)
1473 }
1474
1475 /// Copies the `a` matrix to the `b` matrix.
1476 ///
1477 /// This method indeed copies the `a` matrix array to the `b` matrix array.
1478 pub fn copy(&self, a: &Matrix, b: &Matrix) -> Result<()>
1479 {
1480 if a.row_count != b.row_count || a.col_count != b.col_count {
1481 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1482 }
1483 self.backend.copy(&*a.array, &*b.array)
1484 }
1485
1486 /// Copies the matrix elements to the mutable slice and the transpose flag to the object that
1487 /// is referred by the reference.
1488 pub fn get_elems_and_transpose_flag(&self, a: &Matrix, elems: &mut [f32], is_transposed: &mut bool) -> Result<()>
1489 {
1490 if a.row_count * a.col_count != elems.len() {
1491 return Err(Error::MatrixElemCount(a.row_count * a.col_count, elems.len()));
1492 }
1493 if !a.is_transposed {
1494 self.backend.load(&*a.array, elems)?;
1495 } else {
1496 self.backend.load(&*a.array, elems)?;
1497 }
1498 *is_transposed = true;
1499 Ok(())
1500 }
1501
1502 /// Returns the elements and the transpose flag of matrix.
1503 pub fn elems_and_transpose_flag(&self, a: &Matrix) -> Result<(Vec<f32>, bool)>
1504 {
1505 let mut elems: Vec<f32> = vec![0.0; a.row_count * a.col_count];
1506 let mut is_transposed = false;
1507 self.get_elems_and_transpose_flag(a, elems.as_mut_slice(), &mut is_transposed)?;
1508 Ok((elems, is_transposed))
1509 }
1510
1511 /// Adds the `a` matrix onto the `b` matrix and then the result is in the `c` matrix
1512 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi mathvariant="bold">B</mi></mrow></math>).
1513 ///
1514 /// # Examples
1515 ///
1516 /// ```
1517 /// # use unmtx_gpu::*;
1518 /// let a = matrix![
1519 /// [1.0, 2.0],
1520 /// [3.0, 4.0]
1521 /// ];
1522 /// let b = matrix![
1523 /// [5.0, 6.0],
1524 /// [7.0, 8.0]
1525 /// ];
1526 /// let c = Matrix::new(2, 2);
1527 /// let frontend = Frontend::new().unwrap();
1528 /// frontend.add(&a, &b, &c).unwrap();
1529 /// assert_eq!(vec![1.0 + 5.0, 2.0 + 6.0, 3.0 + 7.0, 4.0 + 8.0], c.elems());
1530 /// ```
1531 pub fn add(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1532 {
1533 if a.row_count != b.row_count || a.col_count != b.col_count {
1534 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1535 }
1536 if a.row_count != c.row_count || a.col_count != c.col_count {
1537 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1538 }
1539 if c.is_transposed {
1540 return Err(Error::ResTransposition);
1541 }
1542 match (a.is_transposed, b.is_transposed) {
1543 (false, false) => self.backend.add_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1544 (true, false) => self.backend.add_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1545 (false, true) => self.backend.add_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1546 (true, true) => self.backend.add_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1547 }
1548 }
1549
1550 /// Subtracts the `b` matrix from the `a` matrix and then the result is in the `c` matrix
1551 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi mathvariant="bold">B</mi></mrow></math>).
1552 ///
1553 /// # Examples
1554 ///
1555 /// ```
1556 /// # use unmtx_gpu::*;
1557 /// let a = matrix![
1558 /// [1.0, 2.0],
1559 /// [3.0, 4.0]
1560 /// ];
1561 /// let b = matrix![
1562 /// [5.0, 6.0],
1563 /// [7.0, 8.0]
1564 /// ];
1565 /// let c = Matrix::new(2, 2);
1566 /// let frontend = Frontend::new().unwrap();
1567 /// frontend.sub(&a, &b, &c).unwrap();
1568 /// assert_eq!(vec![1.0 - 5.0, 2.0 - 6.0, 3.0 - 7.0, 4.0 - 8.0], c.elems());
1569 /// ```
1570 pub fn sub(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1571 {
1572 if a.row_count != b.row_count || a.col_count != b.col_count {
1573 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1574 }
1575 if a.row_count != c.row_count || a.col_count != c.col_count {
1576 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1577 }
1578 if c.is_transposed {
1579 return Err(Error::ResTransposition);
1580 }
1581 match (a.is_transposed, b.is_transposed) {
1582 (false, false) => self.backend.sub_a_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1583 (true, false) => self.backend.sub_at_b(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1584 (false, true) => self.backend.sub_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1585 (true, true) => self.backend.sub_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1586 }
1587 }
1588
1589 /// Multiplies the `a` matrix by the `b` matrix and then the result is in the `c` matrix
1590 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi mathvariant="bold">B</mi></mrow></math>).
1591 ///
1592 /// # Examples
1593 ///
1594 /// ```
1595 /// # use unmtx_gpu::*;
1596 /// let a = matrix![
1597 /// [1.0, 2.0, 3.0],
1598 /// [4.0, 5.0, 6.0]
1599 /// ];
1600 /// let b = matrix![
1601 /// [7.0, 8.0],
1602 /// [9.0, 10.0],
1603 /// [11.0, 12.0]
1604 /// ];
1605 /// let c = Matrix::new(2, 2);
1606 /// let frontend = Frontend::new().unwrap();
1607 /// frontend.mul(&a, &b, &c).unwrap();
1608 /// let c11: f32 = 1.0 * 7.0 + 2.0 * 9.0 + 3.0 * 11.0;
1609 /// let c12: f32 = 1.0 * 8.0 + 2.0 * 10.0 + 3.0 * 12.0;
1610 /// let c21: f32 = 4.0 * 7.0 + 5.0 * 9.0 + 6.0 * 11.0;
1611 /// let c22: f32 = 4.0 * 8.0 + 5.0 * 10.0 + 6.0 * 12.0;
1612 /// assert_eq!(vec![c11, c12, c21, c22], c.elems());
1613 /// ```
1614 pub fn mul(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1615 {
1616 if a.row_count != c.row_count {
1617 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1618 }
1619 if b.col_count != c.col_count {
1620 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1621 }
1622 if a.col_count != b.row_count {
1623 return Err(Error::MulSize(a.row_count, a.col_count, b.row_count, b.col_count, c.row_count, c.col_count));
1624 }
1625 if c.is_transposed {
1626 return Err(Error::ResTransposition);
1627 }
1628 match (a.is_transposed, b.is_transposed) {
1629 (false, false) => self.backend.mul_a_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1630 (true, false) => self.backend.mul_at_b(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1631 (false, true) => self.backend.mul_a_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1632 (true, true) => self.backend.mul_at_bt(&*a.array, &*b.array, &*c.array, a.row_count, b.col_count, a.col_count),
1633 }
1634 }
1635
1636 /// Multiplies the `a` matrix elements by the `b` matrix elements and then the result is in the
1637 /// `c` matrix
1638 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>·</mo><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mrow></math>).
1639 ///
1640 /// # Examples
1641 ///
1642 /// ```
1643 /// # use unmtx_gpu::*;
1644 /// let a = matrix![
1645 /// [1.0, 2.0],
1646 /// [3.0, 4.0]
1647 /// ];
1648 /// let b = matrix![
1649 /// [5.0, 6.0],
1650 /// [7.0, 8.0]
1651 /// ];
1652 /// let c = Matrix::new(2, 2);
1653 /// let frontend = Frontend::new().unwrap();
1654 /// frontend.mul_elems(&a, &b, &c).unwrap();
1655 /// assert_eq!(vec![1.0 * 5.0, 2.0 * 6.0, 3.0 * 7.0, 4.0 * 8.0], c.elems());
1656 /// ```
1657 pub fn mul_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1658 {
1659 if a.row_count != b.row_count || a.col_count != b.col_count {
1660 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1661 }
1662 if a.row_count != c.row_count || a.col_count != c.col_count {
1663 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1664 }
1665 if c.is_transposed {
1666 return Err(Error::ResTransposition);
1667 }
1668 match (a.is_transposed, b.is_transposed) {
1669 (false, false) => self.backend.mul_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1670 (true, false) => self.backend.mul_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1671 (false, true) => self.backend.mul_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1672 (true, true) => self.backend.mul_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1673 }
1674 }
1675
1676 /// Divides the `a` matrix elements by the `b` matrix elements and then the result is in the `c`
1677 /// matrix
1678 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
1679 ///
1680 /// # Examples
1681 ///
1682 /// ```
1683 /// # use unmtx_gpu::*;
1684 /// let a = matrix![
1685 /// [1.0, 2.0],
1686 /// [3.0, 4.0]
1687 /// ];
1688 /// let b = matrix![
1689 /// [5.0, 6.0],
1690 /// [7.0, 8.0]
1691 /// ];
1692 /// let c = Matrix::new(2, 2);
1693 /// let frontend = Frontend::new().unwrap();
1694 /// frontend.div_elems(&a, &b, &c).unwrap();
1695 /// let elems = c.elems();
1696 /// assert!((1.0 / 5.0 - elems[0]).abs() < 0.001);
1697 /// assert!((2.0 / 6.0 - elems[1]).abs() < 0.001);
1698 /// assert!((3.0 / 7.0 - elems[2]).abs() < 0.001);
1699 /// assert!((4.0 / 8.0 - elems[3]).abs() < 0.001);
1700 /// ```
1701 pub fn div_elems(&self, a: &Matrix, b: &Matrix, c: &Matrix) -> Result<()>
1702 {
1703 if a.row_count != b.row_count || a.col_count != b.col_count {
1704 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1705 }
1706 if a.row_count != c.row_count || a.col_count != c.col_count {
1707 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1708 }
1709 if c.is_transposed {
1710 return Err(Error::ResTransposition);
1711 }
1712 match (a.is_transposed, b.is_transposed) {
1713 (false, false) => self.backend.div_a_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1714 (true, false) => self.backend.div_at_b_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1715 (false, true) => self.backend.div_a_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1716 (true, true) => self.backend.div_at_bt_for_elems(&*a.array, &*b.array, &*c.array, a.row_count, a.col_count),
1717 }
1718 }
1719
1720 /// Adds the `a` matrix onto the `b` scalar and then the result is in the `c` matrix
1721 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>+</mo><mi>b</mi></mrow></math>).
1722 ///
1723 /// # Examples
1724 ///
1725 /// ```
1726 /// # use unmtx_gpu::*;
1727 /// let a = matrix![
1728 /// [1.0, 2.0],
1729 /// [3.0, 4.0]
1730 /// ];
1731 /// let c = Matrix::new(2, 2);
1732 /// let frontend = Frontend::new().unwrap();
1733 /// frontend.add_for_scalar(&a, 10.5, &c).unwrap();
1734 /// assert_eq!(vec![1.0 + 10.5, 2.0 + 10.5, 3.0 + 10.5, 4.0 + 10.5], c.elems());
1735 /// ```
1736 pub fn add_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1737 {
1738 if a.row_count != c.row_count || a.col_count != c.col_count {
1739 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1740 }
1741 if c.is_transposed {
1742 return Err(Error::ResTransposition);
1743 }
1744 if !a.is_transposed {
1745 self.backend.add_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1746 } else {
1747 self.backend.add_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1748 }
1749 }
1750
1751 /// Subtracts the `b` scalar from the `a` matrix and then the result is in the `c` matrix
1752 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>-</mo><mi>b</mi></mrow></math>).
1753 ///
1754 /// # Examples
1755 ///
1756 /// ```
1757 /// # use unmtx_gpu::*;
1758 /// let a = matrix![
1759 /// [1.0, 2.0],
1760 /// [3.0, 4.0]
1761 /// ];
1762 /// let c = Matrix::new(2, 2);
1763 /// let frontend = Frontend::new().unwrap();
1764 /// frontend.sub_for_scalar(&a, 10.5, &c).unwrap();
1765 /// assert_eq!(vec![1.0 - 10.5, 2.0 - 10.5, 3.0 - 10.5, 4.0 - 10.5], c.elems());
1766 /// ```
1767 pub fn sub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1768 {
1769 if a.row_count != c.row_count || a.col_count != c.col_count {
1770 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1771 }
1772 if c.is_transposed {
1773 return Err(Error::ResTransposition);
1774 }
1775 if !a.is_transposed {
1776 self.backend.sub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1777 } else {
1778 self.backend.sub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1779 }
1780 }
1781
1782 /// Subtracts the `a` matrix from the `b` scalar and then the result is in the `c` matrix
1783 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi>b</mi><mo>-</mo><mi mathvariant="bold">A</mi></mrow></math>).
1784 ///
1785 /// # Examples
1786 ///
1787 /// ```
1788 /// # use unmtx_gpu::*;
1789 /// let a = matrix![
1790 /// [1.0, 2.0],
1791 /// [3.0, 4.0]
1792 /// ];
1793 /// let c = Matrix::new(2, 2);
1794 /// let frontend = Frontend::new().unwrap();
1795 /// frontend.rsub_for_scalar(&a, 10.5, &c).unwrap();
1796 /// assert_eq!(vec![10.5 - 1.0, 10.5 - 2.0, 10.5 - 3.0, 10.5 - 4.0], c.elems());
1797 /// ```
1798 pub fn rsub_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1799 {
1800 if a.row_count != c.row_count || a.col_count != c.col_count {
1801 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1802 }
1803 if c.is_transposed {
1804 return Err(Error::ResTransposition);
1805 }
1806 if !a.is_transposed {
1807 self.backend.rsub_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1808 } else {
1809 self.backend.rsub_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1810 }
1811 }
1812
1813 /// Multiplies the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1814 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mi mathvariant="bold">A</mi><mo>·</mo><mi>b</mi></mrow></math>).
1815 ///
1816 /// # Examples
1817 ///
1818 /// ```
1819 /// # use unmtx_gpu::*;
1820 /// let a = matrix![
1821 /// [1.0, 2.0],
1822 /// [3.0, 4.0]
1823 /// ];
1824 /// let c = Matrix::new(2, 2);
1825 /// let frontend = Frontend::new().unwrap();
1826 /// frontend.mul_for_scalar(&a, 10.5, &c).unwrap();
1827 /// assert_eq!(vec![1.0 * 10.5, 2.0 * 10.5, 3.0 * 10.5, 4.0 * 10.5], c.elems());
1828 /// ```
1829 pub fn mul_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1830 {
1831 if a.row_count != c.row_count || a.col_count != c.col_count {
1832 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1833 }
1834 if c.is_transposed {
1835 return Err(Error::ResTransposition);
1836 }
1837 if !a.is_transposed {
1838 self.backend.mul_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1839 } else {
1840 self.backend.mul_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1841 }
1842 }
1843
1844 /// Divides the `a` matrix by the `b` scalar and then the result is in the `c` matrix
1845 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">C</mi><mo>=</mo><mfrac><mi mathvariant="bold">A</mi><mi>b</mi></mfrac></mrow></math>).
1846 ///
1847 /// # Examples
1848 ///
1849 /// ```
1850 /// # use unmtx_gpu::*;
1851 /// let a = matrix![
1852 /// [1.0, 2.0],
1853 /// [3.0, 4.0]
1854 /// ];
1855 /// let c = Matrix::new(2, 2);
1856 /// let frontend = Frontend::new().unwrap();
1857 /// frontend.div_for_scalar(&a, 10.5, &c).unwrap();
1858 /// let elems = c.elems();
1859 /// assert!((1.0 / 10.5 - elems[0]).abs() < 0.001);
1860 /// assert!((2.0 / 10.5 - elems[1]).abs() < 0.001);
1861 /// assert!((3.0 / 10.5 - elems[2]).abs() < 0.001);
1862 /// assert!((4.0 / 10.5 - elems[3]).abs() < 0.001);
1863 /// ```
1864 pub fn div_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1865 {
1866 if a.row_count != c.row_count || a.col_count != c.col_count {
1867 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1868 }
1869 if c.is_transposed {
1870 return Err(Error::ResTransposition);
1871 }
1872 if !a.is_transposed {
1873 self.backend.div_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1874 } else {
1875 self.backend.div_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1876 }
1877 }
1878
1879 /// Divides the `b` scalar by the `a` matrix elements and then the result is in the `c` matrix
1880 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>c</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><mfrac><mi>b</mi><msub><mi>a</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub></mfrac></mrow></math>).
1881 ///
1882 /// # Examples
1883 ///
1884 /// ```
1885 /// # use unmtx_gpu::*;
1886 /// let a = matrix![
1887 /// [1.0, 2.0],
1888 /// [3.0, 4.0]
1889 /// ];
1890 /// let c = Matrix::new(2, 2);
1891 /// let frontend = Frontend::new().unwrap();
1892 /// frontend.rdiv_for_scalar(&a, 10.5, &c).unwrap();
1893 /// let elems = c.elems();
1894 /// assert!((10.5 / 1.0- elems[0]).abs() < 0.001);
1895 /// assert!((10.5 / 2.0 - elems[1]).abs() < 0.001);
1896 /// assert!((10.5 / 3.0 - elems[2]).abs() < 0.001);
1897 /// assert!((10.5 / 4.0 - elems[3]).abs() < 0.001);
1898 /// ```
1899 pub fn rdiv_for_scalar(&self, a: &Matrix, b: f32, c: &Matrix) -> Result<()>
1900 {
1901 if a.row_count != c.row_count || a.col_count != c.col_count {
1902 return Err(Error::OpSize(a.row_count, a.col_count, c.row_count, c.col_count));
1903 }
1904 if c.is_transposed {
1905 return Err(Error::ResTransposition);
1906 }
1907 if !a.is_transposed {
1908 self.backend.rdiv_a_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1909 } else {
1910 self.backend.rdiv_at_b_for_scalar(&*a.array, b, &*c.array, a.row_count, a.col_count)
1911 }
1912 }
1913
1914 /// Calculates sigmoid function for the `a` matrix and then the result is in the `b` matrix
1915 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>sigmoid</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1916 ///
1917 /// # Examples
1918 ///
1919 /// ```
1920 /// # use unmtx_gpu::*;
1921 /// let a = matrix![
1922 /// [1.0, 2.0],
1923 /// [3.0, 4.0]
1924 /// ];
1925 /// let b = Matrix::new(2, 2);
1926 /// let frontend = Frontend::new().unwrap();
1927 /// frontend.sigmoid(&a, &b).unwrap();
1928 /// let elems = b.elems();
1929 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
1930 /// assert!((1.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
1931 /// assert!((1.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
1932 /// assert!((1.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
1933 /// ```
1934 pub fn sigmoid(&self, a: &Matrix, b: &Matrix) -> Result<()>
1935 {
1936 if a.row_count != b.row_count || a.col_count != b.col_count {
1937 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1938 }
1939 if b.is_transposed {
1940 return Err(Error::ResTransposition);
1941 }
1942 if !a.is_transposed {
1943 self.backend.sigmoid_a(&*a.array, &*b.array, a.row_count, a.col_count)
1944 } else {
1945 self.backend.sigmoid_at(&*a.array, &*b.array, a.row_count, a.col_count)
1946 }
1947 }
1948
1949 /// Calculates hyperbolic tangent function for the `a` matrix and then the result is in the `b`
1950 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>tanh</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1951 ///
1952 /// # Examples
1953 ///
1954 /// ```
1955 /// # use unmtx_gpu::*;
1956 /// let a = matrix![
1957 /// [1.0, 2.0],
1958 /// [3.0, 4.0]
1959 /// ];
1960 /// let b = Matrix::new(2, 2);
1961 /// let frontend = Frontend::new().unwrap();
1962 /// frontend.tanh(&a, &b).unwrap();
1963 /// let elems = b.elems();
1964 /// assert!((1.0f32.tanh() - elems[0]).abs() < 0.001);
1965 /// assert!((2.0f32.tanh() - elems[1]).abs() < 0.001);
1966 /// assert!((3.0f32.tanh() - elems[2]).abs() < 0.001);
1967 /// assert!((4.0f32.tanh() - elems[3]).abs() < 0.001);
1968 /// ```
1969 pub fn tanh(&self, a: &Matrix, b: &Matrix) -> Result<()>
1970 {
1971 if a.row_count != b.row_count || a.col_count != b.col_count {
1972 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
1973 }
1974 if b.is_transposed {
1975 return Err(Error::ResTransposition);
1976 }
1977 if !a.is_transposed {
1978 self.backend.tanh_a(&*a.array, &*b.array, a.row_count, a.col_count)
1979 } else {
1980 self.backend.tanh_at(&*a.array, &*b.array, a.row_count, a.col_count)
1981 }
1982 }
1983
1984 /// Calculates swish function for the `a` matrix and then the result is in the `b` matrix
1985 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>swish</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
1986 ///
1987 /// # Examples
1988 ///
1989 /// ```
1990 /// # use unmtx_gpu::*;
1991 /// let a = matrix![
1992 /// [1.0, 2.0],
1993 /// [3.0, 4.0]
1994 /// ];
1995 /// let b = Matrix::new(2, 2);
1996 /// let frontend = Frontend::new().unwrap();
1997 /// frontend.swish(&a, &b).unwrap();
1998 /// let elems = b.elems();
1999 /// assert!((1.0 / (1.0 + (-1.0f32).exp()) - elems[0]).abs() < 0.001);
2000 /// assert!((2.0 / (1.0 + (-2.0f32).exp()) - elems[1]).abs() < 0.001);
2001 /// assert!((3.0 / (1.0 + (-3.0f32).exp()) - elems[2]).abs() < 0.001);
2002 /// assert!((4.0 / (1.0 + (-4.0f32).exp()) - elems[3]).abs() < 0.001);
2003 /// ```
2004 pub fn swish(&self, a: &Matrix, b: &Matrix) -> Result<()>
2005 {
2006 if a.row_count != b.row_count || a.col_count != b.col_count {
2007 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2008 }
2009 if b.is_transposed {
2010 return Err(Error::ResTransposition);
2011 }
2012 if !a.is_transposed {
2013 self.backend.swish_a(&*a.array, &*b.array, a.row_count, a.col_count)
2014 } else {
2015 self.backend.swish_at(&*a.array, &*b.array, a.row_count, a.col_count)
2016 }
2017 }
2018
2019 /// Calculates softmax function for the `a` matrix and then the result is in the `b` matrix
2020 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><mi>softmax</mi><mo fence="true">(</mo><mi mathvariant="bold">A</mi><mo fence="true">)</mo></mrow></math>).
2021 ///
2022 /// # Examples
2023 ///
2024 /// ```
2025 /// # use unmtx_gpu::*;
2026 /// let a = matrix![
2027 /// [1.0, 2.0],
2028 /// [3.0, 4.0]
2029 /// ];
2030 /// let b = Matrix::new(2, 2);
2031 /// let frontend = Frontend::new().unwrap();
2032 /// frontend.softmax(&a, &b).unwrap();
2033 /// let elems = b.elems();
2034 /// let sum1 = 1.0f32.exp() + 3.0f32.exp();
2035 /// let sum2 = 2.0f32.exp() + 4.0f32.exp();
2036 /// assert!((1.0f32.exp() / sum1 - elems[0]).abs() < 0.001);
2037 /// assert!((2.0f32.exp() / sum2 - elems[1]).abs() < 0.001);
2038 /// assert!((3.0f32.exp() / sum1 - elems[2]).abs() < 0.001);
2039 /// assert!((4.0f32.exp() / sum2 - elems[3]).abs() < 0.001);
2040 /// ```
2041 pub fn softmax(&self, a: &Matrix, b: &Matrix) -> Result<()>
2042 {
2043 if a.row_count != b.row_count || a.col_count != b.col_count {
2044 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2045 }
2046 if b.is_transposed {
2047 return Err(Error::ResTransposition);
2048 }
2049 if !a.is_transposed {
2050 self.backend.softmax_a(&*a.array, &*b.array, a.row_count, a.col_count)
2051 } else {
2052 self.backend.softmax_at(&*a.array, &*b.array, a.row_count, a.col_count)
2053 }
2054 }
2055
2056 /// Indeed transposes the `a` matrix and then the result is in the `b` matrix
2057 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi mathvariant="bold">B</mi><mo>=</mo><msup><mi mathvariant="bold">A</mi><mi mathvariant="normal">T</mi></msup></mrow></math>).
2058 ///
2059 /// This method indeed transposes the `a` matrix without changing the transpose flag.
2060 ///
2061 /// # Examples
2062 ///
2063 /// ```
2064 /// # use unmtx_gpu::*;
2065 /// let a = matrix![
2066 /// [1.0, 2.0, 3.0],
2067 /// [4.0, 5.0, 6.0]
2068 /// ];
2069 /// let b = Matrix::new(3, 2);
2070 /// let frontend = Frontend::new().unwrap();
2071 /// frontend.really_transpose(&a, &b).unwrap();
2072 /// assert_eq!(vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0], b.elems());
2073 /// ```
2074 pub fn really_transpose(&self, a: &Matrix, b: &Matrix) -> Result<()>
2075 {
2076 if a.row_count != b.col_count || a.col_count != b.row_count {
2077 return Err(Error::TransposeSize(a.row_count, a.col_count, b.row_count, b.col_count));
2078 }
2079 if a.is_transposed {
2080 return Err(Error::ArgTransposition);
2081 }
2082 if b.is_transposed {
2083 return Err(Error::ResTransposition);
2084 }
2085 self.backend.transpose_a(&*a.array, &*b.array, a.col_count, a.row_count)
2086 }
2087
2088 /// Repeats the `a` vector as column or a row
2089 /// (<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi>i</mi></msub></mrow></math> or
2090 /// <math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><msub><mi>b</mi><mrow><mi>i</mi><mi>j</mi></mrow></msub><mo>=</mo><msub><mi>a</mi><mi>j</mi></msub></mrow></math>).
2091 ///
2092 /// # Examples
2093 ///
2094 /// ```
2095 /// # use unmtx_gpu::*;
2096 /// let a = matrix![
2097 /// [1.0],
2098 /// [2.0]
2099 /// ];
2100 /// let b = Matrix::new(2, 3);
2101 /// let frontend = Frontend::new().unwrap();
2102 /// frontend.repeat(&a, &b).unwrap();
2103 /// assert_eq!(vec![1.0, 1.0, 1.0, 2.0, 2.0, 2.0], b.elems());
2104 /// let c = matrix![[1.0, 2.0, 3.0]];
2105 /// let d = Matrix::new(2, 3);
2106 /// let frontend = Frontend::new().unwrap();
2107 /// frontend.repeat(&c, &d).unwrap();
2108 /// assert_eq!(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], d.elems());
2109 /// ```
2110 pub fn repeat(&self, a: &Matrix, b: &Matrix) -> Result<()>
2111 {
2112 if b.is_transposed {
2113 return Err(Error::ResTransposition);
2114 }
2115 if a.col_count == 1 {
2116 if a.row_count != b.row_count {
2117 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2118 }
2119 self.backend.repeat_col_a(&*a.array, &*b.array, a.row_count, b.col_count)
2120 } else if a.row_count == 1 {
2121 if a.col_count != b.col_count {
2122 return Err(Error::OpSize(a.row_count, a.col_count, b.row_count, b.col_count));
2123 }
2124 self.backend.repeat_row_a(&*a.array, &*b.array, b.row_count, a.col_count)
2125 } else {
2126 Err(Error::IsNotVector)
2127 }
2128 }
2129}
2130
2131#[cfg(test)]
2132mod test_helpers;
2133#[cfg(all(test, not(feature = "test_only_backend")))]
2134mod tests;