constraint_decoding_trie/
transition.rs1use std::collections::HashMap;
4use std::collections::VecDeque;
5
6use crate::types::{DenseMask, StaticIndex, TransitionMatrix};
7
8struct TrieNode {
13 children: HashMap<u32, Box<TrieNode>>,
15 node_id: u32,
17 level: u32,
19 is_terminal: bool,
21}
22
23impl TrieNode {
24 fn new(level: u32) -> Self {
25 Self {
26 children: HashMap::new(),
27 node_id: 0,
28 level,
29 is_terminal: false,
30 }
31 }
32}
33
34fn build_trie(constraints: &[Vec<u32>], vocab_size: u32, _sid_length: u32) -> Box<TrieNode> {
40 let mut root = Box::new(TrieNode::new(0));
41
42 for seq in constraints {
43 let mut cur: *mut TrieNode = root.as_mut();
44
45 for &token in seq {
46 debug_assert!(
47 token < vocab_size,
48 "token {token} out of vocabulary (|V|={vocab_size})"
49 );
50 let node = unsafe { &mut *cur };
52 let level = node.level + 1;
53 let child = node
54 .children
55 .entry(token)
56 .or_insert_with(|| Box::new(TrieNode::new(level)));
57 cur = child.as_mut();
58 }
59
60 unsafe { (*cur).is_terminal = true };
62 }
63
64 root
65}
66
67fn enumerate_nodes(root: &TrieNode) -> (HashMap<*const TrieNode, u32>, Vec<u32>) {
80 let mut node_map: HashMap<*const TrieNode, u32> = HashMap::new();
81 let mut level_counts: Vec<u32> = Vec::new();
82 let mut queue: VecDeque<*const TrieNode> = VecDeque::new();
83 let mut next_id: u32 = 0;
84
85 queue.push_back(root as *const _);
86
87 while let Some(ptr) = queue.pop_front() {
88 let node = unsafe { &*ptr };
90
91 while level_counts.len() <= node.level as usize {
93 level_counts.push(0);
94 }
95 level_counts[node.level as usize] += 1;
96
97 node_map.insert(ptr, next_id);
98 next_id += 1;
99
100 let mut children: Vec<(&u32, &Box<TrieNode>)> = node.children.iter().collect();
103 children.sort_by_key(|(tok, _)| *tok);
104 for (_, child) in children {
105 queue.push_back(child.as_ref() as *const _);
106 }
107 }
108
109 (node_map, level_counts)
110}
111
112fn compute_max_branches(root: &TrieNode, sid_length: u32) -> Vec<u32> {
120 let mut max_branches = vec![0u32; sid_length as usize];
121 let mut stack: Vec<*const TrieNode> = vec![root as *const _];
122
123 while let Some(ptr) = stack.pop() {
124 let node = unsafe { &*ptr };
125 if (node.level as usize) < max_branches.len() {
126 let deg = node.children.len() as u32;
127 if deg > max_branches[node.level as usize] {
128 max_branches[node.level as usize] = deg;
129 }
130 }
131 for child in node.children.values() {
132 stack.push(child.as_ref() as *const _);
133 }
134 }
135
136 max_branches
137}
138
139fn build_csr(
144 root: &TrieNode,
145 node_map: &HashMap<*const TrieNode, u32>,
146 vocab_size: u32,
147 sid_length: u32,
148 max_branches: &[u32],
149) -> TransitionMatrix {
150 let num_nodes = node_map.len() as u32;
151
152 let mut nodes_by_id: Vec<*const TrieNode> = vec![std::ptr::null(); num_nodes as usize];
155 {
156 let mut stack: Vec<*const TrieNode> = vec![root as *const _];
157 while let Some(ptr) = stack.pop() {
158 let id = node_map[&ptr] as usize;
159 nodes_by_id[id] = ptr;
160 let node = unsafe { &*ptr };
161 for child in node.children.values() {
162 stack.push(child.as_ref() as *const _);
163 }
164 }
165 }
166
167 let mut row_pointers = Vec::with_capacity(num_nodes as usize + 1);
168 let mut data: Vec<[u32; 2]> = Vec::new();
169
170 let mut offset = 0u32;
171 for ptr in &nodes_by_id {
172 row_pointers.push(offset);
173
174 let node = unsafe { &**ptr };
175
176 let mut children: Vec<(u32, u32)> = node
178 .children
179 .iter()
180 .map(|(&tok, child)| {
181 let next_id = node_map[&(child.as_ref() as *const _)];
182 (tok, next_id)
183 })
184 .collect();
185 children.sort_by_key(|&(tok, _)| tok);
186
187 for (tok, next_id) in children {
188 data.push([tok, next_id]);
189 offset += 1;
190 }
191 }
192 row_pointers.push(offset); TransitionMatrix {
195 row_pointers,
196 data,
197 max_branches: max_branches.to_vec(),
198 num_nodes,
199 vocab_size,
200 sid_length,
201 }
202}
203
204fn build_dense_mask(
213 constraints: &[Vec<u32>],
214 root: &TrieNode,
215 vocab_size: u32,
216 dense_depth: u32,
217 node_map: &HashMap<*const TrieNode, u32>,
218) -> DenseMask {
219 let mut mask = DenseMask::new(vocab_size, dense_depth);
220
221 for seq in constraints {
222 let mut cur: *const TrieNode = root as *const _;
224
225 for (step, &token) in seq.iter().enumerate().take(dense_depth as usize) {
226 let node = unsafe { &*cur };
227 match node.children.get(&token) {
228 Some(child) => cur = child.as_ref() as *const _,
229 None => break, }
231
232 if step + 1 == dense_depth as usize {
234 let node_id = node_map[&cur];
235 let prefix = &seq[..dense_depth as usize];
236 mask.insert(prefix, node_id);
237 }
238 }
239 }
240
241 mask
242}
243
244pub fn build_static_index(
260 constraints: &[Vec<u32>],
261 vocab_size: u32,
262 sid_length: u32,
263 dense_depth: u32,
264) -> StaticIndex {
265 debug_assert!(
266 dense_depth <= sid_length,
267 "dense_depth ({dense_depth}) must be ≤ sid_length ({sid_length})"
268 );
269 debug_assert!(
270 constraints.iter().all(|s| s.len() == sid_length as usize),
271 "every constraint must have exactly sid_length={sid_length} tokens"
272 );
273
274 let trie = build_trie(constraints, vocab_size, sid_length);
276
277 let (node_map, _level_counts) = enumerate_nodes(&trie);
279
280 let max_branches = compute_max_branches(&trie, sid_length);
282
283 let sparse = build_csr(&trie, &node_map, vocab_size, sid_length, &max_branches);
285
286 let dense = build_dense_mask(constraints, &trie, vocab_size, dense_depth, &node_map);
288
289 #[cfg(debug_assertions)]
290 sparse
291 .check_invariants()
292 .expect("CSR invariants violated after construction");
293
294 StaticIndex {
295 dense,
296 sparse,
297 num_constraints: constraints.len(),
298 }
299}