Skip to main content

rig_llama_cpp/
client.rs

1use std::sync::Arc;
2use std::sync::atomic::{AtomicBool, Ordering};
3use std::thread;
4
5use rig::client::CompletionClient;
6use rig::completion::{
7    CompletionError, CompletionModel, CompletionRequest, CompletionResponse, Usage,
8};
9use rig::streaming::StreamingCompletionResponse;
10use tokio::sync::{mpsc, oneshot};
11use tokio_stream::wrappers::UnboundedReceiverStream;
12
13use crate::error::LoadError;
14use crate::request::prepare_request;
15use crate::types::{
16    CheckpointParams, FitParams, InferenceCommand, InferenceParams, InferenceRequest,
17    KvCacheParams, RawResponse, ReloadRequest, ResponseChannel, SamplingParams, StreamChunk,
18};
19use crate::worker::{WorkerInit, inference_worker};
20
21/// Default context window used by [`ClientBuilder`] when `n_ctx` is not set.
22const DEFAULT_N_CTX: u32 = 4096;
23
24/// Capacity of the inference command channel. Bounded to apply backpressure
25/// to misbehaving callers (a flood of requests can't grow the worker's queue
26/// without limit). Eight is generous for a single-worker llama.cpp client —
27/// generation is the bottleneck, not enqueueing — and leaves headroom for
28/// `Reload` / `Shutdown` to slip in alongside in-flight `Request`s.
29const COMMAND_CHANNEL_CAPACITY: usize = 8;
30
31/// Builder for [`Client`].
32///
33/// Construct one with [`Client::builder`], then chain optional setters and
34/// finish with [`ClientBuilder::build`]. Every field except `model_path`
35/// has a sensible default, so the minimal usage is:
36///
37/// ```rust,no_run
38/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
39/// let client = rig_llama_cpp::Client::builder("path/to/model.gguf").build()?;
40/// # let _ = client;
41/// # Ok(())
42/// # }
43/// ```
44///
45/// The builder shape is forward-compatible: new optional knobs can be added
46/// without breaking existing call sites.
47#[must_use]
48pub struct ClientBuilder {
49    model_path: String,
50    #[cfg(feature = "mtmd")]
51    mmproj_path: Option<String>,
52    n_ctx: u32,
53    sampling: SamplingParams,
54    fit: FitParams,
55    kv_cache: KvCacheParams,
56    checkpoint: CheckpointParams,
57}
58
59impl ClientBuilder {
60    fn new(model_path: impl Into<String>) -> Self {
61        Self {
62            model_path: model_path.into(),
63            #[cfg(feature = "mtmd")]
64            mmproj_path: None,
65            n_ctx: DEFAULT_N_CTX,
66            sampling: SamplingParams::default(),
67            fit: FitParams::default(),
68            kv_cache: KvCacheParams::default(),
69            checkpoint: CheckpointParams::default(),
70        }
71    }
72
73    /// Desired context window size in tokens. Defaults to `4096`.
74    pub fn n_ctx(mut self, n_ctx: u32) -> Self {
75        self.n_ctx = n_ctx;
76        self
77    }
78
79    /// Token sampling parameters.
80    pub fn sampling(mut self, sampling: SamplingParams) -> Self {
81        self.sampling = sampling;
82        self
83    }
84
85    /// Automatic-fit parameters (per-device memory margins, minimum context).
86    pub fn fit(mut self, fit: FitParams) -> Self {
87        self.fit = fit;
88        self
89    }
90
91    /// KV cache data-type configuration. Defaults to F16 / F16.
92    pub fn kv_cache(mut self, kv_cache: KvCacheParams) -> Self {
93        self.kv_cache = kv_cache;
94        self
95    }
96
97    /// In-memory state-checkpoint cache tunables (used by hybrid/recurrent
98    /// models to preserve KV state across turns).
99    pub fn checkpoints(mut self, checkpoint: CheckpointParams) -> Self {
100        self.checkpoint = checkpoint;
101        self
102    }
103
104    /// Path to a multimodal projector (`mmproj`) GGUF file. Setting this
105    /// switches the resulting [`Client`] into vision mode. Only available
106    /// when the `mtmd` feature is enabled.
107    #[cfg(feature = "mtmd")]
108    pub fn mmproj(mut self, mmproj_path: impl Into<String>) -> Self {
109        self.mmproj_path = Some(mmproj_path.into());
110        self
111    }
112
113    /// Spawn the inference worker thread, load the model, and return a
114    /// ready-to-use [`Client`].
115    ///
116    /// # Errors
117    ///
118    /// Returns a [`LoadError`] if the backend fails to initialise, automatic
119    /// fitting fails, the GGUF file cannot be loaded, or — when `mmproj` was
120    /// set — the multimodal projector cannot be initialised.
121    pub fn build(self) -> Result<Client, LoadError> {
122        #[cfg(feature = "mtmd")]
123        let mmproj_path = self.mmproj_path;
124        #[cfg(not(feature = "mtmd"))]
125        let mmproj_path: Option<String> = None;
126
127        Client::spawn(
128            self.model_path,
129            mmproj_path,
130            self.n_ctx,
131            self.sampling,
132            self.fit,
133            self.kv_cache,
134            self.checkpoint,
135        )
136    }
137}
138
139/// The llama.cpp completion client.
140///
141/// `Client` loads a GGUF model on a dedicated inference thread and exposes it
142/// through Rig's [`CompletionClient`] trait. Construct one with
143/// [`Client::builder`], or — for backward-compatible positional construction —
144/// [`Client::from_gguf`].
145///
146/// # Lifecycle
147///
148/// The worker thread owns the `LlamaModel`, `LlamaContext`, and (when the
149/// `mtmd` feature is on) the multimodal projector. It only releases that
150/// memory when it exits, which happens in two cases:
151///
152/// - On [`Client::reload`], the worker drops the old model and loads the new
153///   one in place — the `Client` itself is **not** dropped, and the worker
154///   thread is reused. Caller blocks on the reload result.
155/// - On [`Client::drop`], the worker thread is signalled and joined. See
156///   [`impl Drop for Client`](Client#impl-Drop-for-Client) for the exact
157///   semantics — including how a long in-flight generation is cancelled so
158///   the dropping thread doesn't have to wait for it to finish naturally.
159pub struct Client {
160    request_tx: mpsc::Sender<InferenceCommand>,
161    /// Shared shutdown flag. Set by [`Client::drop`] so the worker's prompt
162    /// prefill and sampling loops short-circuit at their next polling point.
163    /// Cloned into the worker via [`WorkerInit::cancel`].
164    cancel: Arc<AtomicBool>,
165    sampling_params: std::sync::RwLock<SamplingParams>,
166    worker_handle: Option<thread::JoinHandle<()>>,
167}
168
169impl Client {
170    /// Start a [`ClientBuilder`] for a GGUF model at `model_path`.
171    pub fn builder(model_path: impl Into<String>) -> ClientBuilder {
172        ClientBuilder::new(model_path)
173    }
174
175    /// Load a GGUF model with automatic GPU/CPU layer fitting and start the inference worker thread.
176    ///
177    /// llama.cpp will probe available device memory and determine the optimal layer
178    /// distribution automatically.
179    ///
180    /// Prefer [`Client::builder`] for new code — this constructor is kept for
181    /// backward compatibility with the positional 0.1.x API and forwards
182    /// directly to the builder.
183    ///
184    /// # Arguments
185    ///
186    /// * `model_path` — Path to a `.gguf` model file.
187    /// * `n_ctx` — Desired context window size in tokens.
188    /// * `sampling_params` — Sampling parameters for token generation.
189    /// * `fit_params` — Configuration for the fitting algorithm.
190    /// * `kv_cache_params` — KV cache data-type configuration (defaults to F16/F16).
191    /// * `checkpoint_params` — Tunables for the in-memory state-checkpoint cache
192    ///   used to preserve KV/recurrent state across chat turns for hybrid models.
193    ///
194    /// # Errors
195    ///
196    /// Returns a [`LoadError`] if the backend fails to initialise, automatic
197    /// fitting fails, or the model cannot be loaded.
198    pub fn from_gguf(
199        model_path: impl Into<String>,
200        n_ctx: u32,
201        sampling_params: SamplingParams,
202        fit_params: FitParams,
203        kv_cache_params: KvCacheParams,
204        checkpoint_params: CheckpointParams,
205    ) -> Result<Self, LoadError> {
206        Self::spawn(
207            model_path.into(),
208            None,
209            n_ctx,
210            sampling_params,
211            fit_params,
212            kv_cache_params,
213            checkpoint_params,
214        )
215    }
216
217    /// Load a GGUF vision model with a multimodal projector and automatic GPU/CPU layer fitting.
218    ///
219    /// This constructor enables multimodal (vision) inference. The `mmproj_path` should point
220    /// to a GGUF multimodal projector file (mmproj) that corresponds to the vision model.
221    ///
222    /// Prefer [`Client::builder`] with [`ClientBuilder::mmproj`] for new code.
223    ///
224    /// # Arguments
225    ///
226    /// * `model_path` — Path to a `.gguf` vision model file.
227    /// * `mmproj_path` — Path to the corresponding multimodal projector `.gguf` file.
228    /// * `n_ctx` — Desired context window size in tokens.
229    /// * `sampling_params` — Sampling parameters for token generation.
230    /// * `fit_params` — Configuration for the fitting algorithm.
231    /// * `kv_cache_params` — KV cache data-type configuration (defaults to F16/F16).
232    ///
233    /// # Errors
234    ///
235    /// Returns a [`LoadError`] if the backend fails to initialise, the model
236    /// cannot be loaded, or the multimodal projector cannot be initialised.
237    #[cfg(feature = "mtmd")]
238    pub fn from_gguf_with_mmproj(
239        model_path: impl Into<String>,
240        mmproj_path: impl Into<String>,
241        n_ctx: u32,
242        sampling_params: SamplingParams,
243        fit_params: FitParams,
244        kv_cache_params: KvCacheParams,
245        checkpoint_params: CheckpointParams,
246    ) -> Result<Self, LoadError> {
247        Self::spawn(
248            model_path.into(),
249            Some(mmproj_path.into()),
250            n_ctx,
251            sampling_params,
252            fit_params,
253            kv_cache_params,
254            checkpoint_params,
255        )
256    }
257
258    /// Shared spawn path used by the builder and by the positional constructors.
259    /// `mmproj_path` is only consulted when the `mtmd` feature is enabled; with
260    /// the feature off, callers always pass `None` and the worker thread
261    /// silently ignores any value.
262    fn spawn(
263        model_path: String,
264        mmproj_path: Option<String>,
265        n_ctx: u32,
266        sampling_params: SamplingParams,
267        fit_params: FitParams,
268        kv_cache_params: KvCacheParams,
269        checkpoint_params: CheckpointParams,
270    ) -> Result<Self, LoadError> {
271        let (request_tx, mut request_rx) =
272            mpsc::channel::<InferenceCommand>(COMMAND_CHANNEL_CAPACITY);
273        let (init_tx, init_rx) = std::sync::mpsc::channel::<Result<(), LoadError>>();
274        let cancel = Arc::new(AtomicBool::new(false));
275        let worker_cancel = Arc::clone(&cancel);
276
277        let worker_handle = thread::spawn(move || {
278            let init = WorkerInit {
279                model_path: &model_path,
280                mmproj_path: mmproj_path.as_deref(),
281                n_ctx,
282                fit_params: &fit_params,
283                kv_cache_params: &kv_cache_params,
284                checkpoint_params,
285                cancel: worker_cancel,
286            };
287            inference_worker(init, init_tx, &mut request_rx);
288        });
289
290        init_rx
291            .recv()
292            .map_err(|_| LoadError::WorkerInitDisconnected)??;
293
294        Ok(Self {
295            request_tx,
296            cancel,
297            sampling_params: std::sync::RwLock::new(sampling_params),
298            worker_handle: Some(worker_handle),
299        })
300    }
301
302    /// Reload the worker thread with a new model without destroying the backend.
303    ///
304    /// This swaps the model in-place on the existing inference thread, avoiding the
305    /// `LlamaBackend` singleton re-initialization race that occurs when dropping and
306    /// recreating a `Client`.
307    ///
308    /// # Errors
309    ///
310    /// Returns [`LoadError::WorkerNotRunning`] if the inference worker is no
311    /// longer accepting commands, or any of the load-stage variants if the
312    /// new model fails to come up.
313    // The positional signature is part of the 0.1.x public API. A future minor
314    // release can introduce a `ReloadOptions`/`reload_builder` shape; until
315    // then, the eight params (self + 7 fields) intentionally stay positional.
316    #[allow(clippy::too_many_arguments)]
317    pub fn reload(
318        &self,
319        model_path: String,
320        mmproj_path: Option<String>,
321        n_ctx: u32,
322        sampling: SamplingParams,
323        fit_params: FitParams,
324        kv_cache_params: KvCacheParams,
325        checkpoint_params: CheckpointParams,
326    ) -> Result<(), LoadError> {
327        let (result_tx, result_rx) = std::sync::mpsc::channel();
328        // `blocking_send` is the right call here: `reload` is a sync API and
329        // is documented to be invoked from a `spawn_blocking` task (or any
330        // non-async thread) when used from a tokio context. Backpressure on a
331        // full command queue is fine — reload is itself a blocking operation.
332        self.request_tx
333            .blocking_send(InferenceCommand::Reload(ReloadRequest {
334                model_path,
335                mmproj_path,
336                n_ctx,
337                fit_params,
338                kv_cache_params,
339                checkpoint_params,
340                result_tx,
341            }))
342            .map_err(|_| LoadError::WorkerNotRunning)?;
343        let result = result_rx
344            .recv()
345            .map_err(|_| LoadError::WorkerInitDisconnected)?;
346        if result.is_ok() {
347            // SamplingParams is `Copy` (just numeric scalars) — a poisoned
348            // lock can't represent torn or invalid data, so recover the
349            // guard rather than propagate a panic.
350            let mut guard = self
351                .sampling_params
352                .write()
353                .unwrap_or_else(|p| p.into_inner());
354            *guard = sampling;
355        }
356        result
357    }
358}
359
360impl Drop for Client {
361    /// Tear down the worker thread synchronously.
362    ///
363    /// `Drop` blocks until the worker thread has fully exited and the
364    /// `LlamaModel` / `LlamaContext` (and `LlamaBackend` device handles, plus
365    /// the multimodal projector when the `mtmd` feature is on) are released.
366    /// This is intentional: the caller almost always wants to allocate a
367    /// replacement `Client` immediately after dropping this one, and a
368    /// non-blocking drop would briefly hold 2× the model's RAM/VRAM and risk
369    /// OOM. [`Client::reload`] reuses the same worker and avoids this whole
370    /// path; prefer it over drop-and-recreate when you can.
371    ///
372    /// To keep the wait short even when a long generation is mid-flight,
373    /// `Drop` flips the shared cancel flag before signalling shutdown. The
374    /// worker polls the flag at every prompt-prefill chunk boundary and
375    /// every sampled token, so an in-flight `Request` returns within a
376    /// single decode step. The pessimal wait is therefore one decode step,
377    /// not the rest of the generation.
378    ///
379    /// `try_send(Shutdown)` is best-effort: if the bounded command queue is
380    /// full at this instant, the `Shutdown` command isn't enqueued — but the
381    /// in-flight request still bails on the cancel flag, and the worker's
382    /// per-iteration cancel check at the top of its command loop also exits
383    /// the thread before pulling more queued commands.
384    ///
385    /// `Model` clones outliving the `Client` keep the channel sender count
386    /// above zero; their `send` calls naturally fail with `SendError` once
387    /// the receiver is dropped on worker exit, so they don't prevent
388    /// shutdown.
389    fn drop(&mut self) {
390        self.cancel.store(true, Ordering::Relaxed);
391        let _ = self.request_tx.try_send(InferenceCommand::Shutdown);
392
393        if let Some(worker_handle) = self.worker_handle.take() {
394            let _ = worker_handle.join();
395        }
396    }
397}
398
399impl CompletionClient for Client {
400    type CompletionModel = Model;
401}
402
403/// A handle to a loaded model that implements Rig's [`CompletionModel`] trait.
404///
405/// Obtained via [`CompletionClient::agent`] on a [`Client`].
406#[derive(Clone)]
407pub struct Model {
408    request_tx: mpsc::Sender<InferenceCommand>,
409    sampling_params: SamplingParams,
410    #[allow(dead_code)]
411    model_id: String,
412}
413
414impl CompletionModel for Model {
415    type Response = RawResponse;
416    type StreamingResponse = StreamChunk;
417    type Client = Client;
418
419    fn make(client: &Client, model: impl Into<String>) -> Self {
420        // See the matching `unwrap_or_else` in `reload`: SamplingParams is
421        // `Copy`, so a poisoned lock still holds valid data.
422        let sampling_params = *client
423            .sampling_params
424            .read()
425            .unwrap_or_else(|p| p.into_inner());
426        Self {
427            request_tx: client.request_tx.clone(),
428            sampling_params,
429            model_id: model.into(),
430        }
431    }
432
433    async fn completion(
434        &self,
435        request: CompletionRequest,
436    ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
437        let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
438        let max_tokens = request.max_tokens.unwrap_or(512) as u32;
439        let temperature = request.temperature.unwrap_or(0.7) as f32;
440
441        let (response_tx, response_rx) = oneshot::channel();
442
443        self.request_tx
444            .send(InferenceCommand::Request(InferenceRequest {
445                params: InferenceParams {
446                    prepared_request,
447                    max_tokens,
448                    temperature,
449                    top_p: self.sampling_params.top_p,
450                    top_k: self.sampling_params.top_k,
451                    min_p: self.sampling_params.min_p,
452                    presence_penalty: self.sampling_params.presence_penalty,
453                    repetition_penalty: self.sampling_params.repetition_penalty,
454                },
455                response_channel: ResponseChannel::Completion(response_tx),
456            }))
457            .await
458            .map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
459
460        let result = response_rx
461            .await
462            .map_err(|_| CompletionError::ProviderError("Response channel closed".into()))?
463            .map_err(CompletionError::ProviderError)?;
464
465        Ok(CompletionResponse {
466            choice: result.choice,
467            usage: Usage {
468                input_tokens: result.prompt_tokens,
469                output_tokens: result.completion_tokens,
470                total_tokens: result.prompt_tokens + result.completion_tokens,
471                cached_input_tokens: result.cached_input_tokens,
472                cache_creation_input_tokens: 0,
473            },
474            raw_response: RawResponse { text: result.text },
475            message_id: None,
476        })
477    }
478
479    async fn stream(
480        &self,
481        request: CompletionRequest,
482    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
483        let prepared_request = prepare_request(&request).map_err(CompletionError::ProviderError)?;
484        let max_tokens = request.max_tokens.unwrap_or(512) as u32;
485        let temperature = request.temperature.unwrap_or(0.7) as f32;
486
487        let (stream_tx, stream_rx) = mpsc::unbounded_channel();
488
489        self.request_tx
490            .send(InferenceCommand::Request(InferenceRequest {
491                params: InferenceParams {
492                    prepared_request,
493                    max_tokens,
494                    temperature,
495                    top_p: self.sampling_params.top_p,
496                    top_k: self.sampling_params.top_k,
497                    min_p: self.sampling_params.min_p,
498                    presence_penalty: self.sampling_params.presence_penalty,
499                    repetition_penalty: self.sampling_params.repetition_penalty,
500                },
501                response_channel: ResponseChannel::Streaming(stream_tx),
502            }))
503            .await
504            .map_err(|_| CompletionError::ProviderError("Inference thread shut down".into()))?;
505
506        Ok(StreamingCompletionResponse::stream(Box::pin(
507            UnboundedReceiverStream::new(stream_rx),
508        )))
509    }
510}