#![allow(clippy::too_many_arguments)]
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};
#[cfg(feature = "llama")]
use std::sync::Mutex;
#[cfg(feature = "llama")]
use tracing::info;
#[derive(Default)]
pub struct LlamaEngine;
impl LlamaEngine {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl InferenceEngine for LlamaEngine {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
#[cfg(feature = "llama")]
{
use llama_cpp_2 as llama;
use std::num::NonZeroU32;
let be = llama::llama_backend::LlamaBackend::init()?;
let model = llama::model::LlamaModel::load_from_file(
&be,
&spec.base_path,
&Default::default(),
)?;
let ctx_params = llama::context::params::LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(spec.ctx_len as u32))
.with_n_batch(2048)
.with_n_ubatch(512)
.with_n_threads(
spec.n_threads.unwrap_or(
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4),
),
)
.with_n_threads_batch(
spec.n_threads.unwrap_or(
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4),
),
);
let ctx_tmp = model.new_context(&be, ctx_params)?;
if let Some(ref lora) = spec.lora_path {
let lora_path = if lora.extension().and_then(|s| s.to_str()) == Some("safetensors")
{
return Err(anyhow!(
"SafeTensors LoRA detected: {}. Please convert to GGUF format first.",
lora.display()
));
} else {
lora.clone()
};
let mut adapter = model.lora_adapter_init(&lora_path)?;
ctx_tmp
.lora_adapter_set(&mut adapter, 1.0)
.map_err(|e| anyhow!("lora set: {e:?}"))?;
info!(adapter=%lora_path.display(), "LoRA adapter attached");
}
let ctx: llama::context::LlamaContext<'static> =
unsafe { std::mem::transmute(ctx_tmp) };
Ok(Box::new(LlamaLoaded {
_be: be,
model,
ctx: Mutex::new(ctx),
}))
}
#[cfg(not(feature = "llama"))]
{
let _ = spec; Ok(Box::new(LlamaFallback))
}
}
}
#[cfg(feature = "llama")]
struct LlamaLoaded {
_be: llama_cpp_2::llama_backend::LlamaBackend,
model: llama_cpp_2::model::LlamaModel,
ctx: Mutex<llama_cpp_2::context::LlamaContext<'static>>,
}
#[cfg(feature = "llama")]
unsafe impl Send for LlamaLoaded {}
#[cfg(feature = "llama")]
unsafe impl Sync for LlamaLoaded {}
#[cfg(feature = "llama")]
#[async_trait]
impl LoadedModel for LlamaLoaded {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
mut on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
use llama_cpp_2::{
llama_batch::LlamaBatch,
model::{AddBos, Special},
sampling::LlamaSampler,
};
let mut ctx = self.ctx.lock().unwrap();
let tokens = self.model.str_to_token(prompt, AddBos::Always)?;
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, &token) in tokens.iter().enumerate() {
let logits = i == tokens.len() - 1;
batch.add(token, i as i32, &[0], logits)?;
}
ctx.decode(&mut batch)?;
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(opts.temperature),
LlamaSampler::top_p(opts.top_p, 1),
LlamaSampler::top_k(opts.top_k),
LlamaSampler::penalties(64, 0.0, 0.0, opts.repeat_penalty),
LlamaSampler::greedy(),
])
.with_tokens(tokens.iter().copied());
let mut out = String::new();
let mut all_tokens = tokens;
for _ in 0..opts.max_tokens {
let token = sampler.sample(&ctx, -1);
if self.model.is_eog_token(token) {
break;
}
let piece = self.model.token_to_str(token, Special::Plaintext)?;
let start = out.len();
out.push_str(&piece);
if let Some(cb) = on_token.as_mut() {
cb(out[start..].to_string());
}
let mut step = LlamaBatch::new(1, 1);
step.add(token, all_tokens.len() as i32, &[0], true)?;
ctx.decode(&mut step)?;
all_tokens.push(token);
}
Ok(out)
}
}
#[cfg(not(feature = "llama"))]
struct LlamaFallback;
#[cfg(not(feature = "llama"))]
#[async_trait]
impl LoadedModel for LlamaFallback {
async fn generate(
&self,
prompt: &str,
_opts: GenOptions,
mut on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
let fallback_msg =
"Llama.cpp support not enabled. Build with --features llama for full functionality.";
if let Some(cb) = on_token.as_mut() {
cb(fallback_msg.to_string());
}
Ok(format!("[INFO] {} Input: {}", fallback_msg, prompt))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_engine_initialization() {
let engine = LlamaEngine::new();
assert_eq!(std::mem::size_of_val(&engine), 0);
}
#[tokio::test]
async fn test_model_loading_validation() {
let _engine = LlamaEngine::new();
let _spec = ModelSpec {
name: "test".to_string(),
base_path: "/nonexistent".into(),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 2048,
n_threads: None,
};
}
#[test]
fn test_model_spec_validation() {
let spec = ModelSpec {
name: "valid".to_string(),
base_path: "test.gguf".into(),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 4096,
n_threads: Some(4),
};
assert_eq!(spec.name, "valid");
assert_eq!(spec.ctx_len, 4096);
assert!(spec.template.is_some());
}
}