1use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::lu;
7use ndarray::{Array1, Array2};
8use num_complex::Complex64;
9
10const DEFAULT_TOLERANCE: f64 = 1.0e-12;
11
12#[derive(Debug, Clone)]
14pub struct IterativeConfig<T = f64> {
15 pub tolerance: T,
17 pub max_iterations: usize,
19}
20
21impl IterativeConfig<f64> {
22 #[must_use]
24 pub const fn default_f64() -> Self { Self { tolerance: 1e-10, max_iterations: 1000 } }
25}
26
27impl Default for IterativeConfig<f64> {
28 fn default() -> Self { Self::default_f64() }
29}
30
31impl IterativeConfig<f32> {
32 #[must_use]
34 pub const fn default_f32() -> Self { Self { tolerance: 1e-6, max_iterations: 1000 } }
35}
36
37impl Default for IterativeConfig<f32> {
38 fn default() -> Self { Self::default_f32() }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum IterativeError {
44 EmptyMatrix,
46 DimensionMismatch,
48 MaxIterationsExceeded,
50 NotPositiveDefinite,
52 Breakdown,
54}
55
56impl fmt::Display for IterativeError {
57 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58 match self {
59 IterativeError::EmptyMatrix => write!(f, "Matrix is empty"),
60 IterativeError::DimensionMismatch => write!(f, "Dimension mismatch"),
61 IterativeError::MaxIterationsExceeded => write!(f, "Maximum iterations exceeded"),
62 IterativeError::NotPositiveDefinite => write!(f, "Matrix is not positive definite"),
63 IterativeError::Breakdown => write!(f, "Algorithm breakdown"),
64 }
65 }
66}
67
68impl std::error::Error for IterativeError {}
69
70fn default_tolerance<T: NabledReal>() -> T {
71 T::from_f64(DEFAULT_TOLERANCE).unwrap_or_else(T::epsilon)
72}
73
74fn vector_norm<T: NabledReal>(vector: &Array1<T>) -> T {
75 vector
76 .iter()
77 .map(|value| *value * *value)
78 .fold(T::zero(), |acc, value| acc + value)
79 .sqrt()
80}
81
82fn vector_norm_complex(vector: &Array1<Complex64>) -> f64 {
83 vector.iter().map(Complex64::norm_sqr).sum::<f64>().sqrt()
84}
85
86#[cfg(feature = "lapack-provider")]
87trait IterativeLinearScalar: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack {}
88
89#[cfg(feature = "lapack-provider")]
90impl<T> IterativeLinearScalar for T where
91 T: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack
92{
93}
94
95#[cfg(not(feature = "lapack-provider"))]
96trait IterativeLinearScalar: NabledReal + std::ops::SubAssign {}
97
98#[cfg(not(feature = "lapack-provider"))]
99impl<T> IterativeLinearScalar for T where T: NabledReal + std::ops::SubAssign {}
100
101pub fn conjugate_gradient<T>(
106 matrix_a: &Array2<T>,
107 matrix_b: &Array1<T>,
108 config: &IterativeConfig<T>,
109) -> Result<Array1<T>, IterativeError>
110where
111 T: NabledReal,
112{
113 if matrix_a.is_empty() || matrix_b.is_empty() {
114 return Err(IterativeError::EmptyMatrix);
115 }
116 if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
117 return Err(IterativeError::DimensionMismatch);
118 }
119
120 let n = matrix_b.len();
121 let mut x = Array1::<T>::zeros(n);
122 let mut r = matrix_b.clone();
123 let mut p = r.clone();
124 let mut rs_old = r.dot(&r);
125
126 let tolerance = config.tolerance.max(default_tolerance::<T>());
127 if rs_old.sqrt() <= tolerance {
128 return Ok(x);
129 }
130
131 for _ in 0..config.max_iterations {
132 let ap = matrix_a.dot(&p);
133 let curvature = p.dot(&ap);
134 if curvature <= tolerance {
135 return Err(IterativeError::NotPositiveDefinite);
136 }
137
138 let alpha = rs_old / curvature;
139 x = &x + &p.mapv(|value| alpha * value);
140 r = &r - &ap.mapv(|value| alpha * value);
141
142 let rs_new = r.dot(&r);
143 if rs_new.sqrt() <= tolerance {
144 return Ok(x);
145 }
146
147 let beta = rs_new / rs_old;
148 p = &r + &p.mapv(|value| beta * value);
149 rs_old = rs_new;
150 }
151
152 Err(IterativeError::MaxIterationsExceeded)
153}
154
155pub fn conjugate_gradient_complex(
160 matrix_a: &Array2<Complex64>,
161 matrix_b: &Array1<Complex64>,
162 config: &IterativeConfig<f64>,
163) -> Result<Array1<Complex64>, IterativeError> {
164 if matrix_a.is_empty() || matrix_b.is_empty() {
165 return Err(IterativeError::EmptyMatrix);
166 }
167 if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
168 return Err(IterativeError::DimensionMismatch);
169 }
170
171 let n = matrix_b.len();
172 let mut x = Array1::<Complex64>::zeros(n);
173 let mut r = matrix_b.clone();
174 let mut p = r.clone();
175 let mut rs_old = r.iter().zip(r.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
176 let tolerance = config.tolerance.max(DEFAULT_TOLERANCE);
177
178 if rs_old.re.max(0.0).sqrt() <= tolerance {
179 return Ok(x);
180 }
181
182 for _ in 0..config.max_iterations {
183 let ap = matrix_a.dot(&p);
184 let curvature =
185 p.iter().zip(ap.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
186 if curvature.re <= tolerance || curvature.norm() <= tolerance {
187 return Err(IterativeError::NotPositiveDefinite);
188 }
189
190 let alpha = rs_old / curvature;
191 x = &x + &(alpha * &p);
192 r = &r - &(alpha * &ap);
193
194 let rs_new = r.iter().zip(r.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
195 if rs_new.re.max(0.0).sqrt() <= tolerance {
196 return Ok(x);
197 }
198
199 if rs_old.norm() <= tolerance {
200 return Err(IterativeError::Breakdown);
201 }
202 let beta = rs_new / rs_old;
203 p = &r + &(beta * &p);
204 rs_old = rs_new;
205 }
206
207 Err(IterativeError::MaxIterationsExceeded)
208}
209
210fn solve_linear<T>(matrix: &Array2<T>, rhs: &Array1<T>) -> Result<Array1<T>, IterativeError>
211where
212 T: IterativeLinearScalar,
213{
214 lu::solve(matrix, rhs).map_err(|_| IterativeError::Breakdown)
215}
216
217#[allow(clippy::many_single_char_names)]
222#[cfg(feature = "lapack-provider")]
223pub fn gmres<T>(
224 matrix_a: &Array2<T>,
225 matrix_b: &Array1<T>,
226 config: &IterativeConfig<T>,
227) -> Result<Array1<T>, IterativeError>
228where
229 T: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack,
230{
231 gmres_impl(matrix_a, matrix_b, config)
232}
233
234#[allow(clippy::many_single_char_names)]
239#[cfg(not(feature = "lapack-provider"))]
240pub fn gmres<T>(
241 matrix_a: &Array2<T>,
242 matrix_b: &Array1<T>,
243 config: &IterativeConfig<T>,
244) -> Result<Array1<T>, IterativeError>
245where
246 T: NabledReal + std::ops::SubAssign,
247{
248 gmres_impl(matrix_a, matrix_b, config)
249}
250
251#[allow(clippy::many_single_char_names)]
252fn gmres_impl<T>(
253 matrix_a: &Array2<T>,
254 matrix_b: &Array1<T>,
255 config: &IterativeConfig<T>,
256) -> Result<Array1<T>, IterativeError>
257where
258 T: IterativeLinearScalar,
259{
260 if matrix_a.is_empty() || matrix_b.is_empty() {
261 return Err(IterativeError::EmptyMatrix);
262 }
263 if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
264 return Err(IterativeError::DimensionMismatch);
265 }
266
267 let n = matrix_b.len();
268 let m = n.min(config.max_iterations.max(1));
269 let mut basis = Array2::<T>::zeros((n, m + 1));
270 let mut hessenberg = Array2::<T>::zeros((m + 1, m));
271
272 let beta = vector_norm(matrix_b);
273 let tolerance = config.tolerance.max(default_tolerance::<T>());
274 if beta <= tolerance {
275 return Ok(Array1::<T>::zeros(n));
276 }
277
278 for row in 0..n {
279 basis[[row, 0]] = matrix_b[row] / beta;
280 }
281
282 let mut effective_m = m;
283 for j in 0..m {
284 let mut w = matrix_a.dot(&basis.column(j));
285
286 for i in 0..=j {
287 let vi = basis.column(i);
288 let hij = vi.dot(&w);
289 hessenberg[[i, j]] = hij;
290 for row in 0..n {
291 w[row] -= hij * basis[[row, i]];
292 }
293 }
294
295 let norm_w = vector_norm(&w);
296 hessenberg[[j + 1, j]] = norm_w;
297 if norm_w <= tolerance {
298 effective_m = j + 1;
299 break;
300 }
301 for row in 0..n {
302 basis[[row, j + 1]] = w[row] / norm_w;
303 }
304 }
305
306 let h = hessenberg.slice(ndarray::s![..(effective_m + 1), ..effective_m]);
307 let ht = h.t();
308 let normal_matrix = ht.dot(&h);
309
310 let mut rhs_ls = Array1::<T>::zeros(effective_m + 1);
311 rhs_ls[0] = beta;
312 let normal_rhs = ht.dot(&rhs_ls);
313
314 let y = solve_linear(&normal_matrix, &normal_rhs)?;
315 let x = basis.slice(ndarray::s![.., ..effective_m]).dot(&y);
316
317 let residual = matrix_b - &matrix_a.dot(&x);
318 if vector_norm(&residual) <= tolerance {
319 Ok(x)
320 } else {
321 Err(IterativeError::MaxIterationsExceeded)
322 }
323}
324
325#[allow(clippy::many_single_char_names)]
330pub fn gmres_complex(
331 matrix_a: &Array2<Complex64>,
332 matrix_b: &Array1<Complex64>,
333 config: &IterativeConfig<f64>,
334) -> Result<Array1<Complex64>, IterativeError> {
335 if matrix_a.is_empty() || matrix_b.is_empty() {
336 return Err(IterativeError::EmptyMatrix);
337 }
338 if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
339 return Err(IterativeError::DimensionMismatch);
340 }
341
342 let n = matrix_b.len();
343 let m = n.min(config.max_iterations.max(1));
344 let mut basis = Array2::<Complex64>::zeros((n, m + 1));
345 let mut hessenberg = Array2::<Complex64>::zeros((m + 1, m));
346 let tolerance = config.tolerance.max(DEFAULT_TOLERANCE);
347
348 let beta = vector_norm_complex(matrix_b);
349 if beta <= tolerance {
350 return Ok(Array1::<Complex64>::zeros(n));
351 }
352
353 for row in 0..n {
354 basis[[row, 0]] = matrix_b[row] / beta;
355 }
356
357 let mut effective_m = m;
358 for j in 0..m {
359 let mut w = matrix_a.dot(&basis.column(j));
360
361 for i in 0..=j {
362 let vi = basis.column(i);
363 let hij = vi.iter().zip(w.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
364 hessenberg[[i, j]] = hij;
365 for row in 0..n {
366 w[row] -= hij * basis[[row, i]];
367 }
368 }
369
370 let norm_w = vector_norm_complex(&w);
371 hessenberg[[j + 1, j]] = Complex64::new(norm_w, 0.0);
372 if norm_w <= tolerance {
373 effective_m = j + 1;
374 break;
375 }
376 for row in 0..n {
377 basis[[row, j + 1]] = w[row] / norm_w;
378 }
379 }
380
381 let h = hessenberg.slice(ndarray::s![..(effective_m + 1), ..effective_m]);
382 let h_conj_t = h.mapv(|value| value.conj()).reversed_axes();
383 let normal_matrix = h_conj_t.dot(&h);
384
385 let mut rhs_ls = Array1::<Complex64>::zeros(effective_m + 1);
386 rhs_ls[0] = Complex64::new(beta, 0.0);
387 let normal_rhs = h_conj_t.dot(&rhs_ls);
388
389 let y =
390 lu::solve_complex(&normal_matrix, &normal_rhs).map_err(|_| IterativeError::Breakdown)?;
391 let x = basis.slice(ndarray::s![.., ..effective_m]).dot(&y);
392
393 let residual = matrix_b - &matrix_a.dot(&x);
394 if vector_norm_complex(&residual) <= tolerance {
395 Ok(x)
396 } else {
397 Err(IterativeError::MaxIterationsExceeded)
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use ndarray::{Array1, Array2};
404 use num_complex::Complex64;
405
406 use super::*;
407
408 #[test]
409 fn cg_solves_spd_system() {
410 let matrix = Array2::from_shape_vec((2, 2), vec![4.0_f64, 1.0, 1.0, 3.0]).unwrap();
411 let rhs = Array1::from_vec(vec![1.0_f64, 2.0]);
412 let solution =
413 conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
414 let reconstructed = matrix.dot(&solution);
415 assert!((reconstructed[0] - rhs[0]).abs() < 1e-8);
416 assert!((reconstructed[1] - rhs[1]).abs() < 1e-8);
417 }
418
419 #[test]
420 fn gmres_solves_small_system() {
421 let matrix = Array2::from_shape_vec((2, 2), vec![3.0_f64, 1.0, 1.0, 2.0]).unwrap();
422 let rhs = Array1::from_vec(vec![9.0_f64, 8.0]);
423 let solution = gmres(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
424 let reconstructed = matrix.dot(&solution);
425 assert!((reconstructed[0] - rhs[0]).abs() < 1e-8);
426 assert!((reconstructed[1] - rhs[1]).abs() < 1e-8);
427 }
428
429 #[test]
430 fn real_f32_solvers_work() {
431 let matrix = Array2::from_shape_vec((2, 2), vec![4.0_f32, 1.0, 1.0, 3.0]).unwrap();
432 let rhs = Array1::from_vec(vec![1.0_f32, 2.0]);
433 let config = IterativeConfig::<f32>::default();
434
435 let cg = conjugate_gradient(&matrix, &rhs, &config).unwrap();
436 let gm = gmres(&matrix, &rhs, &config).unwrap();
437
438 let cg_reconstructed = matrix.dot(&cg);
439 let gm_reconstructed = matrix.dot(&gm);
440 for i in 0..rhs.len() {
441 assert!((cg_reconstructed[i] - rhs[i]).abs() < 1e-4);
442 assert!((gm_reconstructed[i] - rhs[i]).abs() < 1e-4);
443 }
444 }
445
446 #[test]
447 fn cg_rejects_dimension_mismatch() {
448 let matrix = Array2::<f64>::eye(2);
449 let rhs = Array1::from_vec(vec![1.0_f64, 2.0, 3.0]);
450 let result = conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default());
451 assert!(matches!(result, Err(IterativeError::DimensionMismatch)));
452 }
453
454 #[test]
455 fn gmres_rejects_empty_input() {
456 let matrix = Array2::<f64>::zeros((0, 0));
457 let rhs = Array1::<f64>::zeros(0);
458 let result = gmres(&matrix, &rhs, &IterativeConfig::<f64>::default());
459 assert!(matches!(result, Err(IterativeError::EmptyMatrix)));
460 }
461
462 #[test]
463 fn cg_returns_zero_for_zero_rhs() {
464 let matrix = Array2::<f64>::eye(2);
465 let rhs = Array1::from_vec(vec![0.0_f64, 0.0]);
466 let solution =
467 conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
468 assert!(solution.iter().all(|value| value.abs() < 1e-12_f64));
469 }
470
471 #[test]
472 fn cg_complex_solves_hermitian_spd_system() {
473 let matrix = Array2::from_shape_vec((2, 2), vec![
474 Complex64::new(4.0, 0.0),
475 Complex64::new(1.0, 1.0),
476 Complex64::new(1.0, -1.0),
477 Complex64::new(3.0, 0.0),
478 ])
479 .unwrap();
480 let rhs = Array1::from_vec(vec![Complex64::new(1.0, 0.5), Complex64::new(2.0, -1.0)]);
481 let solution =
482 conjugate_gradient_complex(&matrix, &rhs, &IterativeConfig::default()).unwrap();
483 let reconstructed = matrix.dot(&solution);
484 for i in 0..rhs.len() {
485 assert!((reconstructed[i] - rhs[i]).norm() < 1e-7);
486 }
487 }
488
489 #[test]
490 fn gmres_complex_solves_small_system() {
491 let matrix = Array2::from_shape_vec((2, 2), vec![
492 Complex64::new(3.0, 1.0),
493 Complex64::new(1.0, -0.5),
494 Complex64::new(0.5, 1.0),
495 Complex64::new(2.0, -1.0),
496 ])
497 .unwrap();
498 let rhs = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -1.0)]);
499 let solution = gmres_complex(&matrix, &rhs, &IterativeConfig::default()).unwrap();
500 let reconstructed = matrix.dot(&solution);
501 for i in 0..rhs.len() {
502 assert!((reconstructed[i] - rhs[i]).norm() < 1e-7);
503 }
504 }
505
506 #[test]
507 fn cg_complex_rejects_dimension_mismatch() {
508 let matrix = Array2::from_shape_vec((2, 2), vec![
509 Complex64::new(1.0, 0.0),
510 Complex64::new(0.0, 0.0),
511 Complex64::new(0.0, 0.0),
512 Complex64::new(1.0, 0.0),
513 ])
514 .unwrap();
515 let rhs = Array1::from_vec(vec![Complex64::new(1.0, 0.0)]);
516 let result = conjugate_gradient_complex(&matrix, &rhs, &IterativeConfig::default());
517 assert!(matches!(result, Err(IterativeError::DimensionMismatch)));
518 }
519
520 #[test]
521 fn gmres_complex_rejects_empty_input() {
522 let matrix = Array2::<Complex64>::zeros((0, 0));
523 let rhs = Array1::<Complex64>::zeros(0);
524 let result = gmres_complex(&matrix, &rhs, &IterativeConfig::default());
525 assert!(matches!(result, Err(IterativeError::EmptyMatrix)));
526 }
527}