oxiblas_blas/
lib.rs

1//! `OxiBLAS` BLAS - Pure Rust BLAS implementation.
2//!
3//! This crate provides BLAS (Basic Linear Algebra Subprograms) operations
4//! implemented in pure Rust with SIMD optimization.
5//!
6//! # BLAS Levels
7//!
8//! - **Level 1**: Vector-vector operations (dot, axpy, nrm2, etc.)
9//! - **Level 2**: Matrix-vector operations (gemv, trmv, etc.)
10//! - **Level 3**: Matrix-matrix operations (gemm, trmm, etc.)
11//!
12//! # Example
13//!
14//! ```
15//! use oxiblas_blas::level3::gemm;
16//! use oxiblas_matrix::Mat;
17//!
18//! // Create matrices
19//! let a: Mat<f64> = Mat::filled(100, 50, 1.0);
20//! let b: Mat<f64> = Mat::filled(50, 80, 2.0);
21//! let mut c: Mat<f64> = Mat::zeros(100, 80);
22//!
23//! // GEMM: C = A * B
24//! gemm(1.0, a.as_ref(), b.as_ref(), 0.0, c.as_mut());
25//! ```
26
27#![warn(missing_docs)]
28#![warn(clippy::all)]
29// Stylistic choices for BLAS library
30#![allow(clippy::module_name_repetitions)]
31#![allow(clippy::similar_names)]
32#![allow(clippy::too_many_lines)]
33#![allow(clippy::too_many_arguments)]
34#![allow(clippy::many_single_char_names)]
35#![allow(clippy::cast_possible_truncation)]
36#![allow(clippy::cast_sign_loss)]
37// BLAS uses Self vs typename interchangeably
38#![allow(clippy::use_self)]
39// Constants defined close to usage is clearer for BLAS
40#![allow(clippy::items_after_statements)]
41// Technical terms (OxiBLAS, GEMM, etc.) don't need backticks
42#![allow(clippy::doc_markdown)]
43// BLAS functions have well-known semantics
44#![allow(clippy::missing_errors_doc)]
45#![allow(clippy::missing_panics_doc)]
46// Not critical for performance library
47#![allow(clippy::must_use_candidate)]
48#![allow(clippy::return_self_not_must_use)]
49// Sometimes makes match arms more explicit
50#![allow(clippy::match_same_arms)]
51// Explicit lifetimes can be clearer
52#![allow(clippy::needless_lifetimes)]
53// Index-based loops common in BLAS
54#![allow(clippy::needless_range_loop)]
55// Not all functions need const
56#![allow(clippy::missing_const_for_fn)]
57// API consistency with Result/Option
58#![allow(clippy::unnecessary_wraps)]
59// Raw pointer casting common in SIMD code
60#![allow(clippy::ptr_as_ptr)]
61// SIMD code uses transmute
62#![allow(clippy::transmute_ptr_to_ref)]
63// Casting in BLAS is intentional
64#![allow(clippy::cast_possible_wrap)]
65#![allow(clippy::cast_precision_loss)]
66// CBLAS extern functions have well-known semantics
67#![allow(clippy::missing_safety_doc)]
68// Manual assign clearer for BLAS code
69#![allow(clippy::assign_op_pattern)]
70// Transmute in SIMD code is intentional
71#![allow(clippy::transmute_undefined_repr)]
72#![allow(clippy::missing_transmute_annotations)]
73// SIMD kernels benefit from inline(always)
74#![allow(clippy::inline_always)]
75// Some refs are cfg-gated
76#![allow(clippy::needless_pass_by_ref_mut)]
77// Small types passed by value intentionally
78#![allow(clippy::trivially_copy_pass_by_ref)]
79#![allow(clippy::needless_pass_by_value)]
80// Sometimes makes conditional code clearer
81#![allow(clippy::if_same_then_else)]
82#![allow(clippy::branches_sharing_code)]
83// Strict float comparison sometimes needed
84#![allow(clippy::float_cmp)]
85// Older Rust compatible code
86#![allow(clippy::manual_div_ceil)]
87// Manual copy for specific layouts
88#![allow(clippy::manual_memcpy)]
89// Unused variable patterns are intentional
90#![allow(clippy::no_effect_underscore_binding)]
91
92pub mod accuracy;
93pub mod cblas;
94pub mod complex_interleaved;
95pub mod level1;
96pub mod level2;
97pub mod level3;
98pub mod ndtensor;
99pub mod tensor;
100
101/// Prelude module for convenient imports.
102pub mod prelude {
103    pub use crate::level1::{asum, axpy, copy, dot, iamax, iamin, nrm2, nrm2_sq, scal, swap};
104    pub use crate::level2::{
105        DiagKind,
106        GemvTrans,
107        HerError,
108        HerUplo,
109        SyrError,
110        SyrUplo,
111        TriangularMode,
112        TriangularSide,
113        TrmvError,
114        TrmvOp,
115        TrmvUplo,
116        gemv,
117        gemv_simple,
118        ger,
119        gerc,
120        her,
121        her_new,
122        // Symmetric/Hermitian rank-1 updates
123        syr,
124        syr_new,
125        // Triangular matrix-vector multiply
126        trmv,
127        trmv_alloc,
128        trsv,
129        trsv_in_place,
130    };
131    pub use crate::level3::{
132        Diag,
133        GemmBlocking,
134        GemmKernel,
135        Her2kError,
136        HerkError,
137        Side,
138        Syr2kError,
139        SyrkError,
140        Trans,
141        TrmmDiag,
142        TrmmError,
143        TrmmSide,
144        TrmmTrans,
145        TrmmUplo,
146        Uplo,
147        gemm,
148        gemm_with_par,
149        her2k,
150        her2k_new,
151        herk,
152        herk_new,
153        syr2k,
154        syr2k_new,
155        syrk,
156        syrk_new,
157        // Triangular matrix-matrix multiply
158        trmm,
159        trmm_in_place,
160        trsm,
161    };
162    pub use crate::ndtensor::{NdTensor, NdTensorError, Order};
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use oxiblas_matrix::Mat;
169
170    #[test]
171    fn test_gemm_correctness() {
172        // Test with known values
173        let a: Mat<f64> = Mat::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
174
175        let b: Mat<f64> = Mat::from_rows(&[&[9.0, 8.0, 7.0], &[6.0, 5.0, 4.0], &[3.0, 2.0, 1.0]]);
176
177        let mut c: Mat<f64> = Mat::zeros(3, 3);
178
179        level3::gemm(1.0, a.as_ref(), b.as_ref(), 0.0, c.as_mut());
180
181        // Expected result:
182        // C[0,0] = 1*9 + 2*6 + 3*3 = 9 + 12 + 9 = 30
183        // C[0,1] = 1*8 + 2*5 + 3*2 = 8 + 10 + 6 = 24
184        // C[0,2] = 1*7 + 2*4 + 3*1 = 7 + 8 + 3 = 18
185        // C[1,0] = 4*9 + 5*6 + 6*3 = 36 + 30 + 18 = 84
186        // etc.
187
188        assert!((c[(0, 0)] - 30.0).abs() < 1e-10);
189        assert!((c[(0, 1)] - 24.0).abs() < 1e-10);
190        assert!((c[(0, 2)] - 18.0).abs() < 1e-10);
191        assert!((c[(1, 0)] - 84.0).abs() < 1e-10);
192    }
193
194    #[test]
195    fn test_gemm_non_square() {
196        let a: Mat<f64> = Mat::filled(10, 20, 1.0);
197        let b: Mat<f64> = Mat::filled(20, 15, 1.0);
198        let mut c: Mat<f64> = Mat::zeros(10, 15);
199
200        level3::gemm(1.0, a.as_ref(), b.as_ref(), 0.0, c.as_mut());
201
202        // Each element should be 20 (sum of 20 ones)
203        for i in 0..10 {
204            for j in 0..15 {
205                assert!((c[(i, j)] - 20.0).abs() < 1e-10);
206            }
207        }
208    }
209
210    #[test]
211    fn test_gemm_large() {
212        let n = 128;
213        let a: Mat<f64> = Mat::filled(n, n, 1.0);
214        let b: Mat<f64> = Mat::filled(n, n, 1.0);
215        let mut c: Mat<f64> = Mat::zeros(n, n);
216
217        level3::gemm(1.0, a.as_ref(), b.as_ref(), 0.0, c.as_mut());
218
219        // Each element should be n
220        for i in 0..n {
221            for j in 0..n {
222                assert!(
223                    (c[(i, j)] - n as f64).abs() < 1e-8,
224                    "c[{},{}] = {}, expected {}",
225                    i,
226                    j,
227                    c[(i, j)],
228                    n
229                );
230            }
231        }
232    }
233}