rig_llama_cpp/types.rs
1use rig_core::completion::{CompletionError, GetTokenUsage, Usage};
2use rig_core::message::AssistantContent;
3use rig_core::one_or_many::OneOrMany;
4use rig_core::streaming::RawStreamingChoice;
5use serde::{Deserialize, Serialize};
6use tokio::sync::{mpsc, oneshot};
7
8/// Raw completion response returned by the model.
9///
10/// Marked `#[non_exhaustive]` because new fields may be added in future
11/// minor releases.
12#[derive(Clone, Debug, Serialize, Deserialize)]
13#[non_exhaustive]
14pub struct RawResponse {
15 /// The full generated text.
16 pub text: String,
17}
18
19/// A single chunk emitted during streaming inference.
20///
21/// The final chunk in a stream includes token usage counts. Marked
22/// `#[non_exhaustive]` because new fields may be added in future minor
23/// releases.
24#[derive(Clone, Debug, Serialize, Deserialize)]
25#[non_exhaustive]
26pub struct StreamChunk {
27 /// The text fragment for this chunk.
28 pub text: String,
29 /// Number of prompt tokens (only set on the final chunk).
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub prompt_tokens: Option<u64>,
32 /// Number of completion tokens (only set on the final chunk).
33 #[serde(skip_serializing_if = "Option::is_none")]
34 pub completion_tokens: Option<u64>,
35 /// Number of prompt tokens that were served from the persistent KV-cache prefix
36 /// (only set on the final chunk).
37 #[serde(skip_serializing_if = "Option::is_none")]
38 pub cached_input_tokens: Option<u64>,
39}
40
41impl GetTokenUsage for StreamChunk {
42 fn token_usage(&self) -> Option<Usage> {
43 let (input, output) = self.prompt_tokens.zip(self.completion_tokens)?;
44 Some(Usage {
45 input_tokens: input,
46 output_tokens: output,
47 total_tokens: input + output,
48 cached_input_tokens: self.cached_input_tokens.unwrap_or(0),
49 cache_creation_input_tokens: 0,
50 tool_use_prompt_tokens: 0,
51 reasoning_tokens: 0,
52 })
53 }
54}
55
56pub(crate) type StreamSender =
57 mpsc::UnboundedSender<Result<RawStreamingChoice<StreamChunk>, CompletionError>>;
58
59pub(crate) enum ResponseChannel {
60 Completion(oneshot::Sender<Result<InferenceResult, String>>),
61 Streaming(StreamSender),
62}
63
64pub(crate) enum InferenceCommand {
65 Request(InferenceRequest),
66 Reload(ReloadRequest),
67 Shutdown,
68}
69
70pub(crate) struct ReloadRequest {
71 pub model_path: String,
72 pub mmproj_path: Option<String>,
73 pub n_ctx: u32,
74 pub fit_params: FitParams,
75 pub kv_cache_params: KvCacheParams,
76 pub checkpoint_params: CheckpointParams,
77 pub result_tx: std::sync::mpsc::Sender<Result<(), crate::error::LoadError>>,
78}
79
80pub(crate) struct InferenceRequest {
81 pub params: InferenceParams,
82 pub response_channel: ResponseChannel,
83}
84
85pub(crate) struct InferenceParams {
86 pub prepared_request: PreparedRequest,
87 pub max_tokens: u32,
88 pub temperature: f32,
89 pub top_p: f32,
90 pub top_k: i32,
91 pub min_p: f32,
92 pub presence_penalty: f32,
93 pub repetition_penalty: f32,
94}
95
96pub(crate) struct InferenceResult {
97 pub text: String,
98 pub choice: OneOrMany<AssistantContent>,
99 pub prompt_tokens: u64,
100 pub completion_tokens: u64,
101 /// Tokens of the prompt that were already present in the persistent KV cache
102 /// (i.e. the longest common prefix shared with the previous request).
103 pub cached_input_tokens: u64,
104}
105
106pub(crate) struct PreparedRequest {
107 pub messages_json: String,
108 pub tools_json: Option<String>,
109 pub tool_choice: Option<String>,
110 pub json_schema: Option<String>,
111 /// Parsed from the request's `additional_params` (`{ "thinking": bool }`).
112 /// `llama-cpp-2` 0.1.147 dropped the `chat_template_kwargs` plumbing the
113 /// old oaicompat path used to forward this to the jinja engine, so the
114 /// flag is currently advisory: thinking-enabled is the template default
115 /// and continues to work; thinking-disabled can no longer be enforced
116 /// through the template and is surfaced only via the model's defaults.
117 #[allow(dead_code)]
118 pub enable_thinking: bool,
119 #[cfg(feature = "mtmd")]
120 pub images: Vec<PreparedImage>,
121}
122
123/// One image extracted from the chat history with its FNV-1a hash precomputed.
124/// The hash is propagated into the underlying `MtmdBitmap` via `set_id` so
125/// that `MtmdInputChunk::id()` round-trips it for the prefix-cache diff.
126#[cfg(feature = "mtmd")]
127#[derive(Clone, Debug)]
128pub(crate) struct PreparedImage {
129 pub bytes: Vec<u8>,
130 pub hash: u64,
131}
132
133pub(crate) struct PromptBuildResult {
134 pub prompt: String,
135}
136
137/// Sampling parameters that control token generation.
138///
139/// Marked `#[non_exhaustive]` so future sampling knobs can be added without
140/// a breaking release. Start from [`SamplingParams::default`] and chain
141/// `with_*` setters:
142///
143/// ```
144/// let params = rig_llama_cpp::SamplingParams::default()
145/// .with_top_k(40)
146/// .with_presence_penalty(1.5);
147/// ```
148#[derive(Clone, Copy, Debug)]
149#[non_exhaustive]
150pub struct SamplingParams {
151 /// Nucleus sampling threshold (default: `0.95`).
152 pub top_p: f32,
153 /// Top-k sampling parameter (default: `40`).
154 pub top_k: i32,
155 /// Minimum probability threshold (default: `0.0`).
156 pub min_p: f32,
157 /// Penalty for token presence (default: `0.0`).
158 pub presence_penalty: f32,
159 /// Penalty for token repetition (default: `1.0`).
160 pub repetition_penalty: f32,
161}
162
163impl Default for SamplingParams {
164 fn default() -> Self {
165 Self {
166 top_p: 0.95,
167 top_k: 40,
168 min_p: 0.0,
169 presence_penalty: 0.0,
170 repetition_penalty: 1.0,
171 }
172 }
173}
174
175impl SamplingParams {
176 /// Set the nucleus sampling threshold.
177 #[must_use]
178 pub fn with_top_p(mut self, top_p: f32) -> Self {
179 self.top_p = top_p;
180 self
181 }
182
183 /// Set the top-k sampling parameter.
184 #[must_use]
185 pub fn with_top_k(mut self, top_k: i32) -> Self {
186 self.top_k = top_k;
187 self
188 }
189
190 /// Set the minimum probability threshold.
191 #[must_use]
192 pub fn with_min_p(mut self, min_p: f32) -> Self {
193 self.min_p = min_p;
194 self
195 }
196
197 /// Set the presence penalty.
198 #[must_use]
199 pub fn with_presence_penalty(mut self, presence_penalty: f32) -> Self {
200 self.presence_penalty = presence_penalty;
201 self
202 }
203
204 /// Set the repetition penalty.
205 #[must_use]
206 pub fn with_repetition_penalty(mut self, repetition_penalty: f32) -> Self {
207 self.repetition_penalty = repetition_penalty;
208 self
209 }
210}
211
212/// Configuration for automatic GPU/CPU layer fitting.
213///
214/// Passed to [`crate::Client::builder`] (or [`crate::Client::from_gguf`]) so
215/// llama.cpp can probe available device memory and pick the optimal number
216/// of layers to offload to GPU automatically, instead of requiring a manual
217/// `n_gpu_layers` value.
218///
219/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
220/// `with_*` setters.
221#[derive(Clone, Debug)]
222#[non_exhaustive]
223pub struct FitParams {
224 /// Memory margin per device in bytes. If `None`, defaults to 1 GiB per device.
225 pub margins: Option<Vec<usize>>,
226 /// Minimum context size to preserve during fitting (default: `4096`).
227 pub n_ctx_min: u32,
228}
229
230impl Default for FitParams {
231 fn default() -> Self {
232 Self {
233 margins: None,
234 n_ctx_min: 4096,
235 }
236 }
237}
238
239impl FitParams {
240 /// Override the per-device memory margin in bytes.
241 #[must_use]
242 pub fn with_margins(mut self, margins: Option<Vec<usize>>) -> Self {
243 self.margins = margins;
244 self
245 }
246
247 /// Override the minimum context size to preserve during fitting.
248 #[must_use]
249 pub fn with_n_ctx_min(mut self, n_ctx_min: u32) -> Self {
250 self.n_ctx_min = n_ctx_min;
251 self
252 }
253}
254
255/// Tunable parameters for the in-memory state-checkpoint cache used to
256/// preserve KV/recurrent state across chat turns for hybrid models.
257///
258/// Hybrid architectures (Qwen 3.5, Jamba, etc.) interleave Mamba-style
259/// recurrent layers with transformer layers. The recurrent state can't be
260/// rolled back to an arbitrary earlier position, so a partial KV trim
261/// fails whenever the next prompt diverges deep into the conversation.
262/// To work around this, we periodically snapshot the partial seq state
263/// (recurrent + SWA, via `LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY`) during
264/// prompt prefill and restore the closest snapshot when the next prompt
265/// arrives. Mirrors the mechanism used by upstream `llama-server`.
266///
267/// For non-hybrid models (Qwen 2.5, Llama 3, Gemma, ...) checkpoints are
268/// created but never used because the cheaper partial-trim path
269/// succeeds.
270///
271/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
272/// `with_*` setters.
273#[derive(Clone, Copy, Debug)]
274#[non_exhaustive]
275pub struct CheckpointParams {
276 /// Maximum number of checkpoints retained per persistent context.
277 /// `0` disables checkpointing entirely. Each checkpoint is a few MB
278 /// for typical hybrid models.
279 pub max_checkpoints: u32,
280 /// Approximate spacing between checkpoints during prompt prefill, in
281 /// tokens. The last `4..=4 + n_ubatch` tokens always get a
282 /// checkpoint regardless. `<= 0` means "only checkpoint near the end
283 /// of the prompt".
284 pub every_n_tokens: i32,
285 /// Don't checkpoint the very start of a prompt — saves space for
286 /// no benefit because we'd have to re-decode that prefix anyway if
287 /// it's the entire reuse window.
288 pub min_tokens: u32,
289 /// Don't take two checkpoints closer than this many tokens apart.
290 pub min_gap: u32,
291}
292
293impl Default for CheckpointParams {
294 fn default() -> Self {
295 Self {
296 // llama-server uses 32; cap lower because each checkpoint is
297 // a few MB and we'd rather not balloon RSS.
298 max_checkpoints: 8,
299 every_n_tokens: 8192,
300 min_tokens: 64,
301 min_gap: 64,
302 }
303 }
304}
305
306impl CheckpointParams {
307 /// Override the maximum number of checkpoints retained per context.
308 #[must_use]
309 pub fn with_max_checkpoints(mut self, max_checkpoints: u32) -> Self {
310 self.max_checkpoints = max_checkpoints;
311 self
312 }
313
314 /// Override the approximate spacing between checkpoints (in tokens).
315 #[must_use]
316 pub fn with_every_n_tokens(mut self, every_n_tokens: i32) -> Self {
317 self.every_n_tokens = every_n_tokens;
318 self
319 }
320
321 /// Override the minimum prompt length before checkpoints are taken.
322 #[must_use]
323 pub fn with_min_tokens(mut self, min_tokens: u32) -> Self {
324 self.min_tokens = min_tokens;
325 self
326 }
327
328 /// Override the minimum spacing between two consecutive checkpoints.
329 #[must_use]
330 pub fn with_min_gap(mut self, min_gap: u32) -> Self {
331 self.min_gap = min_gap;
332 self
333 }
334}
335
336/// Data type used for an entry in the attention KV cache.
337///
338/// Mirrors the subset of `ggml_type` values that `llama.cpp` accepts as KV
339/// cache element types. The `F16` default preserves full attention quality;
340/// quantizing (e.g. `Q8_0` ≈ ½ size, `Q4_0` ≈ ¼ size) trades a small amount
341/// of accuracy for a large VRAM reduction at long `n_ctx`.
342///
343/// This is a local shim around `llama_cpp_2::context::params::KvCacheType`
344/// so a future `llama-cpp-2` update doesn't force a breaking release of
345/// `rig-llama-cpp`. Marked `#[non_exhaustive]`: when llama.cpp adds a new
346/// `ggml_type`, we add a corresponding variant in a minor (`0.1.x`) release.
347#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
348#[allow(non_camel_case_types)]
349#[non_exhaustive]
350pub enum KvCacheType {
351 /// IEEE 754 single precision.
352 F32,
353 /// IEEE 754 half precision (llama.cpp's default for both K and V).
354 F16,
355 /// Brain floating-point 16, common on newer NVIDIA / AMD GPUs.
356 BF16,
357 /// IEEE 754 double precision.
358 F64,
359 /// 4-bit block quantization, type 0.
360 Q4_0,
361 /// 4-bit block quantization, type 1.
362 Q4_1,
363 /// 5-bit block quantization, type 0.
364 Q5_0,
365 /// 5-bit block quantization, type 1.
366 Q5_1,
367 /// 8-bit block quantization, type 0.
368 Q8_0,
369 /// 8-bit block quantization, type 1.
370 Q8_1,
371 /// 2-bit K-quant.
372 Q2_K,
373 /// 3-bit K-quant.
374 Q3_K,
375 /// 4-bit K-quant.
376 Q4_K,
377 /// 5-bit K-quant.
378 Q5_K,
379 /// 6-bit K-quant.
380 Q6_K,
381 /// 8-bit K-quant.
382 Q8_K,
383 /// Importance-weighted 2-bit, extra-extra-small.
384 IQ2_XXS,
385 /// Importance-weighted 2-bit, extra-small.
386 IQ2_XS,
387 /// Importance-weighted 2-bit, small.
388 IQ2_S,
389 /// Importance-weighted 3-bit, extra-extra-small.
390 IQ3_XXS,
391 /// Importance-weighted 3-bit, small.
392 IQ3_S,
393 /// Importance-weighted 1-bit, small.
394 IQ1_S,
395 /// Importance-weighted 1-bit, medium.
396 IQ1_M,
397 /// Importance-weighted 4-bit, extra-small.
398 IQ4_XS,
399 /// Importance-weighted 4-bit, non-linear.
400 IQ4_NL,
401 /// Signed 8-bit integer.
402 I8,
403 /// Signed 16-bit integer.
404 I16,
405 /// Signed 32-bit integer.
406 I32,
407 /// Signed 64-bit integer.
408 I64,
409 /// Ternary 1-bit, type 0.
410 TQ1_0,
411 /// Ternary 2-bit, type 0.
412 TQ2_0,
413 /// Microscaling FP4.
414 MXFP4,
415}
416
417impl From<KvCacheType> for llama_cpp_2::context::params::KvCacheType {
418 fn from(value: KvCacheType) -> Self {
419 use llama_cpp_2::context::params::KvCacheType as Upstream;
420 match value {
421 KvCacheType::F32 => Upstream::F32,
422 KvCacheType::F16 => Upstream::F16,
423 KvCacheType::BF16 => Upstream::BF16,
424 KvCacheType::F64 => Upstream::F64,
425 KvCacheType::Q4_0 => Upstream::Q4_0,
426 KvCacheType::Q4_1 => Upstream::Q4_1,
427 KvCacheType::Q5_0 => Upstream::Q5_0,
428 KvCacheType::Q5_1 => Upstream::Q5_1,
429 KvCacheType::Q8_0 => Upstream::Q8_0,
430 KvCacheType::Q8_1 => Upstream::Q8_1,
431 KvCacheType::Q2_K => Upstream::Q2_K,
432 KvCacheType::Q3_K => Upstream::Q3_K,
433 KvCacheType::Q4_K => Upstream::Q4_K,
434 KvCacheType::Q5_K => Upstream::Q5_K,
435 KvCacheType::Q6_K => Upstream::Q6_K,
436 KvCacheType::Q8_K => Upstream::Q8_K,
437 KvCacheType::IQ2_XXS => Upstream::IQ2_XXS,
438 KvCacheType::IQ2_XS => Upstream::IQ2_XS,
439 KvCacheType::IQ2_S => Upstream::IQ2_S,
440 KvCacheType::IQ3_XXS => Upstream::IQ3_XXS,
441 KvCacheType::IQ3_S => Upstream::IQ3_S,
442 KvCacheType::IQ1_S => Upstream::IQ1_S,
443 KvCacheType::IQ1_M => Upstream::IQ1_M,
444 KvCacheType::IQ4_XS => Upstream::IQ4_XS,
445 KvCacheType::IQ4_NL => Upstream::IQ4_NL,
446 KvCacheType::I8 => Upstream::I8,
447 KvCacheType::I16 => Upstream::I16,
448 KvCacheType::I32 => Upstream::I32,
449 KvCacheType::I64 => Upstream::I64,
450 KvCacheType::TQ1_0 => Upstream::TQ1_0,
451 KvCacheType::TQ2_0 => Upstream::TQ2_0,
452 KvCacheType::MXFP4 => Upstream::MXFP4,
453 }
454 }
455}
456
457/// KV cache quantization configuration.
458///
459/// Controls the data type used for the attention K and V caches. llama.cpp defaults
460/// both to `F16` (`GGML_TYPE_F16`), which is what `KvCacheParams::default()` preserves.
461/// Quantizing the KV cache (e.g. `Q8_0` → ~½ size, `Q4_0` → ~¼ size) trades a small
462/// amount of accuracy for a large reduction in VRAM usage, which is often the dominant
463/// cost at long `n_ctx`.
464///
465/// Marked `#[non_exhaustive]`; build via `Default::default()` and chain the
466/// `with_*` setters:
467///
468/// ```
469/// use rig_llama_cpp::{KvCacheParams, KvCacheType};
470///
471/// let kv = KvCacheParams::default()
472/// .with_type_k(KvCacheType::Q8_0)
473/// .with_type_v(KvCacheType::Q8_0);
474/// ```
475#[derive(Clone, Copy, Debug)]
476#[non_exhaustive]
477pub struct KvCacheParams {
478 /// Data type for the K cache (default: [`KvCacheType::F16`]).
479 pub type_k: KvCacheType,
480 /// Data type for the V cache (default: [`KvCacheType::F16`]).
481 pub type_v: KvCacheType,
482}
483
484impl Default for KvCacheParams {
485 fn default() -> Self {
486 Self {
487 type_k: KvCacheType::F16,
488 type_v: KvCacheType::F16,
489 }
490 }
491}
492
493impl KvCacheParams {
494 /// Override the K cache data type.
495 #[must_use]
496 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
497 self.type_k = type_k;
498 self
499 }
500
501 /// Override the V cache data type.
502 #[must_use]
503 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
504 self.type_v = type_v;
505 self
506 }
507}