Skip to main content

constraint_decoding_trie/
dense_mask.rs

1// src/dense_mask.rs
2
3use crate::types::DenseMask;
4use rayon::prelude::*;
5
6// ──────────────────────────────────────────────────────────────────────────────
7// Construction
8// ──────────────────────────────────────────────────────────────────────────────
9
10impl DenseMask {
11    // ------------------------------------------------------------------
12    // Bulk construction from a constraint set
13    // ------------------------------------------------------------------
14
15    /// Build a `DenseMask` directly from a full constraint set.
16    ///
17    /// This is the canonical constructor used by `build_static_index`.
18    /// It is equivalent to calling `DenseMask::new` followed by repeated
19    /// `insert` calls, but avoids recomputing `flat_index` twice per entry.
20    ///
21    /// # Arguments
22    /// - `constraints` — every sequence must have length ≥ `depth`
23    /// - `vocab_size`  — |V|
24    /// - `depth`       — number of dense layers d (typically 2)
25    /// - `node_ids`    — parallel slice: `node_ids[i]` is the trie node reached
26    ///                   after the first `depth` tokens of `constraints[i]`.
27    ///                   Pass an all-zeros slice when node IDs are not yet known
28    ///                   and will be back-filled by `transition.rs`.
29    pub fn from_constraints(
30        constraints: &[Vec<u32>],
31        vocab_size: u32,
32        depth: u32,
33        node_ids: &[u32],
34    ) -> Self {
35        debug_assert_eq!(
36            constraints.len(),
37            node_ids.len(),
38            "constraints and node_ids must have equal length"
39        );
40
41        let mut mask = DenseMask::new(vocab_size, depth);
42
43        for (seq, &nid) in constraints.iter().zip(node_ids.iter()) {
44            debug_assert!(
45                seq.len() >= depth as usize,
46                "sequence too short for dense depth {depth}"
47            );
48            mask.insert(&seq[..depth as usize], nid);
49        }
50
51        mask
52    }
53
54    // ------------------------------------------------------------------
55    // Prefix validity — O(word) scan over packed bits
56    // ------------------------------------------------------------------
57
58    /// Returns `true` if **any** full-depth prefix starts with `first_token`.
59    ///
60    /// Operates on packed `u64` words without deserialising individual bits.
61    /// Used by `decoder.rs` at step 0 to expose the valid first-token set.
62    ///
63    /// # Complexity
64    /// O(|V|^(depth-1) / 64)  ≈  O(1) for small depth and typical |V|.
65    pub fn first_token_valid(&self, first_token: u32) -> bool {
66        let (base, end) = self.token_block_range(first_token, 0);
67        self.any_bit_set_in(base, end)
68    }
69
70    /// Returns `true` if `partial` (length < `depth`) can be extended to a
71    /// valid full-depth prefix in the constraint set.
72    ///
73    /// # Panics (debug)
74    /// Panics if `partial.len() >= depth`.
75    pub fn partial_prefix_has_extension(&self, partial: &[u32]) -> bool {
76        debug_assert!(
77            partial.len() < self.depth as usize,
78            "partial prefix length {} must be < depth {}",
79            partial.len(),
80            self.depth
81        );
82        let flat_base: usize = partial.iter().fold(0usize, |acc, &t| {
83            acc * self.vocab_size as usize + t as usize
84        });
85        let stride = (self.vocab_size as usize).pow((self.depth as usize - partial.len()) as u32);
86        let base = flat_base * stride;
87        let end = base + stride;
88        self.any_bit_set_in(base, end)
89    }
90
91    // ------------------------------------------------------------------
92    // Bit-parallel intersection
93    // ------------------------------------------------------------------
94
95    /// Returns a new `DenseMask` that is the intersection of `self` and `other`.
96    ///
97    /// Two masks can be intersected to find the set of prefixes that appear in
98    /// **both** constraint sets — useful for multi-constraint filtering.
99    ///
100    /// # Panics
101    /// Panics if `self` and `other` have different `vocab_size` or `depth`.
102    pub fn intersect(&self, other: &DenseMask) -> DenseMask {
103        assert_eq!(
104            self.vocab_size, other.vocab_size,
105            "vocab_size mismatch in intersect"
106        );
107        assert_eq!(self.depth, other.depth, "depth mismatch in intersect");
108
109        let bits: Vec<u64> = self
110            .bits
111            .iter()
112            .zip(other.bits.iter())
113            .map(|(&a, &b)| a & b)
114            .collect();
115
116        // Zero out states entries whose bit was cleared by the intersection.
117        let total = (self.vocab_size as usize).pow(self.depth);
118        let mut states = vec![0u32; total];
119        for idx in 0..total {
120            if (bits[idx / 64] >> (idx % 64)) & 1 == 1 {
121                states[idx] = self.states[idx];
122            }
123        }
124
125        DenseMask {
126            bits,
127            states,
128            depth: self.depth,
129            vocab_size: self.vocab_size,
130        }
131    }
132
133    /// Returns a new `DenseMask` that is the union of `self` and `other`.
134    ///
135    /// Used when merging two separately-built index shards.
136    /// Where both masks have a valid entry, `self`'s node ID takes precedence.
137    ///
138    /// # Panics
139    /// Panics if `vocab_size` or `depth` differ.
140    pub fn union(&self, other: &DenseMask) -> DenseMask {
141        assert_eq!(self.vocab_size, other.vocab_size);
142        assert_eq!(self.depth, other.depth);
143
144        let bits: Vec<u64> = self
145            .bits
146            .iter()
147            .zip(other.bits.iter())
148            .map(|(&a, &b)| a | b)
149            .collect();
150
151        let total = (self.vocab_size as usize).pow(self.depth);
152        let mut states = other.states.clone(); // start with other's node IDs
153        for idx in 0..total {
154            // Self takes priority where self has the bit set.
155            if (self.bits[idx / 64] >> (idx % 64)) & 1 == 1 {
156                states[idx] = self.states[idx];
157            }
158        }
159
160        DenseMask {
161            bits,
162            states,
163            depth: self.depth,
164            vocab_size: self.vocab_size,
165        }
166    }
167
168    // ------------------------------------------------------------------
169    // Packed-bit mask extraction (for logit gating)
170    // ------------------------------------------------------------------
171
172    /// Returns the first-token marginal as a packed `Vec<u64>` of length
173    /// `ceil(vocab_size / 64)`.
174    ///
175    /// Bit `t` is set iff token `t` is a valid first token in the constraint
176    /// set.  This vec can be ANDed directly with the model's top-k bitmask.
177    pub fn first_token_packed_mask(&self) -> Vec<u64> {
178        let v = self.vocab_size as usize;
179        let words = v.div_ceil(64);
180        let mut out = vec![0u64; words];
181        for tok in 0..v as u32 {
182            if self.first_token_valid(tok) {
183                let idx = tok as usize;
184                out[idx / 64] |= 1u64 << (idx % 64);
185            }
186        }
187        out
188    }
189
190    /// Returns the second-token marginal **given** that `first_token` was chosen,
191    /// packed as a `Vec<u64>` of length `ceil(vocab_size / 64)`.
192    ///
193    /// Only defined for `depth >= 2`.
194    ///
195    /// # Panics (debug)
196    /// Panics if `depth < 2`.
197    pub fn second_token_packed_mask(&self, first_token: u32) -> Vec<u64> {
198        debug_assert!(
199            self.depth >= 2,
200            "second_token_packed_mask requires depth >= 2"
201        );
202        let v = self.vocab_size as usize;
203        let words = v.div_ceil(64);
204        let mut out = vec![0u64; words];
205        for tok2 in 0..v as u32 {
206            if self.get(first_token, tok2) {
207                let idx = tok2 as usize;
208                out[idx / 64] |= 1u64 << (idx % 64);
209            }
210        }
211        out
212    }
213
214    // ------------------------------------------------------------------
215    // Count helpers
216    // ------------------------------------------------------------------
217
218    /// Returns the total number of valid prefixes stored in the mask.
219    pub fn count_valid(&self) -> u64 {
220        self.bits.iter().map(|w| w.count_ones() as u64).sum()
221    }
222
223    /// Returns the number of distinct valid first tokens.
224    pub fn count_valid_first_tokens(&self) -> u32 {
225        (0..self.vocab_size)
226            .filter(|&t| self.first_token_valid(t))
227            .count() as u32
228    }
229
230    // ------------------------------------------------------------------
231    // Serialisation helpers (used by persistence tests)
232    // ------------------------------------------------------------------
233
234    /// Serialises the mask into a flat byte buffer.
235    ///
236    /// Layout (little-endian):
237    /// ```text
238    /// [u32 vocab_size][u32 depth]
239    /// [u32 bits_len][u64 * bits_len]
240    /// [u32 states_len][u32 * states_len]
241    /// ```
242    pub fn to_bytes(&self) -> Vec<u8> {
243        let mut out = Vec::new();
244        out.extend_from_slice(&self.vocab_size.to_le_bytes());
245        out.extend_from_slice(&self.depth.to_le_bytes());
246        out.extend_from_slice(&(self.bits.len() as u32).to_le_bytes());
247        for &w in &self.bits {
248            out.extend_from_slice(&w.to_le_bytes());
249        }
250        out.extend_from_slice(&(self.states.len() as u32).to_le_bytes());
251        for &s in &self.states {
252            out.extend_from_slice(&s.to_le_bytes());
253        }
254        out
255    }
256
257    /// Deserialises a `DenseMask` from the byte layout produced by `to_bytes`.
258    ///
259    /// Returns `None` if the buffer is malformed.
260    pub fn from_bytes(buf: &[u8]) -> Option<Self> {
261        let mut cur = 0usize;
262
263        let read_u32 = |buf: &[u8], pos: &mut usize| -> Option<u32> {
264            let bytes = buf.get(*pos..*pos + 4)?;
265            *pos += 4;
266            Some(u32::from_le_bytes(bytes.try_into().ok()?))
267        };
268        let read_u64 = |buf: &[u8], pos: &mut usize| -> Option<u64> {
269            let bytes = buf.get(*pos..*pos + 8)?;
270            *pos += 8;
271            Some(u64::from_le_bytes(bytes.try_into().ok()?))
272        };
273
274        let vocab_size = read_u32(buf, &mut cur)?;
275        let depth = read_u32(buf, &mut cur)?;
276        let bits_len = read_u32(buf, &mut cur)? as usize;
277
278        let mut bits = Vec::with_capacity(bits_len);
279        for _ in 0..bits_len {
280            bits.push(read_u64(buf, &mut cur)?);
281        }
282
283        let states_len = read_u32(buf, &mut cur)? as usize;
284        let mut states = Vec::with_capacity(states_len);
285        for _ in 0..states_len {
286            states.push(read_u32(buf, &mut cur)?);
287        }
288
289        Some(DenseMask {
290            bits,
291            states,
292            depth,
293            vocab_size,
294        })
295    }
296
297    // ------------------------------------------------------------------
298    // Internal helpers
299    // ------------------------------------------------------------------
300
301    /// Returns the `[base, end)` range of flat indices covered by the block
302    /// rooted at `token` appearing at position `pos` in the prefix.
303    fn token_block_range(&self, token: u32, pos: usize) -> (usize, usize) {
304        let v = self.vocab_size as usize;
305        let d = self.depth as usize;
306        let stride = v.pow((d - pos - 1) as u32);
307        let base = token as usize * stride;
308        (base, base + stride)
309    }
310
311    /// Returns `true` if any bit in flat-index range `[base, end)` is set.
312    // In dense_mask.rs
313    /// Returns `true` if any bit in flat-index range `[base, end)` is set.
314    ///
315    /// This implementation correctly handles ranges that span multiple 64-bit
316    /// words as well as ranges contained within a single word.
317    #[inline]
318    fn any_bit_set_in(&self, base: usize, end: usize) -> bool {
319        if base >= end {
320            return false;
321        }
322
323        let w_start = base / 64;
324        let w_end = (end - 1) / 64; // index of the last word touched
325
326        // Safety bounds check
327        if w_start >= self.bits.len() {
328            return false;
329        }
330        let actual_w_end = w_end.min(self.bits.len() - 1);
331
332        for w_idx in w_start..=actual_w_end {
333            let mut val = self.bits[w_idx];
334
335            // 1. Mask out bits BEFORE the range in the start word
336            if w_idx == w_start {
337                let shift = base % 64;
338                val &= !0u64 << shift;
339            }
340
341            // 2. Mask out bits AFTER the range in the end word
342            // This is applied independently so it works even if w_start == w_end.
343            if w_idx == w_end {
344                let limit = end % 64;
345                if limit != 0 {
346                    // mask has bits 0..limit-1 set
347                    let mask = (1u64 << limit) - 1;
348                    val &= mask;
349                }
350            }
351
352            if val != 0 {
353                return true;
354            }
355        }
356        false
357    }
358}
359
360// ──────────────────────────────────────────────────────────────────────────────
361// Parallel bulk validation  (used by transition.rs integration tests)
362// ──────────────────────────────────────────────────────────────────────────────
363
364/// Validates a batch of token prefixes against the mask in parallel.
365///
366/// Returns a `Vec<bool>` of length `prefixes.len()` where `true` means the
367/// prefix is present in the constraint set.
368pub fn validate_prefixes(mask: &DenseMask, prefixes: &[Vec<u32>]) -> Vec<bool> {
369    prefixes.par_iter().map(|p| mask.contains(p)).collect()
370}
371
372/// Converts a `DenseMask` into a flat `Vec<u64>` token-level marginal mask
373/// for the given prefix position and preceding token sequence.
374///
375/// Returns a packed bitmask of length `ceil(vocab_size / 64)` whose bit `t`
376/// is set iff appending token `t` to `prefix_so_far` yields a valid (partial
377/// or complete) prefix in the mask.
378pub fn marginal_mask_at(mask: &DenseMask, prefix_so_far: &[u32]) -> Vec<u64> {
379    let len = prefix_so_far.len();
380    let v = mask.vocab_size as usize;
381    let depth = mask.depth as usize;
382    let words = v.div_ceil(64);
383
384    assert!(
385        len < depth,
386        "prefix_so_far length {len} must be < depth {depth}"
387    );
388
389    let mut out = vec![0u64; words];
390    for tok in 0..v as u32 {
391        let mut candidate = prefix_so_far.to_vec();
392        candidate.push(tok);
393        let valid = if candidate.len() == depth {
394            mask.contains(&candidate)
395        } else {
396            mask.partial_prefix_has_extension(&candidate)
397        };
398        if valid {
399            out[tok as usize / 64] |= 1u64 << (tok as usize % 64);
400        }
401    }
402    out
403}