constraint_decoding_trie/
types.rs1use std::fmt;
4
5#[derive(Debug, Clone, PartialEq, Eq)]
13pub struct TransitionMatrix {
14 pub row_pointers: Vec<u32>,
17
18 pub data: Vec<[u32; 2]>,
20
21 pub max_branches: Vec<u32>,
23
24 pub num_nodes: u32,
26
27 pub vocab_size: u32,
29
30 pub sid_length: u32,
32}
33
34impl TransitionMatrix {
35 pub fn new(num_nodes: u32, vocab_size: u32, sid_length: u32) -> Self {
37 Self {
38 row_pointers: vec![0u32; num_nodes as usize + 1],
39 data: Vec::new(),
40 max_branches: vec![0u32; sid_length as usize],
41 num_nodes,
42 vocab_size,
43 sid_length,
44 }
45 }
46
47 #[inline]
52 pub fn children(&self, node: u32) -> &[[u32; 2]] {
53 assert!(
54 node < self.num_nodes,
55 "node {node} out of range (num_nodes={})",
56 self.num_nodes
57 );
58 let start = self.row_pointers[node as usize] as usize;
59 let end = self.row_pointers[node as usize + 1] as usize;
60 &self.data[start..end]
61 }
62
63 #[inline]
66 pub fn next_node(&self, node: u32, token: u32) -> Option<u32> {
67 self.children(node)
68 .iter()
69 .find(|&&[t, _]| t == token)
70 .map(|&[_, n]| n)
71 }
72
73 #[inline]
75 pub fn is_leaf(&self, node: u32) -> bool {
76 self.children(node).is_empty()
77 }
78
79 #[inline]
81 pub fn degree(&self, node: u32) -> u32 {
82 self.children(node).len() as u32
83 }
84
85 pub fn check_invariants(&self) -> Result<(), String> {
87 if self.row_pointers.len() != self.num_nodes as usize + 1 {
88 return Err(format!(
89 "row_pointers length {} ≠ num_nodes+1 {}",
90 self.row_pointers.len(),
91 self.num_nodes + 1
92 ));
93 }
94 let last = *self.row_pointers.last().unwrap() as usize;
95 if last != self.data.len() {
96 return Err(format!(
97 "row_pointers tail {last} ≠ data.len() {}",
98 self.data.len()
99 ));
100 }
101 for w in self.row_pointers.windows(2) {
103 if w[0] > w[1] {
104 return Err(format!("row_pointers not monotone: {} > {}", w[0], w[1]));
105 }
106 }
107 for &[tok, nxt] in &self.data {
109 if tok >= self.vocab_size {
110 return Err(format!("token {tok} ≥ vocab_size {}", self.vocab_size));
111 }
112 if nxt >= self.num_nodes {
113 return Err(format!("next_node {nxt} ≥ num_nodes {}", self.num_nodes));
114 }
115 }
116 Ok(())
117 }
118}
119
120impl fmt::Display for TransitionMatrix {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 write!(
123 f,
124 "TransitionMatrix(nodes={}, edges={}, |V|={}, L={})",
125 self.num_nodes,
126 self.data.len(),
127 self.vocab_size,
128 self.sid_length,
129 )
130 }
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
147pub struct DenseMask {
148 pub bits: Vec<u64>,
150
151 pub states: Vec<u32>,
154
155 pub depth: u32,
157
158 pub vocab_size: u32,
160}
161
162impl DenseMask {
163 pub fn new(vocab_size: u32, depth: u32) -> Self {
165 let total = (vocab_size as usize).pow(depth);
166 let words = total.div_ceil(64);
167 Self {
168 bits: vec![0u64; words],
169 states: vec![0u32; total],
170 depth,
171 vocab_size,
172 }
173 }
174
175 #[inline]
180 pub fn flat_index(&self, tokens: &[u32]) -> usize {
181 debug_assert_eq!(tokens.len(), self.depth as usize);
182 tokens.iter().fold(0usize, |acc, &t| {
183 acc * self.vocab_size as usize + t as usize
184 })
185 }
186
187 pub fn insert(&mut self, tokens: &[u32], node_id: u32) {
189 let idx = self.flat_index(tokens);
190 let word = idx / 64;
191 let bit = idx % 64;
192 self.bits[word] |= 1u64 << bit;
193 self.states[idx] = node_id;
194 }
195
196 #[inline]
198 pub fn contains(&self, tokens: &[u32]) -> bool {
199 let idx = self.flat_index(tokens);
200 let word = idx / 64;
201 let bit = idx % 64;
202 (self.bits[word] >> bit) & 1 == 1
203 }
204
205 #[inline]
208 pub fn get(&self, v1: u32, v2: u32) -> bool {
209 debug_assert_eq!(self.depth, 2, "get(v1,v2) requires depth == 2");
210 self.contains(&[v1, v2])
211 }
212
213 #[inline]
216 pub fn state_for(&self, tokens: &[u32]) -> Option<u32> {
217 if self.contains(tokens) {
218 Some(self.states[self.flat_index(tokens)])
219 } else {
220 None
221 }
222 }
223
224 pub fn iter_valid(&self) -> impl Iterator<Item = (Vec<u32>, u32)> + '_ {
226 let d = self.depth as usize;
227 let v = self.vocab_size as usize;
228 let total = v.pow(d as u32);
229 (0..total).filter_map(move |idx| {
230 let word = idx / 64;
231 let bit = idx % 64;
232 if (self.bits[word] >> bit) & 1 == 0 {
233 return None;
234 }
235 let mut rem = idx;
237 let mut toks = vec![0u32; d];
238 for pos in (0..d).rev() {
239 toks[pos] = (rem % v) as u32;
240 rem /= v;
241 }
242 Some((toks, self.states[idx]))
243 })
244 }
245}
246
247impl fmt::Display for DenseMask {
248 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
249 let valid = self.bits.iter().map(|w| w.count_ones()).sum::<u32>();
250 write!(
251 f,
252 "DenseMask(depth={}, |V|={}, valid_prefixes={valid})",
253 self.depth, self.vocab_size
254 )
255 }
256}
257
258#[derive(Debug, Clone)]
265pub struct StaticIndex {
266 pub dense: DenseMask,
268
269 pub sparse: TransitionMatrix,
271
272 pub num_constraints: usize,
274}
275
276impl StaticIndex {
277 pub fn new(dense: DenseMask, sparse: TransitionMatrix, num_constraints: usize) -> Self {
278 Self {
279 dense,
280 sparse,
281 num_constraints,
282 }
283 }
284
285 pub fn check_invariants(&self) -> Result<(), String> {
287 self.sparse.check_invariants()
288 }
289}
290
291impl fmt::Display for StaticIndex {
292 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
293 write!(
294 f,
295 "StaticIndex(|C|={}, {}, {})",
296 self.num_constraints, self.dense, self.sparse
297 )
298 }
299}
300
301#[derive(Debug, Clone)]
312pub struct BeamState {
313 pub nodes: Vec<Vec<u32>>,
314 pub scores: Vec<Vec<f64>>,
315 pub tokens: Vec<Vec<Vec<u32>>>,
316}
317
318impl BeamState {
319 pub fn new(batch_size: usize, beam_width: usize) -> Self {
322 Self {
323 nodes: vec![vec![0u32; beam_width]; batch_size],
324 scores: vec![vec![0.0f64; beam_width]; batch_size],
325 tokens: vec![vec![Vec::new(); beam_width]; batch_size],
326 }
327 }
328
329 pub fn batch_size(&self) -> usize {
330 self.nodes.len()
331 }
332 pub fn beam_width(&self) -> usize {
333 self.nodes.first().map_or(0, Vec::len)
334 }
335
336 pub fn step(&self) -> usize {
339 self.tokens
340 .first()
341 .and_then(|b| b.first())
342 .map_or(0, Vec::len)
343 }
344
345 pub fn flat_nodes(&self) -> Vec<u32> {
348 self.nodes
349 .iter()
350 .flat_map(|row| row.iter().copied())
351 .collect()
352 }
353
354 pub fn update_from_flat(
357 &mut self,
358 flat_nodes: &[u32],
359 flat_scores: &[f64],
360 flat_tokens: &[Vec<u32>],
361 ) {
362 let bw = self.beam_width();
363 for (b, row) in self.nodes.iter_mut().enumerate() {
364 row.copy_from_slice(&flat_nodes[b * bw..(b + 1) * bw]);
365 }
366 for (b, row) in self.scores.iter_mut().enumerate() {
367 row.copy_from_slice(&flat_scores[b * bw..(b + 1) * bw]);
368 }
369 for (b, row) in self.tokens.iter_mut().enumerate() {
370 for (w, toks) in row.iter_mut().enumerate() {
371 *toks = flat_tokens[b * bw + w].clone();
372 }
373 }
374 }
375
376 pub fn completed(&self, sid_length: usize) -> Vec<Vec<u32>> {
378 self.tokens
379 .iter()
380 .flat_map(|batch| batch.iter())
381 .filter(|seq| seq.len() == sid_length)
382 .cloned()
383 .collect()
384 }
385}
386
387impl fmt::Display for BeamState {
388 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
389 write!(
390 f,
391 "BeamState(batch={}, beams={}, step={})",
392 self.batch_size(),
393 self.beam_width(),
394 self.step(),
395 )
396 }
397}
398
399#[derive(Debug, Clone)]
405pub struct VntkOutput {
406 pub next_nodes: Vec<u32>,
409
410 pub mask: Vec<bool>,
414}