use ariadnetor_core::backend::MemoryOrder;
use rand::SeedableRng;
use crate::block_sparse::*;
use crate::sector::{U1Sector, Z2Sector};
use crate::test_fixtures::{legs, out_in_legs, square_legs};
#[test]
fn zeros_u1_identity_flux() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 2);
assert_eq!(bs.stored_len(), 13);
assert_eq!(bs.shape(), &[5, 5]);
let d00 = bs.block_data(&BlockCoord(vec![0, 0])).unwrap();
assert!(d00.iter().all(|&v| v == 0.0));
let d11 = bs.block_data(&BlockCoord(vec![1, 1])).unwrap();
assert!(d11.iter().all(|&v| v == 0.0));
assert!(bs.block_data(&BlockCoord(vec![0, 1])).is_none());
assert!(bs.block_data(&BlockCoord(vec![1, 0])).is_none());
}
#[test]
fn zeros_u1_nonzero_flux() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
out_in_legs(
vec![(U1Sector(0), 2), (U1Sector(1), 3)],
vec![(U1Sector(0), 4)],
),
U1Sector(1),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 1);
assert_eq!(bs.stored_len(), 12);
assert!(bs.block_data(&BlockCoord(vec![1, 0])).is_some());
}
#[test]
fn zeros_z2() {
let bs = BlockSparseTensorData::<f64, Z2Sector>::zeros(
out_in_legs(
vec![(Z2Sector::new(0), 2), (Z2Sector::new(1), 3)],
vec![(Z2Sector::new(0), 4), (Z2Sector::new(1), 5)],
),
Z2Sector::new(0),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 2);
assert_eq!(bs.stored_len(), 23);
}
#[test]
fn zeros_rank3() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
legs([
(vec![(U1Sector(0), 2), (U1Sector(1), 1)], Direction::Out),
(vec![(U1Sector(0), 3)], Direction::Out),
(vec![(U1Sector(0), 2), (U1Sector(1), 1)], Direction::In),
]),
U1Sector(0),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 2);
assert_eq!(bs.stored_len(), 15);
}
#[test]
fn zeros_rank0_identity_flux() {
let bs =
BlockSparseTensorData::<f64, U1Sector>::zeros(vec![], U1Sector(0), MemoryOrder::RowMajor);
assert_eq!(bs.rank(), 0);
assert_eq!(bs.shape(), &[] as &[usize]);
assert_eq!(bs.num_blocks(), 1);
assert_eq!(bs.stored_len(), 1);
let d = bs.block_data(&BlockCoord(vec![])).unwrap();
assert_eq!(d, &[0.0]);
}
#[test]
fn zeros_rank0_nonidentity_flux() {
let bs =
BlockSparseTensorData::<f64, U1Sector>::zeros(vec![], U1Sector(1), MemoryOrder::RowMajor);
assert_eq!(bs.rank(), 0);
assert_eq!(bs.num_blocks(), 0);
assert_eq!(bs.stored_len(), 0);
}
#[test]
fn zeros_no_allowed_blocks() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
legs([
(vec![(U1Sector(0), 2)], Direction::Out),
(vec![(U1Sector(0), 3)], Direction::Out),
]),
U1Sector(1),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 0);
assert_eq!(bs.stored_len(), 0);
assert_eq!(bs.shape(), &[2, 3]);
}
#[test]
fn zeros_block_layout() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
assert_eq!(bs.num_blocks(), 2);
let metas = bs.block_metas();
assert_eq!(metas[0].coord, BlockCoord(vec![0, 0]));
assert_eq!(metas[0].size, 4);
assert_eq!(metas[1].coord, BlockCoord(vec![1, 1]));
assert_eq!(metas[1].size, 9);
assert_eq!(metas[0].offset, 0);
assert_eq!(metas[1].offset, 4);
}
#[test]
fn block_data_mut_fills_block() {
let mut bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
let d = bs.block_data_mut(&BlockCoord(vec![0, 0])).unwrap();
for (i, v) in d.iter_mut().enumerate() {
*v = (i + 1) as f64;
}
let d = bs.block_data(&BlockCoord(vec![0, 0])).unwrap();
assert_eq!(d, &[1.0, 2.0, 3.0, 4.0]);
let d11 = bs.block_data(&BlockCoord(vec![1, 1])).unwrap();
assert!(d11.iter().all(|&v| v == 0.0));
}
#[test]
fn block_data_mut_nonexistent_returns_none() {
let mut bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
assert!(bs.block_data_mut(&BlockCoord(vec![0, 1])).is_none());
}
#[test]
fn block_data_mut_cow_semantics() {
let mut bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
let cloned = bs.clone();
assert_eq!(cloned.block_data(&BlockCoord(vec![0, 0])).unwrap()[0], 0.0);
let d = bs.block_data_mut(&BlockCoord(vec![0, 0])).unwrap();
d[0] = 42.0;
assert_eq!(bs.block_data(&BlockCoord(vec![0, 0])).unwrap()[0], 42.0);
assert_eq!(cloned.block_data(&BlockCoord(vec![0, 0])).unwrap()[0], 0.0);
}
#[test]
fn random_matches_zeros_structure() {
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
let indices = out_in_legs(
vec![(U1Sector(0), 2), (U1Sector(1), 3), (U1Sector(2), 4)],
vec![(U1Sector(0), 5), (U1Sector(1), 2)],
);
let zeros = BlockSparseTensorData::<f64, U1Sector>::zeros(
indices.clone(),
U1Sector(1),
MemoryOrder::RowMajor,
);
let rand_bs = BlockSparseTensorData::<f64, U1Sector>::random(
indices,
U1Sector(1),
MemoryOrder::RowMajor,
&mut rng,
);
assert_eq!(rand_bs.shape(), zeros.shape());
assert_eq!(rand_bs.num_blocks(), zeros.num_blocks());
assert_eq!(rand_bs.stored_len(), zeros.stored_len());
assert_eq!(rand_bs.flux(), zeros.flux());
assert_eq!(rand_bs.indices().len(), zeros.indices().len());
}
#[test]
fn random_reproducible() {
let indices = square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]);
let mut rng1 = rand::rngs::StdRng::seed_from_u64(123);
let bs1 = BlockSparseTensorData::<f64, U1Sector>::random(
indices.clone(),
U1Sector(0),
MemoryOrder::RowMajor,
&mut rng1,
);
let mut rng2 = rand::rngs::StdRng::seed_from_u64(123);
let bs2 = BlockSparseTensorData::<f64, U1Sector>::random(
indices,
U1Sector(0),
MemoryOrder::RowMajor,
&mut rng2,
);
for meta in bs1.block_metas() {
let d1 = bs1.block_data(&meta.coord).unwrap();
let d2 = bs2.block_data(&meta.coord).unwrap();
assert_eq!(d1, d2);
}
}
#[test]
fn random_data_is_nonzero() {
let mut rng = rand::rngs::StdRng::seed_from_u64(7);
let bs = BlockSparseTensorData::<f64, U1Sector>::random(
square_legs(vec![(U1Sector(0), 4), (U1Sector(1), 4)]),
U1Sector(0),
MemoryOrder::RowMajor,
&mut rng,
);
let has_nonzero = bs.block_metas().iter().any(|meta| {
bs.block_data(&meta.coord)
.unwrap()
.iter()
.any(|&v| v != 0.0)
});
assert!(has_nonzero);
}
#[test]
fn is_allowed_block_matches_flux_conservation() {
let bs = BlockSparseTensorData::<f64, U1Sector>::zeros(
square_legs(vec![(U1Sector(0), 2), (U1Sector(1), 3)]),
U1Sector(0),
MemoryOrder::RowMajor,
);
assert!(bs.is_allowed_block(&BlockCoord(vec![0, 0])));
assert!(bs.is_allowed_block(&BlockCoord(vec![1, 1])));
assert!(!bs.is_allowed_block(&BlockCoord(vec![0, 1])));
assert!(!bs.is_allowed_block(&BlockCoord(vec![1, 0])));
}