rstsr_common/
flags.rs

1//! Flags for the crate.
2
3use crate::prelude_dev::*;
4use core::ffi::c_char;
5use rstsr_lapack_ffi::cblas::{CBLAS_DIAG, CBLAS_LAYOUT, CBLAS_SIDE, CBLAS_TRANSPOSE, CBLAS_UPLO};
6
7/* #region changeable default */
8
9pub trait ChangeableDefault {
10    /// # Safety
11    ///
12    /// This function changes static mutable variable.
13    /// It is better applying cargo feature instead of using this function.
14    unsafe fn change_default(val: Self);
15    fn get_default() -> Self;
16}
17
18macro_rules! impl_changeable_default {
19    ($struct:ty, $val:ident, $default:expr) => {
20        static mut $val: $struct = $default;
21
22        impl ChangeableDefault for $struct {
23            unsafe fn change_default(val: Self) {
24                $val = val;
25            }
26
27            fn get_default() -> Self {
28                return unsafe { $val };
29            }
30        }
31
32        impl Default for $struct
33        where
34            Self: ChangeableDefault,
35        {
36            fn default() -> Self {
37                <$struct>::get_default()
38            }
39        }
40    };
41}
42
43/* #endregion */
44
45/* #region FlagOrder */
46
47/// The order of the tensor.
48///
49/// # Default
50///
51/// Default order depends on cargo feature `f_prefer`.
52/// If `f_prefer` is set, then [`FlagOrder::F`] is applied as default;
53/// otherwise [`FlagOrder::C`] is applied as default.
54///
55/// # IMPORTANT NOTE
56///
57/// F-prefer is not a stable feature currently! We develop only in C-prefer
58/// currently.
59#[repr(u8)]
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum FlagOrder {
62    /// row-major order.
63    C = 101,
64    /// column-major order.
65    F = 102,
66}
67
68#[allow(non_upper_case_globals)]
69impl FlagOrder {
70    pub const RowMajor: Self = FlagOrder::C;
71    pub const ColMajor: Self = FlagOrder::F;
72}
73
74#[allow(clippy::derivable_impls)]
75impl Default for FlagOrder {
76    fn default() -> Self {
77        if cfg!(feature = "col_major") {
78            return FlagOrder::F;
79        } else {
80            return FlagOrder::C;
81        }
82    }
83}
84
85/* #endregion */
86
87/* #region TensorIterOrder */
88
89/// The policy of the tensor iterator.
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum TensorIterOrder {
92    /// Row-major order.
93    ///
94    /// - absolute safe for array iteration
95    C,
96    /// Column-major order.
97    ///
98    /// - absolute safe for array iteration
99    F,
100    /// Automatically choose row/col-major order.
101    ///
102    /// - try c/f-contig first (also see [`TensorIterOrder::B`]),
103    /// - try c/f-prefer second (also see [`TensorIterOrder::C`],
104    ///   [`TensorIterOrder::F`]),
105    /// - otherwise [`FlagOrder::default()`], which is defined by crate feature
106    ///   `f_prefer`.
107    ///
108    /// - safe for multi-array iteration like `get_iter(a, b)`
109    /// - not safe for cases like `a.iter().zip(b.iter())`
110    A,
111    /// Greedy when possible (reorder layouts during iteration).
112    ///
113    /// - safe for multi-array iteration like `get_iter(a, b)`
114    /// - not safe for cases like `a.iter().zip(b.iter())`
115    /// - if it is used to create a new array, the stride of new array will be
116    ///   in K order
117    K,
118    /// Greedy when possible (reset dimension to 1 if axis is broadcasted).
119    ///
120    /// - not safe for multi-array iteration like `get_iter(a, b)`
121    /// - this is useful for inplace-assign broadcasted array.
122    G,
123    /// Sequential buffer.
124    ///
125    /// - not safe for multi-array iteration like `get_iter(a, b)`
126    /// - this is useful for reshaping or all-contiguous cases.
127    B,
128}
129
130impl_changeable_default!(TensorIterOrder, DEFAULT_TENSOR_ITER_ORDER, TensorIterOrder::K);
131
132/* #endregion */
133
134/* #region TensorCopyPolicy */
135
136/// The policy of copying tensor.
137pub mod TensorCopyPolicy {
138    #![allow(non_snake_case)]
139
140    // this is a workaround in stable rust
141    // when const enum can not be used as generic parameters
142
143    pub type FlagCopy = u8;
144
145    /// Copy when needed
146    pub const COPY_NEEDED: FlagCopy = 0;
147    /// Force copy
148    pub const COPY_TRUE: FlagCopy = 1;
149    /// Force not copy; and when copy is required, it will emit error
150    pub const COPY_FALSE: FlagCopy = 2;
151
152    pub const COPY_DEFAULT: FlagCopy = COPY_NEEDED;
153}
154
155/* #endregion */
156
157/* #region blas-flags */
158
159#[repr(u8)]
160#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
161pub enum FlagTrans {
162    #[default]
163    Undefined,
164    /// No transpose
165    N = 111,
166    /// Transpose
167    T = 112,
168    /// Conjugate transpose
169    C = 113,
170    // Conjuate only
171    CN = 114,
172}
173
174#[repr(u8)]
175#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
176pub enum FlagSide {
177    #[default]
178    Undefined,
179    /// Left side
180    L = 141,
181    /// Right side
182    R = 142,
183}
184
185#[repr(u8)]
186#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
187pub enum FlagUpLo {
188    #[default]
189    Undefined,
190    /// Upper triangle
191    U = 121,
192    /// Lower triangle
193    L = 122,
194}
195
196#[repr(u8)]
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
198pub enum FlagDiag {
199    #[default]
200    Undefined,
201    /// Non-unit diagonal
202    N = 131,
203    /// Unit diagonal
204    U = 132,
205}
206
207/* #endregion */
208
209/* #region symm-flags */
210
211#[derive(Debug, Clone, Copy, PartialEq, Eq)]
212pub enum FlagSymm {
213    /// Symmetric matrix
214    Sy,
215    /// Hermitian matrix
216    He,
217    /// Anti-symmetric matrix
218    Ay,
219    /// Anti-Hermitian matrix
220    Ah,
221    /// Non-symmetric matrix
222    N,
223}
224
225pub type TensorOrder = FlagOrder;
226pub type TensorDiag = FlagDiag;
227pub type TensorSide = FlagSide;
228pub type TensorUpLo = FlagUpLo;
229pub type TensorTrans = FlagTrans;
230pub type TensorSymm = FlagSymm;
231
232/* #endregion */
233
234/* #region flag alias */
235
236pub use FlagTrans::C as ConjTrans;
237pub use FlagTrans::N as NoTrans;
238pub use FlagTrans::T as Trans;
239
240pub use FlagSide::L as Left;
241pub use FlagSide::R as Right;
242
243pub use FlagUpLo::L as Lower;
244pub use FlagUpLo::U as Upper;
245
246pub use FlagDiag::N as NonUnit;
247pub use FlagDiag::U as Unit;
248
249pub use FlagOrder::C as RowMajor;
250pub use FlagOrder::F as ColMajor;
251
252/* #endregion */
253
254/* #region flag into */
255
256impl From<char> for FlagTrans {
257    fn from(val: char) -> Self {
258        match val {
259            'N' | 'n' => FlagTrans::N,
260            'T' | 't' => FlagTrans::T,
261            'C' | 'c' => FlagTrans::C,
262            _ => rstsr_invalid!(val).unwrap(),
263        }
264    }
265}
266
267impl From<FlagTrans> for char {
268    fn from(val: FlagTrans) -> Self {
269        match val {
270            FlagTrans::N => 'N',
271            FlagTrans::T => 'T',
272            FlagTrans::C => 'C',
273            _ => rstsr_invalid!(val).unwrap(),
274        }
275    }
276}
277
278impl From<FlagTrans> for c_char {
279    fn from(val: FlagTrans) -> Self {
280        match val {
281            FlagTrans::N => b'N' as c_char,
282            FlagTrans::T => b'T' as c_char,
283            FlagTrans::C => b'C' as c_char,
284            _ => rstsr_invalid!(val).unwrap(),
285        }
286    }
287}
288
289impl From<c_char> for FlagTrans {
290    fn from(val: c_char) -> Self {
291        match val as u8 {
292            b'N' => FlagTrans::N,
293            b'T' => FlagTrans::T,
294            b'C' => FlagTrans::C,
295            _ => rstsr_invalid!(val).unwrap(),
296        }
297    }
298}
299
300impl From<CBLAS_TRANSPOSE> for FlagTrans {
301    fn from(val: CBLAS_TRANSPOSE) -> Self {
302        match val {
303            CBLAS_TRANSPOSE::CblasNoTrans => FlagTrans::N,
304            CBLAS_TRANSPOSE::CblasTrans => FlagTrans::T,
305            CBLAS_TRANSPOSE::CblasConjTrans => FlagTrans::C,
306        }
307    }
308}
309
310impl From<FlagTrans> for CBLAS_TRANSPOSE {
311    fn from(val: FlagTrans) -> Self {
312        match val {
313            FlagTrans::N => CBLAS_TRANSPOSE::CblasNoTrans,
314            FlagTrans::T => CBLAS_TRANSPOSE::CblasTrans,
315            FlagTrans::C => CBLAS_TRANSPOSE::CblasConjTrans,
316            _ => rstsr_invalid!(val).unwrap(),
317        }
318    }
319}
320
321impl From<char> for FlagDiag {
322    fn from(val: char) -> Self {
323        match val {
324            'N' | 'n' => FlagDiag::N,
325            'U' | 'u' => FlagDiag::U,
326            _ => rstsr_invalid!(val).unwrap(),
327        }
328    }
329}
330
331impl From<FlagDiag> for char {
332    fn from(val: FlagDiag) -> Self {
333        match val {
334            FlagDiag::N => 'N',
335            FlagDiag::U => 'U',
336            _ => rstsr_invalid!(val).unwrap(),
337        }
338    }
339}
340
341impl From<FlagDiag> for c_char {
342    fn from(val: FlagDiag) -> Self {
343        match val {
344            FlagDiag::N => b'N' as c_char,
345            FlagDiag::U => b'U' as c_char,
346            _ => rstsr_invalid!(val).unwrap(),
347        }
348    }
349}
350
351impl From<c_char> for FlagDiag {
352    fn from(val: c_char) -> Self {
353        match val as u8 {
354            b'N' => FlagDiag::N,
355            b'U' => FlagDiag::U,
356            _ => rstsr_invalid!(val).unwrap(),
357        }
358    }
359}
360
361impl From<CBLAS_DIAG> for FlagDiag {
362    fn from(val: CBLAS_DIAG) -> Self {
363        match val {
364            CBLAS_DIAG::CblasNonUnit => FlagDiag::N,
365            CBLAS_DIAG::CblasUnit => FlagDiag::U,
366        }
367    }
368}
369
370impl From<FlagDiag> for CBLAS_DIAG {
371    fn from(val: FlagDiag) -> Self {
372        match val {
373            FlagDiag::N => CBLAS_DIAG::CblasNonUnit,
374            FlagDiag::U => CBLAS_DIAG::CblasUnit,
375            _ => rstsr_invalid!(val).unwrap(),
376        }
377    }
378}
379
380impl From<char> for FlagSide {
381    fn from(val: char) -> Self {
382        match val {
383            'L' | 'l' => FlagSide::L,
384            'R' | 'r' => FlagSide::R,
385            _ => rstsr_invalid!(val).unwrap(),
386        }
387    }
388}
389
390impl From<FlagSide> for char {
391    fn from(val: FlagSide) -> Self {
392        match val {
393            FlagSide::L => 'L',
394            FlagSide::R => 'R',
395            _ => rstsr_invalid!(val).unwrap(),
396        }
397    }
398}
399
400impl From<FlagSide> for c_char {
401    fn from(val: FlagSide) -> Self {
402        match val {
403            FlagSide::L => b'L' as c_char,
404            FlagSide::R => b'R' as c_char,
405            _ => rstsr_invalid!(val).unwrap(),
406        }
407    }
408}
409
410impl From<c_char> for FlagSide {
411    fn from(val: c_char) -> Self {
412        match val as u8 {
413            b'L' => FlagSide::L,
414            b'R' => FlagSide::R,
415            _ => rstsr_invalid!(val).unwrap(),
416        }
417    }
418}
419
420impl From<CBLAS_SIDE> for FlagSide {
421    fn from(val: CBLAS_SIDE) -> Self {
422        match val {
423            CBLAS_SIDE::CblasLeft => FlagSide::L,
424            CBLAS_SIDE::CblasRight => FlagSide::R,
425        }
426    }
427}
428
429impl From<FlagSide> for CBLAS_SIDE {
430    fn from(val: FlagSide) -> Self {
431        match val {
432            FlagSide::L => CBLAS_SIDE::CblasLeft,
433            FlagSide::R => CBLAS_SIDE::CblasRight,
434            _ => rstsr_invalid!(val).unwrap(),
435        }
436    }
437}
438
439impl From<char> for FlagUpLo {
440    fn from(val: char) -> Self {
441        match val {
442            'U' | 'u' => FlagUpLo::U,
443            'L' | 'l' => FlagUpLo::L,
444            _ => rstsr_invalid!(val).unwrap(),
445        }
446    }
447}
448
449impl From<FlagUpLo> for char {
450    fn from(val: FlagUpLo) -> Self {
451        match val {
452            FlagUpLo::U => 'U',
453            FlagUpLo::L => 'L',
454            _ => rstsr_invalid!(val).unwrap(),
455        }
456    }
457}
458
459impl From<FlagUpLo> for c_char {
460    fn from(val: FlagUpLo) -> Self {
461        match val {
462            FlagUpLo::U => b'U' as c_char,
463            FlagUpLo::L => b'L' as c_char,
464            _ => rstsr_invalid!(val).unwrap(),
465        }
466    }
467}
468
469impl From<c_char> for FlagUpLo {
470    fn from(val: c_char) -> Self {
471        match val as u8 {
472            b'U' => FlagUpLo::U,
473            b'L' => FlagUpLo::L,
474            _ => rstsr_invalid!(val).unwrap(),
475        }
476    }
477}
478
479impl From<CBLAS_UPLO> for FlagUpLo {
480    fn from(val: CBLAS_UPLO) -> Self {
481        match val {
482            CBLAS_UPLO::CblasUpper => FlagUpLo::U,
483            CBLAS_UPLO::CblasLower => FlagUpLo::L,
484        }
485    }
486}
487
488impl From<FlagUpLo> for CBLAS_UPLO {
489    fn from(val: FlagUpLo) -> Self {
490        match val {
491            FlagUpLo::U => CBLAS_UPLO::CblasUpper,
492            FlagUpLo::L => CBLAS_UPLO::CblasLower,
493            _ => rstsr_invalid!(val).unwrap(),
494        }
495    }
496}
497
498impl From<CBLAS_LAYOUT> for FlagOrder {
499    fn from(val: CBLAS_LAYOUT) -> Self {
500        match val {
501            CBLAS_LAYOUT::CblasRowMajor => FlagOrder::C,
502            CBLAS_LAYOUT::CblasColMajor => FlagOrder::F,
503        }
504    }
505}
506
507impl From<FlagOrder> for CBLAS_LAYOUT {
508    fn from(val: FlagOrder) -> Self {
509        match val {
510            FlagOrder::C => CBLAS_LAYOUT::CblasRowMajor,
511            FlagOrder::F => CBLAS_LAYOUT::CblasColMajor,
512        }
513    }
514}
515
516/* #endregion */
517
518/* #region flag flip */
519
520impl FlagOrder {
521    pub fn flip(&self) -> Self {
522        match self {
523            FlagOrder::C => FlagOrder::F,
524            FlagOrder::F => FlagOrder::C,
525        }
526    }
527}
528
529impl FlagTrans {
530    pub fn flip(&self, hermi: bool) -> Result<Self> {
531        match self {
532            FlagTrans::N => match hermi {
533                true => Ok(FlagTrans::C),
534                false => Ok(FlagTrans::T),
535            },
536            FlagTrans::T => Ok(FlagTrans::N),
537            FlagTrans::C => Ok(FlagTrans::N),
538            _ => rstsr_invalid!(self)?,
539        }
540    }
541}
542
543impl FlagSide {
544    pub fn flip(&self) -> Result<Self> {
545        match self {
546            FlagSide::L => Ok(FlagSide::R),
547            FlagSide::R => Ok(FlagSide::L),
548            _ => rstsr_invalid!(self)?,
549        }
550    }
551}
552
553impl FlagUpLo {
554    pub fn flip(&self) -> Result<Self> {
555        match self {
556            FlagUpLo::U => Ok(FlagUpLo::L),
557            FlagUpLo::L => Ok(FlagUpLo::U),
558            _ => rstsr_invalid!(self)?,
559        }
560    }
561}
562
563/* #endregion */