oxide-rs 0.1.1

AI Inference library and CLI in Rust - llama.cpp style
Documentation
use std::fs::{self, File};
use std::io::Read;
use std::path::PathBuf;

use anyhow::Result;
use sha2::{Digest, Sha256};
use shimmytok::Tokenizer as ShimmyTokenizer;

const CACHE_DIR: &str = ".cache/oxide";

pub struct TokenizerWrapper {
    inner: ShimmyTokenizer,
    eos_token_id: u32,
    pending_tokens: Vec<u32>,
}

fn get_cache_path(model_path: &PathBuf) -> Result<PathBuf> {
    let mut hasher = Sha256::new();

    let mut file = File::open(model_path)?;
    let mut header = vec![0u8; 65536];
    file.read(&mut header)?;
    hasher.update(&header);

    let file_size = file.metadata()?.len();
    hasher.update(file_size.to_le_bytes());

    let hash = format!("{:x}", hasher.finalize());

    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
    let cache_dir = PathBuf::from(home).join(CACHE_DIR);
    fs::create_dir_all(&cache_dir)?;

    Ok(cache_dir.join(format!("{}.tokenizer", hash)))
}

impl TokenizerWrapper {
    pub fn from_gguf(path: &PathBuf) -> Result<Self> {
        let cache_path = get_cache_path(path)?;

        let inner = if cache_path.exists() {
            tracing::info!("Loading tokenizer from cache: {:?}", cache_path);
            ShimmyTokenizer::from_gguf_file(&cache_path)
                .map_err(|e| {
                    tracing::warn!("Cache corrupted, reloading from GGUF: {}", e);
                    e
                })
                .or_else(|_| ShimmyTokenizer::from_gguf_file(path))?
        } else {
            tracing::info!("Loading tokenizer from GGUF (first time)...");
            let tokenizer = ShimmyTokenizer::from_gguf_file(path)
                .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

            if let Err(e) = fs::copy(path, &cache_path) {
                tracing::warn!("Failed to cache tokenizer: {}", e);
            } else {
                tracing::info!("Tokenizer cached to {:?}", cache_path);
            }

            tokenizer
        };

        let eos_token_id = inner.eos_token();
        tracing::info!("Loaded tokenizer, EOS={}", eos_token_id);

        Ok(Self {
            inner,
            eos_token_id,
            pending_tokens: Vec::new(),
        })
    }

    pub fn from_file(path: &PathBuf) -> Result<Self> {
        let inner = ShimmyTokenizer::from_gguf_file(path)
            .map_err(|e| anyhow::anyhow!("Failed to load tokenizer: {}", e))?;

        let eos_token_id = inner.eos_token();

        tracing::info!("Loaded tokenizer from file, EOS={}", eos_token_id);

        Ok(Self {
            inner,
            eos_token_id,
            pending_tokens: Vec::new(),
        })
    }

    pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
        self.inner
            .encode(text, true)
            .map_err(|e| anyhow::anyhow!("Encode failed: {}", e))
    }

    pub fn decode(&self, tokens: &[u32]) -> Result<String> {
        self.inner
            .decode(tokens, false)
            .map_err(|e| anyhow::anyhow!("Decode failed: {}", e))
    }

    pub fn eos_token_id(&self) -> u32 {
        self.eos_token_id
    }

    pub fn clear_cache(&mut self) {
        self.pending_tokens.clear();
    }

    pub fn decode_next(&mut self, token: u32) -> Result<Option<String>> {
        self.pending_tokens.push(token);

        let decoded = self.decode(&self.pending_tokens)?;

        let prev_decoded: String = {
            let prev_tokens = &self.pending_tokens[..self.pending_tokens.len() - 1];
            if prev_tokens.is_empty() {
                String::new()
            } else {
                self.decode(prev_tokens)?
            }
        };

        Ok(get_new_text(&decoded, &prev_decoded))
    }

    pub fn decode_rest(&mut self) -> Result<Option<String>> {
        if self.pending_tokens.is_empty() {
            return Ok(None);
        }

        let decoded = self.decode(&self.pending_tokens)?;

        let prev_decoded: String = {
            if self.pending_tokens.len() > 1 {
                let prev_tokens = &self.pending_tokens[..self.pending_tokens.len() - 1];
                self.decode(prev_tokens)?
            } else {
                String::new()
            }
        };

        Ok(get_new_text(&decoded, &prev_decoded))
    }
}

fn get_new_text(full: &str, prev: &str) -> Option<String> {
    let prev_chars: Vec<char> = prev.chars().collect();
    let full_chars: Vec<char> = full.chars().collect();

    if full_chars.len() > prev_chars.len() {
        let new_chars: String = full_chars[prev_chars.len()..].iter().collect();
        Some(new_chars)
    } else {
        None
    }
}