ariadnetor_tensor/block_sparse/
layout.rs1use std::collections::HashMap;
11
12use ariadnetor_core::backend::MemoryOrder;
13
14use super::{BlockCoord, BlockMeta, Direction, QNIndex};
15use crate::{Sector, TensorLayout};
16
17#[derive(Clone)]
31pub struct BlockSparseLayout<S: Sector> {
32 blocks: Vec<BlockMeta>,
33 block_index: HashMap<BlockCoord, usize>,
34 indices: Vec<QNIndex<S>>,
35 flux: S,
36 shape: Vec<usize>,
37 order: MemoryOrder,
38 storage_extent: usize,
41}
42
43impl<S: Sector> BlockSparseLayout<S> {
44 pub fn new(indices: Vec<QNIndex<S>>, flux: S, order: MemoryOrder) -> Self {
52 let (blocks, block_index, shape, storage_extent) =
53 enumerate_allowed_blocks(&indices, &flux);
54 Self {
55 blocks,
56 block_index,
57 indices,
58 flux,
59 shape,
60 order,
61 storage_extent,
62 }
63 }
64
65 pub fn flux(&self) -> &S {
67 &self.flux
68 }
69
70 pub fn indices(&self) -> &[QNIndex<S>] {
72 &self.indices
73 }
74
75 pub fn num_blocks(&self) -> usize {
77 self.blocks.len()
78 }
79
80 pub fn block_metas(&self) -> &[BlockMeta] {
82 &self.blocks
83 }
84
85 pub(crate) fn block_index(&self) -> &HashMap<BlockCoord, usize> {
87 &self.block_index
88 }
89
90 pub fn shape(&self) -> &[usize] {
92 &self.shape
93 }
94
95 pub fn rank(&self) -> usize {
97 self.indices.len()
98 }
99
100 pub fn order(&self) -> MemoryOrder {
102 self.order
103 }
104
105 pub fn block_shape(&self, coord: &BlockCoord) -> Option<Vec<usize>> {
107 if coord.0.len() != self.indices.len() {
108 return None;
109 }
110 let mut shape = Vec::with_capacity(coord.0.len());
111 for (axis, &block_idx) in coord.0.iter().enumerate() {
112 if block_idx >= self.indices[axis].num_blocks() {
113 return None;
114 }
115 shape.push(self.indices[axis].block_dim(block_idx));
116 }
117 Some(shape)
118 }
119
120 pub(crate) fn dagger_layout(&self) -> Self {
130 let flipped_indices: Vec<QNIndex<S>> = self
131 .indices
132 .iter()
133 .map(|idx| {
134 let new_dir = match idx.direction() {
135 Direction::Out => Direction::In,
136 Direction::In => Direction::Out,
137 };
138 QNIndex::new(idx.blocks().to_vec(), new_dir)
139 })
140 .collect();
141 Self {
142 blocks: self.blocks.clone(),
143 block_index: self.block_index.clone(),
144 indices: flipped_indices,
145 flux: self.flux.dual(),
146 shape: self.shape.clone(),
147 order: self.order,
148 storage_extent: self.storage_extent,
149 }
150 }
151
152 pub fn is_allowed_block(&self, coord: &BlockCoord) -> bool {
154 if coord.0.len() != self.indices.len() {
155 return false;
156 }
157 let mut fused = S::identity();
158 for (axis, &block_idx) in coord.0.iter().enumerate() {
159 let idx = &self.indices[axis];
160 if block_idx >= idx.num_blocks() {
161 return false;
162 }
163 let sector = idx.sector(block_idx);
164 let directed = idx.direction().apply(sector);
165 fused = fused.fuse(&directed);
166 }
167 fused == self.flux
168 }
169}
170
171impl<S: Sector> TensorLayout for BlockSparseLayout<S> {
172 fn shape(&self) -> &[usize] {
173 &self.shape
174 }
175
176 fn storage_extent(&self) -> usize {
177 self.storage_extent
178 }
179}
180
181fn enumerate_allowed_blocks<S: Sector>(
186 indices: &[QNIndex<S>],
187 flux: &S,
188) -> (
189 Vec<BlockMeta>,
190 HashMap<BlockCoord, usize>,
191 Vec<usize>,
192 usize,
193) {
194 let shape: Vec<usize> = indices.iter().map(|idx| idx.total_dim()).collect();
195 let rank = indices.len();
196 let num_blocks_per_leg: Vec<usize> = indices.iter().map(|idx| idx.num_blocks()).collect();
197
198 let mut blocks = Vec::new();
199 let mut total_size = 0usize;
200
201 if rank == 0 || num_blocks_per_leg.iter().all(|&n| n > 0) {
202 let mut current = vec![0usize; rank];
203 loop {
204 let mut fused = S::identity();
205 for (axis, &bi) in current.iter().enumerate() {
206 let sector = indices[axis].sector(bi);
207 let directed = indices[axis].direction().apply(sector);
208 fused = fused.fuse(&directed);
209 }
210
211 if fused == *flux {
212 let size: usize = current
213 .iter()
214 .enumerate()
215 .map(|(axis, &bi)| indices[axis].block_dim(bi))
216 .product();
217 blocks.push(BlockMeta {
218 coord: BlockCoord(current.clone()),
219 offset: total_size,
220 size,
221 });
222 total_size += size;
223 }
224
225 let mut carry = true;
226 for axis in (0..rank).rev() {
227 current[axis] += 1;
228 if current[axis] < num_blocks_per_leg[axis] {
229 carry = false;
230 break;
231 }
232 current[axis] = 0;
233 }
234 if carry {
235 break;
236 }
237 }
238 }
239
240 let mut block_index = HashMap::with_capacity(blocks.len());
241 for (i, meta) in blocks.iter().enumerate() {
242 block_index.insert(meta.coord.clone(), i);
243 }
244
245 (blocks, block_index, shape, total_size)
246}