scirs2_sparse/linalg/
gcrot.rs

1//! Generalized Conjugate Residual with Orthogonalization and Truncation (GCROT-m) method
2//!
3//! GCROT-m is a Krylov subspace method for solving sparse linear systems.
4//! It maintains a set of search directions and performs orthogonalization
5//! to improve stability and convergence.
6
7use crate::error::{SparseError, SparseResult};
8use crate::sparray::SparseArray;
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
10use scirs2_core::numeric::Float;
11use scirs2_core::SparseElement;
12use std::fmt::Debug;
13
14/// Type alias for GCROT inner iteration result
15type GCROTInnerResult<T> = SparseResult<(Array1<T>, Option<Array1<T>>, Option<Array1<T>>, bool)>;
16
17/// Options for the GCROT solver
18#[derive(Debug, Clone)]
19pub struct GCROTOptions {
20    /// Maximum number of iterations
21    pub max_iter: usize,
22    /// Convergence tolerance
23    pub tol: f64,
24    /// Maximum dimension of the truncated space (m parameter)
25    pub truncation_size: usize,
26    /// Whether to store residual history
27    pub store_residual_history: bool,
28}
29
30impl Default for GCROTOptions {
31    fn default() -> Self {
32        Self {
33            max_iter: 1000,
34            tol: 1e-6,
35            truncation_size: 20,
36            store_residual_history: true,
37        }
38    }
39}
40
41/// Result from GCROT solver
42#[derive(Debug, Clone)]
43pub struct GCROTResult<T> {
44    /// Solution vector
45    pub x: Array1<T>,
46    /// Number of iterations performed
47    pub iterations: usize,
48    /// Final residual norm
49    pub residual_norm: T,
50    /// Whether the solver converged
51    pub converged: bool,
52    /// Residual history (if requested)
53    pub residual_history: Option<Vec<T>>,
54}
55
56/// Generalized Conjugate Residual with Orthogonalization and Truncation method
57///
58/// Solves the linear system A * x = b using the GCROT-m method.
59/// This method builds and maintains a truncated Krylov subspace to
60/// accelerate convergence for challenging linear systems.
61///
62/// # Arguments
63///
64/// * `matrix` - The coefficient matrix A
65/// * `b` - The right-hand side vector
66/// * `x0` - Initial guess (optional)
67/// * `options` - Solver options
68///
69/// # Returns
70///
71/// A `GCROTResult` containing the solution and convergence information
72///
73/// # Example
74///
75/// ```rust
76/// use scirs2_sparse::csr_array::CsrArray;
77/// use scirs2_sparse::linalg::{gcrot, GCROTOptions};
78/// use scirs2_core::ndarray::Array1;
79///
80/// // Create a simple matrix
81/// let rows = vec![0, 0, 1, 1, 2, 2];
82/// let cols = vec![0, 1, 0, 1, 1, 2];
83/// let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
84/// let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
85///
86/// // Right-hand side
87/// let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
88///
89/// // Solve using GCROT
90/// let result = gcrot(&matrix, &b.view(), None, GCROTOptions::default()).unwrap();
91/// ```
92#[allow(dead_code)]
93pub fn gcrot<T, S>(
94    matrix: &S,
95    b: &ArrayView1<T>,
96    x0: Option<&ArrayView1<T>>,
97    options: GCROTOptions,
98) -> SparseResult<GCROTResult<T>>
99where
100    T: Float + SparseElement + Debug + Copy + 'static,
101    S: SparseArray<T>,
102{
103    let n = b.len();
104    let (rows, cols) = matrix.shape();
105
106    if rows != cols || rows != n {
107        return Err(SparseError::DimensionMismatch {
108            expected: n,
109            found: rows,
110        });
111    }
112
113    // Initialize solution vector
114    let mut x = match x0 {
115        Some(x0_val) => x0_val.to_owned(),
116        None => Array1::zeros(n),
117    };
118
119    // Compute initial residual: r0 = b - A * x0
120    let ax = matrix_vector_multiply(matrix, &x.view())?;
121    let mut r = b - &ax;
122
123    // Check if already converged
124    let initial_residual_norm = l2_norm(&r.view());
125    let b_norm = l2_norm(b);
126    let tolerance = T::from(options.tol).unwrap() * b_norm;
127
128    if initial_residual_norm <= tolerance {
129        return Ok(GCROTResult {
130            x,
131            iterations: 0,
132            residual_norm: initial_residual_norm,
133            converged: true,
134            residual_history: if options.store_residual_history {
135                Some(vec![initial_residual_norm])
136            } else {
137                None
138            },
139        });
140    }
141
142    let m = options.truncation_size;
143
144    // Storage for the truncated space
145    let mut c_vectors = Array2::zeros((n, 0)); // C_k matrix
146    let mut u_vectors = Array2::zeros((n, 0)); // U_k matrix (A * C_k)
147
148    let mut residual_history = if options.store_residual_history {
149        Some(vec![initial_residual_norm])
150    } else {
151        None
152    };
153
154    let mut converged = false;
155    let mut iter = 0;
156
157    for k in 0..options.max_iter {
158        iter = k + 1;
159
160        // GCROT inner iteration (flexible GMRES with truncation)
161        let (delta_x, new_c, new_u, inner_converged) = gcrot_inner_iteration(
162            matrix,
163            &r.view(),
164            &c_vectors.view(),
165            &u_vectors.view(),
166            tolerance,
167        )?;
168
169        // Update solution
170        x = &x + &delta_x;
171
172        // Update residual
173        let ax = matrix_vector_multiply(matrix, &x.view())?;
174        r = b - &ax;
175        let residual_norm = l2_norm(&r.view());
176
177        if let Some(ref mut history) = residual_history {
178            history.push(residual_norm);
179        }
180
181        // Check convergence
182        if residual_norm <= tolerance || inner_converged {
183            converged = true;
184            break;
185        }
186
187        // Update truncated space
188        if let (Some(c), Some(u)) = (new_c, new_u) {
189            if c_vectors.ncols() >= m {
190                // Truncate the space by removing the oldest vector
191                let mut new_c_vectors = Array2::zeros((n, m));
192                let mut new_u_vectors = Array2::zeros((n, m));
193
194                // Keep the m-1 most recent vectors and add the new one
195                for j in 1..c_vectors.ncols() {
196                    for i in 0..n {
197                        new_c_vectors[[i, j - 1]] = c_vectors[[i, j]];
198                        new_u_vectors[[i, j - 1]] = u_vectors[[i, j]];
199                    }
200                }
201
202                // Add new vectors
203                for i in 0..n {
204                    new_c_vectors[[i, m - 1]] = c[i];
205                    new_u_vectors[[i, m - 1]] = u[i];
206                }
207
208                c_vectors = new_c_vectors;
209                u_vectors = new_u_vectors;
210            } else {
211                // Simply append the new vectors
212                let old_cols = c_vectors.ncols();
213                let mut new_c_vectors = Array2::zeros((n, old_cols + 1));
214                let mut new_u_vectors = Array2::zeros((n, old_cols + 1));
215
216                // Copy old vectors
217                for j in 0..old_cols {
218                    for i in 0..n {
219                        new_c_vectors[[i, j]] = c_vectors[[i, j]];
220                        new_u_vectors[[i, j]] = u_vectors[[i, j]];
221                    }
222                }
223
224                // Add new vectors
225                for i in 0..n {
226                    new_c_vectors[[i, old_cols]] = c[i];
227                    new_u_vectors[[i, old_cols]] = u[i];
228                }
229
230                c_vectors = new_c_vectors;
231                u_vectors = new_u_vectors;
232            }
233        }
234    }
235
236    // Compute final residual norm
237    let ax_final = matrix_vector_multiply(matrix, &x.view())?;
238    let final_residual = b - &ax_final;
239    let final_residual_norm = l2_norm(&final_residual.view());
240
241    Ok(GCROTResult {
242        x,
243        iterations: iter,
244        residual_norm: final_residual_norm,
245        converged,
246        residual_history,
247    })
248}
249
250/// Inner GCROT iteration (flexible GMRES step)
251#[allow(dead_code)]
252fn gcrot_inner_iteration<T, S>(
253    matrix: &S,
254    r: &ArrayView1<T>,
255    c_vectors: &scirs2_core::ndarray::ArrayView2<T>,
256    u_vectors: &scirs2_core::ndarray::ArrayView2<T>,
257    tolerance: T,
258) -> GCROTInnerResult<T>
259where
260    T: Float + SparseElement + Debug + Copy + 'static,
261    S: SparseArray<T>,
262{
263    let n = r.len();
264    let k = c_vectors.ncols(); // Number of _vectors in truncated space
265
266    // Start with the current residual
267    let mut v = r.to_owned();
268    let beta = l2_norm(&v.view());
269
270    if beta <= tolerance {
271        return Ok((Array1::zeros(n), None, None, true));
272    }
273
274    // Normalize v
275    for i in 0..n {
276        v[i] = v[i] / beta;
277    }
278
279    // Orthogonalize against the truncated space
280    for j in 0..k {
281        let mut proj = T::sparse_zero();
282        for i in 0..n {
283            proj = proj + u_vectors[[i, j]] * v[i];
284        }
285
286        for i in 0..n {
287            v[i] = v[i] - proj * c_vectors[[i, j]];
288        }
289    }
290
291    // Renormalize
292    let v_norm = l2_norm(&v.view());
293    if v_norm > T::from(1e-12).unwrap() {
294        for i in 0..n {
295            v[i] = v[i] / v_norm;
296        }
297    }
298
299    // Compute A * v
300    let av = matrix_vector_multiply(matrix, &v.view())?;
301
302    // Simple update: delta_x = (beta / ||A*v||^2) * (A*v . r) * v
303    let av_norm_sq = dot_product(&av.view(), &av.view());
304    let av_r_dot = dot_product(&av.view(), r);
305
306    if av_norm_sq > T::from(1e-12).unwrap() {
307        let alpha = av_r_dot / av_norm_sq;
308        let mut delta_x = Array1::zeros(n);
309
310        for i in 0..n {
311            delta_x[i] = alpha * v[i];
312        }
313
314        Ok((delta_x, Some(v), Some(av), false))
315    } else {
316        Ok((Array1::zeros(n), None, None, true))
317    }
318}
319
320/// Helper function for matrix-vector multiplication
321#[allow(dead_code)]
322fn matrix_vector_multiply<T, S>(matrix: &S, x: &ArrayView1<T>) -> SparseResult<Array1<T>>
323where
324    T: Float + SparseElement + Debug + Copy + 'static,
325    S: SparseArray<T>,
326{
327    let (rows, cols) = matrix.shape();
328    if x.len() != cols {
329        return Err(SparseError::DimensionMismatch {
330            expected: cols,
331            found: x.len(),
332        });
333    }
334
335    let mut result = Array1::zeros(rows);
336    let (row_indices, col_indices, values) = matrix.find();
337
338    for (k, (&i, &j)) in row_indices.iter().zip(col_indices.iter()).enumerate() {
339        result[i] = result[i] + values[k] * x[j];
340    }
341
342    Ok(result)
343}
344
345/// Compute L2 norm of a vector
346#[allow(dead_code)]
347fn l2_norm<T>(x: &ArrayView1<T>) -> T
348where
349    T: Float + Debug + Copy + SparseElement,
350{
351    (x.iter()
352        .map(|&val| val * val)
353        .fold(T::sparse_zero(), |a, b| a + b))
354    .sqrt()
355}
356
357/// Compute dot product of two vectors
358#[allow(dead_code)]
359fn dot_product<T>(x: &ArrayView1<T>, y: &ArrayView1<T>) -> T
360where
361    T: Float + Debug + Copy + SparseElement,
362{
363    x.iter()
364        .zip(y.iter())
365        .map(|(&xi, &yi)| xi * yi)
366        .fold(T::sparse_zero(), |a, b| a + b)
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use crate::csr_array::CsrArray;
373
374    #[test]
375    fn test_gcrot_simple_system() {
376        // Create a simple 3x3 system
377        let rows = vec![0, 0, 1, 1, 2, 2];
378        let cols = vec![0, 1, 0, 1, 1, 2];
379        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
380        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
381
382        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
383        let result = gcrot(&matrix, &b.view(), None, GCROTOptions::default()).unwrap();
384
385        assert!(result.converged);
386
387        // Verify solution by computing residual
388        let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
389        let residual = &b - &ax;
390        let residual_norm = l2_norm(&residual.view());
391
392        assert!(residual_norm < 1e-6);
393    }
394
395    #[test]
396    fn test_gcrot_diagonal_system() {
397        // Create a simple diagonal-dominant system that should converge easily
398        let rows = vec![0, 1, 2];
399        let cols = vec![0, 1, 2];
400        let data = vec![5.0, 5.0, 5.0];
401        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
402
403        let b = Array1::from_vec(vec![5.0, 10.0, 15.0]);
404
405        let result = gcrot(&matrix, &b.view(), None, GCROTOptions::default()).unwrap();
406
407        assert!(result.converged);
408
409        // Verify solution by computing residual
410        let ax = matrix_vector_multiply(&matrix, &result.x.view()).unwrap();
411        let residual = &b - &ax;
412        let residual_norm = l2_norm(&residual.view());
413
414        assert!(residual_norm < 1e-6);
415    }
416
417    #[test]
418    fn test_gcrot_truncation() {
419        // Test with small truncation size
420        let rows = vec![0, 0, 1, 1, 2, 2];
421        let cols = vec![0, 1, 0, 1, 1, 2];
422        let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0];
423        let matrix = CsrArray::from_triplets(&rows, &cols, &data, (3, 3), false).unwrap();
424
425        let b = Array1::from_vec(vec![1.0, 0.0, 1.0]);
426
427        let options = GCROTOptions {
428            truncation_size: 2, // Small truncation
429            ..Default::default()
430        };
431
432        let result = gcrot(&matrix, &b.view(), None, options).unwrap();
433
434        // Should still converge even with small truncation
435        assert!(result.converged);
436    }
437}