1use ariadnetor_core::Scalar;
28use ariadnetor_linalg::{TruncSvdParams, contract, trunc_svd};
29use ariadnetor_mps::{Mpo, Mps, TensorChain};
30use ariadnetor_tensor::{DenseTensor, Host};
31
32#[cfg(feature = "arpack")]
33use crate::krylov::arpack_smallest;
34use crate::krylov::{LinearOp, lanczos_smallest};
35
36use super::heff_error::DmrgHeffError;
37use super::solver::{
38 DmrgScalar, LocalEigensolverParams, eigensolver_tol, validate_eigensolver_params,
39};
40use ariadnetor_mps::BraketEnvs;
41
42#[derive(Debug, Clone)]
46pub(crate) struct EffectiveHamiltonian2Site<'a, T: Scalar> {
47 left: &'a DenseTensor<T>,
48 w_i: &'a DenseTensor<T>,
49 w_ip1: &'a DenseTensor<T>,
50 right: &'a DenseTensor<T>,
51 chi_l: usize,
52 d_i: usize,
53 d_ip1: usize,
54 chi_r: usize,
55}
56
57impl<'a, T: Scalar> EffectiveHamiltonian2Site<'a, T> {
58 #[allow(clippy::too_many_arguments)]
61 pub(crate) fn new(
62 left: &'a DenseTensor<T>,
63 w_i: &'a DenseTensor<T>,
64 w_ip1: &'a DenseTensor<T>,
65 right: &'a DenseTensor<T>,
66 chi_l: usize,
67 d_i: usize,
68 d_ip1: usize,
69 chi_r: usize,
70 ) -> Self {
71 debug_assert_eq!(left.shape().len(), 3, "left.rank == 3");
72 debug_assert_eq!(right.shape().len(), 3, "right.rank == 3");
73 debug_assert_eq!(w_i.shape().len(), 4, "W[i].rank == 4");
74 debug_assert_eq!(w_ip1.shape().len(), 4, "W[i+1].rank == 4");
75 debug_assert!(
76 chi_l > 0 && d_i > 0 && d_ip1 > 0 && chi_r > 0,
77 "heff: dims must be > 0"
78 );
79 debug_assert_eq!(
80 left.shape(),
81 &[chi_l, w_i.shape()[0], chi_l],
82 "left env shape"
83 );
84 debug_assert_eq!(
85 right.shape(),
86 &[chi_r, w_ip1.shape()[3], chi_r],
87 "right env shape"
88 );
89 debug_assert_eq!(w_i.shape()[1], d_i, "W[i] d_ket / d_i");
90 debug_assert_eq!(w_i.shape()[2], d_i, "W[i] d_bra / d_i");
91 debug_assert_eq!(w_ip1.shape()[1], d_ip1, "W[i+1] d_ket / d_ip1");
92 debug_assert_eq!(w_ip1.shape()[2], d_ip1, "W[i+1] d_bra / d_ip1");
93 debug_assert_eq!(w_i.shape()[3], w_ip1.shape()[0], "W bond W_mid agreement");
94 Self {
95 left,
96 w_i,
97 w_ip1,
98 right,
99 chi_l,
100 d_i,
101 d_ip1,
102 chi_r,
103 }
104 }
105
106 pub(crate) fn dim(&self) -> usize {
108 self.chi_l * self.d_i * self.d_ip1 * self.chi_r
109 }
110}
111
112impl<'a, T: Scalar> LinearOp<T> for EffectiveHamiltonian2Site<'a, T> {
113 fn apply(&self, v: &DenseTensor<T>) -> DenseTensor<T> {
114 let backend = Host::shared();
119 let psi = v.reshape(vec![self.chi_l, self.d_i, self.d_ip1, self.chi_r]);
120
121 let tmp1 = contract(backend.as_ref(), self.left, &psi, "abc,cijf->abijf")
122 .expect("heff matvec step 1: shape pre-validated");
123 let tmp2 = contract(backend.as_ref(), &tmp1, self.w_i, "abijf,bism->asmjf")
124 .expect("heff matvec step 2: shape pre-validated");
125 let tmp3 = contract(backend.as_ref(), &tmp2, self.w_ip1, "asmjf,mjtg->astgf")
126 .expect("heff matvec step 3: shape pre-validated");
127 let out = contract(backend.as_ref(), &tmp3, self.right, "astgf,hgf->asth")
128 .expect("heff matvec step 4: shape pre-validated");
129
130 out.reshape(vec![self.dim()])
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct TwoSiteStepResult<T: Scalar> {
137 pub eigenvalue: T::Real,
139 pub residual: T::Real,
141 pub iters: usize,
143 pub converged: bool,
147 pub u: DenseTensor<T>,
149 pub s: DenseTensor<T::Real>,
151 pub vt: DenseTensor<T>,
153 pub trunc_err: T::Real,
155}
156
157pub(crate) fn dmrg_2site_step<T>(
159 envs: &BraketEnvs<ariadnetor_tensor::DenseStorage<T>, ariadnetor_tensor::DenseLayout>,
160 mps: &Mps<ariadnetor_tensor::DenseStorage<T>, ariadnetor_tensor::DenseLayout>,
161 mpo: &Mpo<ariadnetor_tensor::DenseStorage<T>, ariadnetor_tensor::DenseLayout>,
162 site: usize,
163 eigensolver: &LocalEigensolverParams,
164 trunc: &TruncSvdParams,
165) -> Result<TwoSiteStepResult<T>, DmrgHeffError>
166where
167 T: DmrgScalar,
168 T::Real: Scalar<Real = T::Real>,
169{
170 let n_sites = envs.n_sites();
171 if mps.len() != n_sites || mpo.len() != n_sites {
172 return Err(DmrgHeffError::LengthMismatch {
173 mps: mps.len(),
174 mpo: mpo.len(),
175 envs: n_sites,
176 });
177 }
178 if site >= n_sites.saturating_sub(1) {
179 return Err(DmrgHeffError::InvalidSite { site, n_sites });
180 }
181
182 validate_eigensolver_params(eigensolver)
183 .map_err(|detail| DmrgHeffError::InvalidEigensolverParams { detail })?;
184 if crate::numeric::try_real_from_f64::<T>(eigensolver_tol(eigensolver)).is_none() {
185 return Err(DmrgHeffError::InvalidEigensolverParams {
186 detail: "tol is not representable in T::Real",
187 });
188 }
189
190 let left = envs.left(site).ok_or(DmrgHeffError::StaleEnv {
191 side: "left",
192 index: site,
193 })?;
194 let right = envs.right(site + 2).ok_or(DmrgHeffError::StaleEnv {
195 side: "right",
196 index: site + 2,
197 })?;
198 let w_i = mpo.site(site);
199 let w_ip1 = mpo.site(site + 1);
200 let mps_i = mps.site(site);
201 let mps_ip1 = mps.site(site + 1);
202
203 let check_eq =
204 |expected: usize, actual: usize, field: &'static str| -> Result<(), DmrgHeffError> {
205 if expected == actual {
206 Ok(())
207 } else {
208 Err(DmrgHeffError::ShapeMismatch {
209 site,
210 field,
211 expected,
212 actual,
213 })
214 }
215 };
216 let check_at_least =
217 |min: usize, actual: usize, field: &'static str| -> Result<(), DmrgHeffError> {
218 if actual >= min {
219 Ok(())
220 } else {
221 Err(DmrgHeffError::ShapeMismatch {
222 site,
223 field,
224 expected: min,
225 actual,
226 })
227 }
228 };
229 check_eq(3, left.shape().len(), "left.rank")?;
230 check_eq(3, right.shape().len(), "right.rank")?;
231 check_eq(4, w_i.shape().len(), "W[i].rank")?;
232 check_eq(4, w_ip1.shape().len(), "W[i+1].rank")?;
233 check_eq(3, mps_i.shape().len(), "MPS[i].rank")?;
234 check_eq(3, mps_ip1.shape().len(), "MPS[i+1].rank")?;
235
236 let chi_l = left.shape()[0];
237 let chi_r = right.shape()[0];
238 let d_i = mps_i.shape()[1];
239 let d_ip1 = mps_ip1.shape()[1];
240
241 check_at_least(1, chi_l, "chi_l (left bond)")?;
242 check_at_least(1, chi_r, "chi_r (right bond)")?;
243 check_at_least(1, d_i, "d_i (MPS[i] physical)")?;
244 check_at_least(1, d_ip1, "d_ip1 (MPS[i+1] physical)")?;
245 check_at_least(1, w_i.shape()[0], "W[i].W_l")?;
246 check_at_least(1, w_i.shape()[3], "W[i].W_r (= W_mid)")?;
247 check_at_least(1, w_ip1.shape()[3], "W[i+1].W_r")?;
248
249 check_eq(
250 left.shape()[2],
251 mps_i.shape()[0],
252 "left.bot_ket vs MPS[i].left_bond",
253 )?;
254 check_eq(
255 left.shape()[2],
256 chi_l,
257 "left.bot_ket vs left.top_bra (bra=ket)",
258 )?;
259 check_eq(
260 right.shape()[2],
261 mps_ip1.shape()[2],
262 "right.bot_ket vs MPS[i+1].right_bond",
263 )?;
264 check_eq(
265 right.shape()[2],
266 chi_r,
267 "right.bot_ket vs right.top_bra (bra=ket)",
268 )?;
269 check_eq(left.shape()[1], w_i.shape()[0], "left.W_bond vs W[i].W_l")?;
270 check_eq(
271 right.shape()[1],
272 w_ip1.shape()[3],
273 "right.W_bond vs W[i+1].W_r",
274 )?;
275 check_eq(
276 w_i.shape()[3],
277 w_ip1.shape()[0],
278 "W[i].W_r vs W[i+1].W_l (W_mid)",
279 )?;
280 check_eq(w_i.shape()[1], d_i, "W[i].d_ket vs MPS[i].physical")?;
281 check_eq(w_i.shape()[2], d_i, "W[i].d_bra vs MPS[i].physical")?;
282 check_eq(w_ip1.shape()[1], d_ip1, "W[i+1].d_ket vs MPS[i+1].physical")?;
283 check_eq(w_ip1.shape()[2], d_ip1, "W[i+1].d_bra vs MPS[i+1].physical")?;
284
285 let heff = EffectiveHamiltonian2Site::new(left, w_i, w_ip1, right, chi_l, d_i, d_ip1, chi_r);
286
287 let dim = heff.dim();
288 let (eigenvalue, eigenvector, iters, converged, residual) = match eigensolver {
291 LocalEigensolverParams::Lanczos(p) => {
292 let lan = lanczos_smallest::<T, _>(&heff, dim, p)?;
293 (
294 lan.eigenvalue,
295 lan.eigenvector,
296 lan.iters,
297 lan.converged,
298 lan.residual,
299 )
300 }
301 #[cfg(feature = "arpack")]
302 LocalEigensolverParams::Arpack(p) => {
303 let res = arpack_smallest::<T, _>(&heff, dim, p)?;
304 (
305 res.eigenvalue,
306 res.eigenvector,
307 res.iters,
308 true,
309 res.residual,
310 )
311 }
312 };
313
314 let psi_4d = eigenvector.reshape(vec![chi_l, d_i, d_ip1, chi_r]);
315 let (u_2d, s, vt_2d, trunc_err) = trunc_svd(Host::shared().as_ref(), &psi_4d, 2, trunc)?;
316
317 let chi_new = u_2d.shape()[1];
318 debug_assert_eq!(vt_2d.shape()[0], chi_new, "U/Vt new bond dim agreement");
319
320 let u = u_2d.split_leg(0, &[chi_l, d_i]);
324 let vt = vt_2d.split_leg(1, &[d_ip1, chi_r]);
325
326 Ok(TwoSiteStepResult {
327 eigenvalue,
328 residual,
329 iters,
330 converged,
331 u,
332 s,
333 vt,
334 trunc_err,
335 })
336}
337
338#[cfg(test)]
339mod tests;