Skip to main content

ariadnetor_algorithms/dmrg/
sweep.rs

1//! 2-site DMRG sweep driver.
2//!
3//! Runs alternating L→R / R→L half-sweeps over a [`super::DmrgOps`]
4//! chain, dispatching the storage-specific local solve and
5//! S-absorb through the trait. Mutates the MPS site tensors and the
6//! [`BraketEnvs`] in place. Caller supplies a right-canonical (or
7//! `Mixed { center: 0 }`) MPS plus a freshly-built `BraketEnvs`; the
8//! driver does **not** auto-canonicalize because doing so would
9//! silently invalidate the caller-supplied envs.
10//!
11//! # Canonical-form precondition
12//!
13//! The local effective-Hamiltonian eigenvalue equation returns
14//! physical energy directly only when, with active block `(i, i+1)`,
15//! sites `(0..i)` are left-canonical and sites `(i+2..N-1)` are
16//! right-canonical. The driver starts L→R from `i = 0`, so the
17//! binding precondition is right-canonicality of `(2..N-1)`, met
18//! exactly by `Right` and `Mixed { center: 0 }`. This argument is
19//! storage-independent (the local-block requirement does not depend
20//! on Dense vs BlockSparse representation), so the gate applies
21//! uniformly through the trait.
22//!
23//! # Convergence
24//!
25//! After each full L→R + R→L cycle we record the **normalized**
26//! post-truncation expectation `<psi|H|psi>.re() / <psi|psi>`.
27//! The truncated SVD keeps unrenormalized singular values, so without
28//! the `<psi|psi>` divisor the sweep energy drifts toward zero
29//! whenever truncation happens — the divisor strips that
30//! norm-artifact away. Convergence requires energy delta within
31//! `energy_tol`, every step's local eigensolver converged, and
32//! `n_sweeps >= min_sweeps`.
33
34use ariadnetor_core::Scalar;
35use ariadnetor_linalg::LinalgError;
36use ariadnetor_mps::{CanonicalForm, Mpo, Mps, MpsOps, TensorChain, braket};
37use ariadnetor_tensor::{Host, OpsFor, Storage, StorageFor, TensorLayout};
38
39use crate::numeric::try_real_from_f64;
40
41use super::dispatch::{DmrgOps, FullStepError};
42use super::heff_error::DmrgHeffError;
43use super::solver::{LocalEigensolverParams, eigensolver_tol, validate_eigensolver_params};
44use ariadnetor_mps::{BraketEnvError, BraketEnvOps, BraketEnvs};
45
46/// Direction of a half-sweep.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SweepDirection {
49    /// Steps from `site = 0` to `site = n_sites - 2`. Each step's
50    /// SVD splits the optimized 2-site block into a left-isometric
51    /// site `i` and an `S·Vt`-carrying site `i + 1`.
52    LeftToRight,
53    /// Steps from `site = n_sites - 2` down to `site = 0`. Each
54    /// step's SVD splits into a `U·S`-carrying site `i` and a
55    /// right-isometric site `i + 1`.
56    RightToLeft,
57}
58
59/// Caller-supplied parameters for [`sweep_2site`] (the chain-generic
60/// 2-site DMRG sweep driver, dispatched over `Mps<St, L>: super::DmrgOps<T>`
61/// so the same params type covers both the Dense and BlockSparse /
62/// U(1) paths).
63///
64/// Stored as plain `f64` for `energy_tol`; the entry point converts
65/// it to `T::Real` via the same `NumCast::from(f64)` idiom as
66/// [`crate::krylov::LanczosParams::tol`]. This keeps the params
67/// type non-generic across the scalar type.
68#[derive(Debug, Clone)]
69pub struct DmrgSweepParams {
70    /// Maximum number of full L→R + R→L cycles. Must be `>= 1`;
71    /// `0` is rejected with [`DmrgSweepError::InvalidParams`].
72    pub max_sweeps: usize,
73    /// Minimum number of sweeps before the energy-delta convergence
74    /// test is honored. Pre-`min_sweeps` cycles always continue
75    /// regardless of energy delta. Must be `<= max_sweeps`.
76    pub min_sweeps: usize,
77    /// Energy-delta tolerance. After cycle `n >= min_sweeps`,
78    /// convergence requires `|E_n - E_{n-1}| <= energy_tol`. Must
79    /// be finite and non-negative.
80    pub energy_tol: f64,
81    /// Local-eigensolver selector + per-variant parameters, forwarded
82    /// to the per-step driver (Dense `dmrg_2site_step` or BlockSparse
83    /// `dmrg_2site_step_block_sparse`).
84    pub eigensolver: LocalEigensolverParams,
85    /// Truncated-SVD parameters, forwarded to the per-step driver
86    /// (Dense `dmrg_2site_step` or BlockSparse
87    /// `dmrg_2site_step_block_sparse`).
88    pub trunc: ariadnetor_linalg::TruncSvdParams,
89}
90
91/// Per-step diagnostics record.
92#[derive(Debug, Clone)]
93pub struct DmrgStepRecord<R> {
94    /// Index of the sweep cycle this step belongs to.
95    pub sweep: usize,
96    /// Direction of the half-sweep (`L→R` or `R→L`).
97    pub direction: SweepDirection,
98    /// Left site of the optimized two-site block.
99    pub site: usize,
100    /// Smallest eigenvalue of `H_eff` at this step (pre-truncation
101    /// local-block variational minimum). May lie below the
102    /// post-truncation sweep energy.
103    pub eigenvalue: R,
104    /// Local-eigensolver true residual `‖H v − λ v‖₂`.
105    pub residual: R,
106    /// Frobenius norm of singular values discarded by this step's
107    /// truncated SVD.
108    pub trunc_err: R,
109    /// New bond dimension between `site` and `site + 1` after the
110    /// truncated split.
111    pub bond_dim: usize,
112    /// Number of iterations the local eigensolver ran for this step.
113    /// For Lanczos this is the inner loop count; for ARPACK it is
114    /// the restart-iteration count returned in `iparam[2]`.
115    pub eigensolver_iters: usize,
116    /// `true` iff the local eigensolver succeeded — Lanczos by its
117    /// absolute true-residual test against `LanczosParams::tol`,
118    /// ARPACK by its relative-tol stopping criterion (i.e. `Ok`
119    /// return from `arpack_smallest`). The two arms intentionally
120    /// disagree on what they call "converged": Lanczos uses the
121    /// absolute residual; ARPACK uses `residual <= tol * |lambda|`.
122    /// See [`super::heff::TwoSiteStepResult::converged`] for the
123    /// upstream contract this field forwards from.
124    pub eigensolver_converged: bool,
125}
126
127/// Per-sweep diagnostics record (one full L→R + R→L cycle).
128#[derive(Debug, Clone)]
129pub struct DmrgSweepRecord<R> {
130    /// Index of this sweep cycle.
131    pub sweep: usize,
132    /// Normalized post-truncation `<psi|H|psi> / <psi|psi>` after
133    /// this cycle. The convergence metric.
134    pub sweep_energy: R,
135    /// `min(step.eigenvalue)` across this cycle. Diagnostic only —
136    /// reflects local-block variational minima, which can be lower
137    /// than `sweep_energy`.
138    pub min_step_eigenvalue: R,
139    /// Largest per-step truncation error in this cycle.
140    pub max_trunc_err: R,
141    /// Largest bond dimension reached in this cycle.
142    pub max_bond: usize,
143    /// `true` iff every step in this cycle's local-eigensolver pass
144    /// converged.
145    pub all_eigensolver_converged: bool,
146    /// Per-step diagnostic records for this cycle, in execution order.
147    pub steps: Vec<DmrgStepRecord<R>>,
148}
149
150/// Final result of the 2-site DMRG sweep driver [`sweep_2site`].
151#[derive(Debug, Clone)]
152pub struct DmrgResult<R> {
153    /// Last cycle's `sweep_energy`.
154    pub energy: R,
155    /// `true` iff the final cycle satisfied:
156    /// `n_sweeps >= min_sweeps`,
157    /// `|delta_E| <= energy_tol`,
158    /// and every step's local eigensolver converged.
159    pub converged: bool,
160    /// Number of sweep cycles executed.
161    pub n_sweeps: usize,
162    /// Per-cycle diagnostic records, in execution order.
163    pub sweeps: Vec<DmrgSweepRecord<R>>,
164}
165
166/// Errors raised by the 2-site DMRG sweep driver [`sweep_2site`].
167#[derive(Debug, thiserror::Error)]
168#[non_exhaustive]
169pub enum DmrgSweepError {
170    /// MPS, MPO, and `BraketEnvs` disagree on `n_sites`.
171    #[error("chain length mismatch: mps = {mps}, mpo = {mpo}, envs = {envs}")]
172    LengthMismatch {
173        /// `n_sites` reported by the MPS.
174        mps: usize,
175        /// `n_sites` reported by the MPO.
176        mpo: usize,
177        /// `n_sites` reported by the environments.
178        envs: usize,
179    },
180    /// `n_sites < 2`. 2-site sweeps require at least 2 sites.
181    #[error("2-site sweep requires n_sites >= 2, got {n_sites}")]
182    TooFewSites {
183        /// The (too-small) site count supplied.
184        n_sites: usize,
185    },
186    /// `DmrgSweepParams` failed entry-point validation. `detail`
187    /// names the constraint that fired.
188    #[error("invalid DmrgSweepParams: {detail}")]
189    InvalidParams {
190        /// Names the validation constraint that failed.
191        detail: &'static str,
192    },
193    /// MPS canonical form was not `Right` or `Mixed { center: 0 }`.
194    /// `Unknown` is also rejected — see the module-level docs for
195    /// the rationale.
196    #[error("MPS must be in Right or Mixed {{ center: 0 }} form before sweep, got {found:?}")]
197    MpsNotRightCanonical {
198        /// The canonical form actually found.
199        found: CanonicalForm,
200    },
201    /// The per-step driver (`dmrg_2site_step` or
202    /// `dmrg_2site_step_block_sparse`) returned an error. Source
203    /// preserved.
204    #[error("2-site DMRG step failed at sweep {sweep}, {direction:?}, site {site}")]
205    Step {
206        /// Sweep cycle where the failure occurred.
207        sweep: usize,
208        /// Half-sweep direction at the failure.
209        direction: SweepDirection,
210        /// Left site index at the failure.
211        site: usize,
212        /// The underlying per-step error.
213        #[source]
214        source: DmrgHeffError,
215    },
216    /// `BraketEnvs::advance_left/right` returned an error during a
217    /// post-step env update. Surfaced separately from `Step` so
218    /// the caller can distinguish "local solve failed" from "env
219    /// state became inconsistent". Defense-in-depth — under the
220    /// driver's own advance ordering, this branch should never
221    /// fire from the public API.
222    #[error("BraketEnvs advance failed at sweep {sweep}, {direction:?}, site {site}")]
223    Env {
224        /// Sweep cycle where the failure occurred.
225        sweep: usize,
226        /// Half-sweep direction at the failure.
227        direction: SweepDirection,
228        /// Left site index at the failure.
229        site: usize,
230        /// The underlying environment-advance error.
231        #[source]
232        source: BraketEnvError,
233    },
234    /// The post-step S-absorb (`ariadnetor_linalg::diagonal_scale`, which
235    /// dispatches over layout for both Dense and BlockSparse) failed.
236    /// Carries the same `(sweep, direction, site)`
237    /// breadcrumbs as `Step` / `Env` so the caller can pin down
238    /// where the failure occurred without having to walk the
239    /// `DmrgResult::sweeps` history manually.
240    #[error("S-absorb (diagonal scale) failed during sweep {sweep}, {direction:?}, site {site}")]
241    Scale {
242        /// Sweep cycle where the failure occurred.
243        sweep: usize,
244        /// Half-sweep direction at the failure.
245        direction: SweepDirection,
246        /// Left site index at the failure.
247        site: usize,
248        /// The underlying diagonal-scale (linalg) error.
249        #[source]
250        source: LinalgError,
251    },
252}
253
254/// Run alternating L→R / R→L sweeps until convergence or
255/// `max_sweeps` over a [`DmrgOps`] chain. Mutates `mps` and
256/// `envs` in place; the final MPS state is
257/// `CanonicalForm::Mixed { center: 0 }` (R→L runs last).
258///
259/// Generic over the `Mps<St, L>` chain (`Mps<St, L>: DmrgOps<T>`), so a
260/// single call site covers both the Dense and BlockSparse / U(1) paths.
261/// The trait dispatches the local solve and S-absorb to the
262/// storage-specific implementations.
263///
264/// The driver is host-pinned in this stage: the local solve, the
265/// S-absorb, and the post-sweep `braket` / `norm` all route their
266/// backend-dependent work through the [`Host`] substrate
267/// (`Host::shared()`), so callers supply host-resident MPS / MPO / env
268/// state. Generic non-host-backend DMRG is a separate, later track.
269///
270/// See the module-level rustdoc for the canonical-form contract on
271/// the input MPS and for the convergence criterion.
272pub fn sweep_2site<T, St, L>(
273    envs: &mut BraketEnvs<St, L>,
274    mps: &mut Mps<St, L>,
275    mpo: &Mpo<St, L>,
276    params: &DmrgSweepParams,
277) -> Result<DmrgResult<T::Real>, DmrgSweepError>
278where
279    T: Scalar,
280    T::Real: Scalar<Real = T::Real>,
281    St: Storage + StorageFor<L>,
282    L: TensorLayout,
283    Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
284    BraketEnvs<St, L>: BraketEnvOps<T, Storage = St, Layout = L>,
285    // Host-pinned: the host backend supplies every kernel, so it must declare
286    // capability for this chain's storage (satisfied by Dense / BlockSparse).
287    Host: OpsFor<St>,
288{
289    // ---- Length / size validation -------------------------------
290    let n_sites = envs.n_sites();
291    if mps.len() != n_sites || mpo.len() != n_sites {
292        return Err(DmrgSweepError::LengthMismatch {
293            mps: mps.len(),
294            mpo: mpo.len(),
295            envs: n_sites,
296        });
297    }
298    if n_sites < 2 {
299        return Err(DmrgSweepError::TooFewSites { n_sites });
300    }
301
302    // ---- Param validation ---------------------------------------
303    validate_params(params)?;
304    // Casts may fail when the real scalar type (`T::Real`) is `f32`
305    // and the user supplied a
306    // finite value outside f32 range (NumCast::from returns Some(inf),
307    // which try_real_from_f64 then maps to None). Surface that as
308    // `InvalidParams` so the public API stays fallible end-to-end
309    // instead of failing inside the local eigensolver (Lanczos's
310    // internal `try_real_from_f64` would panic; ARPACK's `tol_real`
311    // cast would also panic). The selected eigensolver's `tol` is
312    // gated here too so a borderline f32 tol does not slip past
313    // sweep-level validation only to abort the run from inside the
314    // local solve.
315    let energy_tol_real: T::Real =
316        try_real_from_f64::<T>(params.energy_tol).ok_or(DmrgSweepError::InvalidParams {
317            detail: "energy_tol is not representable in the storage's real scalar type",
318        })?;
319    if try_real_from_f64::<T>(eigensolver_tol(&params.eigensolver)).is_none() {
320        return Err(DmrgSweepError::InvalidParams {
321            detail: "eigensolver tol is not representable in the storage's real scalar type",
322        });
323    }
324
325    // ---- Canonical-form contract --------------------------------
326    match mps.canonical_form() {
327        CanonicalForm::Right => {}
328        CanonicalForm::Mixed { center: 0 } => {}
329        other => {
330            return Err(DmrgSweepError::MpsNotRightCanonical {
331                found: other.clone(),
332            });
333        }
334    }
335
336    // DMRG is host-pinned in the CPU-only Stage B scope; the whole sweep
337    // boundary runs on the host substrate.
338    let backend = Host::shared();
339    let mut sweeps: Vec<DmrgSweepRecord<T::Real>> = Vec::with_capacity(params.max_sweeps);
340    let mut last_energy: Option<T::Real> = None;
341    let mut converged = false;
342    let mut completed_sweeps = 0usize;
343
344    // ---- Main sweep loop ----------------------------------------
345    for sweep_idx in 0..params.max_sweeps {
346        let mut steps: Vec<DmrgStepRecord<T::Real>> = Vec::with_capacity(2 * (n_sites - 1));
347
348        // L→R half-sweep.
349        for site in 0..n_sites - 1 {
350            let record = run_step(
351                envs,
352                mps,
353                mpo,
354                site,
355                params,
356                sweep_idx,
357                SweepDirection::LeftToRight,
358            )?;
359            steps.push(record);
360
361            // Skip the trailing advance: site = n_sites - 2 is
362            // the last L→R step; the next R→L step at the same
363            // `site` consumes `left(site)` which is still valid.
364            if site < n_sites - 2 {
365                envs.advance_left::<T>(mps, mpo, mps, site)
366                    .map_err(|source| DmrgSweepError::Env {
367                        sweep: sweep_idx,
368                        direction: SweepDirection::LeftToRight,
369                        site,
370                        source,
371                    })?;
372            }
373        }
374
375        // R→L half-sweep.
376        for site in (0..n_sites - 1).rev() {
377            let record = run_step(
378                envs,
379                mps,
380                mpo,
381                site,
382                params,
383                sweep_idx,
384                SweepDirection::RightToLeft,
385            )?;
386            steps.push(record);
387
388            // Always advance, including at the trailing `site == 0`
389            // boundary. Skipping the boundary advance would leave
390            // `right[1]` stale-but-`Some` (it would still hold the
391            // pre-sweep `BraketEnvs::build` value, computed against
392            // the original MPS site 1) and `left[1]` stale-but-
393            // `Some` (it would still hold the L→R-time value, even
394            // though R→L has further mutated MPS site 0). Both are
395            // contract violations against
396            // `BraketEnvs`'s "stale = None" convention even though
397            // they do not affect the next sweep iteration's
398            // numerics, which overwrites `left[1]` via
399            // `advance_left(0)` before consumption.
400            envs.advance_right::<T>(mps, mpo, mps, site + 1)
401                .map_err(|source| DmrgSweepError::Env {
402                    sweep: sweep_idx,
403                    direction: SweepDirection::RightToLeft,
404                    site,
405                    source,
406                })?;
407        }
408
409        // R→L ends with the orthogonality center at site 0
410        // (S absorbed leftward at every step). storage_mut reset
411        // the form to Unknown along the way; re-pin it here so a
412        // caller breaking out mid-loop sees a coherent state.
413        mps.set_canonical_form(CanonicalForm::Mixed { center: 0 });
414
415        // ---- Post-sweep diagnostics -----------------------------
416        let bra_h_ket: T = braket(backend.as_ref(), mps, mpo, mps);
417        let nrm: T::Real = mps.norm(backend.as_ref());
418        let nrm_sq: T::Real = nrm * nrm;
419        // The Float bound on the storage's real scalar type guarantees division.
420        let sweep_energy: T::Real = bra_h_ket.re() / nrm_sq;
421
422        let max_bond = mps.max_bond_dim();
423        let mut min_eig = steps[0].eigenvalue;
424        let mut max_te = steps[0].trunc_err;
425        let mut all_ok = true;
426        for s in &steps {
427            if s.eigenvalue < min_eig {
428                min_eig = s.eigenvalue;
429            }
430            if s.trunc_err > max_te {
431                max_te = s.trunc_err;
432            }
433            if !s.eigensolver_converged {
434                all_ok = false;
435            }
436        }
437
438        sweeps.push(DmrgSweepRecord {
439            sweep: sweep_idx,
440            sweep_energy,
441            min_step_eigenvalue: min_eig,
442            max_trunc_err: max_te,
443            max_bond,
444            all_eigensolver_converged: all_ok,
445            steps,
446        });
447
448        completed_sweeps = sweep_idx + 1;
449
450        // ---- Convergence check ----------------------------------
451        if completed_sweeps >= params.min_sweeps
452            && let Some(prev) = last_energy
453        {
454            let abs_delta = (sweep_energy - prev).abs();
455            if abs_delta <= energy_tol_real && all_ok {
456                converged = true;
457                break;
458            }
459        }
460        last_energy = Some(sweep_energy);
461    }
462
463    let final_energy = sweeps
464        .last()
465        .map(|s| s.sweep_energy)
466        .expect("at least one sweep ran (max_sweeps >= 1 by validation)");
467
468    Ok(DmrgResult {
469        energy: final_energy,
470        converged,
471        n_sweeps: completed_sweeps,
472        sweeps,
473    })
474}
475
476/// Run a single 2-site step at `site` over a [`DmrgOps`] storage,
477/// then mutate the MPS site tensors at `site` and `site + 1`
478/// according to `direction`.
479///
480/// Returns the step's diagnostics record. The caller is responsible
481/// for advancing the env afterwards (or skipping the advance at the
482/// trailing boundary of a half-sweep).
483fn run_step<T, St, L>(
484    envs: &BraketEnvs<St, L>,
485    mps: &mut Mps<St, L>,
486    mpo: &Mpo<St, L>,
487    site: usize,
488    params: &DmrgSweepParams,
489    sweep_idx: usize,
490    direction: SweepDirection,
491) -> Result<DmrgStepRecord<T::Real>, DmrgSweepError>
492where
493    T: Scalar,
494    T::Real: Scalar<Real = T::Real>,
495    St: Storage + StorageFor<L>,
496    L: TensorLayout,
497    Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
498    BraketEnvs<St, L>: BraketEnvOps<T, Storage = St, Layout = L>,
499    // Host-pinned: the host backend supplies every kernel, so it must declare
500    // capability for this chain's storage (satisfied by Dense / BlockSparse).
501    Host: OpsFor<St>,
502{
503    // The fused step builds `H_eff`, solves, projects the diagnostics, and
504    // absorbs `S` in one host-pinned call; mapping `FullStepError` keeps the
505    // `Step` (local solve) vs `Scale` (S-absorb) breadcrumb distinction.
506    let (absorbed, (eigenvalue, residual, trunc_err, iters, converged)) = mps
507        .full_step_k(
508            envs,
509            mpo,
510            site,
511            &params.eigensolver,
512            &params.trunc,
513            direction,
514        )
515        .map_err(|source| match source {
516            FullStepError::Heff(source) => DmrgSweepError::Step {
517                sweep: sweep_idx,
518                direction,
519                site,
520                source,
521            },
522            FullStepError::Scale(source) => DmrgSweepError::Scale {
523                sweep: sweep_idx,
524                direction,
525                site,
526                source,
527            },
528        })?;
529
530    *mps.site_mut(site) = absorbed.site_i;
531    *mps.site_mut(site + 1) = absorbed.site_ip1;
532
533    Ok(DmrgStepRecord {
534        sweep: sweep_idx,
535        direction,
536        site,
537        eigenvalue,
538        residual,
539        trunc_err,
540        bond_dim: absorbed.bond_dim,
541        eigensolver_iters: iters,
542        eigensolver_converged: converged,
543    })
544}
545
546pub(super) fn validate_params(params: &DmrgSweepParams) -> Result<(), DmrgSweepError> {
547    if params.max_sweeps == 0 {
548        return Err(DmrgSweepError::InvalidParams {
549            detail: "max_sweeps must be >= 1",
550        });
551    }
552    if params.min_sweeps > params.max_sweeps {
553        return Err(DmrgSweepError::InvalidParams {
554            detail: "min_sweeps must be <= max_sweeps",
555        });
556    }
557    if !params.energy_tol.is_finite() {
558        return Err(DmrgSweepError::InvalidParams {
559            detail: "energy_tol must be finite",
560        });
561    }
562    if params.energy_tol < 0.0 {
563        return Err(DmrgSweepError::InvalidParams {
564            detail: "energy_tol must be non-negative",
565        });
566    }
567    validate_eigensolver_params(&params.eigensolver)
568        .map_err(|detail| DmrgSweepError::InvalidParams { detail })?;
569    if let Some(chi) = params.trunc.chi_max
570        && chi == 0
571    {
572        return Err(DmrgSweepError::InvalidParams {
573            detail: "trunc.chi_max must be > 0 if Some",
574        });
575    }
576    if let Some(te) = params.trunc.target_trunc_err {
577        if !te.is_finite() {
578            return Err(DmrgSweepError::InvalidParams {
579                detail: "trunc.target_trunc_err must be finite",
580            });
581        }
582        if te < 0.0 {
583            return Err(DmrgSweepError::InvalidParams {
584                detail: "trunc.target_trunc_err must be non-negative",
585            });
586        }
587    }
588    Ok(())
589}