faer/linalg/evd/schur/
mod.rs1use super::*;
2
3pub(crate) mod complex_schur;
4pub(crate) mod real_schur;
5
6#[derive(Clone, Copy, Debug)]
7pub struct SchurParams {
8 pub recommended_shift_count: fn(matrix_dimension: usize, active_block_dimension: usize) -> usize,
10 pub recommended_deflation_window: fn(matrix_dimension: usize, active_block_dimension: usize) -> usize,
12 pub blocking_threshold: usize,
14 pub nibble_threshold: usize,
17
18 #[doc(hidden)]
19 pub non_exhaustive: NonExhaustive,
20}
21
22impl<T: ComplexField> Auto<T> for SchurParams {
23 fn auto() -> Self {
24 Self {
25 recommended_shift_count: default_recommended_shift_count,
26 recommended_deflation_window: default_recommended_deflation_window,
27 blocking_threshold: 75,
28 nibble_threshold: 50,
29 non_exhaustive: NonExhaustive(()),
30 }
31 }
32}
33
34pub fn multishift_qr_scratch<T: ComplexField>(n: usize, nh: usize, want_t: bool, want_z: bool, parallelism: Par, params: SchurParams) -> StackReq {
35 let nsr = (params.recommended_shift_count)(n, nh);
36
37 let _ = want_t;
38 let _ = want_z;
39
40 if n <= 3 {
41 return StackReq::EMPTY;
42 }
43
44 let nw_max = (n - 3) / 3;
45
46 StackReq::any_of(&[
47 hessenberg::hessenberg_in_place_scratch::<T>(nw_max, 1, parallelism, Default::default()),
48 linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(nw_max, nw_max, nw_max),
49 linalg::householder::apply_block_householder_sequence_on_the_right_in_place_scratch::<T>(nw_max, nw_max, nw_max),
50 temp_mat_scratch::<T>(3, nsr),
51 ])
52}
53
54fn default_recommended_shift_count(dim: usize, _active_block_dim: usize) -> usize {
55 let n = dim;
56 if n < 30 {
57 2
58 } else if n < 60 {
59 4
60 } else if n < 150 {
61 12
62 } else if n < 590 {
63 32
64 } else if n < 3000 {
65 64
66 } else if n < 6000 {
67 128
68 } else {
69 256
70 }
71}
72
73fn default_recommended_deflation_window(dim: usize, _active_block_dim: usize) -> usize {
74 let n = dim;
75 if n < 30 {
76 2
77 } else if n < 60 {
78 4
79 } else if n < 150 {
80 10
81 } else if n < 590 {
82 #[cfg(feature = "std")]
83 {
84 (n as f64 / (n as f64).log2()) as usize
85 }
86 #[cfg(not(feature = "std"))]
87 {
88 libm::log2(n as f64 / (n as f64)) as usize
89 }
90 } else if n < 3000 {
91 96
92 } else if n < 6000 {
93 192
94 } else {
95 384
96 }
97}