sparse_ir/gemm.rs
1//! Matrix multiplication utilities with pluggable BLAS backend
2//!
3//! This module provides thin wrappers around matrix multiplication operations,
4//! with support for runtime selection of BLAS implementations.
5//!
6//! # Design
7//! - **Default**: Pure Rust Faer backend (no external dependencies)
8//! - **Optional**: External BLAS via function pointer injection
9//! - **Thread-safe**: Global dispatcher protected by RwLock
10//!
11//! # Example
12//! ```ignore
13//! use sparse_ir::gemm::{matmul_par, set_blas_backend};
14//!
15//! // Use default Faer backend
16//! let c = matmul_par(&a, &b);
17//!
18//! // Or inject custom BLAS (from C-API)
19//! unsafe {
20//! set_blas_backend(my_dgemm_ptr, my_zgemm_ptr);
21//! }
22//! let c = matmul_par(&a, &b); // Now uses custom BLAS
23//! ```
24
25use mdarray::{DSlice, DTensor, Layout};
26use once_cell::sync::Lazy;
27use std::sync::{Arc, RwLock};
28
29#[cfg(feature = "system-blas")]
30use blas_sys::dgemm_;
31
32//==============================================================================
33// BLAS Function Pointer Types
34//==============================================================================
35
36/// BLAS dgemm function pointer type (LP64: 32-bit integers)
37///
38/// Signature matches Fortran BLAS dgemm:
39/// ```c
40/// void dgemm_(char *transa, char *transb, int *m, int *n, int *k,
41/// double *alpha, double *a, int *lda, double *b, int *ldb,
42/// double *beta, double *c, int *ldc);
43/// ```
44/// Note: All parameters are passed by reference (pointers).
45/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
46pub type DgemmFnPtr = unsafe extern "C" fn(
47 transa: *const libc::c_char,
48 transb: *const libc::c_char,
49 m: *const libc::c_int,
50 n: *const libc::c_int,
51 k: *const libc::c_int,
52 alpha: *const libc::c_double,
53 a: *const libc::c_double,
54 lda: *const libc::c_int,
55 b: *const libc::c_double,
56 ldb: *const libc::c_int,
57 beta: *const libc::c_double,
58 c: *mut libc::c_double,
59 ldc: *const libc::c_int,
60);
61
62/// BLAS zgemm function pointer type (LP64: 32-bit integers)
63///
64/// Signature matches Fortran BLAS zgemm:
65/// ```c
66/// void zgemm_(char *transa, char *transb, int *m, int *n, int *k,
67/// void *alpha, void *a, int *lda, void *b, int *ldb,
68/// void *beta, void *c, int *ldc);
69/// ```
70/// Note: All parameters are passed by reference (pointers).
71/// Complex numbers are passed as void* (typically complex<double>*).
72/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
73pub type ZgemmFnPtr = unsafe extern "C" fn(
74 transa: *const libc::c_char,
75 transb: *const libc::c_char,
76 m: *const libc::c_int,
77 n: *const libc::c_int,
78 k: *const libc::c_int,
79 alpha: *const num_complex::Complex<f64>,
80 a: *const num_complex::Complex<f64>,
81 lda: *const libc::c_int,
82 b: *const num_complex::Complex<f64>,
83 ldb: *const libc::c_int,
84 beta: *const num_complex::Complex<f64>,
85 c: *mut num_complex::Complex<f64>,
86 ldc: *const libc::c_int,
87);
88
89// When using system BLAS via `blas-sys`, we need a small wrapper to adapt
90// `blas_sys::zgemm_` (which uses `c_double_complex = [f64; 2]`) to the
91// `ZgemmFnPtr` signature that takes `num_complex::Complex<f64>`.
92#[cfg(feature = "system-blas")]
93unsafe extern "C" fn zgemm_wrapper(
94 transa: *const libc::c_char,
95 transb: *const libc::c_char,
96 m: *const libc::c_int,
97 n: *const libc::c_int,
98 k: *const libc::c_int,
99 alpha: *const num_complex::Complex<f64>,
100 a: *const num_complex::Complex<f64>,
101 lda: *const libc::c_int,
102 b: *const num_complex::Complex<f64>,
103 ldb: *const libc::c_int,
104 beta: *const num_complex::Complex<f64>,
105 c: *mut num_complex::Complex<f64>,
106 ldc: *const libc::c_int,
107) {
108 // Safety: `blas_sys::c_double_complex` is defined as `[f64; 2]` and is
109 // layout-compatible with `num_complex::Complex<f64>` in memory, so we can
110 // cast between the two pointer types here.
111 unsafe {
112 blas_sys::zgemm_(
113 transa,
114 transb,
115 m,
116 n,
117 k,
118 alpha as *const _ as *const blas_sys::c_double_complex,
119 a as *const _ as *const blas_sys::c_double_complex,
120 lda,
121 b as *const _ as *const blas_sys::c_double_complex,
122 ldb,
123 beta as *const _ as *const blas_sys::c_double_complex,
124 c as *mut _ as *mut blas_sys::c_double_complex,
125 ldc,
126 );
127 }
128}
129
130/// BLAS dgemm function pointer type (ILP64: 64-bit integers)
131///
132/// Signature matches Fortran BLAS dgemm (ILP64):
133/// ```c
134/// void dgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
135/// double *alpha, double *a, long long *lda, double *b, long long *ldb,
136/// double *beta, double *c, long long *ldc);
137/// ```
138pub type Dgemm64FnPtr = unsafe extern "C" fn(
139 transa: *const libc::c_char,
140 transb: *const libc::c_char,
141 m: *const i64,
142 n: *const i64,
143 k: *const i64,
144 alpha: *const libc::c_double,
145 a: *const libc::c_double,
146 lda: *const i64,
147 b: *const libc::c_double,
148 ldb: *const i64,
149 beta: *const libc::c_double,
150 c: *mut libc::c_double,
151 ldc: *const i64,
152);
153
154/// BLAS zgemm function pointer type (ILP64: 64-bit integers)
155///
156/// Signature matches Fortran BLAS zgemm (ILP64):
157/// ```c
158/// void zgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
159/// void *alpha, void *a, long long *lda, void *b, long long *ldb,
160/// void *beta, void *c, long long *ldc);
161/// ```
162pub type Zgemm64FnPtr = unsafe extern "C" fn(
163 transa: *const libc::c_char,
164 transb: *const libc::c_char,
165 m: *const i64,
166 n: *const i64,
167 k: *const i64,
168 alpha: *const num_complex::Complex<f64>,
169 a: *const num_complex::Complex<f64>,
170 lda: *const i64,
171 b: *const num_complex::Complex<f64>,
172 ldb: *const i64,
173 beta: *const num_complex::Complex<f64>,
174 c: *mut num_complex::Complex<f64>,
175 ldc: *const i64,
176);
177
178//==============================================================================
179// Fortran BLAS Constants
180//==============================================================================
181
182// Fortran BLAS transpose characters
183
184//==============================================================================
185// GemmBackend Trait
186//==============================================================================
187
188/// GEMM backend trait for runtime dispatch
189pub trait GemmBackend: Send + Sync {
190 /// Matrix multiplication: C = A * B (f64)
191 ///
192 /// # Arguments
193 /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
194 /// * `a` - Pointer to matrix A (row-major, M x K)
195 /// * `b` - Pointer to matrix B (row-major, K x N)
196 /// * `c` - Pointer to output matrix C (row-major, M x N)
197 /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
198 unsafe fn dgemm(&self, m: usize, n: usize, k: usize, a: *const f64, b: *const f64, c: *mut f64);
199
200 /// Matrix multiplication: C = A * B (Complex<f64>)
201 ///
202 /// # Arguments
203 /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
204 /// * `a` - Pointer to matrix A (row-major, M x K)
205 /// * `b` - Pointer to matrix B (row-major, K x N)
206 /// * `c` - Pointer to output matrix C (row-major, M x N)
207 /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
208 unsafe fn zgemm(
209 &self,
210 m: usize,
211 n: usize,
212 k: usize,
213 a: *const num_complex::Complex<f64>,
214 b: *const num_complex::Complex<f64>,
215 c: *mut num_complex::Complex<f64>,
216 );
217
218 /// Returns true if this backend uses 64-bit integers (ILP64)
219 fn is_ilp64(&self) -> bool {
220 false
221 }
222
223 /// Returns backend name for debugging
224 fn name(&self) -> &'static str;
225}
226
227//==============================================================================
228// Faer Backend (Default, Pure Rust)
229//==============================================================================
230
231/// Default Faer backend (Pure Rust, no external dependencies)
232struct FaerBackend;
233
234impl GemmBackend for FaerBackend {
235 unsafe fn dgemm(
236 &self,
237 m: usize,
238 n: usize,
239 k: usize,
240 a: *const f64,
241 b: *const f64,
242 c: *mut f64,
243 ) {
244 use mdarray_linalg::matmul::MatMulBuilder;
245 use mdarray_linalg::prelude::MatMul;
246 use mdarray_linalg_faer::Faer;
247
248 // Create tensors from pointers (row-major order)
249 let a_slice = unsafe { std::slice::from_raw_parts(a, m * k) };
250 let b_slice = unsafe { std::slice::from_raw_parts(b, k * n) };
251 let a_tensor = DTensor::<f64, 2>::from_fn([m, k], |idx| a_slice[idx[0] * k + idx[1]]);
252 let b_tensor = DTensor::<f64, 2>::from_fn([k, n], |idx| b_slice[idx[0] * n + idx[1]]);
253
254 // Perform matrix multiplication
255 let c_tensor = Faer.matmul(&*a_tensor, &*b_tensor).parallelize().eval();
256
257 // Copy result back to output pointer (row-major order)
258 // For row-major, ldc = n (number of columns)
259 let ldc = n;
260 let c_slice = unsafe { std::slice::from_raw_parts_mut(c, m * ldc) };
261 for i in 0..m {
262 for j in 0..n {
263 c_slice[i * ldc + j] = c_tensor[[i, j]];
264 }
265 }
266 }
267
268 unsafe fn zgemm(
269 &self,
270 m: usize,
271 n: usize,
272 k: usize,
273 a: *const num_complex::Complex<f64>,
274 b: *const num_complex::Complex<f64>,
275 c: *mut num_complex::Complex<f64>,
276 ) {
277 use mdarray_linalg::matmul::MatMulBuilder;
278 use mdarray_linalg::prelude::MatMul;
279 use mdarray_linalg_faer::Faer;
280
281 // Create tensors from pointers (row-major order)
282 let a_slice = unsafe { std::slice::from_raw_parts(a, m * k) };
283 let b_slice = unsafe { std::slice::from_raw_parts(b, k * n) };
284 let a_tensor = DTensor::<num_complex::Complex<f64>, 2>::from_fn([m, k], |idx| {
285 a_slice[idx[0] * k + idx[1]]
286 });
287 let b_tensor = DTensor::<num_complex::Complex<f64>, 2>::from_fn([k, n], |idx| {
288 b_slice[idx[0] * n + idx[1]]
289 });
290
291 // Perform matrix multiplication
292 let c_tensor = Faer.matmul(&*a_tensor, &*b_tensor).parallelize().eval();
293
294 // Copy result back to output pointer (row-major order)
295 // For row-major, ldc = n (number of columns)
296 let ldc = n;
297 let c_slice = unsafe { std::slice::from_raw_parts_mut(c, m * ldc) };
298 for i in 0..m {
299 for j in 0..n {
300 c_slice[i * ldc + j] = c_tensor[[i, j]];
301 }
302 }
303 }
304
305 fn name(&self) -> &'static str {
306 "Faer (Pure Rust)"
307 }
308}
309
310//==============================================================================
311// External BLAS Backends (LP64 and ILP64)
312//==============================================================================
313
314/// Conversion rules for row-major data to column-major BLAS:
315///
316/// **Goal**: Compute C = A * B where:
317/// - A is m×k (row-major)
318/// - B is k×n (row-major)
319/// - C is m×n (row-major)
320///
321/// **Row-major to column-major interpretation**:
322/// - Row-major A (m×k) appears as A^T (k×m) in column-major → call this At
323/// - Row-major B (k×n) appears as B^T (n×k) in column-major → call this Bt
324/// - Row-major C (m×n) appears as C^T (n×m) in column-major → call this Ct
325/// - To compute C = A * B, we need: C^T = (A * B)^T = B^T * A^T
326/// - So: Ct = Bt * At
327///
328/// **BLAS call transformation**:
329/// - Original: C = A * B (row-major world)
330/// - BLAS call: Ct = Bt * At (column-major world)
331/// - transa = 'N' (Bt is already transposed-looking, no transpose needed)
332/// - transb = 'N' (At is already transposed-looking, no transpose needed)
333/// - Call: dgemm('N', 'N', n, m, k, alpha, B, lda, A, ldb, beta, C, ldc)
334///
335/// **Dimension conversions**:
336/// - m_blas = n (Ct rows = Bt rows)
337/// - n_blas = m (Ct cols = At cols)
338/// - k_blas = k (common dimension)
339/// - lda = n (leading dimension of Bt: n×k in column-major, lda = n)
340/// - ldb = k (leading dimension of At: k×m in column-major, ldb = k)
341/// - ldc = n (leading dimension of Ct: n×m in column-major, ldc = n)
342
343/// External BLAS backend (LP64: 32-bit integers)
344pub struct ExternalBlasBackend {
345 dgemm: DgemmFnPtr,
346 zgemm: ZgemmFnPtr,
347}
348
349impl ExternalBlasBackend {
350 pub fn new(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) -> Self {
351 Self { dgemm, zgemm }
352 }
353}
354
355impl GemmBackend for ExternalBlasBackend {
356 unsafe fn dgemm(
357 &self,
358 m: usize,
359 n: usize,
360 k: usize,
361 a: *const f64,
362 b: *const f64,
363 c: *mut f64,
364 ) {
365 // Validate dimensions fit in i32
366 assert!(
367 m <= i32::MAX as usize,
368 "Matrix dimension m too large for LP64 BLAS"
369 );
370 assert!(
371 n <= i32::MAX as usize,
372 "Matrix dimension n too large for LP64 BLAS"
373 );
374 assert!(
375 k <= i32::MAX as usize,
376 "Matrix dimension k too large for LP64 BLAS"
377 );
378
379 // Fortran BLAS requires all parameters passed by reference
380 // Apply row-major to column-major conversion (see conversion rules above)
381 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
382 let transb = b'N' as libc::c_char; // At is already transposed-looking
383 let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
384 let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
385 let k_i32 = k as i32; // k_blas = k (common dimension)
386 let alpha = 1.0f64;
387 let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
388 let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
389 let beta = 0.0f64;
390 // For row-major C (m×n) viewed as column-major Ct (n×m):
391 // Leading dimension in column-major is the stride between rows
392 // In row-major, stride between rows = number of columns = n
393 // So ldc = n (the number of columns in the original row-major matrix)
394 let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
395
396 unsafe {
397 (self.dgemm)(
398 &transa, &transb, &m_i32, &n_i32, &k_i32, &alpha, b, // B first (Bt)
399 &lda, a, // A second (At)
400 &ldb, &beta, c, &ldc_i32,
401 );
402 }
403 }
404
405 unsafe fn zgemm(
406 &self,
407 m: usize,
408 n: usize,
409 k: usize,
410 a: *const num_complex::Complex<f64>,
411 b: *const num_complex::Complex<f64>,
412 c: *mut num_complex::Complex<f64>,
413 ) {
414 assert!(
415 m <= i32::MAX as usize,
416 "Matrix dimension m too large for LP64 BLAS"
417 );
418 assert!(
419 n <= i32::MAX as usize,
420 "Matrix dimension n too large for LP64 BLAS"
421 );
422 assert!(
423 k <= i32::MAX as usize,
424 "Matrix dimension k too large for LP64 BLAS"
425 );
426
427 // Fortran BLAS requires all parameters passed by reference
428 // Apply row-major to column-major conversion (see conversion rules above)
429 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
430 let transb = b'N' as libc::c_char; // At is already transposed-looking
431 let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
432 let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
433 let k_i32 = k as i32; // k_blas = k (common dimension)
434 let alpha = num_complex::Complex::new(1.0, 0.0);
435 let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
436 let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
437 let beta = num_complex::Complex::new(0.0, 0.0);
438 // For row-major C (m×n) viewed as column-major Ct (n×m):
439 // Leading dimension in column-major is the stride between rows = n
440 let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
441
442 unsafe {
443 (self.zgemm)(
444 &transa,
445 &transb,
446 &m_i32,
447 &n_i32,
448 &k_i32,
449 &alpha,
450 b as *const _, // B first (Bt)
451 &lda,
452 a as *const _, // A second (At)
453 &ldb,
454 &beta,
455 c as *mut _,
456 &ldc_i32,
457 );
458 }
459 }
460
461 fn name(&self) -> &'static str {
462 "External BLAS (LP64)"
463 }
464}
465
466/// External BLAS backend (ILP64: 64-bit integers)
467pub struct ExternalBlas64Backend {
468 dgemm64: Dgemm64FnPtr,
469 zgemm64: Zgemm64FnPtr,
470}
471
472impl ExternalBlas64Backend {
473 pub fn new(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) -> Self {
474 Self { dgemm64, zgemm64 }
475 }
476}
477
478impl GemmBackend for ExternalBlas64Backend {
479 unsafe fn dgemm(
480 &self,
481 m: usize,
482 n: usize,
483 k: usize,
484 a: *const f64,
485 b: *const f64,
486 c: *mut f64,
487 ) {
488 // Fortran BLAS requires all parameters passed by reference
489 // Apply row-major to column-major conversion (see conversion rules above)
490 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
491 let transb = b'N' as libc::c_char; // At is already transposed-looking
492 let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
493 let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
494 let k_i64 = k as i64; // k_blas = k (common dimension)
495 let alpha = 1.0f64;
496 let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
497 let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
498 let beta = 0.0f64;
499 // For row-major C (m×n) viewed as column-major Ct (n×m):
500 // Leading dimension in column-major is the stride between rows = n
501 let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
502
503 unsafe {
504 (self.dgemm64)(
505 &transa, &transb, &m_i64, &n_i64, &k_i64, &alpha, b, // B first (Bt)
506 &lda, a, // A second (At)
507 &ldb, &beta, c, &ldc_i64,
508 );
509 }
510 }
511
512 unsafe fn zgemm(
513 &self,
514 m: usize,
515 n: usize,
516 k: usize,
517 a: *const num_complex::Complex<f64>,
518 b: *const num_complex::Complex<f64>,
519 c: *mut num_complex::Complex<f64>,
520 ) {
521 // Fortran BLAS requires all parameters passed by reference
522 // Apply row-major to column-major conversion (see conversion rules above)
523 let transa = b'N' as libc::c_char; // Bt is already transposed-looking
524 let transb = b'N' as libc::c_char; // At is already transposed-looking
525 let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
526 let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
527 let k_i64 = k as i64; // k_blas = k (common dimension)
528 let alpha = num_complex::Complex::new(1.0, 0.0);
529 let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
530 let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
531 let beta = num_complex::Complex::new(0.0, 0.0);
532 // For row-major C (m×n) viewed as column-major Ct (n×m):
533 // Leading dimension in column-major is the stride between rows = n
534 let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
535
536 unsafe {
537 (self.zgemm64)(
538 &transa,
539 &transb,
540 &m_i64,
541 &n_i64,
542 &k_i64,
543 &alpha,
544 b as *const _, // B first (Bt)
545 &lda,
546 a as *const _, // A second (At)
547 &ldb,
548 &beta,
549 c as *mut _,
550 &ldc_i64,
551 );
552 }
553 }
554
555 fn is_ilp64(&self) -> bool {
556 true
557 }
558
559 fn name(&self) -> &'static str {
560 "External BLAS (ILP64)"
561 }
562}
563
564//==============================================================================
565// Backend Handle
566//==============================================================================
567
568/// Thread-safe handle to a GEMM backend
569///
570/// This type wraps an `Arc<dyn GemmBackend>` to allow sharing a backend
571/// across multiple function calls without global state.
572///
573/// # Example
574/// ```ignore
575/// use sparse_ir::gemm::GemmBackendHandle;
576///
577/// let backend = GemmBackendHandle::default();
578/// let result = matmul_par(&a, &b, Some(&backend));
579/// ```
580#[derive(Clone)]
581pub struct GemmBackendHandle {
582 inner: Arc<dyn GemmBackend>,
583}
584
585impl GemmBackendHandle {
586 /// Create a new backend handle from a boxed backend
587 pub fn new(backend: Box<dyn GemmBackend>) -> Self {
588 Self {
589 inner: Arc::from(backend),
590 }
591 }
592
593 /// Create a default backend handle (Faer backend)
594 pub fn default() -> Self {
595 Self {
596 inner: Arc::new(FaerBackend),
597 }
598 }
599
600 /// Get a reference to the inner backend
601 pub(crate) fn as_ref(&self) -> &dyn GemmBackend {
602 self.inner.as_ref()
603 }
604}
605
606//==============================================================================
607// Global Dispatcher (for backward compatibility)
608//==============================================================================
609
610/// Global BLAS dispatcher (thread-safe)
611///
612/// This is kept for backward compatibility when `None` is passed as backend.
613/// New code should use `GemmBackendHandle` explicitly.
614static BLAS_DISPATCHER: Lazy<RwLock<Box<dyn GemmBackend>>> = Lazy::new(|| {
615 #[cfg(feature = "system-blas")]
616 {
617 // Use system BLAS (LP64) by default via `blas-sys`.
618 let backend =
619 ExternalBlasBackend::new(dgemm_ as DgemmFnPtr, zgemm_wrapper as ZgemmFnPtr);
620 RwLock::new(Box::new(backend) as Box<dyn GemmBackend>)
621 }
622 #[cfg(not(feature = "system-blas"))]
623 {
624 // Default to the pure Rust Faer backend.
625 RwLock::new(Box::new(FaerBackend) as Box<dyn GemmBackend>)
626 }
627});
628
629/// Set BLAS backend (LP64: 32-bit integers)
630///
631/// # Safety
632/// - Function pointers must be valid and thread-safe
633/// - Must remain valid for the lifetime of the program
634/// - Must follow Fortran BLAS calling convention
635///
636/// # Example
637/// ```ignore
638/// unsafe {
639/// set_blas_backend(dgemm_ as _, zgemm_ as _);
640/// }
641/// ```
642pub unsafe fn set_blas_backend(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) {
643 let backend = ExternalBlasBackend { dgemm, zgemm };
644 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
645 *dispatcher = Box::new(backend);
646}
647
648/// Set ILP64 BLAS backend (64-bit integers)
649///
650/// # Safety
651/// - Function pointers must be valid, thread-safe, and use 64-bit integers
652/// - Must remain valid for the lifetime of the program
653/// - Must follow Fortran BLAS calling convention with ILP64 interface
654///
655/// # Example
656/// ```ignore
657/// unsafe {
658/// set_ilp64_backend(dgemm_ as _, zgemm_ as _);
659/// }
660/// ```
661pub unsafe fn set_ilp64_backend(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) {
662 let backend = ExternalBlas64Backend { dgemm64, zgemm64 };
663 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
664 *dispatcher = Box::new(backend);
665}
666
667/// Clear BLAS backend (reset to default Faer)
668///
669/// This function resets the GEMM dispatcher to use the default Pure Rust Faer backend.
670pub fn clear_blas_backend() {
671 let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
672 *dispatcher = Box::new(FaerBackend);
673}
674
675/// Get current BLAS backend information
676///
677/// Returns:
678/// - `(backend_name, is_external, is_ilp64)`
679pub fn get_backend_info() -> (&'static str, bool, bool) {
680 let dispatcher = BLAS_DISPATCHER.read().unwrap();
681 let name = dispatcher.name();
682 let is_external = !name.contains("Faer");
683 let is_ilp64 = dispatcher.is_ilp64();
684 (name, is_external, is_ilp64)
685}
686
687//==============================================================================
688// Public API
689//==============================================================================
690
691/// Parallel matrix multiplication: C = A * B
692///
693/// Dispatches to the provided backend, or the global dispatcher if `None`.
694///
695/// # Arguments
696/// * `a` - Left matrix (M x K)
697/// * `b` - Right matrix (K x N)
698/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
699///
700/// # Returns
701/// Result matrix (M x N)
702///
703/// # Panics
704/// Panics if matrix dimensions are incompatible (A.cols != B.rows)
705///
706/// # Example
707/// ```ignore
708/// use mdarray::tensor;
709/// use sparse_ir::gemm::{matmul_par, GemmBackendHandle};
710///
711/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
712/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
713/// let backend = GemmBackendHandle::default();
714/// let c = matmul_par(&a, &b, Some(&backend));
715/// // c = [[19.0, 22.0], [43.0, 50.0]]
716/// ```
717pub fn matmul_par<T>(
718 a: &DTensor<T, 2>,
719 b: &DTensor<T, 2>,
720 backend: Option<&GemmBackendHandle>,
721) -> DTensor<T, 2>
722where
723 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
724{
725 let (_m, k) = *a.shape();
726 let (k2, _n) = *b.shape();
727
728 // Validate dimensions
729 assert_eq!(
730 k, k2,
731 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
732 k, k2
733 );
734
735 // Use Faer directly to avoid creating intermediate DTensors through backend
736 // create _m x _n result tensor
737 let mut result = DTensor::<T, 2>::from_elem([_m, _n], T::zero().into());
738 matmul_par_overwrite(a, b, &mut result, backend);
739 result
740}
741
742/// Parallel matrix multiplication with overwrite: C = A * B (writes to existing buffer)
743///
744/// This function writes the result directly into the provided buffer `c`,
745/// avoiding memory allocation. This is more memory-efficient for repeated operations.
746///
747/// # Arguments
748/// * `a` - Left matrix (M x K)
749/// * `b` - Right matrix (K x N)
750/// * `c` - Output matrix (M x N) - will be overwritten with result
751/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
752///
753/// # Panics
754/// Panics if matrix dimensions are incompatible (A.cols != B.rows or C.shape != [M, N])
755pub fn matmul_par_overwrite<T, Lc: Layout>(
756 a: &DTensor<T, 2>,
757 b: &DTensor<T, 2>,
758 c: &mut DSlice<T, 2, Lc>,
759 backend: Option<&GemmBackendHandle>,
760) where
761 T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
762{
763 let (m, k) = *a.shape();
764 let (k2, n) = *b.shape();
765 let (mc, nc) = *c.shape();
766
767 // Validate dimensions
768 assert_eq!(
769 k, k2,
770 "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
771 k, k2
772 );
773 assert_eq!(
774 m, mc,
775 "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
776 mc, m
777 );
778 assert_eq!(
779 n, nc,
780 "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
781 nc, n
782 );
783
784 // Type dispatch: f64 or Complex<f64>
785 if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
786 // f64 case
787 // Get pointers directly from DTensors (row-major order)
788 let a_ptr = a.as_ptr() as *const f64;
789 let b_ptr = b.as_ptr() as *const f64;
790 let c_ptr = c.as_mut_ptr() as *mut f64;
791
792 // Get backend: use provided handle or fall back to global dispatcher
793 match backend {
794 Some(handle) => {
795 // Call backend directly with pointers (no temporary buffer needed)
796 // Leading dimension is calculated internally in the backend
797 unsafe {
798 handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
799 }
800 }
801 None => {
802 // Backward compatibility: use global dispatcher
803 let dispatcher = BLAS_DISPATCHER.read().unwrap();
804 unsafe {
805 dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
806 }
807 }
808 }
809 } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
810 // Complex<f64> case
811 // Get pointers directly from DTensors (row-major order)
812 let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
813 let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
814 let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
815
816 // Get backend: use provided handle or fall back to global dispatcher
817 match backend {
818 Some(handle) => {
819 // Call backend directly with pointers (no temporary buffer needed)
820 // Leading dimension is calculated internally in the backend
821 unsafe {
822 handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
823 }
824 }
825 None => {
826 // Backward compatibility: use global dispatcher
827 let dispatcher = BLAS_DISPATCHER.read().unwrap();
828 unsafe {
829 dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
830 }
831 }
832 }
833 } else {
834 // Fallback to Faer for unsupported types
835 use mdarray_linalg::matmul::MatMulBuilder;
836 use mdarray_linalg::prelude::MatMul;
837 use mdarray_linalg_faer::Faer;
838
839 Faer.matmul(a, b).parallelize().overwrite(c);
840 }
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846
847 #[test]
848 fn test_default_backend_is_faer() {
849 let (name, is_external, is_ilp64) = get_backend_info();
850 assert_eq!(name, "Faer (Pure Rust)");
851 assert!(!is_external);
852 assert!(!is_ilp64);
853 }
854
855 #[test]
856 fn test_clear_backend() {
857 // Should not panic
858 clear_blas_backend();
859 let (name, _, _) = get_backend_info();
860 assert_eq!(name, "Faer (Pure Rust)");
861 }
862
863 #[test]
864 fn test_matmul_f64() {
865 let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
866 let b_data = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
867
868 let a = DTensor::<f64, 2>::from_fn([2, 3], |idx| a_data[idx[0] * 3 + idx[1]]);
869 let b = DTensor::<f64, 2>::from_fn([3, 2], |idx| b_data[idx[0] * 2 + idx[1]]);
870 let c = matmul_par(&a, &b, None);
871
872 assert_eq!(*c.shape(), (2, 2));
873 // First row: [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
874 // Second row: [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
875 assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
876 assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
877 assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
878 assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
879 }
880
881 #[test]
882 fn test_matmul_par_basic() {
883 use mdarray::tensor;
884 let a: DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
885 let b: DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
886 let c = matmul_par(&a, &b, None);
887
888 // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
889 // = [[19, 22], [43, 50]]
890 assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
891 assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
892 assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
893 assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
894 }
895
896 #[test]
897 fn test_matmul_par_non_square() {
898 use mdarray::tensor;
899 let a: DTensor<f64, 2> = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // 2x3
900 let b: DTensor<f64, 2> = tensor![[7.0], [8.0], [9.0]]; // 3x1
901 let c = matmul_par(&a, &b, None);
902
903 // Expected: [[1*7+2*8+3*9], [4*7+5*8+6*9]]
904 // = [[50], [122]]
905 assert!((c[[0, 0]] - 50.0).abs() < 1e-10);
906 assert!((c[[1, 0]] - 122.0).abs() < 1e-10);
907 }
908}