Skip to main content

ariadnetor_algorithms/dmrg/
dispatch.rs

1//! Chain-keyed dispatch for the 2-site DMRG sweep driver.
2//!
3//! [`DmrgOps`] is keyed on the [`Mps`] chain (sharing its [`MpsOps`]
4//! supertrait so the storage / layout taxa come from one place) and lets
5//! [`super::sweep::sweep_2site`] be written once over both the Dense and
6//! BlockSparse paths. Each implementation is a thin delegation to the
7//! existing storage-specific kernels, with no logic duplication. Sealed
8//! transitively through `MpsOps` (itself sealed), so it cannot be
9//! implemented downstream.
10//!
11//! The per-storage step result (`TwoSiteStepResult` /
12//! `TwoSiteStepResultBlockSparse`) diverges in its `s` field and cannot
13//! collapse to one chain-agnostic struct. Rather than expose it as a `pub`
14//! associated type, [`DmrgOps::full_step_k`] runs the whole 2-site step
15//! (solve, project diagnostics, absorb `S`) in one call, keeping the step
16//! result an impl-internal local and returning only public types.
17
18use ariadnetor_core::Scalar;
19use ariadnetor_linalg::{LinalgError, TruncSvdParams, diagonal_scale};
20use ariadnetor_mps::{Mpo, Mps, MpsOps};
21use ariadnetor_tensor::{
22    BlockSparseLayout, BlockSparseStorage, DenseLayout, DenseStorage, Host, Sector, Storage,
23    StorageFor, Tensor, TensorLayout,
24};
25
26use super::env::DmrgEnvs;
27use super::heff::dmrg_2site_step;
28use super::heff_block_sparse::dmrg_2site_step_block_sparse;
29use super::heff_error::DmrgHeffError;
30use super::solver::{DmrgScalar, LocalEigensolverParams};
31use super::sweep::SweepDirection;
32
33/// Post-absorb site tensors + new bond dimension returned by
34/// [`DmrgOps::full_step_k`], paired there with the diagnostic scalars.
35///
36/// The matching diagnostics (eigenvalue / residual / trunc-err / iters /
37/// converged) are typed on `T::Real` and returned alongside this struct
38/// rather than held on it; folding them in would force a third scalar
39/// parameter, so the struct stays keyed on `(St, L)` only.
40pub struct AbsorbedStep<St, L>
41where
42    St: Storage + StorageFor<L>,
43    L: TensorLayout,
44{
45    /// Post-absorb tensor to write into MPS site `i`.
46    pub site_i: Tensor<St, L>,
47    /// Post-absorb tensor to write into MPS site `i + 1`.
48    pub site_ip1: Tensor<St, L>,
49    /// Bond dimension of the new shared bond between sites `i` and
50    /// `i + 1`. For BlockSparse / U(1), summed over retained sectors.
51    pub bond_dim: usize,
52}
53
54/// Diagnostic scalars projected from a 2-site step, in the order
55/// `(eigenvalue, residual, trunc_err, iters, converged)`.
56type StepDiagnostics<R> = (R, R, R, usize, bool);
57
58/// Successful / failed outcome of [`DmrgOps::full_step_k`]: the absorbed
59/// site tensors plus diagnostics, or a [`FullStepError`].
60type FullStepOutput<St, L, R> = Result<(AbsorbedStep<St, L>, StepDiagnostics<R>), FullStepError>;
61
62/// Error raised by [`DmrgOps::full_step_k`], distinguishing the local
63/// eigensolve failure from the post-solve S-absorb failure so the sweep
64/// driver can keep its `Step` / `Scale` breadcrumbs.
65#[derive(Debug, thiserror::Error)]
66#[non_exhaustive]
67pub enum FullStepError {
68    /// The local effective-Hamiltonian solve (`dmrg_2site_step` /
69    /// `dmrg_2site_step_block_sparse`) failed.
70    #[error("DMRG local 2-site step failed")]
71    Heff(#[source] DmrgHeffError),
72    /// The post-solve S-absorb (`diagonal_scale`) failed.
73    #[error("DMRG S-absorb (diagonal scale) failed")]
74    Scale(#[source] LinalgError),
75}
76
77/// Chain-keyed dispatch trait for the 2-site DMRG sweep driver.
78///
79/// Keyed on the [`Mps`] chain with [`MpsOps<T>`] as supertrait (same
80/// `Self`), so the storage / layout taxa are the chain's own via
81/// [`MpsOps::Storage`] / [`MpsOps::Layout`]. The env subsystem shares that
82/// storage flavor; the sweep driver binds [`super::env::DmrgEnvOps`] on
83/// the matching [`DmrgEnvs`] chain at its call site, so the
84/// storage-coincidence the former layout-keyed super-bound expressed is
85/// now carried by the two chains sharing `(St, L)`.
86pub trait DmrgOps<T: Scalar>: MpsOps<T> {
87    /// Build local `H_eff`, drive the local eigensolver, project the
88    /// `T::Real` diagnostics, and absorb `S` — fused into one step
89    /// returning only public types. The order is preserved: solve →
90    /// project diagnostics → absorb, so the scalars are read before the
91    /// step result is consumed.
92    ///
93    /// Host-pinned: the explicit-backend scaling path runs on
94    /// [`Host::shared`], so the method takes no backend at the call site.
95    /// DMRG is host-pinned in the CPU-only Stage B scope.
96    fn full_step_k(
97        &self,
98        envs: &DmrgEnvs<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout>,
99        mpo: &Mpo<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout>,
100        site: usize,
101        eigensolver: &LocalEigensolverParams,
102        trunc: &TruncSvdParams,
103        direction: SweepDirection,
104    ) -> FullStepOutput<<Self as MpsOps<T>>::Storage, <Self as MpsOps<T>>::Layout, T::Real>;
105}
106
107// ---------------------------------------------------------------------------
108// Dense implementation
109// ---------------------------------------------------------------------------
110
111impl<T> DmrgOps<T> for Mps<DenseStorage<T>, DenseLayout>
112where
113    T: DmrgScalar,
114    T::Real: Scalar<Real = T::Real>,
115{
116    fn full_step_k(
117        &self,
118        envs: &DmrgEnvs<DenseStorage<T>, DenseLayout>,
119        mpo: &Mpo<DenseStorage<T>, DenseLayout>,
120        site: usize,
121        eigensolver: &LocalEigensolverParams,
122        trunc: &TruncSvdParams,
123        direction: SweepDirection,
124    ) -> FullStepOutput<DenseStorage<T>, DenseLayout, T::Real> {
125        let result = dmrg_2site_step(envs, self, mpo, site, eigensolver, trunc)
126            .map_err(FullStepError::Heff)?;
127        // Project diagnostics before the result is consumed by the absorb.
128        let diagnostics = (
129            result.eigenvalue,
130            result.residual,
131            result.trunc_err,
132            result.iters,
133            result.converged,
134        );
135        let backend = Host::shared();
136        let bond_dim = result.s.shape()[0];
137        let (site_i, site_ip1) = match direction {
138            SweepDirection::LeftToRight => {
139                let s_vt = diagonal_scale(backend.as_ref(), &result.vt, result.s.data_slice(), 0)
140                    .map_err(FullStepError::Scale)?;
141                (result.u, s_vt)
142            }
143            SweepDirection::RightToLeft => {
144                let u_s = diagonal_scale(backend.as_ref(), &result.u, result.s.data_slice(), 2)
145                    .map_err(FullStepError::Scale)?;
146                (u_s, result.vt)
147            }
148        };
149        Ok((
150            AbsorbedStep {
151                site_i,
152                site_ip1,
153                bond_dim,
154            },
155            diagnostics,
156        ))
157    }
158}
159
160// ---------------------------------------------------------------------------
161// BlockSparse implementation
162// ---------------------------------------------------------------------------
163
164impl<T, S> DmrgOps<T> for Mps<BlockSparseStorage<T>, BlockSparseLayout<S>>
165where
166    T: DmrgScalar,
167    T::Real: Scalar<Real = T::Real>,
168    S: Sector,
169{
170    fn full_step_k(
171        &self,
172        envs: &DmrgEnvs<BlockSparseStorage<T>, BlockSparseLayout<S>>,
173        mpo: &Mpo<BlockSparseStorage<T>, BlockSparseLayout<S>>,
174        site: usize,
175        eigensolver: &LocalEigensolverParams,
176        trunc: &TruncSvdParams,
177        direction: SweepDirection,
178    ) -> FullStepOutput<BlockSparseStorage<T>, BlockSparseLayout<S>, T::Real> {
179        let result = dmrg_2site_step_block_sparse(envs, self, mpo, site, eigensolver, trunc)
180            .map_err(FullStepError::Heff)?;
181        let diagnostics = (
182            result.eigenvalue,
183            result.residual,
184            result.trunc_err,
185            result.iters,
186            result.converged,
187        );
188        let backend = Host::shared();
189        // Total post-truncation singular values across sectors — the
190        // conventional U(1) MPS bond dimension.
191        let bond_dim: usize = result.s.values.iter().map(|(_, v)| v.len()).sum();
192        let (site_i, site_ip1) = match direction {
193            SweepDirection::LeftToRight => {
194                let s_vt = diagonal_scale(backend.as_ref(), &result.vt, &result.s, 0)
195                    .map_err(FullStepError::Scale)?;
196                (result.u, s_vt)
197            }
198            SweepDirection::RightToLeft => {
199                let u_s = diagonal_scale(backend.as_ref(), &result.u, &result.s, 2)
200                    .map_err(FullStepError::Scale)?;
201                (u_s, result.vt)
202            }
203        };
204        Ok((
205            AbsorbedStep {
206                site_i,
207                site_ip1,
208                bond_dim,
209            },
210            diagnostics,
211        ))
212    }
213}