faer/linalg/evd/schur/
mod.rs

1use super::*;
2
3pub(crate) mod complex_schur;
4pub(crate) mod real_schur;
5
6#[derive(Clone, Copy, Debug)]
7pub struct SchurParams {
8	/// function that returns the number of shifts to use for a given matrix size
9	pub recommended_shift_count: fn(matrix_dimension: usize, active_block_dimension: usize) -> usize,
10	/// function that returns the deflation window to use for a given matrix size
11	pub recommended_deflation_window: fn(matrix_dimension: usize, active_block_dimension: usize) -> usize,
12	/// threshold to switch between blocked and unblocked code
13	pub blocking_threshold: usize,
14	/// threshold of percent of aggressive-early-deflation window that must converge to skip a
15	/// sweep
16	pub nibble_threshold: usize,
17
18	#[doc(hidden)]
19	pub non_exhaustive: NonExhaustive,
20}
21
22impl<T: ComplexField> Auto<T> for SchurParams {
23	fn auto() -> Self {
24		Self {
25			recommended_shift_count: default_recommended_shift_count,
26			recommended_deflation_window: default_recommended_deflation_window,
27			blocking_threshold: 75,
28			nibble_threshold: 50,
29			non_exhaustive: NonExhaustive(()),
30		}
31	}
32}
33
34pub fn multishift_qr_scratch<T: ComplexField>(n: usize, nh: usize, want_t: bool, want_z: bool, parallelism: Par, params: SchurParams) -> StackReq {
35	let nsr = (params.recommended_shift_count)(n, nh);
36
37	let _ = want_t;
38	let _ = want_z;
39
40	if n <= 3 {
41		return StackReq::EMPTY;
42	}
43
44	let nw_max = (n - 3) / 3;
45
46	StackReq::any_of(&[
47		hessenberg::hessenberg_in_place_scratch::<T>(nw_max, 1, parallelism, Default::default()),
48		linalg::householder::apply_block_householder_sequence_transpose_on_the_left_in_place_scratch::<T>(nw_max, nw_max, nw_max),
49		linalg::householder::apply_block_householder_sequence_on_the_right_in_place_scratch::<T>(nw_max, nw_max, nw_max),
50		temp_mat_scratch::<T>(3, nsr),
51	])
52}
53
54fn default_recommended_shift_count(dim: usize, _active_block_dim: usize) -> usize {
55	let n = dim;
56	if n < 30 {
57		2
58	} else if n < 60 {
59		4
60	} else if n < 150 {
61		12
62	} else if n < 590 {
63		32
64	} else if n < 3000 {
65		64
66	} else if n < 6000 {
67		128
68	} else {
69		256
70	}
71}
72
73fn default_recommended_deflation_window(dim: usize, _active_block_dim: usize) -> usize {
74	let n = dim;
75	if n < 30 {
76		2
77	} else if n < 60 {
78		4
79	} else if n < 150 {
80		10
81	} else if n < 590 {
82		#[cfg(feature = "std")]
83		{
84			(n as f64 / (n as f64).log2()) as usize
85		}
86		#[cfg(not(feature = "std"))]
87		{
88			libm::log2(n as f64 / (n as f64)) as usize
89		}
90	} else if n < 3000 {
91		96
92	} else if n < 6000 {
93		192
94	} else {
95		384
96	}
97}