Skip to main content

constraint_decoding_trie/
transition.rs

1// src/transition.rs
2
3use std::collections::HashMap;
4use std::collections::VecDeque;
5
6use crate::types::{DenseMask, StaticIndex, TransitionMatrix};
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Internal trie node
10// ──────────────────────────────────────────────────────────────────────────────
11
12struct TrieNode {
13    /// Child nodes keyed by the token that leads to them.
14    children: HashMap<u32, Box<TrieNode>>,
15    /// BFS-assigned integer ID (set during `enumerate_nodes`).
16    node_id: u32,
17    /// Depth of this node in the trie (root = 0).
18    level: u32,
19    /// True iff at least one constraint sequence ends here.
20    is_terminal: bool,
21}
22
23impl TrieNode {
24    fn new(level: u32) -> Self {
25        Self {
26            children: HashMap::new(),
27            node_id: 0,
28            level,
29            is_terminal: false,
30        }
31    }
32}
33
34// ──────────────────────────────────────────────────────────────────────────────
35// Phase 1 — trie construction
36// ──────────────────────────────────────────────────────────────────────────────
37
38/// Inserts all constraint sequences into a fresh trie and returns the root.
39fn build_trie(constraints: &[Vec<u32>], vocab_size: u32, _sid_length: u32) -> Box<TrieNode> {
40    let mut root = Box::new(TrieNode::new(0));
41
42    for seq in constraints {
43        let mut cur: *mut TrieNode = root.as_mut();
44
45        for &token in seq {
46            debug_assert!(
47                token < vocab_size,
48                "token {token} out of vocabulary (|V|={vocab_size})"
49            );
50            // SAFETY: `cur` always points to a live node owned by `root`.
51            let node = unsafe { &mut *cur };
52            let level = node.level + 1;
53            let child = node
54                .children
55                .entry(token)
56                .or_insert_with(|| Box::new(TrieNode::new(level)));
57            cur = child.as_mut();
58        }
59
60        // Mark the terminal node.
61        unsafe { (*cur).is_terminal = true };
62    }
63
64    root
65}
66
67// ──────────────────────────────────────────────────────────────────────────────
68// Phase 2 — BFS enumeration
69// ──────────────────────────────────────────────────────────────────────────────
70
71/// BFS over the trie, assigning monotonically increasing integer IDs in
72/// level-order.  Returns:
73///
74/// - `node_map`    : raw pointer → integer ID (used by CSR builder)
75/// - `level_nodes` : `level_nodes[l]` = number of nodes at depth l
76///
77/// BFS order guarantees that for any node with ID i, all nodes on shallower
78/// levels have IDs < i, which matches the paper's Figure 1d layout.
79fn enumerate_nodes(root: &TrieNode) -> (HashMap<*const TrieNode, u32>, Vec<u32>) {
80    let mut node_map: HashMap<*const TrieNode, u32> = HashMap::new();
81    let mut level_counts: Vec<u32> = Vec::new();
82    let mut queue: VecDeque<*const TrieNode> = VecDeque::new();
83    let mut next_id: u32 = 0;
84
85    queue.push_back(root as *const _);
86
87    while let Some(ptr) = queue.pop_front() {
88        // SAFETY: all pointers in the queue are live nodes owned by `root`.
89        let node = unsafe { &*ptr };
90
91        // Extend level_counts if this is the first node we see at this depth.
92        while level_counts.len() <= node.level as usize {
93            level_counts.push(0);
94        }
95        level_counts[node.level as usize] += 1;
96
97        node_map.insert(ptr, next_id);
98        next_id += 1;
99
100        // Enqueue children sorted by token so the ID assignment is deterministic
101        // across runs (important for reproducible Parquet snapshots).
102        let mut children: Vec<(&u32, &Box<TrieNode>)> = node.children.iter().collect();
103        children.sort_by_key(|(tok, _)| *tok);
104        for (_, child) in children {
105            queue.push_back(child.as_ref() as *const _);
106        }
107    }
108
109    (node_map, level_counts)
110}
111
112// ──────────────────────────────────────────────────────────────────────────────
113// Phase 3 — max branch factor per level
114// ──────────────────────────────────────────────────────────────────────────────
115
116/// Returns a Vec of length `sid_length` where entry `l` is the maximum number
117/// of children any node at depth `l` has.  Used by VNTK to size its output
118/// buffer without dynamic allocation.
119fn compute_max_branches(root: &TrieNode, sid_length: u32) -> Vec<u32> {
120    let mut max_branches = vec![0u32; sid_length as usize];
121    let mut stack: Vec<*const TrieNode> = vec![root as *const _];
122
123    while let Some(ptr) = stack.pop() {
124        let node = unsafe { &*ptr };
125        if (node.level as usize) < max_branches.len() {
126            let deg = node.children.len() as u32;
127            if deg > max_branches[node.level as usize] {
128                max_branches[node.level as usize] = deg;
129            }
130        }
131        for child in node.children.values() {
132            stack.push(child.as_ref() as *const _);
133        }
134    }
135
136    max_branches
137}
138
139// ──────────────────────────────────────────────────────────────────────────────
140// Phase 4 — CSR construction
141// ──────────────────────────────────────────────────────────────────────────────
142
143fn build_csr(
144    root: &TrieNode,
145    node_map: &HashMap<*const TrieNode, u32>,
146    vocab_size: u32,
147    sid_length: u32,
148    max_branches: &[u32],
149) -> TransitionMatrix {
150    let num_nodes = node_map.len() as u32;
151
152    // Collect all nodes into a Vec indexed by their BFS ID so we can fill
153    // row_pointers in strictly ascending order.
154    let mut nodes_by_id: Vec<*const TrieNode> = vec![std::ptr::null(); num_nodes as usize];
155    {
156        let mut stack: Vec<*const TrieNode> = vec![root as *const _];
157        while let Some(ptr) = stack.pop() {
158            let id = node_map[&ptr] as usize;
159            nodes_by_id[id] = ptr;
160            let node = unsafe { &*ptr };
161            for child in node.children.values() {
162                stack.push(child.as_ref() as *const _);
163            }
164        }
165    }
166
167    let mut row_pointers = Vec::with_capacity(num_nodes as usize + 1);
168    let mut data: Vec<[u32; 2]> = Vec::new();
169
170    let mut offset = 0u32;
171    for ptr in &nodes_by_id {
172        row_pointers.push(offset);
173
174        let node = unsafe { &**ptr };
175
176        // Sort children by token ID for deterministic, cache-friendly access.
177        let mut children: Vec<(u32, u32)> = node
178            .children
179            .iter()
180            .map(|(&tok, child)| {
181                let next_id = node_map[&(child.as_ref() as *const _)];
182                (tok, next_id)
183            })
184            .collect();
185        children.sort_by_key(|&(tok, _)| tok);
186
187        for (tok, next_id) in children {
188            data.push([tok, next_id]);
189            offset += 1;
190        }
191    }
192    row_pointers.push(offset); // sentinel
193
194    TransitionMatrix {
195        row_pointers,
196        data,
197        max_branches: max_branches.to_vec(),
198        num_nodes,
199        vocab_size,
200        sid_length,
201    }
202}
203
204// ──────────────────────────────────────────────────────────────────────────────
205// Phase 5 — dense mask
206// ──────────────────────────────────────────────────────────────────────────────
207
208/// Populates the bit-packed dense mask for all prefixes of length `dense_depth`.
209///
210/// Walks each constraint once, extracting the first `dense_depth` tokens and
211/// looking up the destination trie node via `node_map`.
212fn build_dense_mask(
213    constraints: &[Vec<u32>],
214    root: &TrieNode,
215    vocab_size: u32,
216    dense_depth: u32,
217    node_map: &HashMap<*const TrieNode, u32>,
218) -> DenseMask {
219    let mut mask = DenseMask::new(vocab_size, dense_depth);
220
221    for seq in constraints {
222        // Walk the trie along the first `dense_depth` tokens.
223        let mut cur: *const TrieNode = root as *const _;
224
225        for (step, &token) in seq.iter().enumerate().take(dense_depth as usize) {
226            let node = unsafe { &*cur };
227            match node.children.get(&token) {
228                Some(child) => cur = child.as_ref() as *const _,
229                None => break, // should never happen for valid constraints
230            }
231
232            // After consuming exactly `dense_depth` tokens, record the node.
233            if step + 1 == dense_depth as usize {
234                let node_id = node_map[&cur];
235                let prefix = &seq[..dense_depth as usize];
236                mask.insert(prefix, node_id);
237            }
238        }
239    }
240
241    mask
242}
243
244// ──────────────────────────────────────────────────────────────────────────────
245// Public entry point
246// ──────────────────────────────────────────────────────────────────────────────
247
248/// Build a STATIC index from a set of Semantic ID sequences.
249///
250/// # Arguments
251/// - `constraints` — `|C|` sequences of token IDs, each of length `sid_length`
252/// - `vocab_size`  — vocabulary size |V|; all tokens must be in `[0, vocab_size)`
253/// - `sid_length`  — fixed length L of every SID
254/// - `dense_depth` — number of layers to cover with the dense mask (typically 2)
255///
256/// # Panics
257/// Panics in debug mode if any token ≥ `vocab_size` or any sequence length ≠
258/// `sid_length`.
259pub fn build_static_index(
260    constraints: &[Vec<u32>],
261    vocab_size: u32,
262    sid_length: u32,
263    dense_depth: u32,
264) -> StaticIndex {
265    debug_assert!(
266        dense_depth <= sid_length,
267        "dense_depth ({dense_depth}) must be ≤ sid_length ({sid_length})"
268    );
269    debug_assert!(
270        constraints.iter().all(|s| s.len() == sid_length as usize),
271        "every constraint must have exactly sid_length={sid_length} tokens"
272    );
273
274    // ── Phase 1 ──────────────────────────────────────────────────────────────
275    let trie = build_trie(constraints, vocab_size, sid_length);
276
277    // ── Phase 2 ──────────────────────────────────────────────────────────────
278    let (node_map, _level_counts) = enumerate_nodes(&trie);
279
280    // ── Phase 3 ──────────────────────────────────────────────────────────────
281    let max_branches = compute_max_branches(&trie, sid_length);
282
283    // ── Phase 4 ──────────────────────────────────────────────────────────────
284    let sparse = build_csr(&trie, &node_map, vocab_size, sid_length, &max_branches);
285
286    // ── Phase 5 ──────────────────────────────────────────────────────────────
287    let dense = build_dense_mask(constraints, &trie, vocab_size, dense_depth, &node_map);
288
289    #[cfg(debug_assertions)]
290    sparse
291        .check_invariants()
292        .expect("CSR invariants violated after construction");
293
294    StaticIndex {
295        dense,
296        sparse,
297        num_constraints: constraints.len(),
298    }
299}