flint-ai 0.1.0

A lightweight embedded AI runtime for every device
Documentation
use std::num::NonZeroU32;
use std::path::{Path, PathBuf};

use encoding_rs::UTF_8;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;

pub struct LocalAI {
    model_path: String,
}

impl LocalAI {
    pub fn new(model_name: &str) -> Self {
        let last_part = model_name.rsplit('/').next().unwrap_or(model_name);
        let _search_term = last_part
            .to_lowercase()
            .replace('-', "")
            .replace('_', "");

        let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
        let models_dir = PathBuf::from(home).join(".flint").join("models");
        let _ = std::fs::create_dir_all(&models_dir);

        let exact = models_dir.join(format!("{}.gguf", last_part));
        if exact.exists() {
            return Self {
                model_path: exact.to_string_lossy().into_owned(),
            };
        }

        // scan directory for any gguf file containing the model name
        if let Ok(entries) = std::fs::read_dir(&models_dir) {
            for entry in entries.flatten() {
                let fname = entry.file_name().to_string_lossy().to_lowercase();
                if !fname.ends_with(".gguf") {
                    continue;
                }
                // check if the filename contains any meaningful part of the model name
                let last_lower = last_part.to_lowercase();
                // strip common suffixes to get core name
                let core = last_lower
                    .trim_end_matches("-gguf")
                    .trim_end_matches("_gguf");
                // split into words and check if all words appear in filename
                let words: Vec<&str> = core
                    .split(|c: char| c == '-' || c == '_' || c == '.')
                    .filter(|w| w.len() > 1)
                    .collect();
                let matches = words.iter().all(|w| fname.contains(w));
                if matches {
                    return Self {
                        model_path: entry.path().to_string_lossy().into_owned(),
                    };
                }
            }
        }

        Self {
            model_path: models_dir
                .join(format!("{}.gguf", last_part))
                .to_string_lossy()
                .into_owned(),
        }
    }

    pub fn is_available(&self) -> bool {
        Path::new(&self.model_path).exists()
    }

    pub fn chat(&self, message: &str) -> Result<String, String> {
        std::env::set_var("GGML_LOG_LEVEL", "error");
        std::env::set_var("LLAMA_LOG_LEVEL", "error");

        if !self.is_available() {
            return Err("Model not found. Run: localai use <model_name>".to_string());
        }

        let backend =
            LlamaBackend::init().map_err(|e| format!("failed to init backend: {}", e))?;
        let model = LlamaModel::load_from_file(
            &backend,
            Path::new(&self.model_path),
            &LlamaModelParams::default(),
        )
        .map_err(|e| format!("model load failed: {}", e))?;

        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(2048));

        let mut ctx = model
            .new_context(&backend, ctx_params)
            .map_err(|e| format!("failed to create context: {}", e))?;

        let prompt = format!(
            "<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n{}</s>\n<|assistant|>\n",
            message
        );

        let tokens = model
            .str_to_token(&prompt, AddBos::Always)
            .map_err(|e| format!("tokenization failed: {}", e))?;

        if tokens.is_empty() {
            return Err("tokenization produced no tokens".to_string());
        }

        let mut batch = LlamaBatch::new(512, 1);
        for (i, token) in tokens.iter().enumerate() {
            let is_last = i + 1 == tokens.len();
            batch
                .add(*token, i as i32, &[0], is_last)
                .map_err(|e| format!("batch add failed: {}", e))?;
        }

        ctx.decode(&mut batch)
            .map_err(|e| format!("prompt decode failed: {}", e))?;

        let mut sampler = LlamaSampler::greedy();
        let mut output = String::new();
        let mut n_cur = batch.n_tokens();
        let mut decoder = UTF_8.new_decoder();

        for _ in 0..512 {
            let token = sampler.sample(&ctx, batch.n_tokens() - 1);
            sampler.accept(token);

            if model.is_eog_token(token) {
                break;
            }

            let piece = model
                .token_to_piece(token, &mut decoder, false, None)
                .map_err(|e| format!("token decode failed: {}", e))?;
            output.push_str(&piece);

            batch.clear();
            batch
                .add(token, n_cur, &[0], true)
                .map_err(|e| format!("batch add failed: {}", e))?;

            n_cur += 1;
            ctx.decode(&mut batch)
                .map_err(|e| format!("decode failed: {}", e))?;
        }

        Ok(output)
    }

    pub fn model_path(&self) -> &str {
        &self.model_path
    }

    pub fn download(&self) -> Result<(), String> {
        if self.is_available() {
            return Ok(());
        }
        Err("Model not downloaded. Run: flint use <model_name>".to_string())
    }
}