Skip to main content

wls_alloc/
setup.rs

1use nalgebra::{allocator::Allocator, Const, DefaultAllocator, DimMin, DimName};
2
3use crate::types::{MatA, VecN, MIN_DIAG_CLAMP};
4
5#[allow(clippy::needless_range_loop)] // 2D symmetric matrix access a2[i][j]
6fn gamma_estimator<const NV: usize>(a2: &[[f32; NV]; NV], cond_target: f32) -> (f32, f32) {
7    let mut max_sig: f32 = 0.0;
8    for i in 0..NV {
9        let mut r: f32 = 0.0;
10        for j in 0..NV {
11            if j != i {
12                r += libm::fabsf(a2[i][j]);
13            }
14        }
15        let disk = a2[i][i] + r;
16        if max_sig < disk {
17            max_sig = disk;
18        }
19    }
20    (libm::sqrtf(max_sig / cond_target), max_sig)
21}
22
23/// Convert WLS control allocation to a least-squares problem `min ||Au - b||`.
24///
25/// `wu` is **normalized in-place** by its minimum value (matching the C code).
26/// Returns `(A, gamma)`.
27#[allow(clippy::needless_range_loop)] // symmetric matrix fill uses a2[i][j] and a2[j][i]
28pub fn setup_a<const NU: usize, const NV: usize, const NC: usize>(
29    b_mat: &MatA<NV, NU>,
30    wv: &VecN<NV>,
31    wu: &mut VecN<NU>,
32    theta: f32,
33    cond_bound: f32,
34) -> (MatA<NC, NU>, f32)
35where
36    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
37    Const<NU>: DimName,
38    Const<NV>: DimName,
39    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
40        + Allocator<Const<NC>, Const<NC>>
41        + Allocator<Const<NU>, Const<NU>>
42        + Allocator<Const<NC>>
43        + Allocator<Const<NU>>
44        + Allocator<Const<NV>>,
45{
46    debug_assert_eq!(NC, NU + NV);
47
48    // Compute A2[i][j] — symmetric NV×NV Gershgorin scratch
49    let mut a2 = [[0.0f32; NV]; NV];
50    for i in 0..NV {
51        for j in i..NV {
52            let mut sum = 0.0f32;
53            for k in 0..NU {
54                sum += b_mat[(i, k)] * b_mat[(j, k)];
55            }
56            a2[i][j] = sum * wv[i] * wv[i];
57            if i != j {
58                a2[j][i] = a2[i][j];
59            }
60        }
61    }
62
63    // Normalise Wu
64    let mut min_diag: f32 = f32::INFINITY;
65    let mut max_diag: f32 = 0.0;
66    for i in 0..NU {
67        if wu[i] < min_diag {
68            min_diag = wu[i];
69        }
70        if wu[i] > max_diag {
71            max_diag = wu[i];
72        }
73    }
74    if min_diag < MIN_DIAG_CLAMP {
75        min_diag = MIN_DIAG_CLAMP;
76    }
77    let inv = 1.0 / min_diag;
78    for i in 0..NU {
79        wu[i] *= inv;
80    }
81    max_diag *= inv;
82
83    // Compute gamma
84    let gamma = if cond_bound > 0.0 {
85        let (ge, ms) = gamma_estimator(&a2, cond_bound);
86        let gt = libm::sqrtf(ms) * theta / max_diag;
87        if ge > gt {
88            ge
89        } else {
90            gt
91        }
92    } else {
93        let (_, ms) = gamma_estimator(&a2, 1.0);
94        libm::sqrtf(ms) * theta / max_diag
95    };
96
97    // Build A via nalgebra
98    let mut a: MatA<NC, NU> = MatA::zeros();
99    for j in 0..NU {
100        for i in 0..NV {
101            a[(i, j)] = wv[i] * b_mat[(i, j)];
102        }
103        a[(NV + j, j)] = gamma * wu[j];
104    }
105
106    (a, gamma)
107}
108
109/// Compute the right-hand side `b` for the LS problem.
110pub fn setup_b<const NU: usize, const NV: usize, const NC: usize>(
111    v: &VecN<NV>,
112    ud: &VecN<NU>,
113    wv: &VecN<NV>,
114    wu_norm: &VecN<NU>,
115    gamma: f32,
116) -> VecN<NC>
117where
118    Const<NC>: DimName + DimMin<Const<NU>, Output = Const<NU>>,
119    Const<NU>: DimName,
120    Const<NV>: DimName,
121    DefaultAllocator: Allocator<Const<NC>, Const<NU>>
122        + Allocator<Const<NC>, Const<NC>>
123        + Allocator<Const<NU>, Const<NU>>
124        + Allocator<Const<NC>>
125        + Allocator<Const<NU>>
126        + Allocator<Const<NV>>,
127{
128    debug_assert_eq!(NC, NU + NV);
129    let mut b: VecN<NC> = VecN::zeros();
130    for i in 0..NV {
131        b[i] = wv[i] * v[i];
132    }
133    for i in 0..NU {
134        b[NV + i] = gamma * wu_norm[i] * ud[i];
135    }
136    b
137}