use ariadnetor_core::Scalar;
use ariadnetor_linalg::{LinalgError, permute_block_sparse_with_backend, tensordot};
use ariadnetor_tensor::{
BlockCoord, BlockSparseLayout, BlockSparseStorage, BlockSparseTensor, Direction, Host, QNIndex,
Sector,
};
use super::env::{DmrgEnvError, DmrgEnvOps, DmrgEnvs};
fn flip(d: Direction) -> Direction {
match d {
Direction::Out => Direction::In,
Direction::In => Direction::Out,
}
}
fn swap_axes_0_and_2<T, S>(t: &BlockSparseTensor<T, S>) -> BlockSparseTensor<T, S>
where
T: Scalar,
S: Sector,
{
debug_assert_eq!(t.rank(), 3, "swap_axes_0_and_2 expects rank-3");
permute_block_sparse_with_backend(Host::shared().as_ref(), t, &[2, 1, 0])
.expect("valid rank-3 permutation on a same-handle intermediate")
}
fn check_dim1_single_sector<S: Sector>(
idx: &QNIndex<S>,
leg: &'static str,
) -> Result<(), DmrgEnvError> {
if idx.num_blocks() == 1 && idx.block_dim(0) == 1 {
Ok(())
} else {
Err(DmrgEnvError::MalformedEdgeBond { leg })
}
}
fn build_boundary<T, S>(
mps_edge: &QNIndex<S>,
mpo_edge: &QNIndex<S>,
mps_leg_name: &'static str,
mpo_leg_name: &'static str,
) -> Result<BlockSparseTensor<T, S>, DmrgEnvError>
where
T: Scalar,
S: Sector,
{
check_dim1_single_sector(mps_edge, mps_leg_name)?;
check_dim1_single_sector(mpo_edge, mpo_leg_name)?;
let env_leg0 = QNIndex::new(mps_edge.blocks().to_vec(), mps_edge.direction());
let env_leg1 = QNIndex::new(mpo_edge.blocks().to_vec(), flip(mpo_edge.direction()));
let env_leg2 = QNIndex::new(mps_edge.blocks().to_vec(), flip(mps_edge.direction()));
let mut env =
BlockSparseTensor::<T, S>::zeros(vec![env_leg0, env_leg1, env_leg2], S::identity());
let coord = BlockCoord(vec![0, 0, 0]);
match env.block_data_mut(&coord) {
Some(slot) => {
slot[0] = T::one();
Ok(env)
}
None => Err(DmrgEnvError::MalformedEdgeBond { leg: mpo_leg_name }),
}
}
impl<T, S> DmrgEnvOps<T> for DmrgEnvs<BlockSparseStorage<T>, BlockSparseLayout<S>>
where
T: Scalar,
S: Sector,
{
type Layout = BlockSparseLayout<S>;
type Storage = BlockSparseStorage<T>;
fn trivial_left_boundary(
mps_left_edge: &BlockSparseTensor<T, S>,
mpo_left_edge: &BlockSparseTensor<T, S>,
) -> Result<BlockSparseTensor<T, S>, DmrgEnvError> {
build_boundary(
&mps_left_edge.indices()[0],
&mpo_left_edge.indices()[0],
"mps_left",
"mpo_left",
)
}
fn trivial_right_boundary(
mps_right_edge: &BlockSparseTensor<T, S>,
mpo_right_edge: &BlockSparseTensor<T, S>,
) -> Result<BlockSparseTensor<T, S>, DmrgEnvError> {
build_boundary(
&mps_right_edge.indices()[2],
&mpo_right_edge.indices()[3],
"mps_right",
"mpo_right",
)
}
fn extend_left_step(
env: &BlockSparseTensor<T, S>,
site: &BlockSparseTensor<T, S>,
mpo_site: &BlockSparseTensor<T, S>,
) -> Result<BlockSparseTensor<T, S>, LinalgError> {
let backend = Host::shared();
let bra = site.dagger();
let t1 = tensordot(backend.as_ref(), env, &bra, &[0], &[0])?;
let t2 = tensordot(backend.as_ref(), &t1, mpo_site, &[0, 2], &[0, 2])?;
let env_new = tensordot(backend.as_ref(), &t2, site, &[0, 2], &[0, 1])?;
Ok(env_new)
}
fn extend_right_step(
env: &BlockSparseTensor<T, S>,
site: &BlockSparseTensor<T, S>,
mpo_site: &BlockSparseTensor<T, S>,
) -> Result<BlockSparseTensor<T, S>, LinalgError> {
let backend = Host::shared();
let bra = site.dagger();
let t1 = tensordot(backend.as_ref(), env, site, &[2], &[2])?;
let t2 = tensordot(backend.as_ref(), &t1, mpo_site, &[1, 3], &[3, 1])?;
let env_raw = tensordot(backend.as_ref(), &t2, &bra, &[0, 3], &[2, 1])?;
Ok(swap_axes_0_and_2(&env_raw))
}
}