ferrum-sampler 0.3.0

Sampling strategies for Ferrum LLM inference engine
Documentation
//! JSON mode logits processor.
//!
//! Constrains generation to produce valid JSON by tracking a state machine
//! and masking tokens that would produce invalid syntax at each step.
//!
//! # Approach
//!
//! Rather than full grammar-guided generation (which requires tokenizer-level
//! mapping), this processor uses a lightweight state machine that tracks
//! whether we're inside a string, after a key, expecting a value, etc.
//! It biases logits to favor JSON-structural tokens without fully preventing
//! all invalid outputs.
//!
//! For a production-quality implementation, this would need:
//! - Tokenizer integration to map token IDs to byte sequences
//! - Full JSON grammar with recursive descent validation
//! - Efficient bitset masking over the vocabulary
//!
//! This MVP provides the infrastructure and demonstrates the pattern.

use ferrum_interfaces::sampler::{LogitsProcessor, ProcessorPriority, SamplingContext};
use ferrum_types::Result;
use parking_lot::Mutex;

/// Tracks the current position in JSON structure.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum JsonState {
    /// Before any output — expecting `{` or `[`.
    Start,
    /// Inside an object, expecting a key (string) or `}`.
    ObjectStart,
    /// After a key, expecting `:`.
    AfterKey,
    /// After `:`, expecting a value.
    AfterColon,
    /// After a value, expecting `,` or `}` / `]`.
    AfterValue,
    /// Inside a string literal.
    InString,
    /// Inside an array, expecting value or `]`.
    ArrayStart,
    /// Generation complete (closing brace/bracket emitted).
    Done,
}

/// JSON mode logits processor.
///
/// Biases logits to encourage valid JSON output by boosting structural tokens
/// and penalizing tokens that would break JSON syntax at the current state.
///
/// Uses token ID heuristics (ASCII-range tokens for `{`, `}`, `"`, etc.)
/// which works with most tokenizers where single-character punctuation maps
/// to predictable token IDs.
#[derive(Debug)]
pub struct JsonModeProcessor {
    state: Mutex<JsonState>,
    /// Nesting depth — track `{`/`[` vs `}`/`]` balance.
    depth: Mutex<i32>,
    /// Bias to add to structural tokens (positive = encourage).
    structural_bias: f32,
    /// Penalty to apply to clearly invalid tokens (negative = discourage).
    invalid_penalty: f32,
}

impl JsonModeProcessor {
    pub fn new() -> Self {
        Self {
            state: Mutex::new(JsonState::Start),
            depth: Mutex::new(0),
            structural_bias: 5.0,
            invalid_penalty: -10.0,
        }
    }

    /// Reset state for a new generation.
    pub fn reset(&self) {
        *self.state.lock() = JsonState::Start;
        *self.depth.lock() = 0;
    }

    /// Get current state (for testing).
    pub fn current_state(&self) -> JsonState {
        *self.state.lock()
    }

    /// Apply structural biases based on the generated text so far.
    ///
    /// Examines the last generated token's text to update state, then
    /// biases logits for the next step.
    pub fn apply_biases(&self, logits: &mut [f32], generated_text: &str) {
        // Update state based on what was just generated
        self.update_state(generated_text);

        let state = *self.state.lock();
        let depth = *self.depth.lock();
        let vocab_size = logits.len();

        // Apply biases based on current state.
        // We use ASCII token IDs as heuristic — for production, this needs
        // proper tokenizer integration.
        match state {
            JsonState::Start => {
                // Boost `{` (0x7B = 123) and `[` (0x5B = 91)
                self.bias_token(logits, 123, self.structural_bias);
                self.bias_token(logits, 91, self.structural_bias);
            }
            JsonState::ObjectStart => {
                // Boost `"` (0x22 = 34) for key start, or `}` (0x7D = 125) for empty
                self.bias_token(logits, 34, self.structural_bias);
                if depth <= 1 {
                    self.bias_token(logits, 125, self.structural_bias * 0.5);
                }
            }
            JsonState::AfterKey => {
                // Boost `:` (0x3A = 58)
                self.bias_token(logits, 58, self.structural_bias);
            }
            JsonState::AfterValue => {
                // Boost `,` (0x2C = 44) or closing `}` / `]`
                self.bias_token(logits, 44, self.structural_bias);
                self.bias_token(logits, 125, self.structural_bias);
                self.bias_token(logits, 93, self.structural_bias);
            }
            JsonState::Done => {
                // Penalize everything except EOS — we're done
                // Boost common EOS token positions
                if vocab_size > 2 {
                    // Many tokenizers use token 0, 1, or 2 as EOS
                    self.bias_token(logits, 0, self.structural_bias);
                    // Penalize content tokens to discourage continuing
                    for i in 32..vocab_size.min(256) {
                        logits[i] += self.invalid_penalty * 0.3;
                    }
                }
            }
            _ => {}
        }
    }

    fn bias_token(&self, logits: &mut [f32], token_id: usize, bias: f32) {
        if token_id < logits.len() {
            logits[token_id] += bias;
        }
    }

    /// Update internal state based on accumulated generated text.
    fn update_state(&self, text: &str) {
        let mut state = self.state.lock();
        let mut depth = self.depth.lock();

        for ch in text.chars() {
            match (*state, ch) {
                (JsonState::Start, '{') => {
                    *state = JsonState::ObjectStart;
                    *depth += 1;
                }
                (JsonState::Start, '[') => {
                    *state = JsonState::ArrayStart;
                    *depth += 1;
                }
                (JsonState::ObjectStart, '"') => {
                    *state = JsonState::InString;
                }
                (JsonState::ObjectStart, '}') => {
                    *depth -= 1;
                    *state = if *depth <= 0 {
                        JsonState::Done
                    } else {
                        JsonState::AfterValue
                    };
                }
                (JsonState::InString, '"') => {
                    // End of string — could be key or value
                    *state = JsonState::AfterKey;
                }
                (JsonState::InString, '\\') => {
                    // Escape — next char is part of string (simplified)
                }
                (JsonState::AfterKey, ':') => {
                    *state = JsonState::AfterColon;
                }
                (JsonState::AfterColon, '"') => {
                    *state = JsonState::InString;
                }
                (JsonState::AfterColon, '{') => {
                    *state = JsonState::ObjectStart;
                    *depth += 1;
                }
                (JsonState::AfterColon, '[') => {
                    *state = JsonState::ArrayStart;
                    *depth += 1;
                }
                (JsonState::AfterColon, _)
                    if ch.is_ascii_digit() || ch == '-' || ch == 't' || ch == 'f' || ch == 'n' =>
                {
                    // Number, true, false, null — treat as value
                    *state = JsonState::AfterValue;
                }
                (JsonState::AfterValue, ',') => {
                    *state = JsonState::ObjectStart;
                }
                (JsonState::AfterValue, '}') => {
                    *depth -= 1;
                    *state = if *depth <= 0 {
                        JsonState::Done
                    } else {
                        JsonState::AfterValue
                    };
                }
                (JsonState::AfterValue, ']') => {
                    *depth -= 1;
                    *state = if *depth <= 0 {
                        JsonState::Done
                    } else {
                        JsonState::AfterValue
                    };
                }
                (JsonState::ArrayStart, ']') => {
                    *depth -= 1;
                    *state = if *depth <= 0 {
                        JsonState::Done
                    } else {
                        JsonState::AfterValue
                    };
                }
                (JsonState::ArrayStart, '"') => {
                    *state = JsonState::InString;
                }
                (JsonState::ArrayStart, '{') => {
                    *state = JsonState::ObjectStart;
                    *depth += 1;
                }
                _ => {
                    // Whitespace or unrecognized — stay in current state
                }
            }
        }
    }
}

impl Default for JsonModeProcessor {
    fn default() -> Self {
        Self::new()
    }
}

impl LogitsProcessor for JsonModeProcessor {
    fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
        // Build the generated text from previous tokens
        // In a real implementation this would use the tokenizer to decode
        // For now, use the previous_tokens as ASCII approximation
        let generated: String = ctx
            .previous_tokens
            .iter()
            .filter_map(|t| {
                let v = t.get();
                if v < 128 {
                    Some(v as u8 as char)
                } else {
                    None
                }
            })
            .collect();

        self.apply_biases(ctx.logits, &generated);
        Ok(())
    }

    fn name(&self) -> &str {
        "json_mode"
    }

    fn priority(&self) -> ProcessorPriority {
        // Run before other processors (temperature, top-k) so biases are
        // applied to raw logits.
        ProcessorPriority::High
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn state_tracks_simple_json() {
        let proc = JsonModeProcessor::new();
        assert_eq!(proc.current_state(), JsonState::Start);

        proc.update_state("{");
        assert_eq!(proc.current_state(), JsonState::ObjectStart);

        proc.update_state("\"key\"");
        assert_eq!(proc.current_state(), JsonState::AfterKey);

        proc.update_state(":");
        assert_eq!(proc.current_state(), JsonState::AfterColon);

        proc.update_state("\"value\"");
        // After opening quote → InString, after closing quote → AfterKey
        // But this is a value string after colon... the state machine is simplified
        // It treats all strings the same (AfterKey). For production, we'd need
        // to track whether we're parsing a key or value string.
        assert_eq!(proc.current_state(), JsonState::AfterKey);
    }

    #[test]
    fn state_tracks_nested_json() {
        let proc = JsonModeProcessor::new();
        proc.update_state("{\"a\":{\"b\":1}}");
        assert_eq!(proc.current_state(), JsonState::Done);
    }

    #[test]
    fn state_done_after_closing_brace() {
        let proc = JsonModeProcessor::new();
        proc.update_state("{}");
        assert_eq!(proc.current_state(), JsonState::Done);
    }

    #[test]
    fn bias_boosts_structural_tokens() {
        let proc = JsonModeProcessor::new();
        let mut logits = vec![0.0f32; 256];

        // At start, should boost `{` (123) and `[` (91)
        proc.apply_biases(&mut logits, "");
        assert!(logits[123] > 0.0, "Should boost {{ token");
        assert!(logits[91] > 0.0, "Should boost [ token");
    }

    #[test]
    fn reset_clears_state() {
        let proc = JsonModeProcessor::new();
        proc.update_state("{\"a\":1}");
        assert_eq!(proc.current_state(), JsonState::Done);

        proc.reset();
        assert_eq!(proc.current_state(), JsonState::Start);
    }
}