prav_core/decoder/builder.rs
1//! Builder pattern for [`DecodingState`] construction.
2//!
3//! This module provides a type-safe way to construct decoders without
4//! manually calculating the `STRIDE_Y` const generic.
5//!
6//! # Motivation
7//!
8//! The `DecodingState` struct requires a `STRIDE_Y` const generic that must
9//! equal `max(width, height, depth).next_power_of_two()`. Getting this wrong
10//! causes a runtime panic. The builder pattern eliminates this error-prone
11//! manual calculation.
12//!
13//! # Example
14//!
15//! ```ignore
16//! use prav_core::{Arena, DecoderBuilder, SquareGrid, EdgeCorrection, required_buffer_size};
17//!
18//! let size = required_buffer_size(32, 32, 1);
19//! let mut buffer = [0u8; size];
20//! let mut arena = Arena::new(&mut buffer);
21//!
22//! // Builder automatically selects correct STRIDE_Y
23//! let mut decoder = DecoderBuilder::<SquareGrid>::new()
24//! .dimensions(32, 32)
25//! .build(&mut arena)
26//! .unwrap();
27//!
28//! let syndromes = [0u64; 16];
29//! decoder.load_dense_syndromes(&syndromes);
30//! ```
31
32use crate::arena::Arena;
33use crate::decoder::state::DecodingState;
34use crate::decoder::types::EdgeCorrection;
35use crate::decoder::growth::ClusterGrowth;
36use crate::topology::Topology;
37use core::marker::PhantomData;
38
39/// Builder for constructing [`DecodingState`] instances.
40///
41/// The builder pattern eliminates the need to manually calculate `STRIDE_Y`,
42/// preventing the common pitfall of mismatched const generics.
43///
44/// # Type Parameter
45///
46/// * `T` - The topology type (e.g., [`SquareGrid`](crate::SquareGrid)).
47///
48/// # Example
49///
50/// ```ignore
51/// use prav_core::{Arena, DecoderBuilder, SquareGrid, required_buffer_size};
52///
53/// let size = required_buffer_size(64, 64, 1);
54/// let mut buffer = vec![0u8; size];
55/// let mut arena = Arena::new(&mut buffer);
56///
57/// let decoder = DecoderBuilder::<SquareGrid>::new()
58/// .dimensions(64, 64)
59/// .build(&mut arena)
60/// .expect("Failed to build decoder");
61/// ```
62pub struct DecoderBuilder<T: Topology> {
63 width: usize,
64 height: usize,
65 depth: usize,
66 _marker: PhantomData<T>,
67}
68
69impl<T: Topology> DecoderBuilder<T> {
70 /// Creates a new decoder builder with default dimensions.
71 ///
72 /// You must call [`dimensions`](Self::dimensions) or
73 /// [`dimensions_3d`](Self::dimensions_3d) before [`build`](Self::build).
74 #[must_use]
75 pub const fn new() -> Self {
76 Self {
77 width: 0,
78 height: 0,
79 depth: 1,
80 _marker: PhantomData,
81 }
82 }
83
84 /// Sets the grid dimensions for a 2D code.
85 ///
86 /// # Arguments
87 ///
88 /// * `width` - Grid width in nodes.
89 /// * `height` - Grid height in nodes.
90 #[must_use]
91 pub const fn dimensions(mut self, width: usize, height: usize) -> Self {
92 self.width = width;
93 self.height = height;
94 self.depth = 1;
95 self
96 }
97
98 /// Sets the grid dimensions for a 3D code.
99 ///
100 /// # Arguments
101 ///
102 /// * `width` - Grid width in nodes.
103 /// * `height` - Grid height in nodes.
104 /// * `depth` - Grid depth in nodes.
105 #[must_use]
106 pub const fn dimensions_3d(mut self, width: usize, height: usize, depth: usize) -> Self {
107 self.width = width;
108 self.height = height;
109 self.depth = depth;
110 self
111 }
112
113 /// Calculates the required `STRIDE_Y` for the configured dimensions.
114 ///
115 /// This is the value that would need to be specified as the const generic
116 /// when using [`DecodingState`] directly.
117 #[must_use]
118 pub const fn stride_y(&self) -> usize {
119 let is_3d = self.depth > 1;
120 let max_dim = const_max(self.width, const_max(self.height, if is_3d { self.depth } else { 1 }));
121 max_dim.next_power_of_two()
122 }
123
124 /// Builds the decoder with the appropriate `STRIDE_Y`.
125 ///
126 /// This method uses a dispatch table to select the correct const generic
127 /// at runtime, then constructs the decoder.
128 ///
129 /// # Errors
130 ///
131 /// Returns an error if:
132 /// - Dimensions are not set (width or height is 0).
133 /// - The grid is too large (max dimension > 512).
134 ///
135 /// # Example
136 ///
137 /// ```ignore
138 /// let decoder = DecoderBuilder::<SquareGrid>::new()
139 /// .dimensions(32, 32)
140 /// .build(&mut arena)?;
141 /// ```
142 pub fn build<'a>(self, arena: &mut Arena<'a>) -> Result<DynDecoder<'a, T>, &'static str> {
143 if self.width == 0 || self.height == 0 {
144 return Err("Dimensions not set: call dimensions() or dimensions_3d() first");
145 }
146
147 let stride = self.stride_y();
148
149 match stride {
150 1 => Ok(DynDecoder::S1(DecodingState::<T, 1>::new(
151 arena, self.width, self.height, self.depth
152 ))),
153 2 => Ok(DynDecoder::S2(DecodingState::<T, 2>::new(
154 arena, self.width, self.height, self.depth
155 ))),
156 4 => Ok(DynDecoder::S4(DecodingState::<T, 4>::new(
157 arena, self.width, self.height, self.depth
158 ))),
159 8 => Ok(DynDecoder::S8(DecodingState::<T, 8>::new(
160 arena, self.width, self.height, self.depth
161 ))),
162 16 => Ok(DynDecoder::S16(DecodingState::<T, 16>::new(
163 arena, self.width, self.height, self.depth
164 ))),
165 32 => Ok(DynDecoder::S32(DecodingState::<T, 32>::new(
166 arena, self.width, self.height, self.depth
167 ))),
168 64 => Ok(DynDecoder::S64(DecodingState::<T, 64>::new(
169 arena, self.width, self.height, self.depth
170 ))),
171 128 => Ok(DynDecoder::S128(DecodingState::<T, 128>::new(
172 arena, self.width, self.height, self.depth
173 ))),
174 256 => Ok(DynDecoder::S256(DecodingState::<T, 256>::new(
175 arena, self.width, self.height, self.depth
176 ))),
177 512 => Ok(DynDecoder::S512(DecodingState::<T, 512>::new(
178 arena, self.width, self.height, self.depth
179 ))),
180 _ => Err("Grid too large: max dimension exceeds 512"),
181 }
182 }
183}
184
185impl<T: Topology> Default for DecoderBuilder<T> {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191/// Const-compatible max function.
192const fn const_max(a: usize, b: usize) -> usize {
193 if a > b { a } else { b }
194}
195
196/// Dynamic decoder wrapper that hides the `STRIDE_Y` const generic.
197///
198/// This enum provides a unified interface regardless of the underlying
199/// stride, at the cost of a small dispatch overhead per method call.
200///
201/// # Performance Note
202///
203/// For maximum performance in tight loops, prefer using [`DecodingState`]
204/// directly with the correct const generic. The dynamic dispatch overhead
205/// is typically negligible for most use cases.
206///
207/// # Example
208///
209/// ```ignore
210/// let mut decoder = DecoderBuilder::<SquareGrid>::new()
211/// .dimensions(32, 32)
212/// .build(&mut arena)?;
213///
214/// // Use unified interface regardless of stride
215/// decoder.load_dense_syndromes(&syndromes);
216/// decoder.grow_clusters();
217/// let count = decoder.peel_forest(&mut corrections);
218/// decoder.reset_for_next_cycle();
219/// ```
220pub enum DynDecoder<'a, T: Topology> {
221 /// Stride 1 (1x1 grids).
222 S1(DecodingState<'a, T, 1>),
223 /// Stride 2 (up to 2x2 grids).
224 S2(DecodingState<'a, T, 2>),
225 /// Stride 4 (up to 4x4 grids).
226 S4(DecodingState<'a, T, 4>),
227 /// Stride 8 (up to 8x8 grids).
228 S8(DecodingState<'a, T, 8>),
229 /// Stride 16 (up to 16x16 grids).
230 S16(DecodingState<'a, T, 16>),
231 /// Stride 32 (up to 32x32 grids).
232 S32(DecodingState<'a, T, 32>),
233 /// Stride 64 (up to 64x64 grids).
234 S64(DecodingState<'a, T, 64>),
235 /// Stride 128 (up to 128x128 grids).
236 S128(DecodingState<'a, T, 128>),
237 /// Stride 256 (up to 256x256 grids).
238 S256(DecodingState<'a, T, 256>),
239 /// Stride 512 (up to 512x512 grids).
240 S512(DecodingState<'a, T, 512>),
241}
242
243/// Helper macro to dispatch method calls to the inner decoder.
244macro_rules! dispatch {
245 ($self:expr, $method:ident $(, $arg:expr)*) => {
246 match $self {
247 DynDecoder::S1(d) => d.$method($($arg),*),
248 DynDecoder::S2(d) => d.$method($($arg),*),
249 DynDecoder::S4(d) => d.$method($($arg),*),
250 DynDecoder::S8(d) => d.$method($($arg),*),
251 DynDecoder::S16(d) => d.$method($($arg),*),
252 DynDecoder::S32(d) => d.$method($($arg),*),
253 DynDecoder::S64(d) => d.$method($($arg),*),
254 DynDecoder::S128(d) => d.$method($($arg),*),
255 DynDecoder::S256(d) => d.$method($($arg),*),
256 DynDecoder::S512(d) => d.$method($($arg),*),
257 }
258 };
259}
260
261impl<'a, T: Topology> DynDecoder<'a, T> {
262 /// Loads syndrome measurements from a dense bitarray.
263 ///
264 /// Each `u64` in the slice represents 64 consecutive nodes, where bit `i`
265 /// being set indicates a syndrome at node `(block_index * 64 + i)`.
266 ///
267 /// # Arguments
268 ///
269 /// * `syndromes` - Dense syndrome bitarray with one `u64` per 64-node block.
270 #[inline]
271 pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
272 dispatch!(self, load_dense_syndromes, syndromes);
273 }
274
275 /// Performs full cluster growth until convergence.
276 ///
277 /// This iteratively expands syndrome clusters until all defects are paired
278 /// or reach boundaries.
279 #[inline]
280 pub fn grow_clusters(&mut self) {
281 dispatch!(self, grow_clusters);
282 }
283
284 /// Performs a single growth iteration.
285 ///
286 /// Returns `true` if more iterations are needed, `false` if converged.
287 #[inline]
288 pub fn grow_iteration(&mut self) -> bool {
289 dispatch!(self, grow_iteration)
290 }
291
292 /// Extracts corrections by peeling the cluster forest.
293 ///
294 /// This traces paths from defects and accumulates edge corrections.
295 ///
296 /// # Arguments
297 ///
298 /// * `corrections` - Output buffer for edge corrections.
299 ///
300 /// # Returns
301 ///
302 /// The number of corrections written to the buffer.
303 #[inline]
304 pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
305 dispatch!(self, peel_forest, corrections)
306 }
307
308 /// Performs full decode cycle (grow + peel).
309 ///
310 /// This is equivalent to calling [`grow_clusters`](Self::grow_clusters)
311 /// followed by [`peel_forest`](Self::peel_forest).
312 ///
313 /// # Arguments
314 ///
315 /// * `corrections` - Output buffer for edge corrections.
316 ///
317 /// # Returns
318 ///
319 /// The number of corrections written to the buffer.
320 #[inline]
321 pub fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
322 dispatch!(self, decode, corrections)
323 }
324
325 /// Resets state for the next decoding cycle (sparse reset).
326 ///
327 /// This efficiently resets only the blocks that were modified during
328 /// the previous decoding cycle.
329 #[inline]
330 pub fn reset_for_next_cycle(&mut self) {
331 dispatch!(self, sparse_reset);
332 }
333
334 /// Fully resets all decoder state.
335 ///
336 /// This performs a complete reset of all internal data structures.
337 /// For repeated decoding, prefer [`reset_for_next_cycle`](Self::reset_for_next_cycle).
338 #[inline]
339 pub fn full_reset(&mut self) {
340 dispatch!(self, initialize_internal);
341 }
342
343 /// Returns the grid width.
344 #[inline]
345 #[must_use]
346 pub fn width(&self) -> usize {
347 match self {
348 DynDecoder::S1(d) => d.width,
349 DynDecoder::S2(d) => d.width,
350 DynDecoder::S4(d) => d.width,
351 DynDecoder::S8(d) => d.width,
352 DynDecoder::S16(d) => d.width,
353 DynDecoder::S32(d) => d.width,
354 DynDecoder::S64(d) => d.width,
355 DynDecoder::S128(d) => d.width,
356 DynDecoder::S256(d) => d.width,
357 DynDecoder::S512(d) => d.width,
358 }
359 }
360
361 /// Returns the grid height.
362 #[inline]
363 #[must_use]
364 pub fn height(&self) -> usize {
365 match self {
366 DynDecoder::S1(d) => d.height,
367 DynDecoder::S2(d) => d.height,
368 DynDecoder::S4(d) => d.height,
369 DynDecoder::S8(d) => d.height,
370 DynDecoder::S16(d) => d.height,
371 DynDecoder::S32(d) => d.height,
372 DynDecoder::S64(d) => d.height,
373 DynDecoder::S128(d) => d.height,
374 DynDecoder::S256(d) => d.height,
375 DynDecoder::S512(d) => d.height,
376 }
377 }
378
379 /// Returns the stride Y value.
380 #[inline]
381 #[must_use]
382 pub fn stride_y(&self) -> usize {
383 match self {
384 DynDecoder::S1(d) => d.stride_y,
385 DynDecoder::S2(d) => d.stride_y,
386 DynDecoder::S4(d) => d.stride_y,
387 DynDecoder::S8(d) => d.stride_y,
388 DynDecoder::S16(d) => d.stride_y,
389 DynDecoder::S32(d) => d.stride_y,
390 DynDecoder::S64(d) => d.stride_y,
391 DynDecoder::S128(d) => d.stride_y,
392 DynDecoder::S256(d) => d.stride_y,
393 DynDecoder::S512(d) => d.stride_y,
394 }
395 }
396}