llama_cpp_v3_agent_sdk/
inference.rs1use crate::error::AgentError;
16use llama_cpp_v3::{LlamaBackend, LlamaContext, LlamaModel, LoadOptions};
17use std::path::{Path, PathBuf};
18use std::sync::{Arc, Condvar, Mutex};
19
20#[derive(Debug, Clone)]
22pub struct InferenceConfig {
23 pub backend: llama_cpp_v3::backend::Backend,
25 pub model_path: String,
27 pub n_gpu_layers: i32,
29 pub n_ctx: u32,
31 pub app_name: String,
33 pub explicit_dll_path: Option<PathBuf>,
35 pub dll_version: Option<String>,
37 pub cache_dir: Option<PathBuf>,
39 pub chat_template: Option<String>,
41}
42
43impl Default for InferenceConfig {
44 fn default() -> Self {
45 Self {
46 backend: llama_cpp_v3::backend::Backend::Cpu,
47 model_path: String::new(),
48 n_gpu_layers: 0,
49 n_ctx: 8192,
50 app_name: "llama-cpp-v3-agent-sdk".to_string(),
51 explicit_dll_path: None,
52 dll_version: None,
53 cache_dir: None,
54 chat_template: None,
55 }
56 }
57}
58
59pub mod templates {
61 pub const LLAMA_3: &str = "{% set loop_messages = messages %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}";
63
64 pub const CHATML: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n'}}{% endfor %}{% if add_generation_prompt %}{{'<|im_start|>assistant\\n'}}{% endif %}";
66
67 pub const LLAMA_2: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% else %}{{ message['content'] }}{% endif %}{% endfor %}";
69}
70
71#[derive(Clone)]
100pub struct InferenceEngine {
101 pub backend: Arc<LlamaBackend>,
102 pub model: Arc<LlamaModel>,
103 pub config: InferenceConfig,
104}
105
106impl InferenceEngine {
107 pub fn load(config: InferenceConfig) -> Result<Self, AgentError> {
112 let path = Path::new(&config.model_path);
113 if !path.exists() {
114 return Err(AgentError::Other(format!(
115 "Model file not found: {}",
116 config.model_path
117 )));
118 }
119
120 let options = LoadOptions {
121 backend: config.backend,
122 app_name: &config.app_name,
123 version: config.dll_version.as_deref(),
124 explicit_path: config.explicit_dll_path.as_deref(),
125 cache_dir: config.cache_dir.clone(),
126 };
127
128 let backend = LlamaBackend::load(options)?;
129
130 let mut model_params = LlamaModel::default_params(&backend);
131 model_params.n_gpu_layers = config.n_gpu_layers;
132
133 let path_str = config.model_path.replace('\\', "/");
134 let model = LlamaModel::load_from_file(&backend, &path_str, model_params)?;
135
136 Ok(Self {
137 backend: Arc::new(backend),
138 model: Arc::new(model),
139 config,
140 })
141 }
142
143 pub fn create_context(&self, n_ctx_override: Option<u32>) -> Result<LlamaContext, AgentError> {
149 let mut ctx_params = LlamaContext::default_params(&self.model);
150 ctx_params.n_ctx = n_ctx_override.unwrap_or(self.config.n_ctx);
151 let ctx = LlamaContext::new(&self.model, ctx_params)?;
152 Ok(ctx)
153 }
154
155 pub fn model(&self) -> &LlamaModel {
157 &self.model
158 }
159
160 pub fn backend(&self) -> &LlamaBackend {
162 &self.backend
163 }
164
165 pub fn lib(&self) -> Arc<llama_cpp_sys_v3::LlamaLib> {
167 self.backend.lib.clone()
168 }
169}
170
171pub struct InferenceScheduler {
209 state: Mutex<SchedulerState>,
210 cond: Condvar,
211 pool: Mutex<Vec<LlamaContext>>,
212}
213
214struct SchedulerState {
215 max_concurrent: usize,
216 active: usize,
217}
218
219pub struct InferencePermit<'a> {
221 scheduler: &'a InferenceScheduler,
222 context: Option<LlamaContext>,
223}
224
225impl<'a> InferencePermit<'a> {
226 pub fn context_mut(&mut self) -> Option<&mut LlamaContext> {
228 self.context.as_mut()
229 }
230}
231
232impl<'a> Drop for InferencePermit<'a> {
233 fn drop(&mut self) {
234 let ctx = self.context.take();
235 self.scheduler.release(ctx);
236 }
237}
238
239impl InferenceScheduler {
240 pub fn new(max_concurrent: usize) -> Self {
245 assert!(max_concurrent > 0, "max_concurrent must be at least 1");
246 Self {
247 state: Mutex::new(SchedulerState {
248 max_concurrent,
249 active: 0,
250 }),
251 cond: Condvar::new(),
252 pool: Mutex::new(Vec::with_capacity(max_concurrent)),
253 }
254 }
255
256 pub fn init_pool(
259 &self,
260 engine: &InferenceEngine,
261 n_ctx: Option<u32>,
262 ) -> Result<(), AgentError> {
263 let mut pool = self.pool.lock().unwrap();
264 let count = self.max_concurrent();
265 for _ in 0..count {
266 pool.push(engine.create_context(n_ctx)?);
267 }
268 Ok(())
269 }
270
271 pub fn acquire(&self) -> InferencePermit<'_> {
275 let mut state = self.state.lock().unwrap();
276 while state.active >= state.max_concurrent {
277 state = self.cond.wait(state).unwrap();
278 }
279 state.active += 1;
280
281 let context = self.pool.lock().unwrap().pop();
282 InferencePermit {
283 scheduler: self,
284 context,
285 }
286 }
287
288 pub fn try_acquire(&self) -> Option<InferencePermit<'_>> {
292 let mut state = self.state.lock().unwrap();
293 if state.active < state.max_concurrent {
294 state.active += 1;
295 let context = self.pool.lock().unwrap().pop();
296 Some(InferencePermit {
297 scheduler: self,
298 context,
299 })
300 } else {
301 None
302 }
303 }
304
305 fn release(&self, context: Option<LlamaContext>) {
307 if let Some(mut ctx) = context {
308 ctx.kv_cache_clear();
309 self.pool.lock().unwrap().push(ctx);
310 }
311
312 let mut state = self.state.lock().unwrap();
313 state.active -= 1;
314 self.cond.notify_one();
315 }
316
317 pub fn active_count(&self) -> usize {
319 self.state.lock().unwrap().active
320 }
321
322 pub fn max_concurrent(&self) -> usize {
324 self.state.lock().unwrap().max_concurrent
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use std::sync::atomic::{AtomicUsize, Ordering};
332 use std::thread;
333
334 #[test]
335 fn test_scheduler_serialized() {
336 let scheduler = Arc::new(InferenceScheduler::new(1));
337 let counter = Arc::new(AtomicUsize::new(0));
338 let max_seen = Arc::new(AtomicUsize::new(0));
339
340 let mut handles = Vec::new();
341
342 for _ in 0..4 {
343 let sched = scheduler.clone();
344 let cnt = counter.clone();
345 let max = max_seen.clone();
346
347 handles.push(thread::spawn(move || {
348 let _permit = sched.acquire();
349 let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
350 max.fetch_max(current, Ordering::SeqCst);
352 thread::sleep(std::time::Duration::from_millis(10));
353 cnt.fetch_sub(1, Ordering::SeqCst);
354 }));
355 }
356
357 for h in handles {
358 h.join().unwrap();
359 }
360
361 assert_eq!(max_seen.load(Ordering::SeqCst), 1);
363 }
364
365 #[test]
366 fn test_scheduler_parallel() {
367 let scheduler = Arc::new(InferenceScheduler::new(3));
368 let counter = Arc::new(AtomicUsize::new(0));
369 let max_seen = Arc::new(AtomicUsize::new(0));
370
371 let mut handles = Vec::new();
372
373 for _ in 0..6 {
374 let sched = scheduler.clone();
375 let cnt = counter.clone();
376 let max = max_seen.clone();
377
378 handles.push(thread::spawn(move || {
379 let _permit = sched.acquire();
380 let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
381 max.fetch_max(current, Ordering::SeqCst);
382 thread::sleep(std::time::Duration::from_millis(50));
383 cnt.fetch_sub(1, Ordering::SeqCst);
384 }));
385 }
386
387 for h in handles {
388 h.join().unwrap();
389 }
390
391 assert!(max_seen.load(Ordering::SeqCst) > 1);
393 assert!(max_seen.load(Ordering::SeqCst) <= 3);
395 }
396
397 #[test]
398 fn test_try_acquire() {
399 let scheduler = InferenceScheduler::new(1);
400 let _permit = scheduler.acquire();
401 assert!(scheduler.try_acquire().is_none());
402 assert_eq!(scheduler.active_count(), 1);
403 }
404}