ariadnetor_algorithms/dmrg/dispatch.rs
1//! Chain-keyed dispatch for the 2-site DMRG sweep driver.
2//!
3//! [`DmrgOps`] is keyed on the [`Mps`] chain (sharing its [`MpsOps`]
4//! supertrait so the storage / layout taxa come from one place) and lets
5//! [`super::sweep::sweep_2site`] be written once over both the Dense and
6//! BlockSparse paths. Each implementation is a thin delegation to the
7//! existing storage-specific kernels, with no logic duplication. Sealed
8//! transitively through `MpsOps` (itself sealed), so it cannot be
9//! implemented downstream.
10//!
11//! The per-storage step result (`TwoSiteStepResult` /
12//! `TwoSiteStepResultBlockSparse`) diverges in its `s` field and cannot
13//! collapse to one chain-agnostic struct. Rather than expose it as a `pub`
14//! associated type, [`DmrgOps::full_step_k`] runs the whole 2-site step
15//! (solve, project diagnostics, absorb `S`) in one call, keeping the step
16//! result an impl-internal local and returning only public types.
17
18use ariadnetor_core::Scalar;
19use ariadnetor_linalg::{LinalgError, TruncSvdParams, diagonal_scale};
20use ariadnetor_mps::{Mpo, Mps, MpsOps};
21use ariadnetor_tensor::{
22 BlockSparseLayout, BlockSparseStorage, DenseLayout, DenseStorage, Host, Sector, Storage,
23 StorageFor, Tensor, TensorLayout,
24};
25
26use super::heff::dmrg_2site_step;
27use super::heff_block_sparse::dmrg_2site_step_block_sparse;
28use super::heff_error::DmrgHeffError;
29use super::solver::{DmrgScalar, LocalEigensolverParams};
30use super::sweep::SweepDirection;
31use ariadnetor_mps::BraketEnvs;
32
33/// Post-absorb site tensors + new bond dimension returned by
34/// [`DmrgOps::full_step_k`], paired there with the diagnostic scalars.
35///
36/// The matching diagnostics (eigenvalue / residual / trunc-err / iters /
37/// converged) are typed on `T::Real` and returned alongside this struct
38/// rather than held on it; folding them in would force a third scalar
39/// parameter, so the struct stays keyed on `(St, L)` only.
40pub struct AbsorbedStep<St, L>
41where
42 St: Storage + StorageFor<L>,
43 L: TensorLayout,
44{
45 /// Post-absorb tensor to write into MPS site `i`.
46 pub site_i: Tensor<St, L>,
47 /// Post-absorb tensor to write into MPS site `i + 1`.
48 pub site_ip1: Tensor<St, L>,
49 /// Bond dimension of the new shared bond between sites `i` and
50 /// `i + 1`. For BlockSparse / U(1), summed over retained sectors.
51 pub bond_dim: usize,
52}
53
54/// Diagnostic scalars projected from a 2-site step, in the order
55/// `(eigenvalue, residual, trunc_err, iters, converged)`.
56type StepDiagnostics<R> = (R, R, R, usize, bool);
57
58/// Successful / failed outcome of [`DmrgOps::full_step_k`]: the absorbed
59/// site tensors plus diagnostics, or a [`FullStepError`].
60type FullStepOutput<St, L, R> = Result<(AbsorbedStep<St, L>, StepDiagnostics<R>), FullStepError>;
61
62/// Error raised by [`DmrgOps::full_step_k`], distinguishing the local
63/// eigensolve failure from the post-solve S-absorb failure so the sweep
64/// driver can keep its `Step` / `Scale` breadcrumbs.
65#[derive(Debug, thiserror::Error)]
66#[non_exhaustive]
67pub enum FullStepError {
68 /// The local effective-Hamiltonian solve (`dmrg_2site_step` /
69 /// `dmrg_2site_step_block_sparse`) failed.
70 #[error("DMRG local 2-site step failed")]
71 Heff(#[source] DmrgHeffError),
72 /// The post-solve S-absorb (`diagonal_scale`) failed.
73 #[error("DMRG S-absorb (diagonal scale) failed")]
74 Scale(#[source] LinalgError),
75}
76
77/// Chain-keyed dispatch trait for the 2-site DMRG sweep driver.
78///
79/// Keyed on the [`Mps`] chain with [`MpsOps<T>`] as supertrait (same
80/// `Self`), so the storage / layout taxa are the chain's own via
81/// [`MpsOps::Storage`] / [`MpsOps::Layout`]. The env subsystem shares that
82/// storage flavor; the sweep driver binds [`ariadnetor_mps::BraketEnvOps`] on
83/// the matching [`BraketEnvs`] chain at its call site, so the
84/// storage-coincidence the former layout-keyed super-bound expressed is
85/// now carried by the two chains sharing `(St, L)`.
86pub trait DmrgOps<T: Scalar>: MpsOps<T> {
87 /// Build local `H_eff`, drive the local eigensolver, project the
88 /// `T::Real` diagnostics, and absorb `S` — fused into one step
89 /// returning only public types. The order is preserved: solve →
90 /// project diagnostics → absorb, so the scalars are read before the
91 /// step result is consumed.
92 ///
93 /// Host-pinned: the explicit-backend scaling path runs on
94 /// [`Host::shared`], so the method takes no backend at the call site.
95 /// DMRG is host-pinned in the CPU-only Stage B scope.
96 fn full_step_k(
97 &self,
98 envs: &BraketEnvs<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout>,
99 mpo: &Mpo<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout>,
100 site: usize,
101 eigensolver: &LocalEigensolverParams,
102 trunc: &TruncSvdParams,
103 direction: SweepDirection,
104 ) -> FullStepOutput<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout, T::Real>;
105}
106
107// ---------------------------------------------------------------------------
108// Dense implementation
109// ---------------------------------------------------------------------------
110
111impl<T> DmrgOps<T> for Mps<DenseStorage<T>, DenseLayout>
112where
113 T: DmrgScalar,
114 T::Real: Scalar<Real = T::Real>,
115{
116 fn full_step_k(
117 &self,
118 envs: &BraketEnvs<DenseStorage<T>, DenseLayout>,
119 mpo: &Mpo<DenseStorage<T>, DenseLayout>,
120 site: usize,
121 eigensolver: &LocalEigensolverParams,
122 trunc: &TruncSvdParams,
123 direction: SweepDirection,
124 ) -> FullStepOutput<DenseStorage<T>, DenseLayout, T::Real> {
125 let result = dmrg_2site_step(envs, self, mpo, site, eigensolver, trunc)
126 .map_err(FullStepError::Heff)?;
127 // Project diagnostics before the result is consumed by the absorb.
128 let diagnostics = (
129 result.eigenvalue,
130 result.residual,
131 result.trunc_err,
132 result.iters,
133 result.converged,
134 );
135 let backend = Host::shared();
136 let bond_dim = result.s.shape()[0];
137 let (site_i, site_ip1) = match direction {
138 SweepDirection::LeftToRight => {
139 let s_vt = diagonal_scale(backend.as_ref(), &result.vt, result.s.data_slice(), 0)
140 .map_err(FullStepError::Scale)?;
141 (result.u, s_vt)
142 }
143 SweepDirection::RightToLeft => {
144 let u_s = diagonal_scale(backend.as_ref(), &result.u, result.s.data_slice(), 2)
145 .map_err(FullStepError::Scale)?;
146 (u_s, result.vt)
147 }
148 };
149 Ok((
150 AbsorbedStep {
151 site_i,
152 site_ip1,
153 bond_dim,
154 },
155 diagnostics,
156 ))
157 }
158}
159
160// ---------------------------------------------------------------------------
161// BlockSparse implementation
162// ---------------------------------------------------------------------------
163
164impl<T, S> DmrgOps<T> for Mps<BlockSparseStorage<T>, BlockSparseLayout<S>>
165where
166 T: DmrgScalar,
167 T::Real: Scalar<Real = T::Real>,
168 S: Sector,
169{
170 fn full_step_k(
171 &self,
172 envs: &BraketEnvs<BlockSparseStorage<T>, BlockSparseLayout<S>>,
173 mpo: &Mpo<BlockSparseStorage<T>, BlockSparseLayout<S>>,
174 site: usize,
175 eigensolver: &LocalEigensolverParams,
176 trunc: &TruncSvdParams,
177 direction: SweepDirection,
178 ) -> FullStepOutput<BlockSparseStorage<T>, BlockSparseLayout<S>, T::Real> {
179 let result = dmrg_2site_step_block_sparse(envs, self, mpo, site, eigensolver, trunc)
180 .map_err(FullStepError::Heff)?;
181 let diagnostics = (
182 result.eigenvalue,
183 result.residual,
184 result.trunc_err,
185 result.iters,
186 result.converged,
187 );
188 let backend = Host::shared();
189 // Total post-truncation singular values across sectors — the
190 // conventional U(1) MPS bond dimension.
191 let bond_dim: usize = result.s.values.iter().map(|(_, v)| v.len()).sum();
192 let (site_i, site_ip1) = match direction {
193 SweepDirection::LeftToRight => {
194 let s_vt = diagonal_scale(backend.as_ref(), &result.vt, &result.s, 0)
195 .map_err(FullStepError::Scale)?;
196 (result.u, s_vt)
197 }
198 SweepDirection::RightToLeft => {
199 let u_s = diagonal_scale(backend.as_ref(), &result.u, &result.s, 2)
200 .map_err(FullStepError::Scale)?;
201 (u_s, result.vt)
202 }
203 };
204 Ok((
205 AbsorbedStep {
206 site_i,
207 site_ip1,
208 bond_dim,
209 },
210 diagnostics,
211 ))
212 }
213}