use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use llama_cpp_2::context::params::LlamaContextParams;
use llama_cpp_2::context::LlamaContext;
use llama_cpp_2::llama_backend::LlamaBackend;
use llama_cpp_2::llama_batch::LlamaBatch;
use llama_cpp_2::model::params::LlamaModelParams;
use llama_cpp_2::model::{AddBos, LlamaModel, Special};
use llama_cpp_2::sampling::LlamaSampler;
use crate::decider::{DecisionResponse, LlmDecider, LlmError, WorkerDecisionRequest};
use crate::prompt_builder::PromptBuilder;
use crate::response_parser;
#[derive(Debug, Clone)]
pub struct LlamaCppStandaloneConfig {
pub model_path: String,
pub gguf_file: Option<String>,
pub n_ctx: u32,
pub n_batch: u32,
pub n_gpu_layers: u32,
pub max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub n_threads: Option<u32>,
}
impl Default for LlamaCppStandaloneConfig {
fn default() -> Self {
Self {
model_path: String::new(),
gguf_file: None,
n_ctx: 4096,
n_batch: 512,
n_gpu_layers: 0,
max_tokens: 256,
temperature: 0.7,
top_p: 0.9,
n_threads: None,
}
}
}
impl LlamaCppStandaloneConfig {
pub fn from_hf(repo_id: impl Into<String>, gguf_file: impl Into<String>) -> Self {
Self {
model_path: repo_id.into(),
gguf_file: Some(gguf_file.into()),
..Default::default()
}
}
pub fn from_local(path: impl Into<String>) -> Self {
Self {
model_path: path.into(),
gguf_file: None,
..Default::default()
}
}
pub fn with_gpu_layers(mut self, n_layers: u32) -> Self {
self.n_gpu_layers = n_layers;
self
}
pub fn with_context_size(mut self, n_ctx: u32) -> Self {
self.n_ctx = n_ctx;
self
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = top_p;
self
}
pub fn lfm2_1b() -> Self {
Self::from_hf(
"LiquidAI/LFM2.5-1.2B-Instruct-GGUF",
"LFM2.5-1.2B-Instruct-Q4_K_M.gguf",
)
}
pub fn lfm2_1b_q8() -> Self {
Self::from_hf(
"LiquidAI/LFM2.5-1.2B-Instruct-GGUF",
"LFM2.5-1.2B-Instruct-Q8_0.gguf",
)
}
pub fn qwen_0_5b() -> Self {
Self::from_hf(
"Qwen/Qwen2.5-0.5B-Instruct-GGUF",
"qwen2.5-0.5b-instruct-q4_k_m.gguf",
)
}
pub fn qwen_1_5b() -> Self {
Self::from_hf(
"Qwen/Qwen2.5-1.5B-Instruct-GGUF",
"qwen2.5-1.5b-instruct-q4_k_m.gguf",
)
}
pub fn phi3_mini() -> Self {
Self::from_hf(
"microsoft/Phi-3-mini-4k-instruct-gguf",
"Phi-3-mini-4k-instruct-q4.gguf",
)
}
pub fn display_name(&self) -> &str {
&self.model_path
}
}
#[derive(Clone)]
pub struct LlamaCppStandaloneDecider {
inner: Arc<LlamaCppInner>,
prompt_builder: PromptBuilder,
}
struct LlamaCppInner {
context: Mutex<Option<LlamaContext<'static>>>,
#[allow(dead_code)]
backend: LlamaBackend,
model: LlamaModel,
config: LlamaCppStandaloneConfig,
}
unsafe impl Send for LlamaCppInner {}
unsafe impl Sync for LlamaCppInner {}
impl LlamaCppStandaloneDecider {
pub fn new(config: LlamaCppStandaloneConfig) -> Result<Self, LlmError> {
let backend = LlamaBackend::init()
.map_err(|e| LlmError::permanent(format!("Failed to init llama backend: {}", e)))?;
let model_path = Self::resolve_model_path(&config)?;
let model_params = LlamaModelParams::default().with_n_gpu_layers(config.n_gpu_layers);
let model = LlamaModel::load_from_file(&backend, &model_path, &model_params)
.map_err(|e| LlmError::permanent(format!("Failed to load model: {}", e)))?;
let ctx_params = LlamaContextParams::default()
.with_n_ctx(std::num::NonZeroU32::new(config.n_ctx))
.with_n_batch(config.n_batch);
let context = model
.new_context(&backend, ctx_params)
.map_err(|e| LlmError::permanent(format!("Failed to create context: {}", e)))?;
let context: LlamaContext<'static> = unsafe { std::mem::transmute(context) };
let inner = LlamaCppInner {
context: Mutex::new(Some(context)),
backend,
model,
config,
};
Ok(Self {
inner: Arc::new(inner),
prompt_builder: PromptBuilder::new(),
})
}
fn resolve_model_path(config: &LlamaCppStandaloneConfig) -> Result<PathBuf, LlmError> {
if let Some(ref gguf_file) = config.gguf_file {
let cache_dir = dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("huggingface")
.join("hub");
let model_dir_name = format!("models--{}", config.model_path.replace('/', "--"));
let model_dir = cache_dir.join(&model_dir_name);
let snapshots_dir = model_dir.join("snapshots");
if snapshots_dir.exists() {
if let Ok(entries) = std::fs::read_dir(&snapshots_dir) {
for entry in entries.flatten() {
let snapshot_path = entry.path().join(gguf_file);
if snapshot_path.exists() {
tracing::info!("Using cached model: {:?}", snapshot_path);
return Ok(snapshot_path);
}
}
}
}
tracing::info!(
"Model not in cache, downloading from HuggingFace: {}",
config.model_path
);
let api = hf_hub::api::sync::Api::new()
.map_err(|e| LlmError::permanent(format!("Failed to create HF API: {}", e)))?;
let repo = api.model(config.model_path.clone());
let path = repo
.get(gguf_file)
.map_err(|e| LlmError::permanent(format!("Failed to download model: {}", e)))?;
Ok(path)
} else {
let path = PathBuf::from(&config.model_path);
if !path.exists() {
return Err(LlmError::permanent(format!(
"Model file not found: {}",
config.model_path
)));
}
Ok(path)
}
}
fn generate_sync(
inner: &LlamaCppInner,
context: &mut LlamaContext,
prompt: &str,
) -> Result<String, LlmError> {
context.clear_kv_cache();
let tokens = inner
.model
.str_to_token(prompt, AddBos::Always)
.map_err(|e| LlmError::permanent(format!("Tokenization error: {}", e)))?;
tracing::debug!(
tokens_len = tokens.len(),
prompt_len = prompt.len(),
"Tokenized prompt"
);
let mut batch = LlamaBatch::new(inner.config.n_batch as usize, 1);
for (i, token) in tokens.iter().enumerate() {
let is_last = i == tokens.len() - 1;
batch
.add(*token, i as i32, &[0], is_last)
.map_err(|e| LlmError::permanent(format!("Batch add error: {}", e)))?;
}
context
.decode(&mut batch)
.map_err(|e| LlmError::permanent(format!("Decode error: {}", e)))?;
tracing::debug!("Prompt decoded, starting generation");
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(inner.config.temperature),
LlamaSampler::top_p(inner.config.top_p, 1),
LlamaSampler::dist(42),
]);
let mut output_tokens = Vec::new();
let mut n_cur = tokens.len();
for i in 0..inner.config.max_tokens {
let new_token = sampler.sample(context, -1);
if inner.model.is_eog_token(new_token) {
tracing::debug!(iteration = i, "EOS token reached");
break;
}
output_tokens.push(new_token);
if i < 5 {
if let Ok(piece) =
inner
.model
.token_to_str_with_size(new_token, 256, Special::Tokenize)
{
tracing::trace!(token_idx = i, piece = ?piece, "Generated token");
}
}
batch.clear();
batch
.add(new_token, n_cur as i32, &[0], true)
.map_err(|e| LlmError::permanent(format!("Batch add error: {}", e)))?;
n_cur += 1;
context
.decode(&mut batch)
.map_err(|e| LlmError::permanent(format!("Decode error: {}", e)))?;
}
let mut output = String::new();
for token in &output_tokens {
let piece = inner
.model
.token_to_str_with_size(*token, 256, Special::Tokenize)
.map_err(|e| LlmError::permanent(format!("Detokenization error: {}", e)))?;
output.push_str(&piece);
}
Ok(output)
}
async fn call_llm(&self, prompt: &str) -> Result<String, LlmError> {
let inner = Arc::clone(&self.inner);
let formatted_prompt = format!("<|user|>\n{}\n<|assistant|>\n", prompt);
tokio::task::spawn_blocking(move || {
let mut guard = inner.context.lock().unwrap();
let context = guard.as_mut().expect("Context not initialized");
Self::generate_sync(&inner, context, &formatted_prompt)
})
.await
.map_err(|e| LlmError::permanent(format!("spawn_blocking failed: {}", e)))?
}
}
impl LlmDecider for LlamaCppStandaloneDecider {
fn decide(
&self,
request: WorkerDecisionRequest,
) -> Pin<Box<dyn Future<Output = Result<DecisionResponse, LlmError>> + Send + '_>> {
Box::pin(async move {
let prompt = self.prompt_builder.build(&request.context);
let raw_response = self.call_llm(&prompt).await?;
let candidate_names = response_parser::candidate_names(&request.context.candidates);
let decision = match response_parser::parse_response(&raw_response, &candidate_names) {
Ok(mut d) => {
d.prompt = Some(prompt);
d.raw_response = Some(raw_response);
Ok(d)
}
Err(e) => {
tracing::warn!(error = %e, "Parse error");
tracing::debug!(
raw_preview = %&raw_response[..raw_response.len().min(500)],
"Raw response preview"
);
Err(e)
}
};
decision
})
}
fn call_raw(
&self,
prompt: &str,
) -> Pin<Box<dyn Future<Output = Result<String, LlmError>> + Send + '_>> {
let prompt = prompt.to_string();
Box::pin(async move { self.call_llm(&prompt).await })
}
fn model_name(&self) -> &str {
self.inner.config.display_name()
}
fn is_healthy(&self) -> Pin<Box<dyn Future<Output = bool> + Send + '_>> {
Box::pin(async { true })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_from_hf() {
let config = LlamaCppStandaloneConfig::from_hf(
"LiquidAI/LFM2.5-1.2B-Instruct-GGUF",
"LFM2.5-1.2B-Instruct-Q4_K_M.gguf",
);
assert_eq!(config.model_path, "LiquidAI/LFM2.5-1.2B-Instruct-GGUF");
assert_eq!(
config.gguf_file,
Some("LFM2.5-1.2B-Instruct-Q4_K_M.gguf".to_string())
);
}
#[test]
fn test_config_from_local() {
let config = LlamaCppStandaloneConfig::from_local("/path/to/model.gguf");
assert_eq!(config.model_path, "/path/to/model.gguf");
assert!(config.gguf_file.is_none());
}
#[test]
fn test_config_presets() {
let lfm2 = LlamaCppStandaloneConfig::lfm2_1b();
assert!(lfm2.model_path.contains("LFM2.5"));
assert!(lfm2.gguf_file.as_ref().unwrap().contains("Q4_K_M"));
let lfm2_q8 = LlamaCppStandaloneConfig::lfm2_1b_q8();
assert!(lfm2_q8.gguf_file.as_ref().unwrap().contains("Q8_0"));
let qwen = LlamaCppStandaloneConfig::qwen_0_5b();
assert!(qwen.model_path.contains("Qwen"));
}
#[test]
fn test_config_builder() {
let config = LlamaCppStandaloneConfig::lfm2_1b()
.with_gpu_layers(32)
.with_max_tokens(512)
.with_temperature(0.5)
.with_top_p(0.95)
.with_context_size(8192);
assert_eq!(config.n_gpu_layers, 32);
assert_eq!(config.max_tokens, 512);
assert!((config.temperature - 0.5).abs() < f32::EPSILON);
assert!((config.top_p - 0.95).abs() < f32::EPSILON);
assert_eq!(config.n_ctx, 8192);
}
#[test]
fn test_format_chat_prompt() {
let user_prompt = "Select an action from: Read, Write, Grep";
let formatted = format!("<|user|>\n{}\n<|assistant|>\n", user_prompt);
assert!(formatted.starts_with("<|user|>"));
assert!(formatted.contains(user_prompt));
assert!(formatted.ends_with("<|assistant|>\n"));
}
}