constraint_decoding_trie/decoder.rs
1// src/decoder.rs
2
3use rayon::prelude::*;
4
5use crate::types::{BeamState, StaticIndex, VntkOutput};
6use crate::vntk::VntkResult;
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Top-level decoder struct
10// ──────────────────────────────────────────────────────────────────────────────
11
12pub struct ConstrainedDecoder {
13 pub index: StaticIndex,
14 pub beam_width: usize, // M
15 pub batch_size: usize, // B
16}
17
18impl ConstrainedDecoder {
19 pub fn new(index: StaticIndex, beam_width: usize, batch_size: usize) -> Self {
20 Self {
21 index,
22 beam_width,
23 batch_size,
24 }
25 }
26
27 // ──────────────────────────────────────────────────────────────────────────
28 // Public: single decoding step (Algorithm 1, one iteration)
29 // ──────────────────────────────────────────────────────────────────────────
30
31 /// Execute one constrained decoding step.
32 ///
33 /// # Arguments
34 /// - `logits` — raw model outputs, shape \[B × M × |V|\]
35 /// - `state` — mutable beam state (nodes, scores, partial tokens)
36 /// - `step` — 0-indexed decoding step t
37 pub fn step(
38 &self,
39 logits: &[Vec<Vec<f64>>], // [B][M][|V|]
40 state: &mut BeamState,
41 step: usize,
42 ) {
43 let vocab = self.index.sparse.vocab_size as usize;
44 let b = self.batch_size;
45 let m = self.beam_width;
46
47 debug_assert_eq!(logits.len(), b);
48 debug_assert!(logits.iter().all(|q| q.len() == m));
49 debug_assert!(logits.iter().all(|q| q.iter().all(|bm| bm.len() == vocab)));
50
51 // ── Phase 1: LogSoftmax ───────────────────────────────────────────────
52 let log_probs = log_softmax_3d(logits);
53
54 // ── Phase 2: Constraint masking ───────────────────────────────────────
55 // Returns:
56 // masks : [B][M][|V|] — true = token is valid
57 // next_nodes : [B][M][B_t] — trie nodes after each valid token slot
58 let (masks, next_nodes) = if step < self.index.dense.depth as usize {
59 self.dense_lookup(state, step)
60 } else {
61 self.sparse_lookup(state, step)
62 };
63
64 // ── Phase 3: Apply mask → NEG_INF for invalid tokens ─────────────────
65 let masked = apply_mask(&log_probs, &masks);
66
67 // ── Phase 4: Beam search selection ───────────────────────────────────
68 // new_tokens : [B][M] — chosen token per surviving beam
69 // new_scores : [B][M] — updated cumulative log-prob
70 // src_beams : [B][M] — which old beam each new beam came from
71 let (new_tokens, new_scores, src_beams) = beam_search(&masked, &state.scores, m);
72
73 // ── Phase 5: State gather ─────────────────────────────────────────────
74 self.gather_state(
75 state,
76 &new_tokens,
77 &new_scores,
78 &src_beams,
79 &next_nodes,
80 step,
81 );
82 }
83
84 // ──────────────────────────────────────────────────────────────────────────
85 // Public: full decoding loop (Algorithm 1 complete)
86 // ──────────────────────────────────────────────────────────────────────────
87
88 /// Run the full constrained beam-search loop for `sid_length` steps.
89 ///
90 /// `logit_fn` is called once per step; it receives the current `BeamState`
91 /// and must return logits of shape `[B × M × |V|]`.
92 ///
93 /// Returns the top-`beam_width` decoded SIDs for every query in the batch.
94 pub fn decode<F>(&self, logit_fn: F, sid_length: usize) -> Vec<Vec<Vec<u32>>>
95 // [B][M][L]
96 where
97 F: Fn(&BeamState, usize) -> Vec<Vec<Vec<f64>>>,
98 {
99 let mut state = BeamState::new(self.batch_size, self.beam_width);
100
101 for step in 0..sid_length {
102 let logits = logit_fn(&state, step);
103 self.step(&logits, &mut state, step);
104 }
105
106 // Return the token sequences accumulated in state.
107 state.tokens.clone()
108 }
109
110 // ──────────────────────────────────────────────────────────────────────────
111 // Phase 2a: dense lookup (steps 0 .. dense_depth−1)
112 // ──────────────────────────────────────────────────────────────────────────
113
114 /// For steps covered by the bit-packed dense mask, look up validity in O(1)
115 /// per token without touching the CSR matrix.
116 ///
117 /// Returns `(masks, next_nodes)` shaped `[B][M][|V|]` and `[B][M][1]`
118 /// respectively (one "next node" per beam; the trie node reached after the
119 /// chosen token is resolved lazily in `gather_state` from the dense mask's
120 /// `states` array).
121 pub fn dense_lookup(
122 &self,
123 state: &BeamState,
124 step: usize,
125 ) -> (Vec<Vec<Vec<bool>>>, Vec<Vec<Vec<u32>>>) {
126 let vocab = self.index.sparse.vocab_size as usize;
127 let depth = self.index.dense.depth as usize;
128 let b = self.batch_size;
129 let m = self.beam_width;
130
131 debug_assert!(step < depth, "dense_lookup called outside dense range");
132 debug_assert!(depth >= 1);
133
134 // masks[b][m][v] = token validity at this step for each beam
135 let mut masks: Vec<Vec<Vec<bool>>> = vec![vec![vec![false; vocab]; m]; b];
136
137 // next_nodes is not used for dense steps in our gather logic; keep shape stable.
138 let next_nodes: Vec<Vec<Vec<u32>>> = vec![vec![vec![0u32; 1]; m]; b];
139
140 for bi in 0..b {
141 for mi in 0..m {
142 let prev = &state.tokens[bi][mi];
143 debug_assert_eq!(prev.len(), step);
144
145 if step == 0 {
146 // Step 0: allow tokens that start at least one valid dense prefix.
147 for tok in 0..vocab as u32 {
148 if self.index.dense.first_token_valid(tok) {
149 masks[bi][mi][tok as usize] = true;
150 }
151 }
152 continue;
153 }
154
155 // step >= 1: extend the prefix by one candidate token and test
156 for tok in 0..vocab as u32 {
157 let mut candidate = prev.clone();
158 candidate.push(tok);
159
160 let valid = if candidate.len() == depth {
161 // Boundary case: full dense prefix, must be exact membership
162 self.index.dense.contains(&candidate)
163 } else {
164 // Proper partial prefix
165 self.index.dense.partial_prefix_has_extension(&candidate)
166 };
167
168 if valid {
169 masks[bi][mi][tok as usize] = true;
170 }
171 }
172 }
173 }
174
175 (masks, next_nodes)
176 }
177
178 /// Returns true if *any* full-depth dense entry starts with `tok`.
179 #[inline]
180 pub fn dense_first_token_valid(&self, tok: u32) -> bool {
181 self.index.dense.first_token_valid(tok)
182 }
183
184 /// Returns true if `partial_prefix` (length < depth) can be extended to a
185 /// valid full-depth prefix.
186 fn dense_prefix_has_extension(&self, partial_prefix: &[u32]) -> bool {
187 let vocab = self.index.sparse.vocab_size as usize;
188 let depth = self.index.dense.depth as usize;
189 let len = partial_prefix.len();
190 debug_assert!(len < depth);
191
192 // Flat index of the first entry in the block covered by partial_prefix.
193 let block_start: usize = partial_prefix
194 .iter()
195 .fold(0usize, |acc, &t| acc * vocab + t as usize);
196 let stride = vocab.pow((depth - len) as u32);
197 let base = block_start * stride;
198 let end = base + stride;
199 let ws = base / 64;
200 let we = end.div_ceil(64).min(self.index.dense.bits.len());
201 self.index.dense.bits[ws..we].iter().any(|&w| w != 0)
202 }
203
204 // ──────────────────────────────────────────────────────────────────────────
205 // Phase 2b: sparse lookup (steps dense_depth .. L−1)
206 // ──────────────────────────────────────────────────────────────────────────
207
208 /// For deeper steps, call VNTK on the CSR transition matrix.
209 ///
210 /// Returns `(masks, next_nodes)` shaped `[B][M][|V|]` and `[B][M][B_t]`.
211 pub fn sparse_lookup(
212 &self,
213 state: &BeamState,
214 step: usize,
215 ) -> (Vec<Vec<Vec<bool>>>, Vec<Vec<Vec<u32>>>) {
216 let vocab = self.index.sparse.vocab_size as usize;
217 let b = self.batch_size;
218 let m = self.beam_width;
219 let b_t = self.index.sparse.max_branches[step] as usize;
220
221 // Flatten [B][M] nodes into a single slice for a single VNTK call.
222 let flat_nodes: Vec<u32> = state.nodes.iter().flatten().copied().collect();
223
224 let result = self.index.sparse.vntk(&flat_nodes, step);
225
226 // Reshape VntkResult back to [B][M][…]
227 let mut masks: Vec<Vec<Vec<bool>>> = vec![vec![vec![false; vocab]; m]; b];
228 let mut next_nodes: Vec<Vec<Vec<u32>>> = vec![vec![vec![0u32; b_t]; m]; b];
229
230 for bi in 0..b {
231 for mi in 0..m {
232 let flat_i = bi * m + mi;
233
234 // Dense mask slice → masks[bi][mi][*]
235 let mask_slice = result.mask_for(flat_i, vocab);
236 masks[bi][mi].copy_from_slice(mask_slice);
237
238 // Next-node slots → next_nodes[bi][mi][*]
239 let base = flat_i * b_t;
240 next_nodes[bi][mi].copy_from_slice(&result.next_nodes[base..base + b_t]);
241 }
242 }
243
244 (masks, next_nodes)
245 }
246
247 // ──────────────────────────────────────────────────────────────────────────
248 // Phase 5: state gather
249 // ──────────────────────────────────────────────────────────────────────────
250
251 /// Applies the beam-search selection to the live `BeamState`.
252 ///
253 /// For each surviving beam in each batch entry:
254 /// 1. Copy the partial token sequence from the *source* beam.
255 /// 2. Append the newly chosen token.
256 /// 3. Advance the trie node pointer using `next_nodes`.
257 /// Applies the beam-search selection to the live `BeamState`.
258 ///
259 /// This implementation handles the transition from dense "prefix-only"
260 /// tracking to sparse "trie-node" tracking once the prefix length
261 /// matches `index.dense.depth`.
262 fn gather_state(
263 &self,
264 state: &mut BeamState,
265 new_tokens: &[Vec<u32>], // [B][M]
266 new_scores: &[Vec<f64>], // [B][M]
267 src_beams: &[Vec<usize>], // [B][M] — source beam for each new beam
268 next_nodes: &[Vec<Vec<u32>>], // [B][M][B_t] — from VNTK
269 step: usize,
270 ) {
271 let b = self.batch_size;
272 let m = self.beam_width;
273 let depth = self.index.dense.depth as usize;
274
275 // Snapshot current state to avoid reading partially updated sequences.
276 let old_tokens: Vec<Vec<Vec<u32>>> = state.tokens.clone();
277 let old_nodes: Vec<Vec<u32>> = state.nodes.clone();
278
279 for bi in 0..b {
280 for mi in 0..m {
281 let src_idx = src_beams[bi][mi];
282 let chosen_token = new_tokens[bi][mi];
283
284 // 1. Update cumulative score
285 state.scores[bi][mi] = new_scores[bi][mi];
286
287 // 2. Extend the sequence (copy-on-write from source beam)
288 let mut seq = old_tokens[bi][src_idx].clone();
289 seq.push(chosen_token);
290 state.tokens[bi][mi] = seq;
291
292 // 3. Advance the trie node
293 // step 0 creates a 1-token prefix; step (depth-1) creates a depth-token prefix.
294 let current_len = step + 1;
295
296 state.nodes[bi][mi] = if current_len < depth {
297 // Phase A: Still in dense marginalization territory.
298 // We don't have enough tokens to look up a specific trie node yet.
299 0
300 } else if current_len == depth {
301 // Phase B: Boundary reached.
302 // Use the bit-packed DenseMask to find the trie node starting the sparse layer.
303 let prefix = &state.tokens[bi][mi];
304 self.index.dense.state_for(prefix).unwrap_or_else(|| {
305 debug_assert!(false, "Prefix {:?} missing in dense mask", prefix);
306 0
307 })
308 } else {
309 // Phase C: Deep sparse layer traversal using VNTK.
310 self.resolve_next_node(
311 old_nodes[bi][src_idx],
312 chosen_token,
313 &next_nodes[bi][src_idx],
314 step,
315 )
316 };
317 }
318 }
319 }
320
321 /// Resolves the next trie node for a beam that chose `token` at `step`,
322 /// given the pre-computed `next_node_slots` from VNTK.
323 ///
324 /// VNTK returns slots sorted by token ID, so we binary-search rather than
325 /// doing a linear scan or a second CSR lookup.
326 pub fn resolve_next_node(
327 &self,
328 current_node: u32,
329 token: u32,
330 next_node_slots: &[u32], // length B_t, parallel to sorted children
331 step: usize,
332 ) -> u32 {
333 // Children are sorted by token ID; binary-search for `token`.
334 let children = self.index.sparse.children(current_node);
335 match children.binary_search_by_key(&token, |&[t, _]| t) {
336 Ok(pos) if pos < next_node_slots.len() => next_node_slots[pos],
337 // Fallback: direct CSR lookup (should not happen in correct usage).
338 Ok(pos) => children[pos][1],
339 Err(_) => {
340 debug_assert!(
341 false,
342 "token {token} not found in children of node {current_node}"
343 );
344 0
345 }
346 }
347 }
348}
349
350// ──────────────────────────────────────────────────────────────────────────────
351// Pure helper functions
352// ──────────────────────────────────────────────────────────────────────────────
353
354/// Numerically stable log-softmax over the last axis.
355/// Input / output shape: `[B][M][|V|]`.
356pub fn log_softmax_3d(logits: &[Vec<Vec<f64>>]) -> Vec<Vec<Vec<f64>>> {
357 logits
358 .par_iter()
359 .map(|query| query.iter().map(|beam| log_softmax_1d(beam)).collect())
360 .collect()
361}
362
363/// Numerically stable log-softmax over a single 1-D slice.
364pub fn log_softmax_1d(x: &[f64]) -> Vec<f64> {
365 let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
366 let log_sum_exp = x.iter().map(|&v| (v - max).exp()).sum::<f64>().ln();
367 x.iter().map(|&v| v - max - log_sum_exp).collect()
368}
369
370/// Applies a boolean constraint mask to log-probabilities.
371/// Invalid tokens (mask == false) are set to `f64::NEG_INFINITY`.
372/// Input / output shape: `[B][M][|V|]`.
373pub fn apply_mask(log_probs: &[Vec<Vec<f64>>], masks: &[Vec<Vec<bool>>]) -> Vec<Vec<Vec<f64>>> {
374 log_probs
375 .par_iter()
376 .zip(masks.par_iter())
377 .map(|(q_lp, q_mask)| {
378 q_lp.iter()
379 .zip(q_mask.iter())
380 .map(|(beam_lp, beam_mask)| {
381 beam_lp
382 .iter()
383 .zip(beam_mask.iter())
384 .map(|(&lp, &valid)| if valid { lp } else { f64::NEG_INFINITY })
385 .collect()
386 })
387 .collect()
388 })
389 .collect()
390}
391
392/// Beam search selection over masked log-probabilities.
393///
394/// Scores are accumulated as `parent_score + log_prob(token)`.
395///
396/// Returns `(new_tokens, new_scores, src_beams)`, all shaped `[B][M]`.
397pub fn beam_search(
398 masked_log_probs: &[Vec<Vec<f64>>], // [B][M][|V|]
399 parent_scores: &[Vec<f64>], // [B][M]
400 beam_width: usize,
401) -> (Vec<Vec<u32>>, Vec<Vec<f64>>, Vec<Vec<usize>>) {
402 let b = masked_log_probs.len();
403
404 // Process each query in the batch independently and in parallel.
405 let results: Vec<_> = (0..b)
406 .into_par_iter()
407 .map(|bi| {
408 let lp = &masked_log_probs[bi]; // [M][|V|]
409 let par = &parent_scores[bi]; // [M]
410 let vocab = lp[0].len();
411 let m = lp.len();
412
413 // Enumerate all (beam, token) candidates and score them.
414 let mut candidates: Vec<(f64, usize, u32)> = // (score, src_beam, token)
415 (0..m)
416 .flat_map(|mi| {
417 (0..vocab).filter_map(move |v| {
418 let lp_val = lp[mi][v];
419 if lp_val.is_finite() {
420 Some((par[mi] + lp_val, mi, v as u32))
421 } else {
422 None
423 }
424 })
425 })
426 .collect();
427
428 // Partial-sort: keep top `beam_width` by descending score.
429 candidates.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
430 candidates.truncate(beam_width);
431
432 // Separate into parallel vecs.
433 let new_scores: Vec<f64> = candidates.iter().map(|c| c.0).collect();
434 let src_beams: Vec<usize> = candidates.iter().map(|c| c.1).collect();
435 let new_tokens: Vec<u32> = candidates.iter().map(|c| c.2).collect();
436
437 (new_tokens, new_scores, src_beams)
438 })
439 .collect();
440
441 let new_tokens: Vec<Vec<u32>> = results.iter().map(|r| r.0.clone()).collect();
442 let new_scores: Vec<Vec<f64>> = results.iter().map(|r| r.1.clone()).collect();
443 let src_beams: Vec<Vec<usize>> = results.iter().map(|r| r.2.clone()).collect();
444
445 (new_tokens, new_scores, src_beams)
446}
447
448// ──────────────────────────────────────────────────────────────────────────────
449// Public convenience: full decode from flat uniform logits (used by tests)
450// ──────────────────────────────────────────────────────────────────────────────
451
452/// Runs the full decode loop using a *static* flat logit vector (same logits
453/// repeated for every batch entry, beam, and step). Useful for unit tests
454/// where the model is not available.
455pub fn constrained_beam_decode(
456 index: &StaticIndex,
457 flat_logits: &[f32], // length = vocab_size
458 sid_length: usize,
459 beam_width: usize,
460) -> Vec<Vec<u32>> {
461 let vocab = index.sparse.vocab_size as usize;
462 let logits_f64: Vec<f64> = flat_logits.iter().map(|&v| v as f64).collect();
463 // Shape: [1][beam_width][vocab_size]
464 let logits_3d = vec![vec![logits_f64; beam_width]];
465
466 let decoder = ConstrainedDecoder::new(index.clone(), beam_width, 1);
467 let sequences = decoder.decode(|_state, _step| logits_3d.clone(), sid_length);
468
469 sequences.into_iter().next().unwrap_or_default()
470}