Skip to main content

ferrum_sampler/
json_mode.rs

1//! JSON mode logits processor.
2//!
3//! Constrains generation to produce valid JSON by tracking a state machine
4//! and masking tokens that would produce invalid syntax at each step.
5//!
6//! # Approach
7//!
8//! Rather than full grammar-guided generation (which requires tokenizer-level
9//! mapping), this processor uses a lightweight state machine that tracks
10//! whether we're inside a string, after a key, expecting a value, etc.
11//! It biases logits to favor JSON-structural tokens without fully preventing
12//! all invalid outputs.
13//!
14//! For a production-quality implementation, this would need:
15//! - Tokenizer integration to map token IDs to byte sequences
16//! - Full JSON grammar with recursive descent validation
17//! - Efficient bitset masking over the vocabulary
18//!
19//! This MVP provides the infrastructure and demonstrates the pattern.
20
21use ferrum_interfaces::sampler::{LogitsProcessor, ProcessorPriority, SamplingContext};
22use ferrum_types::Result;
23use parking_lot::Mutex;
24
25/// Tracks the current position in JSON structure.
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum JsonState {
28    /// Before any output — expecting `{` or `[`.
29    Start,
30    /// Inside an object, expecting a key (string) or `}`.
31    ObjectStart,
32    /// After a key, expecting `:`.
33    AfterKey,
34    /// After `:`, expecting a value.
35    AfterColon,
36    /// After a value, expecting `,` or `}` / `]`.
37    AfterValue,
38    /// Inside a string literal.
39    InString,
40    /// Inside an array, expecting value or `]`.
41    ArrayStart,
42    /// Generation complete (closing brace/bracket emitted).
43    Done,
44}
45
46/// JSON mode logits processor.
47///
48/// Biases logits to encourage valid JSON output by boosting structural tokens
49/// and penalizing tokens that would break JSON syntax at the current state.
50///
51/// Uses token ID heuristics (ASCII-range tokens for `{`, `}`, `"`, etc.)
52/// which works with most tokenizers where single-character punctuation maps
53/// to predictable token IDs.
54#[derive(Debug)]
55pub struct JsonModeProcessor {
56    state: Mutex<JsonState>,
57    /// Nesting depth — track `{`/`[` vs `}`/`]` balance.
58    depth: Mutex<i32>,
59    /// Bias to add to structural tokens (positive = encourage).
60    structural_bias: f32,
61    /// Penalty to apply to clearly invalid tokens (negative = discourage).
62    invalid_penalty: f32,
63}
64
65impl JsonModeProcessor {
66    pub fn new() -> Self {
67        Self {
68            state: Mutex::new(JsonState::Start),
69            depth: Mutex::new(0),
70            structural_bias: 5.0,
71            invalid_penalty: -10.0,
72        }
73    }
74
75    /// Reset state for a new generation.
76    pub fn reset(&self) {
77        *self.state.lock() = JsonState::Start;
78        *self.depth.lock() = 0;
79    }
80
81    /// Get current state (for testing).
82    pub fn current_state(&self) -> JsonState {
83        *self.state.lock()
84    }
85
86    /// Apply structural biases based on the generated text so far.
87    ///
88    /// Examines the last generated token's text to update state, then
89    /// biases logits for the next step.
90    pub fn apply_biases(&self, logits: &mut [f32], generated_text: &str) {
91        // Update state based on what was just generated
92        self.update_state(generated_text);
93
94        let state = *self.state.lock();
95        let depth = *self.depth.lock();
96        let vocab_size = logits.len();
97
98        // Apply biases based on current state.
99        // We use ASCII token IDs as heuristic — for production, this needs
100        // proper tokenizer integration.
101        match state {
102            JsonState::Start => {
103                // Boost `{` (0x7B = 123) and `[` (0x5B = 91)
104                self.bias_token(logits, 123, self.structural_bias);
105                self.bias_token(logits, 91, self.structural_bias);
106            }
107            JsonState::ObjectStart => {
108                // Boost `"` (0x22 = 34) for key start, or `}` (0x7D = 125) for empty
109                self.bias_token(logits, 34, self.structural_bias);
110                if depth <= 1 {
111                    self.bias_token(logits, 125, self.structural_bias * 0.5);
112                }
113            }
114            JsonState::AfterKey => {
115                // Boost `:` (0x3A = 58)
116                self.bias_token(logits, 58, self.structural_bias);
117            }
118            JsonState::AfterValue => {
119                // Boost `,` (0x2C = 44) or closing `}` / `]`
120                self.bias_token(logits, 44, self.structural_bias);
121                self.bias_token(logits, 125, self.structural_bias);
122                self.bias_token(logits, 93, self.structural_bias);
123            }
124            JsonState::Done => {
125                // Penalize everything except EOS — we're done
126                // Boost common EOS token positions
127                if vocab_size > 2 {
128                    // Many tokenizers use token 0, 1, or 2 as EOS
129                    self.bias_token(logits, 0, self.structural_bias);
130                    // Penalize content tokens to discourage continuing
131                    for i in 32..vocab_size.min(256) {
132                        logits[i] += self.invalid_penalty * 0.3;
133                    }
134                }
135            }
136            _ => {}
137        }
138    }
139
140    fn bias_token(&self, logits: &mut [f32], token_id: usize, bias: f32) {
141        if token_id < logits.len() {
142            logits[token_id] += bias;
143        }
144    }
145
146    /// Update internal state based on accumulated generated text.
147    fn update_state(&self, text: &str) {
148        let mut state = self.state.lock();
149        let mut depth = self.depth.lock();
150
151        for ch in text.chars() {
152            match (*state, ch) {
153                (JsonState::Start, '{') => {
154                    *state = JsonState::ObjectStart;
155                    *depth += 1;
156                }
157                (JsonState::Start, '[') => {
158                    *state = JsonState::ArrayStart;
159                    *depth += 1;
160                }
161                (JsonState::ObjectStart, '"') => {
162                    *state = JsonState::InString;
163                }
164                (JsonState::ObjectStart, '}') => {
165                    *depth -= 1;
166                    *state = if *depth <= 0 {
167                        JsonState::Done
168                    } else {
169                        JsonState::AfterValue
170                    };
171                }
172                (JsonState::InString, '"') => {
173                    // End of string — could be key or value
174                    *state = JsonState::AfterKey;
175                }
176                (JsonState::InString, '\\') => {
177                    // Escape — next char is part of string (simplified)
178                }
179                (JsonState::AfterKey, ':') => {
180                    *state = JsonState::AfterColon;
181                }
182                (JsonState::AfterColon, '"') => {
183                    *state = JsonState::InString;
184                }
185                (JsonState::AfterColon, '{') => {
186                    *state = JsonState::ObjectStart;
187                    *depth += 1;
188                }
189                (JsonState::AfterColon, '[') => {
190                    *state = JsonState::ArrayStart;
191                    *depth += 1;
192                }
193                (JsonState::AfterColon, _)
194                    if ch.is_ascii_digit() || ch == '-' || ch == 't' || ch == 'f' || ch == 'n' =>
195                {
196                    // Number, true, false, null — treat as value
197                    *state = JsonState::AfterValue;
198                }
199                (JsonState::AfterValue, ',') => {
200                    *state = JsonState::ObjectStart;
201                }
202                (JsonState::AfterValue, '}') => {
203                    *depth -= 1;
204                    *state = if *depth <= 0 {
205                        JsonState::Done
206                    } else {
207                        JsonState::AfterValue
208                    };
209                }
210                (JsonState::AfterValue, ']') => {
211                    *depth -= 1;
212                    *state = if *depth <= 0 {
213                        JsonState::Done
214                    } else {
215                        JsonState::AfterValue
216                    };
217                }
218                (JsonState::ArrayStart, ']') => {
219                    *depth -= 1;
220                    *state = if *depth <= 0 {
221                        JsonState::Done
222                    } else {
223                        JsonState::AfterValue
224                    };
225                }
226                (JsonState::ArrayStart, '"') => {
227                    *state = JsonState::InString;
228                }
229                (JsonState::ArrayStart, '{') => {
230                    *state = JsonState::ObjectStart;
231                    *depth += 1;
232                }
233                _ => {
234                    // Whitespace or unrecognized — stay in current state
235                }
236            }
237        }
238    }
239}
240
241impl Default for JsonModeProcessor {
242    fn default() -> Self {
243        Self::new()
244    }
245}
246
247impl LogitsProcessor for JsonModeProcessor {
248    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
249        // Build the generated text from previous tokens
250        // In a real implementation this would use the tokenizer to decode
251        // For now, use the previous_tokens as ASCII approximation
252        let generated: String = ctx
253            .previous_tokens
254            .iter()
255            .filter_map(|t| {
256                let v = t.get();
257                if v < 128 {
258                    Some(v as u8 as char)
259                } else {
260                    None
261                }
262            })
263            .collect();
264
265        self.apply_biases(ctx.logits, &generated);
266        Ok(())
267    }
268
269    fn name(&self) -> &str {
270        "json_mode"
271    }
272
273    fn priority(&self) -> ProcessorPriority {
274        // Run before other processors (temperature, top-k) so biases are
275        // applied to raw logits.
276        ProcessorPriority::High
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn state_tracks_simple_json() {
286        let proc = JsonModeProcessor::new();
287        assert_eq!(proc.current_state(), JsonState::Start);
288
289        proc.update_state("{");
290        assert_eq!(proc.current_state(), JsonState::ObjectStart);
291
292        proc.update_state("\"key\"");
293        assert_eq!(proc.current_state(), JsonState::AfterKey);
294
295        proc.update_state(":");
296        assert_eq!(proc.current_state(), JsonState::AfterColon);
297
298        proc.update_state("\"value\"");
299        // After opening quote → InString, after closing quote → AfterKey
300        // But this is a value string after colon... the state machine is simplified
301        // It treats all strings the same (AfterKey). For production, we'd need
302        // to track whether we're parsing a key or value string.
303        assert_eq!(proc.current_state(), JsonState::AfterKey);
304    }
305
306    #[test]
307    fn state_tracks_nested_json() {
308        let proc = JsonModeProcessor::new();
309        proc.update_state("{\"a\":{\"b\":1}}");
310        assert_eq!(proc.current_state(), JsonState::Done);
311    }
312
313    #[test]
314    fn state_done_after_closing_brace() {
315        let proc = JsonModeProcessor::new();
316        proc.update_state("{}");
317        assert_eq!(proc.current_state(), JsonState::Done);
318    }
319
320    #[test]
321    fn bias_boosts_structural_tokens() {
322        let proc = JsonModeProcessor::new();
323        let mut logits = vec![0.0f32; 256];
324
325        // At start, should boost `{` (123) and `[` (91)
326        proc.apply_biases(&mut logits, "");
327        assert!(logits[123] > 0.0, "Should boost {{ token");
328        assert!(logits[91] > 0.0, "Should boost [ token");
329    }
330
331    #[test]
332    fn reset_clears_state() {
333        let proc = JsonModeProcessor::new();
334        proc.update_state("{\"a\":1}");
335        assert_eq!(proc.current_state(), JsonState::Done);
336
337        proc.reset();
338        assert_eq!(proc.current_state(), JsonState::Start);
339    }
340}