#![allow(clippy::too_many_arguments)]
use anyhow::Result;
use async_trait::async_trait;
use super::{GenOptions, InferenceEngine, LoadedModel, ModelSpec};
#[cfg(feature = "llama")]
use std::sync::Mutex;
use tracing::info;
#[derive(Default)]
pub struct LlamaEngine {
gpu_backend: GpuBackend,
moe_config: MoeConfig,
}
#[derive(Debug, Clone, Default)]
struct MoeConfig {
cpu_moe_all: bool,
n_cpu_moe: Option<usize>,
}
#[derive(Debug, Clone, Default)]
enum GpuBackend {
#[default]
Cpu,
#[cfg(feature = "llama-cuda")]
Cuda,
#[cfg(feature = "llama-vulkan")]
Vulkan,
#[cfg(feature = "llama-opencl")]
OpenCL,
}
impl GpuBackend {
fn from_string(s: &str) -> Self {
match s.to_lowercase().as_str() {
"auto" => Self::detect_best(),
"cpu" => GpuBackend::Cpu,
#[cfg(feature = "llama-cuda")]
"cuda" => GpuBackend::Cuda,
#[cfg(feature = "llama-vulkan")]
"vulkan" => GpuBackend::Vulkan,
#[cfg(feature = "llama-opencl")]
"opencl" => GpuBackend::OpenCL,
#[cfg(not(feature = "llama-cuda"))]
"cuda" => {
tracing::warn!("CUDA backend requested but not enabled at compile time. Rebuild with --features llama-cuda");
Self::detect_best()
}
#[cfg(not(feature = "llama-vulkan"))]
"vulkan" => {
tracing::warn!("Vulkan backend requested but not enabled at compile time. Rebuild with --features llama-vulkan");
Self::detect_best()
}
#[cfg(not(feature = "llama-opencl"))]
"opencl" => {
tracing::warn!("OpenCL backend requested but not enabled at compile time. Rebuild with --features llama-opencl");
Self::detect_best()
}
_ => {
tracing::warn!("Unknown GPU backend '{}', using auto-detect", s);
Self::detect_best()
}
}
}
fn detect_best() -> Self {
#[cfg(feature = "llama-cuda")]
{
if Self::is_cuda_available() {
info!("CUDA GPU detected, using CUDA backend");
return GpuBackend::Cuda;
}
}
#[cfg(feature = "llama-vulkan")]
{
if Self::is_vulkan_available() {
info!("Vulkan GPU detected, using Vulkan backend");
return GpuBackend::Vulkan;
}
}
#[cfg(feature = "llama-opencl")]
{
if Self::is_opencl_available() {
info!("OpenCL GPU detected, using OpenCL backend");
return GpuBackend::OpenCL;
}
}
info!("No GPU acceleration available, using CPU backend");
GpuBackend::Cpu
}
#[cfg(feature = "llama-cuda")]
fn is_cuda_available() -> bool {
std::process::Command::new("nvidia-smi")
.output()
.map(|output| output.status.success())
.unwrap_or(false)
}
#[cfg(feature = "llama-vulkan")]
fn is_vulkan_available() -> bool {
if std::process::Command::new("vulkaninfo")
.arg("--summary")
.output()
.map(|output| output.status.success())
.unwrap_or(false)
{
return true;
}
true
}
#[cfg(feature = "llama-opencl")]
fn is_opencl_available() -> bool {
if std::process::Command::new("clinfo")
.output()
.map(|output| output.status.success())
.unwrap_or(false)
{
return true;
}
#[cfg(target_os = "windows")]
{
if std::process::Command::new("where")
.arg("OpenCL.dll")
.output()
.map(|output| output.status.success())
.unwrap_or(false)
{
return true;
}
}
true
}
fn gpu_layers(&self) -> u32 {
999 }
}
impl LlamaEngine {
pub fn new() -> Self {
Self {
gpu_backend: GpuBackend::detect_best(),
moe_config: MoeConfig::default(),
}
}
pub fn new_with_backend(backend_str: Option<&str>) -> Self {
let gpu_backend = backend_str
.map(GpuBackend::from_string)
.unwrap_or_else(GpuBackend::detect_best);
info!("GPU backend configured: {:?}", gpu_backend);
Self {
gpu_backend,
moe_config: MoeConfig::default(),
}
}
pub fn with_moe_config(mut self, cpu_moe_all: bool, n_cpu_moe: Option<usize>) -> Self {
self.moe_config = MoeConfig {
cpu_moe_all,
n_cpu_moe,
};
self
}
pub fn get_backend_info(&self) -> String {
match self.gpu_backend {
GpuBackend::Cpu => "CPU".to_string(),
#[cfg(feature = "llama-cuda")]
GpuBackend::Cuda => "CUDA".to_string(),
#[cfg(feature = "llama-vulkan")]
GpuBackend::Vulkan => "Vulkan".to_string(),
#[cfg(feature = "llama-opencl")]
GpuBackend::OpenCL => "OpenCL".to_string(),
}
}
}
#[async_trait]
impl InferenceEngine for LlamaEngine {
async fn load(&self, spec: &ModelSpec) -> Result<Box<dyn LoadedModel>> {
#[cfg(feature = "llama")]
{
use anyhow::anyhow;
use llama_cpp_2 as llama;
use std::num::NonZeroU32;
let be = llama::llama_backend::LlamaBackend::init()?;
let n_gpu_layers = self.gpu_backend.gpu_layers();
info!(
"Loading model with {} GPU layers ({:?} backend)",
n_gpu_layers, self.gpu_backend
);
let mut model_params =
llama::model::params::LlamaModelParams::default().with_n_gpu_layers(n_gpu_layers);
if let Some(n) = self.moe_config.n_cpu_moe {
info!("MoE: Offloading first {} expert layers to CPU", n);
model_params = model_params.with_n_cpu_moe(n);
} else if self.moe_config.cpu_moe_all {
info!("MoE: Offloading ALL expert tensors to CPU");
model_params = model_params.with_cpu_moe_all();
}
let model =
llama::model::LlamaModel::load_from_file(&be, &spec.base_path, &model_params)?;
let ctx_params = llama::context::params::LlamaContextParams::default()
.with_n_ctx(NonZeroU32::new(spec.ctx_len as u32))
.with_n_batch(2048)
.with_n_ubatch(512)
.with_n_threads(
spec.n_threads.unwrap_or(
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4),
),
)
.with_n_threads_batch(
spec.n_threads.unwrap_or(
std::thread::available_parallelism()
.map(|n| n.get() as i32)
.unwrap_or(4),
),
);
let ctx_tmp = model.new_context(&be, ctx_params)?;
if let Some(ref lora) = spec.lora_path {
let lora_path = if lora.extension().and_then(|s| s.to_str()) == Some("safetensors")
{
return Err(anyhow!(
"SafeTensors LoRA detected: {}. Please convert to GGUF format first.",
lora.display()
));
} else {
lora.clone()
};
let mut adapter = model.lora_adapter_init(&lora_path)?;
ctx_tmp
.lora_adapter_set(&mut adapter, 1.0)
.map_err(|e| anyhow!("lora set: {e:?}"))?;
info!(adapter=%lora_path.display(), "LoRA adapter attached");
}
let ctx: llama::context::LlamaContext<'static> =
unsafe { std::mem::transmute(ctx_tmp) };
Ok(Box::new(LlamaLoaded {
_be: be,
model,
ctx: Mutex::new(ctx),
}))
}
#[cfg(not(feature = "llama"))]
{
let _ = spec; Ok(Box::new(LlamaFallback))
}
}
}
#[cfg(feature = "llama")]
struct LlamaLoaded {
_be: llama_cpp_2::llama_backend::LlamaBackend,
model: llama_cpp_2::model::LlamaModel,
ctx: Mutex<llama_cpp_2::context::LlamaContext<'static>>,
}
#[cfg(feature = "llama")]
unsafe impl Send for LlamaLoaded {}
#[cfg(feature = "llama")]
unsafe impl Sync for LlamaLoaded {}
#[cfg(feature = "llama")]
#[async_trait]
impl LoadedModel for LlamaLoaded {
async fn generate(
&self,
prompt: &str,
opts: GenOptions,
mut on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
use llama_cpp_2::{
llama_batch::LlamaBatch,
model::{AddBos, Special},
sampling::LlamaSampler,
};
let mut ctx = self
.ctx
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock context: {}", e))?;
let tokens = self.model.str_to_token(prompt, AddBos::Always)?;
let mut batch = LlamaBatch::new(tokens.len(), 1);
for (i, &token) in tokens.iter().enumerate() {
let logits = i == tokens.len() - 1;
batch.add(token, i as i32, &[0], logits)?;
}
ctx.decode(&mut batch)?;
let mut sampler = LlamaSampler::chain_simple([
LlamaSampler::temp(opts.temperature),
LlamaSampler::top_p(opts.top_p, 1),
LlamaSampler::top_k(opts.top_k),
LlamaSampler::penalties(64, 0.0, 0.0, opts.repeat_penalty),
LlamaSampler::greedy(),
])
.with_tokens(tokens.iter().copied());
let mut out = String::new();
let mut all_tokens = tokens;
for _ in 0..opts.max_tokens {
let token = sampler.sample(&ctx, -1);
if self.model.is_eog_token(token) {
break;
}
let piece = self.model.token_to_str(token, Special::Plaintext)?;
let start = out.len();
out.push_str(&piece);
if let Some(cb) = on_token.as_mut() {
cb(out[start..].to_string());
}
let mut step = LlamaBatch::new(1, 1);
step.add(token, all_tokens.len() as i32, &[0], true)?;
ctx.decode(&mut step)?;
all_tokens.push(token);
}
Ok(out)
}
}
#[cfg(not(feature = "llama"))]
struct LlamaFallback;
#[cfg(not(feature = "llama"))]
#[async_trait]
impl LoadedModel for LlamaFallback {
async fn generate(
&self,
prompt: &str,
_opts: GenOptions,
mut on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> Result<String> {
let fallback_msg =
"Llama.cpp support not enabled. Build with --features llama for full functionality.";
if let Some(cb) = on_token.as_mut() {
cb(fallback_msg.to_string());
}
Ok(format!("[INFO] {} Input: {}", fallback_msg, prompt))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llama_engine_initialization() {
let engine = LlamaEngine::new();
let _ = engine; }
#[tokio::test]
async fn test_model_loading_validation() {
let _engine = LlamaEngine::new();
let _spec = ModelSpec {
name: "test".to_string(),
base_path: "/nonexistent".into(),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 2048,
n_threads: None,
};
}
#[test]
fn test_model_spec_validation() {
let spec = ModelSpec {
name: "valid".to_string(),
base_path: "test.gguf".into(),
lora_path: None,
template: Some("chatml".to_string()),
ctx_len: 4096,
n_threads: Some(4),
};
assert_eq!(spec.name, "valid");
assert_eq!(spec.ctx_len, 4096);
assert!(spec.template.is_some());
}
}