Skip to main content

llama_cpp_4/context/params/
mod.rs

1//! A safe wrapper around `llama_context_params`.
2//!
3//! Use [`LlamaContextParams`] to configure context size, batching, KV layout,
4//! `RoPE` / `YaRN` scaling, flash attention, per-sequence samplers, and pairing
5//! with another context (`ctx_other`).
6mod advanced;
7mod types;
8
9pub use types::*;
10
11use std::num::NonZeroU32;
12
13use thiserror::Error;
14
15use crate::sampling::LlamaSampler;
16
17/// Error returned when [`LlamaContextParams::try_clone`] cannot duplicate state.
18#[derive(Debug, Error, PartialEq, Eq)]
19pub enum ParamsCloneError {
20    /// Per-sequence sampler chains cannot be duplicated.
21    #[error("cannot clone params that own per-sequence sampler chains")]
22    SamplerChains,
23}
24
25/// Builder for [`llama_context_params`](llama_cpp_sys_4::llama_context_params).
26///
27/// Construct with [`Default::default()`], chain `with_*` setters, then pass the
28/// value to [`crate::model::LlamaModel::new_context`]. Getter methods mirror
29/// the fields that exist on the underlying C struct.
30///
31/// # Sampler ownership
32///
33/// [`Self::with_sampler_seq_configs`] stores owned [`LlamaSampler`] chains inside
34/// this struct until the context is created. [`Clone`] clears sampler configs
35/// because the underlying chains cannot be duplicated safely.
36///
37/// # Examples
38///
39/// ```rust
40/// # use std::num::NonZeroU32;
41/// use llama_cpp_4::context::params::LlamaContextParams;
42///
43/// let ctx_params = LlamaContextParams::default()
44///     .with_n_ctx(NonZeroU32::new(2048));
45///
46/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
47/// ```
48#[derive(Debug)]
49#[allow(
50    missing_docs,
51    clippy::struct_excessive_bools,
52    clippy::module_name_repetitions
53)]
54pub struct LlamaContextParams {
55    pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
56    /// When `true`, the `TurboQuant` attention rotation (PR #21038) will be
57    /// disabled for any context created from these params.
58    pub(crate) attn_rot_disabled: bool,
59    /// Keeps sampler chains alive while `context_params.samplers` points at them.
60    owned_samplers: Vec<LlamaSampler>,
61    sampler_configs: Vec<llama_cpp_sys_4::llama_sampler_seq_config>,
62}
63
64/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
65unsafe impl Send for LlamaContextParams {}
66unsafe impl Sync for LlamaContextParams {}
67
68impl LlamaContextParams {
69    /// Set the side of the context
70    ///
71    /// # Examples
72    ///
73    /// ```rust
74    /// # use std::num::NonZeroU32;
75    /// use llama_cpp_4::context::params::LlamaContextParams;
76    /// let params = LlamaContextParams::default();
77    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
78    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
79    /// ```
80    #[must_use]
81    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
82        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
83        self
84    }
85
86    /// Get the size of the context.
87    ///
88    /// [`None`] if the context size is specified by the model and not the context.
89    ///
90    /// # Examples
91    ///
92    /// ```rust
93    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
94    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
95    #[must_use]
96    pub fn n_ctx(&self) -> Option<NonZeroU32> {
97        NonZeroU32::new(self.context_params.n_ctx)
98    }
99
100    /// Set the maximum number of independent sequence states in the context.
101    ///
102    /// This maps to llama.cpp's `llama_context_params.n_seq_max` and must match
103    /// the highest sequence id used by batched decoding.
104    ///
105    /// # Examples
106    ///
107    /// ```rust
108    /// use llama_cpp_4::context::params::LlamaContextParams;
109    /// let params = LlamaContextParams::default()
110    ///     .with_n_seq_max(16);
111    /// assert_eq!(params.n_seq_max(), 16);
112    /// ```
113    #[must_use]
114    pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
115        self.context_params.n_seq_max = n_seq_max.max(1);
116        self
117    }
118
119    /// Get the configured maximum number of independent sequence states.
120    #[must_use]
121    pub fn n_seq_max(&self) -> u32 {
122        self.context_params.n_seq_max
123    }
124
125    /// Set the `n_batch`
126    ///
127    /// # Examples
128    ///
129    /// ```rust
130    /// # use std::num::NonZeroU32;
131    /// use llama_cpp_4::context::params::LlamaContextParams;
132    /// let params = LlamaContextParams::default()
133    ///     .with_n_batch(2048);
134    /// assert_eq!(params.n_batch(), 2048);
135    /// ```
136    #[must_use]
137    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
138        self.context_params.n_batch = n_batch;
139        self
140    }
141
142    /// Get the `n_batch`
143    ///
144    /// # Examples
145    ///
146    /// ```rust
147    /// use llama_cpp_4::context::params::LlamaContextParams;
148    /// let params = LlamaContextParams::default();
149    /// assert_eq!(params.n_batch(), 2048);
150    /// ```
151    #[must_use]
152    pub fn n_batch(&self) -> u32 {
153        self.context_params.n_batch
154    }
155
156    /// Set the `n_ubatch`
157    ///
158    /// # Examples
159    ///
160    /// ```rust
161    /// # use std::num::NonZeroU32;
162    /// use llama_cpp_4::context::params::LlamaContextParams;
163    /// let params = LlamaContextParams::default()
164    ///     .with_n_ubatch(512);
165    /// assert_eq!(params.n_ubatch(), 512);
166    /// ```
167    #[must_use]
168    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
169        self.context_params.n_ubatch = n_ubatch;
170        self
171    }
172
173    /// Get the `n_ubatch`
174    ///
175    /// # Examples
176    ///
177    /// ```rust
178    /// use llama_cpp_4::context::params::LlamaContextParams;
179    /// let params = LlamaContextParams::default();
180    /// assert_eq!(params.n_ubatch(), 512);
181    /// ```
182    #[must_use]
183    pub fn n_ubatch(&self) -> u32 {
184        self.context_params.n_ubatch
185    }
186
187    /// Set the context type (e.g. [`LlamaContextType::Mtp`] for the draft context in
188    /// [`crate::mtp::MtpSession`]).
189    #[must_use]
190    pub fn with_ctx_type(mut self, ctx_type: LlamaContextType) -> Self {
191        self.context_params.ctx_type = ctx_type.into();
192        self
193    }
194
195    /// Get the configured context type.
196    #[must_use]
197    pub fn ctx_type(&self) -> LlamaContextType {
198        self.context_params.ctx_type.into()
199    }
200
201    /// Set the number of recurrent-state snapshots per sequence (MTP rollback).
202    ///
203    /// Must be `>=` [`MtpSessionConfig::n_draft_max`](crate::mtp::MtpSessionConfig::n_draft_max)
204    /// on the draft context. See [`crate::mtp`].
205    #[must_use]
206    pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
207        self.context_params.n_rs_seq = n_rs_seq;
208        self
209    }
210
211    /// Get the number of recurrent-state snapshots per sequence used for MTP rollback.
212    #[must_use]
213    pub fn n_rs_seq(&self) -> u32 {
214        self.context_params.n_rs_seq
215    }
216
217    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
218    ///
219    /// # Examples
220    ///
221    /// ```rust
222    /// use llama_cpp_4::context::params::LlamaContextParams;
223    /// let params = LlamaContextParams::default()
224    ///     .with_offload_kqv(false);
225    /// assert_eq!(params.offload_kqv(), false);
226    /// ```
227    #[must_use]
228    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
229        self.context_params.offload_kqv = enabled;
230        self
231    }
232
233    /// Get the `offload_kqv` parameter
234    ///
235    /// # Examples
236    ///
237    /// ```rust
238    /// use llama_cpp_4::context::params::LlamaContextParams;
239    /// let params = LlamaContextParams::default();
240    /// assert_eq!(params.offload_kqv(), true);
241    /// ```
242    #[must_use]
243    pub fn offload_kqv(&self) -> bool {
244        self.context_params.offload_kqv
245    }
246
247    /// Set the type of rope scaling.
248    ///
249    /// # Examples
250    ///
251    /// ```rust
252    /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
253    /// let params = LlamaContextParams::default()
254    ///     .with_rope_scaling_type(RopeScalingType::Linear);
255    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
256    /// ```
257    #[must_use]
258    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
259        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
260        self
261    }
262
263    /// Get the type of rope scaling.
264    ///
265    /// # Examples
266    ///
267    /// ```rust
268    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
269    /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
270    /// ```
271    #[must_use]
272    pub fn rope_scaling_type(&self) -> RopeScalingType {
273        RopeScalingType::from(self.context_params.rope_scaling_type)
274    }
275
276    /// Set the rope frequency base.
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// use llama_cpp_4::context::params::LlamaContextParams;
282    /// let params = LlamaContextParams::default()
283    ///    .with_rope_freq_base(0.5);
284    /// assert_eq!(params.rope_freq_base(), 0.5);
285    /// ```
286    #[must_use]
287    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
288        self.context_params.rope_freq_base = rope_freq_base;
289        self
290    }
291
292    /// Get the rope frequency base.
293    ///
294    /// # Examples
295    ///
296    /// ```rust
297    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
298    /// assert_eq!(params.rope_freq_base(), 0.0);
299    /// ```
300    #[must_use]
301    pub fn rope_freq_base(&self) -> f32 {
302        self.context_params.rope_freq_base
303    }
304
305    /// Set the rope frequency scale.
306    ///
307    /// # Examples
308    ///
309    /// ```rust
310    /// use llama_cpp_4::context::params::LlamaContextParams;
311    /// let params = LlamaContextParams::default()
312    ///   .with_rope_freq_scale(0.5);
313    /// assert_eq!(params.rope_freq_scale(), 0.5);
314    /// ```
315    #[must_use]
316    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
317        self.context_params.rope_freq_scale = rope_freq_scale;
318        self
319    }
320
321    /// Get the rope frequency scale.
322    ///
323    /// # Examples
324    ///
325    /// ```rust
326    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
327    /// assert_eq!(params.rope_freq_scale(), 0.0);
328    /// ```
329    #[must_use]
330    pub fn rope_freq_scale(&self) -> f32 {
331        self.context_params.rope_freq_scale
332    }
333
334    /// Get the number of threads.
335    ///
336    /// # Examples
337    ///
338    /// ```rust
339    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
340    /// assert_eq!(params.n_threads(), 4);
341    /// ```
342    #[must_use]
343    pub fn n_threads(&self) -> i32 {
344        self.context_params.n_threads
345    }
346
347    /// Get the number of threads allocated for batches.
348    ///
349    /// # Examples
350    ///
351    /// ```rust
352    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
353    /// assert_eq!(params.n_threads_batch(), 4);
354    /// ```
355    #[must_use]
356    pub fn n_threads_batch(&self) -> i32 {
357        self.context_params.n_threads_batch
358    }
359
360    /// Set the number of threads.
361    ///
362    /// # Examples
363    ///
364    /// ```rust
365    /// use llama_cpp_4::context::params::LlamaContextParams;
366    /// let params = LlamaContextParams::default()
367    ///    .with_n_threads(8);
368    /// assert_eq!(params.n_threads(), 8);
369    /// ```
370    #[must_use]
371    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
372        self.context_params.n_threads = n_threads;
373        self
374    }
375
376    /// Set the number of threads allocated for batches.
377    ///
378    /// # Examples
379    ///
380    /// ```rust
381    /// use llama_cpp_4::context::params::LlamaContextParams;
382    /// let params = LlamaContextParams::default()
383    ///    .with_n_threads_batch(8);
384    /// assert_eq!(params.n_threads_batch(), 8);
385    /// ```
386    #[must_use]
387    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
388        self.context_params.n_threads_batch = n_threads;
389        self
390    }
391
392    /// Check whether embeddings are enabled
393    ///
394    /// # Examples
395    ///
396    /// ```rust
397    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
398    /// assert!(!params.embeddings());
399    /// ```
400    #[must_use]
401    pub fn embeddings(&self) -> bool {
402        self.context_params.embeddings
403    }
404
405    /// Enable the use of embeddings
406    ///
407    /// # Examples
408    ///
409    /// ```rust
410    /// use llama_cpp_4::context::params::LlamaContextParams;
411    /// let params = LlamaContextParams::default()
412    ///    .with_embeddings(true);
413    /// assert!(params.embeddings());
414    /// ```
415    #[must_use]
416    pub fn with_embeddings(mut self, embedding: bool) -> Self {
417        self.context_params.embeddings = embedding;
418        self
419    }
420
421    /// Set the evaluation callback.
422    ///
423    /// # Examples
424    ///
425    /// ```no_run
426    /// extern "C" fn cb_eval_fn(
427    ///     t: *mut llama_cpp_sys_4::ggml_tensor,
428    ///     ask: bool,
429    ///     user_data: *mut std::ffi::c_void,
430    /// ) -> bool {
431    ///     false
432    /// }
433    ///
434    /// use llama_cpp_4::context::params::LlamaContextParams;
435    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
436    /// ```
437    #[must_use]
438    pub fn with_cb_eval(
439        mut self,
440        cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
441    ) -> Self {
442        self.context_params.cb_eval = cb_eval;
443        self
444    }
445
446    /// Set the evaluation callback user data.
447    ///
448    /// # Examples
449    ///
450    /// ```no_run
451    /// use llama_cpp_4::context::params::LlamaContextParams;
452    /// let params = LlamaContextParams::default();
453    /// let user_data = std::ptr::null_mut();
454    /// let params = params.with_cb_eval_user_data(user_data);
455    /// ```
456    #[must_use]
457    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
458        self.context_params.cb_eval_user_data = cb_eval_user_data;
459        self
460    }
461
462    /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
463    /// intercept intermediate tensor outputs during [`crate::LlamaContext::decode`].
464    ///
465    /// Sets `cb_eval` to copy tensors matching the capture filter (layer outputs,
466    /// named nodes, prefix, or all). After `decode()`, read results from the
467    /// capture — see [`crate::TensorCapture`] and [`crate::context::tensor_capture`].
468    ///
469    /// The capture must outlive the context. Call [`TensorCapture::clear`](crate::TensorCapture::clear) before
470    /// reusing it on another batch.
471    ///
472    /// # Example
473    ///
474    /// ```no_run
475    /// use llama_cpp_4::prelude::*;
476    ///
477    /// fn main() {
478    ///     let backend = LlamaBackend::init().unwrap();
479    ///     let model = LlamaModel::load_from_file(
480    ///         &backend,
481    ///         "model.gguf",
482    ///         &LlamaModelParams::default(),
483    ///     )
484    ///     .unwrap();
485    ///
486    ///     let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
487    ///     let ctx_params = LlamaContextParams::default().with_tensor_capture(&mut capture);
488    ///     let _ctx = model.new_context(&backend, ctx_params).unwrap();
489    /// }
490    /// ```
491    #[must_use]
492    pub fn with_tensor_capture(self, capture: &mut super::tensor_capture::TensorCapture) -> Self {
493        self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
494            .with_cb_eval_user_data(
495                std::ptr::from_mut::<super::tensor_capture::TensorCapture>(capture)
496                    .cast::<std::ffi::c_void>(),
497            )
498    }
499
500    /// Set the storage type for the **K** (key) KV cache tensors.
501    ///
502    /// The default is `GgmlType::F16`.  Quantized types like `GgmlType::Q5_0`
503    /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
504    /// `TurboQuant` attention rotation (the default) keeps quality high.
505    ///
506    /// # Examples
507    ///
508    /// ```rust
509    /// use llama_cpp_4::context::params::LlamaContextParams;
510    /// use llama_cpp_4::quantize::GgmlType;
511    /// let params = LlamaContextParams::default()
512    ///     .with_cache_type_k(GgmlType::Q5_0);
513    /// ```
514    #[must_use]
515    pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
516        self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
517        self
518    }
519
520    /// Get the K-cache storage type.
521    #[must_use]
522    pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
523        self.context_params.type_k
524    }
525
526    /// Set the storage type for the **V** (value) KV cache tensors.
527    ///
528    /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
529    ///
530    /// # Examples
531    ///
532    /// ```rust
533    /// use llama_cpp_4::context::params::LlamaContextParams;
534    /// use llama_cpp_4::quantize::GgmlType;
535    /// let params = LlamaContextParams::default()
536    ///     .with_cache_type_v(GgmlType::Q5_0);
537    /// ```
538    #[must_use]
539    pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
540        self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
541        self
542    }
543
544    /// Get the V-cache storage type.
545    #[must_use]
546    pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
547        self.context_params.type_v
548    }
549
550    /// Control the `TurboQuant` attention-rotation feature (llama.cpp PR #21038).
551    ///
552    /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
553    /// before writing them into the KV cache.  This significantly improves
554    /// quantized KV-cache quality at near-zero overhead, and is enabled
555    /// automatically for models whose head dimension is a power of two.
556    ///
557    /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
558    /// The env-var is applied just before the context is created and restored
559    /// afterwards, so this is safe to call from a single thread.
560    ///
561    /// # Examples
562    ///
563    /// ```rust
564    /// use llama_cpp_4::context::params::LlamaContextParams;
565    /// // Disable rotation for this context only:
566    /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
567    /// assert!(params.attn_rot_disabled());
568    /// ```
569    #[must_use]
570    pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
571        self.attn_rot_disabled = disabled;
572        self
573    }
574
575    /// Returns `true` if `TurboQuant` attention rotation is disabled for this context.
576    ///
577    /// ```rust
578    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
579    /// assert!(!params.attn_rot_disabled());
580    /// ```
581    #[must_use]
582    pub fn attn_rot_disabled(&self) -> bool {
583        self.attn_rot_disabled
584    }
585
586    /// Set the type of pooling.
587    ///
588    /// # Examples
589    ///
590    /// ```rust
591    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
592    /// let params = LlamaContextParams::default()
593    ///     .with_pooling_type(LlamaPoolingType::Last);
594    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
595    /// ```
596    #[must_use]
597    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
598        self.context_params.pooling_type = i32::from(pooling_type);
599        self
600    }
601
602    /// Get the type of pooling.
603    ///
604    /// # Examples
605    ///
606    /// ```rust
607    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
608    /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
609    /// ```
610    #[must_use]
611    pub fn pooling_type(&self) -> LlamaPoolingType {
612        LlamaPoolingType::from(self.context_params.pooling_type)
613    }
614
615    /// Clone these params, failing when sampler chains are attached.
616    ///
617    /// Prefer this over [`Clone::clone`] when you need to detect dropped sampler
618    /// configuration.
619    ///
620    /// # Errors
621    ///
622    /// Returns [`ParamsCloneError::SamplerChains`] when per-sequence sampler
623    /// chains are attached and cannot be duplicated.
624    pub fn try_clone(&self) -> Result<Self, ParamsCloneError> {
625        if self.sampler_configs.is_empty() {
626            Ok(self.clone())
627        } else {
628            Err(ParamsCloneError::SamplerChains)
629        }
630    }
631}
632
633/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
634/// ```
635/// # use std::num::NonZeroU32;
636/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
637/// let params = LlamaContextParams::default();
638/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
639/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
640/// ```
641impl Default for LlamaContextParams {
642    fn default() -> Self {
643        let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
644        Self {
645            context_params,
646            attn_rot_disabled: false,
647            owned_samplers: Vec::new(),
648            sampler_configs: Vec::new(),
649        }
650    }
651}
652
653/// Duplicate context params for reuse.
654///
655/// Sampler chains attached via [`LlamaContextParams::with_sampler_seq_configs`]
656/// are **not** cloned — the copy clears `samplers` / `n_samplers` because the
657/// underlying C chains cannot be duplicated safely.
658impl Clone for LlamaContextParams {
659    fn clone(&self) -> Self {
660        let mut context_params = self.context_params;
661        // Sampler chains cannot be duplicated here; cloned params omit them.
662        context_params.samplers = std::ptr::null_mut();
663        context_params.n_samplers = 0;
664        Self {
665            context_params,
666            attn_rot_disabled: self.attn_rot_disabled,
667            owned_samplers: Vec::new(),
668            sampler_configs: Vec::new(),
669        }
670    }
671}