Skip to main content

rig_llama_cpp/
types.rs

1use rig_core::completion::{CompletionError, GetTokenUsage, Usage};
2use rig_core::message::AssistantContent;
3use rig_core::one_or_many::OneOrMany;
4use rig_core::streaming::RawStreamingChoice;
5use serde::{Deserialize, Serialize};
6use tokio::sync::{mpsc, oneshot};
7
8/// Raw completion response returned by the model.
9///
10/// Marked `#[non_exhaustive]` because new fields may be added in future
11/// minor releases.
12#[derive(Clone, Debug, Serialize, Deserialize)]
13#[non_exhaustive]
14pub struct RawResponse {
15    /// The full generated text.
16    pub text: String,
17}
18
19/// A single chunk emitted during streaming inference.
20///
21/// The final chunk in a stream includes token usage counts. Marked
22/// `#[non_exhaustive]` because new fields may be added in future minor
23/// releases.
24#[derive(Clone, Debug, Serialize, Deserialize)]
25#[non_exhaustive]
26pub struct StreamChunk {
27    /// The text fragment for this chunk.
28    pub text: String,
29    /// Number of prompt tokens (only set on the final chunk).
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub prompt_tokens: Option<u64>,
32    /// Number of completion tokens (only set on the final chunk).
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub completion_tokens: Option<u64>,
35    /// Number of prompt tokens that were served from the persistent KV-cache prefix
36    /// (only set on the final chunk).
37    #[serde(skip_serializing_if = "Option::is_none")]
38    pub cached_input_tokens: Option<u64>,
39}
40
41impl GetTokenUsage for StreamChunk {
42    fn token_usage(&self) -> Usage {
43        let Some((input, output)) = self.prompt_tokens.zip(self.completion_tokens) else {
44            return Usage::new();
45        };
46        Usage {
47            input_tokens: input,
48            output_tokens: output,
49            total_tokens: input + output,
50            cached_input_tokens: self.cached_input_tokens.unwrap_or(0),
51            cache_creation_input_tokens: 0,
52            tool_use_prompt_tokens: 0,
53            reasoning_tokens: 0,
54        }
55    }
56}
57
58pub(crate) type StreamSender =
59    mpsc::UnboundedSender<Result<RawStreamingChoice<StreamChunk>, CompletionError>>;
60
61pub(crate) enum ResponseChannel {
62    Completion(oneshot::Sender<Result<InferenceResult, String>>),
63    Streaming(StreamSender),
64}
65
66pub(crate) enum InferenceCommand {
67    Request(InferenceRequest),
68    Reload(ReloadRequest),
69    Shutdown,
70}
71
72pub(crate) struct ReloadRequest {
73    pub model_path: String,
74    pub mmproj_path: Option<String>,
75    pub n_ctx: u32,
76    pub fit_params: FitParams,
77    pub kv_cache_params: KvCacheParams,
78    pub checkpoint_params: CheckpointParams,
79    pub result_tx: std::sync::mpsc::Sender<Result<(), crate::error::LoadError>>,
80}
81
82pub(crate) struct InferenceRequest {
83    pub params: InferenceParams,
84    pub response_channel: ResponseChannel,
85}
86
87pub(crate) struct InferenceParams {
88    pub prepared_request: PreparedRequest,
89    pub max_tokens: u32,
90    pub temperature: f32,
91    pub top_p: f32,
92    pub top_k: i32,
93    pub min_p: f32,
94    pub presence_penalty: f32,
95    pub repetition_penalty: f32,
96}
97
98pub(crate) struct InferenceResult {
99    pub text: String,
100    pub choice: OneOrMany<AssistantContent>,
101    pub prompt_tokens: u64,
102    pub completion_tokens: u64,
103    /// Tokens of the prompt that were already present in the persistent KV cache
104    /// (i.e. the longest common prefix shared with the previous request).
105    pub cached_input_tokens: u64,
106}
107
108pub(crate) struct PreparedRequest {
109    pub messages_json: String,
110    pub tools_json: Option<String>,
111    pub tool_choice: Option<String>,
112    pub json_schema: Option<String>,
113    /// Parsed from the request's `additional_params` (`{ "thinking": bool }`).
114    /// `llama-cpp-2` 0.1.147 dropped the `chat_template_kwargs` plumbing the
115    /// old oaicompat path used to forward this to the jinja engine, so the
116    /// flag is currently advisory: thinking-enabled is the template default
117    /// and continues to work; thinking-disabled can no longer be enforced
118    /// through the template and is surfaced only via the model's defaults.
119    #[allow(dead_code)]
120    pub enable_thinking: bool,
121    #[cfg(feature = "mtmd")]
122    pub images: Vec<PreparedImage>,
123}
124
125/// One image extracted from the chat history with its FNV-1a hash precomputed.
126/// The hash is propagated into the underlying `MtmdBitmap` via `set_id` so
127/// that `MtmdInputChunk::id()` round-trips it for the prefix-cache diff.
128#[cfg(feature = "mtmd")]
129#[derive(Clone, Debug)]
130pub(crate) struct PreparedImage {
131    pub bytes: Vec<u8>,
132    pub hash: u64,
133}
134
135pub(crate) struct PromptBuildResult {
136    pub prompt: String,
137}
138
139/// Sampling parameters that control token generation.
140///
141/// Marked `#[non_exhaustive]` so future sampling knobs can be added without
142/// a breaking release. Start from [`SamplingParams::default`] and chain
143/// `with_*` setters:
144///
145/// ```
146/// let params = rig_llama_cpp::SamplingParams::default()
147///     .with_top_k(40)
148///     .with_presence_penalty(1.5);
149/// ```
150#[derive(Clone, Copy, Debug)]
151#[non_exhaustive]
152pub struct SamplingParams {
153    /// Nucleus sampling threshold (default: `0.95`).
154    pub top_p: f32,
155    /// Top-k sampling parameter (default: `40`).
156    pub top_k: i32,
157    /// Minimum probability threshold (default: `0.0`).
158    pub min_p: f32,
159    /// Penalty for token presence (default: `0.0`).
160    pub presence_penalty: f32,
161    /// Penalty for token repetition (default: `1.0`).
162    pub repetition_penalty: f32,
163}
164
165impl Default for SamplingParams {
166    fn default() -> Self {
167        Self {
168            top_p: 0.95,
169            top_k: 40,
170            min_p: 0.0,
171            presence_penalty: 0.0,
172            repetition_penalty: 1.0,
173        }
174    }
175}
176
177impl SamplingParams {
178    /// Set the nucleus sampling threshold.
179    #[must_use]
180    pub fn with_top_p(mut self, top_p: f32) -> Self {
181        self.top_p = top_p;
182        self
183    }
184
185    /// Set the top-k sampling parameter.
186    #[must_use]
187    pub fn with_top_k(mut self, top_k: i32) -> Self {
188        self.top_k = top_k;
189        self
190    }
191
192    /// Set the minimum probability threshold.
193    #[must_use]
194    pub fn with_min_p(mut self, min_p: f32) -> Self {
195        self.min_p = min_p;
196        self
197    }
198
199    /// Set the presence penalty.
200    #[must_use]
201    pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
202        self.presence_penalty = presence_penalty;
203        self
204    }
205
206    /// Set the repetition penalty.
207    #[must_use]
208    pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
209        self.repetition_penalty = repetition_penalty;
210        self
211    }
212}
213
214/// Configuration for automatic GPU/CPU layer fitting.
215///
216/// Passed to [`crate::Client::builder`] (or [`crate::Client::from_gguf`]) so
217/// llama.cpp can probe available device memory and pick the optimal number
218/// of layers to offload to GPU automatically, instead of requiring a manual
219/// `n_gpu_layers` value.
220///
221/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
222/// `with_*` setters.
223#[derive(Clone, Debug)]
224#[non_exhaustive]
225pub struct FitParams {
226    /// Memory margin per device in bytes. If `None`, defaults to 1 GiB per device.
227    pub margins: Option<Vec<usize>>,
228    /// Minimum context size to preserve during fitting (default: `4096`).
229    pub n_ctx_min: u32,
230}
231
232impl Default for FitParams {
233    fn default() -> Self {
234        Self {
235            margins: None,
236            n_ctx_min: 4096,
237        }
238    }
239}
240
241impl FitParams {
242    /// Override the per-device memory margin in bytes.
243    #[must_use]
244    pub fn with_margins(mut self, margins: Option<Vec<usize>>) -> Self {
245        self.margins = margins;
246        self
247    }
248
249    /// Override the minimum context size to preserve during fitting.
250    #[must_use]
251    pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
252        self.n_ctx_min = n_ctx_min;
253        self
254    }
255}
256
257/// Tunable parameters for the in-memory state-checkpoint cache used to
258/// preserve KV/recurrent state across chat turns for hybrid models.
259///
260/// Hybrid architectures (Qwen 3.5, Jamba, etc.) interleave Mamba-style
261/// recurrent layers with transformer layers. The recurrent state can't be
262/// rolled back to an arbitrary earlier position, so a partial KV trim
263/// fails whenever the next prompt diverges deep into the conversation.
264/// To work around this, we periodically snapshot the partial seq state
265/// (recurrent + SWA, via `LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY`) during
266/// prompt prefill and restore the closest snapshot when the next prompt
267/// arrives. Mirrors the mechanism used by upstream `llama-server`.
268///
269/// For non-hybrid models (Qwen 2.5, Llama 3, Gemma, ...) checkpoints are
270/// created but never used because the cheaper partial-trim path
271/// succeeds.
272///
273/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
274/// `with_*` setters.
275#[derive(Clone, Copy, Debug)]
276#[non_exhaustive]
277pub struct CheckpointParams {
278    /// Maximum number of checkpoints retained per persistent context.
279    /// `0` disables checkpointing entirely. Each checkpoint is a few MB
280    /// for typical hybrid models.
281    pub max_checkpoints: u32,
282    /// Approximate spacing between checkpoints during prompt prefill, in
283    /// tokens. The last `4..=4 + n_ubatch` tokens always get a
284    /// checkpoint regardless. `<= 0` means "only checkpoint near the end
285    /// of the prompt".
286    pub every_n_tokens: i32,
287    /// Don't checkpoint the very start of a prompt — saves space for
288    /// no benefit because we'd have to re-decode that prefix anyway if
289    /// it's the entire reuse window.
290    pub min_tokens: u32,
291    /// Don't take two checkpoints closer than this many tokens apart.
292    pub min_gap: u32,
293}
294
295impl Default for CheckpointParams {
296    fn default() -> Self {
297        Self {
298            // llama-server uses 32; cap lower because each checkpoint is
299            // a few MB and we'd rather not balloon RSS.
300            max_checkpoints: 8,
301            every_n_tokens: 8192,
302            min_tokens: 64,
303            min_gap: 64,
304        }
305    }
306}
307
308impl CheckpointParams {
309    /// Override the maximum number of checkpoints retained per context.
310    #[must_use]
311    pub fn with_max_checkpoints(mut self, max_checkpoints: u32) -> Self {
312        self.max_checkpoints = max_checkpoints;
313        self
314    }
315
316    /// Override the approximate spacing between checkpoints (in tokens).
317    #[must_use]
318    pub fn with_every_n_tokens(mut self, every_n_tokens: i32) -> Self {
319        self.every_n_tokens = every_n_tokens;
320        self
321    }
322
323    /// Override the minimum prompt length before checkpoints are taken.
324    #[must_use]
325    pub fn with_min_tokens(mut self, min_tokens: u32) -> Self {
326        self.min_tokens = min_tokens;
327        self
328    }
329
330    /// Override the minimum spacing between two consecutive checkpoints.
331    #[must_use]
332    pub fn with_min_gap(mut self, min_gap: u32) -> Self {
333        self.min_gap = min_gap;
334        self
335    }
336}
337
338/// Data type used for an entry in the attention KV cache.
339///
340/// Mirrors the subset of `ggml_type` values that `llama.cpp` accepts as KV
341/// cache element types. The `F16` default preserves full attention quality;
342/// quantizing (e.g. `Q8_0` ≈ ½ size, `Q4_0` ≈ ¼ size) trades a small amount
343/// of accuracy for a large VRAM reduction at long `n_ctx`.
344///
345/// This is a local shim around `llama_cpp_2::context::params::KvCacheType`
346/// so a future `llama-cpp-2` update doesn't force a breaking release of
347/// `rig-llama-cpp`. Marked `#[non_exhaustive]`: when llama.cpp adds a new
348/// `ggml_type`, we add a corresponding variant in a minor (`0.1.x`) release.
349#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
350#[allow(non_camel_case_types)]
351#[non_exhaustive]
352pub enum KvCacheType {
353    /// IEEE 754 single precision.
354    F32,
355    /// IEEE 754 half precision (llama.cpp's default for both K and V).
356    F16,
357    /// Brain floating-point 16, common on newer NVIDIA / AMD GPUs.
358    BF16,
359    /// IEEE 754 double precision.
360    F64,
361    /// 4-bit block quantization, type 0.
362    Q4_0,
363    /// 4-bit block quantization, type 1.
364    Q4_1,
365    /// 5-bit block quantization, type 0.
366    Q5_0,
367    /// 5-bit block quantization, type 1.
368    Q5_1,
369    /// 8-bit block quantization, type 0.
370    Q8_0,
371    /// 8-bit block quantization, type 1.
372    Q8_1,
373    /// 2-bit K-quant.
374    Q2_K,
375    /// 3-bit K-quant.
376    Q3_K,
377    /// 4-bit K-quant.
378    Q4_K,
379    /// 5-bit K-quant.
380    Q5_K,
381    /// 6-bit K-quant.
382    Q6_K,
383    /// 8-bit K-quant.
384    Q8_K,
385    /// Importance-weighted 2-bit, extra-extra-small.
386    IQ2_XXS,
387    /// Importance-weighted 2-bit, extra-small.
388    IQ2_XS,
389    /// Importance-weighted 2-bit, small.
390    IQ2_S,
391    /// Importance-weighted 3-bit, extra-extra-small.
392    IQ3_XXS,
393    /// Importance-weighted 3-bit, small.
394    IQ3_S,
395    /// Importance-weighted 1-bit, small.
396    IQ1_S,
397    /// Importance-weighted 1-bit, medium.
398    IQ1_M,
399    /// Importance-weighted 4-bit, extra-small.
400    IQ4_XS,
401    /// Importance-weighted 4-bit, non-linear.
402    IQ4_NL,
403    /// Signed 8-bit integer.
404    I8,
405    /// Signed 16-bit integer.
406    I16,
407    /// Signed 32-bit integer.
408    I32,
409    /// Signed 64-bit integer.
410    I64,
411    /// Ternary 1-bit, type 0.
412    TQ1_0,
413    /// Ternary 2-bit, type 0.
414    TQ2_0,
415    /// Microscaling FP4.
416    MXFP4,
417}
418
419impl From<KvCacheType> for llama_cpp_2::context::params::KvCacheType {
420    fn from(value: KvCacheType) -> Self {
421        use llama_cpp_2::context::params::KvCacheType as Upstream;
422        match value {
423            KvCacheType::F32 => Upstream::F32,
424            KvCacheType::F16 => Upstream::F16,
425            KvCacheType::BF16 => Upstream::BF16,
426            KvCacheType::F64 => Upstream::F64,
427            KvCacheType::Q4_0 => Upstream::Q4_0,
428            KvCacheType::Q4_1 => Upstream::Q4_1,
429            KvCacheType::Q5_0 => Upstream::Q5_0,
430            KvCacheType::Q5_1 => Upstream::Q5_1,
431            KvCacheType::Q8_0 => Upstream::Q8_0,
432            KvCacheType::Q8_1 => Upstream::Q8_1,
433            KvCacheType::Q2_K => Upstream::Q2_K,
434            KvCacheType::Q3_K => Upstream::Q3_K,
435            KvCacheType::Q4_K => Upstream::Q4_K,
436            KvCacheType::Q5_K => Upstream::Q5_K,
437            KvCacheType::Q6_K => Upstream::Q6_K,
438            KvCacheType::Q8_K => Upstream::Q8_K,
439            KvCacheType::IQ2_XXS => Upstream::IQ2_XXS,
440            KvCacheType::IQ2_XS => Upstream::IQ2_XS,
441            KvCacheType::IQ2_S => Upstream::IQ2_S,
442            KvCacheType::IQ3_XXS => Upstream::IQ3_XXS,
443            KvCacheType::IQ3_S => Upstream::IQ3_S,
444            KvCacheType::IQ1_S => Upstream::IQ1_S,
445            KvCacheType::IQ1_M => Upstream::IQ1_M,
446            KvCacheType::IQ4_XS => Upstream::IQ4_XS,
447            KvCacheType::IQ4_NL => Upstream::IQ4_NL,
448            KvCacheType::I8 => Upstream::I8,
449            KvCacheType::I16 => Upstream::I16,
450            KvCacheType::I32 => Upstream::I32,
451            KvCacheType::I64 => Upstream::I64,
452            KvCacheType::TQ1_0 => Upstream::TQ1_0,
453            KvCacheType::TQ2_0 => Upstream::TQ2_0,
454            KvCacheType::MXFP4 => Upstream::MXFP4,
455        }
456    }
457}
458
459/// KV cache quantization configuration.
460///
461/// Controls the data type used for the attention K and V caches. llama.cpp defaults
462/// both to `F16` (`GGML_TYPE_F16`), which is what `KvCacheParams::default()` preserves.
463/// Quantizing the KV cache (e.g. `Q8_0` → ~½ size, `Q4_0` → ~¼ size) trades a small
464/// amount of accuracy for a large reduction in VRAM usage, which is often the dominant
465/// cost at long `n_ctx`.
466///
467/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
468/// `with_*` setters:
469///
470/// ```
471/// use rig_llama_cpp::{KvCacheParams, KvCacheType};
472///
473/// let kv = KvCacheParams::default()
474///     .with_type_k(KvCacheType::Q8_0)
475///     .with_type_v(KvCacheType::Q8_0);
476/// ```
477#[derive(Clone, Copy, Debug)]
478#[non_exhaustive]
479pub struct KvCacheParams {
480    /// Data type for the K cache (default: [`KvCacheType::F16`]).
481    pub type_k: KvCacheType,
482    /// Data type for the V cache (default: [`KvCacheType::F16`]).
483    pub type_v: KvCacheType,
484}
485
486impl Default for KvCacheParams {
487    fn default() -> Self {
488        Self {
489            type_k: KvCacheType::F16,
490            type_v: KvCacheType::F16,
491        }
492    }
493}
494
495impl KvCacheParams {
496    /// Override the K cache data type.
497    #[must_use]
498    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
499        self.type_k = type_k;
500        self
501    }
502
503    /// Override the V cache data type.
504    #[must_use]
505    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
506        self.type_v = type_v;
507        self
508    }
509}