ariadnetor_algorithms/dmrg/sweep.rs
1//! 2-site DMRG sweep driver.
2//!
3//! Runs alternating L→R / R→L half-sweeps over a [`super::DmrgOps`]
4//! chain, dispatching the storage-specific local solve and
5//! S-absorb through the trait. Mutates the MPS site tensors and the
6//! [`BraketEnvs`] in place. Caller supplies a right-canonical (or
7//! `Mixed { center: 0 }`) MPS plus a freshly-built `BraketEnvs`; the
8//! driver does **not** auto-canonicalize because doing so would
9//! silently invalidate the caller-supplied envs.
10//!
11//! # Canonical-form precondition
12//!
13//! The local effective-Hamiltonian eigenvalue equation returns
14//! physical energy directly only when, with active block `(i, i+1)`,
15//! sites `(0..i)` are left-canonical and sites `(i+2..N-1)` are
16//! right-canonical. The driver starts L→R from `i = 0`, so the
17//! binding precondition is right-canonicality of `(2..N-1)`, met
18//! exactly by `Right` and `Mixed { center: 0 }`. This argument is
19//! storage-independent (the local-block requirement does not depend
20//! on Dense vs BlockSparse representation), so the gate applies
21//! uniformly through the trait.
22//!
23//! # Convergence
24//!
25//! After each full L→R + R→L cycle we record the **normalized**
26//! post-truncation expectation `<psi|H|psi>.re() / <psi|psi>`.
27//! The truncated SVD keeps unrenormalized singular values, so without
28//! the `<psi|psi>` divisor the sweep energy drifts toward zero
29//! whenever truncation happens — the divisor strips that
30//! norm-artifact away. Convergence requires energy delta within
31//! `energy_tol`, every step's local eigensolver converged, and
32//! `n_sweeps >= min_sweeps`.
33
34use ariadnetor_core::Scalar;
35use ariadnetor_linalg::LinalgError;
36use ariadnetor_mps::{CanonicalForm, Mpo, Mps, MpsOps, TensorChain, braket};
37use ariadnetor_tensor::{Host, OpsFor, Storage, StorageFor, TensorLayout};
38
39use crate::numeric::try_real_from_f64;
40
41use super::dispatch::{DmrgOps, FullStepError};
42use super::heff_error::DmrgHeffError;
43use super::solver::{LocalEigensolverParams, eigensolver_tol, validate_eigensolver_params};
44use ariadnetor_mps::{BraketEnvError, BraketEnvOps, BraketEnvs};
45
46/// Direction of a half-sweep.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum SweepDirection {
49 /// Steps from `site = 0` to `site = n_sites - 2`. Each step's
50 /// SVD splits the optimized 2-site block into a left-isometric
51 /// site `i` and an `S·Vt`-carrying site `i + 1`.
52 LeftToRight,
53 /// Steps from `site = n_sites - 2` down to `site = 0`. Each
54 /// step's SVD splits into a `U·S`-carrying site `i` and a
55 /// right-isometric site `i + 1`.
56 RightToLeft,
57}
58
59/// Caller-supplied parameters for [`sweep_2site`] (the chain-generic
60/// 2-site DMRG sweep driver, dispatched over `Mps<St, L>: super::DmrgOps<T>`
61/// so the same params type covers both the Dense and BlockSparse /
62/// U(1) paths).
63///
64/// Stored as plain `f64` for `energy_tol`; the entry point converts
65/// it to `T::Real` via the same `NumCast::from(f64)` idiom as
66/// [`crate::krylov::LanczosParams::tol`]. This keeps the params
67/// type non-generic across the scalar type.
68#[derive(Debug, Clone)]
69pub struct DmrgSweepParams {
70 /// Maximum number of full L→R + R→L cycles. Must be `>= 1`;
71 /// `0` is rejected with [`DmrgSweepError::InvalidParams`].
72 pub max_sweeps: usize,
73 /// Minimum number of sweeps before the energy-delta convergence
74 /// test is honored. Pre-`min_sweeps` cycles always continue
75 /// regardless of energy delta. Must be `<= max_sweeps`.
76 pub min_sweeps: usize,
77 /// Energy-delta tolerance. After cycle `n >= min_sweeps`,
78 /// convergence requires `|E_n - E_{n-1}| <= energy_tol`. Must
79 /// be finite and non-negative.
80 pub energy_tol: f64,
81 /// Local-eigensolver selector + per-variant parameters, forwarded
82 /// to the per-step driver (Dense `dmrg_2site_step` or BlockSparse
83 /// `dmrg_2site_step_block_sparse`).
84 pub eigensolver: LocalEigensolverParams,
85 /// Truncated-SVD parameters, forwarded to the per-step driver
86 /// (Dense `dmrg_2site_step` or BlockSparse
87 /// `dmrg_2site_step_block_sparse`).
88 pub trunc: ariadnetor_linalg::TruncSvdParams,
89}
90
91/// Per-step diagnostics record.
92#[derive(Debug, Clone)]
93pub struct DmrgStepRecord<R> {
94 /// Index of the sweep cycle this step belongs to.
95 pub sweep: usize,
96 /// Direction of the half-sweep (`L→R` or `R→L`).
97 pub direction: SweepDirection,
98 /// Left site of the optimized two-site block.
99 pub site: usize,
100 /// Smallest eigenvalue of `H_eff` at this step (pre-truncation
101 /// local-block variational minimum). May lie below the
102 /// post-truncation sweep energy.
103 pub eigenvalue: R,
104 /// Local-eigensolver true residual `‖H v − λ v‖₂`.
105 pub residual: R,
106 /// Frobenius norm of singular values discarded by this step's
107 /// truncated SVD.
108 pub trunc_err: R,
109 /// New bond dimension between `site` and `site + 1` after the
110 /// truncated split.
111 pub bond_dim: usize,
112 /// Number of iterations the local eigensolver ran for this step.
113 /// For Lanczos this is the inner loop count; for ARPACK it is
114 /// the restart-iteration count returned in `iparam[2]`.
115 pub eigensolver_iters: usize,
116 /// `true` iff the local eigensolver succeeded — Lanczos by its
117 /// absolute true-residual test against `LanczosParams::tol`,
118 /// ARPACK by its relative-tol stopping criterion (i.e. `Ok`
119 /// return from `arpack_smallest`). The two arms intentionally
120 /// disagree on what they call "converged": Lanczos uses the
121 /// absolute residual; ARPACK uses `residual <= tol * |lambda|`.
122 /// See [`super::heff::TwoSiteStepResult::converged`] for the
123 /// upstream contract this field forwards from.
124 pub eigensolver_converged: bool,
125}
126
127/// Per-sweep diagnostics record (one full L→R + R→L cycle).
128#[derive(Debug, Clone)]
129pub struct DmrgSweepRecord<R> {
130 /// Index of this sweep cycle.
131 pub sweep: usize,
132 /// Normalized post-truncation `<psi|H|psi> / <psi|psi>` after
133 /// this cycle. The convergence metric.
134 pub sweep_energy: R,
135 /// `min(step.eigenvalue)` across this cycle. Diagnostic only —
136 /// reflects local-block variational minima, which can be lower
137 /// than `sweep_energy`.
138 pub min_step_eigenvalue: R,
139 /// Largest per-step truncation error in this cycle.
140 pub max_trunc_err: R,
141 /// Largest bond dimension reached in this cycle.
142 pub max_bond: usize,
143 /// `true` iff every step in this cycle's local-eigensolver pass
144 /// converged.
145 pub all_eigensolver_converged: bool,
146 /// Per-step diagnostic records for this cycle, in execution order.
147 pub steps: Vec<DmrgStepRecord<R>>,
148}
149
150/// Final result of the 2-site DMRG sweep driver [`sweep_2site`].
151#[derive(Debug, Clone)]
152pub struct DmrgResult<R> {
153 /// Last cycle's `sweep_energy`.
154 pub energy: R,
155 /// `true` iff the final cycle satisfied:
156 /// `n_sweeps >= min_sweeps`,
157 /// `|delta_E| <= energy_tol`,
158 /// and every step's local eigensolver converged.
159 pub converged: bool,
160 /// Number of sweep cycles executed.
161 pub n_sweeps: usize,
162 /// Per-cycle diagnostic records, in execution order.
163 pub sweeps: Vec<DmrgSweepRecord<R>>,
164}
165
166/// Errors raised by the 2-site DMRG sweep driver [`sweep_2site`].
167#[derive(Debug, thiserror::Error)]
168#[non_exhaustive]
169pub enum DmrgSweepError {
170 /// MPS, MPO, and `BraketEnvs` disagree on `n_sites`.
171 #[error("chain length mismatch: mps = {mps}, mpo = {mpo}, envs = {envs}")]
172 LengthMismatch {
173 /// `n_sites` reported by the MPS.
174 mps: usize,
175 /// `n_sites` reported by the MPO.
176 mpo: usize,
177 /// `n_sites` reported by the environments.
178 envs: usize,
179 },
180 /// `n_sites < 2`. 2-site sweeps require at least 2 sites.
181 #[error("2-site sweep requires n_sites >= 2, got {n_sites}")]
182 TooFewSites {
183 /// The (too-small) site count supplied.
184 n_sites: usize,
185 },
186 /// `DmrgSweepParams` failed entry-point validation. `detail`
187 /// names the constraint that fired.
188 #[error("invalid DmrgSweepParams: {detail}")]
189 InvalidParams {
190 /// Names the validation constraint that failed.
191 detail: &'static str,
192 },
193 /// MPS canonical form was not `Right` or `Mixed { center: 0 }`.
194 /// `Unknown` is also rejected — see the module-level docs for
195 /// the rationale.
196 #[error("MPS must be in Right or Mixed {{ center: 0 }} form before sweep, got {found:?}")]
197 MpsNotRightCanonical {
198 /// The canonical form actually found.
199 found: CanonicalForm,
200 },
201 /// The per-step driver (`dmrg_2site_step` or
202 /// `dmrg_2site_step_block_sparse`) returned an error. Source
203 /// preserved.
204 #[error("2-site DMRG step failed at sweep {sweep}, {direction:?}, site {site}")]
205 Step {
206 /// Sweep cycle where the failure occurred.
207 sweep: usize,
208 /// Half-sweep direction at the failure.
209 direction: SweepDirection,
210 /// Left site index at the failure.
211 site: usize,
212 /// The underlying per-step error.
213 #[source]
214 source: DmrgHeffError,
215 },
216 /// `BraketEnvs::advance_left/right` returned an error during a
217 /// post-step env update. Surfaced separately from `Step` so
218 /// the caller can distinguish "local solve failed" from "env
219 /// state became inconsistent". Defense-in-depth — under the
220 /// driver's own advance ordering, this branch should never
221 /// fire from the public API.
222 #[error("BraketEnvs advance failed at sweep {sweep}, {direction:?}, site {site}")]
223 Env {
224 /// Sweep cycle where the failure occurred.
225 sweep: usize,
226 /// Half-sweep direction at the failure.
227 direction: SweepDirection,
228 /// Left site index at the failure.
229 site: usize,
230 /// The underlying environment-advance error.
231 #[source]
232 source: BraketEnvError,
233 },
234 /// The post-step S-absorb (`ariadnetor_linalg::diagonal_scale`, which
235 /// dispatches over layout for both Dense and BlockSparse) failed.
236 /// Carries the same `(sweep, direction, site)`
237 /// breadcrumbs as `Step` / `Env` so the caller can pin down
238 /// where the failure occurred without having to walk the
239 /// `DmrgResult::sweeps` history manually.
240 #[error("S-absorb (diagonal scale) failed during sweep {sweep}, {direction:?}, site {site}")]
241 Scale {
242 /// Sweep cycle where the failure occurred.
243 sweep: usize,
244 /// Half-sweep direction at the failure.
245 direction: SweepDirection,
246 /// Left site index at the failure.
247 site: usize,
248 /// The underlying diagonal-scale (linalg) error.
249 #[source]
250 source: LinalgError,
251 },
252}
253
254/// Run alternating L→R / R→L sweeps until convergence or
255/// `max_sweeps` over a [`DmrgOps`] chain. Mutates `mps` and
256/// `envs` in place; the final MPS state is
257/// `CanonicalForm::Mixed { center: 0 }` (R→L runs last).
258///
259/// Generic over the `Mps<St, L>` chain (`Mps<St, L>: DmrgOps<T>`), so a
260/// single call site covers both the Dense and BlockSparse / U(1) paths.
261/// The trait dispatches the local solve and S-absorb to the
262/// storage-specific implementations.
263///
264/// The driver is host-pinned in this stage: the local solve, the
265/// S-absorb, and the post-sweep `braket` / `norm` all route their
266/// backend-dependent work through the [`Host`] substrate
267/// (`Host::shared()`), so callers supply host-resident MPS / MPO / env
268/// state. Generic non-host-backend DMRG is a separate, later track.
269///
270/// See the module-level rustdoc for the canonical-form contract on
271/// the input MPS and for the convergence criterion.
272pub fn sweep_2site<T, St, L>(
273 envs: &mut BraketEnvs<St, L>,
274 mps: &mut Mps<St, L>,
275 mpo: &Mpo<St, L>,
276 params: &DmrgSweepParams,
277) -> Result<DmrgResult<T::Real>, DmrgSweepError>
278where
279 T: Scalar,
280 T::Real: Scalar<Real = T::Real>,
281 St: Storage + StorageFor<L>,
282 L: TensorLayout,
283 Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
284 BraketEnvs<St, L>: BraketEnvOps<T, Storage = St, Layout = L>,
285 // Host-pinned: the host backend supplies every kernel, so it must declare
286 // capability for this chain's storage (satisfied by Dense / BlockSparse).
287 Host: OpsFor<St>,
288{
289 // ---- Length / size validation -------------------------------
290 let n_sites = envs.n_sites();
291 if mps.len() != n_sites || mpo.len() != n_sites {
292 return Err(DmrgSweepError::LengthMismatch {
293 mps: mps.len(),
294 mpo: mpo.len(),
295 envs: n_sites,
296 });
297 }
298 if n_sites < 2 {
299 return Err(DmrgSweepError::TooFewSites { n_sites });
300 }
301
302 // ---- Param validation ---------------------------------------
303 validate_params(params)?;
304 // Casts may fail when the real scalar type (`T::Real`) is `f32`
305 // and the user supplied a
306 // finite value outside f32 range (NumCast::from returns Some(inf),
307 // which try_real_from_f64 then maps to None). Surface that as
308 // `InvalidParams` so the public API stays fallible end-to-end
309 // instead of failing inside the local eigensolver (Lanczos's
310 // internal `try_real_from_f64` would panic; ARPACK's `tol_real`
311 // cast would also panic). The selected eigensolver's `tol` is
312 // gated here too so a borderline f32 tol does not slip past
313 // sweep-level validation only to abort the run from inside the
314 // local solve.
315 let energy_tol_real: T::Real =
316 try_real_from_f64::<T>(params.energy_tol).ok_or(DmrgSweepError::InvalidParams {
317 detail: "energy_tol is not representable in the storage's real scalar type",
318 })?;
319 if try_real_from_f64::<T>(eigensolver_tol(¶ms.eigensolver)).is_none() {
320 return Err(DmrgSweepError::InvalidParams {
321 detail: "eigensolver tol is not representable in the storage's real scalar type",
322 });
323 }
324
325 // ---- Canonical-form contract --------------------------------
326 match mps.canonical_form() {
327 CanonicalForm::Right => {}
328 CanonicalForm::Mixed { center: 0 } => {}
329 other => {
330 return Err(DmrgSweepError::MpsNotRightCanonical {
331 found: other.clone(),
332 });
333 }
334 }
335
336 // DMRG is host-pinned in the CPU-only Stage B scope; the whole sweep
337 // boundary runs on the host substrate.
338 let backend = Host::shared();
339 let mut sweeps: Vec<DmrgSweepRecord<T::Real>> = Vec::with_capacity(params.max_sweeps);
340 let mut last_energy: Option<T::Real> = None;
341 let mut converged = false;
342 let mut completed_sweeps = 0usize;
343
344 // ---- Main sweep loop ----------------------------------------
345 for sweep_idx in 0..params.max_sweeps {
346 let mut steps: Vec<DmrgStepRecord<T::Real>> = Vec::with_capacity(2 * (n_sites - 1));
347
348 // L→R half-sweep.
349 for site in 0..n_sites - 1 {
350 let record = run_step(
351 envs,
352 mps,
353 mpo,
354 site,
355 params,
356 sweep_idx,
357 SweepDirection::LeftToRight,
358 )?;
359 steps.push(record);
360
361 // Skip the trailing advance: site = n_sites - 2 is
362 // the last L→R step; the next R→L step at the same
363 // `site` consumes `left(site)` which is still valid.
364 if site < n_sites - 2 {
365 envs.advance_left::<T>(mps, mpo, mps, site)
366 .map_err(|source| DmrgSweepError::Env {
367 sweep: sweep_idx,
368 direction: SweepDirection::LeftToRight,
369 site,
370 source,
371 })?;
372 }
373 }
374
375 // R→L half-sweep.
376 for site in (0..n_sites - 1).rev() {
377 let record = run_step(
378 envs,
379 mps,
380 mpo,
381 site,
382 params,
383 sweep_idx,
384 SweepDirection::RightToLeft,
385 )?;
386 steps.push(record);
387
388 // Always advance, including at the trailing `site == 0`
389 // boundary. Skipping the boundary advance would leave
390 // `right[1]` stale-but-`Some` (it would still hold the
391 // pre-sweep `BraketEnvs::build` value, computed against
392 // the original MPS site 1) and `left[1]` stale-but-
393 // `Some` (it would still hold the L→R-time value, even
394 // though R→L has further mutated MPS site 0). Both are
395 // contract violations against
396 // `BraketEnvs`'s "stale = None" convention even though
397 // they do not affect the next sweep iteration's
398 // numerics, which overwrites `left[1]` via
399 // `advance_left(0)` before consumption.
400 envs.advance_right::<T>(mps, mpo, mps, site + 1)
401 .map_err(|source| DmrgSweepError::Env {
402 sweep: sweep_idx,
403 direction: SweepDirection::RightToLeft,
404 site,
405 source,
406 })?;
407 }
408
409 // R→L ends with the orthogonality center at site 0
410 // (S absorbed leftward at every step). storage_mut reset
411 // the form to Unknown along the way; re-pin it here so a
412 // caller breaking out mid-loop sees a coherent state.
413 mps.set_canonical_form(CanonicalForm::Mixed { center: 0 });
414
415 // ---- Post-sweep diagnostics -----------------------------
416 let bra_h_ket: T = braket(backend.as_ref(), mps, mpo, mps);
417 let nrm: T::Real = mps.norm(backend.as_ref());
418 let nrm_sq: T::Real = nrm * nrm;
419 // The Float bound on the storage's real scalar type guarantees division.
420 let sweep_energy: T::Real = bra_h_ket.re() / nrm_sq;
421
422 let max_bond = mps.max_bond_dim();
423 let mut min_eig = steps[0].eigenvalue;
424 let mut max_te = steps[0].trunc_err;
425 let mut all_ok = true;
426 for s in &steps {
427 if s.eigenvalue < min_eig {
428 min_eig = s.eigenvalue;
429 }
430 if s.trunc_err > max_te {
431 max_te = s.trunc_err;
432 }
433 if !s.eigensolver_converged {
434 all_ok = false;
435 }
436 }
437
438 sweeps.push(DmrgSweepRecord {
439 sweep: sweep_idx,
440 sweep_energy,
441 min_step_eigenvalue: min_eig,
442 max_trunc_err: max_te,
443 max_bond,
444 all_eigensolver_converged: all_ok,
445 steps,
446 });
447
448 completed_sweeps = sweep_idx + 1;
449
450 // ---- Convergence check ----------------------------------
451 if completed_sweeps >= params.min_sweeps
452 && let Some(prev) = last_energy
453 {
454 let abs_delta = (sweep_energy - prev).abs();
455 if abs_delta <= energy_tol_real && all_ok {
456 converged = true;
457 break;
458 }
459 }
460 last_energy = Some(sweep_energy);
461 }
462
463 let final_energy = sweeps
464 .last()
465 .map(|s| s.sweep_energy)
466 .expect("at least one sweep ran (max_sweeps >= 1 by validation)");
467
468 Ok(DmrgResult {
469 energy: final_energy,
470 converged,
471 n_sweeps: completed_sweeps,
472 sweeps,
473 })
474}
475
476/// Run a single 2-site step at `site` over a [`DmrgOps`] storage,
477/// then mutate the MPS site tensors at `site` and `site + 1`
478/// according to `direction`.
479///
480/// Returns the step's diagnostics record. The caller is responsible
481/// for advancing the env afterwards (or skipping the advance at the
482/// trailing boundary of a half-sweep).
483fn run_step<T, St, L>(
484 envs: &BraketEnvs<St, L>,
485 mps: &mut Mps<St, L>,
486 mpo: &Mpo<St, L>,
487 site: usize,
488 params: &DmrgSweepParams,
489 sweep_idx: usize,
490 direction: SweepDirection,
491) -> Result<DmrgStepRecord<T::Real>, DmrgSweepError>
492where
493 T: Scalar,
494 T::Real: Scalar<Real = T::Real>,
495 St: Storage + StorageFor<L>,
496 L: TensorLayout,
497 Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
498 BraketEnvs<St, L>: BraketEnvOps<T, Storage = St, Layout = L>,
499 // Host-pinned: the host backend supplies every kernel, so it must declare
500 // capability for this chain's storage (satisfied by Dense / BlockSparse).
501 Host: OpsFor<St>,
502{
503 // The fused step builds `H_eff`, solves, projects the diagnostics, and
504 // absorbs `S` in one host-pinned call; mapping `FullStepError` keeps the
505 // `Step` (local solve) vs `Scale` (S-absorb) breadcrumb distinction.
506 let (absorbed, (eigenvalue, residual, trunc_err, iters, converged)) = mps
507 .full_step_k(
508 envs,
509 mpo,
510 site,
511 ¶ms.eigensolver,
512 ¶ms.trunc,
513 direction,
514 )
515 .map_err(|source| match source {
516 FullStepError::Heff(source) => DmrgSweepError::Step {
517 sweep: sweep_idx,
518 direction,
519 site,
520 source,
521 },
522 FullStepError::Scale(source) => DmrgSweepError::Scale {
523 sweep: sweep_idx,
524 direction,
525 site,
526 source,
527 },
528 })?;
529
530 *mps.site_mut(site) = absorbed.site_i;
531 *mps.site_mut(site + 1) = absorbed.site_ip1;
532
533 Ok(DmrgStepRecord {
534 sweep: sweep_idx,
535 direction,
536 site,
537 eigenvalue,
538 residual,
539 trunc_err,
540 bond_dim: absorbed.bond_dim,
541 eigensolver_iters: iters,
542 eigensolver_converged: converged,
543 })
544}
545
546pub(super) fn validate_params(params: &DmrgSweepParams) -> Result<(), DmrgSweepError> {
547 if params.max_sweeps == 0 {
548 return Err(DmrgSweepError::InvalidParams {
549 detail: "max_sweeps must be >= 1",
550 });
551 }
552 if params.min_sweeps > params.max_sweeps {
553 return Err(DmrgSweepError::InvalidParams {
554 detail: "min_sweeps must be <= max_sweeps",
555 });
556 }
557 if !params.energy_tol.is_finite() {
558 return Err(DmrgSweepError::InvalidParams {
559 detail: "energy_tol must be finite",
560 });
561 }
562 if params.energy_tol < 0.0 {
563 return Err(DmrgSweepError::InvalidParams {
564 detail: "energy_tol must be non-negative",
565 });
566 }
567 validate_eigensolver_params(¶ms.eigensolver)
568 .map_err(|detail| DmrgSweepError::InvalidParams { detail })?;
569 if let Some(chi) = params.trunc.chi_max
570 && chi == 0
571 {
572 return Err(DmrgSweepError::InvalidParams {
573 detail: "trunc.chi_max must be > 0 if Some",
574 });
575 }
576 if let Some(te) = params.trunc.target_trunc_err {
577 if !te.is_finite() {
578 return Err(DmrgSweepError::InvalidParams {
579 detail: "trunc.target_trunc_err must be finite",
580 });
581 }
582 if te < 0.0 {
583 return Err(DmrgSweepError::InvalidParams {
584 detail: "trunc.target_trunc_err must be non-negative",
585 });
586 }
587 }
588 Ok(())
589}