oxibonsai_runtime/constrained_decoding/
sampler.rs1use super::error_trait::{NoConstraint, TokenConstraint};
7use super::json::JsonConstraint;
8use super::regex::RegexConstraint;
9use crate::constrained_decoding::ConstraintError;
10
11pub struct ConstrainedSampler {
20 inner: crate::sampling_advanced::SamplerChain,
21 constraint: Box<dyn TokenConstraint>,
22 generated: Vec<u32>,
23 vocab_size: usize,
24}
25
26impl ConstrainedSampler {
27 pub fn new(
29 sampler: crate::sampling_advanced::SamplerChain,
30 constraint: Box<dyn TokenConstraint>,
31 vocab_size: usize,
32 ) -> Self {
33 Self {
34 inner: sampler,
35 constraint,
36 generated: Vec::new(),
37 vocab_size,
38 }
39 }
40
41 pub fn sample(&mut self, logits: &mut Vec<f32>) -> u32 {
50 if let Some(mask) = self
52 .constraint
53 .allowed_tokens(&self.generated, self.vocab_size)
54 {
55 for (i, allowed) in mask.iter().enumerate() {
56 if i < logits.len() && !allowed {
57 logits[i] = -1e9;
58 }
59 }
60 }
61 let token = self.inner.sample(logits) as u32;
62 self.constraint.advance(token);
63 self.generated.push(token);
64 token
65 }
66
67 pub fn is_complete(&self) -> bool {
69 self.constraint.is_complete()
70 }
71
72 pub fn reset(&mut self) {
74 self.generated.clear();
75 self.constraint.reset();
76 }
77
78 pub fn generated_text_len(&self) -> usize {
80 self.generated.len()
81 }
82
83 pub fn constraint_name(&self) -> &str {
85 self.constraint.name()
86 }
87}
88
89pub struct ConstrainedSamplerBuilder {
95 vocab_size: usize,
96 seed: u64,
97}
98
99impl ConstrainedSamplerBuilder {
100 pub fn new(vocab_size: usize, seed: u64) -> Self {
102 Self { vocab_size, seed }
103 }
104
105 fn default_chain(&self) -> crate::sampling_advanced::SamplerChain {
106 crate::sampling_advanced::SamplerChain::new(self.seed)
107 }
108
109 pub fn with_json_constraint(self) -> ConstrainedSampler {
111 ConstrainedSampler::new(
112 self.default_chain(),
113 Box::new(JsonConstraint::new()),
114 self.vocab_size,
115 )
116 }
117
118 pub fn with_regex_constraint(
120 self,
121 pattern: &str,
122 ) -> Result<ConstrainedSampler, ConstraintError> {
123 let constraint = RegexConstraint::new(pattern)?;
124 let chain = self.default_chain();
125 Ok(ConstrainedSampler::new(
126 chain,
127 Box::new(constraint),
128 self.vocab_size,
129 ))
130 }
131
132 pub fn unconstrained(self) -> ConstrainedSampler {
134 ConstrainedSampler::new(
135 self.default_chain(),
136 Box::new(NoConstraint),
137 self.vocab_size,
138 )
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn constrained_sampler_masks_logits() {
148 struct AllowEvens;
150 impl TokenConstraint for AllowEvens {
151 fn allowed_tokens(&self, _: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
152 Some((0..vocab_size).map(|i| i % 2 == 0).collect())
153 }
154 fn advance(&mut self, _: u32) -> bool {
155 true
156 }
157 fn is_complete(&self) -> bool {
158 true
159 }
160 fn reset(&mut self) {}
161 fn name(&self) -> &str {
162 "AllowEvens"
163 }
164 }
165
166 let chain = crate::sampling_advanced::SamplerChain::greedy();
167 let mut sampler = ConstrainedSampler::new(chain, Box::new(AllowEvens), 4);
168 let mut logits = vec![2.0_f32, 10.0, 1.0, 0.5];
170 let tok = sampler.sample(&mut logits);
172 assert_eq!(tok, 0);
173 }
174
175 #[test]
176 fn constrained_sampler_greedy_json() {
177 let chain = crate::sampling_advanced::SamplerChain::greedy();
178 let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
179 assert!(!sampler.is_complete());
180 let mut logits_open = vec![0.0_f32; 256];
182 logits_open['{' as usize] = 100.0;
183 sampler.sample(&mut logits_open);
184
185 let mut logits_close = vec![0.0_f32; 256];
186 logits_close['}' as usize] = 100.0;
187 sampler.sample(&mut logits_close);
188
189 assert!(sampler.is_complete());
190 assert_eq!(sampler.generated_text_len(), 2);
191 }
192
193 #[test]
194 fn constrained_sampler_reset() {
195 let chain = crate::sampling_advanced::SamplerChain::greedy();
196 let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
197 let mut logits = vec![0.0_f32; 256];
198 logits['{' as usize] = 100.0;
199 sampler.sample(&mut logits);
200 assert_eq!(sampler.generated_text_len(), 1);
201 sampler.reset();
202 assert_eq!(sampler.generated_text_len(), 0);
203 }
204
205 #[test]
206 fn constrained_sampler_builder_json() {
207 let sampler = ConstrainedSamplerBuilder::new(256, 42).with_json_constraint();
208 assert_eq!(sampler.constraint_name(), "JsonConstraint");
209 }
210
211 #[test]
212 fn constrained_sampler_builder_unconstrained() {
213 let sampler = ConstrainedSamplerBuilder::new(256, 42).unconstrained();
214 assert_eq!(sampler.constraint_name(), "NoConstraint");
215 assert!(sampler.is_complete());
216 }
217}