Skip to main content

llama_cpp_v3_agent_sdk/
inference.rs

1//! Shared inference engine — load a model once, share it across agents.
2//!
3//! This module follows the same pattern as `vnai::ai::TextGeneration`:
4//! the heavy resources (`LlamaBackend`, `LlamaModel`) are wrapped in `Arc`
5//! so they can be cloned cheaply and shared between multiple `Agent` instances.
6//! Each agent creates its own `LlamaContext` (KV cache), so agents don't
7//! interfere with each other.
8//!
9//! ## Concurrency
10//!
11//! - **Without a scheduler**: Agents run truly parallel (safe, but GPU-heavy).
12//! - **With `InferenceScheduler`**: A semaphore limits how many agents can
13//!   run inference concurrently. `max_concurrent = 1` serializes all inference.
14
15use crate::error::AgentError;
16use llama_cpp_v3::{LlamaBackend, LlamaContext, LlamaModel, LoadOptions};
17use std::path::{Path, PathBuf};
18use std::sync::{Arc, Condvar, Mutex};
19
20/// Configuration for loading a model.
21#[derive(Debug, Clone)]
22pub struct InferenceConfig {
23    /// Compute backend.
24    pub backend: llama_cpp_v3::backend::Backend,
25    /// Path to the GGUF model file.
26    pub model_path: String,
27    /// Number of layers to offload to GPU (-1 = all).
28    pub n_gpu_layers: i32,
29    /// Context window size in tokens (default for contexts created from this engine).
30    pub n_ctx: u32,
31    /// Application name (used for DLL cache directory).
32    pub app_name: String,
33    /// Explicit DLL path (skips auto-download).
34    pub explicit_dll_path: Option<PathBuf>,
35    /// DLL version tag to download.
36    pub dll_version: Option<String>,
37    /// DLL cache directory.
38    pub cache_dir: Option<PathBuf>,
39    /// Optional chat template (Jinja). If not provided, uses model metadata.
40    pub chat_template: Option<String>,
41}
42
43impl Default for InferenceConfig {
44    fn default() -> Self {
45        Self {
46            backend: llama_cpp_v3::backend::Backend::Cpu,
47            model_path: String::new(),
48            n_gpu_layers: 0,
49            n_ctx: 8192,
50            app_name: "llama-cpp-v3-agent-sdk".to_string(),
51            explicit_dll_path: None,
52            dll_version: None,
53            cache_dir: None,
54            chat_template: None,
55        }
56    }
57}
58
59/// Common chat templates for models that lack them.
60pub mod templates {
61    /// Llama-3 / Llama-3.1 chat template.
62    pub const LLAMA_3: &str = "{% set loop_messages = messages %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{ bos_token }}{% endif %}{{ '<|start_header_id|>' + message['role'] + '<|end_header_id|>\\n\\n' + message['content'] | trim + '<|eot_id|>' }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\\n\\n' }}{% endif %}";
63
64    /// ChatML template (used by Qwen, Yi, Hermes, etc.).
65    pub const CHATML: &str = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\\n' + message['content'] + '<|im_end|>\\n'}}{% endfor %}{% if add_generation_prompt %}{{'<|im_start|>assistant\\n'}}{% endif %}";
66
67    /// Llama-2 template.
68    pub const LLAMA_2: &str = "{% for message in messages %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'system' %}{{ '<<SYS>>\\n' + message['content'] + '\\n<</SYS>>\\n\\n' }}{% else %}{{ message['content'] }}{% endif %}{% endfor %}";
69}
70
71/// Shared inference engine that holds the backend + model in `Arc`s.
72///
73/// This is the resource-heavy object: it loads the DLL and the model weights.
74/// Multiple agents can share the same `InferenceEngine` — each agent just
75/// creates its own lightweight `LlamaContext`.
76///
77/// # Example
78/// ```no_run
79/// use llama_cpp_v3_agent_sdk::inference::{InferenceEngine, InferenceConfig};
80/// use llama_cpp_v3::backend::Backend;
81/// use std::sync::Arc;
82///
83/// let engine = Arc::new(InferenceEngine::load(InferenceConfig {
84///     backend: Backend::Vulkan,
85///     model_path: "model.gguf".into(),
86///     n_gpu_layers: 99,
87///     ..Default::default()
88/// }).expect("Failed to load model"));
89///
90/// // Share with multiple agents:
91/// let agent1 = llama_cpp_v3_agent_sdk::AgentBuilder::new()
92///     .engine(engine.clone())
93///     .build().unwrap();
94/// let agent2 = llama_cpp_v3_agent_sdk::AgentBuilder::new()
95///     .engine(engine.clone())
96///     .system_prompt("You are a different agent.")
97///     .build().unwrap();
98/// ```
99#[derive(Clone)]
100pub struct InferenceEngine {
101    pub backend: Arc<LlamaBackend>,
102    pub model: Arc<LlamaModel>,
103    pub config: InferenceConfig,
104}
105
106impl InferenceEngine {
107    /// Load a model from the given configuration.
108    ///
109    /// This performs the expensive operations (DLL loading, model weight loading)
110    /// exactly once. The returned engine can be wrapped in `Arc` and shared.
111    pub fn load(config: InferenceConfig) -> Result<Self, AgentError> {
112        let path = Path::new(&config.model_path);
113        if !path.exists() {
114            return Err(AgentError::Other(format!(
115                "Model file not found: {}",
116                config.model_path
117            )));
118        }
119
120        let options = LoadOptions {
121            backend: config.backend,
122            app_name: &config.app_name,
123            version: config.dll_version.as_deref(),
124            explicit_path: config.explicit_dll_path.as_deref(),
125            cache_dir: config.cache_dir.clone(),
126        };
127
128        let backend = LlamaBackend::load(options)?;
129
130        let mut model_params = LlamaModel::default_params(&backend);
131        model_params.n_gpu_layers = config.n_gpu_layers;
132
133        let path_str = config.model_path.replace('\\', "/");
134        let model = LlamaModel::load_from_file(&backend, &path_str, model_params)?;
135
136        Ok(Self {
137            backend: Arc::new(backend),
138            model: Arc::new(model),
139            config,
140        })
141    }
142
143    /// Create a new `LlamaContext` from this engine.
144    ///
145    /// Each agent should have its own context (it holds the KV cache).
146    /// The `n_ctx` override lets callers use a different context size
147    /// than the engine default.
148    pub fn create_context(&self, n_ctx_override: Option<u32>) -> Result<LlamaContext, AgentError> {
149        let mut ctx_params = LlamaContext::default_params(&self.model);
150        ctx_params.n_ctx = n_ctx_override.unwrap_or(self.config.n_ctx);
151        let ctx = LlamaContext::new(&self.model, ctx_params)?;
152        Ok(ctx)
153    }
154
155    /// Access the raw `LlamaModel`.
156    pub fn model(&self) -> &LlamaModel {
157        &self.model
158    }
159
160    /// Access the raw `LlamaBackend`.
161    pub fn backend(&self) -> &LlamaBackend {
162        &self.backend
163    }
164
165    /// Get the `Arc<LlamaLib>` for creating samplers and batches.
166    pub fn lib(&self) -> Arc<llama_cpp_sys_v3::LlamaLib> {
167        self.backend.lib.clone()
168    }
169}
170
171// ─────────────────────────────────────────────────────────────────────────────
172// Inference Scheduler
173// ─────────────────────────────────────────────────────────────────────────────
174
175/// Controls how many agents can perform inference at the same time.
176///
177/// This is a simple counting semaphore: agents call `acquire()` before running
178/// their inference loop and `release()` when done. If `max_concurrent` slots
179/// are already in use, `acquire()` blocks until one is freed.
180///
181/// # Why?
182///
183/// Each agent has its own `LlamaContext` (KV cache) which is independent and
184/// thread-safe. But all contexts share the same GPU for compute. Running too
185/// many inferences in parallel can:
186/// - Exhaust GPU VRAM (multiple KV caches)
187/// - Thrash the GPU scheduler (context switches)
188/// - Cause OOM errors on smaller GPUs
189///
190/// A scheduler with `max_concurrent = 1` serializes all inference (like the
191/// worker-thread pattern in `vnai::ai`), while higher values allow controlled
192/// parallelism.
193///
194/// # Example
195/// ```
196/// use llama_cpp_v3_agent_sdk::InferenceScheduler;
197/// use std::sync::Arc;
198///
199/// // Allow at most 2 agents to infer concurrently:
200/// let scheduler = Arc::new(InferenceScheduler::new(2));
201///
202/// // Use with AgentBuilder:
203/// // AgentBuilder::new()
204/// //     .engine(engine.clone())
205/// //     .scheduler(scheduler.clone())
206/// //     .build()?;
207/// ```
208pub struct InferenceScheduler {
209    state: Mutex<SchedulerState>,
210    cond: Condvar,
211    pool: Mutex<Vec<LlamaContext>>,
212}
213
214struct SchedulerState {
215    max_concurrent: usize,
216    active: usize,
217}
218
219/// RAII guard — releases the scheduler slot and returns the context to the pool on drop.
220pub struct InferencePermit<'a> {
221    scheduler: &'a InferenceScheduler,
222    context: Option<LlamaContext>,
223}
224
225impl<'a> InferencePermit<'a> {
226    /// Access the leased context.
227    pub fn context_mut(&mut self) -> Option<&mut LlamaContext> {
228        self.context.as_mut()
229    }
230}
231
232impl<'a> Drop for InferencePermit<'a> {
233    fn drop(&mut self) {
234        let ctx = self.context.take();
235        self.scheduler.release(ctx);
236    }
237}
238
239impl InferenceScheduler {
240    /// Create a new scheduler with the given concurrency limit.
241    ///
242    /// - `max_concurrent = 1` → fully serialized (one agent at a time)
243    /// - `max_concurrent = N` → up to N agents run inference in parallel
244    pub fn new(max_concurrent: usize) -> Self {
245        assert!(max_concurrent > 0, "max_concurrent must be at least 1");
246        Self {
247            state: Mutex::new(SchedulerState {
248                max_concurrent,
249                active: 0,
250            }),
251            cond: Condvar::new(),
252            pool: Mutex::new(Vec::with_capacity(max_concurrent)),
253        }
254    }
255
256    /// Pre-initialize the context pool with the given engine.
257    /// This avoids lazy allocation during the first inference runs.
258    pub fn init_pool(
259        &self,
260        engine: &InferenceEngine,
261        n_ctx: Option<u32>,
262    ) -> Result<(), AgentError> {
263        let mut pool = self.pool.lock().unwrap();
264        let count = self.max_concurrent();
265        for _ in 0..count {
266            pool.push(engine.create_context(n_ctx)?);
267        }
268        Ok(())
269    }
270
271    /// Acquire a permit and a context from the pool. Blocks if all slots are in use.
272    ///
273    /// Returns an RAII guard that automatically releases the slot on drop.
274    pub fn acquire(&self) -> InferencePermit<'_> {
275        let mut state = self.state.lock().unwrap();
276        while state.active >= state.max_concurrent {
277            state = self.cond.wait(state).unwrap();
278        }
279        state.active += 1;
280
281        let context = self.pool.lock().unwrap().pop();
282        InferencePermit {
283            scheduler: self,
284            context,
285        }
286    }
287
288    /// Try to acquire a permit without blocking.
289    ///
290    /// Returns `None` if all slots are in use.
291    pub fn try_acquire(&self) -> Option<InferencePermit<'_>> {
292        let mut state = self.state.lock().unwrap();
293        if state.active < state.max_concurrent {
294            state.active += 1;
295            let context = self.pool.lock().unwrap().pop();
296            Some(InferencePermit {
297                scheduler: self,
298                context,
299            })
300        } else {
301            None
302        }
303    }
304
305    /// Release a slot and return the context to the pool.
306    fn release(&self, context: Option<LlamaContext>) {
307        if let Some(mut ctx) = context {
308            ctx.kv_cache_clear();
309            self.pool.lock().unwrap().push(ctx);
310        }
311
312        let mut state = self.state.lock().unwrap();
313        state.active -= 1;
314        self.cond.notify_one();
315    }
316
317    /// Number of currently active inferences.
318    pub fn active_count(&self) -> usize {
319        self.state.lock().unwrap().active
320    }
321
322    /// Maximum allowed concurrent inferences.
323    pub fn max_concurrent(&self) -> usize {
324        self.state.lock().unwrap().max_concurrent
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use std::sync::atomic::{AtomicUsize, Ordering};
332    use std::thread;
333
334    #[test]
335    fn test_scheduler_serialized() {
336        let scheduler = Arc::new(InferenceScheduler::new(1));
337        let counter = Arc::new(AtomicUsize::new(0));
338        let max_seen = Arc::new(AtomicUsize::new(0));
339
340        let mut handles = Vec::new();
341
342        for _ in 0..4 {
343            let sched = scheduler.clone();
344            let cnt = counter.clone();
345            let max = max_seen.clone();
346
347            handles.push(thread::spawn(move || {
348                let _permit = sched.acquire();
349                let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
350                // Record the max concurrency we observe
351                max.fetch_max(current, Ordering::SeqCst);
352                thread::sleep(std::time::Duration::from_millis(10));
353                cnt.fetch_sub(1, Ordering::SeqCst);
354            }));
355        }
356
357        for h in handles {
358            h.join().unwrap();
359        }
360
361        // With max_concurrent=1, we should never have seen > 1 active
362        assert_eq!(max_seen.load(Ordering::SeqCst), 1);
363    }
364
365    #[test]
366    fn test_scheduler_parallel() {
367        let scheduler = Arc::new(InferenceScheduler::new(3));
368        let counter = Arc::new(AtomicUsize::new(0));
369        let max_seen = Arc::new(AtomicUsize::new(0));
370
371        let mut handles = Vec::new();
372
373        for _ in 0..6 {
374            let sched = scheduler.clone();
375            let cnt = counter.clone();
376            let max = max_seen.clone();
377
378            handles.push(thread::spawn(move || {
379                let _permit = sched.acquire();
380                let current = cnt.fetch_add(1, Ordering::SeqCst) + 1;
381                max.fetch_max(current, Ordering::SeqCst);
382                thread::sleep(std::time::Duration::from_millis(50));
383                cnt.fetch_sub(1, Ordering::SeqCst);
384            }));
385        }
386
387        for h in handles {
388            h.join().unwrap();
389        }
390
391        // With max_concurrent=3, should have seen > 1 active (likely 3)
392        assert!(max_seen.load(Ordering::SeqCst) > 1);
393        // And never more than 3
394        assert!(max_seen.load(Ordering::SeqCst) <= 3);
395    }
396
397    #[test]
398    fn test_try_acquire() {
399        let scheduler = InferenceScheduler::new(1);
400        let _permit = scheduler.acquire();
401        assert!(scheduler.try_acquire().is_none());
402        assert_eq!(scheduler.active_count(), 1);
403    }
404}