wax-core 0.1.0

Core inference engine for wax, a small Candle-based local LLM runner
Documentation
use crate::{Result, WaxError};

pub struct TokenOutputStream {
    tokenizer: tokenizers::Tokenizer,
    tokens: Vec<u32>,
    prev_index: usize,
    current_index: usize,
}

impl TokenOutputStream {
    pub fn new(tokenizer: tokenizers::Tokenizer) -> Self {
        Self {
            tokenizer,
            tokens: Vec::new(),
            prev_index: 0,
            current_index: 0,
        }
    }

    pub fn next_token(&mut self, token: u32) -> Result<Option<String>> {
        let prev_text = if self.tokens.is_empty() {
            String::new()
        } else {
            self.decode(&self.tokens[self.prev_index..self.current_index])?
        };

        self.tokens.push(token);
        let text = self.decode(&self.tokens[self.prev_index..])?;
        if text.len() > prev_text.len() && text.chars().last().is_some_and(char::is_alphanumeric) {
            let (_, delta) = text.split_at(prev_text.len());
            self.prev_index = self.current_index;
            self.current_index = self.tokens.len();
            Ok(Some(delta.to_string()))
        } else {
            Ok(None)
        }
    }

    pub fn decode_rest(&self) -> Result<Option<String>> {
        let prev_text = if self.tokens.is_empty() {
            String::new()
        } else {
            self.decode(&self.tokens[self.prev_index..self.current_index])?
        };
        let text = self.decode(&self.tokens[self.prev_index..])?;
        if text.len() > prev_text.len() {
            let (_, delta) = text.split_at(prev_text.len());
            Ok(Some(delta.to_string()))
        } else {
            Ok(None)
        }
    }

    fn decode(&self, tokens: &[u32]) -> Result<String> {
        self.tokenizer
            .decode(tokens, true)
            .map_err(WaxError::tokenizer)
    }
}

#[cfg(test)]
mod tests {
    use ahash::AHashMap;
    use tokenizers::{models::wordlevel::WordLevel, Tokenizer};

    use super::TokenOutputStream;

    fn tokenizer() -> Tokenizer {
        let vocab = AHashMap::from([
            ("Hello".to_string(), 0),
            ("world".to_string(), 1),
            ("!".to_string(), 2),
            ("[UNK]".to_string(), 3),
        ]);
        let model = WordLevel::builder()
            .vocab(vocab)
            .unk_token("[UNK]".to_string())
            .build()
            .unwrap();
        Tokenizer::new(model)
    }

    #[test]
    fn streams_alphanumeric_tokens_and_flushes_punctuation_at_end() {
        let mut stream = TokenOutputStream::new(tokenizer());

        assert_eq!(stream.next_token(0).unwrap(), Some("Hello".to_string()));
        assert_eq!(stream.next_token(1).unwrap(), Some(" world".to_string()));
        assert_eq!(stream.next_token(2).unwrap(), None);
        assert_eq!(stream.decode_rest().unwrap(), Some(" !".to_string()));
    }
}