Skip to main content

oxibonsai_runtime/constrained_decoding/
sequence.rs

1//! [`SequenceConstraint`] — force output to follow a specific token sequence exactly.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// SequenceConstraint — force output to follow a specific token sequence exactly
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that forces the generated output to reproduce a specific,
10/// pre-determined token sequence.
11///
12/// While `position < target.len()`, only `target[position]` is allowed.
13/// Once the target has been fully reproduced (`position >= target.len()`) the
14/// constraint is satisfied and all tokens become allowed again (returns `None`).
15///
16/// # Example
17/// ```rust
18/// use oxibonsai_runtime::constrained_decoding::{SequenceConstraint, TokenConstraint};
19///
20/// let mut c = SequenceConstraint::new(vec![5, 6, 7]);
21/// let mask = c.allowed_tokens(&[], 10).unwrap();
22/// assert!(mask[5]);
23/// assert!(!mask[6]);
24/// assert_eq!(c.advance(5), true);
25/// ```
26pub struct SequenceConstraint {
27    /// The token sequence that must be reproduced.
28    target: Vec<u32>,
29    /// Number of tokens consumed (next expected index into `target`).
30    position: usize,
31    /// Set to `true` if a mismatched token was ever committed.
32    failed: bool,
33}
34
35impl SequenceConstraint {
36    /// Create a new `SequenceConstraint` for the given target sequence.
37    pub fn new(target: Vec<u32>) -> Self {
38        Self {
39            target,
40            position: 0,
41            failed: false,
42        }
43    }
44
45    /// Whether the constraint has been violated (a wrong token was committed).
46    pub fn is_failed(&self) -> bool {
47        self.failed
48    }
49}
50
51impl TokenConstraint for SequenceConstraint {
52    /// Returns a bitmask allowing only the next expected token, or `None` once the
53    /// full sequence has been reproduced.
54    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
55        if self.position >= self.target.len() {
56            // Sequence fully consumed — no further restriction.
57            return None;
58        }
59        let mut mask = vec![false; vocab_size];
60        let next = self.target[self.position] as usize;
61        if next < vocab_size {
62            mask[next] = true;
63        }
64        Some(mask)
65    }
66
67    /// Commits `token`.  Returns `false` (and sets the failed flag) if `token`
68    /// does not match the expected token at the current position.
69    fn advance(&mut self, token: u32) -> bool {
70        if self.position < self.target.len() && token != self.target[self.position] {
71            self.failed = true;
72            self.position += 1;
73            return false;
74        }
75        self.position += 1;
76        true
77    }
78
79    /// Returns `true` once all tokens in the target sequence have been consumed.
80    fn is_complete(&self) -> bool {
81        self.position >= self.target.len()
82    }
83
84    /// Reset to initial state.
85    fn reset(&mut self) {
86        self.position = 0;
87        self.failed = false;
88    }
89
90    fn name(&self) -> &str {
91        "SequenceConstraint"
92    }
93}