Skip to main content

rstsr_common/
flags.rs

1//! Flags for the crate.
2
3use crate::prelude_dev::*;
4use core::ffi::c_char;
5use rstsr_cblas_base::*;
6use serde::{Deserialize, Serialize};
7
8/* #region changeable default */
9
10pub trait ChangeableDefault {
11    /// # Safety
12    ///
13    /// This function changes static mutable variable.
14    /// It is better applying cargo feature instead of using this function.
15    unsafe fn change_default(val: Self);
16    fn get_default() -> Self;
17}
18
19macro_rules! impl_changeable_default {
20    ($struct:ty, $val:ident, $default:expr) => {
21        static mut $val: $struct = $default;
22
23        impl ChangeableDefault for $struct {
24            unsafe fn change_default(val: Self) {
25                $val = val;
26            }
27
28            fn get_default() -> Self {
29                return unsafe { $val };
30            }
31        }
32
33        impl Default for $struct
34        where
35            Self: ChangeableDefault,
36        {
37            fn default() -> Self {
38                <$struct>::get_default()
39            }
40        }
41    };
42}
43
44/* #endregion */
45
46/* #region FlagOrder */
47
48/// The order of the tensor.
49///
50/// # Default
51///
52/// Default order depends on cargo feature `f_prefer`.
53/// If `f_prefer` is set, then [`FlagOrder::F`] is applied as default;
54/// otherwise [`FlagOrder::C`] is applied as default.
55///
56/// # IMPORTANT NOTE
57///
58/// F-prefer is not a stable feature currently! We develop only in C-prefer
59/// currently.
60#[repr(u8)]
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62pub enum FlagOrder {
63    /// row-major order.
64    #[serde(rename = "RowMajor")]
65    C = 101,
66    /// column-major order.
67    #[serde(rename = "ColMajor")]
68    F = 102,
69}
70
71#[allow(non_upper_case_globals)]
72impl FlagOrder {
73    pub const RowMajor: Self = FlagOrder::C;
74    pub const ColMajor: Self = FlagOrder::F;
75}
76
77#[allow(clippy::derivable_impls)]
78impl Default for FlagOrder {
79    fn default() -> Self {
80        if cfg!(feature = "col_major") {
81            return FlagOrder::F;
82        } else {
83            return FlagOrder::C;
84        }
85    }
86}
87
88/* #endregion */
89
90/* #region TensorIterOrder */
91
92/// The policy of the tensor iterator.
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum TensorIterOrder {
95    /// Row-major order.
96    ///
97    /// - absolute safe for array iteration
98    #[serde(rename = "RowMajor")]
99    C,
100    /// Column-major order.
101    ///
102    /// - absolute safe for array iteration
103    #[serde(rename = "ColMajor")]
104    F,
105    /// Automatically choose row/col-major order.
106    ///
107    /// - try c/f-contig first (also see [`TensorIterOrder::B`]),
108    /// - try c/f-prefer second (also see [`TensorIterOrder::C`], [`TensorIterOrder::F`]),
109    /// - otherwise [`FlagOrder::default()`], which is defined by crate feature `f_prefer`.
110    ///
111    /// - safe for multi-array iteration like `get_iter(a, b)`
112    /// - not safe for cases like `a.iter().zip(b.iter())`
113    #[serde(rename = "Auto")]
114    A,
115    /// Greedy when possible (reorder layouts during iteration).
116    ///
117    /// - safe for multi-array iteration like `get_iter(a, b)`
118    /// - not safe for cases like `a.iter().zip(b.iter())`
119    /// - if it is used to create a new array, the stride of new array will be in K order
120    #[serde(rename = "Greedy")]
121    K,
122    /// Greedy when possible (reset dimension to 1 if axis is broadcasted).
123    ///
124    /// - not safe for multi-array iteration like `get_iter(a, b)`
125    /// - this is useful for inplace-assign broadcasted array.
126    #[serde(rename = "GreedyInplace")]
127    G,
128    /// Sequential buffer.
129    ///
130    /// - not safe for multi-array iteration like `get_iter(a, b)`
131    /// - this is useful for reshaping or all-contiguous cases.
132    #[serde(rename = "Sequential")]
133    B,
134}
135
136impl_changeable_default!(TensorIterOrder, DEFAULT_TENSOR_ITER_ORDER, TensorIterOrder::K);
137
138/* #endregion */
139
140/* #region TensorCopyPolicy */
141
142/// The policy of copying tensor.
143pub mod TensorCopyPolicy {
144    #![allow(non_snake_case)]
145
146    // this is a workaround in stable rust
147    // when const enum can not be used as generic parameters
148
149    pub type FlagCopy = u8;
150
151    /// Copy when needed
152    pub const COPY_NEEDED: FlagCopy = 0;
153    /// Force copy
154    pub const COPY_TRUE: FlagCopy = 1;
155    /// Force not copy; and when copy is required, it will emit error
156    pub const COPY_FALSE: FlagCopy = 2;
157
158    pub const COPY_DEFAULT: FlagCopy = COPY_NEEDED;
159}
160
161/* #endregion */
162
163/* #region blas-flags */
164
165#[repr(u8)]
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167pub enum FlagTrans {
168    /// No transpose
169    #[serde(rename = "NoTrans")]
170    N = 111,
171    /// Transpose
172    #[serde(rename = "Trans")]
173    T = 112,
174    /// Conjugate transpose
175    #[serde(rename = "ConjTrans")]
176    C = 113,
177    // Conjuate only
178    #[serde(rename = "Conj")]
179    CN = 114,
180}
181
182#[repr(u8)]
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184pub enum FlagSide {
185    /// Left side
186    #[serde(rename = "Left")]
187    L = 141,
188    /// Right side
189    #[serde(rename = "Right")]
190    R = 142,
191}
192
193#[repr(u8)]
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195pub enum FlagUpLo {
196    /// Upper triangle
197    #[serde(rename = "Upper")]
198    U = 121,
199    /// Lower triangle
200    #[serde(rename = "Lower")]
201    L = 122,
202}
203
204#[repr(u8)]
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
206pub enum FlagDiag {
207    /// Non-unit diagonal
208    #[serde(rename = "NonUnit")]
209    N = 131,
210    /// Unit diagonal
211    #[serde(rename = "Unit")]
212    U = 132,
213}
214
215/* #endregion */
216
217/* #region symm-flags */
218
219#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
220pub enum FlagSymm {
221    /// Symmetric matrix
222    #[serde(rename = "Symmetric")]
223    Sy,
224    /// Hermitian matrix
225    #[serde(rename = "Hermitian")]
226    He,
227    /// Anti-symmetric matrix
228    #[serde(rename = "AntiSymmetric")]
229    Ay,
230    /// Anti-Hermitian matrix
231    #[serde(rename = "AntiHermitian")]
232    Ah,
233    /// Non-symmetric matrix
234    #[serde(rename = "NonSymmetric")]
235    N,
236}
237
238pub type TensorOrder = FlagOrder;
239pub type TensorDiag = FlagDiag;
240pub type TensorSide = FlagSide;
241pub type TensorUpLo = FlagUpLo;
242pub type TensorTrans = FlagTrans;
243pub type TensorSymm = FlagSymm;
244
245/* #endregion */
246
247/* #region flag alias */
248
249pub use FlagTrans::C as ConjTrans;
250pub use FlagTrans::N as NoTrans;
251pub use FlagTrans::T as Trans;
252
253pub use FlagSide::L as Left;
254pub use FlagSide::R as Right;
255
256pub use FlagUpLo::L as Lower;
257pub use FlagUpLo::U as Upper;
258
259pub use FlagDiag::N as NonUnit;
260pub use FlagDiag::U as Unit;
261
262pub use FlagOrder::C as RowMajor;
263pub use FlagOrder::F as ColMajor;
264
265/* #endregion */
266
267/* #region flag into */
268
269impl From<char> for FlagTrans {
270    fn from(val: char) -> Self {
271        match val {
272            'N' | 'n' => FlagTrans::N,
273            'T' | 't' => FlagTrans::T,
274            'C' | 'c' => FlagTrans::C,
275            _ => rstsr_invalid!(val).unwrap(),
276        }
277    }
278}
279
280impl From<FlagTrans> for char {
281    fn from(val: FlagTrans) -> Self {
282        match val {
283            FlagTrans::N => 'N',
284            FlagTrans::T => 'T',
285            FlagTrans::C => 'C',
286            _ => rstsr_invalid!(val).unwrap(),
287        }
288    }
289}
290
291impl From<FlagTrans> for c_char {
292    fn from(val: FlagTrans) -> Self {
293        match val {
294            FlagTrans::N => b'N' as c_char,
295            FlagTrans::T => b'T' as c_char,
296            FlagTrans::C => b'C' as c_char,
297            _ => rstsr_invalid!(val).unwrap(),
298        }
299    }
300}
301
302impl From<c_char> for FlagTrans {
303    fn from(val: c_char) -> Self {
304        match val as u8 {
305            b'N' => FlagTrans::N,
306            b'T' => FlagTrans::T,
307            b'C' => FlagTrans::C,
308            _ => rstsr_invalid!(val).unwrap(),
309        }
310    }
311}
312
313impl From<CBLAS_TRANSPOSE> for FlagTrans {
314    fn from(val: CBLAS_TRANSPOSE) -> Self {
315        match val {
316            CBLAS_TRANSPOSE::CblasNoTrans => FlagTrans::N,
317            CBLAS_TRANSPOSE::CblasTrans => FlagTrans::T,
318            CBLAS_TRANSPOSE::CblasConjTrans => FlagTrans::C,
319        }
320    }
321}
322
323impl From<FlagTrans> for CBLAS_TRANSPOSE {
324    fn from(val: FlagTrans) -> Self {
325        match val {
326            FlagTrans::N => CBLAS_TRANSPOSE::CblasNoTrans,
327            FlagTrans::T => CBLAS_TRANSPOSE::CblasTrans,
328            FlagTrans::C => CBLAS_TRANSPOSE::CblasConjTrans,
329            _ => rstsr_invalid!(val).unwrap(),
330        }
331    }
332}
333
334impl From<char> for FlagDiag {
335    fn from(val: char) -> Self {
336        match val {
337            'N' | 'n' => FlagDiag::N,
338            'U' | 'u' => FlagDiag::U,
339            _ => rstsr_invalid!(val).unwrap(),
340        }
341    }
342}
343
344impl From<FlagDiag> for char {
345    fn from(val: FlagDiag) -> Self {
346        match val {
347            FlagDiag::N => 'N',
348            FlagDiag::U => 'U',
349        }
350    }
351}
352
353impl From<FlagDiag> for c_char {
354    fn from(val: FlagDiag) -> Self {
355        match val {
356            FlagDiag::N => b'N' as c_char,
357            FlagDiag::U => b'U' as c_char,
358        }
359    }
360}
361
362impl From<c_char> for FlagDiag {
363    fn from(val: c_char) -> Self {
364        match val as u8 {
365            b'N' => FlagDiag::N,
366            b'U' => FlagDiag::U,
367            _ => rstsr_invalid!(val).unwrap(),
368        }
369    }
370}
371
372impl From<CBLAS_DIAG> for FlagDiag {
373    fn from(val: CBLAS_DIAG) -> Self {
374        match val {
375            CBLAS_DIAG::CblasNonUnit => FlagDiag::N,
376            CBLAS_DIAG::CblasUnit => FlagDiag::U,
377        }
378    }
379}
380
381impl From<FlagDiag> for CBLAS_DIAG {
382    fn from(val: FlagDiag) -> Self {
383        match val {
384            FlagDiag::N => CBLAS_DIAG::CblasNonUnit,
385            FlagDiag::U => CBLAS_DIAG::CblasUnit,
386        }
387    }
388}
389
390impl From<char> for FlagSide {
391    fn from(val: char) -> Self {
392        match val {
393            'L' | 'l' => FlagSide::L,
394            'R' | 'r' => FlagSide::R,
395            _ => rstsr_invalid!(val).unwrap(),
396        }
397    }
398}
399
400impl From<FlagSide> for char {
401    fn from(val: FlagSide) -> Self {
402        match val {
403            FlagSide::L => 'L',
404            FlagSide::R => 'R',
405        }
406    }
407}
408
409impl From<FlagSide> for c_char {
410    fn from(val: FlagSide) -> Self {
411        match val {
412            FlagSide::L => b'L' as c_char,
413            FlagSide::R => b'R' as c_char,
414        }
415    }
416}
417
418impl From<c_char> for FlagSide {
419    fn from(val: c_char) -> Self {
420        match val as u8 {
421            b'L' => FlagSide::L,
422            b'R' => FlagSide::R,
423            _ => rstsr_invalid!(val).unwrap(),
424        }
425    }
426}
427
428impl From<CBLAS_SIDE> for FlagSide {
429    fn from(val: CBLAS_SIDE) -> Self {
430        match val {
431            CBLAS_SIDE::CblasLeft => FlagSide::L,
432            CBLAS_SIDE::CblasRight => FlagSide::R,
433        }
434    }
435}
436
437impl From<FlagSide> for CBLAS_SIDE {
438    fn from(val: FlagSide) -> Self {
439        match val {
440            FlagSide::L => CBLAS_SIDE::CblasLeft,
441            FlagSide::R => CBLAS_SIDE::CblasRight,
442        }
443    }
444}
445
446impl From<char> for FlagUpLo {
447    fn from(val: char) -> Self {
448        match val {
449            'U' | 'u' => FlagUpLo::U,
450            'L' | 'l' => FlagUpLo::L,
451            _ => rstsr_invalid!(val).unwrap(),
452        }
453    }
454}
455
456impl From<FlagUpLo> for char {
457    fn from(val: FlagUpLo) -> Self {
458        match val {
459            FlagUpLo::U => 'U',
460            FlagUpLo::L => 'L',
461        }
462    }
463}
464
465impl From<FlagUpLo> for c_char {
466    fn from(val: FlagUpLo) -> Self {
467        match val {
468            FlagUpLo::U => b'U' as c_char,
469            FlagUpLo::L => b'L' as c_char,
470        }
471    }
472}
473
474impl From<c_char> for FlagUpLo {
475    fn from(val: c_char) -> Self {
476        match val as u8 {
477            b'U' => FlagUpLo::U,
478            b'L' => FlagUpLo::L,
479            _ => rstsr_invalid!(val).unwrap(),
480        }
481    }
482}
483
484impl From<CBLAS_UPLO> for FlagUpLo {
485    fn from(val: CBLAS_UPLO) -> Self {
486        match val {
487            CBLAS_UPLO::CblasUpper => FlagUpLo::U,
488            CBLAS_UPLO::CblasLower => FlagUpLo::L,
489        }
490    }
491}
492
493impl From<FlagUpLo> for CBLAS_UPLO {
494    fn from(val: FlagUpLo) -> Self {
495        match val {
496            FlagUpLo::U => CBLAS_UPLO::CblasUpper,
497            FlagUpLo::L => CBLAS_UPLO::CblasLower,
498        }
499    }
500}
501
502impl From<CBLAS_LAYOUT> for FlagOrder {
503    fn from(val: CBLAS_LAYOUT) -> Self {
504        match val {
505            CBLAS_LAYOUT::CblasRowMajor => FlagOrder::C,
506            CBLAS_LAYOUT::CblasColMajor => FlagOrder::F,
507        }
508    }
509}
510
511impl From<FlagOrder> for CBLAS_LAYOUT {
512    fn from(val: FlagOrder) -> Self {
513        match val {
514            FlagOrder::C => CBLAS_LAYOUT::CblasRowMajor,
515            FlagOrder::F => CBLAS_LAYOUT::CblasColMajor,
516        }
517    }
518}
519
520/* #endregion */
521
522/* #region flag flip */
523
524impl FlagOrder {
525    pub fn flip(&self) -> Self {
526        match self {
527            FlagOrder::C => FlagOrder::F,
528            FlagOrder::F => FlagOrder::C,
529        }
530    }
531}
532
533impl FlagTrans {
534    pub fn flip(&self, hermi: bool) -> Result<Self> {
535        match self {
536            FlagTrans::N => match hermi {
537                true => Ok(FlagTrans::C),
538                false => Ok(FlagTrans::T),
539            },
540            FlagTrans::T => Ok(FlagTrans::N),
541            FlagTrans::C => Ok(FlagTrans::N),
542            _ => rstsr_invalid!(self)?,
543        }
544    }
545}
546
547impl FlagSide {
548    pub fn flip(&self) -> Self {
549        match self {
550            FlagSide::L => FlagSide::R,
551            FlagSide::R => FlagSide::L,
552        }
553    }
554}
555
556impl FlagUpLo {
557    pub fn flip(&self) -> Self {
558        match self {
559            FlagUpLo::U => FlagUpLo::L,
560            FlagUpLo::L => FlagUpLo::U,
561        }
562    }
563}
564
565/* #endregion */