1#![allow(unused_variables)]
12#![allow(unused_assignments)]
13#![allow(unused_mut)]
14
15use crate::error::{SparseError, SparseResult};
16use crate::sparray::SparseArray;
17use scirs2_core::ndarray::{Array1, ArrayView1};
18use scirs2_core::numeric::{Float, One, SparseElement};
19use std::fmt::Debug;
20
21fn sym_ortho<T: Float + SparseElement>(a: T, b: T) -> (T, T, T) {
28 let zero = T::sparse_zero();
29 let one = <T as One>::one();
30
31 if b == zero {
32 return (if a >= zero { one } else { -one }, zero, a.abs());
33 } else if a == zero {
34 return (zero, if b >= zero { one } else { -one }, b.abs());
35 } else if b.abs() > a.abs() {
36 let tau = a / b;
37 let s_sign = if b >= zero { one } else { -one };
38 let s = s_sign / (one + tau * tau).sqrt();
39 let c = s * tau;
40 let r = b / s;
41 (c, s, r)
42 } else {
43 let tau = b / a;
44 let c_sign = if a >= zero { one } else { -one };
45 let c = c_sign / (one + tau * tau).sqrt();
46 let s = c * tau;
47 let r = a / c;
48 (c, s, r)
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct LSMROptions {
55 pub max_iter: usize,
57 pub atol: f64,
59 pub btol: f64,
61 pub conlim: f64,
63 pub calc_var: bool,
65 pub store_residual_history: bool,
67 pub local_size: usize,
69}
70
71impl Default for LSMROptions {
72 fn default() -> Self {
73 Self {
74 max_iter: 1000,
75 atol: 1e-8,
76 btol: 1e-8,
77 conlim: 1e8,
78 calc_var: false,
79 store_residual_history: true,
80 local_size: 0,
81 }
82 }
83}
84
85#[derive(Debug, Clone)]
87pub struct LSMRResult<T> {
88 pub x: Array1<T>,
90 pub iterations: usize,
92 pub residualnorm: T,
94 pub solution_norm: T,
96 pub condition_number: T,
98 pub converged: bool,
100 pub standard_errors: Option<Array1<T>>,
102 pub residual_history: Option<Vec<T>>,
104 pub convergence_reason: String,
106}
107
108#[allow(dead_code)]
144pub fn lsmr<T, S>(
145 matrix: &S,
146 b: &ArrayView1<T>,
147 x0: Option<&ArrayView1<T>>,
148 options: LSMROptions,
149) -> SparseResult<LSMRResult<T>>
150where
151 T: Float + SparseElement + Debug + Copy + 'static,
152 S: SparseArray<T>,
153{
154 let (m, n) = matrix.shape();
155
156 if b.len() != m {
157 return Err(SparseError::DimensionMismatch {
158 expected: m,
159 found: b.len(),
160 });
161 }
162
163 let mut x = match x0 {
165 Some(x0_val) => {
166 if x0_val.len() != n {
167 return Err(SparseError::DimensionMismatch {
168 expected: n,
169 found: x0_val.len(),
170 });
171 }
172 x0_val.to_owned()
173 }
174 None => Array1::zeros(n),
175 };
176
177 let ax = matrix_vector_multiply(matrix, &x.view())?;
179 let mut u = b - &ax;
180 let mut beta = l2_norm(&u.view());
181
182 let atol = T::from(options.atol).expect("Operation failed");
184 let btol = T::from(options.btol).expect("Operation failed");
185 let conlim = T::from(options.conlim).expect("Operation failed");
186
187 let mut residual_history = if options.store_residual_history {
188 Some(vec![beta])
189 } else {
190 None
191 };
192
193 if beta <= atol {
195 let solution_norm = l2_norm(&x.view());
196 return Ok(LSMRResult {
197 x,
198 iterations: 0,
199 residualnorm: beta,
200 solution_norm,
201 condition_number: T::sparse_one(),
202 converged: true,
203 standard_errors: None,
204 residual_history,
205 convergence_reason: "Already converged".to_string(),
206 });
207 }
208
209 if beta > T::sparse_zero() {
211 for i in 0..m {
212 u[i] = u[i] / beta;
213 }
214 }
215
216 let mut v = matrix_transpose_vector_multiply(matrix, &u.view())?;
218 let mut alpha = l2_norm(&v.view());
219
220 if alpha > T::sparse_zero() {
221 for i in 0..n {
222 v[i] = v[i] / alpha;
223 }
224 }
225
226 let one = T::sparse_one();
228 let zero = T::sparse_zero();
229
230 let mut alphabar = alpha;
231 let mut zetabar = alpha * beta;
232 let mut rho = one;
233 let mut rhobar = one;
234 let mut cbar = one;
235 let mut sbar = zero;
236
237 let mut h = v.clone();
238 let mut hbar: Array1<T> = Array1::zeros(n);
239
240 let mut anorm = zero;
242 let mut acond = zero;
243 let mut rnorm = beta;
244 let mut xnorm = zero;
245
246 let bnorm = beta;
247 let mut norm_a2 = alpha * alpha;
248 let mut maxrbar = zero;
249 let mut minrbar = T::from(1e100).expect("Operation failed");
250
251 let mut converged = false;
252 let mut convergence_reason = String::new();
253 let mut iter = 0;
254
255 for itn in 0..options.max_iter {
256 iter = itn + 1;
257
258 let av = matrix_vector_multiply(matrix, &v.view())?;
261 for i in 0..m {
262 u[i] = av[i] - alpha * u[i];
263 }
264 beta = l2_norm(&u.view());
265
266 if beta > zero {
267 for i in 0..m {
268 u[i] = u[i] / beta;
269 }
270
271 let atu = matrix_transpose_vector_multiply(matrix, &u.view())?;
273 for i in 0..n {
274 v[i] = atu[i] - beta * v[i];
275 }
276 alpha = l2_norm(&v.view());
277
278 if alpha > zero {
279 for i in 0..n {
280 v[i] = v[i] / alpha;
281 }
282 }
283 }
284
285 let rhoold = rho;
287 let (c, s, rho_new) = sym_ortho(alphabar, beta);
288 rho = rho_new;
289 let thetanew = s * alpha;
290 alphabar = c * alpha;
291
292 let rhobarold = rhobar;
294 let zetaold = zetabar;
295 let thetabar = sbar * rho;
296 let rhotemp = cbar * rho;
297 let (cbar_new, sbar_new, rhobar_new) = sym_ortho(rhotemp, thetanew);
298 cbar = cbar_new;
299 sbar = sbar_new;
300 rhobar = rhobar_new;
301 let zeta = cbar * zetabar;
302 zetabar = -sbar * zetabar;
303
304 for i in 0..n {
306 let hbar_old = hbar[i];
307 hbar[i] = h[i] - (thetabar * rho / (rhoold * rhobarold)) * hbar_old;
308 }
309 for i in 0..n {
310 x[i] = x[i] + (zeta / (rho * rhobar)) * hbar[i];
311 }
312 for i in 0..n {
313 h[i] = v[i] - (thetanew / rho) * h[i];
314 }
315
316 norm_a2 = norm_a2 + beta * beta;
318 anorm = norm_a2.sqrt();
319 norm_a2 = norm_a2 + alpha * alpha;
320
321 if c.abs() > zero {
323 maxrbar = maxrbar.max(rhobarold);
324 if itn > 1 {
325 minrbar = minrbar.min(rhobarold);
326 }
327 }
328 acond = maxrbar / minrbar;
329
330 let betadd = c * zetaold;
332 let betad = -(sbar * betadd);
333 let rhodold = rho;
334
335 let thetahat = sbar * rho;
337 let rhohat = cbar * rho;
338 let chat = rhohat / rhodold;
339 let shat = thetahat / rhodold;
340
341 rnorm = (rnorm * rnorm * shat * shat + betad * betad).sqrt();
342 xnorm = (xnorm * xnorm + (zeta / (rho * rhobar)) * (zeta / (rho * rhobar))).sqrt();
343
344 let arnorm = alpha * beta.abs() * c.abs() * s.abs();
345
346 if let Some(ref mut history) = residual_history {
347 history.push(rnorm);
348 }
349
350 let test1 = rnorm / (bnorm + anorm * xnorm + one);
353 let test2 = if rnorm > zero {
355 arnorm / (anorm * rnorm + one)
356 } else {
357 zero
358 };
359
360 if test1 <= atol || rnorm <= atol * bnorm {
361 converged = true;
362 convergence_reason = "Residual tolerance satisfied".to_string();
363 break;
364 }
365
366 if test2 <= btol {
367 converged = true;
368 convergence_reason = "Solution tolerance satisfied".to_string();
369 break;
370 }
371
372 if acond >= conlim {
373 converged = true;
374 convergence_reason = "Condition number limit reached".to_string();
375 break;
376 }
377 }
378
379 if !converged {
380 convergence_reason = "Maximum iterations reached".to_string();
381 }
382
383 let ax_final = matrix_vector_multiply(matrix, &x.view())?;
385 let final_residual = b - &ax_final;
386 let final_residualnorm = l2_norm(&final_residual.view());
387 let final_solution_norm = l2_norm(&x.view());
388
389 let condition_number = acond;
391
392 let standard_errors = if options.calc_var {
394 Some(compute_standard_errors(matrix, final_residualnorm, n)?)
395 } else {
396 None
397 };
398
399 Ok(LSMRResult {
400 x,
401 iterations: iter,
402 residualnorm: final_residualnorm,
403 solution_norm: final_solution_norm,
404 condition_number,
405 converged,
406 standard_errors,
407 residual_history,
408 convergence_reason,
409 })
410}
411
412#[allow(dead_code)]
414fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
415where
416 T: Float + SparseElement + Debug + Copy + 'static,
417 S: SparseArray<T>,
418{
419 let (rows, cols) = matrix.shape();
420 if x.len() != cols {
421 return Err(SparseError::DimensionMismatch {
422 expected: cols,
423 found: x.len(),
424 });
425 }
426
427 let mut result = Array1::zeros(rows);
428 let (row_indices, col_indices, values) = matrix.find();
429
430 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
431 result[i] = result[i] + values[k] * x[j];
432 }
433
434 Ok(result)
435}
436
437#[allow(dead_code)]
439fn matrix_transpose_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
440where
441 T: Float + SparseElement + Debug + Copy + 'static,
442 S: SparseArray<T>,
443{
444 let (rows, cols) = matrix.shape();
445 if x.len() != rows {
446 return Err(SparseError::DimensionMismatch {
447 expected: rows,
448 found: x.len(),
449 });
450 }
451
452 let mut result = Array1::zeros(cols);
453 let (row_indices, col_indices, values) = matrix.find();
454
455 for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
456 result[j] = result[j] + values[k] * x[i];
457 }
458
459 Ok(result)
460}
461
462#[allow(dead_code)]
464fn l2_norm<T>(x: &ArrayView1<T>) -> T
465where
466 T: Float + SparseElement + Debug + Copy,
467{
468 (x.iter()
469 .map(|&val| val * val)
470 .fold(T::sparse_zero(), |a, b| a + b))
471 .sqrt()
472}
473
474#[allow(dead_code)]
476fn compute_standard_errors<T, S>(matrix: &S, residualnorm: T, n: usize) -> SparseResult<Array1<T>>
477where
478 T: Float + SparseElement + Debug + Copy + 'static,
479 S: SparseArray<T>,
480{
481 let (m, _) = matrix.shape();
482
483 let variance = if m > n {
485 residualnorm * residualnorm / T::from(m - n).expect("Operation failed")
486 } else {
487 residualnorm * residualnorm
488 };
489
490 let std_err = variance.sqrt();
491 Ok(Array1::from_elem(n, std_err))
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497 use crate::csr_array::CsrArray;
498 use approx::assert_relative_eq;
499
500 #[test]
501 fn test_lsmr_square_system() {
502 let rows = vec![0, 0, 1, 1, 2, 2];
504 let cols = vec![0, 1, 0, 1, 1, 2];
505 let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
506 let matrix =
507 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
508
509 let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
510 let result =
511 lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
512
513 assert!(result.converged);
514
515 let ax = matrix_vector_multiply(&matrix, &result.x.view()).expect("Operation failed");
517 let residual = &b - &ax;
518 let residualnorm = l2_norm(&residual.view());
519
520 assert!(residualnorm < 1e-6);
521 }
522
523 #[test]
524 fn test_lsmr_overdetermined_system() {
525 let rows = vec![0, 0, 1, 1, 2, 2];
527 let cols = vec![0, 1, 0, 1, 0, 1];
528 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
529 let matrix =
530 CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).expect("Operation failed");
531
532 let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
533 let result =
534 lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
535
536 assert!(result.converged);
537 assert_eq!(result.x.len(), 2);
538
539 assert!(result.residualnorm < 2.0);
541 }
542
543 #[test]
544 fn test_lsmr_diagonal_system() {
545 let rows = vec![0, 1, 2];
547 let cols = vec![0, 1, 2];
548 let data = vec![2.0, 3.0, 4.0];
549 let matrix =
550 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
551
552 let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
553 let result =
554 lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
555
556 assert!(result.converged);
557
558 assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
560 assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
561 assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
562 }
563
564 #[test]
565 fn test_lsmr_with_initial_guess() {
566 let rows = vec![0, 1, 2];
567 let cols = vec![0, 1, 2];
568 let data = vec![1.0, 1.0, 1.0];
569 let matrix =
570 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
571
572 let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
573 let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); let result = lsmr(&matrix, &b.view(), Some(&x0.view()), LSMROptions::default())
576 .expect("Operation failed");
577
578 assert!(result.converged);
579 assert!(result.iterations <= 10); }
581
582 #[test]
583 fn test_lsmr_standard_errors() {
584 let rows = vec![0, 1, 2];
585 let cols = vec![0, 1, 2];
586 let data = vec![1.0, 1.0, 1.0];
587 let matrix =
588 CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
589
590 let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
591
592 let options = LSMROptions {
593 calc_var: true,
594 ..Default::default()
595 };
596
597 let result = lsmr(&matrix, &b.view(), None, options).expect("Operation failed");
598
599 assert!(result.converged);
600 assert!(result.standard_errors.is_some());
601
602 let std_errs = result.standard_errors.expect("Operation failed");
603 assert_eq!(std_errs.len(), 3);
604 }
605}