Skip to main content

sklears_simd/
external_integration.rs

1//! External SIMD library integration framework
2//!
3//! This module provides integration capabilities with external high-performance
4//! SIMD libraries such as Intel MKL, OpenBLAS, BLIS, and others.
5
6use crate::traits::SimdError;
7
8#[cfg(feature = "no-std")]
9use alloc::{
10    collections::BTreeMap as HashMap,
11    format,
12    string::{String, ToString},
13    sync::Arc,
14    vec,
15    vec::Vec,
16};
17#[cfg(not(feature = "no-std"))]
18use std::collections::HashMap;
19#[cfg(not(feature = "no-std"))]
20use std::string::ToString;
21#[cfg(not(feature = "no-std"))]
22use std::sync::Arc;
23
24#[cfg(feature = "no-std")]
25use spin::Mutex;
26#[cfg(not(feature = "no-std"))]
27use std::sync::Mutex;
28
29/// Result type for external library operations
30pub type ExternalResult<T> = Result<T, SimdError>;
31
32/// Trait for external library adapters
33pub trait ExternalLibrary: Send + Sync {
34    /// Get the library name
35    fn name(&self) -> &str;
36
37    /// Get the library version
38    fn version(&self) -> &str;
39
40    /// Check if the library is available on the system
41    fn is_available(&self) -> bool;
42
43    /// Initialize the library (if needed)
44    fn initialize(&mut self) -> ExternalResult<()>;
45
46    /// Get supported operations
47    fn supported_operations(&self) -> Vec<String>;
48
49    /// Check if a specific operation is supported
50    fn supports_operation(&self, operation: &str) -> bool {
51        self.supported_operations().contains(&operation.to_string())
52    }
53}
54
55/// Trait for BLAS-like operations from external libraries
56pub trait ExternalBlas: ExternalLibrary {
57    /// Vector dot product (SDOT)
58    fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32>;
59
60    /// Vector-scalar multiplication (SSCAL)
61    fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()>;
62
63    /// Vector addition (SAXPY): y = alpha * x + y
64    fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()>;
65
66    /// Matrix-vector multiplication (SGEMV)
67    #[allow(clippy::too_many_arguments)] // BLAS SGEMV signature
68    fn gemv(
69        &self,
70        alpha: f32,
71        a: &[f32],
72        m: usize,
73        n: usize,
74        x: &[f32],
75        beta: f32,
76        y: &mut [f32],
77    ) -> ExternalResult<()>;
78
79    /// Matrix-matrix multiplication (SGEMM)
80    #[allow(clippy::too_many_arguments)] // BLAS SGEMM signature
81    fn gemm(
82        &self,
83        alpha: f32,
84        a: &[f32],
85        m: usize,
86        k: usize,
87        b: &[f32],
88        n: usize,
89        beta: f32,
90        c: &mut [f32],
91    ) -> ExternalResult<()>;
92}
93
94/// Trait for LAPACK-like operations from external libraries
95pub trait ExternalLapack: ExternalLibrary {
96    /// LU decomposition
97    fn lu_decomposition(&self, a: &mut [f32], m: usize, n: usize) -> ExternalResult<Vec<i32>>;
98
99    /// QR decomposition
100    fn qr_decomposition(&self, a: &mut [f32], m: usize, n: usize) -> ExternalResult<Vec<f32>>;
101
102    /// Singular Value Decomposition
103    fn svd(
104        &self,
105        a: &mut [f32],
106        m: usize,
107        n: usize,
108    ) -> ExternalResult<(Vec<f32>, Vec<f32>, Vec<f32>)>;
109
110    /// Eigenvalue decomposition
111    fn eigenvalues(&self, a: &mut [f32], n: usize) -> ExternalResult<Vec<f32>>;
112}
113
114/// Trait for FFT operations from external libraries
115pub trait ExternalFft: ExternalLibrary {
116    /// Real-to-complex FFT
117    fn rfft(&self, input: &[f32]) -> ExternalResult<Vec<f32>>;
118
119    /// Complex-to-real inverse FFT
120    fn irfft(&self, input: &[f32]) -> ExternalResult<Vec<f32>>;
121
122    /// Complex-to-complex FFT
123    fn cfft(&self, real: &[f32], imag: &[f32]) -> ExternalResult<(Vec<f32>, Vec<f32>)>;
124
125    /// Complex-to-complex inverse FFT
126    fn icfft(&self, real: &[f32], imag: &[f32]) -> ExternalResult<(Vec<f32>, Vec<f32>)>;
127}
128
129/// Mock implementation for Intel MKL adapter (demonstration)
130#[derive(Debug, Clone)]
131pub struct MklAdapter {
132    initialized: bool,
133}
134
135impl MklAdapter {
136    pub fn new() -> Self {
137        Self { initialized: false }
138    }
139}
140
141impl Default for MklAdapter {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl ExternalLibrary for MklAdapter {
148    fn name(&self) -> &str {
149        "Intel MKL"
150    }
151
152    fn version(&self) -> &str {
153        "2024.0"
154    }
155
156    fn is_available(&self) -> bool {
157        // In a real implementation, this would check for MKL libraries
158        false // Mock: not available in test environment
159    }
160
161    fn initialize(&mut self) -> ExternalResult<()> {
162        if !self.is_available() {
163            return Err(SimdError::ExternalLibraryError(
164                "MKL not available".to_string(),
165            ));
166        }
167        self.initialized = true;
168        Ok(())
169    }
170
171    fn supported_operations(&self) -> Vec<String> {
172        vec![
173            "dot".to_string(),
174            "scal".to_string(),
175            "axpy".to_string(),
176            "gemv".to_string(),
177            "gemm".to_string(),
178            "lu".to_string(),
179            "qr".to_string(),
180            "svd".to_string(),
181            "fft".to_string(),
182        ]
183    }
184}
185
186impl ExternalBlas for MklAdapter {
187    fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32> {
188        if !self.initialized {
189            return Err(SimdError::ExternalLibraryError(
190                "MKL not initialized".to_string(),
191            ));
192        }
193
194        if x.len() != y.len() {
195            return Err(SimdError::InvalidInput(
196                "Vector lengths must match".to_string(),
197            ));
198        }
199
200        // Mock implementation - in reality would call cblas_sdot
201        Ok(x.iter().zip(y.iter()).map(|(a, b)| a * b).sum())
202    }
203
204    fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()> {
205        if !self.initialized {
206            return Err(SimdError::ExternalLibraryError(
207                "MKL not initialized".to_string(),
208            ));
209        }
210
211        // Mock implementation - in reality would call cblas_sscal
212        x.iter_mut().for_each(|v| *v *= alpha);
213        Ok(())
214    }
215
216    fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()> {
217        if !self.initialized {
218            return Err(SimdError::ExternalLibraryError(
219                "MKL not initialized".to_string(),
220            ));
221        }
222
223        if x.len() != y.len() {
224            return Err(SimdError::InvalidInput(
225                "Vector lengths must match".to_string(),
226            ));
227        }
228
229        // Mock implementation - in reality would call cblas_saxpy
230        for (yi, &xi) in y.iter_mut().zip(x.iter()) {
231            *yi += alpha * xi;
232        }
233        Ok(())
234    }
235
236    fn gemv(
237        &self,
238        alpha: f32,
239        a: &[f32],
240        m: usize,
241        n: usize,
242        x: &[f32],
243        beta: f32,
244        y: &mut [f32],
245    ) -> ExternalResult<()> {
246        if !self.initialized {
247            return Err(SimdError::ExternalLibraryError(
248                "MKL not initialized".to_string(),
249            ));
250        }
251
252        if a.len() != m * n || x.len() != n || y.len() != m {
253            return Err(SimdError::InvalidInput(
254                "Matrix/vector dimension mismatch".to_string(),
255            ));
256        }
257
258        // Mock implementation - in reality would call cblas_sgemv
259        for i in 0..m {
260            let mut sum = 0.0;
261            for j in 0..n {
262                sum += a[i * n + j] * x[j];
263            }
264            y[i] = alpha * sum + beta * y[i];
265        }
266        Ok(())
267    }
268
269    fn gemm(
270        &self,
271        alpha: f32,
272        a: &[f32],
273        m: usize,
274        k: usize,
275        b: &[f32],
276        n: usize,
277        beta: f32,
278        c: &mut [f32],
279    ) -> ExternalResult<()> {
280        if !self.initialized {
281            return Err(SimdError::ExternalLibraryError(
282                "MKL not initialized".to_string(),
283            ));
284        }
285
286        if a.len() != m * k || b.len() != k * n || c.len() != m * n {
287            return Err(SimdError::InvalidInput(
288                "Matrix dimension mismatch".to_string(),
289            ));
290        }
291
292        // Mock implementation - in reality would call cblas_sgemm
293        for i in 0..m {
294            for j in 0..n {
295                let mut sum = 0.0;
296                for l in 0..k {
297                    sum += a[i * k + l] * b[l * n + j];
298                }
299                c[i * n + j] = alpha * sum + beta * c[i * n + j];
300            }
301        }
302        Ok(())
303    }
304}
305
306/// Mock implementation for OpenBLAS adapter
307#[derive(Debug, Clone)]
308pub struct OpenBlasAdapter {
309    initialized: bool,
310}
311
312impl OpenBlasAdapter {
313    pub fn new() -> Self {
314        Self { initialized: false }
315    }
316}
317
318impl Default for OpenBlasAdapter {
319    fn default() -> Self {
320        Self::new()
321    }
322}
323
324impl ExternalLibrary for OpenBlasAdapter {
325    fn name(&self) -> &str {
326        "OpenBLAS"
327    }
328
329    fn version(&self) -> &str {
330        "0.3.21"
331    }
332
333    fn is_available(&self) -> bool {
334        // Mock: assume available for demonstration
335        true
336    }
337
338    fn initialize(&mut self) -> ExternalResult<()> {
339        self.initialized = true;
340        Ok(())
341    }
342
343    fn supported_operations(&self) -> Vec<String> {
344        vec![
345            "dot".to_string(),
346            "scal".to_string(),
347            "axpy".to_string(),
348            "gemv".to_string(),
349            "gemm".to_string(),
350        ]
351    }
352}
353
354impl ExternalBlas for OpenBlasAdapter {
355    fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32> {
356        if !self.initialized {
357            return Err(SimdError::ExternalLibraryError(
358                "OpenBLAS not initialized".to_string(),
359            ));
360        }
361
362        if x.len() != y.len() {
363            return Err(SimdError::InvalidInput(
364                "Vector lengths must match".to_string(),
365            ));
366        }
367
368        // Mock implementation using our internal SIMD dot product
369        Ok(crate::vector::dot_product(x, y))
370    }
371
372    fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()> {
373        if !self.initialized {
374            return Err(SimdError::ExternalLibraryError(
375                "OpenBLAS not initialized".to_string(),
376            ));
377        }
378
379        // Mock implementation using our internal SIMD scale
380        crate::vector::scale(x, alpha);
381        Ok(())
382    }
383
384    fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()> {
385        if !self.initialized {
386            return Err(SimdError::ExternalLibraryError(
387                "OpenBLAS not initialized".to_string(),
388            ));
389        }
390
391        if x.len() != y.len() {
392            return Err(SimdError::InvalidInput(
393                "Vector lengths must match".to_string(),
394            ));
395        }
396
397        // Mock implementation
398        for (yi, &xi) in y.iter_mut().zip(x.iter()) {
399            *yi += alpha * xi;
400        }
401        Ok(())
402    }
403
404    fn gemv(
405        &self,
406        alpha: f32,
407        a: &[f32],
408        m: usize,
409        n: usize,
410        x: &[f32],
411        beta: f32,
412        y: &mut [f32],
413    ) -> ExternalResult<()> {
414        if !self.initialized {
415            return Err(SimdError::ExternalLibraryError(
416                "OpenBLAS not initialized".to_string(),
417            ));
418        }
419
420        if a.len() != m * n || x.len() != n || y.len() != m {
421            return Err(SimdError::InvalidInput(
422                "Matrix/vector dimension mismatch".to_string(),
423            ));
424        }
425
426        // Mock implementation using our internal matrix operations
427        use scirs2_autograd::ndarray::{Array1, Array2};
428        let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
429            .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
430        let x_vector = Array1::from_vec(x.to_vec());
431        let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
432
433        for (yi, &ri) in y.iter_mut().zip(result.iter()) {
434            *yi = alpha * ri + beta * (*yi);
435        }
436        Ok(())
437    }
438
439    fn gemm(
440        &self,
441        alpha: f32,
442        a: &[f32],
443        m: usize,
444        k: usize,
445        b: &[f32],
446        n: usize,
447        beta: f32,
448        c: &mut [f32],
449    ) -> ExternalResult<()> {
450        if !self.initialized {
451            return Err(SimdError::ExternalLibraryError(
452                "OpenBLAS not initialized".to_string(),
453            ));
454        }
455
456        if a.len() != m * k || b.len() != k * n || c.len() != m * n {
457            return Err(SimdError::InvalidInput(
458                "Matrix dimension mismatch".to_string(),
459            ));
460        }
461
462        // Mock implementation using our internal matrix operations
463        use scirs2_autograd::ndarray::Array2;
464        let a_matrix = Array2::from_shape_vec((m, k), a.to_vec())
465            .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix A shape".to_string()))?;
466        let b_matrix = Array2::from_shape_vec((k, n), b.to_vec())
467            .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix B shape".to_string()))?;
468        let result = crate::matrix::matrix_multiply_f32_simd(&a_matrix, &b_matrix);
469
470        for (ci, &ri) in c.iter_mut().zip(result.iter()) {
471            *ci = alpha * ri + beta * (*ci);
472        }
473        Ok(())
474    }
475}
476
477/// External library registry and management
478pub struct ExternalLibraryRegistry {
479    /// Registered BLAS libraries
480    blas_libraries: HashMap<String, Arc<Mutex<dyn ExternalBlas>>>,
481    /// Registered LAPACK libraries
482    lapack_libraries: HashMap<String, Arc<Mutex<dyn ExternalLapack>>>,
483    /// Registered FFT libraries
484    fft_libraries: HashMap<String, Arc<Mutex<dyn ExternalFft>>>,
485    /// Preferred library for each operation type
486    preferences: HashMap<String, String>,
487}
488
489impl ExternalLibraryRegistry {
490    /// Create a new registry
491    pub fn new() -> Self {
492        Self {
493            blas_libraries: HashMap::new(),
494            lapack_libraries: HashMap::new(),
495            fft_libraries: HashMap::new(),
496            preferences: HashMap::new(),
497        }
498    }
499
500    /// Register a BLAS library
501    pub fn register_blas<T: ExternalBlas + 'static>(&mut self, library: T) {
502        let name = library.name().to_string();
503        self.blas_libraries
504            .insert(name.clone(), Arc::new(Mutex::new(library)));
505
506        // Set as default if none set
507        if !self.preferences.contains_key("blas") {
508            self.preferences.insert("blas".to_string(), name);
509        }
510    }
511
512    /// Register a LAPACK library
513    pub fn register_lapack<T: ExternalLapack + 'static>(&mut self, library: T) {
514        let name = library.name().to_string();
515        self.lapack_libraries
516            .insert(name.clone(), Arc::new(Mutex::new(library)));
517
518        // Set as default if none set
519        if !self.preferences.contains_key("lapack") {
520            self.preferences.insert("lapack".to_string(), name);
521        }
522    }
523
524    /// Register an FFT library
525    pub fn register_fft<T: ExternalFft + 'static>(&mut self, library: T) {
526        let name = library.name().to_string();
527        self.fft_libraries
528            .insert(name.clone(), Arc::new(Mutex::new(library)));
529
530        // Set as default if none set
531        if !self.preferences.contains_key("fft") {
532            self.preferences.insert("fft".to_string(), name);
533        }
534    }
535
536    /// Set preferred library for operation type
537    pub fn set_preference(
538        &mut self,
539        operation_type: &str,
540        library_name: &str,
541    ) -> ExternalResult<()> {
542        match operation_type {
543            "blas" => {
544                if !self.blas_libraries.contains_key(library_name) {
545                    return Err(SimdError::ExternalLibraryError(format!(
546                        "BLAS library '{}' not registered",
547                        library_name
548                    )));
549                }
550            }
551            "lapack" => {
552                if !self.lapack_libraries.contains_key(library_name) {
553                    return Err(SimdError::ExternalLibraryError(format!(
554                        "LAPACK library '{}' not registered",
555                        library_name
556                    )));
557                }
558            }
559            "fft" => {
560                if !self.fft_libraries.contains_key(library_name) {
561                    return Err(SimdError::ExternalLibraryError(format!(
562                        "FFT library '{}' not registered",
563                        library_name
564                    )));
565                }
566            }
567            _ => {
568                return Err(SimdError::InvalidInput(format!(
569                    "Unknown operation type: {}",
570                    operation_type
571                )));
572            }
573        }
574
575        self.preferences
576            .insert(operation_type.to_string(), library_name.to_string());
577        Ok(())
578    }
579
580    /// Get preferred BLAS library
581    pub fn get_blas(&self) -> Option<Arc<Mutex<dyn ExternalBlas>>> {
582        self.preferences
583            .get("blas")
584            .and_then(|name| self.blas_libraries.get(name))
585            .cloned()
586    }
587
588    /// Get preferred LAPACK library
589    pub fn get_lapack(&self) -> Option<Arc<Mutex<dyn ExternalLapack>>> {
590        self.preferences
591            .get("lapack")
592            .and_then(|name| self.lapack_libraries.get(name))
593            .cloned()
594    }
595
596    /// Get preferred FFT library
597    pub fn get_fft(&self) -> Option<Arc<Mutex<dyn ExternalFft>>> {
598        self.preferences
599            .get("fft")
600            .and_then(|name| self.fft_libraries.get(name))
601            .cloned()
602    }
603
604    /// List all registered libraries
605    pub fn list_libraries(&self) -> Vec<String> {
606        let mut libraries = Vec::new();
607        libraries.extend(self.blas_libraries.keys().cloned());
608        libraries.extend(self.lapack_libraries.keys().cloned());
609        libraries.extend(self.fft_libraries.keys().cloned());
610        libraries.sort();
611        libraries.dedup();
612        libraries
613    }
614
615    /// Check library availability
616    pub fn check_availability(&self) -> HashMap<String, bool> {
617        let mut availability = HashMap::new();
618
619        #[cfg(not(feature = "no-std"))]
620        {
621            for (name, library) in &self.blas_libraries {
622                availability.insert(
623                    name.clone(),
624                    library
625                        .lock()
626                        .expect("lock should not be poisoned")
627                        .is_available(),
628                );
629            }
630
631            for (name, library) in &self.lapack_libraries {
632                availability.insert(
633                    name.clone(),
634                    library
635                        .lock()
636                        .expect("lock should not be poisoned")
637                        .is_available(),
638                );
639            }
640
641            for (name, library) in &self.fft_libraries {
642                availability.insert(
643                    name.clone(),
644                    library
645                        .lock()
646                        .expect("lock should not be poisoned")
647                        .is_available(),
648                );
649            }
650        }
651
652        #[cfg(feature = "no-std")]
653        {
654            for (name, library) in &self.blas_libraries {
655                availability.insert(name.clone(), library.lock().is_available());
656            }
657
658            for (name, library) in &self.lapack_libraries {
659                availability.insert(name.clone(), library.lock().is_available());
660            }
661
662            for (name, library) in &self.fft_libraries {
663                availability.insert(name.clone(), library.lock().is_available());
664            }
665        }
666
667        availability
668    }
669}
670
671impl Default for ExternalLibraryRegistry {
672    fn default() -> Self {
673        let mut registry = Self::new();
674
675        // Register default adapters
676        registry.register_blas(OpenBlasAdapter::new());
677        registry.register_blas(MklAdapter::new());
678
679        registry
680    }
681}
682
683/// Global external library registry
684#[cfg(not(feature = "no-std"))]
685static EXTERNAL_REGISTRY: once_cell::sync::Lazy<std::sync::Mutex<ExternalLibraryRegistry>> =
686    once_cell::sync::Lazy::new(|| std::sync::Mutex::new(ExternalLibraryRegistry::default()));
687
688#[cfg(feature = "no-std")]
689static EXTERNAL_REGISTRY: spin::Once<spin::Mutex<ExternalLibraryRegistry>> = spin::Once::new();
690
691/// Get the global external library registry
692#[cfg(not(feature = "no-std"))]
693pub fn get_registry() -> &'static std::sync::Mutex<ExternalLibraryRegistry> {
694    &EXTERNAL_REGISTRY
695}
696
697#[cfg(feature = "no-std")]
698pub fn get_registry() -> &'static spin::Mutex<ExternalLibraryRegistry> {
699    EXTERNAL_REGISTRY.call_once(|| spin::Mutex::new(ExternalLibraryRegistry::default()))
700}
701
702/// Perform dot product using external BLAS if available, fallback to internal
703pub fn external_dot(x: &[f32], y: &[f32]) -> ExternalResult<f32> {
704    #[cfg(not(feature = "no-std"))]
705    {
706        if let Some(blas) = get_registry()
707            .lock()
708            .expect("lock should not be poisoned")
709            .get_blas()
710        {
711            // Try to use external BLAS, but fallback to internal if it fails
712            match blas
713                .lock()
714                .expect("matrix dimensions should be compatible for dot product")
715                .dot(x, y)
716            {
717                Ok(result) => Ok(result),
718                Err(_) => {
719                    // Fall back to internal implementation if external library fails
720                    Ok(crate::vector::dot_product(x, y))
721                }
722            }
723        } else {
724            Ok(crate::vector::dot_product(x, y))
725        }
726    }
727    #[cfg(feature = "no-std")]
728    {
729        if let Some(blas) = get_registry().lock().get_blas() {
730            // Try to use external BLAS, but fallback to internal if it fails
731            match blas.lock().dot(x, y) {
732                Ok(result) => Ok(result),
733                Err(_) => {
734                    // Fall back to internal implementation if external library fails
735                    Ok(crate::vector::dot_product(x, y))
736                }
737            }
738        } else {
739            Ok(crate::vector::dot_product(x, y))
740        }
741    }
742}
743
744/// Perform matrix-vector multiplication using external BLAS if available
745pub fn external_gemv(
746    alpha: f32,
747    a: &[f32],
748    m: usize,
749    n: usize,
750    x: &[f32],
751    beta: f32,
752    y: &mut [f32],
753) -> ExternalResult<()> {
754    #[cfg(not(feature = "no-std"))]
755    {
756        if let Some(blas) = get_registry()
757            .lock()
758            .expect("lock should not be poisoned")
759            .get_blas()
760        {
761            // Try to use external BLAS, but fallback to internal if it fails
762            match blas
763                .lock()
764                .expect("lock should not be poisoned")
765                .gemv(alpha, a, m, n, x, beta, y)
766            {
767                Ok(()) => Ok(()),
768                Err(_) => {
769                    // Fall back to internal implementation if external library fails
770                    use scirs2_autograd::ndarray::{Array1, Array2};
771                    let a_matrix = Array2::from_shape_vec((m, n), a.to_vec()).map_err(|_| {
772                        SimdError::ExternalLibraryError("Invalid matrix shape".to_string())
773                    })?;
774                    let x_vector = Array1::from_vec(x.to_vec());
775                    let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
776
777                    for (yi, &ri) in y.iter_mut().zip(result.iter()) {
778                        *yi = alpha * ri + beta * (*yi);
779                    }
780                    Ok(())
781                }
782            }
783        } else {
784            // Fallback to internal implementation
785            use scirs2_autograd::ndarray::{Array1, Array2};
786            let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
787                .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
788            let x_vector = Array1::from_vec(x.to_vec());
789            let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
790
791            for (yi, &ri) in y.iter_mut().zip(result.iter()) {
792                *yi = alpha * ri + beta * (*yi);
793            }
794            Ok(())
795        }
796    }
797    #[cfg(feature = "no-std")]
798    {
799        if let Some(blas) = get_registry().lock().get_blas() {
800            // Try to use external BLAS, but fallback to internal if it fails
801            match blas.lock().gemv(alpha, a, m, n, x, beta, y) {
802                Ok(()) => Ok(()),
803                Err(_) => {
804                    // Fall back to internal implementation if external library fails
805                    use scirs2_autograd::ndarray::{Array1, Array2};
806                    let a_matrix = Array2::from_shape_vec((m, n), a.to_vec()).map_err(|_| {
807                        SimdError::ExternalLibraryError("Invalid matrix shape".to_string())
808                    })?;
809                    let x_vector = Array1::from_vec(x.to_vec());
810                    let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
811
812                    for (yi, &ri) in y.iter_mut().zip(result.iter()) {
813                        *yi = alpha * ri + beta * (*yi);
814                    }
815                    Ok(())
816                }
817            }
818        } else {
819            // Fallback to internal implementation
820            use scirs2_autograd::ndarray::{Array1, Array2};
821            let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
822                .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
823            let x_vector = Array1::from_vec(x.to_vec());
824            let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
825
826            for (yi, &ri) in y.iter_mut().zip(result.iter()) {
827                *yi = alpha * ri + beta * (*yi);
828            }
829            Ok(())
830        }
831    }
832}
833
834#[allow(non_snake_case)]
835#[cfg(all(test, not(feature = "no-std")))]
836mod tests {
837    use super::*;
838
839    #[cfg(feature = "no-std")]
840    use alloc::{
841        string::{String, ToString},
842        vec,
843        vec::Vec,
844    };
845
846    #[test]
847    fn test_mkl_adapter_creation() {
848        let adapter = MklAdapter::new();
849        assert_eq!(adapter.name(), "Intel MKL");
850        assert_eq!(adapter.version(), "2024.0");
851        assert!(!adapter.is_available()); // Mock: not available
852    }
853
854    #[test]
855    fn test_openblas_adapter_creation() {
856        let adapter = OpenBlasAdapter::new();
857        assert_eq!(adapter.name(), "OpenBLAS");
858        assert_eq!(adapter.version(), "0.3.21");
859        assert!(adapter.is_available()); // Mock: available
860    }
861
862    #[test]
863    fn test_openblas_initialization() {
864        let mut adapter = OpenBlasAdapter::new();
865        assert!(adapter.initialize().is_ok());
866    }
867
868    #[test]
869    fn test_openblas_dot_product() {
870        let mut adapter = OpenBlasAdapter::new();
871        adapter.initialize().expect("operation should succeed");
872
873        let x = vec![1.0, 2.0, 3.0];
874        let y = vec![4.0, 5.0, 6.0];
875        let result = adapter
876            .dot(&x, &y)
877            .expect("matrix dimensions should be compatible for dot product");
878
879        assert_eq!(result, 32.0); // 1*4 + 2*5 + 3*6
880    }
881
882    #[test]
883    fn test_openblas_scal() {
884        let mut adapter = OpenBlasAdapter::new();
885        adapter.initialize().expect("operation should succeed");
886
887        let mut x = vec![1.0, 2.0, 3.0];
888        adapter.scal(2.0, &mut x).expect("operation should succeed");
889
890        assert_eq!(x, vec![2.0, 4.0, 6.0]);
891    }
892
893    #[test]
894    fn test_registry_blas_registration() {
895        let mut registry = ExternalLibraryRegistry::new();
896        let adapter = OpenBlasAdapter::new();
897
898        registry.register_blas(adapter);
899
900        assert!(registry.get_blas().is_some());
901        assert_eq!(registry.list_libraries(), vec!["OpenBLAS"]);
902    }
903
904    #[test]
905    fn test_registry_preferences() {
906        let mut registry = ExternalLibraryRegistry::new();
907        let adapter1 = OpenBlasAdapter::new();
908        let adapter2 = MklAdapter::new();
909
910        registry.register_blas(adapter1);
911        registry.register_blas(adapter2);
912
913        // Set preference to MKL
914        registry
915            .set_preference("blas", "Intel MKL")
916            .expect("operation should succeed");
917
918        // Should fail for unknown library
919        assert!(registry.set_preference("blas", "Unknown").is_err());
920    }
921
922    #[test]
923    fn test_registry_availability_check() {
924        let registry = ExternalLibraryRegistry::default();
925        let availability = registry.check_availability();
926
927        // OpenBLAS should be available (mock), MKL should not
928        assert_eq!(availability.get("OpenBLAS"), Some(&true));
929        assert_eq!(availability.get("Intel MKL"), Some(&false));
930    }
931
932    #[test]
933    fn test_external_dot_fallback() {
934        // This should fallback to internal implementation
935        let x = vec![1.0, 2.0, 3.0];
936        let y = vec![4.0, 5.0, 6.0];
937        let result = external_dot(&x, &y).expect("operation should succeed");
938
939        assert_eq!(result, 32.0);
940    }
941
942    #[test]
943    fn test_invalid_dimensions() {
944        let mut adapter = OpenBlasAdapter::new();
945        adapter.initialize().expect("operation should succeed");
946
947        let x = vec![1.0, 2.0];
948        let y = vec![3.0, 4.0, 5.0];
949
950        assert!(adapter.dot(&x, &y).is_err());
951    }
952
953    #[test]
954    fn test_uninitialized_adapter() {
955        let adapter = OpenBlasAdapter::new();
956        let x = vec![1.0, 2.0, 3.0];
957        let y = vec![4.0, 5.0, 6.0];
958
959        assert!(adapter.dot(&x, &y).is_err());
960    }
961}