Skip to main content

oxibonsai_runtime/constrained_decoding/
allow_list.rs

1//! [`AllowListConstraint`] — force output to be one of a finite set of sequences.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// AllowListConstraint — force output to be one of a finite set of sequences
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that forces the generated token sequence to exactly match one
10/// of a finite set of allowed token-id sequences (e.g., multiple-choice answers).
11///
12/// At each step only the union of next tokens across all still-active candidates
13/// is permitted.  A candidate becomes inactive the moment any token in the prefix
14/// fails to match.  Generation is considered complete once the full token sequence
15/// of at least one candidate has been consumed.
16///
17/// # Example
18/// ```rust
19/// use oxibonsai_runtime::constrained_decoding::{AllowListConstraint, TokenConstraint};
20///
21/// // Two candidates: [10, 20] and [10, 30]
22/// let mut c = AllowListConstraint::new(vec![vec![10, 20], vec![10, 30]]);
23/// // First token: only 10 is allowed (shared prefix)
24/// let mask = c.allowed_tokens(&[], 50).unwrap();
25/// assert!(mask[10]);
26/// assert!(!mask[20]);
27/// assert!(!mask[30]);
28/// ```
29pub struct AllowListConstraint {
30    /// All allowed sequences.
31    candidates: Vec<Vec<u32>>,
32    /// Which candidates still match the current prefix.
33    active: Vec<bool>,
34    /// Number of tokens consumed so far.
35    position: usize,
36}
37
38impl AllowListConstraint {
39    /// Create a new `AllowListConstraint` from a list of allowed token sequences.
40    pub fn new(candidates: Vec<Vec<u32>>) -> Self {
41        let n = candidates.len();
42        Self {
43            candidates,
44            active: vec![true; n],
45            position: 0,
46        }
47    }
48
49    /// Returns the number of candidate sequences that are still active.
50    pub fn active_count(&self) -> usize {
51        self.active.iter().filter(|&&a| a).count()
52    }
53}
54
55impl TokenConstraint for AllowListConstraint {
56    /// Returns a bitmask of tokens that are valid next tokens across all still-active
57    /// candidates at the current position.  Always returns `Some` (never unconstrained).
58    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
59        let mut mask = vec![false; vocab_size];
60        for (i, active) in self.active.iter().enumerate() {
61            if !active {
62                continue;
63            }
64            let seq = &self.candidates[i];
65            if self.position < seq.len() {
66                let tok = seq[self.position] as usize;
67                if tok < vocab_size {
68                    mask[tok] = true;
69                }
70            }
71        }
72        Some(mask)
73    }
74
75    /// Commits `token` at the current position.
76    ///
77    /// Any candidate where `candidates[i][position] != token` (or the candidate is
78    /// already exhausted) is deactivated.  Returns `true` when at least one candidate
79    /// remains active **or** a candidate was just completed at this position.
80    fn advance(&mut self, token: u32) -> bool {
81        let mut just_completed = false;
82        for (i, active) in self.active.iter_mut().enumerate() {
83            if !*active {
84                continue;
85            }
86            let seq = &self.candidates[i];
87            if self.position >= seq.len() {
88                // Candidate was already completed; further tokens deactivate it.
89                *active = false;
90            } else if seq[self.position] == token {
91                // Token matches; check if this completes the candidate.
92                if self.position + 1 == seq.len() {
93                    just_completed = true;
94                }
95                // Keep active — will be filtered by is_complete / future advance calls.
96            } else {
97                *active = false;
98            }
99        }
100        self.position += 1;
101        // Return true if at least one candidate is still active or one just completed.
102        just_completed || self.active.iter().any(|&a| a)
103    }
104
105    /// Returns `true` when the consumed token sequence fully matches at least one
106    /// candidate, i.e. `position == candidates[i].len()` for some active `i`.
107    fn is_complete(&self) -> bool {
108        self.candidates
109            .iter()
110            .enumerate()
111            .any(|(i, seq)| self.active[i] && self.position == seq.len())
112    }
113
114    /// Reset to initial state (all candidates active, position zero).
115    fn reset(&mut self) {
116        self.active.fill(true);
117        self.position = 0;
118    }
119
120    fn name(&self) -> &str {
121        "AllowListConstraint"
122    }
123}