#[derive(Debug, thiserror::Error)]
pub enum ConstraintError {
#[error("Invalid regex pattern: {0}")]
InvalidPattern(String),
#[error("Invalid JSON schema: {0}")]
InvalidSchema(String),
#[error("Constraint violated at token {token}: {reason}")]
Violated { token: u32, reason: String },
}
pub trait TokenConstraint: Send + Sync {
fn allowed_tokens(&self, generated: &[u32], vocab_size: usize) -> Option<Vec<bool>>;
fn advance(&mut self, token: u32) -> bool;
fn is_complete(&self) -> bool;
fn reset(&mut self);
fn name(&self) -> &str;
}
pub struct NoConstraint;
impl TokenConstraint for NoConstraint {
fn allowed_tokens(&self, _generated: &[u32], _vocab_size: usize) -> Option<Vec<bool>> {
None
}
fn advance(&mut self, _token: u32) -> bool {
true
}
fn is_complete(&self) -> bool {
true
}
fn reset(&mut self) {}
fn name(&self) -> &str {
"NoConstraint"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_constraint_allows_all() {
let nc = NoConstraint;
assert!(nc.allowed_tokens(&[], 10).is_none());
}
#[test]
fn constraint_error_display() {
let e = ConstraintError::InvalidPattern("bad".to_string());
assert!(e.to_string().contains("bad"));
let e2 = ConstraintError::Violated {
token: 5,
reason: "oops".to_string(),
};
assert!(e2.to_string().contains("5"));
assert!(e2.to_string().contains("oops"));
}
}