use super::error_trait::TokenConstraint;
pub struct LengthConstraint {
min_len: usize,
max_len: usize,
stop_token: Option<u32>,
count: usize,
stop_seen: bool,
}
impl LengthConstraint {
pub fn new(min_len: usize, max_len: usize, stop_token: Option<u32>) -> Self {
Self {
min_len,
max_len,
stop_token,
count: 0,
stop_seen: false,
}
}
pub fn count(&self) -> usize {
self.count
}
}
impl TokenConstraint for LengthConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
if self.count < self.min_len {
if let Some(stop) = self.stop_token {
let mut mask = vec![true; vocab_size];
let stop_idx = stop as usize;
if stop_idx < vocab_size {
mask[stop_idx] = false;
}
return Some(mask);
}
return None;
}
if self.count >= self.max_len {
if let Some(stop) = self.stop_token {
let mut mask = vec![false; vocab_size];
let stop_idx = stop as usize;
if stop_idx < vocab_size {
mask[stop_idx] = true;
}
return Some(mask);
}
return Some(vec![false; vocab_size]);
}
None
}
fn advance(&mut self, token: u32) -> bool {
if let Some(stop) = self.stop_token {
if token == stop {
self.stop_seen = true;
}
}
self.count += 1;
true
}
fn is_complete(&self) -> bool {
if self.count < self.min_len {
return false;
}
self.count >= self.max_len || self.stop_seen
}
fn reset(&mut self) {
self.count = 0;
self.stop_seen = false;
}
fn name(&self) -> &str {
"LengthConstraint"
}
}