oxibonsai_runtime/constrained_decoding/sequence.rs
1//! [`SequenceConstraint`] — force output to follow a specific token sequence exactly.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// SequenceConstraint — force output to follow a specific token sequence exactly
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that forces the generated output to reproduce a specific,
10/// pre-determined token sequence.
11///
12/// While `position < target.len()`, only `target[position]` is allowed.
13/// Once the target has been fully reproduced (`position >= target.len()`) the
14/// constraint is satisfied and all tokens become allowed again (returns `None`).
15///
16/// # Example
17/// ```rust
18/// use oxibonsai_runtime::constrained_decoding::{SequenceConstraint, TokenConstraint};
19///
20/// let mut c = SequenceConstraint::new(vec![5, 6, 7]);
21/// let mask = c.allowed_tokens(&[], 10).unwrap();
22/// assert!(mask[5]);
23/// assert!(!mask[6]);
24/// assert_eq!(c.advance(5), true);
25/// ```
26pub struct SequenceConstraint {
27 /// The token sequence that must be reproduced.
28 target: Vec<u32>,
29 /// Number of tokens consumed (next expected index into `target`).
30 position: usize,
31 /// Set to `true` if a mismatched token was ever committed.
32 failed: bool,
33}
34
35impl SequenceConstraint {
36 /// Create a new `SequenceConstraint` for the given target sequence.
37 pub fn new(target: Vec<u32>) -> Self {
38 Self {
39 target,
40 position: 0,
41 failed: false,
42 }
43 }
44
45 /// Whether the constraint has been violated (a wrong token was committed).
46 pub fn is_failed(&self) -> bool {
47 self.failed
48 }
49}
50
51impl TokenConstraint for SequenceConstraint {
52 /// Returns a bitmask allowing only the next expected token, or `None` once the
53 /// full sequence has been reproduced.
54 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
55 if self.position >= self.target.len() {
56 // Sequence fully consumed — no further restriction.
57 return None;
58 }
59 let mut mask = vec![false; vocab_size];
60 let next = self.target[self.position] as usize;
61 if next < vocab_size {
62 mask[next] = true;
63 }
64 Some(mask)
65 }
66
67 /// Commits `token`. Returns `false` (and sets the failed flag) if `token`
68 /// does not match the expected token at the current position.
69 fn advance(&mut self, token: u32) -> bool {
70 if self.position < self.target.len() && token != self.target[self.position] {
71 self.failed = true;
72 self.position += 1;
73 return false;
74 }
75 self.position += 1;
76 true
77 }
78
79 /// Returns `true` once all tokens in the target sequence have been consumed.
80 fn is_complete(&self) -> bool {
81 self.position >= self.target.len()
82 }
83
84 /// Reset to initial state.
85 fn reset(&mut self) {
86 self.position = 0;
87 self.failed = false;
88 }
89
90 fn name(&self) -> &str {
91 "SequenceConstraint"
92 }
93}