kyro 0.1.1

A high-performance ML inference engine
#![allow(dead_code)]

use candle_core::{Result, Tensor};
use std::collections::HashSet;

#[derive(Default)]
#[allow(dead_code)]
pub enum GrammarConstraint {
    #[default]
    None,
    Json,
    Regex(String),
}

#[derive(Default)]
#[allow(dead_code)]
pub struct GrammarState {
    pub stack: Vec<String>,
    pub current_text: String,
}

#[allow(dead_code)]
pub struct GrammarLogitsProcessor {
    pub constraint: GrammarConstraint,
    pub state: GrammarState,
}

impl GrammarLogitsProcessor {
    pub fn new(constraint: GrammarConstraint) -> Self {
        Self {
            constraint,
            state: GrammarState::default(),
        }
    }

    pub fn apply_grammar_mask(&mut self, logits: &Tensor, vocab_size: usize) -> Result<Tensor> {
        let valid_tokens = self.get_valid_tokens(vocab_size);
        let mut mask_data = vec![f32::NEG_INFINITY; vocab_size];
        for &token_id in &valid_tokens {
            mask_data[token_id] = 0.0;
        }
        let mask = Tensor::from_slice(&mask_data, (vocab_size,), logits.device())?;
        logits.broadcast_add(&mask)
    }

    fn get_valid_tokens(&self, vocab_size: usize) -> HashSet<usize> {
        match &self.constraint {
            GrammarConstraint::None => (0..vocab_size).collect(),
            GrammarConstraint::Json => {
                let valid = vec![0, 10, 32, 34, 91, 93, 123, 125];
                valid.into_iter().collect()
            }
            GrammarConstraint::Regex(_) => (0..vocab_size).collect(),
        }
    }

    pub fn advance(&mut self, token_text: &str) {
        self.state.current_text.push_str(token_text);
    }
}