constraint_decoding_trie/vntk.rs
1// src/vntk.rs
2
3use rayon::prelude::*;
4use std::sync::atomic::{AtomicU32, Ordering};
5
6use crate::types::{TransitionMatrix, VntkOutput};
7
8// ──────────────────────────────────────────────────────────────────────────────
9// Public result type
10// ──────────────────────────────────────────────────────────────────────────────
11
12/// Output of a single VNTK call covering all beams at one decoding step.
13///
14/// Index arithmetic:
15/// - `tokens[i * branch_size + j]` — j-th token candidate for beam i
16/// - `next_nodes[i * branch_size + j]` — trie node reached by that token
17/// - `valid[i * branch_size + j]` — whether slot j is a real child
18/// - `dense_masks[i * vocab_size + tok]`— O(1) membership test for beam i
19#[derive(Debug, Clone)]
20pub struct VntkResult {
21 /// Token IDs: shape [n × branch_size], invalid slots hold 0.
22 pub tokens: Vec<u32>,
23 /// Next-node IDs: shape [n × branch_size], invalid slots hold 0.
24 pub next_nodes: Vec<u32>,
25 /// Validity flags: shape [n × branch_size].
26 pub valid: Vec<bool>,
27 /// Dense boolean mask: shape [n × vocab_size].
28 pub dense_masks: Vec<bool>,
29 /// B_t: the padded branch-factor used at this level.
30 pub branch_size: usize,
31}
32
33impl VntkResult {
34 /// Returns the valid (token, next_node) pairs for beam `i`.
35 #[inline]
36 pub fn children_for(&self, i: usize) -> impl Iterator<Item = (u32, u32)> + '_ {
37 let base = i * self.branch_size;
38 (0..self.branch_size).filter_map(move |j| {
39 if self.valid[base + j] {
40 Some((self.tokens[base + j], self.next_nodes[base + j]))
41 } else {
42 None
43 }
44 })
45 }
46
47 /// Returns the dense mask slice for beam `i` (length = vocab_size).
48 #[inline]
49 pub fn mask_for(&self, i: usize, vocab_size: usize) -> &[bool] {
50 let base = i * vocab_size;
51 &self.dense_masks[base..base + vocab_size]
52 }
53
54 /// Collapses all per-beam dense masks into a single OR-reduced mask of
55 /// length `vocab_size`. Used when all beams in a batch share one logit
56 /// vector (single-query inference).
57 pub fn global_mask(&self, vocab_size: usize) -> Vec<bool> {
58 let n = self.dense_masks.len() / vocab_size;
59 let mut out = vec![false; vocab_size];
60 for i in 0..n {
61 let base = i * vocab_size;
62 for (o, &m) in out
63 .iter_mut()
64 .zip(&self.dense_masks[base..base + vocab_size])
65 {
66 *o |= m;
67 }
68 }
69 out
70 }
71
72 /// Converts the dense bool mask for beam `i` into a packed `Vec<u64>`
73 /// (same layout as `DenseMask::bits`) for cheap bitwise AND with
74 /// the model's top-k mask.
75 pub fn packed_mask_for(&self, i: usize, vocab_size: usize) -> Vec<u64> {
76 let slice = self.mask_for(i, vocab_size);
77 let words = vocab_size.div_ceil(64);
78 let mut out = vec![0u64; words];
79 for (idx, &set) in slice.iter().enumerate() {
80 if set {
81 out[idx / 64] |= 1u64 << (idx % 64);
82 }
83 }
84 out
85 }
86}
87
88// ──────────────────────────────────────────────────────────────────────────────
89// VNTK implementation
90// ──────────────────────────────────────────────────────────────────────────────
91
92impl TransitionMatrix {
93 /// **Vectorized Node Transition Kernel** — Algorithm 2 from the paper.
94 ///
95 /// For each of the `n = batch_size × beam_width` active beams, reads the
96 /// CSR row for that beam's current trie node and writes up to `B_t`
97 /// (token, next-node) pairs into pre-allocated output buffers.
98 ///
99 /// # Layout
100 /// All output buffers are flat and strided by `branch_size` (= B_t).
101 ///
102 /// # Parallelism
103 /// The per-beam inner loop is embarrassingly parallel and runs via Rayon.
104 /// Writes to disjoint buffer slices avoid any synchronisation overhead.
105 ///
106 /// # Arguments
107 /// - `current_nodes` — flat slice of length `n`, one node ID per beam
108 /// - `level` — current decoding step (0-indexed); selects `B_t`
109 ///
110 /// # Panics
111 /// Panics if `level >= sid_length` or any node ID ≥ `num_nodes`.
112 pub fn vntk(&self, current_nodes: &[u32], level: usize) -> VntkResult {
113 assert!(
114 level < self.sid_length as usize,
115 "level {level} out of range (sid_length={})",
116 self.sid_length
117 );
118
119 let b_t = self.max_branches[level] as usize;
120 let n = current_nodes.len();
121 let v = self.vocab_size as usize;
122
123 // Allocate output buffers up-front; rayon writes into disjoint slices.
124 let mut tokens: Vec<u32> = vec![0u32; n * b_t];
125 let mut next_nodes: Vec<u32> = vec![0u32; n * b_t];
126 let mut valid: Vec<bool> = vec![false; n * b_t];
127 let mut dense_masks: Vec<bool> = vec![false; n * v];
128
129 // Split each output buffer into n contiguous chunks, one per beam,
130 // then zip them together so each rayon task owns exactly its slice.
131 let tok_chunks: Vec<&mut [u32]> = tokens.chunks_mut(b_t).collect();
132 let next_chunks: Vec<&mut [u32]> = next_nodes.chunks_mut(b_t).collect();
133 let valid_chunks: Vec<&mut [bool]> = valid.chunks_mut(b_t).collect();
134 let mask_chunks: Vec<&mut [bool]> = dense_masks.chunks_mut(v).collect();
135
136 // Bundle into a single Vec of mutable tuple-slices for rayon.
137 tok_chunks
138 .into_par_iter()
139 .zip(next_chunks)
140 .zip(valid_chunks)
141 .zip(mask_chunks)
142 .zip(current_nodes.par_iter())
143 .for_each(|((((tok_s, next_s), valid_s), mask_s), &node)| {
144 debug_assert!(
145 node < self.num_nodes,
146 "node {node} ≥ num_nodes {}",
147 self.num_nodes
148 );
149
150 // ── Phase 1: CSR boundary lookup ─────────────────────────────
151 let row_start = self.row_pointers[node as usize] as usize;
152 let row_end = self.row_pointers[node as usize + 1] as usize;
153 let n_child = row_end - row_start;
154
155 // ── Phase 2: Speculative copy into padded B_t slots ──────────
156 // Slots beyond n_child remain zeroed (implicit padding).
157 let fill = n_child.min(b_t);
158 for j in 0..fill {
159 let entry = self.data[row_start + j];
160 tok_s[j] = entry[0];
161 next_s[j] = entry[1];
162 valid_s[j] = true;
163 }
164
165 // ── Phase 3: Scatter into dense vocab mask ───────────────────
166 // Only `fill` entries are real; token IDs are already sorted.
167 for j in 0..fill {
168 mask_s[self.data[row_start + j][0] as usize] = true;
169 }
170 });
171
172 VntkResult {
173 tokens,
174 next_nodes,
175 valid,
176 dense_masks,
177 branch_size: b_t,
178 }
179 }
180
181 /// Thin wrapper that converts a `VntkResult` into the simpler `VntkOutput`
182 /// expected by the test module (`next_nodes` flat vec + single bool mask).
183 ///
184 /// Only meaningful when `current_nodes` contains a single beam; for
185 /// multi-beam callers use `VntkResult` directly.
186 pub fn vntk_single(&self, node: u32, level: usize) -> VntkOutput {
187 let result = self.vntk(&[node], level);
188 VntkOutput {
189 next_nodes: result.children_for(0).map(|(_, n)| n).collect(),
190 mask: result.dense_masks[..self.vocab_size as usize].to_vec(),
191 }
192 }
193}
194
195// ──────────────────────────────────────────────────────────────────────────────
196// Standalone function form (matches the test module's call convention)
197// ──────────────────────────────────────────────────────────────────────────────
198
199/// Calls `TransitionMatrix::vntk` and returns a `VntkOutput` shaped for the
200/// test module:
201/// - `next_nodes`: flat list of valid next-node IDs across all beams
202/// - `mask`: OR-reduced dense bool mask of length `vocab_size`
203pub fn vntk(
204 current_nodes: &[u32],
205 matrix: &TransitionMatrix,
206 level: usize,
207 vocab_size: usize,
208) -> VntkOutput {
209 debug_assert_eq!(
210 vocab_size, matrix.vocab_size as usize,
211 "vocab_size mismatch"
212 );
213 let result = matrix.vntk(current_nodes, level);
214
215 // Collect all valid next-node IDs in beam × child order.
216 let next_nodes: Vec<u32> = (0..current_nodes.len())
217 .flat_map(|i| result.children_for(i).map(|(_, n)| n))
218 .collect();
219
220 let mask = result.global_mask(vocab_size);
221
222 VntkOutput { next_nodes, mask }
223}