1use crate::error::SolverError;
14
15#[derive(Debug)]
17pub struct SvdResult {
18 pub u: Vec<f32>,
20 pub sigma: Vec<f32>,
22 pub vt: Vec<f32>,
24 pub m: usize,
26 pub n: usize,
28}
29
30#[allow(clippy::cast_precision_loss)]
41pub fn svd(a: &[f32], m: usize, n: usize) -> Result<SvdResult, SolverError> {
42 if a.len() != m * n {
43 return Err(SolverError::SvdDimensionMismatch { m, n });
44 }
45
46 let min_mn = m.min(n);
47 let mut work = a.to_vec();
48
49 let mut v = vec![0.0f32; n * n];
51 for i in 0..n {
52 v[i * n + i] = 1.0;
53 }
54
55 let tol = f32::EPSILON * (m as f32).sqrt();
57 jacobi_sweeps(&mut work, &mut v, m, n, tol);
58
59 let sigma = extract_singular_values(&work, m, n, min_mn);
61 let u = compute_u_matrix(&work, &sigma, m, n, min_mn);
62
63 Ok(assemble_sorted_result(&u, &v, &sigma, m, n, min_mn))
65}
66
67fn jacobi_sweeps(work: &mut [f32], v: &mut [f32], m: usize, n: usize, tol: f32) {
69 let max_sweeps = 100;
70 for _sweep in 0..max_sweeps {
71 let mut converged = true;
72 for p in 0..n {
73 for q in (p + 1)..n {
74 if apply_jacobi_rotation(work, v, m, n, p, q, tol) {
75 converged = false;
76 }
77 }
78 }
79 if converged {
80 break;
81 }
82 }
83}
84
85fn apply_jacobi_rotation(
88 work: &mut [f32],
89 v: &mut [f32],
90 m: usize,
91 n: usize,
92 p: usize,
93 q: usize,
94 tol: f32,
95) -> bool {
96 let (app, apq, aqq) = gram_elements(work, m, n, p, q);
97
98 let scale = (app * aqq).sqrt();
99 if apq.abs() < f64::from(tol) * scale || scale < f64::EPSILON {
100 return false;
101 }
102
103 let (c, s) = jacobi_rotation_angle(app, apq, aqq);
104 rotate_columns(work, m, n, p, q, c, s);
105 rotate_columns(v, n, n, p, q, c, s);
106 true
107}
108
109fn gram_elements(work: &[f32], m: usize, n: usize, p: usize, q: usize) -> (f64, f64, f64) {
111 let mut app = 0.0f64;
112 let mut apq = 0.0f64;
113 let mut aqq = 0.0f64;
114 for i in 0..m {
115 let wp = f64::from(work[i * n + p]);
116 let wq = f64::from(work[i * n + q]);
117 app += wp * wp;
118 apq += wp * wq;
119 aqq += wq * wq;
120 }
121 (app, apq, aqq)
122}
123
124fn jacobi_rotation_angle(app: f64, apq: f64, aqq: f64) -> (f64, f64) {
126 let tau = (aqq - app) / (2.0 * apq);
127 let t = if tau >= 0.0 {
128 1.0 / (tau + (1.0 + tau * tau).sqrt())
129 } else {
130 -1.0 / (-tau + (1.0 + tau * tau).sqrt())
131 };
132 let c = 1.0 / (1.0 + t * t).sqrt();
133 let s = t * c;
134 (c, s)
135}
136
137fn rotate_columns(mat: &mut [f32], rows: usize, cols: usize, p: usize, q: usize, c: f64, s: f64) {
139 for i in 0..rows {
140 let mp = f64::from(mat[i * cols + p]);
141 let mq = f64::from(mat[i * cols + q]);
142 mat[i * cols + p] = (c * mp - s * mq) as f32;
143 mat[i * cols + q] = (s * mp + c * mq) as f32;
144 }
145}
146
147fn extract_singular_values(work: &[f32], m: usize, n: usize, min_mn: usize) -> Vec<f32> {
149 let mut sigma = vec![0.0f32; min_mn];
150 for j in 0..min_mn {
151 let mut norm_sq = 0.0f64;
152 for i in 0..m {
153 let val = f64::from(work[i * n + j]);
154 norm_sq += val * val;
155 }
156 sigma[j] = norm_sq.sqrt() as f32;
157 }
158 sigma
159}
160
161fn compute_u_matrix(work: &[f32], sigma: &[f32], m: usize, n: usize, min_mn: usize) -> Vec<f32> {
163 let mut u = vec![0.0f32; m * m];
164 for j in 0..min_mn {
165 if sigma[j] > f32::EPSILON {
166 let inv_sigma = 1.0 / sigma[j];
167 for i in 0..m {
168 u[i * m + j] = work[i * n + j] * inv_sigma;
169 }
170 }
171 }
172 for j in min_mn..m {
173 u[j * m + j] = 1.0;
174 }
175 u
176}
177
178#[allow(clippy::cast_precision_loss)]
180fn assemble_sorted_result(
181 u: &[f32],
182 v: &[f32],
183 sigma: &[f32],
184 m: usize,
185 n: usize,
186 min_mn: usize,
187) -> SvdResult {
188 let mut indices: Vec<usize> = (0..min_mn).collect();
189 indices.sort_by(|&a, &b| {
190 sigma[b]
191 .partial_cmp(&sigma[a])
192 .unwrap_or(std::cmp::Ordering::Equal)
193 });
194
195 let mut sigma_sorted = vec![0.0f32; min_mn];
196 let mut u_sorted = vec![0.0f32; m * m];
197 let mut v_sorted = vec![0.0f32; n * n];
198
199 for (new_j, &old_j) in indices.iter().enumerate() {
200 sigma_sorted[new_j] = sigma[old_j];
201 for i in 0..m {
202 u_sorted[i * m + new_j] = u[i * m + old_j];
203 }
204 for i in 0..n {
205 v_sorted[i * n + new_j] = v[i * n + old_j];
206 }
207 }
208
209 for j in min_mn..m {
210 u_sorted[j * m + j] = 1.0;
211 }
212 for j in min_mn..n {
213 v_sorted[j * n + j] = 1.0;
214 }
215
216 let mut vt = vec![0.0f32; n * n];
218 for i in 0..n {
219 for j in 0..n {
220 vt[i * n + j] = v_sorted[j * n + i];
221 }
222 }
223
224 SvdResult {
225 u: u_sorted,
226 sigma: sigma_sorted,
227 vt,
228 m,
229 n,
230 }
231}