blas_array2/util/
blas_traits.rs1use crate::util::*;
2use ndarray::Dimension;
3use num_complex::*;
4use num_traits::*;
5
6#[allow(non_camel_case_types)]
7pub type c32 = Complex<f32>;
8#[allow(non_camel_case_types)]
9pub type c64 = Complex<f64>;
10
11pub trait BLASFloat: Num + Copy {
13 type RealFloat: BLASFloat;
14 fn is_complex() -> bool;
15 fn conj(x: Self) -> Self;
16 fn from_real(x: Self::RealFloat) -> Self;
17}
18
19impl BLASFloat for f32 {
20 type RealFloat = f32;
21 #[inline]
22 fn is_complex() -> bool {
23 false
24 }
25 #[inline]
26 fn conj(x: Self) -> Self {
27 x
28 }
29 #[inline]
30 fn from_real(x: Self::RealFloat) -> Self {
31 x
32 }
33}
34
35impl BLASFloat for f64 {
36 type RealFloat = f64;
37 #[inline]
38 fn is_complex() -> bool {
39 false
40 }
41 #[inline]
42 fn conj(x: Self) -> Self {
43 x
44 }
45 #[inline]
46 fn from_real(x: Self::RealFloat) -> Self {
47 x
48 }
49}
50
51impl BLASFloat for c32 {
52 type RealFloat = f32;
53 #[inline]
54 fn is_complex() -> bool {
55 true
56 }
57 #[inline]
58 fn conj(x: Self) -> Self {
59 x.conj()
60 }
61 #[inline]
62 fn from_real(x: Self::RealFloat) -> Self {
63 c32::new(x, 0.0)
64 }
65}
66
67impl BLASFloat for c64 {
68 type RealFloat = f64;
69 #[inline]
70 fn is_complex() -> bool {
71 true
72 }
73 #[inline]
74 fn conj(x: Self) -> Self {
75 x.conj()
76 }
77 #[inline]
78 fn from_real(x: Self::RealFloat) -> Self {
79 c64::new(x, 0.0)
80 }
81}
82
83pub trait BLASDriver<'c, F, D>
85where
86 D: Dimension,
87{
88 fn run_blas(self) -> Result<ArrayOut<'c, F, D>, BLASError>;
89}
90
91pub trait BLASBuilder_<'c, F, D>
93where
94 D: Dimension,
95{
96 fn driver(self) -> Result<impl BLASDriver<'c, F, D>, BLASError>;
97}
98
99pub trait BLASBuilder<'c, F, D>
100where
101 D: Dimension,
102{
103 fn run(self) -> Result<ArrayOut<'c, F, D>, BLASError>;
104}
105
106#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_f32_blasfloat() {
113 let x = 3.0_f32;
114 assert_eq!(<f32 as BLASFloat>::is_complex(), false);
115 assert_eq!(<f32 as BLASFloat>::conj(x), x);
116 assert_eq!(<f32 as BLASFloat>::from_real(x), x);
117 }
118
119 #[test]
120 fn test_f64_blasfloat() {
121 let x = 3.0_f64;
122 assert_eq!(<f64 as BLASFloat>::is_complex(), false);
123 assert_eq!(<f64 as BLASFloat>::conj(x), x);
124 assert_eq!(<f64 as BLASFloat>::from_real(x), x);
125 }
126
127 #[test]
128 fn test_c32_blasfloat() {
129 let x = Complex::new(3.0_f32, 4.0_f32);
130 assert_eq!(<c32 as BLASFloat>::is_complex(), true);
131 assert_eq!(<c32 as BLASFloat>::conj(x), x.conj());
132 assert_eq!(<c32 as BLASFloat>::from_real(3.0_f32), Complex::new(3.0_f32, 0.0_f32));
133 }
134
135 #[test]
136 fn test_c64_blasfloat() {
137 let x = Complex::new(3.0_f64, 4.0_f64);
138 assert_eq!(<c64 as BLASFloat>::is_complex(), true);
139 assert_eq!(<c64 as BLASFloat>::conj(x), x.conj());
140 assert_eq!(<c64 as BLASFloat>::from_real(3.0_f64), Complex::new(3.0_f64, 0.0_f64));
141 }
142}