constraint_decoding_trie/dense_mask.rs
1// src/dense_mask.rs
2
3use crate::types::DenseMask;
4use rayon::prelude::*;
5
6// ──────────────────────────────────────────────────────────────────────────────
7// Construction
8// ──────────────────────────────────────────────────────────────────────────────
9
10impl DenseMask {
11 // ------------------------------------------------------------------
12 // Bulk construction from a constraint set
13 // ------------------------------------------------------------------
14
15 /// Build a `DenseMask` directly from a full constraint set.
16 ///
17 /// This is the canonical constructor used by `build_static_index`.
18 /// It is equivalent to calling `DenseMask::new` followed by repeated
19 /// `insert` calls, but avoids recomputing `flat_index` twice per entry.
20 ///
21 /// # Arguments
22 /// - `constraints` — every sequence must have length ≥ `depth`
23 /// - `vocab_size` — |V|
24 /// - `depth` — number of dense layers d (typically 2)
25 /// - `node_ids` — parallel slice: `node_ids[i]` is the trie node reached
26 /// after the first `depth` tokens of `constraints[i]`.
27 /// Pass an all-zeros slice when node IDs are not yet known
28 /// and will be back-filled by `transition.rs`.
29 pub fn from_constraints(
30 constraints: &[Vec<u32>],
31 vocab_size: u32,
32 depth: u32,
33 node_ids: &[u32],
34 ) -> Self {
35 debug_assert_eq!(
36 constraints.len(),
37 node_ids.len(),
38 "constraints and node_ids must have equal length"
39 );
40
41 let mut mask = DenseMask::new(vocab_size, depth);
42
43 for (seq, &nid) in constraints.iter().zip(node_ids.iter()) {
44 debug_assert!(
45 seq.len() >= depth as usize,
46 "sequence too short for dense depth {depth}"
47 );
48 mask.insert(&seq[..depth as usize], nid);
49 }
50
51 mask
52 }
53
54 // ------------------------------------------------------------------
55 // Prefix validity — O(word) scan over packed bits
56 // ------------------------------------------------------------------
57
58 /// Returns `true` if **any** full-depth prefix starts with `first_token`.
59 ///
60 /// Operates on packed `u64` words without deserialising individual bits.
61 /// Used by `decoder.rs` at step 0 to expose the valid first-token set.
62 ///
63 /// # Complexity
64 /// O(|V|^(depth-1) / 64) ≈ O(1) for small depth and typical |V|.
65 pub fn first_token_valid(&self, first_token: u32) -> bool {
66 let (base, end) = self.token_block_range(first_token, 0);
67 self.any_bit_set_in(base, end)
68 }
69
70 /// Returns `true` if `partial` (length < `depth`) can be extended to a
71 /// valid full-depth prefix in the constraint set.
72 ///
73 /// # Panics (debug)
74 /// Panics if `partial.len() >= depth`.
75 pub fn partial_prefix_has_extension(&self, partial: &[u32]) -> bool {
76 debug_assert!(
77 partial.len() < self.depth as usize,
78 "partial prefix length {} must be < depth {}",
79 partial.len(),
80 self.depth
81 );
82 let flat_base: usize = partial.iter().fold(0usize, |acc, &t| {
83 acc * self.vocab_size as usize + t as usize
84 });
85 let stride = (self.vocab_size as usize).pow((self.depth as usize - partial.len()) as u32);
86 let base = flat_base * stride;
87 let end = base + stride;
88 self.any_bit_set_in(base, end)
89 }
90
91 // ------------------------------------------------------------------
92 // Bit-parallel intersection
93 // ------------------------------------------------------------------
94
95 /// Returns a new `DenseMask` that is the intersection of `self` and `other`.
96 ///
97 /// Two masks can be intersected to find the set of prefixes that appear in
98 /// **both** constraint sets — useful for multi-constraint filtering.
99 ///
100 /// # Panics
101 /// Panics if `self` and `other` have different `vocab_size` or `depth`.
102 pub fn intersect(&self, other: &DenseMask) -> DenseMask {
103 assert_eq!(
104 self.vocab_size, other.vocab_size,
105 "vocab_size mismatch in intersect"
106 );
107 assert_eq!(self.depth, other.depth, "depth mismatch in intersect");
108
109 let bits: Vec<u64> = self
110 .bits
111 .iter()
112 .zip(other.bits.iter())
113 .map(|(&a, &b)| a & b)
114 .collect();
115
116 // Zero out states entries whose bit was cleared by the intersection.
117 let total = (self.vocab_size as usize).pow(self.depth);
118 let mut states = vec![0u32; total];
119 for idx in 0..total {
120 if (bits[idx / 64] >> (idx % 64)) & 1 == 1 {
121 states[idx] = self.states[idx];
122 }
123 }
124
125 DenseMask {
126 bits,
127 states,
128 depth: self.depth,
129 vocab_size: self.vocab_size,
130 }
131 }
132
133 /// Returns a new `DenseMask` that is the union of `self` and `other`.
134 ///
135 /// Used when merging two separately-built index shards.
136 /// Where both masks have a valid entry, `self`'s node ID takes precedence.
137 ///
138 /// # Panics
139 /// Panics if `vocab_size` or `depth` differ.
140 pub fn union(&self, other: &DenseMask) -> DenseMask {
141 assert_eq!(self.vocab_size, other.vocab_size);
142 assert_eq!(self.depth, other.depth);
143
144 let bits: Vec<u64> = self
145 .bits
146 .iter()
147 .zip(other.bits.iter())
148 .map(|(&a, &b)| a | b)
149 .collect();
150
151 let total = (self.vocab_size as usize).pow(self.depth);
152 let mut states = other.states.clone(); // start with other's node IDs
153 for idx in 0..total {
154 // Self takes priority where self has the bit set.
155 if (self.bits[idx / 64] >> (idx % 64)) & 1 == 1 {
156 states[idx] = self.states[idx];
157 }
158 }
159
160 DenseMask {
161 bits,
162 states,
163 depth: self.depth,
164 vocab_size: self.vocab_size,
165 }
166 }
167
168 // ------------------------------------------------------------------
169 // Packed-bit mask extraction (for logit gating)
170 // ------------------------------------------------------------------
171
172 /// Returns the first-token marginal as a packed `Vec<u64>` of length
173 /// `ceil(vocab_size / 64)`.
174 ///
175 /// Bit `t` is set iff token `t` is a valid first token in the constraint
176 /// set. This vec can be ANDed directly with the model's top-k bitmask.
177 pub fn first_token_packed_mask(&self) -> Vec<u64> {
178 let v = self.vocab_size as usize;
179 let words = v.div_ceil(64);
180 let mut out = vec![0u64; words];
181 for tok in 0..v as u32 {
182 if self.first_token_valid(tok) {
183 let idx = tok as usize;
184 out[idx / 64] |= 1u64 << (idx % 64);
185 }
186 }
187 out
188 }
189
190 /// Returns the second-token marginal **given** that `first_token` was chosen,
191 /// packed as a `Vec<u64>` of length `ceil(vocab_size / 64)`.
192 ///
193 /// Only defined for `depth >= 2`.
194 ///
195 /// # Panics (debug)
196 /// Panics if `depth < 2`.
197 pub fn second_token_packed_mask(&self, first_token: u32) -> Vec<u64> {
198 debug_assert!(
199 self.depth >= 2,
200 "second_token_packed_mask requires depth >= 2"
201 );
202 let v = self.vocab_size as usize;
203 let words = v.div_ceil(64);
204 let mut out = vec![0u64; words];
205 for tok2 in 0..v as u32 {
206 if self.get(first_token, tok2) {
207 let idx = tok2 as usize;
208 out[idx / 64] |= 1u64 << (idx % 64);
209 }
210 }
211 out
212 }
213
214 // ------------------------------------------------------------------
215 // Count helpers
216 // ------------------------------------------------------------------
217
218 /// Returns the total number of valid prefixes stored in the mask.
219 pub fn count_valid(&self) -> u64 {
220 self.bits.iter().map(|w| w.count_ones() as u64).sum()
221 }
222
223 /// Returns the number of distinct valid first tokens.
224 pub fn count_valid_first_tokens(&self) -> u32 {
225 (0..self.vocab_size)
226 .filter(|&t| self.first_token_valid(t))
227 .count() as u32
228 }
229
230 // ------------------------------------------------------------------
231 // Serialisation helpers (used by persistence tests)
232 // ------------------------------------------------------------------
233
234 /// Serialises the mask into a flat byte buffer.
235 ///
236 /// Layout (little-endian):
237 /// ```text
238 /// [u32 vocab_size][u32 depth]
239 /// [u32 bits_len][u64 * bits_len]
240 /// [u32 states_len][u32 * states_len]
241 /// ```
242 pub fn to_bytes(&self) -> Vec<u8> {
243 let mut out = Vec::new();
244 out.extend_from_slice(&self.vocab_size.to_le_bytes());
245 out.extend_from_slice(&self.depth.to_le_bytes());
246 out.extend_from_slice(&(self.bits.len() as u32).to_le_bytes());
247 for &w in &self.bits {
248 out.extend_from_slice(&w.to_le_bytes());
249 }
250 out.extend_from_slice(&(self.states.len() as u32).to_le_bytes());
251 for &s in &self.states {
252 out.extend_from_slice(&s.to_le_bytes());
253 }
254 out
255 }
256
257 /// Deserialises a `DenseMask` from the byte layout produced by `to_bytes`.
258 ///
259 /// Returns `None` if the buffer is malformed.
260 pub fn from_bytes(buf: &[u8]) -> Option<Self> {
261 let mut cur = 0usize;
262
263 let read_u32 = |buf: &[u8], pos: &mut usize| -> Option<u32> {
264 let bytes = buf.get(*pos..*pos + 4)?;
265 *pos += 4;
266 Some(u32::from_le_bytes(bytes.try_into().ok()?))
267 };
268 let read_u64 = |buf: &[u8], pos: &mut usize| -> Option<u64> {
269 let bytes = buf.get(*pos..*pos + 8)?;
270 *pos += 8;
271 Some(u64::from_le_bytes(bytes.try_into().ok()?))
272 };
273
274 let vocab_size = read_u32(buf, &mut cur)?;
275 let depth = read_u32(buf, &mut cur)?;
276 let bits_len = read_u32(buf, &mut cur)? as usize;
277
278 let mut bits = Vec::with_capacity(bits_len);
279 for _ in 0..bits_len {
280 bits.push(read_u64(buf, &mut cur)?);
281 }
282
283 let states_len = read_u32(buf, &mut cur)? as usize;
284 let mut states = Vec::with_capacity(states_len);
285 for _ in 0..states_len {
286 states.push(read_u32(buf, &mut cur)?);
287 }
288
289 Some(DenseMask {
290 bits,
291 states,
292 depth,
293 vocab_size,
294 })
295 }
296
297 // ------------------------------------------------------------------
298 // Internal helpers
299 // ------------------------------------------------------------------
300
301 /// Returns the `[base, end)` range of flat indices covered by the block
302 /// rooted at `token` appearing at position `pos` in the prefix.
303 fn token_block_range(&self, token: u32, pos: usize) -> (usize, usize) {
304 let v = self.vocab_size as usize;
305 let d = self.depth as usize;
306 let stride = v.pow((d - pos - 1) as u32);
307 let base = token as usize * stride;
308 (base, base + stride)
309 }
310
311 /// Returns `true` if any bit in flat-index range `[base, end)` is set.
312 // In dense_mask.rs
313 /// Returns `true` if any bit in flat-index range `[base, end)` is set.
314 ///
315 /// This implementation correctly handles ranges that span multiple 64-bit
316 /// words as well as ranges contained within a single word.
317 #[inline]
318 fn any_bit_set_in(&self, base: usize, end: usize) -> bool {
319 if base >= end {
320 return false;
321 }
322
323 let w_start = base / 64;
324 let w_end = (end - 1) / 64; // index of the last word touched
325
326 // Safety bounds check
327 if w_start >= self.bits.len() {
328 return false;
329 }
330 let actual_w_end = w_end.min(self.bits.len() - 1);
331
332 for w_idx in w_start..=actual_w_end {
333 let mut val = self.bits[w_idx];
334
335 // 1. Mask out bits BEFORE the range in the start word
336 if w_idx == w_start {
337 let shift = base % 64;
338 val &= !0u64 << shift;
339 }
340
341 // 2. Mask out bits AFTER the range in the end word
342 // This is applied independently so it works even if w_start == w_end.
343 if w_idx == w_end {
344 let limit = end % 64;
345 if limit != 0 {
346 // mask has bits 0..limit-1 set
347 let mask = (1u64 << limit) - 1;
348 val &= mask;
349 }
350 }
351
352 if val != 0 {
353 return true;
354 }
355 }
356 false
357 }
358}
359
360// ──────────────────────────────────────────────────────────────────────────────
361// Parallel bulk validation (used by transition.rs integration tests)
362// ──────────────────────────────────────────────────────────────────────────────
363
364/// Validates a batch of token prefixes against the mask in parallel.
365///
366/// Returns a `Vec<bool>` of length `prefixes.len()` where `true` means the
367/// prefix is present in the constraint set.
368pub fn validate_prefixes(mask: &DenseMask, prefixes: &[Vec<u32>]) -> Vec<bool> {
369 prefixes.par_iter().map(|p| mask.contains(p)).collect()
370}
371
372/// Converts a `DenseMask` into a flat `Vec<u64>` token-level marginal mask
373/// for the given prefix position and preceding token sequence.
374///
375/// Returns a packed bitmask of length `ceil(vocab_size / 64)` whose bit `t`
376/// is set iff appending token `t` to `prefix_so_far` yields a valid (partial
377/// or complete) prefix in the mask.
378pub fn marginal_mask_at(mask: &DenseMask, prefix_so_far: &[u32]) -> Vec<u64> {
379 let len = prefix_so_far.len();
380 let v = mask.vocab_size as usize;
381 let depth = mask.depth as usize;
382 let words = v.div_ceil(64);
383
384 assert!(
385 len < depth,
386 "prefix_so_far length {len} must be < depth {depth}"
387 );
388
389 let mut out = vec![0u64; words];
390 for tok in 0..v as u32 {
391 let mut candidate = prefix_so_far.to_vec();
392 candidate.push(tok);
393 let valid = if candidate.len() == depth {
394 mask.contains(&candidate)
395 } else {
396 mask.partial_prefix_has_extension(&candidate)
397 };
398 if valid {
399 out[tok as usize / 64] |= 1u64 << (tok as usize % 64);
400 }
401 }
402 out
403}