use ariadnetor_core::Scalar;
use ariadnetor_linalg::LinalgError;
use ariadnetor_mps::{CanonicalForm, Mpo, Mps, MpsOps, TensorChain, braket};
use ariadnetor_tensor::{Host, OpsFor, Storage, StorageFor, TensorLayout};
use crate::numeric::try_real_from_f64;
use super::dispatch::{DmrgOps, FullStepError};
use super::env::{DmrgEnvError, DmrgEnvOps, DmrgEnvs};
use super::heff_error::DmrgHeffError;
use super::solver::{LocalEigensolverParams, eigensolver_tol, validate_eigensolver_params};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SweepDirection {
LeftToRight,
RightToLeft,
}
#[derive(Debug, Clone)]
pub struct DmrgSweepParams {
pub max_sweeps: usize,
pub min_sweeps: usize,
pub energy_tol: f64,
pub eigensolver: LocalEigensolverParams,
pub trunc: ariadnetor_linalg::TruncSvdParams,
}
#[derive(Debug, Clone)]
pub struct DmrgStepRecord<R> {
pub sweep: usize,
pub direction: SweepDirection,
pub site: usize,
pub eigenvalue: R,
pub residual: R,
pub trunc_err: R,
pub bond_dim: usize,
pub eigensolver_iters: usize,
pub eigensolver_converged: bool,
}
#[derive(Debug, Clone)]
pub struct DmrgSweepRecord<R> {
pub sweep: usize,
pub sweep_energy: R,
pub min_step_eigenvalue: R,
pub max_trunc_err: R,
pub max_bond: usize,
pub all_eigensolver_converged: bool,
pub steps: Vec<DmrgStepRecord<R>>,
}
#[derive(Debug, Clone)]
pub struct DmrgResult<R> {
pub energy: R,
pub converged: bool,
pub n_sweeps: usize,
pub sweeps: Vec<DmrgSweepRecord<R>>,
}
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum DmrgSweepError {
#[error("chain length mismatch: mps = {mps}, mpo = {mpo}, envs = {envs}")]
LengthMismatch {
mps: usize,
mpo: usize,
envs: usize,
},
#[error("2-site sweep requires n_sites >= 2, got {n_sites}")]
TooFewSites {
n_sites: usize,
},
#[error("invalid DmrgSweepParams: {detail}")]
InvalidParams {
detail: &'static str,
},
#[error("MPS must be in Right or Mixed {{ center: 0 }} form before sweep, got {found:?}")]
MpsNotRightCanonical {
found: CanonicalForm,
},
#[error("2-site DMRG step failed at sweep {sweep}, {direction:?}, site {site}")]
Step {
sweep: usize,
direction: SweepDirection,
site: usize,
#[source]
source: DmrgHeffError,
},
#[error("DmrgEnvs advance failed at sweep {sweep}, {direction:?}, site {site}")]
Env {
sweep: usize,
direction: SweepDirection,
site: usize,
#[source]
source: DmrgEnvError,
},
#[error("S-absorb (diagonal scale) failed during sweep {sweep}, {direction:?}, site {site}")]
Scale {
sweep: usize,
direction: SweepDirection,
site: usize,
#[source]
source: LinalgError,
},
}
pub fn sweep_2site<T, St, L>(
envs: &mut DmrgEnvs<St, L>,
mps: &mut Mps<St, L>,
mpo: &Mpo<St, L>,
params: &DmrgSweepParams,
) -> Result<DmrgResult<T::Real>, DmrgSweepError>
where
T: Scalar,
T::Real: Scalar<Real = T::Real>,
St: Storage + StorageFor<L>,
L: TensorLayout,
Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
DmrgEnvs<St, L>: DmrgEnvOps<T, Storage = St, Layout = L>,
Host: OpsFor<St>,
{
let n_sites = envs.n_sites();
if mps.len() != n_sites || mpo.len() != n_sites {
return Err(DmrgSweepError::LengthMismatch {
mps: mps.len(),
mpo: mpo.len(),
envs: n_sites,
});
}
if n_sites < 2 {
return Err(DmrgSweepError::TooFewSites { n_sites });
}
validate_params(params)?;
let energy_tol_real: T::Real =
try_real_from_f64::<T>(params.energy_tol).ok_or(DmrgSweepError::InvalidParams {
detail: "energy_tol is not representable in the storage's real scalar type",
})?;
if try_real_from_f64::<T>(eigensolver_tol(¶ms.eigensolver)).is_none() {
return Err(DmrgSweepError::InvalidParams {
detail: "eigensolver tol is not representable in the storage's real scalar type",
});
}
match mps.canonical_form() {
CanonicalForm::Right => {}
CanonicalForm::Mixed { center: 0 } => {}
other => {
return Err(DmrgSweepError::MpsNotRightCanonical {
found: other.clone(),
});
}
}
let backend = Host::shared();
let mut sweeps: Vec<DmrgSweepRecord<T::Real>> = Vec::with_capacity(params.max_sweeps);
let mut last_energy: Option<T::Real> = None;
let mut converged = false;
let mut completed_sweeps = 0usize;
for sweep_idx in 0..params.max_sweeps {
let mut steps: Vec<DmrgStepRecord<T::Real>> = Vec::with_capacity(2 * (n_sites - 1));
for site in 0..n_sites - 1 {
let record = run_step(
envs,
mps,
mpo,
site,
params,
sweep_idx,
SweepDirection::LeftToRight,
)?;
steps.push(record);
if site < n_sites - 2 {
envs.advance_left::<T>(mps, mpo, site)
.map_err(|source| DmrgSweepError::Env {
sweep: sweep_idx,
direction: SweepDirection::LeftToRight,
site,
source,
})?;
}
}
for site in (0..n_sites - 1).rev() {
let record = run_step(
envs,
mps,
mpo,
site,
params,
sweep_idx,
SweepDirection::RightToLeft,
)?;
steps.push(record);
envs.advance_right::<T>(mps, mpo, site + 1)
.map_err(|source| DmrgSweepError::Env {
sweep: sweep_idx,
direction: SweepDirection::RightToLeft,
site,
source,
})?;
}
mps.set_canonical_form(CanonicalForm::Mixed { center: 0 });
let bra_h_ket: T = braket(backend.as_ref(), mps, mpo, mps);
let nrm: T::Real = mps.norm(backend.as_ref());
let nrm_sq: T::Real = nrm * nrm;
let sweep_energy: T::Real = bra_h_ket.re() / nrm_sq;
let max_bond = mps.max_bond_dim();
let mut min_eig = steps[0].eigenvalue;
let mut max_te = steps[0].trunc_err;
let mut all_ok = true;
for s in &steps {
if s.eigenvalue < min_eig {
min_eig = s.eigenvalue;
}
if s.trunc_err > max_te {
max_te = s.trunc_err;
}
if !s.eigensolver_converged {
all_ok = false;
}
}
sweeps.push(DmrgSweepRecord {
sweep: sweep_idx,
sweep_energy,
min_step_eigenvalue: min_eig,
max_trunc_err: max_te,
max_bond,
all_eigensolver_converged: all_ok,
steps,
});
completed_sweeps = sweep_idx + 1;
if completed_sweeps >= params.min_sweeps
&& let Some(prev) = last_energy
{
let abs_delta = (sweep_energy - prev).abs();
if abs_delta <= energy_tol_real && all_ok {
converged = true;
break;
}
}
last_energy = Some(sweep_energy);
}
let final_energy = sweeps
.last()
.map(|s| s.sweep_energy)
.expect("at least one sweep ran (max_sweeps >= 1 by validation)");
Ok(DmrgResult {
energy: final_energy,
converged,
n_sweeps: completed_sweeps,
sweeps,
})
}
fn run_step<T, St, L>(
envs: &DmrgEnvs<St, L>,
mps: &mut Mps<St, L>,
mpo: &Mpo<St, L>,
site: usize,
params: &DmrgSweepParams,
sweep_idx: usize,
direction: SweepDirection,
) -> Result<DmrgStepRecord<T::Real>, DmrgSweepError>
where
T: Scalar,
T::Real: Scalar<Real = T::Real>,
St: Storage + StorageFor<L>,
L: TensorLayout,
Mps<St, L>: DmrgOps<T> + MpsOps<T, Storage = St, Layout = L>,
DmrgEnvs<St, L>: DmrgEnvOps<T, Storage = St, Layout = L>,
Host: OpsFor<St>,
{
let (absorbed, (eigenvalue, residual, trunc_err, iters, converged)) = mps
.full_step_k(
envs,
mpo,
site,
¶ms.eigensolver,
¶ms.trunc,
direction,
)
.map_err(|source| match source {
FullStepError::Heff(source) => DmrgSweepError::Step {
sweep: sweep_idx,
direction,
site,
source,
},
FullStepError::Scale(source) => DmrgSweepError::Scale {
sweep: sweep_idx,
direction,
site,
source,
},
})?;
*mps.site_mut(site) = absorbed.site_i;
*mps.site_mut(site + 1) = absorbed.site_ip1;
Ok(DmrgStepRecord {
sweep: sweep_idx,
direction,
site,
eigenvalue,
residual,
trunc_err,
bond_dim: absorbed.bond_dim,
eigensolver_iters: iters,
eigensolver_converged: converged,
})
}
pub(super) fn validate_params(params: &DmrgSweepParams) -> Result<(), DmrgSweepError> {
if params.max_sweeps == 0 {
return Err(DmrgSweepError::InvalidParams {
detail: "max_sweeps must be >= 1",
});
}
if params.min_sweeps > params.max_sweeps {
return Err(DmrgSweepError::InvalidParams {
detail: "min_sweeps must be <= max_sweeps",
});
}
if !params.energy_tol.is_finite() {
return Err(DmrgSweepError::InvalidParams {
detail: "energy_tol must be finite",
});
}
if params.energy_tol < 0.0 {
return Err(DmrgSweepError::InvalidParams {
detail: "energy_tol must be non-negative",
});
}
validate_eigensolver_params(¶ms.eigensolver)
.map_err(|detail| DmrgSweepError::InvalidParams { detail })?;
if let Some(chi) = params.trunc.chi_max
&& chi == 0
{
return Err(DmrgSweepError::InvalidParams {
detail: "trunc.chi_max must be > 0 if Some",
});
}
if let Some(te) = params.trunc.target_trunc_err {
if !te.is_finite() {
return Err(DmrgSweepError::InvalidParams {
detail: "trunc.target_trunc_err must be finite",
});
}
if te < 0.0 {
return Err(DmrgSweepError::InvalidParams {
detail: "trunc.target_trunc_err must be non-negative",
});
}
}
Ok(())
}