use crate::error::AgentError;
use llama_cpp_v3::{LlamaBackend, LlamaContext, LlamaModel, LoadOptions};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Condvar, Mutex};
#[derive(Debug, Clone)]
pub struct InferenceConfig {
pub backend: llama_cpp_v3::backend::Backend,
pub model_path: String,
pub n_gpu_layers: i32,
pub n_ctx: u32,
pub app_name: String,
pub explicit_dll_path: Option<PathBuf>,
pub dll_version: Option<String>,
pub cache_dir: Option<PathBuf>,
pub chat_template: Option<String>,
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
backend: llama_cpp_v3::backend::Backend::Cpu,
model_path: String::new(),
n_gpu_layers: 0,
n_ctx: 8192,
app_name: "llama-cpp-v3-agent-sdk".to_string(),
explicit_dll_path: None,
dll_version: None,
cache_dir: None,
chat_template: None,
}
}
}
pub mod templates {
pub const LLAMA_3: &str = "{% set loop_messages = messages %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}";
pub const CHATML: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n'}}{% endfor %}{% if add_generation_prompt %}{{'<|im_start|>assistant\\n'}}{% endif %}";
pub const LLAMA_2: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% else %}{{ message['content'] }}{% endif %}{% endfor %}";
}
#[derive(Clone)]
pub struct InferenceEngine {
pub backend: Arc<LlamaBackend>,
pub model: Arc<LlamaModel>,
pub config: InferenceConfig,
}
impl InferenceEngine {
pub fn load(config: InferenceConfig) -> Result<Self, AgentError> {
let path = Path::new(&config.model_path);
if !path.exists() {
return Err(AgentError::Other(format!(
"Model file not found: {}",
config.model_path
)));
}
let options = LoadOptions {
backend: config.backend,
app_name: &config.app_name,
version: config.dll_version.as_deref(),
explicit_path: config.explicit_dll_path.as_deref(),
cache_dir: config.cache_dir.clone(),
};
let backend = LlamaBackend::load(options)?;
let mut model_params = LlamaModel::default_params(&backend);
model_params.n_gpu_layers = config.n_gpu_layers;
let path_str = config.model_path.replace('\\', "/");
let model = LlamaModel::load_from_file(&backend, &path_str, model_params)?;
Ok(Self {
backend: Arc::new(backend),
model: Arc::new(model),
config,
})
}
pub fn create_context(&self, n_ctx_override: Option<u32>) -> Result<LlamaContext, AgentError> {
let mut ctx_params = LlamaContext::default_params(&self.model);
ctx_params.n_ctx = n_ctx_override.unwrap_or(self.config.n_ctx);
let ctx = LlamaContext::new(&self.model, ctx_params)?;
Ok(ctx)
}
pub fn model(&self) -> &LlamaModel {
&self.model
}
pub fn backend(&self) -> &LlamaBackend {
&self.backend
}
pub fn lib(&self) -> Arc<llama_cpp_sys_v3::LlamaLib> {
self.backend.lib.clone()
}
}
pub struct InferenceScheduler {
state: Mutex<SchedulerState>,
cond: Condvar,
pool: Mutex<Vec<LlamaContext>>,
}
struct SchedulerState {
max_concurrent: usize,
active: usize,
}
pub struct InferencePermit<'a> {
scheduler: &'a InferenceScheduler,
context: Option<LlamaContext>,
}
impl<'a> InferencePermit<'a> {
pub fn context_mut(&mut self) -> Option<&mut LlamaContext> {
self.context.as_mut()
}
}
impl<'a> Drop for InferencePermit<'a> {
fn drop(&mut self) {
let ctx = self.context.take();
self.scheduler.release(ctx);
}
}
impl InferenceScheduler {
pub fn new(max_concurrent: usize) -> Self {
assert!(max_concurrent > 0, "max_concurrent must be at least 1");
Self {
state: Mutex::new(SchedulerState {
max_concurrent,
active: 0,
}),
cond: Condvar::new(),
pool: Mutex::new(Vec::with_capacity(max_concurrent)),
}
}
pub fn init_pool(
&self,
engine: &InferenceEngine,
n_ctx: Option<u32>,
) -> Result<(), AgentError> {
let mut pool = self.pool.lock().unwrap();
let count = self.max_concurrent();
for _ in 0..count {
pool.push(engine.create_context(n_ctx)?);
}
Ok(())
}
pub fn acquire(&self) -> InferencePermit<'_> {
let mut state = self.state.lock().unwrap();
while state.active >= state.max_concurrent {
state = self.cond.wait(state).unwrap();
}
state.active += 1;
let context = self.pool.lock().unwrap().pop();
InferencePermit {
scheduler: self,
context,
}
}
pub fn try_acquire(&self) -> Option<InferencePermit<'_>> {
let mut state = self.state.lock().unwrap();
if state.active < state.max_concurrent {
state.active += 1;
let context = self.pool.lock().unwrap().pop();
Some(InferencePermit {
scheduler: self,
context,
})
} else {
None
}
}
fn release(&self, context: Option<LlamaContext>) {
if let Some(mut ctx) = context {
ctx.kv_cache_clear();
self.pool.lock().unwrap().push(ctx);
}
let mut state = self.state.lock().unwrap();
state.active -= 1;
self.cond.notify_one();
}
pub fn active_count(&self) -> usize {
self.state.lock().unwrap().active
}
pub fn max_concurrent(&self) -> usize {
self.state.lock().unwrap().max_concurrent
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
#[test]
fn test_scheduler_serialized() {
let scheduler = Arc::new(InferenceScheduler::new(1));
let counter = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..4 {
let sched = scheduler.clone();
let cnt = counter.clone();
let max = max_seen.clone();
handles.push(thread::spawn(move || {
let _permit = sched.acquire();
let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
max.fetch_max(current, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_millis(10));
cnt.fetch_sub(1, Ordering::SeqCst);
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(max_seen.load(Ordering::SeqCst), 1);
}
#[test]
fn test_scheduler_parallel() {
let scheduler = Arc::new(InferenceScheduler::new(3));
let counter = Arc::new(AtomicUsize::new(0));
let max_seen = Arc::new(AtomicUsize::new(0));
let mut handles = Vec::new();
for _ in 0..6 {
let sched = scheduler.clone();
let cnt = counter.clone();
let max = max_seen.clone();
handles.push(thread::spawn(move || {
let _permit = sched.acquire();
let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
max.fetch_max(current, Ordering::SeqCst);
thread::sleep(std::time::Duration::from_millis(50));
cnt.fetch_sub(1, Ordering::SeqCst);
}));
}
for h in handles {
h.join().unwrap();
}
assert!(max_seen.load(Ordering::SeqCst) > 1);
assert!(max_seen.load(Ordering::SeqCst) <= 3);
}
#[test]
fn test_try_acquire() {
let scheduler = InferenceScheduler::new(1);
let _permit = scheduler.acquire();
assert!(scheduler.try_acquire().is_none());
assert_eq!(scheduler.active_count(), 1);
}
}