use std::env;
use std::fs::File;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
fn main() {
let mut args = env::args();
let _ = args.next();
let command = match args.next() {
Some(cmd) => cmd,
None => {
print_usage();
return;
}
};
match command.as_str() {
"use" => {
if let Err(msg) = handle_use(args) {
println!("{msg}");
std::process::exit(1);
}
}
"remove" => {
if let Err(msg) = handle_remove(args) {
println!("{msg}");
std::process::exit(1);
}
}
"list" => {
if let Err(msg) = handle_list() {
println!("{msg}");
std::process::exit(1);
}
}
_ => {
print_usage();
return;
}
}
}
fn handle_use(mut args: env::Args) -> Result<(), String> {
let model_name = match args.next() {
Some(name) => name,
None => {
return Err("Usage: flint use <owner/model-name>".to_string());
}
};
if !model_name.contains('/') {
return Err("invalid model name. use format: owner/model-name".to_string());
}
let filename = match fetch_gguf_filename(&model_name) {
Ok(name) => name,
Err(msg) => {
return Err(msg);
}
};
println!("found: {}", filename);
println!("downloading...");
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let models_dir = PathBuf::from(home).join(".flint").join("models");
std::fs::create_dir_all(&models_dir)
.map_err(|e| format!("failed to create models dir: {}", e))?;
let save_path = models_dir.join(&filename);
if save_path.exists() {
println!("{} already downloaded", filename);
return Ok(());
}
if let Err(err) = download_model(&model_name, &filename, &save_path) {
return Err(format!("download failed: {}", err));
}
println!("ready. use it in your code:");
println!(" let ai = LocalAI::new('microsoft/phi-3-mini');");
println!(" let response = ai.chat('your message here')?;");
Ok(())
}
fn download_model(model_name: &str, filename: &str, save_path: &Path) -> Result<(), String> {
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
model_name, filename
);
let client = reqwest::blocking::Client::new();
let mut response = client.get(url).send().map_err(|e| e.to_string())?;
if !response.status().is_success() {
return Err(format!("http {}", response.status()));
}
let total = response
.content_length()
.ok_or_else(|| "missing content length".to_string())?;
let mut file = File::create(save_path).map_err(|e| e.to_string())?;
let mut buffer = [0u8; 8192];
let mut downloaded: u64 = 0;
let mut next_report: u64 = 256 * 1024 * 1024;
loop {
let bytes_read = response.read(&mut buffer).map_err(|e| e.to_string())?;
if bytes_read == 0 {
break;
}
file.write_all(&buffer[..bytes_read])
.map_err(|e| e.to_string())?;
downloaded += bytes_read as u64;
if downloaded >= next_report {
println!(
"progress: {} / {}",
format_size(downloaded),
format_size(total)
);
next_report += 256 * 1024 * 1024;
}
}
println!("done. {} downloaded", format_size(total));
Ok(())
}
fn fetch_gguf_filename(model_name: &str) -> Result<String, String> {
let not_found = || {
format!(
"could not find a .gguf file in {}\nmake sure the model exists on huggingface.co",
model_name
)
};
let url = format!("https://huggingface.co/api/models/{}", model_name);
let client = reqwest::blocking::Client::new();
let response = client.get(url).send().map_err(|_| not_found())?;
if !response.status().is_success() {
return Err(not_found());
}
let body: serde_json::Value = response
.json::<serde_json::Value>()
.map_err(|e| format!("failed to parse api response: {}", e))?;
let siblings = body
.get("siblings")
.and_then(|v: &serde_json::Value| v.as_array())
.ok_or_else(|| format!("no files found in model repo"))?;
let mut gguf_file: Option<String> = None;
for item in siblings {
if let Some(name) = item.get("rfilename").and_then(|v: &serde_json::Value| v.as_str()) {
if name.ends_with(".gguf") {
gguf_file = Some(name.to_string());
break;
}
}
}
let filename = gguf_file.ok_or_else(|| {
format!("no .gguf file found in this repo")
})?;
Ok(filename)
}
fn format_size(bytes: u64) -> String {
let gb = 1024.0 * 1024.0 * 1024.0;
let mb = 1024.0 * 1024.0;
let b = bytes as f64;
if b >= gb {
format!("{:.1} GB", b / gb)
} else {
format!("{:.0} MB", b / mb)
}
}
fn print_usage() {
println!("Usage:");
println!(" flint use <owner/model-name> download a model");
println!(" flint remove <owner/model-name> remove a model");
println!(" flint remove --all remove all models");
println!(" flint list list downloaded models");
}
fn models_dir() -> PathBuf {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".flint").join("models")
}
fn handle_remove(mut args: env::Args) -> Result<(), String> {
let model_name = args.next().unwrap_or_default();
if model_name == "--all" {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let models_dir = std::path::PathBuf::from(home).join(".flint").join("models");
if models_dir.exists() {
std::fs::remove_dir_all(&models_dir).unwrap_or(());
println!("removed all models from ~/.flint/models/");
} else {
println!("no models found");
}
return Ok(());
}
if model_name.is_empty() {
println!("Usage: flint remove <owner/model-name>");
println!(" flint remove --all");
return Ok(());
}
if !model_name.contains('/') {
println!("invalid model name. use format: owner/model-name");
return Ok(());
}
let last_part = model_name.rsplit('/').next().unwrap_or(&model_name);
let last_lower = last_part.to_lowercase();
let core = last_lower
.trim_end_matches("-gguf")
.trim_end_matches("_gguf");
let words: Vec<&str> = core
.split(|c: char| c == '-' || c == '_' || c == '.')
.filter(|w| w.len() > 1)
.collect();
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let models_dir = std::path::PathBuf::from(home).join(".flint").join("models");
if let Ok(entries) = std::fs::read_dir(&models_dir) {
for entry in entries.flatten() {
let fname = entry.file_name().to_string_lossy().to_lowercase();
if fname.ends_with(".gguf") {
let matches = words.iter().all(|w| fname.contains(w));
if matches {
std::fs::remove_file(entry.path()).unwrap_or(());
println!("removed: {}", entry.file_name().to_string_lossy());
return Ok(());
}
}
}
}
println!("model not found: {}", model_name);
Ok(())
}
fn handle_list() -> Result<(), String> {
let dir = models_dir();
if !dir.exists() {
println!("no models downloaded. run: flint use <model>");
return Ok(());
}
let mut items: Vec<(String, u64)> = Vec::new();
if let Ok(entries) = std::fs::read_dir(&dir) {
for entry in entries.flatten() {
let path = entry.path();
if path
.extension()
.and_then(|ext| ext.to_str())
.map(|ext| ext.eq_ignore_ascii_case("gguf"))
.unwrap_or(false)
{
let size = entry.metadata().map(|m| m.len()).unwrap_or(0);
items.push((entry.file_name().to_string_lossy().into_owned(), size));
}
}
}
if items.is_empty() {
println!("no models downloaded. run: flint use <model>");
return Ok(());
}
println!("downloaded models (~/.flint/models/):");
for (name, size) in items {
println!(" {:<40} {}", name, format_size_mb(size));
}
Ok(())
}
fn format_size_mb(bytes: u64) -> String {
let mb = 1024.0 * 1024.0;
let b = bytes as f64;
format!("{:.0} MB", b / mb)
}