blas_array2/util/
blas_traits.rs

1use crate::util::*;
2use ndarray::Dimension;
3use num_complex::*;
4use num_traits::*;
5
6#[allow(non_camel_case_types)]
7pub type c32 = Complex<f32>;
8#[allow(non_camel_case_types)]
9pub type c64 = Complex<f64>;
10
11/// Trait for defining real part float types
12pub trait BLASFloat: Num + Copy {
13    type RealFloat: BLASFloat;
14    fn is_complex() -> bool;
15    fn conj(x: Self) -> Self;
16    fn from_real(x: Self::RealFloat) -> Self;
17}
18
19impl BLASFloat for f32 {
20    type RealFloat = f32;
21    #[inline]
22    fn is_complex() -> bool {
23        false
24    }
25    #[inline]
26    fn conj(x: Self) -> Self {
27        x
28    }
29    #[inline]
30    fn from_real(x: Self::RealFloat) -> Self {
31        x
32    }
33}
34
35impl BLASFloat for f64 {
36    type RealFloat = f64;
37    #[inline]
38    fn is_complex() -> bool {
39        false
40    }
41    #[inline]
42    fn conj(x: Self) -> Self {
43        x
44    }
45    #[inline]
46    fn from_real(x: Self::RealFloat) -> Self {
47        x
48    }
49}
50
51impl BLASFloat for c32 {
52    type RealFloat = f32;
53    #[inline]
54    fn is_complex() -> bool {
55        true
56    }
57    #[inline]
58    fn conj(x: Self) -> Self {
59        x.conj()
60    }
61    #[inline]
62    fn from_real(x: Self::RealFloat) -> Self {
63        c32::new(x, 0.0)
64    }
65}
66
67impl BLASFloat for c64 {
68    type RealFloat = f64;
69    #[inline]
70    fn is_complex() -> bool {
71        true
72    }
73    #[inline]
74    fn conj(x: Self) -> Self {
75        x.conj()
76    }
77    #[inline]
78    fn from_real(x: Self::RealFloat) -> Self {
79        c64::new(x, 0.0)
80    }
81}
82
83/// Trait for BLAS drivers
84pub trait BLASDriver<'c, F, D>
85where
86    D: Dimension,
87{
88    fn run_blas(self) -> Result<ArrayOut<'c, F, D>, BLASError>;
89}
90
91/// Trait for BLAS builder prototypes
92pub trait BLASBuilder_<'c, F, D>
93where
94    D: Dimension,
95{
96    fn driver(self) -> Result<impl BLASDriver<'c, F, D>, BLASError>;
97}
98
99pub trait BLASBuilder<'c, F, D>
100where
101    D: Dimension,
102{
103    fn run(self) -> Result<ArrayOut<'c, F, D>, BLASError>;
104}
105
106// Following test is assisted by DeepSeek
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_f32_blasfloat() {
113        let x = 3.0_f32;
114        assert_eq!(<f32 as BLASFloat>::is_complex(), false);
115        assert_eq!(<f32 as BLASFloat>::conj(x), x);
116        assert_eq!(<f32 as BLASFloat>::from_real(x), x);
117    }
118
119    #[test]
120    fn test_f64_blasfloat() {
121        let x = 3.0_f64;
122        assert_eq!(<f64 as BLASFloat>::is_complex(), false);
123        assert_eq!(<f64 as BLASFloat>::conj(x), x);
124        assert_eq!(<f64 as BLASFloat>::from_real(x), x);
125    }
126
127    #[test]
128    fn test_c32_blasfloat() {
129        let x = Complex::new(3.0_f32, 4.0_f32);
130        assert_eq!(<c32 as BLASFloat>::is_complex(), true);
131        assert_eq!(<c32 as BLASFloat>::conj(x), x.conj());
132        assert_eq!(<c32 as BLASFloat>::from_real(3.0_f32), Complex::new(3.0_f32, 0.0_f32));
133    }
134
135    #[test]
136    fn test_c64_blasfloat() {
137        let x = Complex::new(3.0_f64, 4.0_f64);
138        assert_eq!(<c64 as BLASFloat>::is_complex(), true);
139        assert_eq!(<c64 as BLASFloat>::conj(x), x.conj());
140        assert_eq!(<c64 as BLASFloat>::from_real(3.0_f64), Complex::new(3.0_f64, 0.0_f64));
141    }
142}