oxibonsai_runtime/constrained_decoding/length.rs
1//! [`LengthConstraint`] — enforce hard minimum and maximum generation lengths.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// LengthConstraint — enforce hard minimum and maximum generation lengths
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that enforces hard minimum and maximum token-count limits.
10///
11/// - While `count < min_len`: if a `stop_token` is configured it is excluded from
12/// the mask (cannot stop early).
13/// - While `count >= max_len`: if a `stop_token` is configured only that token is
14/// allowed; otherwise an all-`false` mask is returned, signalling the caller to
15/// halt generation externally.
16/// - Between `min_len` and `max_len`: all tokens are allowed (`None`).
17///
18/// Completion is defined as either reaching `max_len` OR generating `min_len` or
19/// more tokens followed by the `stop_token`.
20///
21/// # Example
22/// ```rust
23/// use oxibonsai_runtime::constrained_decoding::{LengthConstraint, TokenConstraint};
24///
25/// // Must generate at least 2 tokens, stop token is 1 (EOS), max 10.
26/// let mut c = LengthConstraint::new(2, 10, Some(1));
27/// // Before min_len: stop_token excluded
28/// let mask = c.allowed_tokens(&[], 4).unwrap();
29/// assert!(!mask[1]); // stop token blocked
30/// assert!(mask[0]); // other tokens allowed
31/// ```
32pub struct LengthConstraint {
33 /// Minimum number of tokens that must be generated before stopping.
34 min_len: usize,
35 /// Hard upper bound on generated token count.
36 max_len: usize,
37 /// Optional end-of-sequence token; treated specially for early-stop control.
38 stop_token: Option<u32>,
39 /// Number of tokens committed via `advance` so far.
40 count: usize,
41 /// True once the `stop_token` has been committed.
42 stop_seen: bool,
43}
44
45impl LengthConstraint {
46 /// Create a new `LengthConstraint`.
47 ///
48 /// `min_len` must be `<= max_len`.
49 pub fn new(min_len: usize, max_len: usize, stop_token: Option<u32>) -> Self {
50 Self {
51 min_len,
52 max_len,
53 stop_token,
54 count: 0,
55 stop_seen: false,
56 }
57 }
58
59 /// Current token count.
60 pub fn count(&self) -> usize {
61 self.count
62 }
63}
64
65impl TokenConstraint for LengthConstraint {
66 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
67 if self.count < self.min_len {
68 // Cannot stop early — exclude stop_token if one is configured.
69 if let Some(stop) = self.stop_token {
70 let mut mask = vec![true; vocab_size];
71 let stop_idx = stop as usize;
72 if stop_idx < vocab_size {
73 mask[stop_idx] = false;
74 }
75 return Some(mask);
76 }
77 // No stop token: no restriction below min_len.
78 return None;
79 }
80
81 if self.count >= self.max_len {
82 // Must stop now.
83 if let Some(stop) = self.stop_token {
84 let mut mask = vec![false; vocab_size];
85 let stop_idx = stop as usize;
86 if stop_idx < vocab_size {
87 mask[stop_idx] = true;
88 }
89 return Some(mask);
90 }
91 // No stop token: emit an all-false mask to force external termination.
92 return Some(vec![false; vocab_size]);
93 }
94
95 // Between min and max: unconstrained.
96 None
97 }
98
99 /// Commits `token`, updating `count` and `stop_seen`. Always returns `true`.
100 fn advance(&mut self, token: u32) -> bool {
101 if let Some(stop) = self.stop_token {
102 if token == stop {
103 self.stop_seen = true;
104 }
105 }
106 self.count += 1;
107 true
108 }
109
110 /// Returns `true` when at least `min_len` tokens have been generated AND either
111 /// the `stop_token` was seen or `max_len` has been reached.
112 fn is_complete(&self) -> bool {
113 if self.count < self.min_len {
114 return false;
115 }
116 self.count >= self.max_len || self.stop_seen
117 }
118
119 /// Reset to initial state.
120 fn reset(&mut self) {
121 self.count = 0;
122 self.stop_seen = false;
123 }
124
125 fn name(&self) -> &str {
126 "LengthConstraint"
127 }
128}