use std::num::NonZeroU32;
use std::path::{Path, PathBuf};
use encoding_rs::UTF_8;
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaModel};
use llama_cpp_2::sampling::LlamaSampler;
pub struct LocalAI {
model_path: String,
}
impl LocalAI {
pub fn new(model_name: &str) -> Self {
let last_part = model_name.rsplit('/').next().unwrap_or(model_name);
let _search_term = last_part
.to_lowercase()
.replace('-', "")
.replace('_', "");
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
let models_dir = PathBuf::from(home).join(".flint").join("models");
let _ = std::fs::create_dir_all(&models_dir);
let exact = models_dir.join(format!("{}.gguf", last_part));
if exact.exists() {
return Self {
model_path: exact.to_string_lossy().into_owned(),
};
}
if let Ok(entries) = std::fs::read_dir(&models_dir) {
for entry in entries.flatten() {
let fname = entry.file_name().to_string_lossy().to_lowercase();
if !fname.ends_with(".gguf") {
continue;
}
let last_lower = last_part.to_lowercase();
let core = last_lower
.trim_end_matches("-gguf")
.trim_end_matches("_gguf");
let words: Vec<&str> = core
.split(|c: char| c == '-' || c == '_' || c == '.')
.filter(|w| w.len() > 1)
.collect();
let matches = words.iter().all(|w| fname.contains(w));
if matches {
return Self {
model_path: entry.path().to_string_lossy().into_owned(),
};
}
}
}
Self {
model_path: models_dir
.join(format!("{}.gguf", last_part))
.to_string_lossy()
.into_owned(),
}
}
pub fn is_available(&self) -> bool {
Path::new(&self.model_path).exists()
}
pub fn chat(&self, message: &str) -> Result<String, String> {
std::env::set_var("GGML_LOG_LEVEL", "error");
std::env::set_var("LLAMA_LOG_LEVEL", "error");
if !self.is_available() {
return Err("Model not found. Run: localai use <model_name>".to_string());
}
let backend =
LlamaBackend::init().map_err(|e| format!("failed to init backend: {}", e))?;
let model = LlamaModel::load_from_file(
&backend,
Path::new(&self.model_path),
&LlamaModelParams::default(),
)
.map_err(|e| format!("model load failed: {}", e))?;
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(2048));
let mut ctx = model
.new_context(&backend, ctx_params)
.map_err(|e| format!("failed to create context: {}", e))?;
let prompt = format!(
"<|system|>\nYou are a helpful assistant.</s>\n<|user|>\n{}</s>\n<|assistant|>\n",
message
);
let tokens = model
.str_to_token(&prompt, AddBos::Always)
.map_err(|e| format!("tokenization failed: {}", e))?;
if tokens.is_empty() {
return Err("tokenization produced no tokens".to_string());
}
let mut batch = LlamaBatch::new(512, 1);
for (i, token) in tokens.iter().enumerate() {
let is_last = i + 1 == tokens.len();
batch
.add(*token, i as i32, &[0], is_last)
.map_err(|e| format!("batch add failed: {}", e))?;
}
ctx.decode(&mut batch)
.map_err(|e| format!("prompt decode failed: {}", e))?;
let mut sampler = LlamaSampler::greedy();
let mut output = String::new();
let mut n_cur = batch.n_tokens();
let mut decoder = UTF_8.new_decoder();
for _ in 0..512 {
let token = sampler.sample(&ctx, batch.n_tokens() - 1);
sampler.accept(token);
if model.is_eog_token(token) {
break;
}
let piece = model
.token_to_piece(token, &mut decoder, false, None)
.map_err(|e| format!("token decode failed: {}", e))?;
output.push_str(&piece);
batch.clear();
batch
.add(token, n_cur, &[0], true)
.map_err(|e| format!("batch add failed: {}", e))?;
n_cur += 1;
ctx.decode(&mut batch)
.map_err(|e| format!("decode failed: {}", e))?;
}
Ok(output)
}
pub fn model_path(&self) -> &str {
&self.model_path
}
pub fn download(&self) -> Result<(), String> {
if self.is_available() {
return Ok(());
}
Err("Model not downloaded. Run: flint use <model_name>".to_string())
}
}