use ferrum_interfaces::sampler::{LogitsProcessor, ProcessorPriority, SamplingContext};
use ferrum_types::Result;
use parking_lot::Mutex;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JsonState {
Start,
ObjectStart,
AfterKey,
AfterColon,
AfterValue,
InString,
ArrayStart,
Done,
}
#[derive(Debug)]
pub struct JsonModeProcessor {
state: Mutex<JsonState>,
depth: Mutex<i32>,
structural_bias: f32,
invalid_penalty: f32,
}
impl JsonModeProcessor {
pub fn new() -> Self {
Self {
state: Mutex::new(JsonState::Start),
depth: Mutex::new(0),
structural_bias: 5.0,
invalid_penalty: -10.0,
}
}
pub fn reset(&self) {
*self.state.lock() = JsonState::Start;
*self.depth.lock() = 0;
}
pub fn current_state(&self) -> JsonState {
*self.state.lock()
}
pub fn apply_biases(&self, logits: &mut [f32], generated_text: &str) {
self.update_state(generated_text);
let state = *self.state.lock();
let depth = *self.depth.lock();
let vocab_size = logits.len();
match state {
JsonState::Start => {
self.bias_token(logits, 123, self.structural_bias);
self.bias_token(logits, 91, self.structural_bias);
}
JsonState::ObjectStart => {
self.bias_token(logits, 34, self.structural_bias);
if depth <= 1 {
self.bias_token(logits, 125, self.structural_bias * 0.5);
}
}
JsonState::AfterKey => {
self.bias_token(logits, 58, self.structural_bias);
}
JsonState::AfterValue => {
self.bias_token(logits, 44, self.structural_bias);
self.bias_token(logits, 125, self.structural_bias);
self.bias_token(logits, 93, self.structural_bias);
}
JsonState::Done => {
if vocab_size > 2 {
self.bias_token(logits, 0, self.structural_bias);
for i in 32..vocab_size.min(256) {
logits[i] += self.invalid_penalty * 0.3;
}
}
}
_ => {}
}
}
fn bias_token(&self, logits: &mut [f32], token_id: usize, bias: f32) {
if token_id < logits.len() {
logits[token_id] += bias;
}
}
fn update_state(&self, text: &str) {
let mut state = self.state.lock();
let mut depth = self.depth.lock();
for ch in text.chars() {
match (*state, ch) {
(JsonState::Start, '{') => {
*state = JsonState::ObjectStart;
*depth += 1;
}
(JsonState::Start, '[') => {
*state = JsonState::ArrayStart;
*depth += 1;
}
(JsonState::ObjectStart, '"') => {
*state = JsonState::InString;
}
(JsonState::ObjectStart, '}') => {
*depth -= 1;
*state = if *depth <= 0 {
JsonState::Done
} else {
JsonState::AfterValue
};
}
(JsonState::InString, '"') => {
*state = JsonState::AfterKey;
}
(JsonState::InString, '\\') => {
}
(JsonState::AfterKey, ':') => {
*state = JsonState::AfterColon;
}
(JsonState::AfterColon, '"') => {
*state = JsonState::InString;
}
(JsonState::AfterColon, '{') => {
*state = JsonState::ObjectStart;
*depth += 1;
}
(JsonState::AfterColon, '[') => {
*state = JsonState::ArrayStart;
*depth += 1;
}
(JsonState::AfterColon, _)
if ch.is_ascii_digit() || ch == '-' || ch == 't' || ch == 'f' || ch == 'n' =>
{
*state = JsonState::AfterValue;
}
(JsonState::AfterValue, ',') => {
*state = JsonState::ObjectStart;
}
(JsonState::AfterValue, '}') => {
*depth -= 1;
*state = if *depth <= 0 {
JsonState::Done
} else {
JsonState::AfterValue
};
}
(JsonState::AfterValue, ']') => {
*depth -= 1;
*state = if *depth <= 0 {
JsonState::Done
} else {
JsonState::AfterValue
};
}
(JsonState::ArrayStart, ']') => {
*depth -= 1;
*state = if *depth <= 0 {
JsonState::Done
} else {
JsonState::AfterValue
};
}
(JsonState::ArrayStart, '"') => {
*state = JsonState::InString;
}
(JsonState::ArrayStart, '{') => {
*state = JsonState::ObjectStart;
*depth += 1;
}
_ => {
}
}
}
}
}
impl Default for JsonModeProcessor {
fn default() -> Self {
Self::new()
}
}
impl LogitsProcessor for JsonModeProcessor {
fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
let generated: String = ctx
.previous_tokens
.iter()
.filter_map(|t| {
let v = t.get();
if v < 128 {
Some(v as u8 as char)
} else {
None
}
})
.collect();
self.apply_biases(ctx.logits, &generated);
Ok(())
}
fn name(&self) -> &str {
"json_mode"
}
fn priority(&self) -> ProcessorPriority {
ProcessorPriority::High
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn state_tracks_simple_json() {
let proc = JsonModeProcessor::new();
assert_eq!(proc.current_state(), JsonState::Start);
proc.update_state("{");
assert_eq!(proc.current_state(), JsonState::ObjectStart);
proc.update_state("\"key\"");
assert_eq!(proc.current_state(), JsonState::AfterKey);
proc.update_state(":");
assert_eq!(proc.current_state(), JsonState::AfterColon);
proc.update_state("\"value\"");
assert_eq!(proc.current_state(), JsonState::AfterKey);
}
#[test]
fn state_tracks_nested_json() {
let proc = JsonModeProcessor::new();
proc.update_state("{\"a\":{\"b\":1}}");
assert_eq!(proc.current_state(), JsonState::Done);
}
#[test]
fn state_done_after_closing_brace() {
let proc = JsonModeProcessor::new();
proc.update_state("{}");
assert_eq!(proc.current_state(), JsonState::Done);
}
#[test]
fn bias_boosts_structural_tokens() {
let proc = JsonModeProcessor::new();
let mut logits = vec![0.0f32; 256];
proc.apply_biases(&mut logits, "");
assert!(logits[123] > 0.0, "Should boost {{ token");
assert!(logits[91] > 0.0, "Should boost [ token");
}
#[test]
fn reset_clears_state() {
let proc = JsonModeProcessor::new();
proc.update_state("{\"a\":1}");
assert_eq!(proc.current_state(), JsonState::Done);
proc.reset();
assert_eq!(proc.current_state(), JsonState::Start);
}
}