rig-llama-cpp 0.1.1

Rig completion provider for local GGUF models via llama.cpp, with streaming, tool calling, reasoning, and multimodal (mtmd) support.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};

use rig::completion::CompletionError;
use rig::streaming::RawStreamingChoice;
use tokio::sync::mpsc;

use crate::checkpoint::{
    PersistentCtx, ensure_persistent_ctx, maybe_create_checkpoint, restore_or_clear,
};
use crate::error::LoadError;
#[cfg(feature = "mtmd")]
use crate::image::run_image_inference;
use crate::loader::{WorkerModel, fit_and_load_model};
use crate::prompt::build_prompt;
use crate::sampling::sample_tokens;
use crate::slot::{SlotEntry, get_common_prefix};
use crate::types::{
    CheckpointParams, FitParams, InferenceCommand, InferenceParams, InferenceResult, KvCacheParams,
    ResponseChannel, StreamChunk, StreamSender,
};

/// Error returned by inference paths when [`Client::drop`] (or a future
/// per-request cancel hook) flips the shared cancel flag mid-decode. Surfaced
/// to the response channel so the caller's await wakes up promptly with a
/// well-known message instead of a hang.
pub(crate) const CANCEL_ERR: &str = "inference cancelled";
enum LoopOutcome {
    Reload(crate::types::ReloadRequest),
    Shutdown,
}

/// Borrowed view of the worker's loaded-model environment, threaded through
/// the inference call chain to keep individual function signatures small.
///
/// The lifetimes:
///
/// - `'m` is the lifetime of the loaded `LlamaModel` / `LlamaBackend` —
///   shared with [`PersistentCtx`] so the persistent KV cache borrows from
///   the same model.
/// - `'a` borrows the per-request reads (`kv_cache`, mtmd context) for the
///   duration of a single inference call.
pub(crate) struct RunCtx<'a, 'm> {
    pub(crate) backend: &'m llama_cpp_2::llama_backend::LlamaBackend,
    pub(crate) model: &'m llama_cpp_2::model::LlamaModel,
    pub(crate) n_ctx: u32,
    pub(crate) kv_cache: &'a KvCacheParams,
    pub(crate) checkpoint_params: CheckpointParams,
    #[cfg(feature = "mtmd")]
    pub(crate) mtmd_ctx: Option<&'a llama_cpp_2::mtmd::MtmdContext>,
    /// Shared shutdown signal. Polled at chunk boundaries during prompt
    /// prefill and per-token in the sampler loop so [`Client::drop`] returns
    /// promptly even when a long generation is in flight.
    pub(crate) cancel: &'a AtomicBool,
}

/// Bringup parameters for [`inference_worker`]. Folded into a struct so the
/// worker entry point stays under clippy's `too_many_arguments` threshold.
pub(crate) struct WorkerInit<'a> {
    pub(crate) model_path: &'a str,
    pub(crate) mmproj_path: Option<&'a str>,
    pub(crate) n_ctx: u32,
    pub(crate) fit_params: &'a FitParams,
    pub(crate) kv_cache_params: &'a KvCacheParams,
    pub(crate) checkpoint_params: CheckpointParams,
    /// Owned clone of the [`Client`]'s cancel flag. The worker re-borrows it
    /// for each inference call so [`Client::drop`] can short-circuit the
    /// sampling loop without dropping the channel sender.
    pub(crate) cancel: Arc<AtomicBool>,
}

/// Inner request loop that owns the persistent context.
///
/// Returns when the channel closes, a reload is requested, or shutdown is requested.
/// On Reload the caller drops `wm` and reloads the model — the persistent context
/// (which borrows `&wm.model`) is dropped automatically when this function returns.
fn handle_until_reload<'m>(
    backend: &'m llama_cpp_2::llama_backend::LlamaBackend,
    wm: &'m WorkerModel,
    checkpoint_params: CheckpointParams,
    cancel: &AtomicBool,
    rx: &mut mpsc::Receiver<InferenceCommand>,
) -> LoopOutcome {
    let mut persistent: Option<PersistentCtx<'m>> = None;

    while let Some(command) = rx.blocking_recv() {
        match command {
            InferenceCommand::Request(req) => {
                let crate::types::InferenceRequest {
                    params,
                    response_channel,
                } = req;

                let ctx = RunCtx {
                    backend,
                    model: &wm.model,
                    n_ctx: wm.n_ctx,
                    kv_cache: &wm.kv_cache,
                    checkpoint_params,
                    #[cfg(feature = "mtmd")]
                    mtmd_ctx: wm.mtmd_ctx.as_ref(),
                    cancel,
                };

                match response_channel {
                    ResponseChannel::Completion(tx) => {
                        let result = run_inference(&ctx, &mut persistent, &params, None);
                        let _ = tx.send(result);
                    }
                    ResponseChannel::Streaming(stream_tx) => {
                        let result =
                            run_inference(&ctx, &mut persistent, &params, Some(&stream_tx));
                        match result {
                            Ok(result) => {
                                let _ = stream_tx.send(Ok(RawStreamingChoice::FinalResponse(
                                    StreamChunk {
                                        text: result.text,
                                        prompt_tokens: Some(result.prompt_tokens),
                                        completion_tokens: Some(result.completion_tokens),
                                        cached_input_tokens: Some(result.cached_input_tokens),
                                    },
                                )));
                            }
                            Err(e) => {
                                let _ = stream_tx.send(Err(CompletionError::ProviderError(e)));
                            }
                        }
                    }
                }
            }
            InferenceCommand::Reload(reload) => return LoopOutcome::Reload(reload),
            InferenceCommand::Shutdown => return LoopOutcome::Shutdown,
        }

        // Drop sets `cancel` and best-effort-tries to send a Shutdown. If the
        // channel was full at the time, Shutdown didn't make it in — but the
        // request we just finished short-circuited via `cancel`, so check it
        // here and exit before pulling the next queued command.
        if cancel.load(Ordering::Relaxed) {
            return LoopOutcome::Shutdown;
        }
    }
    LoopOutcome::Shutdown
}

pub(crate) fn inference_worker(
    init: WorkerInit<'_>,
    init_tx: std::sync::mpsc::Sender<Result<(), LoadError>>,
    rx: &mut mpsc::Receiver<InferenceCommand>,
) {
    let backend = match crate::shared_backend() {
        Ok(b) => b,
        Err(e) => {
            let _ = init_tx.send(Err(LoadError::BackendInit(e)));
            return;
        }
    };
    let logs_enabled = crate::llama_logs_enabled();

    let mut wm = match fit_and_load_model(
        backend,
        init.model_path,
        init.mmproj_path,
        init.n_ctx,
        init.fit_params,
        init.kv_cache_params,
        logs_enabled,
    ) {
        Ok(wm) => wm,
        Err(e) => {
            let _ = init_tx.send(Err(e));
            return;
        }
    };

    // Signal successful initialization
    let _ = init_tx.send(Ok(()));

    let mut checkpoint_params = init.checkpoint_params;
    let cancel = init.cancel;

    while let LoopOutcome::Reload(reload) =
        handle_until_reload(backend, &wm, checkpoint_params, &cancel, rx)
    {
        // The persistent context (held inside handle_until_reload) has
        // already been dropped by the time we get here, so it is safe
        // to drop and replace `wm`.
        drop(wm);

        let result = fit_and_load_model(
            backend,
            &reload.model_path,
            reload.mmproj_path.as_deref(),
            reload.n_ctx,
            &reload.fit_params,
            &reload.kv_cache_params,
            logs_enabled,
        );

        match result {
            Ok(new_wm) => {
                wm = new_wm;
                checkpoint_params = reload.checkpoint_params;
                let _ = reload.result_tx.send(Ok(()));
            }
            Err(e) => {
                let _ = reload.result_tx.send(Err(e));
                return;
            }
        }
    }
}

/// Top-level inference dispatch: text-only by default, multimodal when the
/// request carries images and an mtmd context is available.
fn run_inference<'m>(
    ctx: &RunCtx<'_, 'm>,
    persistent: &mut Option<PersistentCtx<'m>>,
    req: &InferenceParams,
    stream_tx: Option<&StreamSender>,
) -> Result<InferenceResult, String> {
    #[cfg(feature = "mtmd")]
    {
        let has_images = !req.prepared_request.images.is_empty();
        if has_images && ctx.mtmd_ctx.is_some() {
            return run_image_inference(ctx, persistent, req, stream_tx);
        }
    }
    run_text_inference(ctx, persistent, req, stream_tx)
}

/// Text-only inference with persistent-context + prefix-cache reuse.
///
/// On each call we tokenize the new prompt, find the longest common prefix with
/// the tokens currently committed in the KV cache, trim everything after that
/// prefix, and decode only the suffix. If prefix-cache reuse fails (which can
/// happen e.g. on memory implementations that don't support arbitrary partial
/// trims), we invalidate the persistent slot and retry once with a fresh
/// context — so the user's request still succeeds at the cost of a full decode.
fn run_text_inference<'m>(
    ctx: &RunCtx<'_, 'm>,
    persistent: &mut Option<PersistentCtx<'m>>,
    req: &InferenceParams,
    stream_tx: Option<&StreamSender>,
) -> Result<InferenceResult, String> {
    use llama_cpp_2::model::AddBos;

    let prompt_build = build_prompt(ctx.model, &req.prepared_request)?;
    let prompt = prompt_build.prompt.as_str();

    let new_tokens = ctx
        .model
        .str_to_token(prompt, AddBos::Always)
        .map_err(|e| format!("Tokenization failed: {e}"))?;
    let prompt_len = new_tokens.len();

    if prompt_len == 0 {
        return Err("Empty prompt after tokenization".to_string());
    }
    if prompt_len > ctx.n_ctx as usize {
        return Err(format!(
            "Prompt {prompt_len} tokens exceeds n_ctx {}",
            ctx.n_ctx
        ));
    }

    // Build the candidate as all-Text entries for the diff. Image entries
    // from a previous mtmd turn (if any) compare unequal to text tokens,
    // which is exactly what we want — divergence at the first image position.
    let new_entries: Vec<SlotEntry> = new_tokens.iter().map(|t| SlotEntry::Text(*t)).collect();
    let cached = {
        let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
        get_common_prefix(&p.last_entries, &new_entries)
    };

    // Phase 1: prompt decode (with prefix-cache reuse). This phase is safe to
    // retry on failure because no output has been streamed yet. The helper
    // gracefully handles trim-unsupported memories (recurrent/hybrid) by
    // restoring the closest checkpoint or fully clearing the cache.
    let phase1 = {
        let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
        prepare_prompt_decode(
            p,
            &new_tokens,
            cached,
            prompt_len,
            ctx.checkpoint_params,
            ctx.cancel,
        )
    };

    let (mut batch, effective_cached) = match phase1 {
        Ok(out) => out,
        Err(e) if cached > 0 => {
            // Some other phase-1 failure mode. Drop persistent, rebuild fresh,
            // and retry from scratch. Safe because no output has streamed yet.
            log::warn!(
                "prefix-cache decode failed (cached={cached}, prompt_len={prompt_len}): {e}. \
                 Falling back to fresh-context decode."
            );
            *persistent = None;
            let retry = {
                let p = ensure_persistent_ctx(
                    ctx.backend,
                    ctx.model,
                    ctx.n_ctx,
                    ctx.kv_cache,
                    persistent,
                )?;
                prepare_prompt_decode(
                    p,
                    &new_tokens,
                    0,
                    prompt_len,
                    ctx.checkpoint_params,
                    ctx.cancel,
                )
            };
            match retry {
                Ok(out) => out,
                Err(e) => {
                    *persistent = None;
                    return Err(e);
                }
            }
        }
        Err(e) => {
            *persistent = None;
            return Err(e);
        }
    };

    // Phase 2: commit the prompt to last_entries and sample. From this point on
    // we may have streamed tokens to the consumer, so any failure invalidates
    // the persistent slot but cannot be retried.
    let p = ensure_persistent_ctx(ctx.backend, ctx.model, ctx.n_ctx, ctx.kv_cache, persistent)?;
    p.last_entries = new_entries;
    let prompt_tokens = prompt_len as u64;
    let cached_tokens = effective_cached as u64;

    let result = sample_tokens(
        ctx.model,
        &mut p.ctx,
        &mut batch,
        &prompt_build,
        req,
        stream_tx,
        prompt_tokens,
        cached_tokens,
        &mut p.last_entries,
        ctx.cancel,
    );

    if result.is_err() {
        *persistent = None;
    }
    result
}

/// Decode the prompt suffix into the persistent context's KV cache and return
/// a batch ready for sampling, plus the count of tokens that were actually
/// served from the cache (which may be less than the LCP if a rollback wasn't
/// possible). This is "phase 1" — safe to retry on failure because no output
/// has been streamed to the consumer yet.
///
/// For models whose memory rejects partial trims (recurrent/hybrid), we
/// attempt to restore from the closest in-memory state checkpoint before
/// falling back to a full clear.
fn prepare_prompt_decode<'b>(
    p: &mut PersistentCtx<'_>,
    new_tokens: &[llama_cpp_2::token::LlamaToken],
    cached: usize,
    prompt_len: usize,
    checkpoint_params: CheckpointParams,
    cancel: &AtomicBool,
) -> Result<(llama_cpp_2::llama_batch::LlamaBatch<'b>, usize), String> {
    use llama_cpp_2::llama_batch::LlamaBatch;

    log::debug!(
        "prefix-cache: prompt_len={prompt_len} last_entries.len={} cached={cached} trim_unsupported={} checkpoints={}",
        p.last_entries.len(),
        p.trim_unsupported,
        p.checkpoint_count(),
    );

    let mut effective_cached = cached;

    if cached < p.last_entries.len() {
        // Need to roll back the cache to position `cached`.
        if p.trim_unsupported {
            // Already known: trim refused before. Try checkpoint restore.
            effective_cached = restore_or_clear(p, cached);
        } else {
            let removed = p
                .ctx
                .clear_kv_cache_seq(Some(0), Some(cached as u32), None)
                .map_err(|e| format!("KV cache trim failed: {e:?}"))?;
            if removed {
                // Trim worked. Drop checkpoints whose pos_max >= cached because
                // the state they captured is now invalid (positions ahead of
                // the trim boundary).
                p.retain_checkpoints_below(cached);
            } else {
                // First time this model rejects a partial trim. Mark it and
                // try the checkpoint path.
                log::info!(
                    "partial KV-cache trim not supported by this model \
                     (likely recurrent/hybrid). Routing rollbacks through checkpoint restore."
                );
                p.trim_unsupported = true;
                effective_cached = restore_or_clear(p, cached);
            }
        }
    } else {
        // No rollback needed (extension only or full match). Drop checkpoints
        // whose pos_max would land past where we're now operating.
        p.retain_checkpoints_below(cached.max(1));
    }

    let prompt_batch_limit = p.ctx.n_batch().max(1) as usize;
    let mut batch = LlamaBatch::new(prompt_batch_limit, 1);

    if effective_cached < prompt_len {
        // Decode the new suffix.
        let suffix = &new_tokens[effective_cached..];
        for (chunk_index, chunk) in suffix.chunks(prompt_batch_limit).enumerate() {
            // Bail at chunk boundaries so a long prompt prefill (potentially
            // tens of seconds for a 100k-token prompt on a slow backend)
            // doesn't keep `Client::drop` waiting.
            if cancel.load(Ordering::Relaxed) {
                return Err(CANCEL_ERR.to_string());
            }
            batch.clear();
            for (offset, token) in chunk.iter().copied().enumerate() {
                let abs = effective_cached + chunk_index * prompt_batch_limit + offset;
                let is_last_prompt_token = abs + 1 == prompt_len;
                batch
                    .add(token, abs as i32, &[0], is_last_prompt_token)
                    .map_err(|e| format!("Batch add failed: {e}"))?;
            }
            if batch.n_tokens() == 0 {
                return Err(format!(
                    "BUG: empty prompt batch at chunk {chunk_index} (suffix.len={}, prompt_batch_limit={})",
                    suffix.len(),
                    prompt_batch_limit,
                ));
            }
            p.ctx
                .decode(&mut batch)
                .map_err(|e| format!("Prompt decode failed: {e}"))?;

            let n_tokens_decoded =
                effective_cached + chunk_index * prompt_batch_limit + chunk.len();
            maybe_create_checkpoint(p, checkpoint_params, n_tokens_decoded, prompt_len);
        }
    } else {
        // Whole prompt already cached. Roll back the last position by one and
        // re-decode it so the sampler has a fresh `logits=true` slot to read.
        // Only reachable when trim is supported (otherwise effective_cached
        // would have been reset to 0 above).
        let removed = p
            .ctx
            .clear_kv_cache_seq(Some(0), Some((prompt_len - 1) as u32), None)
            .map_err(|e| format!("KV cache trim failed: {e:?}"))?;
        if !removed {
            return Err(format!(
                "KV cache trim (rollback) returned false at pos {}",
                prompt_len - 1
            ));
        }
        batch.clear();
        batch
            .add(
                new_tokens[prompt_len - 1],
                (prompt_len - 1) as i32,
                &[0],
                true,
            )
            .map_err(|e| format!("Batch add failed: {e}"))?;
        p.ctx
            .decode(&mut batch)
            .map_err(|e| format!("Prompt decode failed: {e}"))?;
    }

    Ok((batch, effective_cached))
}