prav_core/decoder/
builder.rs

1//! Builder pattern for [`DecodingState`] construction.
2//!
3//! This module provides a type-safe way to construct decoders without
4//! manually calculating the `STRIDE_Y` const generic.
5//!
6//! # Motivation
7//!
8//! The `DecodingState` struct requires a `STRIDE_Y` const generic that must
9//! equal `max(width, height, depth).next_power_of_two()`. Getting this wrong
10//! causes a runtime panic. The builder pattern eliminates this error-prone
11//! manual calculation.
12//!
13//! # Example
14//!
15//! ```ignore
16//! use prav_core::{Arena, DecoderBuilder, SquareGrid, EdgeCorrection, required_buffer_size};
17//!
18//! let size = required_buffer_size(32, 32, 1);
19//! let mut buffer = [0u8; size];
20//! let mut arena = Arena::new(&mut buffer);
21//!
22//! // Builder automatically selects correct STRIDE_Y
23//! let mut decoder = DecoderBuilder::<SquareGrid>::new()
24//!     .dimensions(32, 32)
25//!     .build(&mut arena)
26//!     .unwrap();
27//!
28//! let syndromes = [0u64; 16];
29//! decoder.load_dense_syndromes(&syndromes);
30//! ```
31
32use crate::arena::Arena;
33use crate::decoder::state::DecodingState;
34use crate::decoder::types::EdgeCorrection;
35use crate::decoder::growth::ClusterGrowth;
36use crate::topology::Topology;
37use core::marker::PhantomData;
38
39/// Builder for constructing [`DecodingState`] instances.
40///
41/// The builder pattern eliminates the need to manually calculate `STRIDE_Y`,
42/// preventing the common pitfall of mismatched const generics.
43///
44/// # Type Parameter
45///
46/// * `T` - The topology type (e.g., [`SquareGrid`](crate::SquareGrid)).
47///
48/// # Example
49///
50/// ```ignore
51/// use prav_core::{Arena, DecoderBuilder, SquareGrid, required_buffer_size};
52///
53/// let size = required_buffer_size(64, 64, 1);
54/// let mut buffer = vec![0u8; size];
55/// let mut arena = Arena::new(&mut buffer);
56///
57/// let decoder = DecoderBuilder::<SquareGrid>::new()
58///     .dimensions(64, 64)
59///     .build(&mut arena)
60///     .expect("Failed to build decoder");
61/// ```
62pub struct DecoderBuilder<T: Topology> {
63    width: usize,
64    height: usize,
65    depth: usize,
66    _marker: PhantomData<T>,
67}
68
69impl<T: Topology> DecoderBuilder<T> {
70    /// Creates a new decoder builder with default dimensions.
71    ///
72    /// You must call [`dimensions`](Self::dimensions) or
73    /// [`dimensions_3d`](Self::dimensions_3d) before [`build`](Self::build).
74    #[must_use]
75    pub const fn new() -> Self {
76        Self {
77            width: 0,
78            height: 0,
79            depth: 1,
80            _marker: PhantomData,
81        }
82    }
83
84    /// Sets the grid dimensions for a 2D code.
85    ///
86    /// # Arguments
87    ///
88    /// * `width` - Grid width in nodes.
89    /// * `height` - Grid height in nodes.
90    #[must_use]
91    pub const fn dimensions(mut self, width: usize, height: usize) -> Self {
92        self.width = width;
93        self.height = height;
94        self.depth = 1;
95        self
96    }
97
98    /// Sets the grid dimensions for a 3D code.
99    ///
100    /// # Arguments
101    ///
102    /// * `width` - Grid width in nodes.
103    /// * `height` - Grid height in nodes.
104    /// * `depth` - Grid depth in nodes.
105    #[must_use]
106    pub const fn dimensions_3d(mut self, width: usize, height: usize, depth: usize) -> Self {
107        self.width = width;
108        self.height = height;
109        self.depth = depth;
110        self
111    }
112
113    /// Calculates the required `STRIDE_Y` for the configured dimensions.
114    ///
115    /// This is the value that would need to be specified as the const generic
116    /// when using [`DecodingState`] directly.
117    #[must_use]
118    pub const fn stride_y(&self) -> usize {
119        let is_3d = self.depth > 1;
120        let max_dim = const_max(self.width, const_max(self.height, if is_3d { self.depth } else { 1 }));
121        max_dim.next_power_of_two()
122    }
123
124    /// Builds the decoder with the appropriate `STRIDE_Y`.
125    ///
126    /// This method uses a dispatch table to select the correct const generic
127    /// at runtime, then constructs the decoder.
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if:
132    /// - Dimensions are not set (width or height is 0).
133    /// - The grid is too large (max dimension > 512).
134    ///
135    /// # Example
136    ///
137    /// ```ignore
138    /// let decoder = DecoderBuilder::<SquareGrid>::new()
139    ///     .dimensions(32, 32)
140    ///     .build(&mut arena)?;
141    /// ```
142    pub fn build<'a>(self, arena: &mut Arena<'a>) -> Result<DynDecoder<'a, T>, &'static str> {
143        if self.width == 0 || self.height == 0 {
144            return Err("Dimensions not set: call dimensions() or dimensions_3d() first");
145        }
146
147        let stride = self.stride_y();
148
149        match stride {
150            1 => Ok(DynDecoder::S1(DecodingState::<T, 1>::new(
151                arena, self.width, self.height, self.depth
152            ))),
153            2 => Ok(DynDecoder::S2(DecodingState::<T, 2>::new(
154                arena, self.width, self.height, self.depth
155            ))),
156            4 => Ok(DynDecoder::S4(DecodingState::<T, 4>::new(
157                arena, self.width, self.height, self.depth
158            ))),
159            8 => Ok(DynDecoder::S8(DecodingState::<T, 8>::new(
160                arena, self.width, self.height, self.depth
161            ))),
162            16 => Ok(DynDecoder::S16(DecodingState::<T, 16>::new(
163                arena, self.width, self.height, self.depth
164            ))),
165            32 => Ok(DynDecoder::S32(DecodingState::<T, 32>::new(
166                arena, self.width, self.height, self.depth
167            ))),
168            64 => Ok(DynDecoder::S64(DecodingState::<T, 64>::new(
169                arena, self.width, self.height, self.depth
170            ))),
171            128 => Ok(DynDecoder::S128(DecodingState::<T, 128>::new(
172                arena, self.width, self.height, self.depth
173            ))),
174            256 => Ok(DynDecoder::S256(DecodingState::<T, 256>::new(
175                arena, self.width, self.height, self.depth
176            ))),
177            512 => Ok(DynDecoder::S512(DecodingState::<T, 512>::new(
178                arena, self.width, self.height, self.depth
179            ))),
180            _ => Err("Grid too large: max dimension exceeds 512"),
181        }
182    }
183}
184
185impl<T: Topology> Default for DecoderBuilder<T> {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191/// Const-compatible max function.
192const fn const_max(a: usize, b: usize) -> usize {
193    if a > b { a } else { b }
194}
195
196/// Dynamic decoder wrapper that hides the `STRIDE_Y` const generic.
197///
198/// This enum provides a unified interface regardless of the underlying
199/// stride, at the cost of a small dispatch overhead per method call.
200///
201/// # Performance Note
202///
203/// For maximum performance in tight loops, prefer using [`DecodingState`]
204/// directly with the correct const generic. The dynamic dispatch overhead
205/// is typically negligible for most use cases.
206///
207/// # Example
208///
209/// ```ignore
210/// let mut decoder = DecoderBuilder::<SquareGrid>::new()
211///     .dimensions(32, 32)
212///     .build(&mut arena)?;
213///
214/// // Use unified interface regardless of stride
215/// decoder.load_dense_syndromes(&syndromes);
216/// decoder.grow_clusters();
217/// let count = decoder.peel_forest(&mut corrections);
218/// decoder.reset_for_next_cycle();
219/// ```
220pub enum DynDecoder<'a, T: Topology> {
221    /// Stride 1 (1x1 grids).
222    S1(DecodingState<'a, T, 1>),
223    /// Stride 2 (up to 2x2 grids).
224    S2(DecodingState<'a, T, 2>),
225    /// Stride 4 (up to 4x4 grids).
226    S4(DecodingState<'a, T, 4>),
227    /// Stride 8 (up to 8x8 grids).
228    S8(DecodingState<'a, T, 8>),
229    /// Stride 16 (up to 16x16 grids).
230    S16(DecodingState<'a, T, 16>),
231    /// Stride 32 (up to 32x32 grids).
232    S32(DecodingState<'a, T, 32>),
233    /// Stride 64 (up to 64x64 grids).
234    S64(DecodingState<'a, T, 64>),
235    /// Stride 128 (up to 128x128 grids).
236    S128(DecodingState<'a, T, 128>),
237    /// Stride 256 (up to 256x256 grids).
238    S256(DecodingState<'a, T, 256>),
239    /// Stride 512 (up to 512x512 grids).
240    S512(DecodingState<'a, T, 512>),
241}
242
243/// Helper macro to dispatch method calls to the inner decoder.
244macro_rules! dispatch {
245    ($self:expr, $method:ident $(, $arg:expr)*) => {
246        match $self {
247            DynDecoder::S1(d) => d.$method($($arg),*),
248            DynDecoder::S2(d) => d.$method($($arg),*),
249            DynDecoder::S4(d) => d.$method($($arg),*),
250            DynDecoder::S8(d) => d.$method($($arg),*),
251            DynDecoder::S16(d) => d.$method($($arg),*),
252            DynDecoder::S32(d) => d.$method($($arg),*),
253            DynDecoder::S64(d) => d.$method($($arg),*),
254            DynDecoder::S128(d) => d.$method($($arg),*),
255            DynDecoder::S256(d) => d.$method($($arg),*),
256            DynDecoder::S512(d) => d.$method($($arg),*),
257        }
258    };
259}
260
261impl<'a, T: Topology> DynDecoder<'a, T> {
262    /// Loads syndrome measurements from a dense bitarray.
263    ///
264    /// Each `u64` in the slice represents 64 consecutive nodes, where bit `i`
265    /// being set indicates a syndrome at node `(block_index * 64 + i)`.
266    ///
267    /// # Arguments
268    ///
269    /// * `syndromes` - Dense syndrome bitarray with one `u64` per 64-node block.
270    #[inline]
271    pub fn load_dense_syndromes(&mut self, syndromes: &[u64]) {
272        dispatch!(self, load_dense_syndromes, syndromes);
273    }
274
275    /// Performs full cluster growth until convergence.
276    ///
277    /// This iteratively expands syndrome clusters until all defects are paired
278    /// or reach boundaries.
279    #[inline]
280    pub fn grow_clusters(&mut self) {
281        dispatch!(self, grow_clusters);
282    }
283
284    /// Performs a single growth iteration.
285    ///
286    /// Returns `true` if more iterations are needed, `false` if converged.
287    #[inline]
288    pub fn grow_iteration(&mut self) -> bool {
289        dispatch!(self, grow_iteration)
290    }
291
292    /// Extracts corrections by peeling the cluster forest.
293    ///
294    /// This traces paths from defects and accumulates edge corrections.
295    ///
296    /// # Arguments
297    ///
298    /// * `corrections` - Output buffer for edge corrections.
299    ///
300    /// # Returns
301    ///
302    /// The number of corrections written to the buffer.
303    #[inline]
304    pub fn peel_forest(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
305        dispatch!(self, peel_forest, corrections)
306    }
307
308    /// Performs full decode cycle (grow + peel).
309    ///
310    /// This is equivalent to calling [`grow_clusters`](Self::grow_clusters)
311    /// followed by [`peel_forest`](Self::peel_forest).
312    ///
313    /// # Arguments
314    ///
315    /// * `corrections` - Output buffer for edge corrections.
316    ///
317    /// # Returns
318    ///
319    /// The number of corrections written to the buffer.
320    #[inline]
321    pub fn decode(&mut self, corrections: &mut [EdgeCorrection]) -> usize {
322        dispatch!(self, decode, corrections)
323    }
324
325    /// Resets state for the next decoding cycle (sparse reset).
326    ///
327    /// This efficiently resets only the blocks that were modified during
328    /// the previous decoding cycle.
329    #[inline]
330    pub fn reset_for_next_cycle(&mut self) {
331        dispatch!(self, sparse_reset);
332    }
333
334    /// Fully resets all decoder state.
335    ///
336    /// This performs a complete reset of all internal data structures.
337    /// For repeated decoding, prefer [`reset_for_next_cycle`](Self::reset_for_next_cycle).
338    #[inline]
339    pub fn full_reset(&mut self) {
340        dispatch!(self, initialize_internal);
341    }
342
343    /// Returns the grid width.
344    #[inline]
345    #[must_use]
346    pub fn width(&self) -> usize {
347        match self {
348            DynDecoder::S1(d) => d.width,
349            DynDecoder::S2(d) => d.width,
350            DynDecoder::S4(d) => d.width,
351            DynDecoder::S8(d) => d.width,
352            DynDecoder::S16(d) => d.width,
353            DynDecoder::S32(d) => d.width,
354            DynDecoder::S64(d) => d.width,
355            DynDecoder::S128(d) => d.width,
356            DynDecoder::S256(d) => d.width,
357            DynDecoder::S512(d) => d.width,
358        }
359    }
360
361    /// Returns the grid height.
362    #[inline]
363    #[must_use]
364    pub fn height(&self) -> usize {
365        match self {
366            DynDecoder::S1(d) => d.height,
367            DynDecoder::S2(d) => d.height,
368            DynDecoder::S4(d) => d.height,
369            DynDecoder::S8(d) => d.height,
370            DynDecoder::S16(d) => d.height,
371            DynDecoder::S32(d) => d.height,
372            DynDecoder::S64(d) => d.height,
373            DynDecoder::S128(d) => d.height,
374            DynDecoder::S256(d) => d.height,
375            DynDecoder::S512(d) => d.height,
376        }
377    }
378
379    /// Returns the stride Y value.
380    #[inline]
381    #[must_use]
382    pub fn stride_y(&self) -> usize {
383        match self {
384            DynDecoder::S1(d) => d.stride_y,
385            DynDecoder::S2(d) => d.stride_y,
386            DynDecoder::S4(d) => d.stride_y,
387            DynDecoder::S8(d) => d.stride_y,
388            DynDecoder::S16(d) => d.stride_y,
389            DynDecoder::S32(d) => d.stride_y,
390            DynDecoder::S64(d) => d.stride_y,
391            DynDecoder::S128(d) => d.stride_y,
392            DynDecoder::S256(d) => d.stride_y,
393            DynDecoder::S512(d) => d.stride_y,
394        }
395    }
396}