Skip to main content

scirs2_sparse/
lobpcg.rs

1//! LOBPCG (Locally Optimal Block Preconditioned Conjugate Gradient) eigensolver
2//!
3//! Solves the generalized eigenvalue problem `A x = lambda B x` for the smallest
4//! or largest eigenvalues of large sparse matrices. The method is matrix-free
5//! (only requires matrix-vector products) and supports preconditioning and
6//! locking of converged eigenvectors.
7//!
8//! # Algorithm
9//!
10//! LOBPCG iterates on a block of vectors simultaneously, applying the Rayleigh-Ritz
11//! procedure to the trial subspace `span(X, W, P)` where:
12//! - `X` — current eigenvector approximations
13//! - `W` — preconditioned residuals
14//! - `P` — conjugate direction from the previous iteration
15//!
16//! # References
17//!
18//! - Knyazev, A.V. (2001). "Toward the optimal preconditioned eigensolver: LOBPCG".
19//!   SIAM J. Sci. Comput. 23(2), 517-541.
20
21use crate::csr::CsrMatrix;
22use crate::error::{SparseError, SparseResult};
23use crate::iterative_solvers::Preconditioner;
24use scirs2_core::ndarray::{Array1, Array2};
25use scirs2_core::numeric::{Float, NumAssign, SparseElement};
26use std::fmt::Debug;
27use std::iter::Sum;
28
29// ---------------------------------------------------------------------------
30// Configuration
31// ---------------------------------------------------------------------------
32
33/// Whether to compute smallest or largest eigenvalues.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
35pub enum EigenTarget {
36    /// Smallest algebraic eigenvalues (default for LOBPCG).
37    #[default]
38    Smallest,
39    /// Largest algebraic eigenvalues.
40    Largest,
41}
42
43/// Configuration for the LOBPCG eigensolver.
44#[derive(Debug, Clone)]
45pub struct LobpcgConfig {
46    /// Number of eigenvalues to compute (block size).
47    pub block_size: usize,
48    /// Maximum number of outer iterations.
49    pub max_iter: usize,
50    /// Convergence tolerance for residual norms.
51    pub tol: f64,
52    /// Whether to compute smallest or largest eigenvalues.
53    pub target: EigenTarget,
54    /// Whether to lock converged eigenvectors early.
55    pub locking: bool,
56    /// Whether to print convergence information.
57    pub verbose: bool,
58}
59
60impl Default for LobpcgConfig {
61    fn default() -> Self {
62        Self {
63            block_size: 1,
64            max_iter: 500,
65            tol: 1e-8,
66            target: EigenTarget::Smallest,
67            locking: true,
68            verbose: false,
69        }
70    }
71}
72
73/// Result of a LOBPCG computation.
74#[derive(Debug, Clone)]
75pub struct LobpcgResult<F> {
76    /// Converged eigenvalues, sorted by magnitude.
77    pub eigenvalues: Array1<F>,
78    /// Corresponding eigenvectors stored column-wise.
79    pub eigenvectors: Array2<F>,
80    /// Number of iterations performed.
81    pub iterations: usize,
82    /// Final residual norms for each eigenpair.
83    pub residual_norms: Vec<F>,
84    /// Whether the solver converged within tolerance.
85    pub converged: bool,
86    /// Number of converged eigenpairs.
87    pub n_converged: usize,
88}
89
90// ---------------------------------------------------------------------------
91// Dense linear algebra helpers (Pure Rust, no LAPACK)
92// ---------------------------------------------------------------------------
93
94/// Compute y = A * x  for CSR matrix A and dense vector x.
95fn csr_matvec<F>(a: &CsrMatrix<F>, x: &[F]) -> SparseResult<Vec<F>>
96where
97    F: Float + NumAssign + Sum + SparseElement + 'static,
98{
99    let (m, n) = a.shape();
100    if x.len() != n {
101        return Err(SparseError::DimensionMismatch {
102            expected: n,
103            found: x.len(),
104        });
105    }
106    let mut y = vec![F::sparse_zero(); m];
107    for i in 0..m {
108        let range = a.row_range(i);
109        let cols = &a.indices[range.clone()];
110        let vals = &a.data[range];
111        let mut acc = F::sparse_zero();
112        for (idx, &col) in cols.iter().enumerate() {
113            acc += vals[idx] * x[col];
114        }
115        y[i] = acc;
116    }
117    Ok(y)
118}
119
120/// Inner product of two slices.
121#[inline]
122fn dot<F: Float + Sum>(a: &[F], b: &[F]) -> F {
123    a.iter().zip(b.iter()).map(|(&ai, &bi)| ai * bi).sum()
124}
125
126/// 2-norm of a slice.
127#[inline]
128fn norm2<F: Float + Sum>(v: &[F]) -> F {
129    dot(v, v).sqrt()
130}
131
132/// Normalise a vector in-place; returns its original norm.
133fn normalise<F: Float + Sum + SparseElement>(v: &mut [F]) -> F {
134    let nrm = norm2(v);
135    if nrm > F::epsilon() {
136        let inv = F::sparse_one() / nrm;
137        for vi in v.iter_mut() {
138            *vi = *vi * inv;
139        }
140    }
141    nrm
142}
143
144/// Classical Gram-Schmidt orthogonalisation of column `col` of `mat` against
145/// all previous columns. `mat` is stored column-major: mat[col * n .. (col+1)*n].
146fn gram_schmidt_column<F: Float + Sum + SparseElement>(mat: &mut [F], n: usize, col: usize) {
147    for j in 0..col {
148        let c = dot(&mat[j * n..(j + 1) * n], &mat[col * n..(col + 1) * n]);
149        for i in 0..n {
150            mat[col * n + i] = mat[col * n + i] - c * mat[j * n + i];
151        }
152    }
153    normalise(&mut mat[col * n..(col + 1) * n]);
154}
155
156/// Multiply a CSR matrix A by a dense column-major block V (n x k),
157/// producing AV (m x k) column-major.
158fn csr_matmul_block<F>(a: &CsrMatrix<F>, v: &[F], n: usize, k: usize) -> SparseResult<Vec<F>>
159where
160    F: Float + NumAssign + Sum + SparseElement + 'static,
161{
162    let m = a.rows();
163    let mut av = vec![F::sparse_zero(); m * k];
164    for col in 0..k {
165        let col_vec = &v[col * n..(col + 1) * n];
166        let result = csr_matvec(a, col_vec)?;
167        for i in 0..m {
168            av[col * m + i] = result[i];
169        }
170    }
171    Ok(av)
172}
173
174/// Compute the Gram matrix G = V^T W where V is (n x p) and W is (n x q),
175/// both in column-major layout. Result is (p x q) in row-major.
176fn gram_matrix<F: Float + Sum + SparseElement>(
177    v: &[F],
178    w: &[F],
179    n: usize,
180    p: usize,
181    q: usize,
182) -> Vec<F> {
183    let mut g = vec![F::sparse_zero(); p * q];
184    for i in 0..p {
185        for j in 0..q {
186            g[i * q + j] = dot(&v[i * n..(i + 1) * n], &w[j * n..(j + 1) * n]);
187        }
188    }
189    g
190}
191
192/// Solve the small dense symmetric eigenvalue problem using the Jacobi method.
193/// `a` is (k x k) row-major. Returns eigenvalues and eigenvectors (column-major k x k).
194fn dense_symmetric_eig<F>(a: &[F], k: usize) -> SparseResult<(Vec<F>, Vec<F>)>
195where
196    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
197{
198    let max_sweeps = 100;
199    let tol = F::epsilon() * F::from(100.0).unwrap_or(F::sparse_one());
200
201    // Work on a copy
202    let mut mat = a.to_vec();
203    // Eigenvectors as identity
204    let mut vecs = vec![F::sparse_zero(); k * k];
205    for i in 0..k {
206        vecs[i * k + i] = F::sparse_one();
207    }
208
209    for _sweep in 0..max_sweeps {
210        // Find max off-diagonal
211        let mut max_off = F::sparse_zero();
212        for i in 0..k {
213            for j in (i + 1)..k {
214                let val = mat[i * k + j].abs();
215                if val > max_off {
216                    max_off = val;
217                }
218            }
219        }
220        if max_off < tol {
221            break;
222        }
223
224        for p in 0..k {
225            for q in (p + 1)..k {
226                let apq = mat[p * k + q];
227                if apq.abs() < tol {
228                    continue;
229                }
230                let diff = mat[q * k + q] - mat[p * k + p];
231                let theta = if diff.abs() < F::epsilon() {
232                    F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::sparse_one())
233                } else {
234                    let tau = diff / (apq + apq);
235                    // t = sign(tau) / (|tau| + sqrt(1 + tau^2))
236                    let sign_tau = if tau >= F::sparse_zero() {
237                        F::sparse_one()
238                    } else {
239                        -F::sparse_one()
240                    };
241                    let t = sign_tau / (tau.abs() + (F::sparse_one() + tau * tau).sqrt());
242                    t.atan()
243                };
244
245                let (sin_t, cos_t) = (theta.sin(), theta.cos());
246
247                // Rotate rows/cols p, q in mat
248                for r in 0..k {
249                    if r == p || r == q {
250                        continue;
251                    }
252                    let arp = mat[r * k + p];
253                    let arq = mat[r * k + q];
254                    mat[r * k + p] = cos_t * arp - sin_t * arq;
255                    mat[r * k + q] = sin_t * arp + cos_t * arq;
256                    mat[p * k + r] = mat[r * k + p];
257                    mat[q * k + r] = mat[r * k + q];
258                }
259
260                let app = mat[p * k + p];
261                let aqq = mat[q * k + q];
262                let apq_old = mat[p * k + q];
263                mat[p * k + p] = cos_t * cos_t * app
264                    - F::from(2.0).unwrap_or(F::sparse_one()) * sin_t * cos_t * apq_old
265                    + sin_t * sin_t * aqq;
266                mat[q * k + q] = sin_t * sin_t * app
267                    + F::from(2.0).unwrap_or(F::sparse_one()) * sin_t * cos_t * apq_old
268                    + cos_t * cos_t * aqq;
269                mat[p * k + q] = F::sparse_zero();
270                mat[q * k + p] = F::sparse_zero();
271
272                // Rotate eigenvectors
273                for r in 0..k {
274                    let vp = vecs[p * k + r];
275                    let vq = vecs[q * k + r];
276                    vecs[p * k + r] = cos_t * vp - sin_t * vq;
277                    vecs[q * k + r] = sin_t * vp + cos_t * vq;
278                }
279            }
280        }
281    }
282
283    let mut eigenvalues: Vec<F> = (0..k).map(|i| mat[i * k + i]).collect();
284
285    // Sort eigenvalues and permute eigenvectors
286    let mut perm: Vec<usize> = (0..k).collect();
287    perm.sort_by(|&a_idx, &b_idx| {
288        eigenvalues[a_idx]
289            .partial_cmp(&eigenvalues[b_idx])
290            .unwrap_or(std::cmp::Ordering::Equal)
291    });
292
293    let sorted_vals: Vec<F> = perm.iter().map(|&i| eigenvalues[i]).collect();
294    let mut sorted_vecs = vec![F::sparse_zero(); k * k];
295    for (new_col, &old_col) in perm.iter().enumerate() {
296        for r in 0..k {
297            sorted_vecs[new_col * k + r] = vecs[old_col * k + r];
298        }
299    }
300
301    eigenvalues = sorted_vals;
302    Ok((eigenvalues, sorted_vecs))
303}
304
305/// Solve the small dense generalised symmetric eigenproblem
306///   S z = lambda M z
307/// where both S, M are (k x k) row-major and M is SPD.
308/// Returns eigenvalues and eigenvectors (column-major k x k).
309fn dense_generalised_eig<F>(s: &[F], m: &[F], k: usize) -> SparseResult<(Vec<F>, Vec<F>)>
310where
311    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
312{
313    // Cholesky factorisation of M = L L^T
314    let mut l_mat = vec![F::sparse_zero(); k * k];
315    for i in 0..k {
316        for j in 0..=i {
317            let mut sum = m[i * k + j];
318            for kk in 0..j {
319                sum -= l_mat[i * k + kk] * l_mat[j * k + kk];
320            }
321            if i == j {
322                if sum <= F::sparse_zero() {
323                    // Fall back to standard eigensolve (M ~ I)
324                    return dense_symmetric_eig(s, k);
325                }
326                l_mat[i * k + j] = sum.sqrt();
327            } else {
328                let l_jj = l_mat[j * k + j];
329                if l_jj.abs() < F::epsilon() {
330                    return dense_symmetric_eig(s, k);
331                }
332                l_mat[i * k + j] = sum / l_jj;
333            }
334        }
335    }
336
337    // Compute L^{-1}
338    let mut l_inv = vec![F::sparse_zero(); k * k];
339    for i in 0..k {
340        l_inv[i * k + i] = F::sparse_one() / l_mat[i * k + i];
341        for j in (i + 1)..k {
342            let mut sum = F::sparse_zero();
343            for kk in i..j {
344                sum += l_mat[j * k + kk] * l_inv[kk * k + i];
345            }
346            l_inv[j * k + i] = -sum / l_mat[j * k + j];
347        }
348    }
349
350    // Compute S' = L^{-1} S L^{-T}
351    // First: T = L^{-1} S
352    let mut temp = vec![F::sparse_zero(); k * k];
353    for i in 0..k {
354        for j in 0..k {
355            let mut val = F::sparse_zero();
356            for kk in 0..k {
357                val += l_inv[i * k + kk] * s[kk * k + j];
358            }
359            temp[i * k + j] = val;
360        }
361    }
362    // S' = T * L^{-T}
363    let mut s_prime = vec![F::sparse_zero(); k * k];
364    for i in 0..k {
365        for j in 0..k {
366            let mut val = F::sparse_zero();
367            for kk in 0..k {
368                val += temp[i * k + kk] * l_inv[j * k + kk]; // L^{-T}[kk,j] = L^{-1}[j,kk]
369            }
370            s_prime[i * k + j] = val;
371        }
372    }
373
374    // Standard eigenproblem on S'
375    let (eigenvalues, z_vecs) = dense_symmetric_eig(&s_prime, k)?;
376
377    // Back-transform: x = L^{-T} z
378    let mut eigenvectors = vec![F::sparse_zero(); k * k];
379    for col in 0..k {
380        for i in 0..k {
381            let mut val = F::sparse_zero();
382            for kk in 0..k {
383                val += l_inv[kk * k + i] * z_vecs[col * k + kk];
384            }
385            eigenvectors[col * k + i] = val;
386        }
387    }
388
389    Ok((eigenvalues, eigenvectors))
390}
391
392// ---------------------------------------------------------------------------
393// LOBPCG solver
394// ---------------------------------------------------------------------------
395
396/// Run the LOBPCG eigensolver for the standard eigenvalue problem `A x = lambda x`.
397///
398/// # Arguments
399///
400/// * `a` - Sparse matrix (CSR format)
401/// * `config` - Solver configuration
402/// * `precond` - Optional preconditioner
403/// * `initial_vectors` - Optional initial guesses (n x block_size column-major)
404///
405/// # Returns
406///
407/// A `LobpcgResult` containing eigenvalues, eigenvectors, iteration count, and
408/// residual norms.
409pub fn lobpcg<F>(
410    a: &CsrMatrix<F>,
411    config: &LobpcgConfig,
412    precond: Option<&dyn Preconditioner<F>>,
413    initial_vectors: Option<&Array2<F>>,
414) -> SparseResult<LobpcgResult<F>>
415where
416    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
417{
418    lobpcg_generalised(a, None, config, precond, initial_vectors)
419}
420
421/// Run the LOBPCG eigensolver for the generalised eigenvalue problem
422/// `A x = lambda B x`.
423///
424/// When `b` is `None`, it reduces to the standard problem (`B = I`).
425///
426/// # Arguments
427///
428/// * `a` - Left-hand-side sparse matrix (CSR format)
429/// * `b` - Optional right-hand-side SPD sparse matrix (CSR format)
430/// * `config` - Solver configuration
431/// * `precond` - Optional preconditioner
432/// * `initial_vectors` - Optional initial guesses (n x block_size column-major)
433pub fn lobpcg_generalised<F>(
434    a: &CsrMatrix<F>,
435    b: Option<&CsrMatrix<F>>,
436    config: &LobpcgConfig,
437    precond: Option<&dyn Preconditioner<F>>,
438    initial_vectors: Option<&Array2<F>>,
439) -> SparseResult<LobpcgResult<F>>
440where
441    F: Float + NumAssign + Sum + SparseElement + Debug + 'static,
442{
443    let (m, n_cols) = a.shape();
444    if m != n_cols {
445        return Err(SparseError::ValueError(
446            "LOBPCG requires a square matrix".to_string(),
447        ));
448    }
449    let n = m;
450    let k = config.block_size;
451    if k == 0 || k > n {
452        return Err(SparseError::ValueError(format!(
453            "block_size must be in [1, {n}], got {k}"
454        )));
455    }
456
457    if let Some(b_mat) = b {
458        let (bm, bn) = b_mat.shape();
459        if bm != n || bn != n {
460            return Err(SparseError::ShapeMismatch {
461                expected: (n, n),
462                found: (bm, bn),
463            });
464        }
465    }
466
467    let tol = F::from(config.tol)
468        .ok_or_else(|| SparseError::ValueError("Failed to convert tolerance".to_string()))?;
469
470    // ---- Initialise X (column-major: n * k) ----
471    let mut x_buf = vec![F::sparse_zero(); n * k];
472    if let Some(init) = initial_vectors {
473        let (ir, ic) = (init.nrows(), init.ncols());
474        if ir != n || ic != k {
475            return Err(SparseError::ShapeMismatch {
476                expected: (n, k),
477                found: (ir, ic),
478            });
479        }
480        for col in 0..k {
481            for row in 0..n {
482                x_buf[col * n + row] = init[[row, col]];
483            }
484        }
485    } else {
486        // Deterministic initial vectors: coordinate vectors + small perturbation
487        for j in 0..k {
488            if j < n {
489                x_buf[j * n + j] = F::sparse_one();
490            }
491            // Add small deterministic perturbation for robustness
492            for i in 0..n {
493                let val = F::from((i + j + 1) as f64 / (n + k) as f64).unwrap_or(F::sparse_zero());
494                x_buf[j * n + i] += val * F::from(0.01).unwrap_or(F::sparse_zero());
495            }
496        }
497    }
498
499    // B-orthogonalise X
500    b_orthonormalise(&mut x_buf, b, n, k)?;
501
502    // AX = A * X
503    let mut ax_buf = csr_matmul_block(a, &x_buf, n, k)?;
504    // BX = B * X  (or X if B is None)
505    let mut bx_buf = match b {
506        Some(b_mat) => csr_matmul_block(b_mat, &x_buf, n, k)?,
507        None => x_buf.clone(),
508    };
509
510    // Rayleigh quotients (initial eigenvalue estimates)
511    let mut lambdas = vec![F::sparse_zero(); k];
512    for j in 0..k {
513        lambdas[j] = dot(&x_buf[j * n..(j + 1) * n], &ax_buf[j * n..(j + 1) * n]);
514    }
515
516    // P buffer (conjugate directions) - initialised to zero, indicating first iter
517    let mut p_buf = vec![F::sparse_zero(); n * k];
518    let mut ap_buf = vec![F::sparse_zero(); n * k];
519    let mut bp_buf = vec![F::sparse_zero(); n * k];
520    let mut have_p = false;
521
522    let mut locked_count = 0usize;
523    let mut residual_norms = vec![F::sparse_zero(); k];
524    let mut converged_flags = vec![false; k];
525    let mut iter_count = 0usize;
526
527    for iteration in 0..config.max_iter {
528        iter_count = iteration + 1;
529
530        // ---- Compute residuals R = AX - lambda * BX ----
531        let active_start = locked_count;
532        let active_k = k - active_start;
533        if active_k == 0 {
534            break;
535        }
536
537        let mut r_buf = vec![F::sparse_zero(); n * active_k];
538        for j in 0..active_k {
539            let gj = j + active_start;
540            let lam = lambdas[gj];
541            for i in 0..n {
542                r_buf[j * n + i] = ax_buf[gj * n + i] - lam * bx_buf[gj * n + i];
543            }
544            residual_norms[gj] = norm2(&r_buf[j * n..(j + 1) * n]);
545        }
546
547        // ---- Check convergence ----
548        let mut all_converged = true;
549        for j in 0..active_k {
550            let gj = j + active_start;
551            if residual_norms[gj] < tol {
552                converged_flags[gj] = true;
553            } else {
554                all_converged = false;
555            }
556        }
557
558        if all_converged {
559            break;
560        }
561
562        // ---- Locking ----
563        if config.locking && active_k > 1 {
564            let mut newly_locked = 0usize;
565            for j in 0..active_k {
566                let gj = j + active_start;
567                if converged_flags[gj] && gj == locked_count + newly_locked {
568                    newly_locked += 1;
569                }
570            }
571            if newly_locked > 0 {
572                locked_count += newly_locked;
573                have_p = false; // reset conjugate directions after locking
574                continue;
575            }
576        }
577
578        // ---- Apply preconditioner: W = T^{-1} R ----
579        let mut w_buf = vec![F::sparse_zero(); n * active_k];
580        for j in 0..active_k {
581            let r_col = &r_buf[j * n..(j + 1) * n];
582            match precond {
583                Some(pc) => {
584                    let r_arr = Array1::from_vec(r_col.to_vec());
585                    let w_arr = pc.apply(&r_arr)?;
586                    for i in 0..n {
587                        w_buf[j * n + i] = w_arr[i];
588                    }
589                }
590                None => {
591                    w_buf[j * n..(j + 1) * n].copy_from_slice(r_col);
592                }
593            }
594        }
595
596        // B-orthogonalise W against locked + X
597        b_orthogonalise_against(&mut w_buf, &x_buf, b, n, active_k, k)?;
598        // Orthonormalise W internally
599        for j in 0..active_k {
600            // Orthogonalise against previous W columns
601            for prev in 0..j {
602                let c = b_inner_product(
603                    &w_buf[prev * n..(prev + 1) * n],
604                    &w_buf[j * n..(j + 1) * n],
605                    b,
606                    n,
607                )?;
608                for i in 0..n {
609                    w_buf[j * n + i] = w_buf[j * n + i] - c * w_buf[prev * n + i];
610                }
611            }
612            normalise(&mut w_buf[j * n..(j + 1) * n]);
613        }
614
615        // AW = A * W
616        let aw_buf = csr_matmul_block(a, &w_buf, n, active_k)?;
617        // BW = B * W
618        let bw_buf = match b {
619            Some(b_mat) => csr_matmul_block(b_mat, &w_buf, n, active_k)?,
620            None => w_buf.clone(),
621        };
622
623        // ---- Build the Rayleigh-Ritz problem on [X_active, W, P] ----
624        let subspace_dim = if have_p {
625            active_k + active_k + active_k // X_active + W + P
626        } else {
627            active_k + active_k // X_active + W
628        };
629
630        // Concatenate subspace vectors and A-products
631        let mut s_vecs = vec![F::sparse_zero(); n * subspace_dim];
632        let mut as_vecs = vec![F::sparse_zero(); n * subspace_dim];
633        let mut bs_vecs = vec![F::sparse_zero(); n * subspace_dim];
634
635        // Copy X_active
636        for j in 0..active_k {
637            let gj = j + active_start;
638            s_vecs[j * n..(j + 1) * n].copy_from_slice(&x_buf[gj * n..(gj + 1) * n]);
639            as_vecs[j * n..(j + 1) * n].copy_from_slice(&ax_buf[gj * n..(gj + 1) * n]);
640            bs_vecs[j * n..(j + 1) * n].copy_from_slice(&bx_buf[gj * n..(gj + 1) * n]);
641        }
642        // Copy W
643        let w_off = active_k;
644        for j in 0..active_k {
645            s_vecs[(w_off + j) * n..(w_off + j + 1) * n]
646                .copy_from_slice(&w_buf[j * n..(j + 1) * n]);
647            as_vecs[(w_off + j) * n..(w_off + j + 1) * n]
648                .copy_from_slice(&aw_buf[j * n..(j + 1) * n]);
649            bs_vecs[(w_off + j) * n..(w_off + j + 1) * n]
650                .copy_from_slice(&bw_buf[j * n..(j + 1) * n]);
651        }
652        // Copy P (if available)
653        if have_p {
654            let p_off = active_k + active_k;
655            for j in 0..active_k {
656                let gj = j + active_start;
657                s_vecs[(p_off + j) * n..(p_off + j + 1) * n]
658                    .copy_from_slice(&p_buf[gj * n..(gj + 1) * n]);
659                as_vecs[(p_off + j) * n..(p_off + j + 1) * n]
660                    .copy_from_slice(&ap_buf[gj * n..(gj + 1) * n]);
661                bs_vecs[(p_off + j) * n..(p_off + j + 1) * n]
662                    .copy_from_slice(&bp_buf[gj * n..(gj + 1) * n]);
663            }
664        }
665
666        // Gram matrices: S_gram = S^T A S,  M_gram = S^T B S
667        let s_gram = gram_matrix(&s_vecs, &as_vecs, n, subspace_dim, subspace_dim);
668        let m_gram = gram_matrix(&s_vecs, &bs_vecs, n, subspace_dim, subspace_dim);
669
670        // Solve small generalised eigenproblem
671        let (small_evals, small_evecs) = dense_generalised_eig(&s_gram, &m_gram, subspace_dim)?;
672
673        // Select eigenvalues based on target
674        let selected_indices: Vec<usize> = match config.target {
675            EigenTarget::Smallest => (0..active_k).collect(),
676            EigenTarget::Largest => {
677                let start = subspace_dim.saturating_sub(active_k);
678                (start..subspace_dim).collect()
679            }
680        };
681
682        // ---- Update X, AX, BX using Ritz vectors ----
683        // Also compute P = new_X - old_X
684        let old_x_active: Vec<F> = (0..active_k)
685            .flat_map(|j| {
686                let gj = j + active_start;
687                x_buf[gj * n..(gj + 1) * n].to_vec()
688            })
689            .collect();
690
691        for (sel_idx, &eig_col) in selected_indices.iter().enumerate() {
692            let gj = sel_idx + active_start;
693            lambdas[gj] = small_evals[eig_col];
694
695            // Compute the new X column: x_new = S * z
696            let z_col = &small_evecs[eig_col * subspace_dim..(eig_col + 1) * subspace_dim];
697            for i in 0..n {
698                let mut xval = F::sparse_zero();
699                let mut axval = F::sparse_zero();
700                let mut bxval = F::sparse_zero();
701                for (s_idx, &zc) in z_col.iter().enumerate() {
702                    xval += zc * s_vecs[s_idx * n + i];
703                    axval += zc * as_vecs[s_idx * n + i];
704                    bxval += zc * bs_vecs[s_idx * n + i];
705                }
706                // P = x_new - x_old
707                p_buf[gj * n + i] = xval - old_x_active[sel_idx * n + i];
708                x_buf[gj * n + i] = xval;
709                ax_buf[gj * n + i] = axval;
710                bx_buf[gj * n + i] = bxval;
711            }
712
713            // AP = A * P,  BP = B * P
714            let p_col = &p_buf[gj * n..(gj + 1) * n];
715            let ap_col = csr_matvec(a, p_col)?;
716            for i in 0..n {
717                ap_buf[gj * n + i] = ap_col[i];
718            }
719            match b {
720                Some(b_mat) => {
721                    let bp_col = csr_matvec(b_mat, p_col)?;
722                    for i in 0..n {
723                        bp_buf[gj * n + i] = bp_col[i];
724                    }
725                }
726                None => {
727                    bp_buf[gj * n..(gj + 1) * n].copy_from_slice(p_col);
728                }
729            }
730        }
731
732        have_p = true;
733    }
734
735    // ---- Assemble final result ----
736    let n_converged = converged_flags.iter().filter(|&&f| f).count();
737    let all_converged = n_converged == k;
738
739    let mut eigenvalues = Array1::zeros(k);
740    let mut eigenvectors = Array2::zeros((n, k));
741    for j in 0..k {
742        eigenvalues[j] = lambdas[j];
743        for i in 0..n {
744            eigenvectors[[i, j]] = x_buf[j * n + i];
745        }
746    }
747
748    Ok(LobpcgResult {
749        eigenvalues,
750        eigenvectors,
751        iterations: iter_count,
752        residual_norms,
753        converged: all_converged,
754        n_converged,
755    })
756}
757
758// ---------------------------------------------------------------------------
759// B-orthogonalisation helpers
760// ---------------------------------------------------------------------------
761
762/// Compute <u, v>_B = u^T B v  (or u^T v if B is None).
763fn b_inner_product<F>(u: &[F], v: &[F], b: Option<&CsrMatrix<F>>, n: usize) -> SparseResult<F>
764where
765    F: Float + NumAssign + Sum + SparseElement + 'static,
766{
767    match b {
768        Some(b_mat) => {
769            let bv = csr_matvec(b_mat, v)?;
770            Ok(dot(u, &bv))
771        }
772        None => Ok(dot(u, v)),
773    }
774}
775
776/// B-orthonormalise columns of `mat` (column-major, n x k) using modified Gram-Schmidt.
777fn b_orthonormalise<F>(
778    mat: &mut [F],
779    b: Option<&CsrMatrix<F>>,
780    n: usize,
781    k: usize,
782) -> SparseResult<()>
783where
784    F: Float + NumAssign + Sum + SparseElement + 'static,
785{
786    for j in 0..k {
787        // Orthogonalise against all previous columns
788        for prev in 0..j {
789            let c = b_inner_product(
790                &mat[prev * n..(prev + 1) * n],
791                &mat[j * n..(j + 1) * n],
792                b,
793                n,
794            )?;
795            for i in 0..n {
796                mat[j * n + i] -= c * mat[prev * n + i];
797            }
798        }
799        // B-normalise
800        let bnorm = b_inner_product(&mat[j * n..(j + 1) * n], &mat[j * n..(j + 1) * n], b, n)?;
801        if bnorm > F::epsilon() {
802            let inv = F::sparse_one() / bnorm.sqrt();
803            for i in 0..n {
804                mat[j * n + i] *= inv;
805            }
806        }
807    }
808    Ok(())
809}
810
811/// Orthogonalise columns of `w` (n x wk) against columns of `q` (n x qk) w.r.t. B.
812fn b_orthogonalise_against<F>(
813    w: &mut [F],
814    q: &[F],
815    b: Option<&CsrMatrix<F>>,
816    n: usize,
817    wk: usize,
818    qk: usize,
819) -> SparseResult<()>
820where
821    F: Float + NumAssign + Sum + SparseElement + 'static,
822{
823    for j in 0..wk {
824        for qi in 0..qk {
825            let c = b_inner_product(&q[qi * n..(qi + 1) * n], &w[j * n..(j + 1) * n], b, n)?;
826            for i in 0..n {
827                w[j * n + i] -= c * q[qi * n + i];
828            }
829        }
830    }
831    Ok(())
832}
833
834// ---------------------------------------------------------------------------
835// Tests
836// ---------------------------------------------------------------------------
837
838#[cfg(test)]
839mod tests {
840    use super::*;
841
842    /// Build a small SPD tridiagonal matrix: 2 on diagonal, -1 off-diagonal.
843    fn build_tridiag_spd(n: usize) -> CsrMatrix<f64> {
844        let mut rows = Vec::new();
845        let mut cols = Vec::new();
846        let mut data = Vec::new();
847        for i in 0..n {
848            if i > 0 {
849                rows.push(i);
850                cols.push(i - 1);
851                data.push(-1.0);
852            }
853            rows.push(i);
854            cols.push(i);
855            data.push(2.0);
856            if i + 1 < n {
857                rows.push(i);
858                cols.push(i + 1);
859                data.push(-1.0);
860            }
861        }
862        CsrMatrix::new(data, rows, cols, (n, n)).expect("valid matrix")
863    }
864
865    /// Build a diagonal SPD matrix.
866    fn build_diag_matrix(diag: &[f64]) -> CsrMatrix<f64> {
867        let n = diag.len();
868        let rows: Vec<usize> = (0..n).collect();
869        let cols: Vec<usize> = (0..n).collect();
870        CsrMatrix::new(diag.to_vec(), rows, cols, (n, n)).expect("valid matrix")
871    }
872
873    #[test]
874    fn test_lobpcg_smallest_eigenvalue_tridiag() {
875        let n = 20;
876        let a = build_tridiag_spd(n);
877        let config = LobpcgConfig {
878            block_size: 2,
879            max_iter: 200,
880            tol: 1e-6,
881            target: EigenTarget::Smallest,
882            locking: true,
883            verbose: false,
884        };
885        let result = lobpcg(&a, &config, None, None).expect("lobpcg should succeed");
886        // The smallest eigenvalue of the 1D Laplacian is ~ 4 sin^2(pi/(2(n+1)))
887        let lambda_min_exact = 4.0
888            * (std::f64::consts::PI / (2.0 * (n as f64 + 1.0)))
889                .sin()
890                .powi(2);
891        let computed = result.eigenvalues[0];
892        assert!(
893            (computed - lambda_min_exact).abs() < 1e-4,
894            "Expected smallest eigenvalue ~{lambda_min_exact}, got {computed}"
895        );
896    }
897
898    #[test]
899    fn test_lobpcg_largest_eigenvalue_tridiag() {
900        let n = 20;
901        let a = build_tridiag_spd(n);
902        let config = LobpcgConfig {
903            block_size: 1,
904            max_iter: 200,
905            tol: 1e-6,
906            target: EigenTarget::Largest,
907            locking: false,
908            verbose: false,
909        };
910        let result = lobpcg(&a, &config, None, None).expect("lobpcg should succeed");
911        // The largest eigenvalue of the 1D Laplacian is ~ 4 cos^2(pi/(2(n+1)))
912        let lambda_max_exact = 4.0
913            * (std::f64::consts::PI * n as f64 / (2.0 * (n as f64 + 1.0)))
914                .sin()
915                .powi(2);
916        let computed = result.eigenvalues[0];
917        assert!(
918            (computed - lambda_max_exact).abs() < 1e-3,
919            "Expected largest eigenvalue ~{lambda_max_exact}, got {computed}"
920        );
921    }
922
923    #[test]
924    fn test_lobpcg_diagonal_matrix() {
925        let diag = vec![1.0, 3.0, 5.0, 7.0, 9.0];
926        let a = build_diag_matrix(&diag);
927        let config = LobpcgConfig {
928            block_size: 2,
929            max_iter: 100,
930            tol: 1e-10,
931            target: EigenTarget::Smallest,
932            locking: true,
933            verbose: false,
934        };
935        let result = lobpcg(&a, &config, None, None).expect("lobpcg should succeed");
936        // Smallest two eigenvalues should be 1.0 and 3.0
937        let mut eigs: Vec<f64> = result.eigenvalues.to_vec();
938        eigs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
939        assert!(
940            (eigs[0] - 1.0).abs() < 1e-6,
941            "Expected 1.0, got {}",
942            eigs[0]
943        );
944        assert!(
945            (eigs[1] - 3.0).abs() < 1e-6,
946            "Expected 3.0, got {}",
947            eigs[1]
948        );
949    }
950
951    #[test]
952    fn test_lobpcg_generalised_with_identity_b() {
953        let n = 10;
954        let a = build_tridiag_spd(n);
955        let diag_ones = vec![1.0; n];
956        let b = build_diag_matrix(&diag_ones);
957        let config = LobpcgConfig {
958            block_size: 1,
959            max_iter: 200,
960            tol: 1e-6,
961            target: EigenTarget::Smallest,
962            ..Default::default()
963        };
964        let result =
965            lobpcg_generalised(&a, Some(&b), &config, None, None).expect("generalised lobpcg");
966        let lambda_min_exact = 4.0
967            * (std::f64::consts::PI / (2.0 * (n as f64 + 1.0)))
968                .sin()
969                .powi(2);
970        assert!(
971            (result.eigenvalues[0] - lambda_min_exact).abs() < 1e-4,
972            "Expected ~{lambda_min_exact}, got {}",
973            result.eigenvalues[0]
974        );
975    }
976
977    #[test]
978    fn test_lobpcg_generalised_nontrivial_b() {
979        // A = diag(2,4,6), B = diag(1,2,3)
980        // Generalized eigenvalues: 2/1=2, 4/2=2, 6/3=2 => all eigenvalues = 2
981        let a = build_diag_matrix(&[2.0, 4.0, 6.0]);
982        let b = build_diag_matrix(&[1.0, 2.0, 3.0]);
983        let config = LobpcgConfig {
984            block_size: 1,
985            max_iter: 100,
986            tol: 1e-8,
987            target: EigenTarget::Smallest,
988            ..Default::default()
989        };
990        let result =
991            lobpcg_generalised(&a, Some(&b), &config, None, None).expect("generalised lobpcg");
992        assert!(
993            (result.eigenvalues[0] - 2.0).abs() < 1e-4,
994            "Expected 2.0, got {}",
995            result.eigenvalues[0]
996        );
997    }
998
999    #[test]
1000    fn test_lobpcg_with_preconditioner() {
1001        let n = 15;
1002        let a = build_tridiag_spd(n);
1003        let precond = JacobiPreconditioner::new(&a).expect("Jacobi precond");
1004        let config = LobpcgConfig {
1005            block_size: 1,
1006            max_iter: 200,
1007            tol: 1e-8,
1008            target: EigenTarget::Smallest,
1009            ..Default::default()
1010        };
1011        let result = lobpcg(&a, &config, Some(&precond), None).expect("lobpcg with precond");
1012        let lambda_min = 4.0
1013            * (std::f64::consts::PI / (2.0 * (n as f64 + 1.0)))
1014                .sin()
1015                .powi(2);
1016        assert!(
1017            (result.eigenvalues[0] - lambda_min).abs() < 1e-4,
1018            "Expected ~{lambda_min}, got {}",
1019            result.eigenvalues[0]
1020        );
1021    }
1022
1023    #[test]
1024    fn test_lobpcg_with_initial_vectors() {
1025        let n = 10;
1026        let a = build_tridiag_spd(n);
1027        // Create a reasonable initial guess
1028        let mut init = Array2::zeros((n, 1));
1029        for i in 0..n {
1030            init[[i, 0]] = ((i + 1) as f64 * std::f64::consts::PI / (n as f64 + 1.0)).sin();
1031        }
1032        let config = LobpcgConfig {
1033            block_size: 1,
1034            max_iter: 100,
1035            tol: 1e-8,
1036            target: EigenTarget::Smallest,
1037            ..Default::default()
1038        };
1039        let result = lobpcg(&a, &config, None, Some(&init)).expect("lobpcg with initial vectors");
1040        let lambda_min = 4.0
1041            * (std::f64::consts::PI / (2.0 * (n as f64 + 1.0)))
1042                .sin()
1043                .powi(2);
1044        assert!(
1045            (result.eigenvalues[0] - lambda_min).abs() < 1e-4,
1046            "Expected ~{lambda_min}, got {}",
1047            result.eigenvalues[0]
1048        );
1049    }
1050
1051    #[test]
1052    fn test_lobpcg_convergence_info() {
1053        let a = build_diag_matrix(&[1.0, 2.0, 3.0, 4.0, 5.0]);
1054        let config = LobpcgConfig {
1055            block_size: 1,
1056            max_iter: 100,
1057            tol: 1e-10,
1058            target: EigenTarget::Smallest,
1059            ..Default::default()
1060        };
1061        let result = lobpcg(&a, &config, None, None).expect("lobpcg");
1062        assert!(
1063            result.converged,
1064            "should converge on simple diagonal matrix"
1065        );
1066        assert!(result.iterations > 0);
1067        assert_eq!(result.n_converged, 1);
1068    }
1069
1070    #[test]
1071    fn test_lobpcg_error_on_non_square() {
1072        let rows = vec![0, 1];
1073        let cols = vec![0, 1];
1074        let data = vec![1.0, 2.0];
1075        let a = CsrMatrix::new(data, rows, cols, (2, 3)).expect("valid rect matrix");
1076        let config = LobpcgConfig::default();
1077        let result = lobpcg(&a, &config, None, None);
1078        assert!(result.is_err());
1079    }
1080
1081    #[test]
1082    fn test_lobpcg_multiple_eigenvalues() {
1083        let n = 30;
1084        let a = build_tridiag_spd(n);
1085        let config = LobpcgConfig {
1086            block_size: 3,
1087            max_iter: 300,
1088            tol: 1e-5,
1089            target: EigenTarget::Smallest,
1090            locking: true,
1091            verbose: false,
1092        };
1093        let result = lobpcg(&a, &config, None, None).expect("lobpcg");
1094        // First 3 eigenvalues
1095        for j in 0..3 {
1096            let exact = 4.0
1097                * (std::f64::consts::PI * (j + 1) as f64 / (2.0 * (n as f64 + 1.0)))
1098                    .sin()
1099                    .powi(2);
1100            let computed = result.eigenvalues[j];
1101            assert!(
1102                (computed - exact).abs() < 0.05,
1103                "Eigenvalue {j}: expected ~{exact}, got {computed}"
1104            );
1105        }
1106    }
1107
1108    #[test]
1109    fn test_dense_symmetric_eig_basic() {
1110        // 2x2 symmetric: [[2,1],[1,2]] => eigenvalues 1 and 3
1111        let a = vec![2.0, 1.0, 1.0, 2.0];
1112        let (vals, _vecs) = dense_symmetric_eig(&a, 2).expect("dense eig");
1113        assert!((vals[0] - 1.0).abs() < 1e-10);
1114        assert!((vals[1] - 3.0).abs() < 1e-10);
1115    }
1116
1117    #[test]
1118    fn test_dense_generalised_eig_identity_b() {
1119        let a = vec![3.0, 1.0, 1.0, 3.0];
1120        let b = vec![1.0, 0.0, 0.0, 1.0];
1121        let (vals, _) = dense_generalised_eig(&a, &b, 2).expect("dense gen eig");
1122        assert!((vals[0] - 2.0).abs() < 1e-10);
1123        assert!((vals[1] - 4.0).abs() < 1e-10);
1124    }
1125
1126    // Helper: Jacobi preconditioner for tests
1127    use crate::iterative_solvers::JacobiPreconditioner;
1128}