Skip to main content

oxibonsai_runtime/constrained_decoding/
error_trait.rs

1//! Core trait and error types for grammar-constrained decoding.
2//!
3//! This sub-module hosts:
4//! - [`ConstraintError`]: errors that can arise when building or running a token constraint
5//! - [`TokenConstraint`]: the trait implemented by every concrete constraint
6//! - [`NoConstraint`]: a passthrough constraint that allows all tokens
7
8// ─────────────────────────────────────────────────────────────────────────────
9// Error type
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// Errors that can arise when building or running a token constraint.
13#[derive(Debug, thiserror::Error)]
14pub enum ConstraintError {
15    /// The supplied regex pattern was syntactically invalid.
16    #[error("Invalid regex pattern: {0}")]
17    InvalidPattern(String),
18
19    /// The supplied JSON schema was invalid (reserved for future schema-based constraints).
20    #[error("Invalid JSON schema: {0}")]
21    InvalidSchema(String),
22
23    /// The constraint was violated at a specific token.
24    #[error("Constraint violated at token {token}: {reason}")]
25    Violated { token: u32, reason: String },
26}
27
28// ─────────────────────────────────────────────────────────────────────────────
29// Core trait
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// A constraint that restricts which tokens are valid at each decoding step.
33///
34/// Implementors maintain internal state representing how far through the
35/// constrained structure the generation has progressed.
36pub trait TokenConstraint: Send + Sync {
37    /// Given the tokens generated so far, return a bitmask of allowed next tokens.
38    ///
39    /// `vocab_size` is the total vocabulary size.  Returns `None` if all tokens
40    /// are allowed (no active constraint).
41    fn allowed_tokens(&self, generated: &[u32], vocab_size: usize) -> Option<Vec<bool>>;
42
43    /// Called after a token is committed.
44    ///
45    /// Returns `false` if the constraint is now violated (generation should stop).
46    fn advance(&mut self, token: u32) -> bool;
47
48    /// Returns `true` if the current state is a valid terminal state.
49    fn is_complete(&self) -> bool;
50
51    /// Reset the constraint to its initial state.
52    fn reset(&mut self);
53
54    /// Human-readable name for debugging and logging.
55    fn name(&self) -> &str;
56}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// NoConstraint — passthrough
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// A passthrough constraint that places no restriction on the vocabulary.
63pub struct NoConstraint;
64
65impl TokenConstraint for NoConstraint {
66    fn allowed_tokens(&self, _generated: &[u32], _vocab_size: usize) -> Option<Vec<bool>> {
67        None
68    }
69
70    fn advance(&mut self, _token: u32) -> bool {
71        true
72    }
73
74    fn is_complete(&self) -> bool {
75        true
76    }
77
78    fn reset(&mut self) {}
79
80    fn name(&self) -> &str {
81        "NoConstraint"
82    }
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn no_constraint_allows_all() {
91        let nc = NoConstraint;
92        assert!(nc.allowed_tokens(&[], 10).is_none());
93    }
94
95    #[test]
96    fn constraint_error_display() {
97        let e = ConstraintError::InvalidPattern("bad".to_string());
98        assert!(e.to_string().contains("bad"));
99        let e2 = ConstraintError::Violated {
100            token: 5,
101            reason: "oops".to_string(),
102        };
103        assert!(e2.to_string().contains("5"));
104        assert!(e2.to_string().contains("oops"));
105    }
106}