Skip to main content

constraint_decoding_trie/
types.rs

1// src/types.rs
2
3use std::fmt;
4
5// ──────────────────────────────────────────────────────────────────────────────
6// TransitionMatrix
7// ──────────────────────────────────────────────────────────────────────────────
8
9/// Stacked CSR transition matrix: interleaves [col_idx, next_state] for
10/// coalesced reads. Row i corresponds to trie node i; values are
11/// (token_id, next_node_id) pairs.
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct TransitionMatrix {
14    /// CSR row pointers: length = num_nodes + 1.
15    /// `row_pointers[i]..row_pointers[i+1]` is the slice in `data` for node i.
16    pub row_pointers: Vec<u32>,
17
18    /// Interleaved pairs: [(token_id, next_node_id), …]
19    pub data: Vec<[u32; 2]>,
20
21    /// Max observed branch factor at each trie depth: length = sid_length.
22    pub max_branches: Vec<u32>,
23
24    /// Total number of trie nodes (states).
25    pub num_nodes: u32,
26
27    /// Vocabulary size |V|.
28    pub vocab_size: u32,
29
30    /// Semantic ID length L.
31    pub sid_length: u32,
32}
33
34impl TransitionMatrix {
35    /// Construct an empty matrix for `num_nodes` nodes.
36    pub fn new(num_nodes: u32, vocab_size: u32, sid_length: u32) -> Self {
37        Self {
38            row_pointers: vec![0u32; num_nodes as usize + 1],
39            data: Vec::new(),
40            max_branches: vec![0u32; sid_length as usize],
41            num_nodes,
42            vocab_size,
43            sid_length,
44        }
45    }
46
47    /// Returns the children slice for `node` as `&[[token_id, next_node]; _]`.
48    ///
49    /// # Panics
50    /// Panics if `node >= num_nodes`.
51    #[inline]
52    pub fn children(&self, node: u32) -> &[[u32; 2]] {
53        assert!(
54            node < self.num_nodes,
55            "node {node} out of range (num_nodes={})",
56            self.num_nodes
57        );
58        let start = self.row_pointers[node as usize] as usize;
59        let end = self.row_pointers[node as usize + 1] as usize;
60        &self.data[start..end]
61    }
62
63    /// Looks up the next node reached from `node` by emitting `token`.
64    /// Returns `None` if the transition does not exist (invalid / masked).
65    #[inline]
66    pub fn next_node(&self, node: u32, token: u32) -> Option<u32> {
67        self.children(node)
68            .iter()
69            .find(|&&[t, _]| t == token)
70            .map(|&[_, n]| n)
71    }
72
73    /// Returns `true` if `node` has no outgoing transitions (i.e. is a leaf).
74    #[inline]
75    pub fn is_leaf(&self, node: u32) -> bool {
76        self.children(node).is_empty()
77    }
78
79    /// Number of outgoing transitions (branches) from `node`.
80    #[inline]
81    pub fn degree(&self, node: u32) -> u32 {
82        self.children(node).len() as u32
83    }
84
85    /// Validates internal invariants; useful inside `debug_assert!`.
86    pub fn check_invariants(&self) -> Result<(), String> {
87        if self.row_pointers.len() != self.num_nodes as usize + 1 {
88            return Err(format!(
89                "row_pointers length {} ≠ num_nodes+1 {}",
90                self.row_pointers.len(),
91                self.num_nodes + 1
92            ));
93        }
94        let last = *self.row_pointers.last().unwrap() as usize;
95        if last != self.data.len() {
96            return Err(format!(
97                "row_pointers tail {last} ≠ data.len() {}",
98                self.data.len()
99            ));
100        }
101        // Rows must be non-decreasing
102        for w in self.row_pointers.windows(2) {
103            if w[0] > w[1] {
104                return Err(format!("row_pointers not monotone: {} > {}", w[0], w[1]));
105            }
106        }
107        // All token ids must be in [0, vocab_size)
108        for &[tok, nxt] in &self.data {
109            if tok >= self.vocab_size {
110                return Err(format!("token {tok} ≥ vocab_size {}", self.vocab_size));
111            }
112            if nxt >= self.num_nodes {
113                return Err(format!("next_node {nxt} ≥ num_nodes {}", self.num_nodes));
114            }
115        }
116        Ok(())
117    }
118}
119
120impl fmt::Display for TransitionMatrix {
121    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122        write!(
123            f,
124            "TransitionMatrix(nodes={}, edges={}, |V|={}, L={})",
125            self.num_nodes,
126            self.data.len(),
127            self.vocab_size,
128            self.sid_length,
129        )
130    }
131}
132
133// ──────────────────────────────────────────────────────────────────────────────
134// DenseMask
135// ──────────────────────────────────────────────────────────────────────────────
136
137/// Bit-packed dense mask for the first `depth` trie layers.
138///
139/// For `depth = 2` and `vocab_size = V` this is a V × V bit matrix stored as
140/// packed `u64` words (row-major).  A set bit at linear index
141/// `i * vocab_size + j` means the 2-token prefix `[i, j]` exists in C.
142///
143/// `states[i * vocab_size + j]` is the trie node reached after emitting
144/// `[i, j]`; 0 is used as a sentinel for *invalid* entries (the root is
145/// never a valid post-prefix destination in practice).
146#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct DenseMask {
148    /// Packed u64 bits.  Length = ceil(vocab_size^depth / 64).
149    pub bits: Vec<u64>,
150
151    /// Flat node-ID lookup: length = vocab_size ^ depth.
152    /// Entry is 0 (sentinel) when the corresponding prefix is absent.
153    pub states: Vec<u32>,
154
155    /// Number of dense layers `d`.
156    pub depth: u32,
157
158    /// Vocabulary size |V|.
159    pub vocab_size: u32,
160}
161
162impl DenseMask {
163    /// Allocate a zeroed mask for `vocab_size^depth` entries.
164    pub fn new(vocab_size: u32, depth: u32) -> Self {
165        let total = (vocab_size as usize).pow(depth);
166        let words = total.div_ceil(64);
167        Self {
168            bits: vec![0u64; words],
169            states: vec![0u32; total],
170            depth,
171            vocab_size,
172        }
173    }
174
175    /// Converts a token sequence of length `depth` to a flat index.
176    ///
177    /// # Panics
178    /// Panics in debug builds if `tokens.len() != depth`.
179    #[inline]
180    pub fn flat_index(&self, tokens: &[u32]) -> usize {
181        debug_assert_eq!(tokens.len(), self.depth as usize);
182        tokens.iter().fold(0usize, |acc, &t| {
183            acc * self.vocab_size as usize + t as usize
184        })
185    }
186
187    /// Sets the bit and stores the destination `node_id` for prefix `tokens`.
188    pub fn insert(&mut self, tokens: &[u32], node_id: u32) {
189        let idx = self.flat_index(tokens);
190        let word = idx / 64;
191        let bit = idx % 64;
192        self.bits[word] |= 1u64 << bit;
193        self.states[idx] = node_id;
194    }
195
196    /// Returns `true` if the prefix encoded by `tokens` is marked valid.
197    #[inline]
198    pub fn contains(&self, tokens: &[u32]) -> bool {
199        let idx = self.flat_index(tokens);
200        let word = idx / 64;
201        let bit = idx % 64;
202        (self.bits[word] >> bit) & 1 == 1
203    }
204
205    /// Two-token shorthand used pervasively in tests and the decoder.
206    /// Equivalent to `contains(&[v1, v2])` when `depth == 2`.
207    #[inline]
208    pub fn get(&self, v1: u32, v2: u32) -> bool {
209        debug_assert_eq!(self.depth, 2, "get(v1,v2) requires depth == 2");
210        self.contains(&[v1, v2])
211    }
212
213    /// Returns the trie node reached after the valid prefix `tokens`,
214    /// or `None` if the prefix is absent.
215    #[inline]
216    pub fn state_for(&self, tokens: &[u32]) -> Option<u32> {
217        if self.contains(tokens) {
218            Some(self.states[self.flat_index(tokens)])
219        } else {
220            None
221        }
222    }
223
224    /// Iterates over all valid prefixes as `(tokens, node_id)`.
225    pub fn iter_valid(&self) -> impl Iterator<Item = (Vec<u32>, u32)> + '_ {
226        let d = self.depth as usize;
227        let v = self.vocab_size as usize;
228        let total = v.pow(d as u32);
229        (0..total).filter_map(move |idx| {
230            let word = idx / 64;
231            let bit = idx % 64;
232            if (self.bits[word] >> bit) & 1 == 0 {
233                return None;
234            }
235            // decode flat index back into token sequence
236            let mut rem = idx;
237            let mut toks = vec![0u32; d];
238            for pos in (0..d).rev() {
239                toks[pos] = (rem % v) as u32;
240                rem /= v;
241            }
242            Some((toks, self.states[idx]))
243        })
244    }
245}
246
247impl fmt::Display for DenseMask {
248    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249        let valid = self.bits.iter().map(|w| w.count_ones()).sum::<u32>();
250        write!(
251            f,
252            "DenseMask(depth={}, |V|={}, valid_prefixes={valid})",
253            self.depth, self.vocab_size
254        )
255    }
256}
257
258// ──────────────────────────────────────────────────────────────────────────────
259// StaticIndex
260// ──────────────────────────────────────────────────────────────────────────────
261
262/// Combined STATIC index: dense mask for the first `d` layers,
263/// sparse CSR matrix for all deeper layers.
264#[derive(Debug, Clone)]
265pub struct StaticIndex {
266    /// Bit-packed dense mask covering the first `d` trie levels.
267    pub dense: DenseMask,
268
269    /// CSR transition matrix for levels `d..L`.
270    pub sparse: TransitionMatrix,
271
272    /// Total number of constraints |C|.
273    pub num_constraints: usize,
274}
275
276impl StaticIndex {
277    pub fn new(dense: DenseMask, sparse: TransitionMatrix, num_constraints: usize) -> Self {
278        Self {
279            dense,
280            sparse,
281            num_constraints,
282        }
283    }
284
285    /// Quick sanity check delegating to both sub-structures.
286    pub fn check_invariants(&self) -> Result<(), String> {
287        self.sparse.check_invariants()
288    }
289}
290
291impl fmt::Display for StaticIndex {
292    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293        write!(
294            f,
295            "StaticIndex(|C|={}, {}, {})",
296            self.num_constraints, self.dense, self.sparse
297        )
298    }
299}
300
301// ──────────────────────────────────────────────────────────────────────────────
302// BeamState
303// ──────────────────────────────────────────────────────────────────────────────
304
305/// Live decoding state for a single batch of beam searches.
306///
307/// Shapes:
308/// - `nodes`  : `[batch_size][beam_width]`  — current trie node per beam
309/// - `scores` : `[batch_size][beam_width]`  — accumulated log-probability
310/// - `tokens` : `[batch_size][beam_width][step]` — partial SID decoded so far
311#[derive(Debug, Clone)]
312pub struct BeamState {
313    pub nodes: Vec<Vec<u32>>,
314    pub scores: Vec<Vec<f64>>,
315    pub tokens: Vec<Vec<Vec<u32>>>,
316}
317
318impl BeamState {
319    /// Creates a blank state for `batch_size` queries, each with `beam_width`
320    /// beams, all starting at the trie root (node 0) with log-prob 0.0.
321    pub fn new(batch_size: usize, beam_width: usize) -> Self {
322        Self {
323            nodes: vec![vec![0u32; beam_width]; batch_size],
324            scores: vec![vec![0.0f64; beam_width]; batch_size],
325            tokens: vec![vec![Vec::new(); beam_width]; batch_size],
326        }
327    }
328
329    pub fn batch_size(&self) -> usize {
330        self.nodes.len()
331    }
332    pub fn beam_width(&self) -> usize {
333        self.nodes.first().map_or(0, Vec::len)
334    }
335
336    /// Returns the current decoding step (number of tokens emitted so far).
337    /// Assumes all beams in batch 0 are at the same step.
338    pub fn step(&self) -> usize {
339        self.tokens
340            .first()
341            .and_then(|b| b.first())
342            .map_or(0, Vec::len)
343    }
344
345    /// Flattens `nodes` into a single `Vec<u32>` of length
346    /// `batch_size * beam_width` for bulk VNTK calls.
347    pub fn flat_nodes(&self) -> Vec<u32> {
348        self.nodes
349            .iter()
350            .flat_map(|row| row.iter().copied())
351            .collect()
352    }
353
354    /// Replaces the state from a flat representation produced by the decoder.
355    /// `flat_nodes` and `flat_scores` must have length `batch_size * beam_width`.
356    pub fn update_from_flat(
357        &mut self,
358        flat_nodes: &[u32],
359        flat_scores: &[f64],
360        flat_tokens: &[Vec<u32>],
361    ) {
362        let bw = self.beam_width();
363        for (b, row) in self.nodes.iter_mut().enumerate() {
364            row.copy_from_slice(&flat_nodes[b * bw..(b + 1) * bw]);
365        }
366        for (b, row) in self.scores.iter_mut().enumerate() {
367            row.copy_from_slice(&flat_scores[b * bw..(b + 1) * bw]);
368        }
369        for (b, row) in self.tokens.iter_mut().enumerate() {
370            for (w, toks) in row.iter_mut().enumerate() {
371                *toks = flat_tokens[b * bw + w].clone();
372            }
373        }
374    }
375
376    /// Returns completed sequences (those whose `tokens` length == `sid_length`).
377    pub fn completed(&self, sid_length: usize) -> Vec<Vec<u32>> {
378        self.tokens
379            .iter()
380            .flat_map(|batch| batch.iter())
381            .filter(|seq| seq.len() == sid_length)
382            .cloned()
383            .collect()
384    }
385}
386
387impl fmt::Display for BeamState {
388    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389        write!(
390            f,
391            "BeamState(batch={}, beams={}, step={})",
392            self.batch_size(),
393            self.beam_width(),
394            self.step(),
395        )
396    }
397}
398
399// ──────────────────────────────────────────────────────────────────────────────
400// VntkOutput  (consumed by decoder.rs and the test module)
401// ──────────────────────────────────────────────────────────────────────────────
402
403/// Output of a single VNTK call for one decoding step.
404#[derive(Debug, Clone)]
405pub struct VntkOutput {
406    /// Next trie-node IDs, one per valid (beam, child) slot.
407    /// Length = `beam_width * max_branches_at_level` (padded with sentinel 0).
408    pub next_nodes: Vec<u32>,
409
410    /// Dense boolean mask over the vocabulary: `mask[t]` is `true` iff token
411    /// `t` is a valid next token for *at least one* active beam.
412    /// Length = `vocab_size`.
413    pub mask: Vec<bool>,
414}