rig_llama_cpp/client.rs
1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use rig::client::CompletionClient;
6use rig::completion::{
7 CompletionError, CompletionModel, CompletionRequest, CompletionResponse, Usage,
8};
9use rig::streaming::StreamingCompletionResponse;
10use tokio::sync::{mpsc, oneshot};
11use tokio_stream::wrappers::UnboundedReceiverStream;
12
13use crate::error::LoadError;
14use crate::request::prepare_request;
15use crate::types::{
16 CheckpointParams, FitParams, InferenceCommand, InferenceParams, InferenceRequest,
17 KvCacheParams, RawResponse, ReloadRequest, ResponseChannel, SamplingParams, StreamChunk,
18};
19use crate::worker::{WorkerInit, inference_worker};
20
21/// Default context window used by [`ClientBuilder`] when `n_ctx` is not set.
22const DEFAULT_N_CTX: u32 = 4096;
23
24/// Capacity of the inference command channel. Bounded to apply backpressure
25/// to misbehaving callers (a flood of requests can't grow the worker's queue
26/// without limit). Eight is generous for a single-worker llama.cpp client —
27/// generation is the bottleneck, not enqueueing — and leaves headroom for
28/// `Reload` / `Shutdown` to slip in alongside in-flight `Request`s.
29const COMMAND_CHANNEL_CAPACITY: usize = 8;
30
31/// Builder for [`Client`].
32///
33/// Construct one with [`Client::builder`], then chain optional setters and
34/// finish with [`ClientBuilder::build`]. Every field except `model_path`
35/// has a sensible default, so the minimal usage is:
36///
37/// ```rust,no_run
38/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
39/// let client = rig_llama_cpp::Client::builder("path/to/model.gguf").build()?;
40/// # let _ = client;
41/// # Ok(())
42/// # }
43/// ```
44///
45/// The builder shape is forward-compatible: new optional knobs can be added
46/// without breaking existing call sites.
47#[must_use]
48pub struct ClientBuilder {
49 model_path: String,
50 #[cfg(feature = "mtmd")]
51 mmproj_path: Option<String>,
52 n_ctx: u32,
53 sampling: SamplingParams,
54 fit: FitParams,
55 kv_cache: KvCacheParams,
56 checkpoint: CheckpointParams,
57}
58
59impl ClientBuilder {
60 fn new(model_path: impl Into<String>) -> Self {
61 Self {
62 model_path: model_path.into(),
63 #[cfg(feature = "mtmd")]
64 mmproj_path: None,
65 n_ctx: DEFAULT_N_CTX,
66 sampling: SamplingParams::default(),
67 fit: FitParams::default(),
68 kv_cache: KvCacheParams::default(),
69 checkpoint: CheckpointParams::default(),
70 }
71 }
72
73 /// Desired context window size in tokens. Defaults to `4096`.
74 pub fn n_ctx(mut self, n_ctx: u32) -> Self {
75 self.n_ctx = n_ctx;
76 self
77 }
78
79 /// Token sampling parameters.
80 pub fn sampling(mut self, sampling: SamplingParams) -> Self {
81 self.sampling = sampling;
82 self
83 }
84
85 /// Automatic-fit parameters (per-device memory margins, minimum context).
86 pub fn fit(mut self, fit: FitParams) -> Self {
87 self.fit = fit;
88 self
89 }
90
91 /// KV cache data-type configuration. Defaults to F16 / F16.
92 pub fn kv_cache(mut self, kv_cache: KvCacheParams) -> Self {
93 self.kv_cache = kv_cache;
94 self
95 }
96
97 /// In-memory state-checkpoint cache tunables (used by hybrid/recurrent
98 /// models to preserve KV state across turns).
99 pub fn checkpoints(mut self, checkpoint: CheckpointParams) -> Self {
100 self.checkpoint = checkpoint;
101 self
102 }
103
104 /// Path to a multimodal projector (`mmproj`) GGUF file. Setting this
105 /// switches the resulting [`Client`] into vision mode. Only available
106 /// when the `mtmd` feature is enabled.
107 #[cfg(feature = "mtmd")]
108 pub fn mmproj(mut self, mmproj_path: impl Into<String>) -> Self {
109 self.mmproj_path = Some(mmproj_path.into());
110 self
111 }
112
113 /// Spawn the inference worker thread, load the model, and return a
114 /// ready-to-use [`Client`].
115 ///
116 /// # Errors
117 ///
118 /// Returns a [`LoadError`] if the backend fails to initialise, automatic
119 /// fitting fails, the GGUF file cannot be loaded, or — when `mmproj` was
120 /// set — the multimodal projector cannot be initialised.
121 pub fn build(self) -> Result<Client, LoadError> {
122 #[cfg(feature = "mtmd")]
123 let mmproj_path = self.mmproj_path;
124 #[cfg(not(feature = "mtmd"))]
125 let mmproj_path: Option<String> = None;
126
127 Client::spawn(
128 self.model_path,
129 mmproj_path,
130 self.n_ctx,
131 self.sampling,
132 self.fit,
133 self.kv_cache,
134 self.checkpoint,
135 )
136 }
137}
138
139/// The llama.cpp completion client.
140///
141/// `Client` loads a GGUF model on a dedicated inference thread and exposes it
142/// through Rig's [`CompletionClient`] trait. Construct one with
143/// [`Client::builder`], or — for backward-compatible positional construction —
144/// [`Client::from_gguf`].
145///
146/// # Lifecycle
147///
148/// The worker thread owns the `LlamaModel`, `LlamaContext`, and (when the
149/// `mtmd` feature is on) the multimodal projector. It only releases that
150/// memory when it exits, which happens in two cases:
151///
152/// - On [`Client::reload`], the worker drops the old model and loads the new
153/// one in place — the `Client` itself is **not** dropped, and the worker
154/// thread is reused. Caller blocks on the reload result.
155/// - On [`Client::drop`], the worker thread is signalled and joined. See
156/// [`impl Drop for Client`](Client#impl-Drop-for-Client) for the exact
157/// semantics — including how a long in-flight generation is cancelled so
158/// the dropping thread doesn't have to wait for it to finish naturally.
159pub struct Client {
160 request_tx: mpsc::Sender<InferenceCommand>,
161 /// Shared shutdown flag. Set by [`Client::drop`] so the worker's prompt
162 /// prefill and sampling loops short-circuit at their next polling point.
163 /// Cloned into the worker via [`WorkerInit::cancel`].
164 cancel: Arc<AtomicBool>,
165 sampling_params: std::sync::RwLock<SamplingParams>,
166 worker_handle: Option<thread::JoinHandle<()>>,
167}
168
169impl Client {
170 /// Start a [`ClientBuilder`] for a GGUF model at `model_path`.
171 pub fn builder(model_path: impl Into<String>) -> ClientBuilder {
172 ClientBuilder::new(model_path)
173 }
174
175 /// Load a GGUF model with automatic GPU/CPU layer fitting and start the inference worker thread.
176 ///
177 /// llama.cpp will probe available device memory and determine the optimal layer
178 /// distribution automatically.
179 ///
180 /// Prefer [`Client::builder`] for new code — this constructor is kept for
181 /// backward compatibility with the positional 0.1.x API and forwards
182 /// directly to the builder.
183 ///
184 /// # Arguments
185 ///
186 /// * `model_path` — Path to a `.gguf` model file.
187 /// * `n_ctx` — Desired context window size in tokens.
188 /// * `sampling_params` — Sampling parameters for token generation.
189 /// * `fit_params` — Configuration for the fitting algorithm.
190 /// * `kv_cache_params` — KV cache data-type configuration (defaults to F16/F16).
191 /// * `checkpoint_params` — Tunables for the in-memory state-checkpoint cache
192 /// used to preserve KV/recurrent state across chat turns for hybrid models.
193 ///
194 /// # Errors
195 ///
196 /// Returns a [`LoadError`] if the backend fails to initialise, automatic
197 /// fitting fails, or the model cannot be loaded.
198 pub fn from_gguf(
199 model_path: impl Into<String>,
200 n_ctx: u32,
201 sampling_params: SamplingParams,
202 fit_params: FitParams,
203 kv_cache_params: KvCacheParams,
204 checkpoint_params: CheckpointParams,
205 ) -> Result<Self, LoadError> {
206 Self::spawn(
207 model_path.into(),
208 None,
209 n_ctx,
210 sampling_params,
211 fit_params,
212 kv_cache_params,
213 checkpoint_params,
214 )
215 }
216
217 /// Load a GGUF vision model with a multimodal projector and automatic GPU/CPU layer fitting.
218 ///
219 /// This constructor enables multimodal (vision) inference. The `mmproj_path` should point
220 /// to a GGUF multimodal projector file (mmproj) that corresponds to the vision model.
221 ///
222 /// Prefer [`Client::builder`] with [`ClientBuilder::mmproj`] for new code.
223 ///
224 /// # Arguments
225 ///
226 /// * `model_path` — Path to a `.gguf` vision model file.
227 /// * `mmproj_path` — Path to the corresponding multimodal projector `.gguf` file.
228 /// * `n_ctx` — Desired context window size in tokens.
229 /// * `sampling_params` — Sampling parameters for token generation.
230 /// * `fit_params` — Configuration for the fitting algorithm.
231 /// * `kv_cache_params` — KV cache data-type configuration (defaults to F16/F16).
232 ///
233 /// # Errors
234 ///
235 /// Returns a [`LoadError`] if the backend fails to initialise, the model
236 /// cannot be loaded, or the multimodal projector cannot be initialised.
237 #[cfg(feature = "mtmd")]
238 pub fn from_gguf_with_mmproj(
239 model_path: impl Into<String>,
240 mmproj_path: impl Into<String>,
241 n_ctx: u32,
242 sampling_params: SamplingParams,
243 fit_params: FitParams,
244 kv_cache_params: KvCacheParams,
245 checkpoint_params: CheckpointParams,
246 ) -> Result<Self, LoadError> {
247 Self::spawn(
248 model_path.into(),
249 Some(mmproj_path.into()),
250 n_ctx,
251 sampling_params,
252 fit_params,
253 kv_cache_params,
254 checkpoint_params,
255 )
256 }
257
258 /// Shared spawn path used by the builder and by the positional constructors.
259 /// `mmproj_path` is only consulted when the `mtmd` feature is enabled; with
260 /// the feature off, callers always pass `None` and the worker thread
261 /// silently ignores any value.
262 fn spawn(
263 model_path: String,
264 mmproj_path: Option<String>,
265 n_ctx: u32,
266 sampling_params: SamplingParams,
267 fit_params: FitParams,
268 kv_cache_params: KvCacheParams,
269 checkpoint_params: CheckpointParams,
270 ) -> Result<Self, LoadError> {
271 let (request_tx, mut request_rx) =
272 mpsc::channel::<InferenceCommand>(COMMAND_CHANNEL_CAPACITY);
273 let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<(), LoadError>>();
274 let cancel = Arc::new(AtomicBool::new(false));
275 let worker_cancel = Arc::clone(&cancel);
276
277 let worker_handle = thread::spawn(move || {
278 let init = WorkerInit {
279 model_path: &model_path,
280 mmproj_path: mmproj_path.as_deref(),
281 n_ctx,
282 fit_params: &fit_params,
283 kv_cache_params: &kv_cache_params,
284 checkpoint_params,
285 cancel: worker_cancel,
286 };
287 inference_worker(init, init_tx, &mut request_rx);
288 });
289
290 init_rx
291 .recv()
292 .map_err(|_| LoadError::WorkerInitDisconnected)??;
293
294 Ok(Self {
295 request_tx,
296 cancel,
297 sampling_params: std::sync::RwLock::new(sampling_params),
298 worker_handle: Some(worker_handle),
299 })
300 }
301
302 /// Reload the worker thread with a new model without destroying the backend.
303 ///
304 /// This swaps the model in-place on the existing inference thread, avoiding the
305 /// `LlamaBackend` singleton re-initialization race that occurs when dropping and
306 /// recreating a `Client`.
307 ///
308 /// # Errors
309 ///
310 /// Returns [`LoadError::WorkerNotRunning`] if the inference worker is no
311 /// longer accepting commands, or any of the load-stage variants if the
312 /// new model fails to come up.
313 // The positional signature is part of the 0.1.x public API. A future minor
314 // release can introduce a `ReloadOptions`/`reload_builder` shape; until
315 // then, the eight params (self + 7 fields) intentionally stay positional.
316 #[allow(clippy::too_many_arguments)]
317 pub fn reload(
318 &self,
319 model_path: String,
320 mmproj_path: Option<String>,
321 n_ctx: u32,
322 sampling: SamplingParams,
323 fit_params: FitParams,
324 kv_cache_params: KvCacheParams,
325 checkpoint_params: CheckpointParams,
326 ) -> Result<(), LoadError> {
327 let (result_tx, result_rx) = std::sync::mpsc::channel();
328 // `blocking_send` is the right call here: `reload` is a sync API and
329 // is documented to be invoked from a `spawn_blocking` task (or any
330 // non-async thread) when used from a tokio context. Backpressure on a
331 // full command queue is fine — reload is itself a blocking operation.
332 self.request_tx
333 .blocking_send(InferenceCommand::Reload(ReloadRequest {
334 model_path,
335 mmproj_path,
336 n_ctx,
337 fit_params,
338 kv_cache_params,
339 checkpoint_params,
340 result_tx,
341 }))
342 .map_err(|_| LoadError::WorkerNotRunning)?;
343 let result = result_rx
344 .recv()
345 .map_err(|_| LoadError::WorkerInitDisconnected)?;
346 if result.is_ok() {
347 // SamplingParams is `Copy` (just numeric scalars) — a poisoned
348 // lock can't represent torn or invalid data, so recover the
349 // guard rather than propagate a panic.
350 let mut guard = self
351 .sampling_params
352 .write()
353 .unwrap_or_else(|p| p.into_inner());
354 *guard = sampling;
355 }
356 result
357 }
358}
359
360impl Drop for Client {
361 /// Tear down the worker thread synchronously.
362 ///
363 /// `Drop` blocks until the worker thread has fully exited and the
364 /// `LlamaModel` / `LlamaContext` (and `LlamaBackend` device handles, plus
365 /// the multimodal projector when the `mtmd` feature is on) are released.
366 /// This is intentional: the caller almost always wants to allocate a
367 /// replacement `Client` immediately after dropping this one, and a
368 /// non-blocking drop would briefly hold 2× the model's RAM/VRAM and risk
369 /// OOM. [`Client::reload`] reuses the same worker and avoids this whole
370 /// path; prefer it over drop-and-recreate when you can.
371 ///
372 /// To keep the wait short even when a long generation is mid-flight,
373 /// `Drop` flips the shared cancel flag before signalling shutdown. The
374 /// worker polls the flag at every prompt-prefill chunk boundary and
375 /// every sampled token, so an in-flight `Request` returns within a
376 /// single decode step. The pessimal wait is therefore one decode step,
377 /// not the rest of the generation.
378 ///
379 /// `try_send(Shutdown)` is best-effort: if the bounded command queue is
380 /// full at this instant, the `Shutdown` command isn't enqueued — but the
381 /// in-flight request still bails on the cancel flag, and the worker's
382 /// per-iteration cancel check at the top of its command loop also exits
383 /// the thread before pulling more queued commands.
384 ///
385 /// `Model` clones outliving the `Client` keep the channel sender count
386 /// above zero; their `send` calls naturally fail with `SendError` once
387 /// the receiver is dropped on worker exit, so they don't prevent
388 /// shutdown.
389 fn drop(&mut self) {
390 self.cancel.store(true, Ordering::Relaxed);
391 let _ = self.request_tx.try_send(InferenceCommand::Shutdown);
392
393 if let Some(worker_handle) = self.worker_handle.take() {
394 let _ = worker_handle.join();
395 }
396 }
397}
398
399impl CompletionClient for Client {
400 type CompletionModel = Model;
401}
402
403/// A handle to a loaded model that implements Rig's [`CompletionModel`] trait.
404///
405/// Obtained via [`CompletionClient::agent`] on a [`Client`].
406#[derive(Clone)]
407pub struct Model {
408 request_tx: mpsc::Sender<InferenceCommand>,
409 sampling_params: SamplingParams,
410 #[allow(dead_code)]
411 model_id: String,
412}
413
414impl CompletionModel for Model {
415 type Response = RawResponse;
416 type StreamingResponse = StreamChunk;
417 type Client = Client;
418
419 fn make(client: &Client, model: impl Into<String>) -> Self {
420 // See the matching `unwrap_or_else` in `reload`: SamplingParams is
421 // `Copy`, so a poisoned lock still holds valid data.
422 let sampling_params = *client
423 .sampling_params
424 .read()
425 .unwrap_or_else(|p| p.into_inner());
426 Self {
427 request_tx: client.request_tx.clone(),
428 sampling_params,
429 model_id: model.into(),
430 }
431 }
432
433 async fn completion(
434 &self,
435 request: CompletionRequest,
436 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
437 let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
438 let max_tokens = request.max_tokens.unwrap_or(512) as u32;
439 let temperature = request.temperature.unwrap_or(0.7) as f32;
440
441 let (response_tx, response_rx) = oneshot::channel();
442
443 self.request_tx
444 .send(InferenceCommand::Request(InferenceRequest {
445 params: InferenceParams {
446 prepared_request,
447 max_tokens,
448 temperature,
449 top_p: self.sampling_params.top_p,
450 top_k: self.sampling_params.top_k,
451 min_p: self.sampling_params.min_p,
452 presence_penalty: self.sampling_params.presence_penalty,
453 repetition_penalty: self.sampling_params.repetition_penalty,
454 },
455 response_channel: ResponseChannel::Completion(response_tx),
456 }))
457 .await
458 .map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
459
460 let result = response_rx
461 .await
462 .map_err(|_| CompletionError::ProviderError("Response channel closed".into()))?
463 .map_err(CompletionError::ProviderError)?;
464
465 Ok(CompletionResponse {
466 choice: result.choice,
467 usage: Usage {
468 input_tokens: result.prompt_tokens,
469 output_tokens: result.completion_tokens,
470 total_tokens: result.prompt_tokens + result.completion_tokens,
471 cached_input_tokens: result.cached_input_tokens,
472 cache_creation_input_tokens: 0,
473 },
474 raw_response: RawResponse { text: result.text },
475 message_id: None,
476 })
477 }
478
479 async fn stream(
480 &self,
481 request: CompletionRequest,
482 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
483 let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
484 let max_tokens = request.max_tokens.unwrap_or(512) as u32;
485 let temperature = request.temperature.unwrap_or(0.7) as f32;
486
487 let (stream_tx, stream_rx) = mpsc::unbounded_channel();
488
489 self.request_tx
490 .send(InferenceCommand::Request(InferenceRequest {
491 params: InferenceParams {
492 prepared_request,
493 max_tokens,
494 temperature,
495 top_p: self.sampling_params.top_p,
496 top_k: self.sampling_params.top_k,
497 min_p: self.sampling_params.min_p,
498 presence_penalty: self.sampling_params.presence_penalty,
499 repetition_penalty: self.sampling_params.repetition_penalty,
500 },
501 response_channel: ResponseChannel::Streaming(stream_tx),
502 }))
503 .await
504 .map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
505
506 Ok(StreamingCompletionResponse::stream(Box::pin(
507 UnboundedReceiverStream::new(stream_rx),
508 )))
509 }
510}