use ariadnetor_core::Scalar;
use ariadnetor_mps::{Mpo, Mps, TensorChain};
use ariadnetor_tensor::{
BlockSparseLayout, BlockSparseStorage, BlockSparseTensor, QNIndex, Sector,
};
use super::super::env::DmrgEnvs;
use super::super::heff_error::DmrgHeffError;
use super::super::solver::{LocalEigensolverParams, eigensolver_tol, validate_eigensolver_params};
pub(super) struct ValidatedInputs<'a, T: Scalar, S: Sector> {
pub left: &'a BlockSparseTensor<T, S>,
pub right: &'a BlockSparseTensor<T, S>,
pub w_i: &'a BlockSparseTensor<T, S>,
pub w_ip1: &'a BlockSparseTensor<T, S>,
pub mps_i: &'a BlockSparseTensor<T, S>,
pub mps_ip1: &'a BlockSparseTensor<T, S>,
}
pub(super) fn validate_inputs<'a, T, S>(
envs: &'a DmrgEnvs<BlockSparseStorage<T>, BlockSparseLayout<S>>,
mps: &'a Mps<BlockSparseStorage<T>, BlockSparseLayout<S>>,
mpo: &'a Mpo<BlockSparseStorage<T>, BlockSparseLayout<S>>,
site: usize,
eigensolver: &LocalEigensolverParams,
) -> Result<ValidatedInputs<'a, T, S>, DmrgHeffError>
where
T: Scalar,
S: Sector,
{
let n_sites = envs.n_sites();
if mps.len() != n_sites || mpo.len() != n_sites {
return Err(DmrgHeffError::LengthMismatch {
mps: mps.len(),
mpo: mpo.len(),
envs: n_sites,
});
}
if site >= n_sites.saturating_sub(1) {
return Err(DmrgHeffError::InvalidSite { site, n_sites });
}
validate_eigensolver_params(eigensolver)
.map_err(|detail| DmrgHeffError::InvalidEigensolverParams { detail })?;
if crate::numeric::try_real_from_f64::<T>(eigensolver_tol(eigensolver)).is_none() {
return Err(DmrgHeffError::InvalidEigensolverParams {
detail: "tol is not representable in T::Real",
});
}
let left = envs.left(site).ok_or(DmrgHeffError::StaleEnv {
side: "left",
index: site,
})?;
let right = envs.right(site + 2).ok_or(DmrgHeffError::StaleEnv {
side: "right",
index: site + 2,
})?;
let w_i = mpo.site(site);
let w_ip1 = mpo.site(site + 1);
let mps_i = mps.site(site);
let mps_ip1 = mps.site(site + 1);
let check_eq =
|expected: usize, actual: usize, field: &'static str| -> Result<(), DmrgHeffError> {
if expected == actual {
Ok(())
} else {
Err(DmrgHeffError::ShapeMismatch {
site,
field,
expected,
actual,
})
}
};
let check_at_least =
|min: usize, actual: usize, field: &'static str| -> Result<(), DmrgHeffError> {
if actual >= min {
Ok(())
} else {
Err(DmrgHeffError::ShapeMismatch {
site,
field,
expected: min,
actual,
})
}
};
check_eq(3, left.rank(), "left.rank")?;
check_eq(3, right.rank(), "right.rank")?;
check_eq(4, w_i.rank(), "W[i].rank")?;
check_eq(4, w_ip1.rank(), "W[i+1].rank")?;
check_eq(3, mps_i.rank(), "MPS[i].rank")?;
check_eq(3, mps_ip1.rank(), "MPS[i+1].rank")?;
check_at_least(1, left.shape()[0], "left.top_bra (chi_l) total_dim")?;
check_at_least(1, right.shape()[0], "right.top_bra (chi_r) total_dim")?;
check_at_least(1, mps_i.shape()[1], "MPS[i].physical total_dim")?;
check_at_least(1, mps_ip1.shape()[1], "MPS[i+1].physical total_dim")?;
check_at_least(1, w_i.shape()[0], "W[i].W_l total_dim")?;
check_at_least(1, w_i.shape()[3], "W[i].W_r total_dim")?;
check_at_least(1, w_ip1.shape()[3], "W[i+1].W_r total_dim")?;
let chi_l = left.shape()[0];
let chi_r = right.shape()[0];
let d_i = mps_i.shape()[1];
let d_ip1 = mps_ip1.shape()[1];
check_eq(
left.shape()[2],
mps_i.shape()[0],
"left.bot_ket vs MPS[i].left_bond total_dim",
)?;
check_eq(
left.shape()[2],
chi_l,
"left.bot_ket vs left.top_bra total_dim",
)?;
check_eq(
right.shape()[2],
mps_ip1.shape()[2],
"right.bot_ket vs MPS[i+1].right_bond total_dim",
)?;
check_eq(
right.shape()[2],
chi_r,
"right.bot_ket vs right.top_bra total_dim",
)?;
check_eq(
left.shape()[1],
w_i.shape()[0],
"left.W_bond vs W[i].W_l total_dim",
)?;
check_eq(
right.shape()[1],
w_ip1.shape()[3],
"right.W_bond vs W[i+1].W_r total_dim",
)?;
check_eq(
w_i.shape()[3],
w_ip1.shape()[0],
"W[i].W_r vs W[i+1].W_l total_dim",
)?;
check_eq(w_i.shape()[1], d_i, "W[i].d_ket vs MPS[i] phys total_dim")?;
check_eq(w_i.shape()[2], d_i, "W[i].d_bra vs MPS[i] phys total_dim")?;
check_eq(
w_ip1.shape()[1],
d_ip1,
"W[i+1].d_ket vs MPS[i+1] phys total_dim",
)?;
check_eq(
w_ip1.shape()[2],
d_ip1,
"W[i+1].d_bra vs MPS[i+1] phys total_dim",
)?;
check_qn_pair(
site,
"left.bot_ket vs psi.axis 0 (MPS[i].left_bond)",
&left.indices()[2],
&mps_i.indices()[0],
true,
)?;
check_qn_pair(
site,
"left.W_bond vs W[i].W_l",
&left.indices()[1],
&w_i.indices()[0],
true,
)?;
check_qn_pair(
site,
"psi.axis 1 (MPS[i].phys) vs W[i].ket",
&mps_i.indices()[1],
&w_i.indices()[1],
true,
)?;
check_qn_pair(
site,
"psi.axis 2 (MPS[i+1].phys) vs W[i+1].ket",
&mps_ip1.indices()[1],
&w_ip1.indices()[1],
true,
)?;
check_qn_pair(
site,
"W[i].W_r vs W[i+1].W_l",
&w_i.indices()[3],
&w_ip1.indices()[0],
true,
)?;
check_qn_pair(
site,
"psi.axis 3 (MPS[i+1].right_bond) vs right.bot_ket",
&mps_ip1.indices()[2],
&right.indices()[2],
true,
)?;
check_qn_pair(
site,
"right.W_bond vs W[i+1].W_r",
&right.indices()[1],
&w_ip1.indices()[3],
true,
)?;
check_qn_pair(
site,
"left.top_bra vs psi.axis 0 (MPS[i].left_bond)",
&left.indices()[0],
&mps_i.indices()[0],
false,
)?;
check_qn_pair(
site,
"right.top_bra vs psi.axis 3 (MPS[i+1].right_bond)",
&right.indices()[0],
&mps_ip1.indices()[2],
false,
)?;
check_qn_pair(
site,
"W[i].bra vs W[i].ket",
&w_i.indices()[2],
&w_i.indices()[1],
true,
)?;
check_qn_pair(
site,
"W[i+1].bra vs W[i+1].ket",
&w_ip1.indices()[2],
&w_ip1.indices()[1],
true,
)?;
if w_i.indices()[2].direction() != mps_i.indices()[1].direction() {
return Err(DmrgHeffError::QnMismatch {
site,
field: "W[i].bra direction vs MPS[i] phys direction",
detail: format!(
"W[i].bra direction = {:?}, MPS[i].phys direction = {:?} (must be equal)",
w_i.indices()[2].direction(),
mps_i.indices()[1].direction()
),
});
}
if w_ip1.indices()[2].direction() != mps_ip1.indices()[1].direction() {
return Err(DmrgHeffError::QnMismatch {
site,
field: "W[i+1].bra direction vs MPS[i+1] phys direction",
detail: format!(
"W[i+1].bra direction = {:?}, MPS[i+1].phys direction = {:?} (must be equal)",
w_ip1.indices()[2].direction(),
mps_ip1.indices()[1].direction()
),
});
}
if !is_identity_flux(left.flux()) {
return Err(DmrgHeffError::QnMismatch {
site,
field: "left.flux",
detail: format!("left.flux = {:?} (must be identity)", left.flux()),
});
}
if !is_identity_flux(right.flux()) {
return Err(DmrgHeffError::QnMismatch {
site,
field: "right.flux",
detail: format!("right.flux = {:?} (must be identity)", right.flux()),
});
}
if !is_identity_flux(w_i.flux()) {
return Err(DmrgHeffError::QnMismatch {
site,
field: "W[i].flux",
detail: format!("W[i].flux = {:?} (must be identity)", w_i.flux()),
});
}
if !is_identity_flux(w_ip1.flux()) {
return Err(DmrgHeffError::QnMismatch {
site,
field: "W[i+1].flux",
detail: format!("W[i+1].flux = {:?} (must be identity)", w_ip1.flux()),
});
}
Ok(ValidatedInputs {
left,
right,
w_i,
w_ip1,
mps_i,
mps_ip1,
})
}
fn check_qn_pair<S: Sector>(
site: usize,
field: &'static str,
lhs: &QNIndex<S>,
rhs: &QNIndex<S>,
opposite_direction: bool,
) -> Result<(), DmrgHeffError> {
let dirs_ok = if opposite_direction {
lhs.direction() != rhs.direction()
} else {
lhs.direction() == rhs.direction()
};
if !dirs_ok {
return Err(DmrgHeffError::QnMismatch {
site,
field,
detail: format!(
"directions {:?} vs {:?} ({})",
lhs.direction(),
rhs.direction(),
if opposite_direction {
"must be opposite"
} else {
"must be equal"
}
),
});
}
if lhs.blocks() != rhs.blocks() {
return Err(DmrgHeffError::QnMismatch {
site,
field,
detail: format!(
"sector lists {:?} vs {:?} (must match by sector + per-sector dim)",
lhs.blocks(),
rhs.blocks()
),
});
}
Ok(())
}
fn is_identity_flux<S: Sector>(flux: &S) -> bool {
flux == &S::identity()
}