Skip to main content

oxibonsai_runtime/constrained_decoding/
sampler.rs

1//! [`ConstrainedSampler`] and its builder.
2//!
3//! Wraps a [`crate::sampling_advanced::SamplerChain`] with a
4//! [`TokenConstraint`] and applies the mask to logits before sampling.
5
6use super::error_trait::{NoConstraint, TokenConstraint};
7use super::json::JsonConstraint;
8use super::regex::RegexConstraint;
9use crate::constrained_decoding::ConstraintError;
10
11// ─────────────────────────────────────────────────────────────────────────────
12// ConstrainedSampler
13// ─────────────────────────────────────────────────────────────────────────────
14
15/// Wraps a [`crate::sampling_advanced::SamplerChain`] with a [`TokenConstraint`].
16///
17/// Before each sampling step the logits for disallowed tokens are masked to
18/// `-1e9` so they are effectively excluded from the distribution.
19pub struct ConstrainedSampler {
20    inner: crate::sampling_advanced::SamplerChain,
21    constraint: Box<dyn TokenConstraint>,
22    generated: Vec<u32>,
23    vocab_size: usize,
24}
25
26impl ConstrainedSampler {
27    /// Create a new `ConstrainedSampler`.
28    pub fn new(
29        sampler: crate::sampling_advanced::SamplerChain,
30        constraint: Box<dyn TokenConstraint>,
31        vocab_size: usize,
32    ) -> Self {
33        Self {
34            inner: sampler,
35            constraint,
36            generated: Vec::new(),
37            vocab_size,
38        }
39    }
40
41    /// Sample the next token, masking logits for disallowed tokens first.
42    ///
43    /// Steps:
44    /// 1. Query the constraint for an allowed-token mask.
45    /// 2. Set `logits[i] = -1e9` for every `false` entry in the mask.
46    /// 3. Delegate to the inner sampler chain.
47    /// 4. Call `constraint.advance(token)`.
48    /// 5. Track the token in `self.generated`.
49    pub fn sample(&mut self, logits: &mut Vec<f32>) -> u32 {
50        // Apply constraint mask.
51        if let Some(mask) = self
52            .constraint
53            .allowed_tokens(&self.generated, self.vocab_size)
54        {
55            for (i, allowed) in mask.iter().enumerate() {
56                if i < logits.len() && !allowed {
57                    logits[i] = -1e9;
58                }
59            }
60        }
61        let token = self.inner.sample(logits) as u32;
62        self.constraint.advance(token);
63        self.generated.push(token);
64        token
65    }
66
67    /// Returns `true` if the constraint considers the current output complete.
68    pub fn is_complete(&self) -> bool {
69        self.constraint.is_complete()
70    }
71
72    /// Reset both the inner sampler state and the constraint.
73    pub fn reset(&mut self) {
74        self.generated.clear();
75        self.constraint.reset();
76    }
77
78    /// Number of tokens generated so far.
79    pub fn generated_text_len(&self) -> usize {
80        self.generated.len()
81    }
82
83    /// Human-readable name of the active constraint.
84    pub fn constraint_name(&self) -> &str {
85        self.constraint.name()
86    }
87}
88
89// ─────────────────────────────────────────────────────────────────────────────
90// ConstrainedSamplerBuilder
91// ─────────────────────────────────────────────────────────────────────────────
92
93/// Ergonomic builder for [`ConstrainedSampler`].
94pub struct ConstrainedSamplerBuilder {
95    vocab_size: usize,
96    seed: u64,
97}
98
99impl ConstrainedSamplerBuilder {
100    /// Create a new builder.
101    pub fn new(vocab_size: usize, seed: u64) -> Self {
102        Self { vocab_size, seed }
103    }
104
105    fn default_chain(&self) -> crate::sampling_advanced::SamplerChain {
106        crate::sampling_advanced::SamplerChain::new(self.seed)
107    }
108
109    /// Build a `ConstrainedSampler` with a `JsonConstraint`.
110    pub fn with_json_constraint(self) -> ConstrainedSampler {
111        ConstrainedSampler::new(
112            self.default_chain(),
113            Box::new(JsonConstraint::new()),
114            self.vocab_size,
115        )
116    }
117
118    /// Build a `ConstrainedSampler` with a `RegexConstraint`.
119    pub fn with_regex_constraint(
120        self,
121        pattern: &str,
122    ) -> Result<ConstrainedSampler, ConstraintError> {
123        let constraint = RegexConstraint::new(pattern)?;
124        let chain = self.default_chain();
125        Ok(ConstrainedSampler::new(
126            chain,
127            Box::new(constraint),
128            self.vocab_size,
129        ))
130    }
131
132    /// Build an unconstrained `ConstrainedSampler` (passthrough).
133    pub fn unconstrained(self) -> ConstrainedSampler {
134        ConstrainedSampler::new(
135            self.default_chain(),
136            Box::new(NoConstraint),
137            self.vocab_size,
138        )
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn constrained_sampler_masks_logits() {
148        // vocab_size = 4; mask allows only tokens 0 and 2
149        struct AllowEvens;
150        impl TokenConstraint for AllowEvens {
151            fn allowed_tokens(&self, _: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
152                Some((0..vocab_size).map(|i| i % 2 == 0).collect())
153            }
154            fn advance(&mut self, _: u32) -> bool {
155                true
156            }
157            fn is_complete(&self) -> bool {
158                true
159            }
160            fn reset(&mut self) {}
161            fn name(&self) -> &str {
162                "AllowEvens"
163            }
164        }
165
166        let chain = crate::sampling_advanced::SamplerChain::greedy();
167        let mut sampler = ConstrainedSampler::new(chain, Box::new(AllowEvens), 4);
168        // Make token 1 have highest logit; after masking token 0 should win.
169        let mut logits = vec![2.0_f32, 10.0, 1.0, 0.5];
170        // token 1 is masked → token 0 wins (highest among allowed)
171        let tok = sampler.sample(&mut logits);
172        assert_eq!(tok, 0);
173    }
174
175    #[test]
176    fn constrained_sampler_greedy_json() {
177        let chain = crate::sampling_advanced::SamplerChain::greedy();
178        let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
179        assert!(!sampler.is_complete());
180        // Feed '{' then '}'
181        let mut logits_open = vec![0.0_f32; 256];
182        logits_open['{' as usize] = 100.0;
183        sampler.sample(&mut logits_open);
184
185        let mut logits_close = vec![0.0_f32; 256];
186        logits_close['}' as usize] = 100.0;
187        sampler.sample(&mut logits_close);
188
189        assert!(sampler.is_complete());
190        assert_eq!(sampler.generated_text_len(), 2);
191    }
192
193    #[test]
194    fn constrained_sampler_reset() {
195        let chain = crate::sampling_advanced::SamplerChain::greedy();
196        let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
197        let mut logits = vec![0.0_f32; 256];
198        logits['{' as usize] = 100.0;
199        sampler.sample(&mut logits);
200        assert_eq!(sampler.generated_text_len(), 1);
201        sampler.reset();
202        assert_eq!(sampler.generated_text_len(), 0);
203    }
204
205    #[test]
206    fn constrained_sampler_builder_json() {
207        let sampler = ConstrainedSamplerBuilder::new(256, 42).with_json_constraint();
208        assert_eq!(sampler.constraint_name(), "JsonConstraint");
209    }
210
211    #[test]
212    fn constrained_sampler_builder_unconstrained() {
213        let sampler = ConstrainedSamplerBuilder::new(256, 42).unconstrained();
214        assert_eq!(sampler.constraint_name(), "NoConstraint");
215        assert!(sampler.is_complete());
216    }
217}