use super::error_trait::TokenConstraint;
pub struct SequenceConstraint {
target: Vec<u32>,
position: usize,
failed: bool,
}
impl SequenceConstraint {
pub fn new(target: Vec<u32>) -> Self {
Self {
target,
position: 0,
failed: false,
}
}
pub fn is_failed(&self) -> bool {
self.failed
}
}
impl TokenConstraint for SequenceConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
if self.position >= self.target.len() {
return None;
}
let mut mask = vec![false; vocab_size];
let next = self.target[self.position] as usize;
if next < vocab_size {
mask[next] = true;
}
Some(mask)
}
fn advance(&mut self, token: u32) -> bool {
if self.position < self.target.len() && token != self.target[self.position] {
self.failed = true;
self.position += 1;
return false;
}
self.position += 1;
true
}
fn is_complete(&self) -> bool {
self.position >= self.target.len()
}
fn reset(&mut self) {
self.position = 0;
self.failed = false;
}
fn name(&self) -> &str {
"SequenceConstraint"
}
}