sparse_ir/
gemm.rs

1//! Matrix multiplication utilities with pluggable BLAS backend
2//!
3//! This module provides thin wrappers around matrix multiplication operations,
4//! with support for runtime selection of BLAS implementations.
5//!
6//! # Design
7//! - **Default**: Pure Rust Faer backend (no external dependencies)
8//! - **Optional**: External BLAS via function pointer injection
9//! - **Thread-safe**: Global dispatcher protected by RwLock
10//!
11//! # Example
12//! ```ignore
13//! use sparse_ir::gemm::{matmul_par, set_blas_backend};
14//!
15//! // Use default Faer backend
16//! let c = matmul_par(&a, &b);
17//!
18//! // Or inject custom BLAS (from C-API)
19//! unsafe {
20//!     set_blas_backend(my_dgemm_ptr, my_zgemm_ptr);
21//! }
22//! let c = matmul_par(&a, &b);  // Now uses custom BLAS
23//! ```
24
25use mdarray::{DSlice, DTensor, DView, DViewMut, Layout};
26use once_cell::sync::Lazy;
27use std::sync::{Arc, RwLock};
28
29#[cfg(feature = "system-blas")]
30use blas_sys::dgemm_;
31
32//==============================================================================
33// BLAS Function Pointer Types
34//==============================================================================
35
36/// BLAS dgemm function pointer type (LP64: 32-bit integers)
37///
38/// Signature matches Fortran BLAS dgemm:
39/// ```c
40/// void dgemm_(char *transa, char *transb, int *m, int *n, int *k,
41///             double *alpha, double *a, int *lda, double *b, int *ldb,
42///             double *beta, double *c, int *ldc);
43/// ```
44/// Note: All parameters are passed by reference (pointers).
45/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
46pub type DgemmFnPtr = unsafe extern "C" fn(
47    transa: *const libc::c_char,
48    transb: *const libc::c_char,
49    m: *const libc::c_int,
50    n: *const libc::c_int,
51    k: *const libc::c_int,
52    alpha: *const libc::c_double,
53    a: *const libc::c_double,
54    lda: *const libc::c_int,
55    b: *const libc::c_double,
56    ldb: *const libc::c_int,
57    beta: *const libc::c_double,
58    c: *mut libc::c_double,
59    ldc: *const libc::c_int,
60);
61
62/// BLAS zgemm function pointer type (LP64: 32-bit integers)
63///
64/// Signature matches Fortran BLAS zgemm:
65/// ```c
66/// void zgemm_(char *transa, char *transb, int *m, int *n, int *k,
67///             void *alpha, void *a, int *lda, void *b, int *ldb,
68///             void *beta, void *c, int *ldc);
69/// ```
70/// Note: All parameters are passed by reference (pointers).
71/// Complex numbers are passed as void* (typically complex<double>*).
72/// Transpose options: 'N' (no transpose), 'T' (transpose), 'C' (conjugate transpose).
73pub type ZgemmFnPtr = unsafe extern "C" fn(
74    transa: *const libc::c_char,
75    transb: *const libc::c_char,
76    m: *const libc::c_int,
77    n: *const libc::c_int,
78    k: *const libc::c_int,
79    alpha: *const num_complex::Complex<f64>,
80    a: *const num_complex::Complex<f64>,
81    lda: *const libc::c_int,
82    b: *const num_complex::Complex<f64>,
83    ldb: *const libc::c_int,
84    beta: *const num_complex::Complex<f64>,
85    c: *mut num_complex::Complex<f64>,
86    ldc: *const libc::c_int,
87);
88
89// When using system BLAS via `blas-sys`, we need a small wrapper to adapt
90// `blas_sys::zgemm_` (which uses `c_double_complex = [f64; 2]`) to the
91// `ZgemmFnPtr` signature that takes `num_complex::Complex<f64>`.
92#[cfg(feature = "system-blas")]
93unsafe extern "C" fn zgemm_wrapper(
94    transa: *const libc::c_char,
95    transb: *const libc::c_char,
96    m: *const libc::c_int,
97    n: *const libc::c_int,
98    k: *const libc::c_int,
99    alpha: *const num_complex::Complex<f64>,
100    a: *const num_complex::Complex<f64>,
101    lda: *const libc::c_int,
102    b: *const num_complex::Complex<f64>,
103    ldb: *const libc::c_int,
104    beta: *const num_complex::Complex<f64>,
105    c: *mut num_complex::Complex<f64>,
106    ldc: *const libc::c_int,
107) {
108    // Safety: `blas_sys::c_double_complex` is defined as `[f64; 2]` and is
109    // layout-compatible with `num_complex::Complex<f64>` in memory, so we can
110    // cast between the two pointer types here.
111    unsafe {
112        blas_sys::zgemm_(
113            transa,
114            transb,
115            m,
116            n,
117            k,
118            alpha as *const _ as *const blas_sys::c_double_complex,
119            a as *const _ as *const blas_sys::c_double_complex,
120            lda,
121            b as *const _ as *const blas_sys::c_double_complex,
122            ldb,
123            beta as *const _ as *const blas_sys::c_double_complex,
124            c as *mut _ as *mut blas_sys::c_double_complex,
125            ldc,
126        );
127    }
128}
129
130/// BLAS dgemm function pointer type (ILP64: 64-bit integers)
131///
132/// Signature matches Fortran BLAS dgemm (ILP64):
133/// ```c
134/// void dgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
135///             double *alpha, double *a, long long *lda, double *b, long long *ldb,
136///             double *beta, double *c, long long *ldc);
137/// ```
138pub type Dgemm64FnPtr = unsafe extern "C" fn(
139    transa: *const libc::c_char,
140    transb: *const libc::c_char,
141    m: *const i64,
142    n: *const i64,
143    k: *const i64,
144    alpha: *const libc::c_double,
145    a: *const libc::c_double,
146    lda: *const i64,
147    b: *const libc::c_double,
148    ldb: *const i64,
149    beta: *const libc::c_double,
150    c: *mut libc::c_double,
151    ldc: *const i64,
152);
153
154/// BLAS zgemm function pointer type (ILP64: 64-bit integers)
155///
156/// Signature matches Fortran BLAS zgemm (ILP64):
157/// ```c
158/// void zgemm_(char *transa, char *transb, long long *m, long long *n, long long *k,
159///             void *alpha, void *a, long long *lda, void *b, long long *ldb,
160///             void *beta, void *c, long long *ldc);
161/// ```
162pub type Zgemm64FnPtr = unsafe extern "C" fn(
163    transa: *const libc::c_char,
164    transb: *const libc::c_char,
165    m: *const i64,
166    n: *const i64,
167    k: *const i64,
168    alpha: *const num_complex::Complex<f64>,
169    a: *const num_complex::Complex<f64>,
170    lda: *const i64,
171    b: *const num_complex::Complex<f64>,
172    ldb: *const i64,
173    beta: *const num_complex::Complex<f64>,
174    c: *mut num_complex::Complex<f64>,
175    ldc: *const i64,
176);
177
178//==============================================================================
179// Fortran BLAS Constants
180//==============================================================================
181
182// Fortran BLAS transpose characters
183
184//==============================================================================
185// GemmBackend Trait
186//==============================================================================
187
188/// GEMM backend trait for runtime dispatch
189pub trait GemmBackend: Send + Sync {
190    /// Matrix multiplication: C = A * B (f64)
191    ///
192    /// # Arguments
193    /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
194    /// * `a` - Pointer to matrix A (row-major, M x K)
195    /// * `b` - Pointer to matrix B (row-major, K x N)
196    /// * `c` - Pointer to output matrix C (row-major, M x N)
197    /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
198    unsafe fn dgemm(&self, m: usize, n: usize, k: usize, a: *const f64, b: *const f64, c: *mut f64);
199
200    /// Matrix multiplication: C = A * B (Complex<f64>)
201    ///
202    /// # Arguments
203    /// * `m`, `n`, `k` - Matrix dimensions (M x K) * (K x N) = (M x N)
204    /// * `a` - Pointer to matrix A (row-major, M x K)
205    /// * `b` - Pointer to matrix B (row-major, K x N)
206    /// * `c` - Pointer to output matrix C (row-major, M x N)
207    /// Note: Leading dimension is calculated internally based on row-major to column-major conversion
208    unsafe fn zgemm(
209        &self,
210        m: usize,
211        n: usize,
212        k: usize,
213        a: *const num_complex::Complex<f64>,
214        b: *const num_complex::Complex<f64>,
215        c: *mut num_complex::Complex<f64>,
216    );
217
218    /// Returns true if this backend uses 64-bit integers (ILP64)
219    fn is_ilp64(&self) -> bool {
220        false
221    }
222
223    /// Returns backend name for debugging
224    fn name(&self) -> &'static str;
225}
226
227//==============================================================================
228// Faer Backend (Default, Pure Rust, Zero-Copy)
229//==============================================================================
230
231/// Default Faer backend (Pure Rust, no external dependencies)
232///
233/// This implementation uses faer's native API with raw pointer views,
234/// achieving zero-copy matrix multiplication without intermediate allocations.
235struct FaerBackend;
236
237impl GemmBackend for FaerBackend {
238    unsafe fn dgemm(
239        &self,
240        m: usize,
241        n: usize,
242        k: usize,
243        a: *const f64,
244        b: *const f64,
245        c: *mut f64,
246    ) {
247        use faer::linalg::matmul::matmul;
248        use faer::mat::{MatMut, MatRef};
249        use faer::{Accum, Par};
250
251        // Create views directly from raw pointers (zero-copy!)
252        // Row-major layout: row_stride = number of columns, col_stride = 1
253        let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
254        let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
255        let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
256
257        // In-place matrix multiplication (no intermediate allocations)
258        matmul(&mut dst, Accum::Replace, &lhs, &rhs, 1.0, Par::Seq);
259    }
260
261    unsafe fn zgemm(
262        &self,
263        m: usize,
264        n: usize,
265        k: usize,
266        a: *const num_complex::Complex<f64>,
267        b: *const num_complex::Complex<f64>,
268        c: *mut num_complex::Complex<f64>,
269    ) {
270        use faer::linalg::matmul::matmul;
271        use faer::mat::{MatMut, MatRef};
272        use faer::{Accum, Par};
273
274        // Create views directly from raw pointers (zero-copy!)
275        // Row-major layout: row_stride = number of columns, col_stride = 1
276        let lhs = unsafe { MatRef::from_raw_parts(a, m, k, k as isize, 1) };
277        let rhs = unsafe { MatRef::from_raw_parts(b, k, n, n as isize, 1) };
278        let mut dst = unsafe { MatMut::from_raw_parts_mut(c, m, n, n as isize, 1) };
279
280        // In-place matrix multiplication (no intermediate allocations)
281        matmul(
282            &mut dst,
283            Accum::Replace,
284            &lhs,
285            &rhs,
286            num_complex::Complex::new(1.0, 0.0),
287            Par::Seq,
288        );
289    }
290
291    fn name(&self) -> &'static str {
292        "Faer (Pure Rust)"
293    }
294}
295
296//==============================================================================
297// External BLAS Backends (LP64 and ILP64)
298//==============================================================================
299
300/// Conversion rules for row-major data to column-major BLAS:
301///
302/// **Goal**: Compute C = A * B where:
303///   - A is m×k (row-major)
304///   - B is k×n (row-major)
305///   - C is m×n (row-major)
306///
307/// **Row-major to column-major interpretation**:
308///   - Row-major A (m×k) appears as A^T (k×m) in column-major → call this At
309///   - Row-major B (k×n) appears as B^T (n×k) in column-major → call this Bt
310///   - Row-major C (m×n) appears as C^T (n×m) in column-major → call this Ct
311///   - To compute C = A * B, we need: C^T = (A * B)^T = B^T * A^T
312///   - So: Ct = Bt * At
313///
314/// **BLAS call transformation**:
315///   - Original: C = A * B (row-major world)
316///   - BLAS call: Ct = Bt * At (column-major world)
317///   - transa = 'N' (Bt is already transposed-looking, no transpose needed)
318///   - transb = 'N' (At is already transposed-looking, no transpose needed)
319///   - Call: dgemm('N', 'N', n, m, k, alpha, B, lda, A, ldb, beta, C, ldc)
320///
321/// **Dimension conversions**:
322///   - m_blas = n (Ct rows = Bt rows)
323///   - n_blas = m (Ct cols = At cols)
324///   - k_blas = k (common dimension)
325///   - lda = n (leading dimension of Bt: n×k in column-major, lda = n)
326///   - ldb = k (leading dimension of At: k×m in column-major, ldb = k)
327///   - ldc = n (leading dimension of Ct: n×m in column-major, ldc = n)
328
329/// External BLAS backend (LP64: 32-bit integers)
330pub struct ExternalBlasBackend {
331    dgemm: DgemmFnPtr,
332    zgemm: ZgemmFnPtr,
333}
334
335impl ExternalBlasBackend {
336    pub fn new(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) -> Self {
337        Self { dgemm, zgemm }
338    }
339}
340
341impl GemmBackend for ExternalBlasBackend {
342    unsafe fn dgemm(
343        &self,
344        m: usize,
345        n: usize,
346        k: usize,
347        a: *const f64,
348        b: *const f64,
349        c: *mut f64,
350    ) {
351        // Validate dimensions fit in i32
352        assert!(
353            m <= i32::MAX as usize,
354            "Matrix dimension m too large for LP64 BLAS"
355        );
356        assert!(
357            n <= i32::MAX as usize,
358            "Matrix dimension n too large for LP64 BLAS"
359        );
360        assert!(
361            k <= i32::MAX as usize,
362            "Matrix dimension k too large for LP64 BLAS"
363        );
364
365        // Fortran BLAS requires all parameters passed by reference
366        // Apply row-major to column-major conversion (see conversion rules above)
367        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
368        let transb = b'N' as libc::c_char; // At is already transposed-looking
369        let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
370        let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
371        let k_i32 = k as i32; // k_blas = k (common dimension)
372        let alpha = 1.0f64;
373        let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
374        let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
375        let beta = 0.0f64;
376        // For row-major C (m×n) viewed as column-major Ct (n×m):
377        // Leading dimension in column-major is the stride between rows
378        // In row-major, stride between rows = number of columns = n
379        // So ldc = n (the number of columns in the original row-major matrix)
380        let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
381
382        unsafe {
383            (self.dgemm)(
384                &transa, &transb, &m_i32, &n_i32, &k_i32, &alpha, b, // B first (Bt)
385                &lda, a, // A second (At)
386                &ldb, &beta, c, &ldc_i32,
387            );
388        }
389    }
390
391    unsafe fn zgemm(
392        &self,
393        m: usize,
394        n: usize,
395        k: usize,
396        a: *const num_complex::Complex<f64>,
397        b: *const num_complex::Complex<f64>,
398        c: *mut num_complex::Complex<f64>,
399    ) {
400        assert!(
401            m <= i32::MAX as usize,
402            "Matrix dimension m too large for LP64 BLAS"
403        );
404        assert!(
405            n <= i32::MAX as usize,
406            "Matrix dimension n too large for LP64 BLAS"
407        );
408        assert!(
409            k <= i32::MAX as usize,
410            "Matrix dimension k too large for LP64 BLAS"
411        );
412
413        // Fortran BLAS requires all parameters passed by reference
414        // Apply row-major to column-major conversion (see conversion rules above)
415        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
416        let transb = b'N' as libc::c_char; // At is already transposed-looking
417        let m_i32 = n as i32; // m_blas = n (Ct rows = Bt rows)
418        let n_i32 = m as i32; // n_blas = m (Ct cols = At cols)
419        let k_i32 = k as i32; // k_blas = k (common dimension)
420        let alpha = num_complex::Complex::new(1.0, 0.0);
421        let lda = n as i32; // lda = n (leading dimension of Bt: n×k in column-major)
422        let ldb = k as i32; // ldb = k (leading dimension of At: k×m in column-major)
423        let beta = num_complex::Complex::new(0.0, 0.0);
424        // For row-major C (m×n) viewed as column-major Ct (n×m):
425        // Leading dimension in column-major is the stride between rows = n
426        let ldc_i32 = n as i32; // ldc = n (leading dimension of Ct: n×m in column-major)
427
428        unsafe {
429            (self.zgemm)(
430                &transa,
431                &transb,
432                &m_i32,
433                &n_i32,
434                &k_i32,
435                &alpha,
436                b as *const _, // B first (Bt)
437                &lda,
438                a as *const _, // A second (At)
439                &ldb,
440                &beta,
441                c as *mut _,
442                &ldc_i32,
443            );
444        }
445    }
446
447    fn name(&self) -> &'static str {
448        "External BLAS (LP64)"
449    }
450}
451
452/// External BLAS backend (ILP64: 64-bit integers)
453pub struct ExternalBlas64Backend {
454    dgemm64: Dgemm64FnPtr,
455    zgemm64: Zgemm64FnPtr,
456}
457
458impl ExternalBlas64Backend {
459    pub fn new(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) -> Self {
460        Self { dgemm64, zgemm64 }
461    }
462}
463
464impl GemmBackend for ExternalBlas64Backend {
465    unsafe fn dgemm(
466        &self,
467        m: usize,
468        n: usize,
469        k: usize,
470        a: *const f64,
471        b: *const f64,
472        c: *mut f64,
473    ) {
474        // Fortran BLAS requires all parameters passed by reference
475        // Apply row-major to column-major conversion (see conversion rules above)
476        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
477        let transb = b'N' as libc::c_char; // At is already transposed-looking
478        let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
479        let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
480        let k_i64 = k as i64; // k_blas = k (common dimension)
481        let alpha = 1.0f64;
482        let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
483        let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
484        let beta = 0.0f64;
485        // For row-major C (m×n) viewed as column-major Ct (n×m):
486        // Leading dimension in column-major is the stride between rows = n
487        let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
488
489        unsafe {
490            (self.dgemm64)(
491                &transa, &transb, &m_i64, &n_i64, &k_i64, &alpha, b, // B first (Bt)
492                &lda, a, // A second (At)
493                &ldb, &beta, c, &ldc_i64,
494            );
495        }
496    }
497
498    unsafe fn zgemm(
499        &self,
500        m: usize,
501        n: usize,
502        k: usize,
503        a: *const num_complex::Complex<f64>,
504        b: *const num_complex::Complex<f64>,
505        c: *mut num_complex::Complex<f64>,
506    ) {
507        // Fortran BLAS requires all parameters passed by reference
508        // Apply row-major to column-major conversion (see conversion rules above)
509        let transa = b'N' as libc::c_char; // Bt is already transposed-looking
510        let transb = b'N' as libc::c_char; // At is already transposed-looking
511        let m_i64 = n as i64; // m_blas = n (Ct rows = Bt rows)
512        let n_i64 = m as i64; // n_blas = m (Ct cols = At cols)
513        let k_i64 = k as i64; // k_blas = k (common dimension)
514        let alpha = num_complex::Complex::new(1.0, 0.0);
515        let lda = n as i64; // lda = n (leading dimension of Bt: n×k in column-major)
516        let ldb = k as i64; // ldb = k (leading dimension of At: k×m in column-major)
517        let beta = num_complex::Complex::new(0.0, 0.0);
518        // For row-major C (m×n) viewed as column-major Ct (n×m):
519        // Leading dimension in column-major is the stride between rows = n
520        let ldc_i64 = n as i64; // ldc = n (leading dimension of Ct: n×m in column-major)
521
522        unsafe {
523            (self.zgemm64)(
524                &transa,
525                &transb,
526                &m_i64,
527                &n_i64,
528                &k_i64,
529                &alpha,
530                b as *const _, // B first (Bt)
531                &lda,
532                a as *const _, // A second (At)
533                &ldb,
534                &beta,
535                c as *mut _,
536                &ldc_i64,
537            );
538        }
539    }
540
541    fn is_ilp64(&self) -> bool {
542        true
543    }
544
545    fn name(&self) -> &'static str {
546        "External BLAS (ILP64)"
547    }
548}
549
550//==============================================================================
551// Backend Handle
552//==============================================================================
553
554/// Thread-safe handle to a GEMM backend
555///
556/// This type wraps an `Arc<dyn GemmBackend>` to allow sharing a backend
557/// across multiple function calls without global state.
558///
559/// # Example
560/// ```ignore
561/// use sparse_ir::gemm::GemmBackendHandle;
562///
563/// let backend = GemmBackendHandle::default();
564/// let result = matmul_par(&a, &b, Some(&backend));
565/// ```
566#[derive(Clone)]
567pub struct GemmBackendHandle {
568    inner: Arc<dyn GemmBackend>,
569}
570
571impl GemmBackendHandle {
572    /// Create a new backend handle from a boxed backend
573    pub fn new(backend: Box<dyn GemmBackend>) -> Self {
574        Self {
575            inner: Arc::from(backend),
576        }
577    }
578
579    /// Create a default backend handle (Faer backend)
580    pub fn default() -> Self {
581        Self {
582            inner: Arc::new(FaerBackend),
583        }
584    }
585
586    /// Get a reference to the inner backend
587    pub(crate) fn as_ref(&self) -> &dyn GemmBackend {
588        self.inner.as_ref()
589    }
590}
591
592//==============================================================================
593// Global Dispatcher (for backward compatibility)
594//==============================================================================
595
596/// Global BLAS dispatcher (thread-safe)
597///
598/// This is kept for backward compatibility when `None` is passed as backend.
599/// New code should use `GemmBackendHandle` explicitly.
600static BLAS_DISPATCHER: Lazy<RwLock<Box<dyn GemmBackend>>> = Lazy::new(|| {
601    #[cfg(feature = "system-blas")]
602    {
603        // Use system BLAS (LP64) by default via `blas-sys`.
604        let backend = ExternalBlasBackend::new(dgemm_ as DgemmFnPtr, zgemm_wrapper as ZgemmFnPtr);
605        RwLock::new(Box::new(backend) as Box<dyn GemmBackend>)
606    }
607    #[cfg(not(feature = "system-blas"))]
608    {
609        // Default to the pure Rust Faer backend.
610        RwLock::new(Box::new(FaerBackend) as Box<dyn GemmBackend>)
611    }
612});
613
614/// Set BLAS backend (LP64: 32-bit integers)
615///
616/// # Safety
617/// - Function pointers must be valid and thread-safe
618/// - Must remain valid for the lifetime of the program
619/// - Must follow Fortran BLAS calling convention
620///
621/// # Example
622/// ```ignore
623/// unsafe {
624///     set_blas_backend(dgemm_ as _, zgemm_ as _);
625/// }
626/// ```
627pub unsafe fn set_blas_backend(dgemm: DgemmFnPtr, zgemm: ZgemmFnPtr) {
628    let backend = ExternalBlasBackend { dgemm, zgemm };
629    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
630    *dispatcher = Box::new(backend);
631}
632
633/// Set ILP64 BLAS backend (64-bit integers)
634///
635/// # Safety
636/// - Function pointers must be valid, thread-safe, and use 64-bit integers
637/// - Must remain valid for the lifetime of the program
638/// - Must follow Fortran BLAS calling convention with ILP64 interface
639///
640/// # Example
641/// ```ignore
642/// unsafe {
643///     set_ilp64_backend(dgemm_ as _, zgemm_ as _);
644/// }
645/// ```
646pub unsafe fn set_ilp64_backend(dgemm64: Dgemm64FnPtr, zgemm64: Zgemm64FnPtr) {
647    let backend = ExternalBlas64Backend { dgemm64, zgemm64 };
648    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
649    *dispatcher = Box::new(backend);
650}
651
652/// Clear BLAS backend (reset to default Faer)
653///
654/// This function resets the GEMM dispatcher to use the default Pure Rust Faer backend.
655pub fn clear_blas_backend() {
656    let mut dispatcher = BLAS_DISPATCHER.write().unwrap();
657    *dispatcher = Box::new(FaerBackend);
658}
659
660/// Get current BLAS backend information
661///
662/// Returns:
663/// - `(backend_name, is_external, is_ilp64)`
664pub fn get_backend_info() -> (&'static str, bool, bool) {
665    let dispatcher = BLAS_DISPATCHER.read().unwrap();
666    let name = dispatcher.name();
667    let is_external = !name.contains("Faer");
668    let is_ilp64 = dispatcher.is_ilp64();
669    (name, is_external, is_ilp64)
670}
671
672//==============================================================================
673// Public API
674//==============================================================================
675
676/// Parallel matrix multiplication: C = A * B
677///
678/// Dispatches to the provided backend, or the global dispatcher if `None`.
679///
680/// # Arguments
681/// * `a` - Left matrix (M x K)
682/// * `b` - Right matrix (K x N)
683/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
684///
685/// # Returns
686/// Result matrix (M x N)
687///
688/// # Panics
689/// Panics if matrix dimensions are incompatible (A.cols != B.rows)
690///
691/// # Example
692/// ```ignore
693/// use mdarray::tensor;
694/// use sparse_ir::gemm::{matmul_par, GemmBackendHandle};
695///
696/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
697/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
698/// let backend = GemmBackendHandle::default();
699/// let c = matmul_par(&a, &b, Some(&backend));
700/// // c = [[19.0, 22.0], [43.0, 50.0]]
701/// ```
702pub fn matmul_par<T>(
703    a: &DTensor<T, 2>,
704    b: &DTensor<T, 2>,
705    backend: Option<&GemmBackendHandle>,
706) -> DTensor<T, 2>
707where
708    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
709{
710    let (_m, k) = *a.shape();
711    let (k2, _n) = *b.shape();
712
713    // Validate dimensions
714    assert_eq!(
715        k, k2,
716        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
717        k, k2
718    );
719
720    // Use Faer directly to avoid creating intermediate DTensors through backend
721    // create _m x _n result tensor
722    let mut result = DTensor::<T, 2>::from_elem([_m, _n], T::zero().into());
723    matmul_par_overwrite(a, b, &mut result, backend);
724    result
725}
726
727/// Parallel matrix multiplication accepting DView (assumes contiguous memory)
728///
729/// This function accepts `DView` instead of `DTensor`, allowing views of arrays
730/// to be used directly without copying. The view must have contiguous memory layout.
731///
732/// # Arguments
733/// * `a` - Left matrix view (M x K) - must be contiguous
734/// * `b` - Right matrix view (K x N) - must be contiguous
735/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
736///
737/// # Panics
738/// Panics if:
739/// - Matrix dimensions are incompatible (A.cols != B.rows)
740/// - Views are not contiguous in memory
741///
742/// # Example
743/// ```ignore
744/// use mdarray::DView;
745///
746/// let a = tensor![[1.0, 2.0], [3.0, 4.0]];
747/// let b = tensor![[5.0, 6.0], [7.0, 8.0]];
748/// let a_view: DView<'_, f64, 2> = a.view(..);
749/// let b_view: DView<'_, f64, 2> = b.view(..);
750/// let c = matmul_par_view(&a_view, &b_view, None);
751/// ```
752pub fn matmul_par_view<T>(
753    a: &DView<'_, T, 2>,
754    b: &DView<'_, T, 2>,
755    backend: Option<&GemmBackendHandle>,
756) -> DTensor<T, 2>
757where
758    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
759{
760    // Check that views are contiguous (required for BLAS operations)
761    assert!(
762        a.is_contiguous(),
763        "Matrix A view must be contiguous in memory"
764    );
765    assert!(
766        b.is_contiguous(),
767        "Matrix B view must be contiguous in memory"
768    );
769
770    let (m, k) = *a.shape();
771    let (k2, n) = *b.shape();
772
773    // Validate dimensions
774    assert_eq!(
775        k, k2,
776        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
777        k, k2
778    );
779
780    // Create result tensor
781    let mut result = DTensor::<T, 2>::from_elem([m, n], T::zero().into());
782    matmul_par_overwrite_view(a, b, &mut result, backend);
783    result
784}
785
786/// Parallel matrix multiplication with overwrite accepting DView (assumes contiguous memory)
787///
788/// This function writes the result directly into the provided buffer `c`,
789/// accepting `DView` inputs. The views must have contiguous memory layout.
790///
791/// # Arguments
792/// * `a` - Left matrix view (M x K) - must be contiguous
793/// * `b` - Right matrix view (K x N) - must be contiguous
794/// * `c` - Output matrix (M x N) - will be overwritten with result
795/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
796///
797/// # Panics
798/// Panics if:
799/// - Matrix dimensions are incompatible
800/// - Views are not contiguous in memory
801pub fn matmul_par_overwrite_view<T>(
802    a: &DView<'_, T, 2>,
803    b: &DView<'_, T, 2>,
804    c: &mut DTensor<T, 2>,
805    backend: Option<&GemmBackendHandle>,
806) where
807    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
808{
809    // Check that views are contiguous (required for BLAS operations)
810    assert!(
811        a.is_contiguous(),
812        "Matrix A view must be contiguous in memory"
813    );
814    assert!(
815        b.is_contiguous(),
816        "Matrix B view must be contiguous in memory"
817    );
818
819    let (m, k) = *a.shape();
820    let (k2, n) = *b.shape();
821    let (mc, nc) = *c.shape();
822
823    // Validate dimensions
824    assert_eq!(
825        k, k2,
826        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
827        k, k2
828    );
829    assert_eq!(
830        m, mc,
831        "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
832        mc, m
833    );
834    assert_eq!(
835        n, nc,
836        "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
837        nc, n
838    );
839
840    // Type dispatch: f64 or Complex<f64>
841    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
842        // f64 case
843        // Get pointers from views (contiguous memory assumed)
844        let a_ptr = a.as_ptr() as *const f64;
845        let b_ptr = b.as_ptr() as *const f64;
846        let c_ptr = c.as_mut_ptr() as *mut f64;
847
848        // Get backend: use provided handle or fall back to global dispatcher
849        match backend {
850            Some(handle) => unsafe {
851                handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
852            },
853            None => {
854                let dispatcher = BLAS_DISPATCHER.read().unwrap();
855                unsafe {
856                    dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
857                }
858            }
859        }
860    } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
861        // Complex<f64> case
862        let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
863        let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
864        let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
865
866        match backend {
867            Some(handle) => unsafe {
868                handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
869            },
870            None => {
871                let dispatcher = BLAS_DISPATCHER.read().unwrap();
872                unsafe {
873                    dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
874                }
875            }
876        }
877    } else {
878        // Fallback to Faer for unsupported types
879        // Convert views to DTensors for Faer (this will copy, but only for unsupported types)
880        let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
881        let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
882        use mdarray_linalg::matmul::MatMulBuilder;
883        use mdarray_linalg::prelude::MatMul;
884        use mdarray_linalg_faer::Faer;
885
886        Faer.matmul(&a_tensor, &b_tensor).parallelize().overwrite(c);
887    }
888}
889
890/// Parallel matrix multiplication with overwrite to mutable view: C = A * B
891///
892/// This function writes the result directly into the provided mutable view `c`,
893/// allowing zero-copy writes to pre-allocated buffers (e.g., C pointers via FFI).
894///
895/// # Arguments
896/// * `a` - Left matrix view (M x K), must be contiguous
897/// * `b` - Right matrix view (K x N), must be contiguous
898/// * `c` - Output mutable view (M x N), must be contiguous - will be overwritten
899/// * `backend` - Optional backend handle. If `None`, uses global dispatcher
900///
901/// # Panics
902/// Panics if:
903/// - Matrix dimensions are incompatible
904/// - Views are not contiguous in memory
905pub fn matmul_par_to_viewmut<T>(
906    a: &DView<'_, T, 2>,
907    b: &DView<'_, T, 2>,
908    c: &mut DViewMut<'_, T, 2>,
909    backend: Option<&GemmBackendHandle>,
910) where
911    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
912{
913    // Check that views are contiguous (required for BLAS operations)
914    assert!(
915        a.is_contiguous(),
916        "Matrix A view must be contiguous in memory"
917    );
918    assert!(
919        b.is_contiguous(),
920        "Matrix B view must be contiguous in memory"
921    );
922    assert!(
923        c.is_contiguous(),
924        "Matrix C view must be contiguous in memory"
925    );
926
927    let (m, k) = *a.shape();
928    let (k2, n) = *b.shape();
929    let (mc, nc) = *c.shape();
930
931    // Validate dimensions
932    assert_eq!(
933        k, k2,
934        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
935        k, k2
936    );
937    assert_eq!(
938        m, mc,
939        "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
940        mc, m
941    );
942    assert_eq!(
943        n, nc,
944        "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
945        nc, n
946    );
947
948    // Type dispatch: f64 or Complex<f64>
949    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
950        // f64 case
951        let a_ptr = a.as_ptr() as *const f64;
952        let b_ptr = b.as_ptr() as *const f64;
953        let c_ptr = c.as_mut_ptr() as *mut f64;
954
955        match backend {
956            Some(handle) => unsafe {
957                handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
958            },
959            None => {
960                let dispatcher = BLAS_DISPATCHER.read().unwrap();
961                unsafe {
962                    dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
963                }
964            }
965        }
966    } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
967        // Complex<f64> case
968        let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
969        let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
970        let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
971
972        match backend {
973            Some(handle) => unsafe {
974                handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
975            },
976            None => {
977                let dispatcher = BLAS_DISPATCHER.read().unwrap();
978                unsafe {
979                    dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
980                }
981            }
982        }
983    } else {
984        // Fallback: convert to DTensor (will copy)
985        let a_tensor = DTensor::<T, 2>::from_fn(*a.shape(), |idx| a[idx]);
986        let b_tensor = DTensor::<T, 2>::from_fn(*b.shape(), |idx| b[idx]);
987        let mut c_tensor = DTensor::<T, 2>::from_fn(*c.shape(), |_| T::zero());
988        use mdarray_linalg::matmul::MatMulBuilder;
989        use mdarray_linalg::prelude::MatMul;
990        use mdarray_linalg_faer::Faer;
991
992        Faer.matmul(&a_tensor, &b_tensor)
993            .parallelize()
994            .overwrite(&mut c_tensor);
995
996        // Copy back to view
997        for i in 0..mc {
998            for j in 0..nc {
999                c[[i, j]] = c_tensor[[i, j]];
1000            }
1001        }
1002    }
1003}
1004
1005/// Parallel matrix multiplication with overwrite: C = A * B (writes to existing buffer)
1006///
1007/// This function writes the result directly into the provided buffer `c`,
1008/// avoiding memory allocation. This is more memory-efficient for repeated operations.
1009///
1010/// # Arguments
1011/// * `a` - Left matrix (M x K)
1012/// * `b` - Right matrix (K x N)
1013/// * `c` - Output matrix (M x N) - will be overwritten with result
1014/// * `backend` - Optional backend handle. If `None`, uses global dispatcher (for backward compatibility)
1015///
1016/// # Panics
1017/// Panics if matrix dimensions are incompatible (A.cols != B.rows or C.shape != [M, N])
1018pub fn matmul_par_overwrite<T, Lc: Layout>(
1019    a: &DTensor<T, 2>,
1020    b: &DTensor<T, 2>,
1021    c: &mut DSlice<T, 2, Lc>,
1022    backend: Option<&GemmBackendHandle>,
1023) where
1024    T: num_complex::ComplexFloat + faer_traits::ComplexField + num_traits::One + Copy + 'static,
1025{
1026    let (m, k) = *a.shape();
1027    let (k2, n) = *b.shape();
1028    let (mc, nc) = *c.shape();
1029
1030    // Validate dimensions
1031    assert_eq!(
1032        k, k2,
1033        "Matrix dimension mismatch: A.cols ({}) != B.rows ({})",
1034        k, k2
1035    );
1036    assert_eq!(
1037        m, mc,
1038        "Output matrix dimension mismatch: C.rows ({}) != A.rows ({})",
1039        mc, m
1040    );
1041    assert_eq!(
1042        n, nc,
1043        "Output matrix dimension mismatch: C.cols ({}) != B.cols ({})",
1044        nc, n
1045    );
1046
1047    // Type dispatch: f64 or Complex<f64>
1048    if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
1049        // f64 case
1050        // Get pointers directly from DTensors (row-major order)
1051        let a_ptr = a.as_ptr() as *const f64;
1052        let b_ptr = b.as_ptr() as *const f64;
1053        let c_ptr = c.as_mut_ptr() as *mut f64;
1054
1055        // Get backend: use provided handle or fall back to global dispatcher
1056        match backend {
1057            Some(handle) => {
1058                // Call backend directly with pointers (no temporary buffer needed)
1059                // Leading dimension is calculated internally in the backend
1060                unsafe {
1061                    handle.as_ref().dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1062                }
1063            }
1064            None => {
1065                // Backward compatibility: use global dispatcher
1066                let dispatcher = BLAS_DISPATCHER.read().unwrap();
1067                unsafe {
1068                    dispatcher.dgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1069                }
1070            }
1071        }
1072    } else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<num_complex::Complex<f64>>() {
1073        // Complex<f64> case
1074        // Get pointers directly from DTensors (row-major order)
1075        let a_ptr = a.as_ptr() as *const num_complex::Complex<f64>;
1076        let b_ptr = b.as_ptr() as *const num_complex::Complex<f64>;
1077        let c_ptr = c.as_mut_ptr() as *mut num_complex::Complex<f64>;
1078
1079        // Get backend: use provided handle or fall back to global dispatcher
1080        match backend {
1081            Some(handle) => {
1082                // Call backend directly with pointers (no temporary buffer needed)
1083                // Leading dimension is calculated internally in the backend
1084                unsafe {
1085                    handle.as_ref().zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1086                }
1087            }
1088            None => {
1089                // Backward compatibility: use global dispatcher
1090                let dispatcher = BLAS_DISPATCHER.read().unwrap();
1091                unsafe {
1092                    dispatcher.zgemm(m, n, k, a_ptr, b_ptr, c_ptr);
1093                }
1094            }
1095        }
1096    } else {
1097        // Fallback to Faer for unsupported types
1098        use mdarray_linalg::matmul::MatMulBuilder;
1099        use mdarray_linalg::prelude::MatMul;
1100        use mdarray_linalg_faer::Faer;
1101
1102        Faer.matmul(a, b).parallelize().overwrite(c);
1103    }
1104}
1105
1106#[cfg(test)]
1107mod tests {
1108    use super::*;
1109    use mdarray::DView;
1110
1111    #[test]
1112    #[cfg(not(feature = "system-blas"))]
1113    fn test_default_backend_is_faer() {
1114        let (name, is_external, is_ilp64) = get_backend_info();
1115        assert_eq!(name, "Faer (Pure Rust)");
1116        assert!(!is_external);
1117        assert!(!is_ilp64);
1118    }
1119
1120    #[test]
1121    fn test_matmul_par_view() {
1122        // Test with f64
1123        let a = DTensor::<f64, 2>::from([[1.0, 2.0], [3.0, 4.0]]);
1124        let b = DTensor::<f64, 2>::from([[5.0, 6.0], [7.0, 8.0]]);
1125        let a_view: DView<'_, f64, 2> = a.view(.., ..);
1126        let b_view: DView<'_, f64, 2> = b.view(.., ..);
1127
1128        let c_view = matmul_par_view(&a_view, &b_view, None);
1129        let c_expected = matmul_par(&a, &b, None);
1130
1131        // Results should be identical
1132        assert_eq!(c_view.shape(), c_expected.shape());
1133        for i in 0..c_view.shape().0 {
1134            for j in 0..c_view.shape().1 {
1135                assert!((c_view[[i, j]] - c_expected[[i, j]]).abs() < 1e-10);
1136            }
1137        }
1138    }
1139
1140    #[test]
1141    fn test_matmul_par_overwrite_view() {
1142        // Test with Complex<f64>
1143        use num_complex::Complex;
1144        let a = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
1145            Complex::new((idx[0] * 2 + idx[1]) as f64, 0.0)
1146        });
1147        let b = DTensor::<Complex<f64>, 2>::from_fn([2, 2], |idx| {
1148            Complex::new((idx[0] * 2 + idx[1] + 10) as f64, 0.0)
1149        });
1150        let a_view: DView<'_, Complex<f64>, 2> = a.view(.., ..);
1151        let b_view: DView<'_, Complex<f64>, 2> = b.view(.., ..);
1152
1153        let mut c_view = DTensor::<Complex<f64>, 2>::from_elem([2, 2], Complex::new(0.0, 0.0));
1154        matmul_par_overwrite_view(&a_view, &b_view, &mut c_view, None);
1155
1156        let c_expected = matmul_par(&a, &b, None);
1157
1158        // Results should be identical
1159        assert_eq!(c_view.shape(), c_expected.shape());
1160        for i in 0..c_view.shape().0 {
1161            for j in 0..c_view.shape().1 {
1162                assert!((c_view[[i, j]] - c_expected[[i, j]]).norm() < 1e-10);
1163            }
1164        }
1165    }
1166
1167    #[test]
1168    fn test_clear_backend() {
1169        // Should not panic
1170        clear_blas_backend();
1171        let (name, _, _) = get_backend_info();
1172        assert_eq!(name, "Faer (Pure Rust)");
1173    }
1174
1175    #[test]
1176    fn test_matmul_f64() {
1177        let a_data = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1178        let b_data = [7.0, 8.0, 9.0, 10.0, 11.0, 12.0];
1179
1180        let a = DTensor::<f64, 2>::from_fn([2, 3], |idx| a_data[idx[0] * 3 + idx[1]]);
1181        let b = DTensor::<f64, 2>::from_fn([3, 2], |idx| b_data[idx[0] * 2 + idx[1]]);
1182        let c = matmul_par(&a, &b, None);
1183
1184        assert_eq!(*c.shape(), (2, 2));
1185        // First row: [1*7+2*9+3*11, 1*8+2*10+3*12] = [58, 64]
1186        // Second row: [4*7+5*9+6*11, 4*8+5*10+6*12] = [139, 154]
1187        assert!((c[[0, 0]] - 58.0).abs() < 1e-10);
1188        assert!((c[[0, 1]] - 64.0).abs() < 1e-10);
1189        assert!((c[[1, 0]] - 139.0).abs() < 1e-10);
1190        assert!((c[[1, 1]] - 154.0).abs() < 1e-10);
1191    }
1192
1193    #[test]
1194    fn test_matmul_par_basic() {
1195        use mdarray::tensor;
1196        let a: DTensor<f64, 2> = tensor![[1.0, 2.0], [3.0, 4.0]];
1197        let b: DTensor<f64, 2> = tensor![[5.0, 6.0], [7.0, 8.0]];
1198        let c = matmul_par(&a, &b, None);
1199
1200        // Expected: [[1*5+2*7, 1*6+2*8], [3*5+4*7, 3*6+4*8]]
1201        //         = [[19, 22], [43, 50]]
1202        assert!((c[[0, 0]] - 19.0).abs() < 1e-10);
1203        assert!((c[[0, 1]] - 22.0).abs() < 1e-10);
1204        assert!((c[[1, 0]] - 43.0).abs() < 1e-10);
1205        assert!((c[[1, 1]] - 50.0).abs() < 1e-10);
1206    }
1207
1208    #[test]
1209    fn test_matmul_par_non_square() {
1210        use mdarray::tensor;
1211        let a: DTensor<f64, 2> = tensor![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]; // 2x3
1212        let b: DTensor<f64, 2> = tensor![[7.0], [8.0], [9.0]]; // 3x1
1213        let c = matmul_par(&a, &b, None);
1214
1215        // Expected: [[1*7+2*8+3*9], [4*7+5*8+6*9]]
1216        //         = [[50], [122]]
1217        assert!((c[[0, 0]] - 50.0).abs() < 1e-10);
1218        assert!((c[[1, 0]] - 122.0).abs() < 1e-10);
1219    }
1220}