oxibonsai_runtime/constrained_decoding/allow_list.rs
1//! [`AllowListConstraint`] — force output to be one of a finite set of sequences.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// AllowListConstraint — force output to be one of a finite set of sequences
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that forces the generated token sequence to exactly match one
10/// of a finite set of allowed token-id sequences (e.g., multiple-choice answers).
11///
12/// At each step only the union of next tokens across all still-active candidates
13/// is permitted. A candidate becomes inactive the moment any token in the prefix
14/// fails to match. Generation is considered complete once the full token sequence
15/// of at least one candidate has been consumed.
16///
17/// # Example
18/// ```rust
19/// use oxibonsai_runtime::constrained_decoding::{AllowListConstraint, TokenConstraint};
20///
21/// // Two candidates: [10, 20] and [10, 30]
22/// let mut c = AllowListConstraint::new(vec![vec![10, 20], vec![10, 30]]);
23/// // First token: only 10 is allowed (shared prefix)
24/// let mask = c.allowed_tokens(&[], 50).unwrap();
25/// assert!(mask[10]);
26/// assert!(!mask[20]);
27/// assert!(!mask[30]);
28/// ```
29pub struct AllowListConstraint {
30 /// All allowed sequences.
31 candidates: Vec<Vec<u32>>,
32 /// Which candidates still match the current prefix.
33 active: Vec<bool>,
34 /// Number of tokens consumed so far.
35 position: usize,
36}
37
38impl AllowListConstraint {
39 /// Create a new `AllowListConstraint` from a list of allowed token sequences.
40 pub fn new(candidates: Vec<Vec<u32>>) -> Self {
41 let n = candidates.len();
42 Self {
43 candidates,
44 active: vec![true; n],
45 position: 0,
46 }
47 }
48
49 /// Returns the number of candidate sequences that are still active.
50 pub fn active_count(&self) -> usize {
51 self.active.iter().filter(|&&a| a).count()
52 }
53}
54
55impl TokenConstraint for AllowListConstraint {
56 /// Returns a bitmask of tokens that are valid next tokens across all still-active
57 /// candidates at the current position. Always returns `Some` (never unconstrained).
58 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
59 let mut mask = vec![false; vocab_size];
60 for (i, active) in self.active.iter().enumerate() {
61 if !active {
62 continue;
63 }
64 let seq = &self.candidates[i];
65 if self.position < seq.len() {
66 let tok = seq[self.position] as usize;
67 if tok < vocab_size {
68 mask[tok] = true;
69 }
70 }
71 }
72 Some(mask)
73 }
74
75 /// Commits `token` at the current position.
76 ///
77 /// Any candidate where `candidates[i][position] != token` (or the candidate is
78 /// already exhausted) is deactivated. Returns `true` when at least one candidate
79 /// remains active **or** a candidate was just completed at this position.
80 fn advance(&mut self, token: u32) -> bool {
81 let mut just_completed = false;
82 for (i, active) in self.active.iter_mut().enumerate() {
83 if !*active {
84 continue;
85 }
86 let seq = &self.candidates[i];
87 if self.position >= seq.len() {
88 // Candidate was already completed; further tokens deactivate it.
89 *active = false;
90 } else if seq[self.position] == token {
91 // Token matches; check if this completes the candidate.
92 if self.position + 1 == seq.len() {
93 just_completed = true;
94 }
95 // Keep active — will be filtered by is_complete / future advance calls.
96 } else {
97 *active = false;
98 }
99 }
100 self.position += 1;
101 // Return true if at least one candidate is still active or one just completed.
102 just_completed || self.active.iter().any(|&a| a)
103 }
104
105 /// Returns `true` when the consumed token sequence fully matches at least one
106 /// candidate, i.e. `position == candidates[i].len()` for some active `i`.
107 fn is_complete(&self) -> bool {
108 self.candidates
109 .iter()
110 .enumerate()
111 .any(|(i, seq)| self.active[i] && self.position == seq.len())
112 }
113
114 /// Reset to initial state (all candidates active, position zero).
115 fn reset(&mut self) {
116 self.active.fill(true);
117 self.position = 0;
118 }
119
120 fn name(&self) -> &str {
121 "AllowListConstraint"
122 }
123}