Skip to main content

rig_llama_cpp/
types.rs

1use std::collections::HashMap;
2
3use rig::completion::{CompletionError, GetTokenUsage, Usage};
4use rig::message::AssistantContent;
5use rig::one_or_many::OneOrMany;
6use rig::streaming::{RawStreamingChoice, RawStreamingToolCall};
7use serde::{Deserialize, Serialize};
8use tokio::sync::{mpsc, oneshot};
9
10/// Raw completion response returned by the model.
11///
12/// Marked `#[non_exhaustive]` because new fields may be added in future
13/// minor releases.
14#[derive(Clone, Debug, Serialize, Deserialize)]
15#[non_exhaustive]
16pub struct RawResponse {
17    /// The full generated text.
18    pub text: String,
19}
20
21/// A single chunk emitted during streaming inference.
22///
23/// The final chunk in a stream includes token usage counts. Marked
24/// `#[non_exhaustive]` because new fields may be added in future minor
25/// releases.
26#[derive(Clone, Debug, Serialize, Deserialize)]
27#[non_exhaustive]
28pub struct StreamChunk {
29    /// The text fragment for this chunk.
30    pub text: String,
31    /// Number of prompt tokens (only set on the final chunk).
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub prompt_tokens: Option<u64>,
34    /// Number of completion tokens (only set on the final chunk).
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub completion_tokens: Option<u64>,
37    /// Number of prompt tokens that were served from the persistent KV-cache prefix
38    /// (only set on the final chunk).
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub cached_input_tokens: Option<u64>,
41}
42
43impl GetTokenUsage for StreamChunk {
44    fn token_usage(&self) -> Option<Usage> {
45        let (input, output) = self.prompt_tokens.zip(self.completion_tokens)?;
46        Some(Usage {
47            input_tokens: input,
48            output_tokens: output,
49            total_tokens: input + output,
50            cached_input_tokens: self.cached_input_tokens.unwrap_or(0),
51            cache_creation_input_tokens: 0,
52        })
53    }
54}
55
56pub(crate) type StreamSender =
57    mpsc::UnboundedSender<Result<RawStreamingChoice<StreamChunk>, CompletionError>>;
58
59pub(crate) enum ResponseChannel {
60    Completion(oneshot::Sender<Result<InferenceResult, String>>),
61    Streaming(StreamSender),
62}
63
64pub(crate) enum InferenceCommand {
65    Request(InferenceRequest),
66    Reload(ReloadRequest),
67    Shutdown,
68}
69
70pub(crate) struct ReloadRequest {
71    pub model_path: String,
72    pub mmproj_path: Option<String>,
73    pub n_ctx: u32,
74    pub fit_params: FitParams,
75    pub kv_cache_params: KvCacheParams,
76    pub checkpoint_params: CheckpointParams,
77    pub result_tx: std::sync::mpsc::Sender<Result<(), crate::error::LoadError>>,
78}
79
80pub(crate) struct InferenceRequest {
81    pub params: InferenceParams,
82    pub response_channel: ResponseChannel,
83}
84
85pub(crate) struct InferenceParams {
86    pub prepared_request: PreparedRequest,
87    pub max_tokens: u32,
88    pub temperature: f32,
89    pub top_p: f32,
90    pub top_k: i32,
91    pub min_p: f32,
92    pub presence_penalty: f32,
93    pub repetition_penalty: f32,
94}
95
96pub(crate) struct InferenceResult {
97    pub text: String,
98    pub choice: OneOrMany<AssistantContent>,
99    pub prompt_tokens: u64,
100    pub completion_tokens: u64,
101    /// Tokens of the prompt that were already present in the persistent KV cache
102    /// (i.e. the longest common prefix shared with the previous request).
103    pub cached_input_tokens: u64,
104}
105
106pub(crate) struct PreparedRequest {
107    pub messages_json: String,
108    pub tools_json: Option<String>,
109    pub tool_choice: Option<String>,
110    pub json_schema: Option<String>,
111    pub enable_thinking: bool,
112    #[cfg(feature = "mtmd")]
113    pub images: Vec<PreparedImage>,
114}
115
116/// One image extracted from the chat history with its FNV-1a hash precomputed.
117/// The hash is propagated into the underlying `MtmdBitmap` via `set_id` so
118/// that `MtmdInputChunk::id()` round-trips it for the prefix-cache diff.
119#[cfg(feature = "mtmd")]
120#[derive(Clone, Debug)]
121pub(crate) struct PreparedImage {
122    pub bytes: Vec<u8>,
123    pub hash: u64,
124}
125
126pub(crate) struct PromptBuildResult {
127    pub prompt: String,
128    pub template_result: Option<llama_cpp_2::model::ChatTemplateResult>,
129}
130
131/// Sampling parameters that control token generation.
132///
133/// Marked `#[non_exhaustive]` so future sampling knobs can be added without
134/// a breaking release. Start from [`SamplingParams::default`] and chain
135/// `with_*` setters:
136///
137/// ```
138/// let params = rig_llama_cpp::SamplingParams::default()
139///     .with_top_k(40)
140///     .with_presence_penalty(1.5);
141/// ```
142#[derive(Clone, Copy, Debug)]
143#[non_exhaustive]
144pub struct SamplingParams {
145    /// Nucleus sampling threshold (default: `0.95`).
146    pub top_p: f32,
147    /// Top-k sampling parameter (default: `40`).
148    pub top_k: i32,
149    /// Minimum probability threshold (default: `0.0`).
150    pub min_p: f32,
151    /// Penalty for token presence (default: `0.0`).
152    pub presence_penalty: f32,
153    /// Penalty for token repetition (default: `1.0`).
154    pub repetition_penalty: f32,
155}
156
157impl Default for SamplingParams {
158    fn default() -> Self {
159        Self {
160            top_p: 0.95,
161            top_k: 40,
162            min_p: 0.0,
163            presence_penalty: 0.0,
164            repetition_penalty: 1.0,
165        }
166    }
167}
168
169impl SamplingParams {
170    /// Set the nucleus sampling threshold.
171    #[must_use]
172    pub fn with_top_p(mut self, top_p: f32) -> Self {
173        self.top_p = top_p;
174        self
175    }
176
177    /// Set the top-k sampling parameter.
178    #[must_use]
179    pub fn with_top_k(mut self, top_k: i32) -> Self {
180        self.top_k = top_k;
181        self
182    }
183
184    /// Set the minimum probability threshold.
185    #[must_use]
186    pub fn with_min_p(mut self, min_p: f32) -> Self {
187        self.min_p = min_p;
188        self
189    }
190
191    /// Set the presence penalty.
192    #[must_use]
193    pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
194        self.presence_penalty = presence_penalty;
195        self
196    }
197
198    /// Set the repetition penalty.
199    #[must_use]
200    pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
201        self.repetition_penalty = repetition_penalty;
202        self
203    }
204}
205
206/// Configuration for automatic GPU/CPU layer fitting.
207///
208/// Passed to [`crate::Client::builder`] (or [`crate::Client::from_gguf`]) so
209/// llama.cpp can probe available device memory and pick the optimal number
210/// of layers to offload to GPU automatically, instead of requiring a manual
211/// `n_gpu_layers` value.
212///
213/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
214/// `with_*` setters.
215#[derive(Clone, Debug)]
216#[non_exhaustive]
217pub struct FitParams {
218    /// Memory margin per device in bytes. If `None`, defaults to 1 GiB per device.
219    pub margins: Option<Vec<usize>>,
220    /// Minimum context size to preserve during fitting (default: `4096`).
221    pub n_ctx_min: u32,
222}
223
224impl Default for FitParams {
225    fn default() -> Self {
226        Self {
227            margins: None,
228            n_ctx_min: 4096,
229        }
230    }
231}
232
233impl FitParams {
234    /// Override the per-device memory margin in bytes.
235    #[must_use]
236    pub fn with_margins(mut self, margins: Option<Vec<usize>>) -> Self {
237        self.margins = margins;
238        self
239    }
240
241    /// Override the minimum context size to preserve during fitting.
242    #[must_use]
243    pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
244        self.n_ctx_min = n_ctx_min;
245        self
246    }
247}
248
249/// Tunable parameters for the in-memory state-checkpoint cache used to
250/// preserve KV/recurrent state across chat turns for hybrid models.
251///
252/// Hybrid architectures (Qwen 3.5, Jamba, etc.) interleave Mamba-style
253/// recurrent layers with transformer layers. The recurrent state can't be
254/// rolled back to an arbitrary earlier position, so a partial KV trim
255/// fails whenever the next prompt diverges deep into the conversation.
256/// To work around this, we periodically snapshot the partial seq state
257/// (recurrent + SWA, via `LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY`) during
258/// prompt prefill and restore the closest snapshot when the next prompt
259/// arrives. Mirrors the mechanism used by upstream `llama-server`.
260///
261/// For non-hybrid models (Qwen 2.5, Llama 3, Gemma, ...) checkpoints are
262/// created but never used because the cheaper partial-trim path
263/// succeeds.
264///
265/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
266/// `with_*` setters.
267#[derive(Clone, Copy, Debug)]
268#[non_exhaustive]
269pub struct CheckpointParams {
270    /// Maximum number of checkpoints retained per persistent context.
271    /// `0` disables checkpointing entirely. Each checkpoint is a few MB
272    /// for typical hybrid models.
273    pub max_checkpoints: u32,
274    /// Approximate spacing between checkpoints during prompt prefill, in
275    /// tokens. The last `4..=4 + n_ubatch` tokens always get a
276    /// checkpoint regardless. `<= 0` means "only checkpoint near the end
277    /// of the prompt".
278    pub every_n_tokens: i32,
279    /// Don't checkpoint the very start of a prompt — saves space for
280    /// no benefit because we'd have to re-decode that prefix anyway if
281    /// it's the entire reuse window.
282    pub min_tokens: u32,
283    /// Don't take two checkpoints closer than this many tokens apart.
284    pub min_gap: u32,
285}
286
287impl Default for CheckpointParams {
288    fn default() -> Self {
289        Self {
290            // llama-server uses 32; cap lower because each checkpoint is
291            // a few MB and we'd rather not balloon RSS.
292            max_checkpoints: 8,
293            every_n_tokens: 8192,
294            min_tokens: 64,
295            min_gap: 64,
296        }
297    }
298}
299
300impl CheckpointParams {
301    /// Override the maximum number of checkpoints retained per context.
302    #[must_use]
303    pub fn with_max_checkpoints(mut self, max_checkpoints: u32) -> Self {
304        self.max_checkpoints = max_checkpoints;
305        self
306    }
307
308    /// Override the approximate spacing between checkpoints (in tokens).
309    #[must_use]
310    pub fn with_every_n_tokens(mut self, every_n_tokens: i32) -> Self {
311        self.every_n_tokens = every_n_tokens;
312        self
313    }
314
315    /// Override the minimum prompt length before checkpoints are taken.
316    #[must_use]
317    pub fn with_min_tokens(mut self, min_tokens: u32) -> Self {
318        self.min_tokens = min_tokens;
319        self
320    }
321
322    /// Override the minimum spacing between two consecutive checkpoints.
323    #[must_use]
324    pub fn with_min_gap(mut self, min_gap: u32) -> Self {
325        self.min_gap = min_gap;
326        self
327    }
328}
329
330/// Data type used for an entry in the attention KV cache.
331///
332/// Mirrors the subset of `ggml_type` values that `llama.cpp` accepts as KV
333/// cache element types. The `F16` default preserves full attention quality;
334/// quantizing (e.g. `Q8_0` ≈ ½ size, `Q4_0` ≈ ¼ size) trades a small amount
335/// of accuracy for a large VRAM reduction at long `n_ctx`.
336///
337/// This is a local shim around `llama_cpp_2::context::params::KvCacheType`
338/// so a future `llama-cpp-2` update doesn't force a breaking release of
339/// `rig-llama-cpp`. Marked `#[non_exhaustive]`: when llama.cpp adds a new
340/// `ggml_type`, we add a corresponding variant in a minor (`0.1.x`) release.
341#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
342#[allow(non_camel_case_types)]
343#[non_exhaustive]
344pub enum KvCacheType {
345    /// IEEE 754 single precision.
346    F32,
347    /// IEEE 754 half precision (llama.cpp's default for both K and V).
348    F16,
349    /// Brain floating-point 16, common on newer NVIDIA / AMD GPUs.
350    BF16,
351    /// IEEE 754 double precision.
352    F64,
353    /// 4-bit block quantization, type 0.
354    Q4_0,
355    /// 4-bit block quantization, type 1.
356    Q4_1,
357    /// 5-bit block quantization, type 0.
358    Q5_0,
359    /// 5-bit block quantization, type 1.
360    Q5_1,
361    /// 8-bit block quantization, type 0.
362    Q8_0,
363    /// 8-bit block quantization, type 1.
364    Q8_1,
365    /// 2-bit K-quant.
366    Q2_K,
367    /// 3-bit K-quant.
368    Q3_K,
369    /// 4-bit K-quant.
370    Q4_K,
371    /// 5-bit K-quant.
372    Q5_K,
373    /// 6-bit K-quant.
374    Q6_K,
375    /// 8-bit K-quant.
376    Q8_K,
377    /// Importance-weighted 2-bit, extra-extra-small.
378    IQ2_XXS,
379    /// Importance-weighted 2-bit, extra-small.
380    IQ2_XS,
381    /// Importance-weighted 2-bit, small.
382    IQ2_S,
383    /// Importance-weighted 3-bit, extra-extra-small.
384    IQ3_XXS,
385    /// Importance-weighted 3-bit, small.
386    IQ3_S,
387    /// Importance-weighted 1-bit, small.
388    IQ1_S,
389    /// Importance-weighted 1-bit, medium.
390    IQ1_M,
391    /// Importance-weighted 4-bit, extra-small.
392    IQ4_XS,
393    /// Importance-weighted 4-bit, non-linear.
394    IQ4_NL,
395    /// Signed 8-bit integer.
396    I8,
397    /// Signed 16-bit integer.
398    I16,
399    /// Signed 32-bit integer.
400    I32,
401    /// Signed 64-bit integer.
402    I64,
403    /// Ternary 1-bit, type 0.
404    TQ1_0,
405    /// Ternary 2-bit, type 0.
406    TQ2_0,
407    /// Microscaling FP4.
408    MXFP4,
409}
410
411impl From<KvCacheType> for llama_cpp_2::context::params::KvCacheType {
412    fn from(value: KvCacheType) -> Self {
413        use llama_cpp_2::context::params::KvCacheType as Upstream;
414        match value {
415            KvCacheType::F32 => Upstream::F32,
416            KvCacheType::F16 => Upstream::F16,
417            KvCacheType::BF16 => Upstream::BF16,
418            KvCacheType::F64 => Upstream::F64,
419            KvCacheType::Q4_0 => Upstream::Q4_0,
420            KvCacheType::Q4_1 => Upstream::Q4_1,
421            KvCacheType::Q5_0 => Upstream::Q5_0,
422            KvCacheType::Q5_1 => Upstream::Q5_1,
423            KvCacheType::Q8_0 => Upstream::Q8_0,
424            KvCacheType::Q8_1 => Upstream::Q8_1,
425            KvCacheType::Q2_K => Upstream::Q2_K,
426            KvCacheType::Q3_K => Upstream::Q3_K,
427            KvCacheType::Q4_K => Upstream::Q4_K,
428            KvCacheType::Q5_K => Upstream::Q5_K,
429            KvCacheType::Q6_K => Upstream::Q6_K,
430            KvCacheType::Q8_K => Upstream::Q8_K,
431            KvCacheType::IQ2_XXS => Upstream::IQ2_XXS,
432            KvCacheType::IQ2_XS => Upstream::IQ2_XS,
433            KvCacheType::IQ2_S => Upstream::IQ2_S,
434            KvCacheType::IQ3_XXS => Upstream::IQ3_XXS,
435            KvCacheType::IQ3_S => Upstream::IQ3_S,
436            KvCacheType::IQ1_S => Upstream::IQ1_S,
437            KvCacheType::IQ1_M => Upstream::IQ1_M,
438            KvCacheType::IQ4_XS => Upstream::IQ4_XS,
439            KvCacheType::IQ4_NL => Upstream::IQ4_NL,
440            KvCacheType::I8 => Upstream::I8,
441            KvCacheType::I16 => Upstream::I16,
442            KvCacheType::I32 => Upstream::I32,
443            KvCacheType::I64 => Upstream::I64,
444            KvCacheType::TQ1_0 => Upstream::TQ1_0,
445            KvCacheType::TQ2_0 => Upstream::TQ2_0,
446            KvCacheType::MXFP4 => Upstream::MXFP4,
447        }
448    }
449}
450
451/// KV cache quantization configuration.
452///
453/// Controls the data type used for the attention K and V caches. llama.cpp defaults
454/// both to `F16` (`GGML_TYPE_F16`), which is what `KvCacheParams::default()` preserves.
455/// Quantizing the KV cache (e.g. `Q8_0` → ~½ size, `Q4_0` → ~¼ size) trades a small
456/// amount of accuracy for a large reduction in VRAM usage, which is often the dominant
457/// cost at long `n_ctx`.
458///
459/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
460/// `with_*` setters:
461///
462/// ```
463/// use rig_llama_cpp::{KvCacheParams, KvCacheType};
464///
465/// let kv = KvCacheParams::default()
466///     .with_type_k(KvCacheType::Q8_0)
467///     .with_type_v(KvCacheType::Q8_0);
468/// ```
469#[derive(Clone, Copy, Debug)]
470#[non_exhaustive]
471pub struct KvCacheParams {
472    /// Data type for the K cache (default: [`KvCacheType::F16`]).
473    pub type_k: KvCacheType,
474    /// Data type for the V cache (default: [`KvCacheType::F16`]).
475    pub type_v: KvCacheType,
476}
477
478impl Default for KvCacheParams {
479    fn default() -> Self {
480        Self {
481            type_k: KvCacheType::F16,
482            type_v: KvCacheType::F16,
483        }
484    }
485}
486
487impl KvCacheParams {
488    /// Override the K cache data type.
489    #[must_use]
490    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
491        self.type_k = type_k;
492        self
493    }
494
495    /// Override the V cache data type.
496    #[must_use]
497    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
498        self.type_v = type_v;
499        self
500    }
501}
502
503/// Result of building a sampler chain: the chain itself plus whether grammar is active.
504///
505/// When grammar is present, `llama_sampler_sample()` already calls `accept()` internally
506/// and we must NOT call it again (double-accept corrupts grammar state). When grammar is
507/// absent, we call `accept()` explicitly after `sample()` to preserve the legacy
508/// double-accept behavior that the base samplers were tuned around.
509pub(crate) struct SamplerChain {
510    pub sampler: llama_cpp_2::sampling::LlamaSampler,
511    pub has_grammar: bool,
512}
513
514pub(crate) struct StreamDeltaState {
515    pub tool_calls: HashMap<u64, RawStreamingToolCall>,
516}