oxibonsai_runtime/constrained_decoding/
json.rs1use super::error_trait::TokenConstraint;
7
8#[derive(Debug, Clone, PartialEq)]
14pub enum JsonParseState {
15 Start,
17 InObject,
19 InObjectKey,
21 AfterKey,
23 InObjectValue,
25 InArray,
27 InArrayValue,
29 InString,
31 InStringEscape,
33 InNumber,
35 InBool,
37 InNull,
39 Complete,
41 Error,
43}
44
45pub struct JsonConstraint {
49 state: JsonParseState,
50 depth: usize,
51 buffer: String,
52 expecting_comma_or_close: bool,
53 keyword_buf: String,
55 context_stack: Vec<char>,
57}
58
59impl JsonConstraint {
60 pub fn new() -> Self {
62 Self {
63 state: JsonParseState::Start,
64 depth: 0,
65 buffer: String::new(),
66 expecting_comma_or_close: false,
67 keyword_buf: String::new(),
68 context_stack: Vec::new(),
69 }
70 }
71
72 pub fn current_state(&self) -> &JsonParseState {
74 &self.state
75 }
76
77 pub fn depth(&self) -> usize {
79 self.depth
80 }
81
82 pub fn is_in_string(&self) -> bool {
84 matches!(
85 self.state,
86 JsonParseState::InString | JsonParseState::InStringEscape
87 )
88 }
89
90 pub fn valid_next_chars(&self) -> Vec<char> {
93 match &self.state {
94 JsonParseState::Start => {
95 vec![
96 '{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f',
97 'n', ' ', '\t', '\n',
98 ]
99 }
100 JsonParseState::InObject => {
101 if self.expecting_comma_or_close {
102 vec![',', '}', ' ', '\t', '\n']
103 } else {
104 vec!['"', '}', ' ', '\t', '\n']
105 }
106 }
107 JsonParseState::InObjectKey => {
108 let mut v: Vec<char> = (0x20u8..0x7fu8)
110 .filter(|&c| c != b'"')
111 .map(|c| c as char)
112 .collect();
113 v.push('"'); v.push('\\');
115 v
116 }
117 JsonParseState::AfterKey => vec![':', ' ', '\t'],
118 JsonParseState::InObjectValue
119 | JsonParseState::InArrayValue
120 | JsonParseState::InArray => {
121 if self.expecting_comma_or_close {
123 if self.context_stack.last() == Some(&'o') {
124 vec![',', '}', ' ', '\t', '\n']
125 } else {
126 vec![',', ']', ' ', '\t', '\n']
127 }
128 } else {
129 vec![
130 '{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't',
131 'f', 'n', ' ', '\t', '\n',
132 ]
133 }
134 }
135 JsonParseState::InString => {
136 let mut v: Vec<char> = (0x20u8..0x7fu8)
137 .filter(|&c| c != b'"')
138 .map(|c| c as char)
139 .collect();
140 v.push('"');
141 v.push('\\');
142 v
143 }
144 JsonParseState::InStringEscape => {
145 vec!['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u']
146 }
147 JsonParseState::InNumber => {
148 vec![
149 '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', 'e', 'E', '+', '-', ',',
150 '}', ']', ' ', '\t', '\n',
151 ]
152 }
153 JsonParseState::InBool | JsonParseState::InNull => {
154 vec![
156 'r', 'u', 'e', 'a', 'l', 's', 'i', 'o', 'n', 't', 'f', ',', '}', ']', ' ',
157 '\t', '\n',
158 ]
159 }
160 JsonParseState::Complete => {
161 vec![' ', '\t', '\n']
163 }
164 JsonParseState::Error => vec![],
165 }
166 }
167
168 fn feed_char(&mut self, ch: char) {
170 match &self.state.clone() {
171 JsonParseState::Error | JsonParseState::Complete => {
172 if self.state == JsonParseState::Complete && !ch.is_whitespace() {
174 self.state = JsonParseState::Error;
175 }
176 return;
177 }
178 JsonParseState::Start => {
179 if ch.is_whitespace() {
180 return;
181 }
182 match ch {
183 '{' => {
184 self.depth += 1;
185 self.context_stack.push('o');
186 self.state = JsonParseState::InObject;
187 self.expecting_comma_or_close = false;
188 }
189 '[' => {
190 self.depth += 1;
191 self.context_stack.push('a');
192 self.state = JsonParseState::InArray;
193 self.expecting_comma_or_close = false;
194 }
195 '"' => {
196 self.state = JsonParseState::InString;
197 }
198 '-' | '0'..='9' => {
199 self.state = JsonParseState::InNumber;
200 self.keyword_buf.clear();
201 self.keyword_buf.push(ch);
202 }
203 't' | 'f' => {
204 self.state = JsonParseState::InBool;
205 self.keyword_buf.clear();
206 self.keyword_buf.push(ch);
207 }
208 'n' => {
209 self.state = JsonParseState::InNull;
210 self.keyword_buf.clear();
211 self.keyword_buf.push(ch);
212 }
213 _ => {
214 self.state = JsonParseState::Error;
215 }
216 }
217 }
218 JsonParseState::InObject => {
219 if ch.is_whitespace() {
220 return;
221 }
222 if self.expecting_comma_or_close {
223 match ch {
224 ',' => {
225 self.expecting_comma_or_close = false;
226 }
227 '}' => {
228 self.close_context();
229 }
230 _ => {
231 self.state = JsonParseState::Error;
232 }
233 }
234 } else {
235 match ch {
236 '"' => {
237 self.state = JsonParseState::InObjectKey;
238 }
239 '}' => {
240 self.close_context();
241 }
242 _ => {
243 self.state = JsonParseState::Error;
244 }
245 }
246 }
247 }
248 JsonParseState::InObjectKey => {
249 match ch {
250 '"' => {
251 self.state = JsonParseState::AfterKey;
252 }
253 '\\' => {
254 self.state = JsonParseState::InStringEscape;
255 }
256 _ => {} }
258 }
259 JsonParseState::AfterKey => {
260 if ch.is_whitespace() {
261 return;
262 }
263 if ch == ':' {
264 self.state = JsonParseState::InObjectValue;
265 self.expecting_comma_or_close = false;
266 } else {
267 self.state = JsonParseState::Error;
268 }
269 }
270 JsonParseState::InObjectValue => {
271 if ch.is_whitespace() {
272 return;
273 }
274 self.start_value(ch, 'o');
275 }
276 JsonParseState::InArray => {
277 if ch.is_whitespace() {
278 return;
279 }
280 if self.expecting_comma_or_close {
281 match ch {
282 ',' => {
283 self.expecting_comma_or_close = false;
284 }
285 ']' => {
286 self.close_context();
287 }
288 _ => {
289 self.state = JsonParseState::Error;
290 }
291 }
292 } else {
293 match ch {
294 ']' => {
295 self.close_context();
296 }
297 _ => {
298 self.start_value(ch, 'a');
299 }
300 }
301 }
302 }
303 JsonParseState::InArrayValue => {
304 if ch.is_whitespace() {
305 return;
306 }
307 if self.expecting_comma_or_close {
308 if self.context_stack.last() == Some(&'a') {
309 match ch {
310 ',' => {
311 self.expecting_comma_or_close = false;
312 self.state = JsonParseState::InArray;
313 }
314 ']' => {
315 self.close_context();
316 }
317 _ => {
318 self.state = JsonParseState::Error;
319 }
320 }
321 } else {
322 match ch {
323 ',' => {
324 self.expecting_comma_or_close = false;
325 self.state = JsonParseState::InObject;
326 }
327 '}' => {
328 self.close_context();
329 }
330 _ => {
331 self.state = JsonParseState::Error;
332 }
333 }
334 }
335 } else {
336 self.start_value(ch, *self.context_stack.last().unwrap_or(&'a'));
337 }
338 }
339 JsonParseState::InString => {
340 match ch {
341 '"' => {
342 self.finish_string();
343 }
344 '\\' => {
345 self.state = JsonParseState::InStringEscape;
346 }
347 _ => {} }
349 }
350 JsonParseState::InStringEscape => {
351 self.state = JsonParseState::InString;
353 }
354 JsonParseState::InNumber => {
355 match ch {
356 '0'..='9' | '.' | 'e' | 'E' | '+' | '-' => {
357 self.keyword_buf.push(ch);
358 }
359 _ => {
360 self.finish_value();
362 self.feed_char(ch);
363 }
364 }
365 }
366 JsonParseState::InBool => {
367 self.keyword_buf.push(ch);
368 let kb = self.keyword_buf.clone();
369 if kb == "true" || kb == "false" {
370 self.keyword_buf.clear();
371 self.finish_value();
372 } else if !"true".starts_with(kb.as_str()) && !"false".starts_with(kb.as_str()) {
373 self.state = JsonParseState::Error;
374 }
375 }
376 JsonParseState::InNull => {
377 self.keyword_buf.push(ch);
378 let kb = self.keyword_buf.clone();
379 if kb == "null" {
380 self.keyword_buf.clear();
381 self.finish_value();
382 } else if !"null".starts_with(kb.as_str()) {
383 self.state = JsonParseState::Error;
384 }
385 }
386 }
387 self.buffer.push(ch);
388 }
389
390 fn start_value(&mut self, ch: char, ctx: char) {
392 match ch {
393 '{' => {
394 self.depth += 1;
395 self.context_stack.push('o');
396 self.state = JsonParseState::InObject;
397 self.expecting_comma_or_close = false;
398 }
399 '[' => {
400 self.depth += 1;
401 self.context_stack.push('a');
402 self.state = JsonParseState::InArray;
403 self.expecting_comma_or_close = false;
404 }
405 '"' => {
406 self.state = JsonParseState::InString;
407 }
408 '-' | '0'..='9' => {
409 self.state = JsonParseState::InNumber;
410 self.keyword_buf.clear();
411 self.keyword_buf.push(ch);
412 let _ = ctx; }
414 't' | 'f' => {
415 self.state = JsonParseState::InBool;
416 self.keyword_buf.clear();
417 self.keyword_buf.push(ch);
418 }
419 'n' => {
420 self.state = JsonParseState::InNull;
421 self.keyword_buf.clear();
422 self.keyword_buf.push(ch);
423 }
424 _ => {
425 self.state = JsonParseState::Error;
426 }
427 }
428 }
429
430 fn finish_value(&mut self) {
432 self.expecting_comma_or_close = true;
433 match self.context_stack.last() {
434 Some(&'o') => {
435 self.state = JsonParseState::InObject;
436 }
437 Some(&'a') => {
438 self.state = JsonParseState::InArray;
439 }
440 None => {
441 self.state = JsonParseState::Complete;
442 }
443 _ => {
444 self.state = JsonParseState::Error;
445 }
446 }
447 }
448
449 fn finish_string(&mut self) {
451 match self.context_stack.last() {
452 Some(&'o') => {
453 self.state = JsonParseState::InObject;
454 self.expecting_comma_or_close = true;
455 }
456 Some(&'a') => {
457 self.state = JsonParseState::InArray;
458 self.expecting_comma_or_close = true;
459 }
460 None => {
461 self.state = JsonParseState::Complete;
462 }
463 _ => {
464 self.state = JsonParseState::Error;
465 }
466 }
467 }
468
469 fn close_context(&mut self) {
471 if let Some(ctx) = self.context_stack.pop() {
472 if ctx == 'o' || ctx == 'a' {
473 self.depth = self.depth.saturating_sub(1);
474 }
475 }
476 self.expecting_comma_or_close = true;
477 match self.context_stack.last() {
478 Some(&'o') => {
479 self.state = JsonParseState::InObject;
480 }
481 Some(&'a') => {
482 self.state = JsonParseState::InArray;
483 }
484 None => {
485 self.state = JsonParseState::Complete;
486 }
487 _ => {
488 self.state = JsonParseState::Error;
489 }
490 }
491 }
492}
493
494impl Default for JsonConstraint {
495 fn default() -> Self {
496 Self::new()
497 }
498}
499
500impl TokenConstraint for JsonConstraint {
501 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
502 if self.state == JsonParseState::Error {
503 return Some(vec![false; vocab_size]);
504 }
505 let valid = self.valid_next_chars();
508 let mask: Vec<bool> = (0..vocab_size)
509 .map(|id| {
510 let ch = char::from_u32(id as u32).unwrap_or('\u{FFFD}');
512 ch as u32 > 127 || valid.contains(&ch)
515 })
516 .collect();
517 Some(mask)
518 }
519
520 fn advance(&mut self, token: u32) -> bool {
521 if self.state == JsonParseState::Error {
522 return false;
523 }
524 if let Some(ch) = char::from_u32(token) {
526 self.feed_char(ch);
527 }
528 self.state != JsonParseState::Error
529 }
530
531 fn is_complete(&self) -> bool {
532 self.state == JsonParseState::Complete
533 }
534
535 fn reset(&mut self) {
536 *self = Self::new();
537 }
538
539 fn name(&self) -> &str {
540 "JsonConstraint"
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn json_constraint_initial_state() {
550 let jc = JsonConstraint::new();
551 assert_eq!(*jc.current_state(), JsonParseState::Start);
552 assert_eq!(jc.depth(), 0);
553 }
554
555 #[test]
556 fn json_constraint_valid_object_chars() {
557 let jc = JsonConstraint::new();
558 let valid = jc.valid_next_chars();
559 assert!(valid.contains(&'{'));
560 assert!(valid.contains(&'['));
561 assert!(valid.contains(&'"'));
562 }
563
564 #[test]
565 fn json_constraint_tracks_depth() {
566 let mut jc = JsonConstraint::new();
567 jc.advance('{' as u32);
568 assert_eq!(jc.depth(), 1);
569 jc.advance('"' as u32);
570 jc.advance('k' as u32);
571 jc.advance('"' as u32);
572 jc.advance(':' as u32);
573 jc.advance('{' as u32);
574 assert_eq!(jc.depth(), 2);
575 jc.advance('}' as u32);
576 assert_eq!(jc.depth(), 1);
577 }
578
579 #[test]
580 fn json_constraint_detects_completion() {
581 let mut jc = JsonConstraint::new();
582 assert!(!jc.is_complete());
583 jc.advance('{' as u32);
585 jc.advance('}' as u32);
586 assert!(jc.is_complete());
587 }
588
589 #[test]
590 fn json_constraint_in_string_state() {
591 let mut jc = JsonConstraint::new();
592 jc.advance('"' as u32);
593 assert!(jc.is_in_string());
594 jc.advance('"' as u32);
595 assert!(!jc.is_in_string());
596 }
597}