rstsr_common/
flags.rs

1//! Flags for the crate.
2
3use crate::prelude_dev::*;
4use core::ffi::c_char;
5use rstsr_cblas_base::*;
6
7/* #region changeable default */
8
9pub trait ChangeableDefault {
10    /// # Safety
11    ///
12    /// This function changes static mutable variable.
13    /// It is better applying cargo feature instead of using this function.
14    unsafe fn change_default(val: Self);
15    fn get_default() -> Self;
16}
17
18macro_rules! impl_changeable_default {
19    ($struct:ty, $val:ident, $default:expr) => {
20        static mut $val: $struct = $default;
21
22        impl ChangeableDefault for $struct {
23            unsafe fn change_default(val: Self) {
24                $val = val;
25            }
26
27            fn get_default() -> Self {
28                return unsafe { $val };
29            }
30        }
31
32        impl Default for $struct
33        where
34            Self: ChangeableDefault,
35        {
36            fn default() -> Self {
37                <$struct>::get_default()
38            }
39        }
40    };
41}
42
43/* #endregion */
44
45/* #region FlagOrder */
46
47/// The order of the tensor.
48///
49/// # Default
50///
51/// Default order depends on cargo feature `f_prefer`.
52/// If `f_prefer` is set, then [`FlagOrder::F`] is applied as default;
53/// otherwise [`FlagOrder::C`] is applied as default.
54///
55/// # IMPORTANT NOTE
56///
57/// F-prefer is not a stable feature currently! We develop only in C-prefer
58/// currently.
59#[repr(u8)]
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum FlagOrder {
62    /// row-major order.
63    C = 101,
64    /// column-major order.
65    F = 102,
66}
67
68#[allow(non_upper_case_globals)]
69impl FlagOrder {
70    pub const RowMajor: Self = FlagOrder::C;
71    pub const ColMajor: Self = FlagOrder::F;
72}
73
74#[allow(clippy::derivable_impls)]
75impl Default for FlagOrder {
76    fn default() -> Self {
77        if cfg!(feature = "col_major") {
78            return FlagOrder::F;
79        } else {
80            return FlagOrder::C;
81        }
82    }
83}
84
85/* #endregion */
86
87/* #region TensorIterOrder */
88
89/// The policy of the tensor iterator.
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum TensorIterOrder {
92    /// Row-major order.
93    ///
94    /// - absolute safe for array iteration
95    C,
96    /// Column-major order.
97    ///
98    /// - absolute safe for array iteration
99    F,
100    /// Automatically choose row/col-major order.
101    ///
102    /// - try c/f-contig first (also see [`TensorIterOrder::B`]),
103    /// - try c/f-prefer second (also see [`TensorIterOrder::C`], [`TensorIterOrder::F`]),
104    /// - otherwise [`FlagOrder::default()`], which is defined by crate feature `f_prefer`.
105    ///
106    /// - safe for multi-array iteration like `get_iter(a, b)`
107    /// - not safe for cases like `a.iter().zip(b.iter())`
108    A,
109    /// Greedy when possible (reorder layouts during iteration).
110    ///
111    /// - safe for multi-array iteration like `get_iter(a, b)`
112    /// - not safe for cases like `a.iter().zip(b.iter())`
113    /// - if it is used to create a new array, the stride of new array will be in K order
114    K,
115    /// Greedy when possible (reset dimension to 1 if axis is broadcasted).
116    ///
117    /// - not safe for multi-array iteration like `get_iter(a, b)`
118    /// - this is useful for inplace-assign broadcasted array.
119    G,
120    /// Sequential buffer.
121    ///
122    /// - not safe for multi-array iteration like `get_iter(a, b)`
123    /// - this is useful for reshaping or all-contiguous cases.
124    B,
125}
126
127impl_changeable_default!(TensorIterOrder, DEFAULT_TENSOR_ITER_ORDER, TensorIterOrder::K);
128
129/* #endregion */
130
131/* #region TensorCopyPolicy */
132
133/// The policy of copying tensor.
134pub mod TensorCopyPolicy {
135    #![allow(non_snake_case)]
136
137    // this is a workaround in stable rust
138    // when const enum can not be used as generic parameters
139
140    pub type FlagCopy = u8;
141
142    /// Copy when needed
143    pub const COPY_NEEDED: FlagCopy = 0;
144    /// Force copy
145    pub const COPY_TRUE: FlagCopy = 1;
146    /// Force not copy; and when copy is required, it will emit error
147    pub const COPY_FALSE: FlagCopy = 2;
148
149    pub const COPY_DEFAULT: FlagCopy = COPY_NEEDED;
150}
151
152/* #endregion */
153
154/* #region blas-flags */
155
156#[repr(u8)]
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
158pub enum FlagTrans {
159    /// No transpose
160    N = 111,
161    /// Transpose
162    T = 112,
163    /// Conjugate transpose
164    C = 113,
165    // Conjuate only
166    CN = 114,
167}
168
169#[repr(u8)]
170#[derive(Debug, Clone, Copy, PartialEq, Eq)]
171pub enum FlagSide {
172    /// Left side
173    L = 141,
174    /// Right side
175    R = 142,
176}
177
178#[repr(u8)]
179#[derive(Debug, Clone, Copy, PartialEq, Eq)]
180pub enum FlagUpLo {
181    /// Upper triangle
182    U = 121,
183    /// Lower triangle
184    L = 122,
185}
186
187#[repr(u8)]
188#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub enum FlagDiag {
190    /// Non-unit diagonal
191    N = 131,
192    /// Unit diagonal
193    U = 132,
194}
195
196/* #endregion */
197
198/* #region symm-flags */
199
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum FlagSymm {
202    /// Symmetric matrix
203    Sy,
204    /// Hermitian matrix
205    He,
206    /// Anti-symmetric matrix
207    Ay,
208    /// Anti-Hermitian matrix
209    Ah,
210    /// Non-symmetric matrix
211    N,
212}
213
214pub type TensorOrder = FlagOrder;
215pub type TensorDiag = FlagDiag;
216pub type TensorSide = FlagSide;
217pub type TensorUpLo = FlagUpLo;
218pub type TensorTrans = FlagTrans;
219pub type TensorSymm = FlagSymm;
220
221/* #endregion */
222
223/* #region flag alias */
224
225pub use FlagTrans::C as ConjTrans;
226pub use FlagTrans::N as NoTrans;
227pub use FlagTrans::T as Trans;
228
229pub use FlagSide::L as Left;
230pub use FlagSide::R as Right;
231
232pub use FlagUpLo::L as Lower;
233pub use FlagUpLo::U as Upper;
234
235pub use FlagDiag::N as NonUnit;
236pub use FlagDiag::U as Unit;
237
238pub use FlagOrder::C as RowMajor;
239pub use FlagOrder::F as ColMajor;
240
241/* #endregion */
242
243/* #region flag into */
244
245impl From<char> for FlagTrans {
246    fn from(val: char) -> Self {
247        match val {
248            'N' | 'n' => FlagTrans::N,
249            'T' | 't' => FlagTrans::T,
250            'C' | 'c' => FlagTrans::C,
251            _ => rstsr_invalid!(val).unwrap(),
252        }
253    }
254}
255
256impl From<FlagTrans> for char {
257    fn from(val: FlagTrans) -> Self {
258        match val {
259            FlagTrans::N => 'N',
260            FlagTrans::T => 'T',
261            FlagTrans::C => 'C',
262            _ => rstsr_invalid!(val).unwrap(),
263        }
264    }
265}
266
267impl From<FlagTrans> for c_char {
268    fn from(val: FlagTrans) -> Self {
269        match val {
270            FlagTrans::N => b'N' as c_char,
271            FlagTrans::T => b'T' as c_char,
272            FlagTrans::C => b'C' as c_char,
273            _ => rstsr_invalid!(val).unwrap(),
274        }
275    }
276}
277
278impl From<c_char> for FlagTrans {
279    fn from(val: c_char) -> Self {
280        match val as u8 {
281            b'N' => FlagTrans::N,
282            b'T' => FlagTrans::T,
283            b'C' => FlagTrans::C,
284            _ => rstsr_invalid!(val).unwrap(),
285        }
286    }
287}
288
289impl From<CBLAS_TRANSPOSE> for FlagTrans {
290    fn from(val: CBLAS_TRANSPOSE) -> Self {
291        match val {
292            CBLAS_TRANSPOSE::CblasNoTrans => FlagTrans::N,
293            CBLAS_TRANSPOSE::CblasTrans => FlagTrans::T,
294            CBLAS_TRANSPOSE::CblasConjTrans => FlagTrans::C,
295        }
296    }
297}
298
299impl From<FlagTrans> for CBLAS_TRANSPOSE {
300    fn from(val: FlagTrans) -> Self {
301        match val {
302            FlagTrans::N => CBLAS_TRANSPOSE::CblasNoTrans,
303            FlagTrans::T => CBLAS_TRANSPOSE::CblasTrans,
304            FlagTrans::C => CBLAS_TRANSPOSE::CblasConjTrans,
305            _ => rstsr_invalid!(val).unwrap(),
306        }
307    }
308}
309
310impl From<char> for FlagDiag {
311    fn from(val: char) -> Self {
312        match val {
313            'N' | 'n' => FlagDiag::N,
314            'U' | 'u' => FlagDiag::U,
315            _ => rstsr_invalid!(val).unwrap(),
316        }
317    }
318}
319
320impl From<FlagDiag> for char {
321    fn from(val: FlagDiag) -> Self {
322        match val {
323            FlagDiag::N => 'N',
324            FlagDiag::U => 'U',
325        }
326    }
327}
328
329impl From<FlagDiag> for c_char {
330    fn from(val: FlagDiag) -> Self {
331        match val {
332            FlagDiag::N => b'N' as c_char,
333            FlagDiag::U => b'U' as c_char,
334        }
335    }
336}
337
338impl From<c_char> for FlagDiag {
339    fn from(val: c_char) -> Self {
340        match val as u8 {
341            b'N' => FlagDiag::N,
342            b'U' => FlagDiag::U,
343            _ => rstsr_invalid!(val).unwrap(),
344        }
345    }
346}
347
348impl From<CBLAS_DIAG> for FlagDiag {
349    fn from(val: CBLAS_DIAG) -> Self {
350        match val {
351            CBLAS_DIAG::CblasNonUnit => FlagDiag::N,
352            CBLAS_DIAG::CblasUnit => FlagDiag::U,
353        }
354    }
355}
356
357impl From<FlagDiag> for CBLAS_DIAG {
358    fn from(val: FlagDiag) -> Self {
359        match val {
360            FlagDiag::N => CBLAS_DIAG::CblasNonUnit,
361            FlagDiag::U => CBLAS_DIAG::CblasUnit,
362        }
363    }
364}
365
366impl From<char> for FlagSide {
367    fn from(val: char) -> Self {
368        match val {
369            'L' | 'l' => FlagSide::L,
370            'R' | 'r' => FlagSide::R,
371            _ => rstsr_invalid!(val).unwrap(),
372        }
373    }
374}
375
376impl From<FlagSide> for char {
377    fn from(val: FlagSide) -> Self {
378        match val {
379            FlagSide::L => 'L',
380            FlagSide::R => 'R',
381        }
382    }
383}
384
385impl From<FlagSide> for c_char {
386    fn from(val: FlagSide) -> Self {
387        match val {
388            FlagSide::L => b'L' as c_char,
389            FlagSide::R => b'R' as c_char,
390        }
391    }
392}
393
394impl From<c_char> for FlagSide {
395    fn from(val: c_char) -> Self {
396        match val as u8 {
397            b'L' => FlagSide::L,
398            b'R' => FlagSide::R,
399            _ => rstsr_invalid!(val).unwrap(),
400        }
401    }
402}
403
404impl From<CBLAS_SIDE> for FlagSide {
405    fn from(val: CBLAS_SIDE) -> Self {
406        match val {
407            CBLAS_SIDE::CblasLeft => FlagSide::L,
408            CBLAS_SIDE::CblasRight => FlagSide::R,
409        }
410    }
411}
412
413impl From<FlagSide> for CBLAS_SIDE {
414    fn from(val: FlagSide) -> Self {
415        match val {
416            FlagSide::L => CBLAS_SIDE::CblasLeft,
417            FlagSide::R => CBLAS_SIDE::CblasRight,
418        }
419    }
420}
421
422impl From<char> for FlagUpLo {
423    fn from(val: char) -> Self {
424        match val {
425            'U' | 'u' => FlagUpLo::U,
426            'L' | 'l' => FlagUpLo::L,
427            _ => rstsr_invalid!(val).unwrap(),
428        }
429    }
430}
431
432impl From<FlagUpLo> for char {
433    fn from(val: FlagUpLo) -> Self {
434        match val {
435            FlagUpLo::U => 'U',
436            FlagUpLo::L => 'L',
437        }
438    }
439}
440
441impl From<FlagUpLo> for c_char {
442    fn from(val: FlagUpLo) -> Self {
443        match val {
444            FlagUpLo::U => b'U' as c_char,
445            FlagUpLo::L => b'L' as c_char,
446        }
447    }
448}
449
450impl From<c_char> for FlagUpLo {
451    fn from(val: c_char) -> Self {
452        match val as u8 {
453            b'U' => FlagUpLo::U,
454            b'L' => FlagUpLo::L,
455            _ => rstsr_invalid!(val).unwrap(),
456        }
457    }
458}
459
460impl From<CBLAS_UPLO> for FlagUpLo {
461    fn from(val: CBLAS_UPLO) -> Self {
462        match val {
463            CBLAS_UPLO::CblasUpper => FlagUpLo::U,
464            CBLAS_UPLO::CblasLower => FlagUpLo::L,
465        }
466    }
467}
468
469impl From<FlagUpLo> for CBLAS_UPLO {
470    fn from(val: FlagUpLo) -> Self {
471        match val {
472            FlagUpLo::U => CBLAS_UPLO::CblasUpper,
473            FlagUpLo::L => CBLAS_UPLO::CblasLower,
474        }
475    }
476}
477
478impl From<CBLAS_LAYOUT> for FlagOrder {
479    fn from(val: CBLAS_LAYOUT) -> Self {
480        match val {
481            CBLAS_LAYOUT::CblasRowMajor => FlagOrder::C,
482            CBLAS_LAYOUT::CblasColMajor => FlagOrder::F,
483        }
484    }
485}
486
487impl From<FlagOrder> for CBLAS_LAYOUT {
488    fn from(val: FlagOrder) -> Self {
489        match val {
490            FlagOrder::C => CBLAS_LAYOUT::CblasRowMajor,
491            FlagOrder::F => CBLAS_LAYOUT::CblasColMajor,
492        }
493    }
494}
495
496/* #endregion */
497
498/* #region flag flip */
499
500impl FlagOrder {
501    pub fn flip(&self) -> Self {
502        match self {
503            FlagOrder::C => FlagOrder::F,
504            FlagOrder::F => FlagOrder::C,
505        }
506    }
507}
508
509impl FlagTrans {
510    pub fn flip(&self, hermi: bool) -> Result<Self> {
511        match self {
512            FlagTrans::N => match hermi {
513                true => Ok(FlagTrans::C),
514                false => Ok(FlagTrans::T),
515            },
516            FlagTrans::T => Ok(FlagTrans::N),
517            FlagTrans::C => Ok(FlagTrans::N),
518            _ => rstsr_invalid!(self)?,
519        }
520    }
521}
522
523impl FlagSide {
524    pub fn flip(&self) -> Result<Self> {
525        match self {
526            FlagSide::L => Ok(FlagSide::R),
527            FlagSide::R => Ok(FlagSide::L),
528        }
529    }
530}
531
532impl FlagUpLo {
533    pub fn flip(&self) -> Result<Self> {
534        match self {
535            FlagUpLo::U => Ok(FlagUpLo::L),
536            FlagUpLo::L => Ok(FlagUpLo::U),
537        }
538    }
539}
540
541/* #endregion */