Skip to main content

oxibonsai_runtime/constrained_decoding/
json.rs

1//! JSON-grammar [`TokenConstraint`] implementation.
2//!
3//! Hosts [`JsonParseState`] and the [`JsonConstraint`] state machine that
4//! restricts generation to syntactically valid JSON.
5
6use super::error_trait::TokenConstraint;
7
8// ─────────────────────────────────────────────────────────────────────────────
9// JsonConstraint
10// ─────────────────────────────────────────────────────────────────────────────
11
12/// Internal parser state for `JsonConstraint`.
13#[derive(Debug, Clone, PartialEq)]
14pub enum JsonParseState {
15    /// Before any character has been emitted.
16    Start,
17    /// Inside a JSON object `{`, waiting for a key or `}`.
18    InObject,
19    /// Inside a string that is an object key.
20    InObjectKey,
21    /// After an object key, expecting `:`.
22    AfterKey,
23    /// After `:`, waiting for a value.
24    InObjectValue,
25    /// Inside a JSON array `[`, waiting for a value or `]`.
26    InArray,
27    /// After a value inside an array, waiting for `,` or `]`.
28    InArrayValue,
29    /// Inside a string value (or key).
30    InString,
31    /// Immediately after a `\` inside a string.
32    InStringEscape,
33    /// Inside a number literal.
34    InNumber,
35    /// Inside a boolean keyword (`true` / `false`).
36    InBool,
37    /// Inside `null`.
38    InNull,
39    /// Top-level value is complete.
40    Complete,
41    /// An error has been encountered.
42    Error,
43}
44
45/// Constrains generation to syntactically valid JSON.
46///
47/// Tracks nesting depth and parse state character by character.
48pub struct JsonConstraint {
49    state: JsonParseState,
50    depth: usize,
51    buffer: String,
52    expecting_comma_or_close: bool,
53    // For keyword tracking (true/false/null).
54    keyword_buf: String,
55    // Stack of context: 'o' = object, 'a' = array.
56    context_stack: Vec<char>,
57}
58
59impl JsonConstraint {
60    /// Create a new `JsonConstraint` in its initial state.
61    pub fn new() -> Self {
62        Self {
63            state: JsonParseState::Start,
64            depth: 0,
65            buffer: String::new(),
66            expecting_comma_or_close: false,
67            keyword_buf: String::new(),
68            context_stack: Vec::new(),
69        }
70    }
71
72    /// Current parse state.
73    pub fn current_state(&self) -> &JsonParseState {
74        &self.state
75    }
76
77    /// Current nesting depth.
78    pub fn depth(&self) -> usize {
79        self.depth
80    }
81
82    /// Returns `true` if we are currently inside a string.
83    pub fn is_in_string(&self) -> bool {
84        matches!(
85            self.state,
86            JsonParseState::InString | JsonParseState::InStringEscape
87        )
88    }
89
90    /// Returns the set of ASCII characters that are valid as the *next* character
91    /// given the current parse state.
92    pub fn valid_next_chars(&self) -> Vec<char> {
93        match &self.state {
94            JsonParseState::Start => {
95                vec![
96                    '{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f',
97                    'n', ' ', '\t', '\n',
98                ]
99            }
100            JsonParseState::InObject => {
101                if self.expecting_comma_or_close {
102                    vec![',', '}', ' ', '\t', '\n']
103                } else {
104                    vec!['"', '}', ' ', '\t', '\n']
105                }
106            }
107            JsonParseState::InObjectKey => {
108                // Any printable ASCII except " (which closes) and \ (handled separately).
109                let mut v: Vec<char> = (0x20u8..0x7fu8)
110                    .filter(|&c| c != b'"')
111                    .map(|c| c as char)
112                    .collect();
113                v.push('"'); // closing quote
114                v.push('\\');
115                v
116            }
117            JsonParseState::AfterKey => vec![':', ' ', '\t'],
118            JsonParseState::InObjectValue
119            | JsonParseState::InArrayValue
120            | JsonParseState::InArray => {
121                // Start of any JSON value.
122                if self.expecting_comma_or_close {
123                    if self.context_stack.last() == Some(&'o') {
124                        vec![',', '}', ' ', '\t', '\n']
125                    } else {
126                        vec![',', ']', ' ', '\t', '\n']
127                    }
128                } else {
129                    vec![
130                        '{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't',
131                        'f', 'n', ' ', '\t', '\n',
132                    ]
133                }
134            }
135            JsonParseState::InString => {
136                let mut v: Vec<char> = (0x20u8..0x7fu8)
137                    .filter(|&c| c != b'"')
138                    .map(|c| c as char)
139                    .collect();
140                v.push('"');
141                v.push('\\');
142                v
143            }
144            JsonParseState::InStringEscape => {
145                vec!['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u']
146            }
147            JsonParseState::InNumber => {
148                vec![
149                    '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', 'e', 'E', '+', '-', ',',
150                    '}', ']', ' ', '\t', '\n',
151                ]
152            }
153            JsonParseState::InBool | JsonParseState::InNull => {
154                // Allow letters that could continue the keyword.
155                vec![
156                    'r', 'u', 'e', 'a', 'l', 's', 'i', 'o', 'n', 't', 'f', ',', '}', ']', ' ',
157                    '\t', '\n',
158                ]
159            }
160            JsonParseState::Complete => {
161                // After a complete value, allow whitespace.
162                vec![' ', '\t', '\n']
163            }
164            JsonParseState::Error => vec![],
165        }
166    }
167
168    /// Feed a single character through the state machine.
169    fn feed_char(&mut self, ch: char) {
170        match &self.state.clone() {
171            JsonParseState::Error | JsonParseState::Complete => {
172                // In Complete state whitespace is ok; anything else is an error.
173                if self.state == JsonParseState::Complete && !ch.is_whitespace() {
174                    self.state = JsonParseState::Error;
175                }
176                return;
177            }
178            JsonParseState::Start => {
179                if ch.is_whitespace() {
180                    return;
181                }
182                match ch {
183                    '{' => {
184                        self.depth += 1;
185                        self.context_stack.push('o');
186                        self.state = JsonParseState::InObject;
187                        self.expecting_comma_or_close = false;
188                    }
189                    '[' => {
190                        self.depth += 1;
191                        self.context_stack.push('a');
192                        self.state = JsonParseState::InArray;
193                        self.expecting_comma_or_close = false;
194                    }
195                    '"' => {
196                        self.state = JsonParseState::InString;
197                    }
198                    '-' | '0'..='9' => {
199                        self.state = JsonParseState::InNumber;
200                        self.keyword_buf.clear();
201                        self.keyword_buf.push(ch);
202                    }
203                    't' | 'f' => {
204                        self.state = JsonParseState::InBool;
205                        self.keyword_buf.clear();
206                        self.keyword_buf.push(ch);
207                    }
208                    'n' => {
209                        self.state = JsonParseState::InNull;
210                        self.keyword_buf.clear();
211                        self.keyword_buf.push(ch);
212                    }
213                    _ => {
214                        self.state = JsonParseState::Error;
215                    }
216                }
217            }
218            JsonParseState::InObject => {
219                if ch.is_whitespace() {
220                    return;
221                }
222                if self.expecting_comma_or_close {
223                    match ch {
224                        ',' => {
225                            self.expecting_comma_or_close = false;
226                        }
227                        '}' => {
228                            self.close_context();
229                        }
230                        _ => {
231                            self.state = JsonParseState::Error;
232                        }
233                    }
234                } else {
235                    match ch {
236                        '"' => {
237                            self.state = JsonParseState::InObjectKey;
238                        }
239                        '}' => {
240                            self.close_context();
241                        }
242                        _ => {
243                            self.state = JsonParseState::Error;
244                        }
245                    }
246                }
247            }
248            JsonParseState::InObjectKey => {
249                match ch {
250                    '"' => {
251                        self.state = JsonParseState::AfterKey;
252                    }
253                    '\\' => {
254                        self.state = JsonParseState::InStringEscape;
255                    }
256                    _ => {} // Any other char stays in key
257                }
258            }
259            JsonParseState::AfterKey => {
260                if ch.is_whitespace() {
261                    return;
262                }
263                if ch == ':' {
264                    self.state = JsonParseState::InObjectValue;
265                    self.expecting_comma_or_close = false;
266                } else {
267                    self.state = JsonParseState::Error;
268                }
269            }
270            JsonParseState::InObjectValue => {
271                if ch.is_whitespace() {
272                    return;
273                }
274                self.start_value(ch, 'o');
275            }
276            JsonParseState::InArray => {
277                if ch.is_whitespace() {
278                    return;
279                }
280                if self.expecting_comma_or_close {
281                    match ch {
282                        ',' => {
283                            self.expecting_comma_or_close = false;
284                        }
285                        ']' => {
286                            self.close_context();
287                        }
288                        _ => {
289                            self.state = JsonParseState::Error;
290                        }
291                    }
292                } else {
293                    match ch {
294                        ']' => {
295                            self.close_context();
296                        }
297                        _ => {
298                            self.start_value(ch, 'a');
299                        }
300                    }
301                }
302            }
303            JsonParseState::InArrayValue => {
304                if ch.is_whitespace() {
305                    return;
306                }
307                if self.expecting_comma_or_close {
308                    if self.context_stack.last() == Some(&'a') {
309                        match ch {
310                            ',' => {
311                                self.expecting_comma_or_close = false;
312                                self.state = JsonParseState::InArray;
313                            }
314                            ']' => {
315                                self.close_context();
316                            }
317                            _ => {
318                                self.state = JsonParseState::Error;
319                            }
320                        }
321                    } else {
322                        match ch {
323                            ',' => {
324                                self.expecting_comma_or_close = false;
325                                self.state = JsonParseState::InObject;
326                            }
327                            '}' => {
328                                self.close_context();
329                            }
330                            _ => {
331                                self.state = JsonParseState::Error;
332                            }
333                        }
334                    }
335                } else {
336                    self.start_value(ch, *self.context_stack.last().unwrap_or(&'a'));
337                }
338            }
339            JsonParseState::InString => {
340                match ch {
341                    '"' => {
342                        self.finish_string();
343                    }
344                    '\\' => {
345                        self.state = JsonParseState::InStringEscape;
346                    }
347                    _ => {} // Any other char stays in string
348                }
349            }
350            JsonParseState::InStringEscape => {
351                // Accept any valid escape char; fall back to InString.
352                self.state = JsonParseState::InString;
353            }
354            JsonParseState::InNumber => {
355                match ch {
356                    '0'..='9' | '.' | 'e' | 'E' | '+' | '-' => {
357                        self.keyword_buf.push(ch);
358                    }
359                    _ => {
360                        // Number ended — treat `ch` as the next character after value.
361                        self.finish_value();
362                        self.feed_char(ch);
363                    }
364                }
365            }
366            JsonParseState::InBool => {
367                self.keyword_buf.push(ch);
368                let kb = self.keyword_buf.clone();
369                if kb == "true" || kb == "false" {
370                    self.keyword_buf.clear();
371                    self.finish_value();
372                } else if !"true".starts_with(kb.as_str()) && !"false".starts_with(kb.as_str()) {
373                    self.state = JsonParseState::Error;
374                }
375            }
376            JsonParseState::InNull => {
377                self.keyword_buf.push(ch);
378                let kb = self.keyword_buf.clone();
379                if kb == "null" {
380                    self.keyword_buf.clear();
381                    self.finish_value();
382                } else if !"null".starts_with(kb.as_str()) {
383                    self.state = JsonParseState::Error;
384                }
385            }
386        }
387        self.buffer.push(ch);
388    }
389
390    /// Begin parsing a new JSON value starting with `ch`.
391    fn start_value(&mut self, ch: char, ctx: char) {
392        match ch {
393            '{' => {
394                self.depth += 1;
395                self.context_stack.push('o');
396                self.state = JsonParseState::InObject;
397                self.expecting_comma_or_close = false;
398            }
399            '[' => {
400                self.depth += 1;
401                self.context_stack.push('a');
402                self.state = JsonParseState::InArray;
403                self.expecting_comma_or_close = false;
404            }
405            '"' => {
406                self.state = JsonParseState::InString;
407            }
408            '-' | '0'..='9' => {
409                self.state = JsonParseState::InNumber;
410                self.keyword_buf.clear();
411                self.keyword_buf.push(ch);
412                let _ = ctx; // context noted but not needed here
413            }
414            't' | 'f' => {
415                self.state = JsonParseState::InBool;
416                self.keyword_buf.clear();
417                self.keyword_buf.push(ch);
418            }
419            'n' => {
420                self.state = JsonParseState::InNull;
421                self.keyword_buf.clear();
422                self.keyword_buf.push(ch);
423            }
424            _ => {
425                self.state = JsonParseState::Error;
426            }
427        }
428    }
429
430    /// A scalar value (string/number/bool/null) has been completed.
431    fn finish_value(&mut self) {
432        self.expecting_comma_or_close = true;
433        match self.context_stack.last() {
434            Some(&'o') => {
435                self.state = JsonParseState::InObject;
436            }
437            Some(&'a') => {
438                self.state = JsonParseState::InArray;
439            }
440            None => {
441                self.state = JsonParseState::Complete;
442            }
443            _ => {
444                self.state = JsonParseState::Error;
445            }
446        }
447    }
448
449    /// A `"` was seen — close the current string.
450    fn finish_string(&mut self) {
451        match self.context_stack.last() {
452            Some(&'o') => {
453                self.state = JsonParseState::InObject;
454                self.expecting_comma_or_close = true;
455            }
456            Some(&'a') => {
457                self.state = JsonParseState::InArray;
458                self.expecting_comma_or_close = true;
459            }
460            None => {
461                self.state = JsonParseState::Complete;
462            }
463            _ => {
464                self.state = JsonParseState::Error;
465            }
466        }
467    }
468
469    /// Close the current object or array context.
470    fn close_context(&mut self) {
471        if let Some(ctx) = self.context_stack.pop() {
472            if ctx == 'o' || ctx == 'a' {
473                self.depth = self.depth.saturating_sub(1);
474            }
475        }
476        self.expecting_comma_or_close = true;
477        match self.context_stack.last() {
478            Some(&'o') => {
479                self.state = JsonParseState::InObject;
480            }
481            Some(&'a') => {
482                self.state = JsonParseState::InArray;
483            }
484            None => {
485                self.state = JsonParseState::Complete;
486            }
487            _ => {
488                self.state = JsonParseState::Error;
489            }
490        }
491    }
492}
493
494impl Default for JsonConstraint {
495    fn default() -> Self {
496        Self::new()
497    }
498}
499
500impl TokenConstraint for JsonConstraint {
501    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
502        if self.state == JsonParseState::Error {
503            return Some(vec![false; vocab_size]);
504        }
505        // Conservative: for each token id in [0, vocab_size) check if its first
506        // ASCII character (treating the id as codepoint) is in valid_next_chars.
507        let valid = self.valid_next_chars();
508        let mask: Vec<bool> = (0..vocab_size)
509            .map(|id| {
510                // Map token id to a char for a simplified single-char check.
511                let ch = char::from_u32(id as u32).unwrap_or('\u{FFFD}');
512                // Allow if valid_next_chars contains it, or if the token is non-ASCII
513                // (we can't tell without a vocab table — be conservative and allow).
514                ch as u32 > 127 || valid.contains(&ch)
515            })
516            .collect();
517        Some(mask)
518    }
519
520    fn advance(&mut self, token: u32) -> bool {
521        if self.state == JsonParseState::Error {
522            return false;
523        }
524        // Treat token id as a codepoint.
525        if let Some(ch) = char::from_u32(token) {
526            self.feed_char(ch);
527        }
528        self.state != JsonParseState::Error
529    }
530
531    fn is_complete(&self) -> bool {
532        self.state == JsonParseState::Complete
533    }
534
535    fn reset(&mut self) {
536        *self = Self::new();
537    }
538
539    fn name(&self) -> &str {
540        "JsonConstraint"
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn json_constraint_initial_state() {
550        let jc = JsonConstraint::new();
551        assert_eq!(*jc.current_state(), JsonParseState::Start);
552        assert_eq!(jc.depth(), 0);
553    }
554
555    #[test]
556    fn json_constraint_valid_object_chars() {
557        let jc = JsonConstraint::new();
558        let valid = jc.valid_next_chars();
559        assert!(valid.contains(&'{'));
560        assert!(valid.contains(&'['));
561        assert!(valid.contains(&'"'));
562    }
563
564    #[test]
565    fn json_constraint_tracks_depth() {
566        let mut jc = JsonConstraint::new();
567        jc.advance('{' as u32);
568        assert_eq!(jc.depth(), 1);
569        jc.advance('"' as u32);
570        jc.advance('k' as u32);
571        jc.advance('"' as u32);
572        jc.advance(':' as u32);
573        jc.advance('{' as u32);
574        assert_eq!(jc.depth(), 2);
575        jc.advance('}' as u32);
576        assert_eq!(jc.depth(), 1);
577    }
578
579    #[test]
580    fn json_constraint_detects_completion() {
581        let mut jc = JsonConstraint::new();
582        assert!(!jc.is_complete());
583        // Feed `{}`
584        jc.advance('{' as u32);
585        jc.advance('}' as u32);
586        assert!(jc.is_complete());
587    }
588
589    #[test]
590    fn json_constraint_in_string_state() {
591        let mut jc = JsonConstraint::new();
592        jc.advance('"' as u32);
593        assert!(jc.is_in_string());
594        jc.advance('"' as u32);
595        assert!(!jc.is_in_string());
596    }
597}