oxibonsai_runtime/constrained_decoding/error_trait.rs
1//! Core trait and error types for grammar-constrained decoding.
2//!
3//! This sub-module hosts:
4//! - [`ConstraintError`]: errors that can arise when building or running a token constraint
5//! - [`TokenConstraint`]: the trait implemented by every concrete constraint
6//! - [`NoConstraint`]: a passthrough constraint that allows all tokens
7
8// ─────────────────────────────────────────────────────────────────────────────
9// Error type
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// Errors that can arise when building or running a token constraint.
13#[derive(Debug, thiserror::Error)]
14pub enum ConstraintError {
15 /// The supplied regex pattern was syntactically invalid.
16 #[error("Invalid regex pattern: {0}")]
17 InvalidPattern(String),
18
19 /// The supplied JSON schema was invalid (reserved for future schema-based constraints).
20 #[error("Invalid JSON schema: {0}")]
21 InvalidSchema(String),
22
23 /// The constraint was violated at a specific token.
24 #[error("Constraint violated at token {token}: {reason}")]
25 Violated { token: u32, reason: String },
26}
27
28// ─────────────────────────────────────────────────────────────────────────────
29// Core trait
30// ─────────────────────────────────────────────────────────────────────────────
31
32/// A constraint that restricts which tokens are valid at each decoding step.
33///
34/// Implementors maintain internal state representing how far through the
35/// constrained structure the generation has progressed.
36pub trait TokenConstraint: Send + Sync {
37 /// Given the tokens generated so far, return a bitmask of allowed next tokens.
38 ///
39 /// `vocab_size` is the total vocabulary size. Returns `None` if all tokens
40 /// are allowed (no active constraint).
41 fn allowed_tokens(&self, generated: &[u32], vocab_size: usize) -> Option<Vec<bool>>;
42
43 /// Called after a token is committed.
44 ///
45 /// Returns `false` if the constraint is now violated (generation should stop).
46 fn advance(&mut self, token: u32) -> bool;
47
48 /// Returns `true` if the current state is a valid terminal state.
49 fn is_complete(&self) -> bool;
50
51 /// Reset the constraint to its initial state.
52 fn reset(&mut self);
53
54 /// Human-readable name for debugging and logging.
55 fn name(&self) -> &str;
56}
57
58// ─────────────────────────────────────────────────────────────────────────────
59// NoConstraint — passthrough
60// ─────────────────────────────────────────────────────────────────────────────
61
62/// A passthrough constraint that places no restriction on the vocabulary.
63pub struct NoConstraint;
64
65impl TokenConstraint for NoConstraint {
66 fn allowed_tokens(&self, _generated: &[u32], _vocab_size: usize) -> Option<Vec<bool>> {
67 None
68 }
69
70 fn advance(&mut self, _token: u32) -> bool {
71 true
72 }
73
74 fn is_complete(&self) -> bool {
75 true
76 }
77
78 fn reset(&mut self) {}
79
80 fn name(&self) -> &str {
81 "NoConstraint"
82 }
83}
84
85#[cfg(test)]
86mod tests {
87 use super::*;
88
89 #[test]
90 fn no_constraint_allows_all() {
91 let nc = NoConstraint;
92 assert!(nc.allowed_tokens(&[], 10).is_none());
93 }
94
95 #[test]
96 fn constraint_error_display() {
97 let e = ConstraintError::InvalidPattern("bad".to_string());
98 assert!(e.to_string().contains("bad"));
99 let e2 = ConstraintError::Violated {
100 token: 5,
101 reason: "oops".to_string(),
102 };
103 assert!(e2.to_string().contains("5"));
104 assert!(e2.to_string().contains("oops"));
105 }
106}