Skip to main content

oxibonsai_runtime/constrained_decoding/
length.rs

1//! [`LengthConstraint`] — enforce hard minimum and maximum generation lengths.
2
3use super::error_trait::TokenConstraint;
4
5// ─────────────────────────────────────────────────────────────────────────────
6// LengthConstraint — enforce hard minimum and maximum generation lengths
7// ─────────────────────────────────────────────────────────────────────────────
8
9/// A constraint that enforces hard minimum and maximum token-count limits.
10///
11/// - While `count < min_len`: if a `stop_token` is configured it is excluded from
12///   the mask (cannot stop early).
13/// - While `count >= max_len`: if a `stop_token` is configured only that token is
14///   allowed; otherwise an all-`false` mask is returned, signalling the caller to
15///   halt generation externally.
16/// - Between `min_len` and `max_len`: all tokens are allowed (`None`).
17///
18/// Completion is defined as either reaching `max_len` OR generating `min_len` or
19/// more tokens followed by the `stop_token`.
20///
21/// # Example
22/// ```rust
23/// use oxibonsai_runtime::constrained_decoding::{LengthConstraint, TokenConstraint};
24///
25/// // Must generate at least 2 tokens, stop token is 1 (EOS), max 10.
26/// let mut c = LengthConstraint::new(2, 10, Some(1));
27/// // Before min_len: stop_token excluded
28/// let mask = c.allowed_tokens(&[], 4).unwrap();
29/// assert!(!mask[1]);  // stop token blocked
30/// assert!(mask[0]);   // other tokens allowed
31/// ```
32pub struct LengthConstraint {
33    /// Minimum number of tokens that must be generated before stopping.
34    min_len: usize,
35    /// Hard upper bound on generated token count.
36    max_len: usize,
37    /// Optional end-of-sequence token; treated specially for early-stop control.
38    stop_token: Option<u32>,
39    /// Number of tokens committed via `advance` so far.
40    count: usize,
41    /// True once the `stop_token` has been committed.
42    stop_seen: bool,
43}
44
45impl LengthConstraint {
46    /// Create a new `LengthConstraint`.
47    ///
48    /// `min_len` must be `<= max_len`.
49    pub fn new(min_len: usize, max_len: usize, stop_token: Option<u32>) -> Self {
50        Self {
51            min_len,
52            max_len,
53            stop_token,
54            count: 0,
55            stop_seen: false,
56        }
57    }
58
59    /// Current token count.
60    pub fn count(&self) -> usize {
61        self.count
62    }
63}
64
65impl TokenConstraint for LengthConstraint {
66    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
67        if self.count < self.min_len {
68            // Cannot stop early — exclude stop_token if one is configured.
69            if let Some(stop) = self.stop_token {
70                let mut mask = vec![true; vocab_size];
71                let stop_idx = stop as usize;
72                if stop_idx < vocab_size {
73                    mask[stop_idx] = false;
74                }
75                return Some(mask);
76            }
77            // No stop token: no restriction below min_len.
78            return None;
79        }
80
81        if self.count >= self.max_len {
82            // Must stop now.
83            if let Some(stop) = self.stop_token {
84                let mut mask = vec![false; vocab_size];
85                let stop_idx = stop as usize;
86                if stop_idx < vocab_size {
87                    mask[stop_idx] = true;
88                }
89                return Some(mask);
90            }
91            // No stop token: emit an all-false mask to force external termination.
92            return Some(vec![false; vocab_size]);
93        }
94
95        // Between min and max: unconstrained.
96        None
97    }
98
99    /// Commits `token`, updating `count` and `stop_seen`.  Always returns `true`.
100    fn advance(&mut self, token: u32) -> bool {
101        if let Some(stop) = self.stop_token {
102            if token == stop {
103                self.stop_seen = true;
104            }
105        }
106        self.count += 1;
107        true
108    }
109
110    /// Returns `true` when at least `min_len` tokens have been generated AND either
111    /// the `stop_token` was seen or `max_len` has been reached.
112    fn is_complete(&self) -> bool {
113        if self.count < self.min_len {
114            return false;
115        }
116        self.count >= self.max_len || self.stop_seen
117    }
118
119    /// Reset to initial state.
120    fn reset(&mut self) {
121        self.count = 0;
122        self.stop_seen = false;
123    }
124
125    fn name(&self) -> &str {
126        "LengthConstraint"
127    }
128}