Skip to main content

llama_cpp_4/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling    
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63    fn from(value: i32) -> Self {
64        match value {
65            0 => Self::None,
66            1 => Self::Mean,
67            2 => Self::Cls,
68            3 => Self::Last,
69            _ => Self::Unspecified,
70        }
71    }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76    fn from(value: LlamaPoolingType) -> Self {
77        match value {
78            LlamaPoolingType::None => 0,
79            LlamaPoolingType::Mean => 1,
80            LlamaPoolingType::Cls => 2,
81            LlamaPoolingType::Last => 3,
82            LlamaPoolingType::Unspecified => -1,
83        }
84    }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97/// let ctx_params = LlamaContextParams::default()
98///     .with_n_ctx(NonZeroU32::new(2048));
99///
100/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
101/// ```
102#[derive(Debug, Clone)]
103#[allow(
104    missing_docs,
105    clippy::struct_excessive_bools,
106    clippy::module_name_repetitions
107)]
108pub struct LlamaContextParams {
109    pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
110    /// When `true`, the `TurboQuant` attention rotation (PR #21038) will be
111    /// disabled for any context created from these params.
112    pub(crate) attn_rot_disabled: bool,
113}
114
115/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
116unsafe impl Send for LlamaContextParams {}
117unsafe impl Sync for LlamaContextParams {}
118
119impl LlamaContextParams {
120    /// Set the side of the context
121    ///
122    /// # Examples
123    ///
124    /// ```rust
125    /// # use std::num::NonZeroU32;
126    /// use llama_cpp_4::context::params::LlamaContextParams;
127    /// let params = LlamaContextParams::default();
128    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
129    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
130    /// ```
131    #[must_use]
132    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
133        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
134        self
135    }
136
137    /// Get the size of the context.
138    ///
139    /// [`None`] if the context size is specified by the model and not the context.
140    ///
141    /// # Examples
142    ///
143    /// ```rust
144    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
145    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
146    #[must_use]
147    pub fn n_ctx(&self) -> Option<NonZeroU32> {
148        NonZeroU32::new(self.context_params.n_ctx)
149    }
150
151    /// Set the `n_batch`
152    ///
153    /// # Examples
154    ///
155    /// ```rust
156    /// # use std::num::NonZeroU32;
157    /// use llama_cpp_4::context::params::LlamaContextParams;
158    /// let params = LlamaContextParams::default()
159    ///     .with_n_batch(2048);
160    /// assert_eq!(params.n_batch(), 2048);
161    /// ```
162    #[must_use]
163    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
164        self.context_params.n_batch = n_batch;
165        self
166    }
167
168    /// Get the `n_batch`
169    ///
170    /// # Examples
171    ///
172    /// ```rust
173    /// use llama_cpp_4::context::params::LlamaContextParams;
174    /// let params = LlamaContextParams::default();
175    /// assert_eq!(params.n_batch(), 2048);
176    /// ```
177    #[must_use]
178    pub fn n_batch(&self) -> u32 {
179        self.context_params.n_batch
180    }
181
182    /// Set the `n_ubatch`
183    ///
184    /// # Examples
185    ///
186    /// ```rust
187    /// # use std::num::NonZeroU32;
188    /// use llama_cpp_4::context::params::LlamaContextParams;
189    /// let params = LlamaContextParams::default()
190    ///     .with_n_ubatch(512);
191    /// assert_eq!(params.n_ubatch(), 512);
192    /// ```
193    #[must_use]
194    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
195        self.context_params.n_ubatch = n_ubatch;
196        self
197    }
198
199    /// Get the `n_ubatch`
200    ///
201    /// # Examples
202    ///
203    /// ```rust
204    /// use llama_cpp_4::context::params::LlamaContextParams;
205    /// let params = LlamaContextParams::default();
206    /// assert_eq!(params.n_ubatch(), 512);
207    /// ```
208    #[must_use]
209    pub fn n_ubatch(&self) -> u32 {
210        self.context_params.n_ubatch
211    }
212
213    /// Set the number of recurrent-state snapshots per sequence used for MTP rollback.
214    ///
215    /// This is only available when built with the `mtp` feature.
216    #[cfg(feature = "mtp")]
217    #[must_use]
218    pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
219        self.context_params.n_rs_seq = n_rs_seq;
220        self
221    }
222
223    /// Get the number of recurrent-state snapshots per sequence used for MTP rollback.
224    ///
225    /// This is only available when built with the `mtp` feature.
226    #[cfg(feature = "mtp")]
227    #[must_use]
228    pub fn n_rs_seq(&self) -> u32 {
229        self.context_params.n_rs_seq
230    }
231
232    /// Set the `flash_attention` parameter
233    ///
234    /// # Examples
235    ///
236    /// ```rust
237    /// use llama_cpp_4::context::params::LlamaContextParams;
238    /// let params = LlamaContextParams::default()
239    ///     .with_flash_attention(true);
240    /// assert_eq!(params.flash_attention(), true);
241    /// ```
242    #[must_use]
243    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
244        self.context_params.flash_attn_type = if enabled {
245            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
246        } else {
247            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
248        };
249        self
250    }
251
252    /// Get the `flash_attention` parameter
253    ///
254    /// # Examples
255    ///
256    /// ```rust
257    /// use llama_cpp_4::context::params::LlamaContextParams;
258    /// let params = LlamaContextParams::default();
259    /// assert_eq!(params.flash_attention(), false);
260    /// ```
261    #[must_use]
262    pub fn flash_attention(&self) -> bool {
263        self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
264    }
265
266    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
267    ///
268    /// # Examples
269    ///
270    /// ```rust
271    /// use llama_cpp_4::context::params::LlamaContextParams;
272    /// let params = LlamaContextParams::default()
273    ///     .with_offload_kqv(false);
274    /// assert_eq!(params.offload_kqv(), false);
275    /// ```
276    #[must_use]
277    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
278        self.context_params.offload_kqv = enabled;
279        self
280    }
281
282    /// Get the `offload_kqv` parameter
283    ///
284    /// # Examples
285    ///
286    /// ```rust
287    /// use llama_cpp_4::context::params::LlamaContextParams;
288    /// let params = LlamaContextParams::default();
289    /// assert_eq!(params.offload_kqv(), true);
290    /// ```
291    #[must_use]
292    pub fn offload_kqv(&self) -> bool {
293        self.context_params.offload_kqv
294    }
295
296    /// Set the type of rope scaling.
297    ///
298    /// # Examples
299    ///
300    /// ```rust
301    /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
302    /// let params = LlamaContextParams::default()
303    ///     .with_rope_scaling_type(RopeScalingType::Linear);
304    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
305    /// ```
306    #[must_use]
307    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
308        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
309        self
310    }
311
312    /// Get the type of rope scaling.
313    ///
314    /// # Examples
315    ///
316    /// ```rust
317    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
318    /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
319    /// ```
320    #[must_use]
321    pub fn rope_scaling_type(&self) -> RopeScalingType {
322        RopeScalingType::from(self.context_params.rope_scaling_type)
323    }
324
325    /// Set the rope frequency base.
326    ///
327    /// # Examples
328    ///
329    /// ```rust
330    /// use llama_cpp_4::context::params::LlamaContextParams;
331    /// let params = LlamaContextParams::default()
332    ///    .with_rope_freq_base(0.5);
333    /// assert_eq!(params.rope_freq_base(), 0.5);
334    /// ```
335    #[must_use]
336    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
337        self.context_params.rope_freq_base = rope_freq_base;
338        self
339    }
340
341    /// Get the rope frequency base.
342    ///
343    /// # Examples
344    ///
345    /// ```rust
346    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
347    /// assert_eq!(params.rope_freq_base(), 0.0);
348    /// ```
349    #[must_use]
350    pub fn rope_freq_base(&self) -> f32 {
351        self.context_params.rope_freq_base
352    }
353
354    /// Set the rope frequency scale.
355    ///
356    /// # Examples
357    ///
358    /// ```rust
359    /// use llama_cpp_4::context::params::LlamaContextParams;
360    /// let params = LlamaContextParams::default()
361    ///   .with_rope_freq_scale(0.5);
362    /// assert_eq!(params.rope_freq_scale(), 0.5);
363    /// ```
364    #[must_use]
365    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
366        self.context_params.rope_freq_scale = rope_freq_scale;
367        self
368    }
369
370    /// Get the rope frequency scale.
371    ///
372    /// # Examples
373    ///
374    /// ```rust
375    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
376    /// assert_eq!(params.rope_freq_scale(), 0.0);
377    /// ```
378    #[must_use]
379    pub fn rope_freq_scale(&self) -> f32 {
380        self.context_params.rope_freq_scale
381    }
382
383    /// Get the number of threads.
384    ///
385    /// # Examples
386    ///
387    /// ```rust
388    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
389    /// assert_eq!(params.n_threads(), 4);
390    /// ```
391    #[must_use]
392    pub fn n_threads(&self) -> i32 {
393        self.context_params.n_threads
394    }
395
396    /// Get the number of threads allocated for batches.
397    ///
398    /// # Examples
399    ///
400    /// ```rust
401    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
402    /// assert_eq!(params.n_threads_batch(), 4);
403    /// ```
404    #[must_use]
405    pub fn n_threads_batch(&self) -> i32 {
406        self.context_params.n_threads_batch
407    }
408
409    /// Set the number of threads.
410    ///
411    /// # Examples
412    ///
413    /// ```rust
414    /// use llama_cpp_4::context::params::LlamaContextParams;
415    /// let params = LlamaContextParams::default()
416    ///    .with_n_threads(8);
417    /// assert_eq!(params.n_threads(), 8);
418    /// ```
419    #[must_use]
420    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
421        self.context_params.n_threads = n_threads;
422        self
423    }
424
425    /// Set the number of threads allocated for batches.
426    ///
427    /// # Examples
428    ///
429    /// ```rust
430    /// use llama_cpp_4::context::params::LlamaContextParams;
431    /// let params = LlamaContextParams::default()
432    ///    .with_n_threads_batch(8);
433    /// assert_eq!(params.n_threads_batch(), 8);
434    /// ```
435    #[must_use]
436    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
437        self.context_params.n_threads_batch = n_threads;
438        self
439    }
440
441    /// Check whether embeddings are enabled
442    ///
443    /// # Examples
444    ///
445    /// ```rust
446    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
447    /// assert!(!params.embeddings());
448    /// ```
449    #[must_use]
450    pub fn embeddings(&self) -> bool {
451        self.context_params.embeddings
452    }
453
454    /// Enable the use of embeddings
455    ///
456    /// # Examples
457    ///
458    /// ```rust
459    /// use llama_cpp_4::context::params::LlamaContextParams;
460    /// let params = LlamaContextParams::default()
461    ///    .with_embeddings(true);
462    /// assert!(params.embeddings());
463    /// ```
464    #[must_use]
465    pub fn with_embeddings(mut self, embedding: bool) -> Self {
466        self.context_params.embeddings = embedding;
467        self
468    }
469
470    /// Set the evaluation callback.
471    ///
472    /// # Examples
473    ///
474    /// ```no_run
475    /// extern "C" fn cb_eval_fn(
476    ///     t: *mut llama_cpp_sys_4::ggml_tensor,
477    ///     ask: bool,
478    ///     user_data: *mut std::ffi::c_void,
479    /// ) -> bool {
480    ///     false
481    /// }
482    ///
483    /// use llama_cpp_4::context::params::LlamaContextParams;
484    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
485    /// ```
486    #[must_use]
487    pub fn with_cb_eval(
488        mut self,
489        cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
490    ) -> Self {
491        self.context_params.cb_eval = cb_eval;
492        self
493    }
494
495    /// Set the evaluation callback user data.
496    ///
497    /// # Examples
498    ///
499    /// ```no_run
500    /// use llama_cpp_4::context::params::LlamaContextParams;
501    /// let params = LlamaContextParams::default();
502    /// let user_data = std::ptr::null_mut();
503    /// let params = params.with_cb_eval_user_data(user_data);
504    /// ```
505    #[must_use]
506    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
507        self.context_params.cb_eval_user_data = cb_eval_user_data;
508        self
509    }
510
511    /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
512    /// intercept intermediate tensor outputs during `decode()`.
513    ///
514    /// This sets up the `cb_eval` callback to capture tensors matching the
515    /// capture's filter (e.g. specific layer outputs). After `decode()` the
516    /// captured data can be read from the `TensorCapture`.
517    ///
518    /// # Example
519    ///
520    /// ```rust,ignore
521    /// use llama_cpp_4::context::params::LlamaContextParams;
522    /// use llama_cpp_4::context::tensor_capture::TensorCapture;
523    ///
524    /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
525    /// let ctx_params = LlamaContextParams::default()
526    ///     .with_embeddings(true)
527    ///     .with_tensor_capture(&mut capture);
528    /// ```
529    #[must_use]
530    pub fn with_tensor_capture(self, capture: &mut super::tensor_capture::TensorCapture) -> Self {
531        self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
532            .with_cb_eval_user_data(
533                std::ptr::from_mut::<super::tensor_capture::TensorCapture>(capture)
534                    .cast::<std::ffi::c_void>(),
535            )
536    }
537
538    /// Set the storage type for the **K** (key) KV cache tensors.
539    ///
540    /// The default is `GgmlType::F16`.  Quantized types like `GgmlType::Q5_0`
541    /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
542    /// `TurboQuant` attention rotation (the default) keeps quality high.
543    ///
544    /// # Examples
545    ///
546    /// ```rust
547    /// use llama_cpp_4::context::params::LlamaContextParams;
548    /// use llama_cpp_4::quantize::GgmlType;
549    /// let params = LlamaContextParams::default()
550    ///     .with_cache_type_k(GgmlType::Q5_0);
551    /// ```
552    #[must_use]
553    pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
554        self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
555        self
556    }
557
558    /// Get the K-cache storage type.
559    #[must_use]
560    pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
561        self.context_params.type_k
562    }
563
564    /// Set the storage type for the **V** (value) KV cache tensors.
565    ///
566    /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
567    ///
568    /// # Examples
569    ///
570    /// ```rust
571    /// use llama_cpp_4::context::params::LlamaContextParams;
572    /// use llama_cpp_4::quantize::GgmlType;
573    /// let params = LlamaContextParams::default()
574    ///     .with_cache_type_v(GgmlType::Q5_0);
575    /// ```
576    #[must_use]
577    pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
578        self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
579        self
580    }
581
582    /// Get the V-cache storage type.
583    #[must_use]
584    pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
585        self.context_params.type_v
586    }
587
588    /// Control the `TurboQuant` attention-rotation feature (llama.cpp PR #21038).
589    ///
590    /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
591    /// before writing them into the KV cache.  This significantly improves
592    /// quantized KV-cache quality at near-zero overhead, and is enabled
593    /// automatically for models whose head dimension is a power of two.
594    ///
595    /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
596    /// The env-var is applied just before the context is created and restored
597    /// afterwards, so this is safe to call from a single thread.
598    ///
599    /// # Examples
600    ///
601    /// ```rust
602    /// use llama_cpp_4::context::params::LlamaContextParams;
603    /// // Disable rotation for this context only:
604    /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
605    /// assert!(params.attn_rot_disabled());
606    /// ```
607    #[must_use]
608    pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
609        self.attn_rot_disabled = disabled;
610        self
611    }
612
613    /// Returns `true` if `TurboQuant` attention rotation is disabled for this context.
614    ///
615    /// ```rust
616    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
617    /// assert!(!params.attn_rot_disabled());
618    /// ```
619    #[must_use]
620    pub fn attn_rot_disabled(&self) -> bool {
621        self.attn_rot_disabled
622    }
623
624    /// Set the type of pooling.
625    ///
626    /// # Examples
627    ///
628    /// ```rust
629    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
630    /// let params = LlamaContextParams::default()
631    ///     .with_pooling_type(LlamaPoolingType::Last);
632    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
633    /// ```
634    #[must_use]
635    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
636        self.context_params.pooling_type = i32::from(pooling_type);
637        self
638    }
639
640    /// Get the type of pooling.
641    ///
642    /// # Examples
643    ///
644    /// ```rust
645    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
646    /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
647    /// ```
648    #[must_use]
649    pub fn pooling_type(&self) -> LlamaPoolingType {
650        LlamaPoolingType::from(self.context_params.pooling_type)
651    }
652}
653
654/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
655/// ```
656/// # use std::num::NonZeroU32;
657/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
658/// let params = LlamaContextParams::default();
659/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
660/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
661/// ```
662impl Default for LlamaContextParams {
663    fn default() -> Self {
664        let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
665        Self {
666            context_params,
667            attn_rot_disabled: false,
668        }
669    }
670}