Skip to main content

constraint_decoding_trie/
decoder.rs

1// src/decoder.rs
2
3use rayon::prelude::*;
4
5use crate::types::{BeamState, StaticIndex, VntkOutput};
6use crate::vntk::VntkResult;
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Top-level decoder struct
10// ──────────────────────────────────────────────────────────────────────────────
11
12pub struct ConstrainedDecoder {
13    pub index: StaticIndex,
14    pub beam_width: usize, // M
15    pub batch_size: usize, // B
16}
17
18impl ConstrainedDecoder {
19    pub fn new(index: StaticIndex, beam_width: usize, batch_size: usize) -> Self {
20        Self {
21            index,
22            beam_width,
23            batch_size,
24        }
25    }
26
27    // ──────────────────────────────────────────────────────────────────────────
28    // Public: single decoding step  (Algorithm 1, one iteration)
29    // ──────────────────────────────────────────────────────────────────────────
30
31    /// Execute one constrained decoding step.
32    ///
33    /// # Arguments
34    /// - `logits` — raw model outputs, shape \[B × M × |V|\]
35    /// - `state`  — mutable beam state (nodes, scores, partial tokens)
36    /// - `step`   — 0-indexed decoding step t
37    pub fn step(
38        &self,
39        logits: &[Vec<Vec<f64>>], // [B][M][|V|]
40        state: &mut BeamState,
41        step: usize,
42    ) {
43        let vocab = self.index.sparse.vocab_size as usize;
44        let b = self.batch_size;
45        let m = self.beam_width;
46
47        debug_assert_eq!(logits.len(), b);
48        debug_assert!(logits.iter().all(|q| q.len() == m));
49        debug_assert!(logits.iter().all(|q| q.iter().all(|bm| bm.len() == vocab)));
50
51        // ── Phase 1: LogSoftmax ───────────────────────────────────────────────
52        let log_probs = log_softmax_3d(logits);
53
54        // ── Phase 2: Constraint masking ───────────────────────────────────────
55        // Returns:
56        //   masks      : [B][M][|V|]  — true = token is valid
57        //   next_nodes : [B][M][B_t]  — trie nodes after each valid token slot
58        let (masks, next_nodes) = if step < self.index.dense.depth as usize {
59            self.dense_lookup(state, step)
60        } else {
61            self.sparse_lookup(state, step)
62        };
63
64        // ── Phase 3: Apply mask → NEG_INF for invalid tokens ─────────────────
65        let masked = apply_mask(&log_probs, &masks);
66
67        // ── Phase 4: Beam search selection ───────────────────────────────────
68        // new_tokens  : [B][M]      — chosen token per surviving beam
69        // new_scores  : [B][M]      — updated cumulative log-prob
70        // src_beams   : [B][M]      — which old beam each new beam came from
71        let (new_tokens, new_scores, src_beams) = beam_search(&masked, &state.scores, m);
72
73        // ── Phase 5: State gather ─────────────────────────────────────────────
74        self.gather_state(
75            state,
76            &new_tokens,
77            &new_scores,
78            &src_beams,
79            &next_nodes,
80            step,
81        );
82    }
83
84    // ──────────────────────────────────────────────────────────────────────────
85    // Public: full decoding loop  (Algorithm 1 complete)
86    // ──────────────────────────────────────────────────────────────────────────
87
88    /// Run the full constrained beam-search loop for `sid_length` steps.
89    ///
90    /// `logit_fn` is called once per step; it receives the current `BeamState`
91    /// and must return logits of shape `[B × M × |V|]`.
92    ///
93    /// Returns the top-`beam_width` decoded SIDs for every query in the batch.
94    pub fn decode<F>(&self, logit_fn: F, sid_length: usize) -> Vec<Vec<Vec<u32>>>
95    // [B][M][L]
96    where
97        F: Fn(&BeamState, usize) -> Vec<Vec<Vec<f64>>>,
98    {
99        let mut state = BeamState::new(self.batch_size, self.beam_width);
100
101        for step in 0..sid_length {
102            let logits = logit_fn(&state, step);
103            self.step(&logits, &mut state, step);
104        }
105
106        // Return the token sequences accumulated in state.
107        state.tokens.clone()
108    }
109
110    // ──────────────────────────────────────────────────────────────────────────
111    // Phase 2a: dense lookup  (steps 0 .. dense_depth−1)
112    // ──────────────────────────────────────────────────────────────────────────
113
114    /// For steps covered by the bit-packed dense mask, look up validity in O(1)
115    /// per token without touching the CSR matrix.
116    ///
117    /// Returns `(masks, next_nodes)` shaped `[B][M][|V|]` and `[B][M][1]`
118    /// respectively (one "next node" per beam; the trie node reached after the
119    /// chosen token is resolved lazily in `gather_state` from the dense mask's
120    /// `states` array).
121    pub fn dense_lookup(
122        &self,
123        state: &BeamState,
124        step: usize,
125    ) -> (Vec<Vec<Vec<bool>>>, Vec<Vec<Vec<u32>>>) {
126        let vocab = self.index.sparse.vocab_size as usize;
127        let depth = self.index.dense.depth as usize;
128        let b = self.batch_size;
129        let m = self.beam_width;
130
131        debug_assert!(step < depth, "dense_lookup called outside dense range");
132        debug_assert!(depth >= 1);
133
134        // masks[b][m][v] = token validity at this step for each beam
135        let mut masks: Vec<Vec<Vec<bool>>> = vec![vec![vec![false; vocab]; m]; b];
136
137        // next_nodes is not used for dense steps in our gather logic; keep shape stable.
138        let next_nodes: Vec<Vec<Vec<u32>>> = vec![vec![vec![0u32; 1]; m]; b];
139
140        for bi in 0..b {
141            for mi in 0..m {
142                let prev = &state.tokens[bi][mi];
143                debug_assert_eq!(prev.len(), step);
144
145                if step == 0 {
146                    // Step 0: allow tokens that start at least one valid dense prefix.
147                    for tok in 0..vocab as u32 {
148                        if self.index.dense.first_token_valid(tok) {
149                            masks[bi][mi][tok as usize] = true;
150                        }
151                    }
152                    continue;
153                }
154
155                // step >= 1: extend the prefix by one candidate token and test
156                for tok in 0..vocab as u32 {
157                    let mut candidate = prev.clone();
158                    candidate.push(tok);
159
160                    let valid = if candidate.len() == depth {
161                        // Boundary case: full dense prefix, must be exact membership
162                        self.index.dense.contains(&candidate)
163                    } else {
164                        // Proper partial prefix
165                        self.index.dense.partial_prefix_has_extension(&candidate)
166                    };
167
168                    if valid {
169                        masks[bi][mi][tok as usize] = true;
170                    }
171                }
172            }
173        }
174
175        (masks, next_nodes)
176    }
177
178    /// Returns true if *any* full-depth dense entry starts with `tok`.
179    #[inline]
180    pub fn dense_first_token_valid(&self, tok: u32) -> bool {
181        self.index.dense.first_token_valid(tok)
182    }
183
184    /// Returns true if `partial_prefix` (length < depth) can be extended to a
185    /// valid full-depth prefix.
186    fn dense_prefix_has_extension(&self, partial_prefix: &[u32]) -> bool {
187        let vocab = self.index.sparse.vocab_size as usize;
188        let depth = self.index.dense.depth as usize;
189        let len = partial_prefix.len();
190        debug_assert!(len < depth);
191
192        // Flat index of the first entry in the block covered by partial_prefix.
193        let block_start: usize = partial_prefix
194            .iter()
195            .fold(0usize, |acc, &t| acc * vocab + t as usize);
196        let stride = vocab.pow((depth - len) as u32);
197        let base = block_start * stride;
198        let end = base + stride;
199        let ws = base / 64;
200        let we = end.div_ceil(64).min(self.index.dense.bits.len());
201        self.index.dense.bits[ws..we].iter().any(|&w| w != 0)
202    }
203
204    // ──────────────────────────────────────────────────────────────────────────
205    // Phase 2b: sparse lookup  (steps dense_depth .. L−1)
206    // ──────────────────────────────────────────────────────────────────────────
207
208    /// For deeper steps, call VNTK on the CSR transition matrix.
209    ///
210    /// Returns `(masks, next_nodes)` shaped `[B][M][|V|]` and `[B][M][B_t]`.
211    pub fn sparse_lookup(
212        &self,
213        state: &BeamState,
214        step: usize,
215    ) -> (Vec<Vec<Vec<bool>>>, Vec<Vec<Vec<u32>>>) {
216        let vocab = self.index.sparse.vocab_size as usize;
217        let b = self.batch_size;
218        let m = self.beam_width;
219        let b_t = self.index.sparse.max_branches[step] as usize;
220
221        // Flatten [B][M] nodes into a single slice for a single VNTK call.
222        let flat_nodes: Vec<u32> = state.nodes.iter().flatten().copied().collect();
223
224        let result = self.index.sparse.vntk(&flat_nodes, step);
225
226        // Reshape VntkResult back to [B][M][…]
227        let mut masks: Vec<Vec<Vec<bool>>> = vec![vec![vec![false; vocab]; m]; b];
228        let mut next_nodes: Vec<Vec<Vec<u32>>> = vec![vec![vec![0u32; b_t]; m]; b];
229
230        for bi in 0..b {
231            for mi in 0..m {
232                let flat_i = bi * m + mi;
233
234                // Dense mask slice → masks[bi][mi][*]
235                let mask_slice = result.mask_for(flat_i, vocab);
236                masks[bi][mi].copy_from_slice(mask_slice);
237
238                // Next-node slots → next_nodes[bi][mi][*]
239                let base = flat_i * b_t;
240                next_nodes[bi][mi].copy_from_slice(&result.next_nodes[base..base + b_t]);
241            }
242        }
243
244        (masks, next_nodes)
245    }
246
247    // ──────────────────────────────────────────────────────────────────────────
248    // Phase 5: state gather
249    // ──────────────────────────────────────────────────────────────────────────
250
251    /// Applies the beam-search selection to the live `BeamState`.
252    ///
253    /// For each surviving beam in each batch entry:
254    /// 1. Copy the partial token sequence from the *source* beam.
255    /// 2. Append the newly chosen token.
256    /// 3. Advance the trie node pointer using `next_nodes`.
257    /// Applies the beam-search selection to the live `BeamState`.
258    ///
259    /// This implementation handles the transition from dense "prefix-only"
260    /// tracking to sparse "trie-node" tracking once the prefix length
261    /// matches `index.dense.depth`.
262    fn gather_state(
263        &self,
264        state: &mut BeamState,
265        new_tokens: &[Vec<u32>],      // [B][M]
266        new_scores: &[Vec<f64>],      // [B][M]
267        src_beams: &[Vec<usize>],     // [B][M] — source beam for each new beam
268        next_nodes: &[Vec<Vec<u32>>], // [B][M][B_t] — from VNTK
269        step: usize,
270    ) {
271        let b = self.batch_size;
272        let m = self.beam_width;
273        let depth = self.index.dense.depth as usize;
274
275        // Snapshot current state to avoid reading partially updated sequences.
276        let old_tokens: Vec<Vec<Vec<u32>>> = state.tokens.clone();
277        let old_nodes: Vec<Vec<u32>> = state.nodes.clone();
278
279        for bi in 0..b {
280            for mi in 0..m {
281                let src_idx = src_beams[bi][mi];
282                let chosen_token = new_tokens[bi][mi];
283
284                // 1. Update cumulative score
285                state.scores[bi][mi] = new_scores[bi][mi];
286
287                // 2. Extend the sequence (copy-on-write from source beam)
288                let mut seq = old_tokens[bi][src_idx].clone();
289                seq.push(chosen_token);
290                state.tokens[bi][mi] = seq;
291
292                // 3. Advance the trie node
293                // step 0 creates a 1-token prefix; step (depth-1) creates a depth-token prefix.
294                let current_len = step + 1;
295
296                state.nodes[bi][mi] = if current_len < depth {
297                    // Phase A: Still in dense marginalization territory.
298                    // We don't have enough tokens to look up a specific trie node yet.
299                    0
300                } else if current_len == depth {
301                    // Phase B: Boundary reached.
302                    // Use the bit-packed DenseMask to find the trie node starting the sparse layer.
303                    let prefix = &state.tokens[bi][mi];
304                    self.index.dense.state_for(prefix).unwrap_or_else(|| {
305                        debug_assert!(false, "Prefix {:?} missing in dense mask", prefix);
306                        0
307                    })
308                } else {
309                    // Phase C: Deep sparse layer traversal using VNTK.
310                    self.resolve_next_node(
311                        old_nodes[bi][src_idx],
312                        chosen_token,
313                        &next_nodes[bi][src_idx],
314                        step,
315                    )
316                };
317            }
318        }
319    }
320
321    /// Resolves the next trie node for a beam that chose `token` at `step`,
322    /// given the pre-computed `next_node_slots` from VNTK.
323    ///
324    /// VNTK returns slots sorted by token ID, so we binary-search rather than
325    /// doing a linear scan or a second CSR lookup.
326    pub fn resolve_next_node(
327        &self,
328        current_node: u32,
329        token: u32,
330        next_node_slots: &[u32], // length B_t, parallel to sorted children
331        step: usize,
332    ) -> u32 {
333        // Children are sorted by token ID; binary-search for `token`.
334        let children = self.index.sparse.children(current_node);
335        match children.binary_search_by_key(&token, |&[t, _]| t) {
336            Ok(pos) if pos < next_node_slots.len() => next_node_slots[pos],
337            // Fallback: direct CSR lookup (should not happen in correct usage).
338            Ok(pos) => children[pos][1],
339            Err(_) => {
340                debug_assert!(
341                    false,
342                    "token {token} not found in children of node {current_node}"
343                );
344                0
345            }
346        }
347    }
348}
349
350// ──────────────────────────────────────────────────────────────────────────────
351// Pure helper functions
352// ──────────────────────────────────────────────────────────────────────────────
353
354/// Numerically stable log-softmax over the last axis.
355/// Input / output shape: `[B][M][|V|]`.
356pub fn log_softmax_3d(logits: &[Vec<Vec<f64>>]) -> Vec<Vec<Vec<f64>>> {
357    logits
358        .par_iter()
359        .map(|query| query.iter().map(|beam| log_softmax_1d(beam)).collect())
360        .collect()
361}
362
363/// Numerically stable log-softmax over a single 1-D slice.
364pub fn log_softmax_1d(x: &[f64]) -> Vec<f64> {
365    let max = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
366    let log_sum_exp = x.iter().map(|&v| (v - max).exp()).sum::<f64>().ln();
367    x.iter().map(|&v| v - max - log_sum_exp).collect()
368}
369
370/// Applies a boolean constraint mask to log-probabilities.
371/// Invalid tokens (mask == false) are set to `f64::NEG_INFINITY`.
372/// Input / output shape: `[B][M][|V|]`.
373pub fn apply_mask(log_probs: &[Vec<Vec<f64>>], masks: &[Vec<Vec<bool>>]) -> Vec<Vec<Vec<f64>>> {
374    log_probs
375        .par_iter()
376        .zip(masks.par_iter())
377        .map(|(q_lp, q_mask)| {
378            q_lp.iter()
379                .zip(q_mask.iter())
380                .map(|(beam_lp, beam_mask)| {
381                    beam_lp
382                        .iter()
383                        .zip(beam_mask.iter())
384                        .map(|(&lp, &valid)| if valid { lp } else { f64::NEG_INFINITY })
385                        .collect()
386                })
387                .collect()
388        })
389        .collect()
390}
391
392/// Beam search selection over masked log-probabilities.
393///
394/// Scores are accumulated as `parent_score + log_prob(token)`.
395///
396/// Returns `(new_tokens, new_scores, src_beams)`, all shaped `[B][M]`.
397pub fn beam_search(
398    masked_log_probs: &[Vec<Vec<f64>>], // [B][M][|V|]
399    parent_scores: &[Vec<f64>],         // [B][M]
400    beam_width: usize,
401) -> (Vec<Vec<u32>>, Vec<Vec<f64>>, Vec<Vec<usize>>) {
402    let b = masked_log_probs.len();
403
404    // Process each query in the batch independently and in parallel.
405    let results: Vec<_> = (0..b)
406        .into_par_iter()
407        .map(|bi| {
408            let lp = &masked_log_probs[bi]; // [M][|V|]
409            let par = &parent_scores[bi]; // [M]
410            let vocab = lp[0].len();
411            let m = lp.len();
412
413            // Enumerate all (beam, token) candidates and score them.
414            let mut candidates: Vec<(f64, usize, u32)> = // (score, src_beam, token)
415                (0..m)
416                    .flat_map(|mi| {
417                        (0..vocab).filter_map(move |v| {
418                            let lp_val = lp[mi][v];
419                            if lp_val.is_finite() {
420                                Some((par[mi] + lp_val, mi, v as u32))
421                            } else {
422                                None
423                            }
424                        })
425                    })
426                    .collect();
427
428            // Partial-sort: keep top `beam_width` by descending score.
429            candidates.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
430            candidates.truncate(beam_width);
431
432            // Separate into parallel vecs.
433            let new_scores: Vec<f64> = candidates.iter().map(|c| c.0).collect();
434            let src_beams: Vec<usize> = candidates.iter().map(|c| c.1).collect();
435            let new_tokens: Vec<u32> = candidates.iter().map(|c| c.2).collect();
436
437            (new_tokens, new_scores, src_beams)
438        })
439        .collect();
440
441    let new_tokens: Vec<Vec<u32>> = results.iter().map(|r| r.0.clone()).collect();
442    let new_scores: Vec<Vec<f64>> = results.iter().map(|r| r.1.clone()).collect();
443    let src_beams: Vec<Vec<usize>> = results.iter().map(|r| r.2.clone()).collect();
444
445    (new_tokens, new_scores, src_beams)
446}
447
448// ──────────────────────────────────────────────────────────────────────────────
449// Public convenience: full decode from flat uniform logits (used by tests)
450// ──────────────────────────────────────────────────────────────────────────────
451
452/// Runs the full decode loop using a *static* flat logit vector (same logits
453/// repeated for every batch entry, beam, and step).  Useful for unit tests
454/// where the model is not available.
455pub fn constrained_beam_decode(
456    index: &StaticIndex,
457    flat_logits: &[f32], // length = vocab_size
458    sid_length: usize,
459    beam_width: usize,
460) -> Vec<Vec<u32>> {
461    let vocab = index.sparse.vocab_size as usize;
462    let logits_f64: Vec<f64> = flat_logits.iter().map(|&v| v as f64).collect();
463    // Shape: [1][beam_width][vocab_size]
464    let logits_3d = vec![vec![logits_f64; beam_width]];
465
466    let decoder = ConstrainedDecoder::new(index.clone(), beam_width, 1);
467    let sequences = decoder.decode(|_state, _step| logits_3d.clone(), sid_length);
468
469    sequences.into_iter().next().unwrap_or_default()
470}