llama_cpp_4/context/params/mod.rs
1//! A safe wrapper around `llama_context_params`.
2//!
3//! Use [`LlamaContextParams`] to configure context size, batching, KV layout,
4//! `RoPE` / `YaRN` scaling, flash attention, per-sequence samplers, and pairing
5//! with another context (`ctx_other`).
6mod advanced;
7mod types;
8
9pub use types::*;
10
11use std::num::NonZeroU32;
12
13use thiserror::Error;
14
15use crate::sampling::LlamaSampler;
16
17/// Error returned when [`LlamaContextParams::try_clone`] cannot duplicate state.
18#[derive(Debug, Error, PartialEq, Eq)]
19pub enum ParamsCloneError {
20 /// Per-sequence sampler chains cannot be duplicated.
21 #[error("cannot clone params that own per-sequence sampler chains")]
22 SamplerChains,
23}
24
25/// Builder for [`llama_context_params`](llama_cpp_sys_4::llama_context_params).
26///
27/// Construct with [`Default::default()`], chain `with_*` setters, then pass the
28/// value to [`crate::model::LlamaModel::new_context`]. Getter methods mirror
29/// the fields that exist on the underlying C struct.
30///
31/// # Sampler ownership
32///
33/// [`Self::with_sampler_seq_configs`] stores owned [`LlamaSampler`] chains inside
34/// this struct until the context is created. [`Clone`] clears sampler configs
35/// because the underlying chains cannot be duplicated safely.
36///
37/// # Examples
38///
39/// ```rust
40/// # use std::num::NonZeroU32;
41/// use llama_cpp_4::context::params::LlamaContextParams;
42///
43/// let ctx_params = LlamaContextParams::default()
44/// .with_n_ctx(NonZeroU32::new(2048));
45///
46/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
47/// ```
48#[derive(Debug)]
49#[allow(
50 missing_docs,
51 clippy::struct_excessive_bools,
52 clippy::module_name_repetitions
53)]
54pub struct LlamaContextParams {
55 pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
56 /// When `true`, the `TurboQuant` attention rotation (PR #21038) will be
57 /// disabled for any context created from these params.
58 pub(crate) attn_rot_disabled: bool,
59 /// Keeps sampler chains alive while `context_params.samplers` points at them.
60 owned_samplers: Vec<LlamaSampler>,
61 sampler_configs: Vec<llama_cpp_sys_4::llama_sampler_seq_config>,
62}
63
64/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
65unsafe impl Send for LlamaContextParams {}
66unsafe impl Sync for LlamaContextParams {}
67
68impl LlamaContextParams {
69 /// Set the side of the context
70 ///
71 /// # Examples
72 ///
73 /// ```rust
74 /// # use std::num::NonZeroU32;
75 /// use llama_cpp_4::context::params::LlamaContextParams;
76 /// let params = LlamaContextParams::default();
77 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
78 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
79 /// ```
80 #[must_use]
81 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
82 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
83 self
84 }
85
86 /// Get the size of the context.
87 ///
88 /// [`None`] if the context size is specified by the model and not the context.
89 ///
90 /// # Examples
91 ///
92 /// ```rust
93 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
94 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
95 #[must_use]
96 pub fn n_ctx(&self) -> Option<NonZeroU32> {
97 NonZeroU32::new(self.context_params.n_ctx)
98 }
99
100 /// Set the maximum number of independent sequence states in the context.
101 ///
102 /// This maps to llama.cpp's `llama_context_params.n_seq_max` and must match
103 /// the highest sequence id used by batched decoding.
104 ///
105 /// # Examples
106 ///
107 /// ```rust
108 /// use llama_cpp_4::context::params::LlamaContextParams;
109 /// let params = LlamaContextParams::default()
110 /// .with_n_seq_max(16);
111 /// assert_eq!(params.n_seq_max(), 16);
112 /// ```
113 #[must_use]
114 pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
115 self.context_params.n_seq_max = n_seq_max.max(1);
116 self
117 }
118
119 /// Get the configured maximum number of independent sequence states.
120 #[must_use]
121 pub fn n_seq_max(&self) -> u32 {
122 self.context_params.n_seq_max
123 }
124
125 /// Set the `n_batch`
126 ///
127 /// # Examples
128 ///
129 /// ```rust
130 /// # use std::num::NonZeroU32;
131 /// use llama_cpp_4::context::params::LlamaContextParams;
132 /// let params = LlamaContextParams::default()
133 /// .with_n_batch(2048);
134 /// assert_eq!(params.n_batch(), 2048);
135 /// ```
136 #[must_use]
137 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
138 self.context_params.n_batch = n_batch;
139 self
140 }
141
142 /// Get the `n_batch`
143 ///
144 /// # Examples
145 ///
146 /// ```rust
147 /// use llama_cpp_4::context::params::LlamaContextParams;
148 /// let params = LlamaContextParams::default();
149 /// assert_eq!(params.n_batch(), 2048);
150 /// ```
151 #[must_use]
152 pub fn n_batch(&self) -> u32 {
153 self.context_params.n_batch
154 }
155
156 /// Set the `n_ubatch`
157 ///
158 /// # Examples
159 ///
160 /// ```rust
161 /// # use std::num::NonZeroU32;
162 /// use llama_cpp_4::context::params::LlamaContextParams;
163 /// let params = LlamaContextParams::default()
164 /// .with_n_ubatch(512);
165 /// assert_eq!(params.n_ubatch(), 512);
166 /// ```
167 #[must_use]
168 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
169 self.context_params.n_ubatch = n_ubatch;
170 self
171 }
172
173 /// Get the `n_ubatch`
174 ///
175 /// # Examples
176 ///
177 /// ```rust
178 /// use llama_cpp_4::context::params::LlamaContextParams;
179 /// let params = LlamaContextParams::default();
180 /// assert_eq!(params.n_ubatch(), 512);
181 /// ```
182 #[must_use]
183 pub fn n_ubatch(&self) -> u32 {
184 self.context_params.n_ubatch
185 }
186
187 /// Set the context type (e.g. [`LlamaContextType::Mtp`] for the draft context in
188 /// [`crate::mtp::MtpSession`]).
189 #[must_use]
190 pub fn with_ctx_type(mut self, ctx_type: LlamaContextType) -> Self {
191 self.context_params.ctx_type = ctx_type.into();
192 self
193 }
194
195 /// Get the configured context type.
196 #[must_use]
197 pub fn ctx_type(&self) -> LlamaContextType {
198 self.context_params.ctx_type.into()
199 }
200
201 /// Set the number of recurrent-state snapshots per sequence (MTP rollback).
202 ///
203 /// Must be `>=` [`MtpSessionConfig::n_draft_max`](crate::mtp::MtpSessionConfig::n_draft_max)
204 /// on the draft context. See [`crate::mtp`].
205 #[must_use]
206 pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
207 self.context_params.n_rs_seq = n_rs_seq;
208 self
209 }
210
211 /// Get the number of recurrent-state snapshots per sequence used for MTP rollback.
212 #[must_use]
213 pub fn n_rs_seq(&self) -> u32 {
214 self.context_params.n_rs_seq
215 }
216
217 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
218 ///
219 /// # Examples
220 ///
221 /// ```rust
222 /// use llama_cpp_4::context::params::LlamaContextParams;
223 /// let params = LlamaContextParams::default()
224 /// .with_offload_kqv(false);
225 /// assert_eq!(params.offload_kqv(), false);
226 /// ```
227 #[must_use]
228 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
229 self.context_params.offload_kqv = enabled;
230 self
231 }
232
233 /// Get the `offload_kqv` parameter
234 ///
235 /// # Examples
236 ///
237 /// ```rust
238 /// use llama_cpp_4::context::params::LlamaContextParams;
239 /// let params = LlamaContextParams::default();
240 /// assert_eq!(params.offload_kqv(), true);
241 /// ```
242 #[must_use]
243 pub fn offload_kqv(&self) -> bool {
244 self.context_params.offload_kqv
245 }
246
247 /// Set the type of rope scaling.
248 ///
249 /// # Examples
250 ///
251 /// ```rust
252 /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
253 /// let params = LlamaContextParams::default()
254 /// .with_rope_scaling_type(RopeScalingType::Linear);
255 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
256 /// ```
257 #[must_use]
258 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
259 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
260 self
261 }
262
263 /// Get the type of rope scaling.
264 ///
265 /// # Examples
266 ///
267 /// ```rust
268 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
269 /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
270 /// ```
271 #[must_use]
272 pub fn rope_scaling_type(&self) -> RopeScalingType {
273 RopeScalingType::from(self.context_params.rope_scaling_type)
274 }
275
276 /// Set the rope frequency base.
277 ///
278 /// # Examples
279 ///
280 /// ```rust
281 /// use llama_cpp_4::context::params::LlamaContextParams;
282 /// let params = LlamaContextParams::default()
283 /// .with_rope_freq_base(0.5);
284 /// assert_eq!(params.rope_freq_base(), 0.5);
285 /// ```
286 #[must_use]
287 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
288 self.context_params.rope_freq_base = rope_freq_base;
289 self
290 }
291
292 /// Get the rope frequency base.
293 ///
294 /// # Examples
295 ///
296 /// ```rust
297 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
298 /// assert_eq!(params.rope_freq_base(), 0.0);
299 /// ```
300 #[must_use]
301 pub fn rope_freq_base(&self) -> f32 {
302 self.context_params.rope_freq_base
303 }
304
305 /// Set the rope frequency scale.
306 ///
307 /// # Examples
308 ///
309 /// ```rust
310 /// use llama_cpp_4::context::params::LlamaContextParams;
311 /// let params = LlamaContextParams::default()
312 /// .with_rope_freq_scale(0.5);
313 /// assert_eq!(params.rope_freq_scale(), 0.5);
314 /// ```
315 #[must_use]
316 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
317 self.context_params.rope_freq_scale = rope_freq_scale;
318 self
319 }
320
321 /// Get the rope frequency scale.
322 ///
323 /// # Examples
324 ///
325 /// ```rust
326 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
327 /// assert_eq!(params.rope_freq_scale(), 0.0);
328 /// ```
329 #[must_use]
330 pub fn rope_freq_scale(&self) -> f32 {
331 self.context_params.rope_freq_scale
332 }
333
334 /// Get the number of threads.
335 ///
336 /// # Examples
337 ///
338 /// ```rust
339 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
340 /// assert_eq!(params.n_threads(), 4);
341 /// ```
342 #[must_use]
343 pub fn n_threads(&self) -> i32 {
344 self.context_params.n_threads
345 }
346
347 /// Get the number of threads allocated for batches.
348 ///
349 /// # Examples
350 ///
351 /// ```rust
352 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
353 /// assert_eq!(params.n_threads_batch(), 4);
354 /// ```
355 #[must_use]
356 pub fn n_threads_batch(&self) -> i32 {
357 self.context_params.n_threads_batch
358 }
359
360 /// Set the number of threads.
361 ///
362 /// # Examples
363 ///
364 /// ```rust
365 /// use llama_cpp_4::context::params::LlamaContextParams;
366 /// let params = LlamaContextParams::default()
367 /// .with_n_threads(8);
368 /// assert_eq!(params.n_threads(), 8);
369 /// ```
370 #[must_use]
371 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
372 self.context_params.n_threads = n_threads;
373 self
374 }
375
376 /// Set the number of threads allocated for batches.
377 ///
378 /// # Examples
379 ///
380 /// ```rust
381 /// use llama_cpp_4::context::params::LlamaContextParams;
382 /// let params = LlamaContextParams::default()
383 /// .with_n_threads_batch(8);
384 /// assert_eq!(params.n_threads_batch(), 8);
385 /// ```
386 #[must_use]
387 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
388 self.context_params.n_threads_batch = n_threads;
389 self
390 }
391
392 /// Check whether embeddings are enabled
393 ///
394 /// # Examples
395 ///
396 /// ```rust
397 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
398 /// assert!(!params.embeddings());
399 /// ```
400 #[must_use]
401 pub fn embeddings(&self) -> bool {
402 self.context_params.embeddings
403 }
404
405 /// Enable the use of embeddings
406 ///
407 /// # Examples
408 ///
409 /// ```rust
410 /// use llama_cpp_4::context::params::LlamaContextParams;
411 /// let params = LlamaContextParams::default()
412 /// .with_embeddings(true);
413 /// assert!(params.embeddings());
414 /// ```
415 #[must_use]
416 pub fn with_embeddings(mut self, embedding: bool) -> Self {
417 self.context_params.embeddings = embedding;
418 self
419 }
420
421 /// Set the evaluation callback.
422 ///
423 /// # Examples
424 ///
425 /// ```no_run
426 /// extern "C" fn cb_eval_fn(
427 /// t: *mut llama_cpp_sys_4::ggml_tensor,
428 /// ask: bool,
429 /// user_data: *mut std::ffi::c_void,
430 /// ) -> bool {
431 /// false
432 /// }
433 ///
434 /// use llama_cpp_4::context::params::LlamaContextParams;
435 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
436 /// ```
437 #[must_use]
438 pub fn with_cb_eval(
439 mut self,
440 cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
441 ) -> Self {
442 self.context_params.cb_eval = cb_eval;
443 self
444 }
445
446 /// Set the evaluation callback user data.
447 ///
448 /// # Examples
449 ///
450 /// ```no_run
451 /// use llama_cpp_4::context::params::LlamaContextParams;
452 /// let params = LlamaContextParams::default();
453 /// let user_data = std::ptr::null_mut();
454 /// let params = params.with_cb_eval_user_data(user_data);
455 /// ```
456 #[must_use]
457 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
458 self.context_params.cb_eval_user_data = cb_eval_user_data;
459 self
460 }
461
462 /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
463 /// intercept intermediate tensor outputs during [`crate::LlamaContext::decode`].
464 ///
465 /// Sets `cb_eval` to copy tensors matching the capture filter (layer outputs,
466 /// named nodes, prefix, or all). After `decode()`, read results from the
467 /// capture — see [`crate::TensorCapture`] and [`crate::context::tensor_capture`].
468 ///
469 /// The capture must outlive the context. Call [`TensorCapture::clear`](crate::TensorCapture::clear) before
470 /// reusing it on another batch.
471 ///
472 /// # Example
473 ///
474 /// ```no_run
475 /// use llama_cpp_4::prelude::*;
476 ///
477 /// fn main() {
478 /// let backend = LlamaBackend::init().unwrap();
479 /// let model = LlamaModel::load_from_file(
480 /// &backend,
481 /// "model.gguf",
482 /// &LlamaModelParams::default(),
483 /// )
484 /// .unwrap();
485 ///
486 /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
487 /// let ctx_params = LlamaContextParams::default().with_tensor_capture(&mut capture);
488 /// let _ctx = model.new_context(&backend, ctx_params).unwrap();
489 /// }
490 /// ```
491 #[must_use]
492 pub fn with_tensor_capture(self, capture: &mut super::tensor_capture::TensorCapture) -> Self {
493 self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
494 .with_cb_eval_user_data(
495 std::ptr::from_mut::<super::tensor_capture::TensorCapture>(capture)
496 .cast::<std::ffi::c_void>(),
497 )
498 }
499
500 /// Set the storage type for the **K** (key) KV cache tensors.
501 ///
502 /// The default is `GgmlType::F16`. Quantized types like `GgmlType::Q5_0`
503 /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
504 /// `TurboQuant` attention rotation (the default) keeps quality high.
505 ///
506 /// # Examples
507 ///
508 /// ```rust
509 /// use llama_cpp_4::context::params::LlamaContextParams;
510 /// use llama_cpp_4::quantize::GgmlType;
511 /// let params = LlamaContextParams::default()
512 /// .with_cache_type_k(GgmlType::Q5_0);
513 /// ```
514 #[must_use]
515 pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
516 self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
517 self
518 }
519
520 /// Get the K-cache storage type.
521 #[must_use]
522 pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
523 self.context_params.type_k
524 }
525
526 /// Set the storage type for the **V** (value) KV cache tensors.
527 ///
528 /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
529 ///
530 /// # Examples
531 ///
532 /// ```rust
533 /// use llama_cpp_4::context::params::LlamaContextParams;
534 /// use llama_cpp_4::quantize::GgmlType;
535 /// let params = LlamaContextParams::default()
536 /// .with_cache_type_v(GgmlType::Q5_0);
537 /// ```
538 #[must_use]
539 pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
540 self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
541 self
542 }
543
544 /// Get the V-cache storage type.
545 #[must_use]
546 pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
547 self.context_params.type_v
548 }
549
550 /// Control the `TurboQuant` attention-rotation feature (llama.cpp PR #21038).
551 ///
552 /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
553 /// before writing them into the KV cache. This significantly improves
554 /// quantized KV-cache quality at near-zero overhead, and is enabled
555 /// automatically for models whose head dimension is a power of two.
556 ///
557 /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
558 /// The env-var is applied just before the context is created and restored
559 /// afterwards, so this is safe to call from a single thread.
560 ///
561 /// # Examples
562 ///
563 /// ```rust
564 /// use llama_cpp_4::context::params::LlamaContextParams;
565 /// // Disable rotation for this context only:
566 /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
567 /// assert!(params.attn_rot_disabled());
568 /// ```
569 #[must_use]
570 pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
571 self.attn_rot_disabled = disabled;
572 self
573 }
574
575 /// Returns `true` if `TurboQuant` attention rotation is disabled for this context.
576 ///
577 /// ```rust
578 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
579 /// assert!(!params.attn_rot_disabled());
580 /// ```
581 #[must_use]
582 pub fn attn_rot_disabled(&self) -> bool {
583 self.attn_rot_disabled
584 }
585
586 /// Set the type of pooling.
587 ///
588 /// # Examples
589 ///
590 /// ```rust
591 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
592 /// let params = LlamaContextParams::default()
593 /// .with_pooling_type(LlamaPoolingType::Last);
594 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
595 /// ```
596 #[must_use]
597 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
598 self.context_params.pooling_type = i32::from(pooling_type);
599 self
600 }
601
602 /// Get the type of pooling.
603 ///
604 /// # Examples
605 ///
606 /// ```rust
607 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
608 /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
609 /// ```
610 #[must_use]
611 pub fn pooling_type(&self) -> LlamaPoolingType {
612 LlamaPoolingType::from(self.context_params.pooling_type)
613 }
614
615 /// Clone these params, failing when sampler chains are attached.
616 ///
617 /// Prefer this over [`Clone::clone`] when you need to detect dropped sampler
618 /// configuration.
619 ///
620 /// # Errors
621 ///
622 /// Returns [`ParamsCloneError::SamplerChains`] when per-sequence sampler
623 /// chains are attached and cannot be duplicated.
624 pub fn try_clone(&self) -> Result<Self, ParamsCloneError> {
625 if self.sampler_configs.is_empty() {
626 Ok(self.clone())
627 } else {
628 Err(ParamsCloneError::SamplerChains)
629 }
630 }
631}
632
633/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
634/// ```
635/// # use std::num::NonZeroU32;
636/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
637/// let params = LlamaContextParams::default();
638/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
639/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
640/// ```
641impl Default for LlamaContextParams {
642 fn default() -> Self {
643 let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
644 Self {
645 context_params,
646 attn_rot_disabled: false,
647 owned_samplers: Vec::new(),
648 sampler_configs: Vec::new(),
649 }
650 }
651}
652
653/// Duplicate context params for reuse.
654///
655/// Sampler chains attached via [`LlamaContextParams::with_sampler_seq_configs`]
656/// are **not** cloned — the copy clears `samplers` / `n_samplers` because the
657/// underlying C chains cannot be duplicated safely.
658impl Clone for LlamaContextParams {
659 fn clone(&self) -> Self {
660 let mut context_params = self.context_params;
661 // Sampler chains cannot be duplicated here; cloned params omit them.
662 context_params.samplers = std::ptr::null_mut();
663 context_params.n_samplers = 0;
664 Self {
665 context_params,
666 attn_rot_disabled: self.attn_rot_disabled,
667 owned_samplers: Vec::new(),
668 sampler_configs: Vec::new(),
669 }
670 }
671}