rstsr_common/
flags.rs

1//! Flags for the crate.
2
3use crate::prelude_dev::*;
4use core::ffi::c_char;
5use rstsr_lapack_ffi::cblas::{CBLAS_DIAG, CBLAS_LAYOUT, CBLAS_SIDE, CBLAS_TRANSPOSE, CBLAS_UPLO};
6
7/* #region changeable default */
8
9pub trait ChangeableDefault {
10    /// # Safety
11    ///
12    /// This function changes static mutable variable.
13    /// It is better applying cargo feature instead of using this function.
14    unsafe fn change_default(val: Self);
15    fn get_default() -> Self;
16}
17
18macro_rules! impl_changeable_default {
19    ($struct:ty, $val:ident, $default:expr) => {
20        static mut $val: $struct = $default;
21
22        impl ChangeableDefault for $struct {
23            unsafe fn change_default(val: Self) {
24                $val = val;
25            }
26
27            fn get_default() -> Self {
28                return unsafe { $val };
29            }
30        }
31
32        impl Default for $struct
33        where
34            Self: ChangeableDefault,
35        {
36            fn default() -> Self {
37                <$struct>::get_default()
38            }
39        }
40    };
41}
42
43/* #endregion */
44
45/* #region FlagOrder */
46
47/// The order of the tensor.
48///
49/// # Default
50///
51/// Default order depends on cargo feature `f_prefer`.
52/// If `f_prefer` is set, then [`FlagOrder::F`] is applied as default;
53/// otherwise [`FlagOrder::C`] is applied as default.
54///
55/// # IMPORTANT NOTE
56///
57/// F-prefer is not a stable feature currently! We develop only in C-prefer
58/// currently.
59#[repr(u8)]
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum FlagOrder {
62    /// row-major order.
63    C = 101,
64    /// column-major order.
65    F = 102,
66}
67
68#[allow(non_upper_case_globals)]
69impl FlagOrder {
70    pub const RowMajor: Self = FlagOrder::C;
71    pub const ColMajor: Self = FlagOrder::F;
72}
73
74#[allow(clippy::derivable_impls)]
75impl Default for FlagOrder {
76    fn default() -> Self {
77        if cfg!(feature = "col_major") {
78            return FlagOrder::F;
79        } else {
80            return FlagOrder::C;
81        }
82    }
83}
84
85/* #endregion */
86
87/* #region TensorIterOrder */
88
89/// The policy of the tensor iterator.
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum TensorIterOrder {
92    /// Row-major order.
93    ///
94    /// - absolute safe for array iteration
95    C,
96    /// Column-major order.
97    ///
98    /// - absolute safe for array iteration
99    F,
100    /// Automatically choose row/col-major order.
101    ///
102    /// - try c/f-contig first (also see [`TensorIterOrder::B`]),
103    /// - try c/f-prefer second (also see [`TensorIterOrder::C`],
104    ///   [`TensorIterOrder::F`]),
105    /// - otherwise [`FlagOrder::default()`], which is defined by crate feature
106    ///   `f_prefer`.
107    ///
108    /// - safe for multi-array iteration like `get_iter(a, b)`
109    /// - not safe for cases like `a.iter().zip(b.iter())`
110    A,
111    /// Greedy when possible (reorder layouts during iteration).
112    ///
113    /// - safe for multi-array iteration like `get_iter(a, b)`
114    /// - not safe for cases like `a.iter().zip(b.iter())`
115    /// - if it is used to create a new array, the stride of new array will be
116    ///   in K order
117    K,
118    /// Greedy when possible (reset dimension to 1 if axis is broadcasted).
119    ///
120    /// - not safe for multi-array iteration like `get_iter(a, b)`
121    /// - this is useful for inplace-assign broadcasted array.
122    G,
123    /// Sequential buffer.
124    ///
125    /// - not safe for multi-array iteration like `get_iter(a, b)`
126    /// - this is useful for reshaping or all-contiguous cases.
127    B,
128}
129
130impl_changeable_default!(TensorIterOrder, DEFAULT_TENSOR_ITER_ORDER, TensorIterOrder::K);
131
132/* #endregion */
133
134/* #region TensorCopyPolicy */
135
136/// The policy of copying tensor.
137pub mod TensorCopyPolicy {
138    #![allow(non_snake_case)]
139
140    // this is a workaround in stable rust
141    // when const enum can not be used as generic parameters
142
143    pub type FlagCopy = u8;
144
145    /// Copy when needed
146    pub const COPY_NEEDED: FlagCopy = 0;
147    /// Force copy
148    pub const COPY_TRUE: FlagCopy = 1;
149    /// Force not copy; and when copy is required, it will emit error
150    pub const COPY_FALSE: FlagCopy = 2;
151
152    pub const COPY_DEFAULT: FlagCopy = COPY_NEEDED;
153}
154
155/* #endregion */
156
157/* #region blas-flags */
158
159#[repr(u8)]
160#[derive(Debug, Clone, Copy, PartialEq, Eq)]
161pub enum FlagTrans {
162    /// No transpose
163    N = 111,
164    /// Transpose
165    T = 112,
166    /// Conjugate transpose
167    C = 113,
168    // Conjuate only
169    CN = 114,
170}
171
172#[repr(u8)]
173#[derive(Debug, Clone, Copy, PartialEq, Eq)]
174pub enum FlagSide {
175    /// Left side
176    L = 141,
177    /// Right side
178    R = 142,
179}
180
181#[repr(u8)]
182#[derive(Debug, Clone, Copy, PartialEq, Eq)]
183pub enum FlagUpLo {
184    /// Upper triangle
185    U = 121,
186    /// Lower triangle
187    L = 122,
188}
189
190#[repr(u8)]
191#[derive(Debug, Clone, Copy, PartialEq, Eq)]
192pub enum FlagDiag {
193    /// Non-unit diagonal
194    N = 131,
195    /// Unit diagonal
196    U = 132,
197}
198
199/* #endregion */
200
201/* #region symm-flags */
202
203#[derive(Debug, Clone, Copy, PartialEq, Eq)]
204pub enum FlagSymm {
205    /// Symmetric matrix
206    Sy,
207    /// Hermitian matrix
208    He,
209    /// Anti-symmetric matrix
210    Ay,
211    /// Anti-Hermitian matrix
212    Ah,
213    /// Non-symmetric matrix
214    N,
215}
216
217pub type TensorOrder = FlagOrder;
218pub type TensorDiag = FlagDiag;
219pub type TensorSide = FlagSide;
220pub type TensorUpLo = FlagUpLo;
221pub type TensorTrans = FlagTrans;
222pub type TensorSymm = FlagSymm;
223
224/* #endregion */
225
226/* #region flag alias */
227
228pub use FlagTrans::C as ConjTrans;
229pub use FlagTrans::N as NoTrans;
230pub use FlagTrans::T as Trans;
231
232pub use FlagSide::L as Left;
233pub use FlagSide::R as Right;
234
235pub use FlagUpLo::L as Lower;
236pub use FlagUpLo::U as Upper;
237
238pub use FlagDiag::N as NonUnit;
239pub use FlagDiag::U as Unit;
240
241pub use FlagOrder::C as RowMajor;
242pub use FlagOrder::F as ColMajor;
243
244/* #endregion */
245
246/* #region flag into */
247
248impl From<char> for FlagTrans {
249    fn from(val: char) -> Self {
250        match val {
251            'N' | 'n' => FlagTrans::N,
252            'T' | 't' => FlagTrans::T,
253            'C' | 'c' => FlagTrans::C,
254            _ => rstsr_invalid!(val).unwrap(),
255        }
256    }
257}
258
259impl From<FlagTrans> for char {
260    fn from(val: FlagTrans) -> Self {
261        match val {
262            FlagTrans::N => 'N',
263            FlagTrans::T => 'T',
264            FlagTrans::C => 'C',
265            _ => rstsr_invalid!(val).unwrap(),
266        }
267    }
268}
269
270impl From<FlagTrans> for c_char {
271    fn from(val: FlagTrans) -> Self {
272        match val {
273            FlagTrans::N => b'N' as c_char,
274            FlagTrans::T => b'T' as c_char,
275            FlagTrans::C => b'C' as c_char,
276            _ => rstsr_invalid!(val).unwrap(),
277        }
278    }
279}
280
281impl From<c_char> for FlagTrans {
282    fn from(val: c_char) -> Self {
283        match val as u8 {
284            b'N' => FlagTrans::N,
285            b'T' => FlagTrans::T,
286            b'C' => FlagTrans::C,
287            _ => rstsr_invalid!(val).unwrap(),
288        }
289    }
290}
291
292impl From<CBLAS_TRANSPOSE> for FlagTrans {
293    fn from(val: CBLAS_TRANSPOSE) -> Self {
294        match val {
295            CBLAS_TRANSPOSE::CblasNoTrans => FlagTrans::N,
296            CBLAS_TRANSPOSE::CblasTrans => FlagTrans::T,
297            CBLAS_TRANSPOSE::CblasConjTrans => FlagTrans::C,
298        }
299    }
300}
301
302impl From<FlagTrans> for CBLAS_TRANSPOSE {
303    fn from(val: FlagTrans) -> Self {
304        match val {
305            FlagTrans::N => CBLAS_TRANSPOSE::CblasNoTrans,
306            FlagTrans::T => CBLAS_TRANSPOSE::CblasTrans,
307            FlagTrans::C => CBLAS_TRANSPOSE::CblasConjTrans,
308            _ => rstsr_invalid!(val).unwrap(),
309        }
310    }
311}
312
313impl From<char> for FlagDiag {
314    fn from(val: char) -> Self {
315        match val {
316            'N' | 'n' => FlagDiag::N,
317            'U' | 'u' => FlagDiag::U,
318            _ => rstsr_invalid!(val).unwrap(),
319        }
320    }
321}
322
323impl From<FlagDiag> for char {
324    fn from(val: FlagDiag) -> Self {
325        match val {
326            FlagDiag::N => 'N',
327            FlagDiag::U => 'U',
328        }
329    }
330}
331
332impl From<FlagDiag> for c_char {
333    fn from(val: FlagDiag) -> Self {
334        match val {
335            FlagDiag::N => b'N' as c_char,
336            FlagDiag::U => b'U' as c_char,
337        }
338    }
339}
340
341impl From<c_char> for FlagDiag {
342    fn from(val: c_char) -> Self {
343        match val as u8 {
344            b'N' => FlagDiag::N,
345            b'U' => FlagDiag::U,
346            _ => rstsr_invalid!(val).unwrap(),
347        }
348    }
349}
350
351impl From<CBLAS_DIAG> for FlagDiag {
352    fn from(val: CBLAS_DIAG) -> Self {
353        match val {
354            CBLAS_DIAG::CblasNonUnit => FlagDiag::N,
355            CBLAS_DIAG::CblasUnit => FlagDiag::U,
356        }
357    }
358}
359
360impl From<FlagDiag> for CBLAS_DIAG {
361    fn from(val: FlagDiag) -> Self {
362        match val {
363            FlagDiag::N => CBLAS_DIAG::CblasNonUnit,
364            FlagDiag::U => CBLAS_DIAG::CblasUnit,
365        }
366    }
367}
368
369impl From<char> for FlagSide {
370    fn from(val: char) -> Self {
371        match val {
372            'L' | 'l' => FlagSide::L,
373            'R' | 'r' => FlagSide::R,
374            _ => rstsr_invalid!(val).unwrap(),
375        }
376    }
377}
378
379impl From<FlagSide> for char {
380    fn from(val: FlagSide) -> Self {
381        match val {
382            FlagSide::L => 'L',
383            FlagSide::R => 'R',
384        }
385    }
386}
387
388impl From<FlagSide> for c_char {
389    fn from(val: FlagSide) -> Self {
390        match val {
391            FlagSide::L => b'L' as c_char,
392            FlagSide::R => b'R' as c_char,
393        }
394    }
395}
396
397impl From<c_char> for FlagSide {
398    fn from(val: c_char) -> Self {
399        match val as u8 {
400            b'L' => FlagSide::L,
401            b'R' => FlagSide::R,
402            _ => rstsr_invalid!(val).unwrap(),
403        }
404    }
405}
406
407impl From<CBLAS_SIDE> for FlagSide {
408    fn from(val: CBLAS_SIDE) -> Self {
409        match val {
410            CBLAS_SIDE::CblasLeft => FlagSide::L,
411            CBLAS_SIDE::CblasRight => FlagSide::R,
412        }
413    }
414}
415
416impl From<FlagSide> for CBLAS_SIDE {
417    fn from(val: FlagSide) -> Self {
418        match val {
419            FlagSide::L => CBLAS_SIDE::CblasLeft,
420            FlagSide::R => CBLAS_SIDE::CblasRight,
421        }
422    }
423}
424
425impl From<char> for FlagUpLo {
426    fn from(val: char) -> Self {
427        match val {
428            'U' | 'u' => FlagUpLo::U,
429            'L' | 'l' => FlagUpLo::L,
430            _ => rstsr_invalid!(val).unwrap(),
431        }
432    }
433}
434
435impl From<FlagUpLo> for char {
436    fn from(val: FlagUpLo) -> Self {
437        match val {
438            FlagUpLo::U => 'U',
439            FlagUpLo::L => 'L',
440        }
441    }
442}
443
444impl From<FlagUpLo> for c_char {
445    fn from(val: FlagUpLo) -> Self {
446        match val {
447            FlagUpLo::U => b'U' as c_char,
448            FlagUpLo::L => b'L' as c_char,
449        }
450    }
451}
452
453impl From<c_char> for FlagUpLo {
454    fn from(val: c_char) -> Self {
455        match val as u8 {
456            b'U' => FlagUpLo::U,
457            b'L' => FlagUpLo::L,
458            _ => rstsr_invalid!(val).unwrap(),
459        }
460    }
461}
462
463impl From<CBLAS_UPLO> for FlagUpLo {
464    fn from(val: CBLAS_UPLO) -> Self {
465        match val {
466            CBLAS_UPLO::CblasUpper => FlagUpLo::U,
467            CBLAS_UPLO::CblasLower => FlagUpLo::L,
468        }
469    }
470}
471
472impl From<FlagUpLo> for CBLAS_UPLO {
473    fn from(val: FlagUpLo) -> Self {
474        match val {
475            FlagUpLo::U => CBLAS_UPLO::CblasUpper,
476            FlagUpLo::L => CBLAS_UPLO::CblasLower,
477        }
478    }
479}
480
481impl From<CBLAS_LAYOUT> for FlagOrder {
482    fn from(val: CBLAS_LAYOUT) -> Self {
483        match val {
484            CBLAS_LAYOUT::CblasRowMajor => FlagOrder::C,
485            CBLAS_LAYOUT::CblasColMajor => FlagOrder::F,
486        }
487    }
488}
489
490impl From<FlagOrder> for CBLAS_LAYOUT {
491    fn from(val: FlagOrder) -> Self {
492        match val {
493            FlagOrder::C => CBLAS_LAYOUT::CblasRowMajor,
494            FlagOrder::F => CBLAS_LAYOUT::CblasColMajor,
495        }
496    }
497}
498
499/* #endregion */
500
501/* #region flag flip */
502
503impl FlagOrder {
504    pub fn flip(&self) -> Self {
505        match self {
506            FlagOrder::C => FlagOrder::F,
507            FlagOrder::F => FlagOrder::C,
508        }
509    }
510}
511
512impl FlagTrans {
513    pub fn flip(&self, hermi: bool) -> Result<Self> {
514        match self {
515            FlagTrans::N => match hermi {
516                true => Ok(FlagTrans::C),
517                false => Ok(FlagTrans::T),
518            },
519            FlagTrans::T => Ok(FlagTrans::N),
520            FlagTrans::C => Ok(FlagTrans::N),
521            _ => rstsr_invalid!(self)?,
522        }
523    }
524}
525
526impl FlagSide {
527    pub fn flip(&self) -> Result<Self> {
528        match self {
529            FlagSide::L => Ok(FlagSide::R),
530            FlagSide::R => Ok(FlagSide::L),
531        }
532    }
533}
534
535impl FlagUpLo {
536    pub fn flip(&self) -> Result<Self> {
537        match self {
538            FlagUpLo::U => Ok(FlagUpLo::L),
539            FlagUpLo::L => Ok(FlagUpLo::U),
540        }
541    }
542}
543
544/* #endregion */