1use crate::prelude_dev::*;
4use core::ffi::c_char;
5use rstsr_cblas_base::*;
6use serde::{Deserialize, Serialize};
7
8pub trait ChangeableDefault {
11 unsafe fn change_default(val: Self);
16 fn get_default() -> Self;
17}
18
19macro_rules! impl_changeable_default {
20 ($struct:ty, $val:ident, $default:expr) => {
21 static mut $val: $struct = $default;
22
23 impl ChangeableDefault for $struct {
24 unsafe fn change_default(val: Self) {
25 $val = val;
26 }
27
28 fn get_default() -> Self {
29 return unsafe { $val };
30 }
31 }
32
33 impl Default for $struct
34 where
35 Self: ChangeableDefault,
36 {
37 fn default() -> Self {
38 <$struct>::get_default()
39 }
40 }
41 };
42}
43
44#[repr(u8)]
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62pub enum FlagOrder {
63 #[serde(rename = "RowMajor")]
65 C = 101,
66 #[serde(rename = "ColMajor")]
68 F = 102,
69}
70
71#[allow(non_upper_case_globals)]
72impl FlagOrder {
73 pub const RowMajor: Self = FlagOrder::C;
74 pub const ColMajor: Self = FlagOrder::F;
75}
76
77#[allow(clippy::derivable_impls)]
78impl Default for FlagOrder {
79 fn default() -> Self {
80 if cfg!(feature = "col_major") {
81 return FlagOrder::F;
82 } else {
83 return FlagOrder::C;
84 }
85 }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94pub enum TensorIterOrder {
95 #[serde(rename = "RowMajor")]
99 C,
100 #[serde(rename = "ColMajor")]
104 F,
105 #[serde(rename = "Auto")]
114 A,
115 #[serde(rename = "Greedy")]
121 K,
122 #[serde(rename = "GreedyInplace")]
127 G,
128 #[serde(rename = "Sequential")]
133 B,
134}
135
136impl_changeable_default!(TensorIterOrder, DEFAULT_TENSOR_ITER_ORDER, TensorIterOrder::K);
137
138pub mod TensorCopyPolicy {
144 #![allow(non_snake_case)]
145
146 pub type FlagCopy = u8;
150
151 pub const COPY_NEEDED: FlagCopy = 0;
153 pub const COPY_TRUE: FlagCopy = 1;
155 pub const COPY_FALSE: FlagCopy = 2;
157
158 pub const COPY_DEFAULT: FlagCopy = COPY_NEEDED;
159}
160
161#[repr(u8)]
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
167pub enum FlagTrans {
168 #[serde(rename = "NoTrans")]
170 N = 111,
171 #[serde(rename = "Trans")]
173 T = 112,
174 #[serde(rename = "ConjTrans")]
176 C = 113,
177 #[serde(rename = "Conj")]
179 CN = 114,
180}
181
182#[repr(u8)]
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
184pub enum FlagSide {
185 #[serde(rename = "Left")]
187 L = 141,
188 #[serde(rename = "Right")]
190 R = 142,
191}
192
193#[repr(u8)]
194#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
195pub enum FlagUpLo {
196 #[serde(rename = "Upper")]
198 U = 121,
199 #[serde(rename = "Lower")]
201 L = 122,
202}
203
204#[repr(u8)]
205#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
206pub enum FlagDiag {
207 #[serde(rename = "NonUnit")]
209 N = 131,
210 #[serde(rename = "Unit")]
212 U = 132,
213}
214
215#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
220pub enum FlagSymm {
221 #[serde(rename = "Symmetric")]
223 Sy,
224 #[serde(rename = "Hermitian")]
226 He,
227 #[serde(rename = "AntiSymmetric")]
229 Ay,
230 #[serde(rename = "AntiHermitian")]
232 Ah,
233 #[serde(rename = "NonSymmetric")]
235 N,
236}
237
238pub type TensorOrder = FlagOrder;
239pub type TensorDiag = FlagDiag;
240pub type TensorSide = FlagSide;
241pub type TensorUpLo = FlagUpLo;
242pub type TensorTrans = FlagTrans;
243pub type TensorSymm = FlagSymm;
244
245pub use FlagTrans::C as ConjTrans;
250pub use FlagTrans::N as NoTrans;
251pub use FlagTrans::T as Trans;
252
253pub use FlagSide::L as Left;
254pub use FlagSide::R as Right;
255
256pub use FlagUpLo::L as Lower;
257pub use FlagUpLo::U as Upper;
258
259pub use FlagDiag::N as NonUnit;
260pub use FlagDiag::U as Unit;
261
262pub use FlagOrder::C as RowMajor;
263pub use FlagOrder::F as ColMajor;
264
265impl From<char> for FlagTrans {
270 fn from(val: char) -> Self {
271 match val {
272 'N' | 'n' => FlagTrans::N,
273 'T' | 't' => FlagTrans::T,
274 'C' | 'c' => FlagTrans::C,
275 _ => rstsr_invalid!(val).unwrap(),
276 }
277 }
278}
279
280impl From<FlagTrans> for char {
281 fn from(val: FlagTrans) -> Self {
282 match val {
283 FlagTrans::N => 'N',
284 FlagTrans::T => 'T',
285 FlagTrans::C => 'C',
286 _ => rstsr_invalid!(val).unwrap(),
287 }
288 }
289}
290
291impl From<FlagTrans> for c_char {
292 fn from(val: FlagTrans) -> Self {
293 match val {
294 FlagTrans::N => b'N' as c_char,
295 FlagTrans::T => b'T' as c_char,
296 FlagTrans::C => b'C' as c_char,
297 _ => rstsr_invalid!(val).unwrap(),
298 }
299 }
300}
301
302impl From<c_char> for FlagTrans {
303 fn from(val: c_char) -> Self {
304 match val as u8 {
305 b'N' => FlagTrans::N,
306 b'T' => FlagTrans::T,
307 b'C' => FlagTrans::C,
308 _ => rstsr_invalid!(val).unwrap(),
309 }
310 }
311}
312
313impl From<CBLAS_TRANSPOSE> for FlagTrans {
314 fn from(val: CBLAS_TRANSPOSE) -> Self {
315 match val {
316 CBLAS_TRANSPOSE::CblasNoTrans => FlagTrans::N,
317 CBLAS_TRANSPOSE::CblasTrans => FlagTrans::T,
318 CBLAS_TRANSPOSE::CblasConjTrans => FlagTrans::C,
319 }
320 }
321}
322
323impl From<FlagTrans> for CBLAS_TRANSPOSE {
324 fn from(val: FlagTrans) -> Self {
325 match val {
326 FlagTrans::N => CBLAS_TRANSPOSE::CblasNoTrans,
327 FlagTrans::T => CBLAS_TRANSPOSE::CblasTrans,
328 FlagTrans::C => CBLAS_TRANSPOSE::CblasConjTrans,
329 _ => rstsr_invalid!(val).unwrap(),
330 }
331 }
332}
333
334impl From<char> for FlagDiag {
335 fn from(val: char) -> Self {
336 match val {
337 'N' | 'n' => FlagDiag::N,
338 'U' | 'u' => FlagDiag::U,
339 _ => rstsr_invalid!(val).unwrap(),
340 }
341 }
342}
343
344impl From<FlagDiag> for char {
345 fn from(val: FlagDiag) -> Self {
346 match val {
347 FlagDiag::N => 'N',
348 FlagDiag::U => 'U',
349 }
350 }
351}
352
353impl From<FlagDiag> for c_char {
354 fn from(val: FlagDiag) -> Self {
355 match val {
356 FlagDiag::N => b'N' as c_char,
357 FlagDiag::U => b'U' as c_char,
358 }
359 }
360}
361
362impl From<c_char> for FlagDiag {
363 fn from(val: c_char) -> Self {
364 match val as u8 {
365 b'N' => FlagDiag::N,
366 b'U' => FlagDiag::U,
367 _ => rstsr_invalid!(val).unwrap(),
368 }
369 }
370}
371
372impl From<CBLAS_DIAG> for FlagDiag {
373 fn from(val: CBLAS_DIAG) -> Self {
374 match val {
375 CBLAS_DIAG::CblasNonUnit => FlagDiag::N,
376 CBLAS_DIAG::CblasUnit => FlagDiag::U,
377 }
378 }
379}
380
381impl From<FlagDiag> for CBLAS_DIAG {
382 fn from(val: FlagDiag) -> Self {
383 match val {
384 FlagDiag::N => CBLAS_DIAG::CblasNonUnit,
385 FlagDiag::U => CBLAS_DIAG::CblasUnit,
386 }
387 }
388}
389
390impl From<char> for FlagSide {
391 fn from(val: char) -> Self {
392 match val {
393 'L' | 'l' => FlagSide::L,
394 'R' | 'r' => FlagSide::R,
395 _ => rstsr_invalid!(val).unwrap(),
396 }
397 }
398}
399
400impl From<FlagSide> for char {
401 fn from(val: FlagSide) -> Self {
402 match val {
403 FlagSide::L => 'L',
404 FlagSide::R => 'R',
405 }
406 }
407}
408
409impl From<FlagSide> for c_char {
410 fn from(val: FlagSide) -> Self {
411 match val {
412 FlagSide::L => b'L' as c_char,
413 FlagSide::R => b'R' as c_char,
414 }
415 }
416}
417
418impl From<c_char> for FlagSide {
419 fn from(val: c_char) -> Self {
420 match val as u8 {
421 b'L' => FlagSide::L,
422 b'R' => FlagSide::R,
423 _ => rstsr_invalid!(val).unwrap(),
424 }
425 }
426}
427
428impl From<CBLAS_SIDE> for FlagSide {
429 fn from(val: CBLAS_SIDE) -> Self {
430 match val {
431 CBLAS_SIDE::CblasLeft => FlagSide::L,
432 CBLAS_SIDE::CblasRight => FlagSide::R,
433 }
434 }
435}
436
437impl From<FlagSide> for CBLAS_SIDE {
438 fn from(val: FlagSide) -> Self {
439 match val {
440 FlagSide::L => CBLAS_SIDE::CblasLeft,
441 FlagSide::R => CBLAS_SIDE::CblasRight,
442 }
443 }
444}
445
446impl From<char> for FlagUpLo {
447 fn from(val: char) -> Self {
448 match val {
449 'U' | 'u' => FlagUpLo::U,
450 'L' | 'l' => FlagUpLo::L,
451 _ => rstsr_invalid!(val).unwrap(),
452 }
453 }
454}
455
456impl From<FlagUpLo> for char {
457 fn from(val: FlagUpLo) -> Self {
458 match val {
459 FlagUpLo::U => 'U',
460 FlagUpLo::L => 'L',
461 }
462 }
463}
464
465impl From<FlagUpLo> for c_char {
466 fn from(val: FlagUpLo) -> Self {
467 match val {
468 FlagUpLo::U => b'U' as c_char,
469 FlagUpLo::L => b'L' as c_char,
470 }
471 }
472}
473
474impl From<c_char> for FlagUpLo {
475 fn from(val: c_char) -> Self {
476 match val as u8 {
477 b'U' => FlagUpLo::U,
478 b'L' => FlagUpLo::L,
479 _ => rstsr_invalid!(val).unwrap(),
480 }
481 }
482}
483
484impl From<CBLAS_UPLO> for FlagUpLo {
485 fn from(val: CBLAS_UPLO) -> Self {
486 match val {
487 CBLAS_UPLO::CblasUpper => FlagUpLo::U,
488 CBLAS_UPLO::CblasLower => FlagUpLo::L,
489 }
490 }
491}
492
493impl From<FlagUpLo> for CBLAS_UPLO {
494 fn from(val: FlagUpLo) -> Self {
495 match val {
496 FlagUpLo::U => CBLAS_UPLO::CblasUpper,
497 FlagUpLo::L => CBLAS_UPLO::CblasLower,
498 }
499 }
500}
501
502impl From<CBLAS_LAYOUT> for FlagOrder {
503 fn from(val: CBLAS_LAYOUT) -> Self {
504 match val {
505 CBLAS_LAYOUT::CblasRowMajor => FlagOrder::C,
506 CBLAS_LAYOUT::CblasColMajor => FlagOrder::F,
507 }
508 }
509}
510
511impl From<FlagOrder> for CBLAS_LAYOUT {
512 fn from(val: FlagOrder) -> Self {
513 match val {
514 FlagOrder::C => CBLAS_LAYOUT::CblasRowMajor,
515 FlagOrder::F => CBLAS_LAYOUT::CblasColMajor,
516 }
517 }
518}
519
520impl FlagOrder {
525 pub fn flip(&self) -> Self {
526 match self {
527 FlagOrder::C => FlagOrder::F,
528 FlagOrder::F => FlagOrder::C,
529 }
530 }
531}
532
533impl FlagTrans {
534 pub fn flip(&self, hermi: bool) -> Result<Self> {
535 match self {
536 FlagTrans::N => match hermi {
537 true => Ok(FlagTrans::C),
538 false => Ok(FlagTrans::T),
539 },
540 FlagTrans::T => Ok(FlagTrans::N),
541 FlagTrans::C => Ok(FlagTrans::N),
542 _ => rstsr_invalid!(self)?,
543 }
544 }
545}
546
547impl FlagSide {
548 pub fn flip(&self) -> Self {
549 match self {
550 FlagSide::L => FlagSide::R,
551 FlagSide::R => FlagSide::L,
552 }
553 }
554}
555
556impl FlagUpLo {
557 pub fn flip(&self) -> Self {
558 match self {
559 FlagUpLo::U => FlagUpLo::L,
560 FlagUpLo::L => FlagUpLo::U,
561 }
562 }
563}
564
565