flint-ai 0.1.0

A lightweight embedded AI runtime for every device
Documentation
use std::env;
use std::fs::File;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};


fn main() {
    let mut args = env::args();
    let _ = args.next();
    let command = match args.next() {
        Some(cmd) => cmd,
        None => {
            print_usage();
            return;
        }
    };

    match command.as_str() {
        "use" => {
            if let Err(msg) = handle_use(args) {
                println!("{msg}");
                std::process::exit(1);
            }
        }
        "remove" => {
            if let Err(msg) = handle_remove(args) {
                println!("{msg}");
                std::process::exit(1);
            }
        }
        "list" => {
            if let Err(msg) = handle_list() {
                println!("{msg}");
                std::process::exit(1);
            }
        }
        _ => {
            print_usage();
            return;
        }
    }
}

fn handle_use(mut args: env::Args) -> Result<(), String> {
    let model_name = match args.next() {
        Some(name) => name,
        None => {
            return Err("Usage: flint use <owner/model-name>".to_string());
        }
    };

    if !model_name.contains('/') {
        return Err("invalid model name. use format: owner/model-name".to_string());
    }

    let filename = match fetch_gguf_filename(&model_name) {
        Ok(name) => name,
        Err(msg) => {
            return Err(msg);
        }
    };

    println!("found: {}", filename);
    println!("downloading...");

    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
    let models_dir = PathBuf::from(home).join(".flint").join("models");
    std::fs::create_dir_all(&models_dir)
        .map_err(|e| format!("failed to create models dir: {}", e))?;
    let save_path = models_dir.join(&filename);

    if save_path.exists() {
        println!("{} already downloaded", filename);
        return Ok(());
    }

    if let Err(err) = download_model(&model_name, &filename, &save_path) {
        return Err(format!("download failed: {}", err));
    }

    println!("ready. use it in your code:");
    println!("     let ai = LocalAI::new('microsoft/phi-3-mini');");
    println!("     let response = ai.chat('your message here')?;");
    Ok(())
}

fn download_model(model_name: &str, filename: &str, save_path: &Path) -> Result<(), String> {
    let url = format!(
        "https://huggingface.co/{}/resolve/main/{}",
        model_name, filename
    );

    let client = reqwest::blocking::Client::new();
    let mut response = client.get(url).send().map_err(|e| e.to_string())?;
    if !response.status().is_success() {
        return Err(format!("http {}", response.status()));
    }

    let total = response
        .content_length()
        .ok_or_else(|| "missing content length".to_string())?;

    let mut file = File::create(save_path).map_err(|e| e.to_string())?;
    let mut buffer = [0u8; 8192];
    let mut downloaded: u64 = 0;
    let mut next_report: u64 = 256 * 1024 * 1024;
    loop {
        let bytes_read = response.read(&mut buffer).map_err(|e| e.to_string())?;
        if bytes_read == 0 {
            break;
        }
        file.write_all(&buffer[..bytes_read])
            .map_err(|e| e.to_string())?;
        downloaded += bytes_read as u64;
        if downloaded >= next_report {
            println!(
                "progress: {} / {}",
                format_size(downloaded),
                format_size(total)
            );
            next_report += 256 * 1024 * 1024;
        }
    }

    println!("done. {} downloaded", format_size(total));
    Ok(())
}

fn fetch_gguf_filename(model_name: &str) -> Result<String, String> {
    let not_found = || {
        format!(
            "could not find a .gguf file in {}\nmake sure the model exists on huggingface.co",
            model_name
        )
    };
    let url = format!("https://huggingface.co/api/models/{}", model_name);
    let client = reqwest::blocking::Client::new();
    let response = client.get(url).send().map_err(|_| not_found())?;
    if !response.status().is_success() {
        return Err(not_found());
    }

    let body: serde_json::Value = response
        .json::<serde_json::Value>()
        .map_err(|e| format!("failed to parse api response: {}", e))?;

    let siblings = body
        .get("siblings")
        .and_then(|v: &serde_json::Value| v.as_array())
        .ok_or_else(|| format!("no files found in model repo"))?;

    let mut gguf_file: Option<String> = None;
    for item in siblings {
        if let Some(name) = item.get("rfilename").and_then(|v: &serde_json::Value| v.as_str()) {
            if name.ends_with(".gguf") {
                gguf_file = Some(name.to_string());
                break;
            }
        }
    }

    let filename = gguf_file.ok_or_else(|| {
        format!("no .gguf file found in this repo")
    })?;

    Ok(filename)
}

fn format_size(bytes: u64) -> String {
    let gb = 1024.0 * 1024.0 * 1024.0;
    let mb = 1024.0 * 1024.0;
    let b = bytes as f64;
    if b >= gb {
        format!("{:.1} GB", b / gb)
    } else {
        format!("{:.0} MB", b / mb)
    }
}

fn print_usage() {
    println!("Usage:");
    println!("  flint use <owner/model-name>     download a model");
    println!("  flint remove <owner/model-name>  remove a model");
    println!("  flint remove --all               remove all models");
    println!("  flint list                       list downloaded models");
}

fn models_dir() -> PathBuf {
    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
    PathBuf::from(home).join(".flint").join("models")
}

fn handle_remove(mut args: env::Args) -> Result<(), String> {
    let model_name = args.next().unwrap_or_default();

    if model_name == "--all" {
        let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
        let models_dir = std::path::PathBuf::from(home).join(".flint").join("models");
        if models_dir.exists() {
            std::fs::remove_dir_all(&models_dir).unwrap_or(());
            println!("removed all models from ~/.flint/models/");
        } else {
            println!("no models found");
        }
        return Ok(());
    }

    if model_name.is_empty() {
        println!("Usage: flint remove <owner/model-name>");
        println!("       flint remove --all");
        return Ok(());
    }

    if !model_name.contains('/') {
        println!("invalid model name. use format: owner/model-name");
        return Ok(());
    }

    let last_part = model_name.rsplit('/').next().unwrap_or(&model_name);
    let last_lower = last_part.to_lowercase();
    let core = last_lower
        .trim_end_matches("-gguf")
        .trim_end_matches("_gguf");
    let words: Vec<&str> = core
        .split(|c: char| c == '-' || c == '_' || c == '.')
        .filter(|w| w.len() > 1)
        .collect();

    let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
    let models_dir = std::path::PathBuf::from(home).join(".flint").join("models");

    if let Ok(entries) = std::fs::read_dir(&models_dir) {
        for entry in entries.flatten() {
            let fname = entry.file_name().to_string_lossy().to_lowercase();
            if fname.ends_with(".gguf") {
                let matches = words.iter().all(|w| fname.contains(w));
                if matches {
                    std::fs::remove_file(entry.path()).unwrap_or(());
                    println!("removed: {}", entry.file_name().to_string_lossy());
                    return Ok(());
                }
            }
        }
    }

    println!("model not found: {}", model_name);
    Ok(())
}

fn handle_list() -> Result<(), String> {
    let dir = models_dir();
    if !dir.exists() {
        println!("no models downloaded. run: flint use <model>");
        return Ok(());
    }

    let mut items: Vec<(String, u64)> = Vec::new();
    if let Ok(entries) = std::fs::read_dir(&dir) {
        for entry in entries.flatten() {
            let path = entry.path();
            if path
                .extension()
                .and_then(|ext| ext.to_str())
                .map(|ext| ext.eq_ignore_ascii_case("gguf"))
                .unwrap_or(false)
            {
                let size = entry.metadata().map(|m| m.len()).unwrap_or(0);
                items.push((entry.file_name().to_string_lossy().into_owned(), size));
            }
        }
    }

    if items.is_empty() {
        println!("no models downloaded. run: flint use <model>");
        return Ok(());
    }

    println!("downloaded models (~/.flint/models/):");
    for (name, size) in items {
        println!("  {:<40} {}", name, format_size_mb(size));
    }
    Ok(())
}

fn format_size_mb(bytes: u64) -> String {
    let mb = 1024.0 * 1024.0;
    let b = bytes as f64;
    format!("{:.0} MB", b / mb)
}