use super::error_trait::TokenConstraint;
pub struct AllowListConstraint {
candidates: Vec<Vec<u32>>,
active: Vec<bool>,
position: usize,
}
impl AllowListConstraint {
pub fn new(candidates: Vec<Vec<u32>>) -> Self {
let n = candidates.len();
Self {
candidates,
active: vec![true; n],
position: 0,
}
}
pub fn active_count(&self) -> usize {
self.active.iter().filter(|&&a| a).count()
}
}
impl TokenConstraint for AllowListConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
let mut mask = vec![false; vocab_size];
for (i, active) in self.active.iter().enumerate() {
if !active {
continue;
}
let seq = &self.candidates[i];
if self.position < seq.len() {
let tok = seq[self.position] as usize;
if tok < vocab_size {
mask[tok] = true;
}
}
}
Some(mask)
}
fn advance(&mut self, token: u32) -> bool {
let mut just_completed = false;
for (i, active) in self.active.iter_mut().enumerate() {
if !*active {
continue;
}
let seq = &self.candidates[i];
if self.position >= seq.len() {
*active = false;
} else if seq[self.position] == token {
if self.position + 1 == seq.len() {
just_completed = true;
}
} else {
*active = false;
}
}
self.position += 1;
just_completed || self.active.iter().any(|&a| a)
}
fn is_complete(&self) -> bool {
self.candidates
.iter()
.enumerate()
.any(|(i, seq)| self.active[i] && self.position == seq.len())
}
fn reset(&mut self) {
self.active.fill(true);
self.position = 0;
}
fn name(&self) -> &str {
"AllowListConstraint"
}
}