Skip to main content

ariadnetor_tensor/block_sparse/
layout.rs

1//! `BlockSparseLayout<S>`: interpretation half of the block-sparse tensor split.
2//!
3//! Carries block metadata (allowed-block enumeration with offsets and
4//! sizes), per-leg sector indices, flux, logical shape, and memory
5//! order. Data lives on
6//! [`BlockSparseStorage<T>`](crate::BlockSparseStorage); the wrapper
7//! [`BlockSparseTensorData<T, S>`](crate::BlockSparseTensorData) joins
8//! the two with a length-consistency check.
9
10use std::collections::HashMap;
11
12use ariadnetor_core::backend::MemoryOrder;
13
14use super::{BlockCoord, BlockMeta, Direction, QNIndex};
15use crate::{Sector, TensorLayout};
16
17/// Interpretation half of the block-sparse tensor split.
18///
19/// Holds the allowed-block enumeration (sorted by coordinate, packed
20/// offsets), the per-leg sector indices, the conserved flux, the
21/// cached logical shape, and the memory order the paired
22/// [`BlockSparseStorage`](crate::BlockSparseStorage) is laid out in.
23///
24/// Construction goes through [`new`](Self::new), which enumerates
25/// flux-allowed blocks and produces a packed layout. Layout-internal
26/// invariants (sector conservation, coord uniqueness, no-gap
27/// packing) hold by construction; the storage-layout boundary check
28/// happens in
29/// [`TensorData::new`](crate::TensorData::new).
30#[derive(Clone)]
31pub struct BlockSparseLayout<S: Sector> {
32    blocks: Vec<BlockMeta>,
33    block_index: HashMap<BlockCoord, usize>,
34    indices: Vec<QNIndex<S>>,
35    flux: S,
36    shape: Vec<usize>,
37    order: MemoryOrder,
38    /// Cached sum of allowed block sizes; equals expected
39    /// [`BlockSparseStorage::flat_len`](crate::BlockSparseStorage::flat_len).
40    storage_extent: usize,
41}
42
43impl<S: Sector> BlockSparseLayout<S> {
44    /// Construct a layout by enumerating flux-allowed blocks.
45    ///
46    /// The resulting layout has blocks sorted by coordinate
47    /// (lexicographic) with packed offsets (no gaps or overlaps),
48    /// every block satisfying the flux-conservation law, and a
49    /// cached `storage_extent` equal to the sum of allowed block
50    /// sizes.
51    pub fn new(indices: Vec<QNIndex<S>>, flux: S, order: MemoryOrder) -> Self {
52        let (blocks, block_index, shape, storage_extent) =
53            enumerate_allowed_blocks(&indices, &flux);
54        Self {
55            blocks,
56            block_index,
57            indices,
58            flux,
59            shape,
60            order,
61            storage_extent,
62        }
63    }
64
65    /// Conserved flux (total quantum number).
66    pub fn flux(&self) -> &S {
67        &self.flux
68    }
69
70    /// Per-leg QN indices.
71    pub fn indices(&self) -> &[QNIndex<S>] {
72        &self.indices
73    }
74
75    /// Number of stored (non-zero) blocks.
76    pub fn num_blocks(&self) -> usize {
77        self.blocks.len()
78    }
79
80    /// Block metadata (sorted by coordinate).
81    pub fn block_metas(&self) -> &[BlockMeta] {
82        &self.blocks
83    }
84
85    /// Block-coordinate → blocks index lookup.
86    pub(crate) fn block_index(&self) -> &HashMap<BlockCoord, usize> {
87        &self.block_index
88    }
89
90    /// Logical shape (total dimension per leg).
91    pub fn shape(&self) -> &[usize] {
92        &self.shape
93    }
94
95    /// Rank (number of legs).
96    pub fn rank(&self) -> usize {
97        self.indices.len()
98    }
99
100    /// Memory order the paired storage is laid out in.
101    pub fn order(&self) -> MemoryOrder {
102        self.order
103    }
104
105    /// Shape of a specific block, or `None` if the coordinate is out of range.
106    pub fn block_shape(&self, coord: &BlockCoord) -> Option<Vec<usize>> {
107        if coord.0.len() != self.indices.len() {
108            return None;
109        }
110        let mut shape = Vec::with_capacity(coord.0.len());
111        for (axis, &block_idx) in coord.0.iter().enumerate() {
112            if block_idx >= self.indices[axis].num_blocks() {
113                return None;
114            }
115            shape.push(self.indices[axis].block_dim(block_idx));
116        }
117        Some(shape)
118    }
119
120    /// Hermitian-adjoint layout: flip every QNIndex direction (Out↔In)
121    /// and dual the flux.
122    ///
123    /// The allowed-block set is preserved: each block's flux
124    /// contribution becomes `dual(direction.apply(sector))`, whose sum
125    /// equals the dualed flux exactly when the original sum equalled
126    /// the original flux (abelian dual is a group homomorphism).
127    /// `blocks`, `block_index`, `shape`, `order`, and `storage_extent`
128    /// are reused as-is.
129    pub(crate) fn dagger_layout(&self) -> Self {
130        let flipped_indices: Vec<QNIndex<S>> = self
131            .indices
132            .iter()
133            .map(|idx| {
134                let new_dir = match idx.direction() {
135                    Direction::Out => Direction::In,
136                    Direction::In => Direction::Out,
137                };
138                QNIndex::new(idx.blocks().to_vec(), new_dir)
139            })
140            .collect();
141        Self {
142            blocks: self.blocks.clone(),
143            block_index: self.block_index.clone(),
144            indices: flipped_indices,
145            flux: self.flux.dual(),
146            shape: self.shape.clone(),
147            order: self.order,
148            storage_extent: self.storage_extent,
149        }
150    }
151
152    /// Check whether a block coordinate satisfies the flux conservation law.
153    pub fn is_allowed_block(&self, coord: &BlockCoord) -> bool {
154        if coord.0.len() != self.indices.len() {
155            return false;
156        }
157        let mut fused = S::identity();
158        for (axis, &block_idx) in coord.0.iter().enumerate() {
159            let idx = &self.indices[axis];
160            if block_idx >= idx.num_blocks() {
161                return false;
162            }
163            let sector = idx.sector(block_idx);
164            let directed = idx.direction().apply(sector);
165            fused = fused.fuse(&directed);
166        }
167        fused == self.flux
168    }
169}
170
171impl<S: Sector> TensorLayout for BlockSparseLayout<S> {
172    fn shape(&self) -> &[usize] {
173        &self.shape
174    }
175
176    fn storage_extent(&self) -> usize {
177        self.storage_extent
178    }
179}
180
181/// Enumerate flux-allowed blocks for given indices and flux.
182///
183/// Returns `(blocks, block_index, shape, total_size)`. Blocks are
184/// emitted in lexicographic coordinate order with packed offsets.
185fn enumerate_allowed_blocks<S: Sector>(
186    indices: &[QNIndex<S>],
187    flux: &S,
188) -> (
189    Vec<BlockMeta>,
190    HashMap<BlockCoord, usize>,
191    Vec<usize>,
192    usize,
193) {
194    let shape: Vec<usize> = indices.iter().map(|idx| idx.total_dim()).collect();
195    let rank = indices.len();
196    let num_blocks_per_leg: Vec<usize> = indices.iter().map(|idx| idx.num_blocks()).collect();
197
198    let mut blocks = Vec::new();
199    let mut total_size = 0usize;
200
201    if rank == 0 || num_blocks_per_leg.iter().all(|&n| n > 0) {
202        let mut current = vec![0usize; rank];
203        loop {
204            let mut fused = S::identity();
205            for (axis, &bi) in current.iter().enumerate() {
206                let sector = indices[axis].sector(bi);
207                let directed = indices[axis].direction().apply(sector);
208                fused = fused.fuse(&directed);
209            }
210
211            if fused == *flux {
212                let size: usize = current
213                    .iter()
214                    .enumerate()
215                    .map(|(axis, &bi)| indices[axis].block_dim(bi))
216                    .product();
217                blocks.push(BlockMeta {
218                    coord: BlockCoord(current.clone()),
219                    offset: total_size,
220                    size,
221                });
222                total_size += size;
223            }
224
225            let mut carry = true;
226            for axis in (0..rank).rev() {
227                current[axis] += 1;
228                if current[axis] < num_blocks_per_leg[axis] {
229                    carry = false;
230                    break;
231                }
232                current[axis] = 0;
233            }
234            if carry {
235                break;
236            }
237        }
238    }
239
240    let mut block_index = HashMap::with_capacity(blocks.len());
241    for (i, meta) in blocks.iter().enumerate() {
242        block_index.insert(meta.coord.clone(), i);
243    }
244
245    (blocks, block_index, shape, total_size)
246}