use anyhow::Result;
use llama_cpp_rs::{
LlamaContext, LlamaContextParams, LlamaModel, LlamaModelParams,
SessionParams, TokenData, TokenDataArray,
};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
use std::sync::Arc;
use tauri::{AppHandle, Manager};
use tokio::sync::Mutex;
use tracing::{error, info};
const MODEL_ID: &str = "TinyLlama/TinyLlama-1.1B-Chat-v1.0";
const GGUF_MODEL_ID: &str = "TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF";
const MODEL_FILE: &str = "tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLoadProgress {
pub progress: u32,
pub status: String,
pub model_name: String,
}
pub struct AIManager {
app_handle: AppHandle,
model: Option<Arc<Mutex<LlamaModel>>>,
context: Option<Arc<Mutex<LlamaContext>>>,
model_path: Option<PathBuf>,
}
impl AIManager {
pub fn new(app_handle: AppHandle) -> Result<Self> {
Ok(Self {
app_handle,
model: None,
context: None,
model_path: None,
})
}
pub async fn load_model(&mut self) -> Result<()> {
self.emit_progress(0, "Starting TinyLlama initialization...", MODEL_ID).await;
let model_path = self.download_model().await?;
self.model_path = Some(model_path.clone());
self.emit_progress(80, "Loading TinyLlama into memory...", MODEL_ID).await;
let model_params = LlamaModelParams {
n_gpu_layers: 32, main_gpu: 0,
tensor_split: None,
vocab_only: false,
use_mmap: true,
use_mlock: false,
};
let model = LlamaModel::load_from_file(
model_path.to_str().unwrap(),
model_params
).map_err(|e| anyhow::anyhow!("Failed to load model: {:?}", e))?;
let ctx_params = LlamaContextParams {
n_ctx: 2048, n_batch: 512, n_threads: 4, n_threads_batch: 4,
rope_scaling_type: 0,
rope_freq_base: 0.0,
rope_freq_scale: 0.0,
yarn_ext_factor: 0.0,
yarn_attn_factor: 0.0,
yarn_beta_fast: 0.0,
yarn_beta_slow: 0.0,
yarn_orig_ctx: 0,
cb_eval: None,
cb_eval_user_data: std::ptr::null_mut(),
type_k: 0,
type_v: 0,
logits_all: false,
embeddings: false,
offload_kqv: true,
};
let context = model.new_context(ctx_params)
.map_err(|e| anyhow::anyhow!("Failed to create context: {:?}", e))?;
self.model = Some(Arc::new(Mutex::new(model)));
self.context = Some(Arc::new(Mutex::new(context)));
self.emit_progress(100, "TinyLlama fully loaded and ready!", MODEL_ID).await;
info!("TinyLlama model loaded successfully with Metal acceleration");
Ok(())
}
pub async fn generate_response(&self, prompt: &str, max_tokens: usize) -> Result<String> {
let model = self.model.as_ref()
.ok_or_else(|| anyhow::anyhow!("Model not loaded"))?;
let context = self.context.as_ref()
.ok_or_else(|| anyhow::anyhow!("Context not loaded"))?;
let formatted_prompt = format!(
"<|system|>\nYou are a helpful AI assistant named Tektra.</s>\n<|user|>\n{}</s>\n<|assistant|>\n",
prompt
);
info!("Generating response for prompt: {}", prompt);
let mut model_guard = model.lock().await;
let mut context_guard = context.lock().await;
let tokens = model_guard.tokenize(&formatted_prompt, true)
.map_err(|e| anyhow::anyhow!("Tokenization failed: {:?}", e))?;
context_guard.clear();
context_guard.eval(&model_guard, &tokens, 1)
.map_err(|e| anyhow::anyhow!("Prompt evaluation failed: {:?}", e))?;
let mut response_tokens = Vec::new();
let mut response_text = String::new();
for _ in 0..max_tokens {
let candidates = context_guard.candidates();
let mut candidates_array = TokenDataArray::from_iter(
candidates.iter().enumerate().map(|(id, &logit)| {
TokenData {
id: id as i32,
logit,
p: 0.0,
}
})
);
candidates_array.sample_temperature(0.7);
candidates_array.sample_top_p(0.9, 1);
let token_id = candidates_array.sample_token(&mut rand::thread_rng());
response_tokens.push(token_id);
if token_id == model_guard.token_eos() {
break;
}
let token_str = model_guard.token_to_str(token_id)
.unwrap_or_else(|_| String::new());
response_text.push_str(&token_str);
context_guard.eval(&model_guard, &[token_id], 1)
.map_err(|e| anyhow::anyhow!("Token evaluation failed: {:?}", e))?;
if response_text.ends_with("</s>") ||
response_text.ends_with("\n\nUser:") ||
response_text.ends_with("\n\nHuman:") {
break;
}
}
let response = response_text
.trim()
.replace("</s>", "")
.replace("<|assistant|>", "")
.trim()
.to_string();
info!("Generated response: {}", response);
Ok(response)
}
async fn download_model(&self) -> Result<PathBuf> {
self.emit_progress(10, "Checking for TinyLlama model...", MODEL_ID).await;
let cache_dir = dirs::cache_dir()
.ok_or_else(|| anyhow::anyhow!("Failed to get cache directory"))?
.join("huggingface")
.join("hub")
.join(GGUF_MODEL_ID.replace('/', "--"));
std::fs::create_dir_all(&cache_dir)?;
let model_path = cache_dir.join(MODEL_FILE);
if model_path.exists() {
let metadata = std::fs::metadata(&model_path)?;
if metadata.len() > 100_000_000 {
self.emit_progress(70, "Found cached TinyLlama model", MODEL_ID).await;
return Ok(model_path);
} else {
let _ = std::fs::remove_file(&model_path);
}
}
self.emit_progress(20, "Downloading TinyLlama Q4 model (669MB)...", MODEL_ID).await;
let url = format!(
"https://huggingface.co/{}/resolve/main/{}",
GGUF_MODEL_ID, MODEL_FILE
);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(600))
.user_agent("Tektra-AI-Assistant/0.1.0")
.build()?;
let response = client.get(&url).send().await?;
if !response.status().is_success() {
return Err(anyhow::anyhow!(
"Failed to download model: HTTP {}",
response.status()
));
}
let total_size = response.content_length().unwrap_or(0);
let mut downloaded = 0u64;
let temp_path = model_path.with_extension("tmp");
let mut file = tokio::fs::File::create(&temp_path).await?;
let mut stream = response.bytes_stream();
use futures::StreamExt;
use tokio::io::AsyncWriteExt;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
file.write_all(&chunk).await?;
downloaded += chunk.len() as u64;
if total_size > 0 {
let progress = 20 + ((downloaded as f64 / total_size as f64) * 50.0) as u32;
self.emit_progress(
progress,
&format!(
"Downloading TinyLlama ({} / {})",
bytesize::ByteSize(downloaded),
bytesize::ByteSize(total_size)
),
MODEL_ID,
).await;
}
}
file.flush().await?;
drop(file);
tokio::fs::rename(&temp_path, &model_path).await?;
Ok(model_path)
}
async fn emit_progress(&self, progress: u32, status: &str, model_name: &str) {
let progress_data = ModelLoadProgress {
progress,
status: status.to_string(),
model_name: model_name.to_string(),
};
if let Err(e) = self.app_handle.emit_all("model-loading-progress", &progress_data) {
error!("Failed to emit progress: {}", e);
}
}
pub fn is_loaded(&self) -> bool {
self.model.is_some() && self.context.is_some()
}
}