use std::collections::HashMap;
use ariadnetor_core::backend::MemoryOrder;
use super::{BlockCoord, BlockMeta, Direction, QNIndex};
use crate::{Sector, TensorLayout};
#[derive(Clone)]
pub struct BlockSparseLayout<S: Sector> {
blocks: Vec<BlockMeta>,
block_index: HashMap<BlockCoord, usize>,
indices: Vec<QNIndex<S>>,
flux: S,
shape: Vec<usize>,
order: MemoryOrder,
storage_extent: usize,
}
impl<S: Sector> BlockSparseLayout<S> {
pub fn new(indices: Vec<QNIndex<S>>, flux: S, order: MemoryOrder) -> Self {
let (blocks, block_index, shape, storage_extent) =
enumerate_allowed_blocks(&indices, &flux);
Self {
blocks,
block_index,
indices,
flux,
shape,
order,
storage_extent,
}
}
pub fn flux(&self) -> &S {
&self.flux
}
pub fn indices(&self) -> &[QNIndex<S>] {
&self.indices
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn block_metas(&self) -> &[BlockMeta] {
&self.blocks
}
pub(crate) fn block_index(&self) -> &HashMap<BlockCoord, usize> {
&self.block_index
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn rank(&self) -> usize {
self.indices.len()
}
pub fn order(&self) -> MemoryOrder {
self.order
}
pub fn block_shape(&self, coord: &BlockCoord) -> Option<Vec<usize>> {
if coord.0.len() != self.indices.len() {
return None;
}
let mut shape = Vec::with_capacity(coord.0.len());
for (axis, &block_idx) in coord.0.iter().enumerate() {
if block_idx >= self.indices[axis].num_blocks() {
return None;
}
shape.push(self.indices[axis].block_dim(block_idx));
}
Some(shape)
}
pub(crate) fn dagger_layout(&self) -> Self {
let flipped_indices: Vec<QNIndex<S>> = self
.indices
.iter()
.map(|idx| {
let new_dir = match idx.direction() {
Direction::Out => Direction::In,
Direction::In => Direction::Out,
};
QNIndex::new(idx.blocks().to_vec(), new_dir)
})
.collect();
Self {
blocks: self.blocks.clone(),
block_index: self.block_index.clone(),
indices: flipped_indices,
flux: self.flux.dual(),
shape: self.shape.clone(),
order: self.order,
storage_extent: self.storage_extent,
}
}
pub fn is_allowed_block(&self, coord: &BlockCoord) -> bool {
if coord.0.len() != self.indices.len() {
return false;
}
let mut fused = S::identity();
for (axis, &block_idx) in coord.0.iter().enumerate() {
let idx = &self.indices[axis];
if block_idx >= idx.num_blocks() {
return false;
}
let sector = idx.sector(block_idx);
let directed = idx.direction().apply(sector);
fused = fused.fuse(&directed);
}
fused == self.flux
}
}
impl<S: Sector> TensorLayout for BlockSparseLayout<S> {
fn shape(&self) -> &[usize] {
&self.shape
}
fn storage_extent(&self) -> usize {
self.storage_extent
}
}
fn enumerate_allowed_blocks<S: Sector>(
indices: &[QNIndex<S>],
flux: &S,
) -> (
Vec<BlockMeta>,
HashMap<BlockCoord, usize>,
Vec<usize>,
usize,
) {
let shape: Vec<usize> = indices.iter().map(|idx| idx.total_dim()).collect();
let rank = indices.len();
let num_blocks_per_leg: Vec<usize> = indices.iter().map(|idx| idx.num_blocks()).collect();
let mut blocks = Vec::new();
let mut total_size = 0usize;
if rank == 0 || num_blocks_per_leg.iter().all(|&n| n > 0) {
let mut current = vec![0usize; rank];
loop {
let mut fused = S::identity();
for (axis, &bi) in current.iter().enumerate() {
let sector = indices[axis].sector(bi);
let directed = indices[axis].direction().apply(sector);
fused = fused.fuse(&directed);
}
if fused == *flux {
let size: usize = current
.iter()
.enumerate()
.map(|(axis, &bi)| indices[axis].block_dim(bi))
.product();
blocks.push(BlockMeta {
coord: BlockCoord(current.clone()),
offset: total_size,
size,
});
total_size += size;
}
let mut carry = true;
for axis in (0..rank).rev() {
current[axis] += 1;
if current[axis] < num_blocks_per_leg[axis] {
carry = false;
break;
}
current[axis] = 0;
}
if carry {
break;
}
}
}
let mut block_index = HashMap::with_capacity(blocks.len());
for (i, meta) in blocks.iter().enumerate() {
block_index.insert(meta.coord.clone(), i);
}
(blocks, block_index, shape, total_size)
}