1use crate::traits::SimdError;
7
8#[cfg(feature = "no-std")]
9use alloc::{
10 collections::BTreeMap as HashMap,
11 format,
12 string::{String, ToString},
13 sync::Arc,
14 vec,
15 vec::Vec,
16};
17#[cfg(not(feature = "no-std"))]
18use std::collections::HashMap;
19#[cfg(not(feature = "no-std"))]
20use std::string::ToString;
21#[cfg(not(feature = "no-std"))]
22use std::sync::Arc;
23
24#[cfg(feature = "no-std")]
25use spin::Mutex;
26#[cfg(not(feature = "no-std"))]
27use std::sync::Mutex;
28
29pub type ExternalResult<T> = Result<T, SimdError>;
31
32pub trait ExternalLibrary: Send + Sync {
34 fn name(&self) -> &str;
36
37 fn version(&self) -> &str;
39
40 fn is_available(&self) -> bool;
42
43 fn initialize(&mut self) -> ExternalResult<()>;
45
46 fn supported_operations(&self) -> Vec<String>;
48
49 fn supports_operation(&self, operation: &str) -> bool {
51 self.supported_operations().contains(&operation.to_string())
52 }
53}
54
55pub trait ExternalBlas: ExternalLibrary {
57 fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32>;
59
60 fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()>;
62
63 fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()>;
65
66 #[allow(clippy::too_many_arguments)] fn gemv(
69 &self,
70 alpha: f32,
71 a: &[f32],
72 m: usize,
73 n: usize,
74 x: &[f32],
75 beta: f32,
76 y: &mut [f32],
77 ) -> ExternalResult<()>;
78
79 #[allow(clippy::too_many_arguments)] fn gemm(
82 &self,
83 alpha: f32,
84 a: &[f32],
85 m: usize,
86 k: usize,
87 b: &[f32],
88 n: usize,
89 beta: f32,
90 c: &mut [f32],
91 ) -> ExternalResult<()>;
92}
93
94pub trait ExternalLapack: ExternalLibrary {
96 fn lu_decomposition(&self, a: &mut [f32], m: usize, n: usize) -> ExternalResult<Vec<i32>>;
98
99 fn qr_decomposition(&self, a: &mut [f32], m: usize, n: usize) -> ExternalResult<Vec<f32>>;
101
102 fn svd(
104 &self,
105 a: &mut [f32],
106 m: usize,
107 n: usize,
108 ) -> ExternalResult<(Vec<f32>, Vec<f32>, Vec<f32>)>;
109
110 fn eigenvalues(&self, a: &mut [f32], n: usize) -> ExternalResult<Vec<f32>>;
112}
113
114pub trait ExternalFft: ExternalLibrary {
116 fn rfft(&self, input: &[f32]) -> ExternalResult<Vec<f32>>;
118
119 fn irfft(&self, input: &[f32]) -> ExternalResult<Vec<f32>>;
121
122 fn cfft(&self, real: &[f32], imag: &[f32]) -> ExternalResult<(Vec<f32>, Vec<f32>)>;
124
125 fn icfft(&self, real: &[f32], imag: &[f32]) -> ExternalResult<(Vec<f32>, Vec<f32>)>;
127}
128
129#[derive(Debug, Clone)]
131pub struct MklAdapter {
132 initialized: bool,
133}
134
135impl MklAdapter {
136 pub fn new() -> Self {
137 Self { initialized: false }
138 }
139}
140
141impl Default for MklAdapter {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl ExternalLibrary for MklAdapter {
148 fn name(&self) -> &str {
149 "Intel MKL"
150 }
151
152 fn version(&self) -> &str {
153 "2024.0"
154 }
155
156 fn is_available(&self) -> bool {
157 false }
160
161 fn initialize(&mut self) -> ExternalResult<()> {
162 if !self.is_available() {
163 return Err(SimdError::ExternalLibraryError(
164 "MKL not available".to_string(),
165 ));
166 }
167 self.initialized = true;
168 Ok(())
169 }
170
171 fn supported_operations(&self) -> Vec<String> {
172 vec![
173 "dot".to_string(),
174 "scal".to_string(),
175 "axpy".to_string(),
176 "gemv".to_string(),
177 "gemm".to_string(),
178 "lu".to_string(),
179 "qr".to_string(),
180 "svd".to_string(),
181 "fft".to_string(),
182 ]
183 }
184}
185
186impl ExternalBlas for MklAdapter {
187 fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32> {
188 if !self.initialized {
189 return Err(SimdError::ExternalLibraryError(
190 "MKL not initialized".to_string(),
191 ));
192 }
193
194 if x.len() != y.len() {
195 return Err(SimdError::InvalidInput(
196 "Vector lengths must match".to_string(),
197 ));
198 }
199
200 Ok(x.iter().zip(y.iter()).map(|(a, b)| a * b).sum())
202 }
203
204 fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()> {
205 if !self.initialized {
206 return Err(SimdError::ExternalLibraryError(
207 "MKL not initialized".to_string(),
208 ));
209 }
210
211 x.iter_mut().for_each(|v| *v *= alpha);
213 Ok(())
214 }
215
216 fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()> {
217 if !self.initialized {
218 return Err(SimdError::ExternalLibraryError(
219 "MKL not initialized".to_string(),
220 ));
221 }
222
223 if x.len() != y.len() {
224 return Err(SimdError::InvalidInput(
225 "Vector lengths must match".to_string(),
226 ));
227 }
228
229 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
231 *yi += alpha * xi;
232 }
233 Ok(())
234 }
235
236 fn gemv(
237 &self,
238 alpha: f32,
239 a: &[f32],
240 m: usize,
241 n: usize,
242 x: &[f32],
243 beta: f32,
244 y: &mut [f32],
245 ) -> ExternalResult<()> {
246 if !self.initialized {
247 return Err(SimdError::ExternalLibraryError(
248 "MKL not initialized".to_string(),
249 ));
250 }
251
252 if a.len() != m * n || x.len() != n || y.len() != m {
253 return Err(SimdError::InvalidInput(
254 "Matrix/vector dimension mismatch".to_string(),
255 ));
256 }
257
258 for i in 0..m {
260 let mut sum = 0.0;
261 for j in 0..n {
262 sum += a[i * n + j] * x[j];
263 }
264 y[i] = alpha * sum + beta * y[i];
265 }
266 Ok(())
267 }
268
269 fn gemm(
270 &self,
271 alpha: f32,
272 a: &[f32],
273 m: usize,
274 k: usize,
275 b: &[f32],
276 n: usize,
277 beta: f32,
278 c: &mut [f32],
279 ) -> ExternalResult<()> {
280 if !self.initialized {
281 return Err(SimdError::ExternalLibraryError(
282 "MKL not initialized".to_string(),
283 ));
284 }
285
286 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
287 return Err(SimdError::InvalidInput(
288 "Matrix dimension mismatch".to_string(),
289 ));
290 }
291
292 for i in 0..m {
294 for j in 0..n {
295 let mut sum = 0.0;
296 for l in 0..k {
297 sum += a[i * k + l] * b[l * n + j];
298 }
299 c[i * n + j] = alpha * sum + beta * c[i * n + j];
300 }
301 }
302 Ok(())
303 }
304}
305
306#[derive(Debug, Clone)]
308pub struct OpenBlasAdapter {
309 initialized: bool,
310}
311
312impl OpenBlasAdapter {
313 pub fn new() -> Self {
314 Self { initialized: false }
315 }
316}
317
318impl Default for OpenBlasAdapter {
319 fn default() -> Self {
320 Self::new()
321 }
322}
323
324impl ExternalLibrary for OpenBlasAdapter {
325 fn name(&self) -> &str {
326 "OpenBLAS"
327 }
328
329 fn version(&self) -> &str {
330 "0.3.21"
331 }
332
333 fn is_available(&self) -> bool {
334 true
336 }
337
338 fn initialize(&mut self) -> ExternalResult<()> {
339 self.initialized = true;
340 Ok(())
341 }
342
343 fn supported_operations(&self) -> Vec<String> {
344 vec![
345 "dot".to_string(),
346 "scal".to_string(),
347 "axpy".to_string(),
348 "gemv".to_string(),
349 "gemm".to_string(),
350 ]
351 }
352}
353
354impl ExternalBlas for OpenBlasAdapter {
355 fn dot(&self, x: &[f32], y: &[f32]) -> ExternalResult<f32> {
356 if !self.initialized {
357 return Err(SimdError::ExternalLibraryError(
358 "OpenBLAS not initialized".to_string(),
359 ));
360 }
361
362 if x.len() != y.len() {
363 return Err(SimdError::InvalidInput(
364 "Vector lengths must match".to_string(),
365 ));
366 }
367
368 Ok(crate::vector::dot_product(x, y))
370 }
371
372 fn scal(&self, alpha: f32, x: &mut [f32]) -> ExternalResult<()> {
373 if !self.initialized {
374 return Err(SimdError::ExternalLibraryError(
375 "OpenBLAS not initialized".to_string(),
376 ));
377 }
378
379 crate::vector::scale(x, alpha);
381 Ok(())
382 }
383
384 fn axpy(&self, alpha: f32, x: &[f32], y: &mut [f32]) -> ExternalResult<()> {
385 if !self.initialized {
386 return Err(SimdError::ExternalLibraryError(
387 "OpenBLAS not initialized".to_string(),
388 ));
389 }
390
391 if x.len() != y.len() {
392 return Err(SimdError::InvalidInput(
393 "Vector lengths must match".to_string(),
394 ));
395 }
396
397 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
399 *yi += alpha * xi;
400 }
401 Ok(())
402 }
403
404 fn gemv(
405 &self,
406 alpha: f32,
407 a: &[f32],
408 m: usize,
409 n: usize,
410 x: &[f32],
411 beta: f32,
412 y: &mut [f32],
413 ) -> ExternalResult<()> {
414 if !self.initialized {
415 return Err(SimdError::ExternalLibraryError(
416 "OpenBLAS not initialized".to_string(),
417 ));
418 }
419
420 if a.len() != m * n || x.len() != n || y.len() != m {
421 return Err(SimdError::InvalidInput(
422 "Matrix/vector dimension mismatch".to_string(),
423 ));
424 }
425
426 use scirs2_autograd::ndarray::{Array1, Array2};
428 let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
429 .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
430 let x_vector = Array1::from_vec(x.to_vec());
431 let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
432
433 for (yi, &ri) in y.iter_mut().zip(result.iter()) {
434 *yi = alpha * ri + beta * (*yi);
435 }
436 Ok(())
437 }
438
439 fn gemm(
440 &self,
441 alpha: f32,
442 a: &[f32],
443 m: usize,
444 k: usize,
445 b: &[f32],
446 n: usize,
447 beta: f32,
448 c: &mut [f32],
449 ) -> ExternalResult<()> {
450 if !self.initialized {
451 return Err(SimdError::ExternalLibraryError(
452 "OpenBLAS not initialized".to_string(),
453 ));
454 }
455
456 if a.len() != m * k || b.len() != k * n || c.len() != m * n {
457 return Err(SimdError::InvalidInput(
458 "Matrix dimension mismatch".to_string(),
459 ));
460 }
461
462 use scirs2_autograd::ndarray::Array2;
464 let a_matrix = Array2::from_shape_vec((m, k), a.to_vec())
465 .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix A shape".to_string()))?;
466 let b_matrix = Array2::from_shape_vec((k, n), b.to_vec())
467 .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix B shape".to_string()))?;
468 let result = crate::matrix::matrix_multiply_f32_simd(&a_matrix, &b_matrix);
469
470 for (ci, &ri) in c.iter_mut().zip(result.iter()) {
471 *ci = alpha * ri + beta * (*ci);
472 }
473 Ok(())
474 }
475}
476
477pub struct ExternalLibraryRegistry {
479 blas_libraries: HashMap<String, Arc<Mutex<dyn ExternalBlas>>>,
481 lapack_libraries: HashMap<String, Arc<Mutex<dyn ExternalLapack>>>,
483 fft_libraries: HashMap<String, Arc<Mutex<dyn ExternalFft>>>,
485 preferences: HashMap<String, String>,
487}
488
489impl ExternalLibraryRegistry {
490 pub fn new() -> Self {
492 Self {
493 blas_libraries: HashMap::new(),
494 lapack_libraries: HashMap::new(),
495 fft_libraries: HashMap::new(),
496 preferences: HashMap::new(),
497 }
498 }
499
500 pub fn register_blas<T: ExternalBlas + 'static>(&mut self, library: T) {
502 let name = library.name().to_string();
503 self.blas_libraries
504 .insert(name.clone(), Arc::new(Mutex::new(library)));
505
506 if !self.preferences.contains_key("blas") {
508 self.preferences.insert("blas".to_string(), name);
509 }
510 }
511
512 pub fn register_lapack<T: ExternalLapack + 'static>(&mut self, library: T) {
514 let name = library.name().to_string();
515 self.lapack_libraries
516 .insert(name.clone(), Arc::new(Mutex::new(library)));
517
518 if !self.preferences.contains_key("lapack") {
520 self.preferences.insert("lapack".to_string(), name);
521 }
522 }
523
524 pub fn register_fft<T: ExternalFft + 'static>(&mut self, library: T) {
526 let name = library.name().to_string();
527 self.fft_libraries
528 .insert(name.clone(), Arc::new(Mutex::new(library)));
529
530 if !self.preferences.contains_key("fft") {
532 self.preferences.insert("fft".to_string(), name);
533 }
534 }
535
536 pub fn set_preference(
538 &mut self,
539 operation_type: &str,
540 library_name: &str,
541 ) -> ExternalResult<()> {
542 match operation_type {
543 "blas" => {
544 if !self.blas_libraries.contains_key(library_name) {
545 return Err(SimdError::ExternalLibraryError(format!(
546 "BLAS library '{}' not registered",
547 library_name
548 )));
549 }
550 }
551 "lapack" => {
552 if !self.lapack_libraries.contains_key(library_name) {
553 return Err(SimdError::ExternalLibraryError(format!(
554 "LAPACK library '{}' not registered",
555 library_name
556 )));
557 }
558 }
559 "fft" => {
560 if !self.fft_libraries.contains_key(library_name) {
561 return Err(SimdError::ExternalLibraryError(format!(
562 "FFT library '{}' not registered",
563 library_name
564 )));
565 }
566 }
567 _ => {
568 return Err(SimdError::InvalidInput(format!(
569 "Unknown operation type: {}",
570 operation_type
571 )));
572 }
573 }
574
575 self.preferences
576 .insert(operation_type.to_string(), library_name.to_string());
577 Ok(())
578 }
579
580 pub fn get_blas(&self) -> Option<Arc<Mutex<dyn ExternalBlas>>> {
582 self.preferences
583 .get("blas")
584 .and_then(|name| self.blas_libraries.get(name))
585 .cloned()
586 }
587
588 pub fn get_lapack(&self) -> Option<Arc<Mutex<dyn ExternalLapack>>> {
590 self.preferences
591 .get("lapack")
592 .and_then(|name| self.lapack_libraries.get(name))
593 .cloned()
594 }
595
596 pub fn get_fft(&self) -> Option<Arc<Mutex<dyn ExternalFft>>> {
598 self.preferences
599 .get("fft")
600 .and_then(|name| self.fft_libraries.get(name))
601 .cloned()
602 }
603
604 pub fn list_libraries(&self) -> Vec<String> {
606 let mut libraries = Vec::new();
607 libraries.extend(self.blas_libraries.keys().cloned());
608 libraries.extend(self.lapack_libraries.keys().cloned());
609 libraries.extend(self.fft_libraries.keys().cloned());
610 libraries.sort();
611 libraries.dedup();
612 libraries
613 }
614
615 pub fn check_availability(&self) -> HashMap<String, bool> {
617 let mut availability = HashMap::new();
618
619 #[cfg(not(feature = "no-std"))]
620 {
621 for (name, library) in &self.blas_libraries {
622 availability.insert(
623 name.clone(),
624 library
625 .lock()
626 .expect("lock should not be poisoned")
627 .is_available(),
628 );
629 }
630
631 for (name, library) in &self.lapack_libraries {
632 availability.insert(
633 name.clone(),
634 library
635 .lock()
636 .expect("lock should not be poisoned")
637 .is_available(),
638 );
639 }
640
641 for (name, library) in &self.fft_libraries {
642 availability.insert(
643 name.clone(),
644 library
645 .lock()
646 .expect("lock should not be poisoned")
647 .is_available(),
648 );
649 }
650 }
651
652 #[cfg(feature = "no-std")]
653 {
654 for (name, library) in &self.blas_libraries {
655 availability.insert(name.clone(), library.lock().is_available());
656 }
657
658 for (name, library) in &self.lapack_libraries {
659 availability.insert(name.clone(), library.lock().is_available());
660 }
661
662 for (name, library) in &self.fft_libraries {
663 availability.insert(name.clone(), library.lock().is_available());
664 }
665 }
666
667 availability
668 }
669}
670
671impl Default for ExternalLibraryRegistry {
672 fn default() -> Self {
673 let mut registry = Self::new();
674
675 registry.register_blas(OpenBlasAdapter::new());
677 registry.register_blas(MklAdapter::new());
678
679 registry
680 }
681}
682
683#[cfg(not(feature = "no-std"))]
685static EXTERNAL_REGISTRY: once_cell::sync::Lazy<std::sync::Mutex<ExternalLibraryRegistry>> =
686 once_cell::sync::Lazy::new(|| std::sync::Mutex::new(ExternalLibraryRegistry::default()));
687
688#[cfg(feature = "no-std")]
689static EXTERNAL_REGISTRY: spin::Once<spin::Mutex<ExternalLibraryRegistry>> = spin::Once::new();
690
691#[cfg(not(feature = "no-std"))]
693pub fn get_registry() -> &'static std::sync::Mutex<ExternalLibraryRegistry> {
694 &EXTERNAL_REGISTRY
695}
696
697#[cfg(feature = "no-std")]
698pub fn get_registry() -> &'static spin::Mutex<ExternalLibraryRegistry> {
699 EXTERNAL_REGISTRY.call_once(|| spin::Mutex::new(ExternalLibraryRegistry::default()))
700}
701
702pub fn external_dot(x: &[f32], y: &[f32]) -> ExternalResult<f32> {
704 #[cfg(not(feature = "no-std"))]
705 {
706 if let Some(blas) = get_registry()
707 .lock()
708 .expect("lock should not be poisoned")
709 .get_blas()
710 {
711 match blas
713 .lock()
714 .expect("matrix dimensions should be compatible for dot product")
715 .dot(x, y)
716 {
717 Ok(result) => Ok(result),
718 Err(_) => {
719 Ok(crate::vector::dot_product(x, y))
721 }
722 }
723 } else {
724 Ok(crate::vector::dot_product(x, y))
725 }
726 }
727 #[cfg(feature = "no-std")]
728 {
729 if let Some(blas) = get_registry().lock().get_blas() {
730 match blas.lock().dot(x, y) {
732 Ok(result) => Ok(result),
733 Err(_) => {
734 Ok(crate::vector::dot_product(x, y))
736 }
737 }
738 } else {
739 Ok(crate::vector::dot_product(x, y))
740 }
741 }
742}
743
744pub fn external_gemv(
746 alpha: f32,
747 a: &[f32],
748 m: usize,
749 n: usize,
750 x: &[f32],
751 beta: f32,
752 y: &mut [f32],
753) -> ExternalResult<()> {
754 #[cfg(not(feature = "no-std"))]
755 {
756 if let Some(blas) = get_registry()
757 .lock()
758 .expect("lock should not be poisoned")
759 .get_blas()
760 {
761 match blas
763 .lock()
764 .expect("lock should not be poisoned")
765 .gemv(alpha, a, m, n, x, beta, y)
766 {
767 Ok(()) => Ok(()),
768 Err(_) => {
769 use scirs2_autograd::ndarray::{Array1, Array2};
771 let a_matrix = Array2::from_shape_vec((m, n), a.to_vec()).map_err(|_| {
772 SimdError::ExternalLibraryError("Invalid matrix shape".to_string())
773 })?;
774 let x_vector = Array1::from_vec(x.to_vec());
775 let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
776
777 for (yi, &ri) in y.iter_mut().zip(result.iter()) {
778 *yi = alpha * ri + beta * (*yi);
779 }
780 Ok(())
781 }
782 }
783 } else {
784 use scirs2_autograd::ndarray::{Array1, Array2};
786 let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
787 .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
788 let x_vector = Array1::from_vec(x.to_vec());
789 let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
790
791 for (yi, &ri) in y.iter_mut().zip(result.iter()) {
792 *yi = alpha * ri + beta * (*yi);
793 }
794 Ok(())
795 }
796 }
797 #[cfg(feature = "no-std")]
798 {
799 if let Some(blas) = get_registry().lock().get_blas() {
800 match blas.lock().gemv(alpha, a, m, n, x, beta, y) {
802 Ok(()) => Ok(()),
803 Err(_) => {
804 use scirs2_autograd::ndarray::{Array1, Array2};
806 let a_matrix = Array2::from_shape_vec((m, n), a.to_vec()).map_err(|_| {
807 SimdError::ExternalLibraryError("Invalid matrix shape".to_string())
808 })?;
809 let x_vector = Array1::from_vec(x.to_vec());
810 let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
811
812 for (yi, &ri) in y.iter_mut().zip(result.iter()) {
813 *yi = alpha * ri + beta * (*yi);
814 }
815 Ok(())
816 }
817 }
818 } else {
819 use scirs2_autograd::ndarray::{Array1, Array2};
821 let a_matrix = Array2::from_shape_vec((m, n), a.to_vec())
822 .map_err(|_| SimdError::ExternalLibraryError("Invalid matrix shape".to_string()))?;
823 let x_vector = Array1::from_vec(x.to_vec());
824 let result = crate::matrix::matrix_vector_multiply_f32(&a_matrix, &x_vector);
825
826 for (yi, &ri) in y.iter_mut().zip(result.iter()) {
827 *yi = alpha * ri + beta * (*yi);
828 }
829 Ok(())
830 }
831 }
832}
833
834#[allow(non_snake_case)]
835#[cfg(all(test, not(feature = "no-std")))]
836mod tests {
837 use super::*;
838
839 #[cfg(feature = "no-std")]
840 use alloc::{
841 string::{String, ToString},
842 vec,
843 vec::Vec,
844 };
845
846 #[test]
847 fn test_mkl_adapter_creation() {
848 let adapter = MklAdapter::new();
849 assert_eq!(adapter.name(), "Intel MKL");
850 assert_eq!(adapter.version(), "2024.0");
851 assert!(!adapter.is_available()); }
853
854 #[test]
855 fn test_openblas_adapter_creation() {
856 let adapter = OpenBlasAdapter::new();
857 assert_eq!(adapter.name(), "OpenBLAS");
858 assert_eq!(adapter.version(), "0.3.21");
859 assert!(adapter.is_available()); }
861
862 #[test]
863 fn test_openblas_initialization() {
864 let mut adapter = OpenBlasAdapter::new();
865 assert!(adapter.initialize().is_ok());
866 }
867
868 #[test]
869 fn test_openblas_dot_product() {
870 let mut adapter = OpenBlasAdapter::new();
871 adapter.initialize().expect("operation should succeed");
872
873 let x = vec![1.0, 2.0, 3.0];
874 let y = vec![4.0, 5.0, 6.0];
875 let result = adapter
876 .dot(&x, &y)
877 .expect("matrix dimensions should be compatible for dot product");
878
879 assert_eq!(result, 32.0); }
881
882 #[test]
883 fn test_openblas_scal() {
884 let mut adapter = OpenBlasAdapter::new();
885 adapter.initialize().expect("operation should succeed");
886
887 let mut x = vec![1.0, 2.0, 3.0];
888 adapter.scal(2.0, &mut x).expect("operation should succeed");
889
890 assert_eq!(x, vec![2.0, 4.0, 6.0]);
891 }
892
893 #[test]
894 fn test_registry_blas_registration() {
895 let mut registry = ExternalLibraryRegistry::new();
896 let adapter = OpenBlasAdapter::new();
897
898 registry.register_blas(adapter);
899
900 assert!(registry.get_blas().is_some());
901 assert_eq!(registry.list_libraries(), vec!["OpenBLAS"]);
902 }
903
904 #[test]
905 fn test_registry_preferences() {
906 let mut registry = ExternalLibraryRegistry::new();
907 let adapter1 = OpenBlasAdapter::new();
908 let adapter2 = MklAdapter::new();
909
910 registry.register_blas(adapter1);
911 registry.register_blas(adapter2);
912
913 registry
915 .set_preference("blas", "Intel MKL")
916 .expect("operation should succeed");
917
918 assert!(registry.set_preference("blas", "Unknown").is_err());
920 }
921
922 #[test]
923 fn test_registry_availability_check() {
924 let registry = ExternalLibraryRegistry::default();
925 let availability = registry.check_availability();
926
927 assert_eq!(availability.get("OpenBLAS"), Some(&true));
929 assert_eq!(availability.get("Intel MKL"), Some(&false));
930 }
931
932 #[test]
933 fn test_external_dot_fallback() {
934 let x = vec![1.0, 2.0, 3.0];
936 let y = vec![4.0, 5.0, 6.0];
937 let result = external_dot(&x, &y).expect("operation should succeed");
938
939 assert_eq!(result, 32.0);
940 }
941
942 #[test]
943 fn test_invalid_dimensions() {
944 let mut adapter = OpenBlasAdapter::new();
945 adapter.initialize().expect("operation should succeed");
946
947 let x = vec![1.0, 2.0];
948 let y = vec![3.0, 4.0, 5.0];
949
950 assert!(adapter.dot(&x, &y).is_err());
951 }
952
953 #[test]
954 fn test_uninitialized_adapter() {
955 let adapter = OpenBlasAdapter::new();
956 let x = vec![1.0, 2.0, 3.0];
957 let y = vec![4.0, 5.0, 6.0];
958
959 assert!(adapter.dot(&x, &y).is_err());
960 }
961}