scirs2_sparse/linalg/
lsmr.rs

1//! Least Squares Minimal Residual (LSMR) method for sparse linear systems
2//!
3//! LSMR is an iterative algorithm for solving large sparse least squares problems
4//! and sparse systems of linear equations. It's closely related to LSQR but
5//! can be more stable for ill-conditioned problems.
6//!
7//! Implementation follows SciPy reference based on:
8//! D. C.-L. Fong and M. A. Saunders (2011), "LSMR: An iterative algorithm
9//! for sparse least-squares problems", SIAM J. Sci. Comput., 33(5), 2950-2971.
10
11#![allow(unused_variables)]
12#![allow(unused_assignments)]
13#![allow(unused_mut)]
14
15use crate::error::{SparseError, SparseResult};
16use crate::sparray::SparseArray;
17use scirs2_core::ndarray::{Array1, ArrayView1};
18use scirs2_core::numeric::{Float, One, SparseElement};
19use std::fmt::Debug;
20
21/// Stable implementation of Givens rotation (sym_ortho)
22///
23/// Computes (c, s, r) such that [ c  s] [a] = [r]
24///                               [-s  c] [b]   [0]
25///
26/// Uses the stable formulation to avoid overflow/underflow.
27fn sym_ortho<T: Float + SparseElement>(a: T, b: T) -> (T, T, T) {
28    let zero = T::sparse_zero();
29    let one = <T as One>::one();
30
31    if b == zero {
32        return (if a >= zero { one } else { -one }, zero, a.abs());
33    } else if a == zero {
34        return (zero, if b >= zero { one } else { -one }, b.abs());
35    } else if b.abs() > a.abs() {
36        let tau = a / b;
37        let s_sign = if b >= zero { one } else { -one };
38        let s = s_sign / (one + tau * tau).sqrt();
39        let c = s * tau;
40        let r = b / s;
41        (c, s, r)
42    } else {
43        let tau = b / a;
44        let c_sign = if a >= zero { one } else { -one };
45        let c = c_sign / (one + tau * tau).sqrt();
46        let s = c * tau;
47        let r = a / c;
48        (c, s, r)
49    }
50}
51
52/// Options for the LSMR solver
53#[derive(Debug, Clone)]
54pub struct LSMROptions {
55    /// Maximum number of iterations
56    pub max_iter: usize,
57    /// Convergence tolerance for the residual
58    pub atol: f64,
59    /// Convergence tolerance for the solution
60    pub btol: f64,
61    /// Condition number limit
62    pub conlim: f64,
63    /// Whether to compute standard errors
64    pub calc_var: bool,
65    /// Whether to store residual history
66    pub store_residual_history: bool,
67    /// Local reorthogonalization parameter
68    pub local_size: usize,
69}
70
71impl Default for LSMROptions {
72    fn default() -> Self {
73        Self {
74            max_iter: 1000,
75            atol: 1e-8,
76            btol: 1e-8,
77            conlim: 1e8,
78            calc_var: false,
79            store_residual_history: true,
80            local_size: 0,
81        }
82    }
83}
84
85/// Result from LSMR solver
86#[derive(Debug, Clone)]
87pub struct LSMRResult<T> {
88    /// Solution vector
89    pub x: Array1<T>,
90    /// Number of iterations performed
91    pub iterations: usize,
92    /// Final residual norm ||Ax - b||
93    pub residualnorm: T,
94    /// Final solution norm ||x||
95    pub solution_norm: T,
96    /// Condition number estimate
97    pub condition_number: T,
98    /// Whether the solver converged
99    pub converged: bool,
100    /// Standard errors (if requested)
101    pub standard_errors: Option<Array1<T>>,
102    /// Residual history (if requested)
103    pub residual_history: Option<Vec<T>>,
104    /// Convergence reason
105    pub convergence_reason: String,
106}
107
108/// LSMR algorithm for sparse least squares problems
109///
110/// Solves the least squares problem min ||Ax - b||_2 or the linear system Ax = b.
111/// The method is based on the Golub-Kahan bidiagonalization process.
112///
113/// # Arguments
114///
115/// * `matrix` - The coefficient matrix A (m x n)
116/// * `b` - The right-hand side vector (length m)
117/// * `x0` - Initial guess (optional, length n)
118/// * `options` - Solver options
119///
120/// # Returns
121///
122/// An `LSMRResult` containing the solution and convergence information
123///
124/// # Example
125///
126/// ```rust
127/// use scirs2_sparse::csr_array::CsrArray;
128/// use scirs2_sparse::linalg::{lsmr, LSMROptions};
129/// use scirs2_core::ndarray::Array1;
130///
131/// // Create an overdetermined system
132/// let rows = vec![0, 0, 1, 1, 2, 2];
133/// let cols = vec![0, 1, 0, 1, 0, 1];
134/// let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
135/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).expect("Operation failed");
136///
137/// // Right-hand side
138/// let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
139///
140/// // Solve using LSMR
141/// let result = lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
142/// ```
143#[allow(dead_code)]
144pub fn lsmr<T, S>(
145    matrix: &S,
146    b: &ArrayView1<T>,
147    x0: Option<&ArrayView1<T>>,
148    options: LSMROptions,
149) -> SparseResult<LSMRResult<T>>
150where
151    T: Float + SparseElement + Debug + Copy + 'static,
152    S: SparseArray<T>,
153{
154    let (m, n) = matrix.shape();
155
156    if b.len() != m {
157        return Err(SparseError::DimensionMismatch {
158            expected: m,
159            found: b.len(),
160        });
161    }
162
163    // Initialize solution vector
164    let mut x = match x0 {
165        Some(x0_val) => {
166            if x0_val.len() != n {
167                return Err(SparseError::DimensionMismatch {
168                    expected: n,
169                    found: x0_val.len(),
170                });
171            }
172            x0_val.to_owned()
173        }
174        None => Array1::zeros(n),
175    };
176
177    // Compute initial residual
178    let ax = matrix_vector_multiply(matrix, &x.view())?;
179    let mut u = b - &ax;
180    let mut beta = l2_norm(&u.view());
181
182    // Tolerances
183    let atol = T::from(options.atol).expect("Operation failed");
184    let btol = T::from(options.btol).expect("Operation failed");
185    let conlim = T::from(options.conlim).expect("Operation failed");
186
187    let mut residual_history = if options.store_residual_history {
188        Some(vec![beta])
189    } else {
190        None
191    };
192
193    // Check for immediate convergence
194    if beta <= atol {
195        let solution_norm = l2_norm(&x.view());
196        return Ok(LSMRResult {
197            x,
198            iterations: 0,
199            residualnorm: beta,
200            solution_norm,
201            condition_number: T::sparse_one(),
202            converged: true,
203            standard_errors: None,
204            residual_history,
205            convergence_reason: "Already converged".to_string(),
206        });
207    }
208
209    // Normalize u
210    if beta > T::sparse_zero() {
211        for i in 0..m {
212            u[i] = u[i] / beta;
213        }
214    }
215
216    // Initialize bidiagonalization
217    let mut v = matrix_transpose_vector_multiply(matrix, &u.view())?;
218    let mut alpha = l2_norm(&v.view());
219
220    if alpha > T::sparse_zero() {
221        for i in 0..n {
222            v[i] = v[i] / alpha;
223        }
224    }
225
226    // Initialize LSMR-specific variables (following SciPy reference)
227    let one = T::sparse_one();
228    let zero = T::sparse_zero();
229
230    let mut alphabar = alpha;
231    let mut zetabar = alpha * beta;
232    let mut rho = one;
233    let mut rhobar = one;
234    let mut cbar = one;
235    let mut sbar = zero;
236
237    let mut h = v.clone();
238    let mut hbar: Array1<T> = Array1::zeros(n);
239
240    // For norm estimation
241    let mut anorm = zero;
242    let mut acond = zero;
243    let mut rnorm = beta;
244    let mut xnorm = zero;
245
246    let bnorm = beta;
247    let mut norm_a2 = alpha * alpha;
248    let mut maxrbar = zero;
249    let mut minrbar = T::from(1e100).expect("Operation failed");
250
251    let mut converged = false;
252    let mut convergence_reason = String::new();
253    let mut iter = 0;
254
255    for itn in 0..options.max_iter {
256        iter = itn + 1;
257
258        // Perform the next step of the bidiagonalization.
259        // Golub-Kahan bidiagonalization: u = A*v - alpha*u
260        let av = matrix_vector_multiply(matrix, &v.view())?;
261        for i in 0..m {
262            u[i] = av[i] - alpha * u[i];
263        }
264        beta = l2_norm(&u.view());
265
266        if beta > zero {
267            for i in 0..m {
268                u[i] = u[i] / beta;
269            }
270
271            // v = A'*u - beta*v
272            let atu = matrix_transpose_vector_multiply(matrix, &u.view())?;
273            for i in 0..n {
274                v[i] = atu[i] - beta * v[i];
275            }
276            alpha = l2_norm(&v.view());
277
278            if alpha > zero {
279                for i in 0..n {
280                    v[i] = v[i] / alpha;
281                }
282            }
283        }
284
285        // Construct rotation Q_{i,2i+1} (plane rotation to eliminate beta)
286        let rhoold = rho;
287        let (c, s, rho_new) = sym_ortho(alphabar, beta);
288        rho = rho_new;
289        let thetanew = s * alpha;
290        alphabar = c * alpha;
291
292        // Construct rotation Qbar_{i,2i+1} (plane rotation for LSMR)
293        let rhobarold = rhobar;
294        let zetaold = zetabar;
295        let thetabar = sbar * rho;
296        let rhotemp = cbar * rho;
297        let (cbar_new, sbar_new, rhobar_new) = sym_ortho(rhotemp, thetanew);
298        cbar = cbar_new;
299        sbar = sbar_new;
300        rhobar = rhobar_new;
301        let zeta = cbar * zetabar;
302        zetabar = -sbar * zetabar;
303
304        // Update h, hbar, x
305        for i in 0..n {
306            let hbar_old = hbar[i];
307            hbar[i] = h[i] - (thetabar * rho / (rhoold * rhobarold)) * hbar_old;
308        }
309        for i in 0..n {
310            x[i] = x[i] + (zeta / (rho * rhobar)) * hbar[i];
311        }
312        for i in 0..n {
313            h[i] = v[i] - (thetanew / rho) * h[i];
314        }
315
316        // Estimate norms
317        norm_a2 = norm_a2 + beta * beta;
318        anorm = norm_a2.sqrt();
319        norm_a2 = norm_a2 + alpha * alpha;
320
321        // Update estimates
322        if c.abs() > zero {
323            maxrbar = maxrbar.max(rhobarold);
324            if itn > 1 {
325                minrbar = minrbar.min(rhobarold);
326            }
327        }
328        acond = maxrbar / minrbar;
329
330        // Compute norm estimates
331        let betadd = c * zetaold;
332        let betad = -(sbar * betadd);
333        let rhodold = rho;
334
335        // Use the recurrence for ||r_k||
336        let thetahat = sbar * rho;
337        let rhohat = cbar * rho;
338        let chat = rhohat / rhodold;
339        let shat = thetahat / rhodold;
340
341        rnorm = (rnorm * rnorm * shat * shat + betad * betad).sqrt();
342        xnorm = (xnorm * xnorm + (zeta / (rho * rhobar)) * (zeta / (rho * rhobar))).sqrt();
343
344        let arnorm = alpha * beta.abs() * c.abs() * s.abs();
345
346        if let Some(ref mut history) = residual_history {
347            history.push(rnorm);
348        }
349
350        // Check stopping criteria
351        // Condition 1: ||Ax - b|| / ||b|| small enough
352        let test1 = rnorm / (bnorm + anorm * xnorm + one);
353        // Condition 2: ||A'r|| / (||A|| ||r||) small enough
354        let test2 = if rnorm > zero {
355            arnorm / (anorm * rnorm + one)
356        } else {
357            zero
358        };
359
360        if test1 <= atol || rnorm <= atol * bnorm {
361            converged = true;
362            convergence_reason = "Residual tolerance satisfied".to_string();
363            break;
364        }
365
366        if test2 <= btol {
367            converged = true;
368            convergence_reason = "Solution tolerance satisfied".to_string();
369            break;
370        }
371
372        if acond >= conlim {
373            converged = true;
374            convergence_reason = "Condition number limit reached".to_string();
375            break;
376        }
377    }
378
379    if !converged {
380        convergence_reason = "Maximum iterations reached".to_string();
381    }
382
383    // Compute final metrics
384    let ax_final = matrix_vector_multiply(matrix, &x.view())?;
385    let final_residual = b - &ax_final;
386    let final_residualnorm = l2_norm(&final_residual.view());
387    let final_solution_norm = l2_norm(&x.view());
388
389    // Condition number estimate
390    let condition_number = acond;
391
392    // Compute standard errors if requested
393    let standard_errors = if options.calc_var {
394        Some(compute_standard_errors(matrix, final_residualnorm, n)?)
395    } else {
396        None
397    };
398
399    Ok(LSMRResult {
400        x,
401        iterations: iter,
402        residualnorm: final_residualnorm,
403        solution_norm: final_solution_norm,
404        condition_number,
405        converged,
406        standard_errors,
407        residual_history,
408        convergence_reason,
409    })
410}
411
412/// Helper function for matrix-vector multiplication
413#[allow(dead_code)]
414fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
415where
416    T: Float + SparseElement + Debug + Copy + 'static,
417    S: SparseArray<T>,
418{
419    let (rows, cols) = matrix.shape();
420    if x.len() != cols {
421        return Err(SparseError::DimensionMismatch {
422            expected: cols,
423            found: x.len(),
424        });
425    }
426
427    let mut result = Array1::zeros(rows);
428    let (row_indices, col_indices, values) = matrix.find();
429
430    for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
431        result[i] = result[i] + values[k] * x[j];
432    }
433
434    Ok(result)
435}
436
437/// Helper function for matrix transpose-vector multiplication
438#[allow(dead_code)]
439fn matrix_transpose_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
440where
441    T: Float + SparseElement + Debug + Copy + 'static,
442    S: SparseArray<T>,
443{
444    let (rows, cols) = matrix.shape();
445    if x.len() != rows {
446        return Err(SparseError::DimensionMismatch {
447            expected: rows,
448            found: x.len(),
449        });
450    }
451
452    let mut result = Array1::zeros(cols);
453    let (row_indices, col_indices, values) = matrix.find();
454
455    for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
456        result[j] = result[j] + values[k] * x[i];
457    }
458
459    Ok(result)
460}
461
462/// Compute L2 norm of a vector
463#[allow(dead_code)]
464fn l2_norm<T>(x: &ArrayView1<T>) -> T
465where
466    T: Float + SparseElement + Debug + Copy,
467{
468    (x.iter()
469        .map(|&val| val * val)
470        .fold(T::sparse_zero(), |a, b| a + b))
471    .sqrt()
472}
473
474/// Compute standard errors (simplified implementation)
475#[allow(dead_code)]
476fn compute_standard_errors<T, S>(matrix: &S, residualnorm: T, n: usize) -> SparseResult<Array1<T>>
477where
478    T: Float + SparseElement + Debug + Copy + 'static,
479    S: SparseArray<T>,
480{
481    let (m, _) = matrix.shape();
482
483    // Simplified standard error computation
484    let variance = if m > n {
485        residualnorm * residualnorm / T::from(m - n).expect("Operation failed")
486    } else {
487        residualnorm * residualnorm
488    };
489
490    let std_err = variance.sqrt();
491    Ok(Array1::from_elem(n, std_err))
492}
493
494#[cfg(test)]
495mod tests {
496    use super::*;
497    use crate::csr_array::CsrArray;
498    use approx::assert_relative_eq;
499
500    #[test]
501    fn test_lsmr_square_system() {
502        // Create a simple 3x3 system
503        let rows = vec![0, 0, 1, 1, 2, 2];
504        let cols = vec![0, 1, 0, 1, 1, 2];
505        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
506        let matrix =
507            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
508
509        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
510        let result =
511            lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
512
513        assert!(result.converged);
514
515        // Verify solution by computing residual
516        let ax = matrix_vector_multiply(&matrix, &result.x.view()).expect("Operation failed");
517        let residual = &b - &ax;
518        let residualnorm = l2_norm(&residual.view());
519
520        assert!(residualnorm < 1e-6);
521    }
522
523    #[test]
524    fn test_lsmr_overdetermined_system() {
525        // Create an overdetermined 3x2 system
526        let rows = vec![0, 0, 1, 1, 2, 2];
527        let cols = vec![0, 1, 0, 1, 0, 1];
528        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
529        let matrix =
530            CsrArray::from_triplets(&rows, &cols, &data, (3, 2), false).expect("Operation failed");
531
532        let b = Array1::from_vec(vec![1.0, 2.0, 3.0]);
533        let result =
534            lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
535
536        assert!(result.converged);
537        assert_eq!(result.x.len(), 2);
538
539        // For overdetermined systems, check that we get a reasonable least squares solution
540        assert!(result.residualnorm < 2.0);
541    }
542
543    #[test]
544    fn test_lsmr_diagonal_system() {
545        // Create a diagonal system
546        let rows = vec![0, 1, 2];
547        let cols = vec![0, 1, 2];
548        let data = vec![2.0, 3.0, 4.0];
549        let matrix =
550            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
551
552        let b = Array1::from_vec(vec![4.0, 9.0, 16.0]);
553        let result =
554            lsmr(&matrix, &b.view(), None, LSMROptions::default()).expect("Operation failed");
555
556        assert!(result.converged);
557
558        // For diagonal system, solution should be [2, 3, 4]
559        assert_relative_eq!(result.x[0], 2.0, epsilon = 1e-6);
560        assert_relative_eq!(result.x[1], 3.0, epsilon = 1e-6);
561        assert_relative_eq!(result.x[2], 4.0, epsilon = 1e-6);
562    }
563
564    #[test]
565    fn test_lsmr_with_initial_guess() {
566        let rows = vec![0, 1, 2];
567        let cols = vec![0, 1, 2];
568        let data = vec![1.0, 1.0, 1.0];
569        let matrix =
570            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
571
572        let b = Array1::from_vec(vec![5.0, 6.0, 7.0]);
573        let x0 = Array1::from_vec(vec![4.0, 5.0, 6.0]); // Close to solution
574
575        let result = lsmr(&matrix, &b.view(), Some(&x0.view()), LSMROptions::default())
576            .expect("Operation failed");
577
578        assert!(result.converged);
579        assert!(result.iterations <= 10); // Should converge reasonably quickly
580    }
581
582    #[test]
583    fn test_lsmr_standard_errors() {
584        let rows = vec![0, 1, 2];
585        let cols = vec![0, 1, 2];
586        let data = vec![1.0, 1.0, 1.0];
587        let matrix =
588            CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).expect("Operation failed");
589
590        let b = Array1::from_vec(vec![1.0, 1.0, 1.0]);
591
592        let options = LSMROptions {
593            calc_var: true,
594            ..Default::default()
595        };
596
597        let result = lsmr(&matrix, &b.view(), None, options).expect("Operation failed");
598
599        assert!(result.converged);
600        assert!(result.standard_errors.is_some());
601
602        let std_errs = result.standard_errors.expect("Operation failed");
603        assert_eq!(std_errs.len(), 3);
604    }
605}