Skip to main content

constraint_decoding_trie/
vntk.rs

1// src/vntk.rs
2
3use rayon::prelude::*;
4use std::sync::atomic::{AtomicU32, Ordering};
5
6use crate::types::{TransitionMatrix, VntkOutput};
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Public result type
10// ──────────────────────────────────────────────────────────────────────────────
11
12/// Output of a single VNTK call covering all beams at one decoding step.
13///
14/// Index arithmetic:
15/// - `tokens[i * branch_size + j]`      — j-th token candidate for beam i
16/// - `next_nodes[i * branch_size + j]`  — trie node reached by that token
17/// - `valid[i * branch_size + j]`       — whether slot j is a real child
18/// - `dense_masks[i * vocab_size + tok]`— O(1) membership test for beam i
19#[derive(Debug, Clone)]
20pub struct VntkResult {
21    /// Token IDs: shape [n × branch_size], invalid slots hold 0.
22    pub tokens: Vec<u32>,
23    /// Next-node IDs: shape [n × branch_size], invalid slots hold 0.
24    pub next_nodes: Vec<u32>,
25    /// Validity flags: shape [n × branch_size].
26    pub valid: Vec<bool>,
27    /// Dense boolean mask: shape [n × vocab_size].
28    pub dense_masks: Vec<bool>,
29    /// B_t: the padded branch-factor used at this level.
30    pub branch_size: usize,
31}
32
33impl VntkResult {
34    /// Returns the valid (token, next_node) pairs for beam `i`.
35    #[inline]
36    pub fn children_for(&self, i: usize) -> impl Iterator<Item = (u32, u32)> + '_ {
37        let base = i * self.branch_size;
38        (0..self.branch_size).filter_map(move |j| {
39            if self.valid[base + j] {
40                Some((self.tokens[base + j], self.next_nodes[base + j]))
41            } else {
42                None
43            }
44        })
45    }
46
47    /// Returns the dense mask slice for beam `i` (length = vocab_size).
48    #[inline]
49    pub fn mask_for(&self, i: usize, vocab_size: usize) -> &[bool] {
50        let base = i * vocab_size;
51        &self.dense_masks[base..base + vocab_size]
52    }
53
54    /// Collapses all per-beam dense masks into a single OR-reduced mask of
55    /// length `vocab_size`.  Used when all beams in a batch share one logit
56    /// vector (single-query inference).
57    pub fn global_mask(&self, vocab_size: usize) -> Vec<bool> {
58        let n = self.dense_masks.len() / vocab_size;
59        let mut out = vec![false; vocab_size];
60        for i in 0..n {
61            let base = i * vocab_size;
62            for (o, &m) in out
63                .iter_mut()
64                .zip(&self.dense_masks[base..base + vocab_size])
65            {
66                *o |= m;
67            }
68        }
69        out
70    }
71
72    /// Converts the dense bool mask for beam `i` into a packed `Vec<u64>`
73    /// (same layout as `DenseMask::bits`) for cheap bitwise AND with
74    /// the model's top-k mask.
75    pub fn packed_mask_for(&self, i: usize, vocab_size: usize) -> Vec<u64> {
76        let slice = self.mask_for(i, vocab_size);
77        let words = vocab_size.div_ceil(64);
78        let mut out = vec![0u64; words];
79        for (idx, &set) in slice.iter().enumerate() {
80            if set {
81                out[idx / 64] |= 1u64 << (idx % 64);
82            }
83        }
84        out
85    }
86}
87
88// ──────────────────────────────────────────────────────────────────────────────
89// VNTK implementation
90// ──────────────────────────────────────────────────────────────────────────────
91
92impl TransitionMatrix {
93    /// **Vectorized Node Transition Kernel** — Algorithm 2 from the paper.
94    ///
95    /// For each of the `n = batch_size × beam_width` active beams, reads the
96    /// CSR row for that beam's current trie node and writes up to `B_t`
97    /// (token, next-node) pairs into pre-allocated output buffers.
98    ///
99    /// # Layout
100    /// All output buffers are flat and strided by `branch_size` (= B_t).
101    ///
102    /// # Parallelism
103    /// The per-beam inner loop is embarrassingly parallel and runs via Rayon.
104    /// Writes to disjoint buffer slices avoid any synchronisation overhead.
105    ///
106    /// # Arguments
107    /// - `current_nodes` — flat slice of length `n`, one node ID per beam
108    /// - `level`         — current decoding step (0-indexed); selects `B_t`
109    ///
110    /// # Panics
111    /// Panics if `level >= sid_length` or any node ID ≥ `num_nodes`.
112    pub fn vntk(&self, current_nodes: &[u32], level: usize) -> VntkResult {
113        assert!(
114            level < self.sid_length as usize,
115            "level {level} out of range (sid_length={})",
116            self.sid_length
117        );
118
119        let b_t = self.max_branches[level] as usize;
120        let n = current_nodes.len();
121        let v = self.vocab_size as usize;
122
123        // Allocate output buffers up-front; rayon writes into disjoint slices.
124        let mut tokens: Vec<u32> = vec![0u32; n * b_t];
125        let mut next_nodes: Vec<u32> = vec![0u32; n * b_t];
126        let mut valid: Vec<bool> = vec![false; n * b_t];
127        let mut dense_masks: Vec<bool> = vec![false; n * v];
128
129        // Split each output buffer into n contiguous chunks, one per beam,
130        // then zip them together so each rayon task owns exactly its slice.
131        let tok_chunks: Vec<&mut [u32]> = tokens.chunks_mut(b_t).collect();
132        let next_chunks: Vec<&mut [u32]> = next_nodes.chunks_mut(b_t).collect();
133        let valid_chunks: Vec<&mut [bool]> = valid.chunks_mut(b_t).collect();
134        let mask_chunks: Vec<&mut [bool]> = dense_masks.chunks_mut(v).collect();
135
136        // Bundle into a single Vec of mutable tuple-slices for rayon.
137        tok_chunks
138            .into_par_iter()
139            .zip(next_chunks)
140            .zip(valid_chunks)
141            .zip(mask_chunks)
142            .zip(current_nodes.par_iter())
143            .for_each(|((((tok_s, next_s), valid_s), mask_s), &node)| {
144                debug_assert!(
145                    node < self.num_nodes,
146                    "node {node} ≥ num_nodes {}",
147                    self.num_nodes
148                );
149
150                // ── Phase 1: CSR boundary lookup ─────────────────────────────
151                let row_start = self.row_pointers[node as usize] as usize;
152                let row_end = self.row_pointers[node as usize + 1] as usize;
153                let n_child = row_end - row_start;
154
155                // ── Phase 2: Speculative copy into padded B_t slots ──────────
156                // Slots beyond n_child remain zeroed (implicit padding).
157                let fill = n_child.min(b_t);
158                for j in 0..fill {
159                    let entry = self.data[row_start + j];
160                    tok_s[j] = entry[0];
161                    next_s[j] = entry[1];
162                    valid_s[j] = true;
163                }
164
165                // ── Phase 3: Scatter into dense vocab mask ───────────────────
166                // Only `fill` entries are real; token IDs are already sorted.
167                for j in 0..fill {
168                    mask_s[self.data[row_start + j][0] as usize] = true;
169                }
170            });
171
172        VntkResult {
173            tokens,
174            next_nodes,
175            valid,
176            dense_masks,
177            branch_size: b_t,
178        }
179    }
180
181    /// Thin wrapper that converts a `VntkResult` into the simpler `VntkOutput`
182    /// expected by the test module (`next_nodes` flat vec + single bool mask).
183    ///
184    /// Only meaningful when `current_nodes` contains a single beam; for
185    /// multi-beam callers use `VntkResult` directly.
186    pub fn vntk_single(&self, node: u32, level: usize) -> VntkOutput {
187        let result = self.vntk(&[node], level);
188        VntkOutput {
189            next_nodes: result.children_for(0).map(|(_, n)| n).collect(),
190            mask: result.dense_masks[..self.vocab_size as usize].to_vec(),
191        }
192    }
193}
194
195// ──────────────────────────────────────────────────────────────────────────────
196// Standalone function form (matches the test module's call convention)
197// ──────────────────────────────────────────────────────────────────────────────
198
199/// Calls `TransitionMatrix::vntk` and returns a `VntkOutput` shaped for the
200/// test module:
201/// - `next_nodes`: flat list of valid next-node IDs across all beams
202/// - `mask`:       OR-reduced dense bool mask of length `vocab_size`
203pub fn vntk(
204    current_nodes: &[u32],
205    matrix: &TransitionMatrix,
206    level: usize,
207    vocab_size: usize,
208) -> VntkOutput {
209    debug_assert_eq!(
210        vocab_size, matrix.vocab_size as usize,
211        "vocab_size mismatch"
212    );
213    let result = matrix.vntk(current_nodes, level);
214
215    // Collect all valid next-node IDs in beam × child order.
216    let next_nodes: Vec<u32> = (0..current_nodes.len())
217        .flat_map(|i| result.children_for(i).map(|(_, n)| n))
218        .collect();
219
220    let mask = result.global_mask(vocab_size);
221
222    VntkOutput { next_nodes, mask }
223}