use llama_cpp_2::{
context::params::LlamaContextParams,
llama_backend::LlamaBackend,
llama_batch::LlamaBatch,
model::{params::LlamaModelParams, AddBos, LlamaModel},
sampling::LlamaSampler,
};
use std::path::PathBuf;
pub struct Local {
backend: Option<LlamaBackend>,
model: Option<LlamaModel>,
model_path: PathBuf,
initialized: bool,
temperature: f32,
}
#[derive(Debug)]
pub enum LocalError {
Load(String),
Download(String),
}
impl std::fmt::Display for LocalError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
LocalError::Load(msg) => write!(f, "Model load error: {msg}"),
LocalError::Download(msg) => write!(f, "Download error: {msg}"),
}
}
}
impl std::error::Error for LocalError {}
impl Default for Local {
fn default() -> Self {
Self::new()
}
}
impl Local {
pub fn new() -> Self {
let cache_dir = std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("."))
.join(".cache")
.join("models");
let model_path = cache_dir.join("Qwen3.5-2B-Q4_K_M.gguf");
Self {
backend: None,
model: None,
model_path,
initialized: false,
temperature: 0.6,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub async fn initialize(&mut self) -> Result<(), LocalError> {
self.ensure_model_exists().await?;
let backend = LlamaBackend::init()
.map_err(|e| LocalError::Load(format!("Failed to initialize backend: {e:?}")))?;
self.backend = Some(backend);
#[cfg(any(feature = "cuda", feature = "metal", feature = "vulkan"))]
let gpu_layers = u32::MAX;
#[cfg(not(any(feature = "cuda", feature = "metal", feature = "vulkan")))]
let gpu_layers = 0u32;
if gpu_layers > 0 {
println!(
"GPU backend enabled — offloading layers to GPU (llama.cpp will auto-fit VRAM)"
);
}
let model_params = LlamaModelParams::default().with_n_gpu_layers(gpu_layers);
let model = LlamaModel::load_from_file(
self.backend.as_ref().unwrap(),
&self.model_path,
&model_params,
)
.map_err(|e| LocalError::Load(format!("Failed to load Qwen 3.5 model: {e:?}")))?;
self.model = Some(model);
self.initialized = true;
Ok(())
}
async fn ensure_model_exists(&self) -> Result<(), LocalError> {
if self.model_path.exists() {
return Ok(());
}
println!("Preparing AI model for document analysis...");
self.download_model().await?;
Ok(())
}
async fn download_model(&self) -> Result<(), LocalError> {
use futures_util::StreamExt;
use tokio::io::AsyncWriteExt;
if let Some(parent) = self.model_path.parent() {
std::fs::create_dir_all(parent)
.map_err(|e| LocalError::Download(format!("Failed to create cache dir: {e}")))?;
}
let url =
"https://huggingface.co/unsloth/Qwen3.5-2B-GGUF/resolve/main/Qwen3.5-2B-Q4_K_M.gguf";
let client = reqwest::Client::new();
let response = client
.get(url)
.send()
.await
.map_err(|e| LocalError::Download(format!("HTTP request failed: {e}")))?;
if !response.status().is_success() {
return Err(LocalError::Download(format!(
"HTTP error: {}",
response.status()
)));
}
let total_size = response.content_length().unwrap_or(0);
let mut file = tokio::fs::File::create(&self.model_path)
.await
.map_err(|e| LocalError::Download(format!("Failed to create file: {e}")))?;
let mut downloaded = 0u64;
let mut stream = response.bytes_stream();
while let Some(chunk) = stream.next().await {
let chunk =
chunk.map_err(|e| LocalError::Download(format!("Download chunk error: {e}")))?;
file.write_all(&chunk)
.await
.map_err(|e| LocalError::Download(format!("Write error: {e}")))?;
downloaded += chunk.len() as u64;
if total_size > 0 {
let progress = (downloaded as f64 / total_size as f64) * 100.0;
print!("\rDownloading AI model: {progress:.0}%");
use std::io::{self, Write};
io::stdout().flush().unwrap();
}
}
println!("\nAI model ready for document analysis.");
Ok(())
}
pub async fn generate(&mut self, prompt: &str) -> String {
if !self.initialized {
if let Err(e) = self.initialize().await {
return format!("Initialization error: {e}");
}
}
let model = match &self.model {
Some(model) => model,
None => {
return "Model not loaded.".to_string();
}
};
let backend = match &self.backend {
Some(backend) => backend,
None => {
return "Backend not initialized.".to_string();
}
};
let context_params =
LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(81920));
let mut context = match model.new_context(backend, context_params) {
Ok(ctx) => ctx,
Err(e) => {
return format!("Failed to create context: {e:?}");
}
};
let formatted_prompt = format!(
"<|im_start|>system\nYou are a helpful AI assistant specialized in analyzing resumes and CVs.<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n<think>\n"
);
let tokens = match model.str_to_token(&formatted_prompt, AddBos::Always) {
Ok(tokens) => tokens,
Err(e) => {
return format!("Tokenization error: {e:?}");
}
};
let mut batch = LlamaBatch::new(5120, 1);
let last_index = tokens.len() as i32 - 1;
for (i, token) in (0_i32..).zip(tokens.into_iter()) {
let is_last = i == last_index;
if let Err(e) = batch.add(token, i, &[0], is_last) {
return format!("Batch add error: {e:?}");
}
}
if let Err(e) = context.decode(&mut batch) {
return format!("Decode error: {e:?}");
}
let mut response = String::new();
let max_tokens = 1_000_000;
let mut n_cur = batch.n_tokens();
let mut decoder = encoding_rs::UTF_8.new_decoder();
let mut sampler = if self.temperature == 0.0 {
LlamaSampler::greedy()
} else {
LlamaSampler::chain_simple([
LlamaSampler::temp(self.temperature),
LlamaSampler::dist(1234),
])
};
for _ in 0..max_tokens {
let new_token = sampler.sample(&context, batch.n_tokens() - 1);
sampler.accept(new_token);
if new_token == model.token_eos() {
break;
}
if let Ok(token_str) = model.token_to_piece(new_token, &mut decoder, true, None) {
response.push_str(&token_str);
if response.contains("<|im_end|>") || response.contains("</s>") {
break;
}
}
batch.clear();
if let Err(e) = batch.add(new_token, n_cur, &[0], true) {
return format!("Batch error: {e:?}");
}
n_cur += 1;
if let Err(e) = context.decode(&mut batch) {
return format!("Decode error: {e:?}");
}
}
let cleaned = response
.split("<|im_end|>")
.next()
.unwrap_or(&response)
.trim()
.to_string();
self.extract_thinking_and_response(&cleaned)
}
fn extract_thinking_and_response(&self, text: &str) -> String {
if let Some(think_end) = text.find("</think>") {
return text[think_end + 8..].trim().to_string();
}
if let (Some(think_start), Some(think_end)) = (text.find("<think>"), text.find("</think>"))
{
if think_start < think_end {
return text[think_end + 8..].trim().to_string();
}
}
text.trim().to_string()
}
pub fn is_ready(&self) -> bool {
self.initialized
&& self.backend.is_some()
&& self.model.is_some()
&& self.model_path.exists()
}
}