Skip to main content

llama_cpp_4/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling    
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63    fn from(value: i32) -> Self {
64        match value {
65            0 => Self::None,
66            1 => Self::Mean,
67            2 => Self::Cls,
68            3 => Self::Last,
69            _ => Self::Unspecified,
70        }
71    }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76    fn from(value: LlamaPoolingType) -> Self {
77        match value {
78            LlamaPoolingType::None => 0,
79            LlamaPoolingType::Mean => 1,
80            LlamaPoolingType::Cls => 2,
81            LlamaPoolingType::Last => 3,
82            LlamaPoolingType::Unspecified => -1,
83        }
84    }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97///let ctx_params = LlamaContextParams::default()
98///    .with_n_ctx(NonZeroU32::new(2048))
99///    .with_seed(1234);
100///
101/// assert_eq!(ctx_params.seed(), 1234);
102/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103/// ```
104#[derive(Debug, Clone)]
105#[allow(
106    missing_docs,
107    clippy::struct_excessive_bools,
108    clippy::module_name_repetitions
109)]
110pub struct LlamaContextParams {
111    pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
112    /// When `true`, the TurboQuant attention rotation (PR #21038) will be
113    /// disabled for any context created from these params.
114    pub(crate) attn_rot_disabled: bool,
115}
116
117/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
118unsafe impl Send for LlamaContextParams {}
119unsafe impl Sync for LlamaContextParams {}
120
121impl LlamaContextParams {
122    /// Set the side of the context
123    ///
124    /// # Examples
125    ///
126    /// ```rust
127    /// # use std::num::NonZeroU32;
128    /// use llama_cpp_4::context::params::LlamaContextParams;
129    /// let params = LlamaContextParams::default();
130    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
131    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
132    /// ```
133    #[must_use]
134    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
135        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
136        self
137    }
138
139    /// Get the size of the context.
140    ///
141    /// [`None`] if the context size is specified by the model and not the context.
142    ///
143    /// # Examples
144    ///
145    /// ```rust
146    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
147    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
148    #[must_use]
149    pub fn n_ctx(&self) -> Option<NonZeroU32> {
150        NonZeroU32::new(self.context_params.n_ctx)
151    }
152
153    /// Set the `n_batch`
154    ///
155    /// # Examples
156    ///
157    /// ```rust
158    /// # use std::num::NonZeroU32;
159    /// use llama_cpp_4::context::params::LlamaContextParams;
160    /// let params = LlamaContextParams::default()
161    ///     .with_n_batch(2048);
162    /// assert_eq!(params.n_batch(), 2048);
163    /// ```
164    #[must_use]
165    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
166        self.context_params.n_batch = n_batch;
167        self
168    }
169
170    /// Get the `n_batch`
171    ///
172    /// # Examples
173    ///
174    /// ```rust
175    /// use llama_cpp_4::context::params::LlamaContextParams;
176    /// let params = LlamaContextParams::default();
177    /// assert_eq!(params.n_batch(), 2048);
178    /// ```
179    #[must_use]
180    pub fn n_batch(&self) -> u32 {
181        self.context_params.n_batch
182    }
183
184    /// Set the `n_ubatch`
185    ///
186    /// # Examples
187    ///
188    /// ```rust
189    /// # use std::num::NonZeroU32;
190    /// use llama_cpp_4::context::params::LlamaContextParams;
191    /// let params = LlamaContextParams::default()
192    ///     .with_n_ubatch(512);
193    /// assert_eq!(params.n_ubatch(), 512);
194    /// ```
195    #[must_use]
196    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
197        self.context_params.n_ubatch = n_ubatch;
198        self
199    }
200
201    /// Get the `n_ubatch`
202    ///
203    /// # Examples
204    ///
205    /// ```rust
206    /// use llama_cpp_4::context::params::LlamaContextParams;
207    /// let params = LlamaContextParams::default();
208    /// assert_eq!(params.n_ubatch(), 512);
209    /// ```
210    #[must_use]
211    pub fn n_ubatch(&self) -> u32 {
212        self.context_params.n_ubatch
213    }
214
215    /// Set the `flash_attention` parameter
216    ///
217    /// # Examples
218    ///
219    /// ```rust
220    /// use llama_cpp_4::context::params::LlamaContextParams;
221    /// let params = LlamaContextParams::default()
222    ///     .with_flash_attention(true);
223    /// assert_eq!(params.flash_attention(), true);
224    /// ```
225    #[must_use]
226    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
227        self.context_params.flash_attn_type = if enabled {
228            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
229        } else {
230            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
231        };
232        self
233    }
234
235    /// Get the `flash_attention` parameter
236    ///
237    /// # Examples
238    ///
239    /// ```rust
240    /// use llama_cpp_4::context::params::LlamaContextParams;
241    /// let params = LlamaContextParams::default();
242    /// assert_eq!(params.flash_attention(), false);
243    /// ```
244    #[must_use]
245    pub fn flash_attention(&self) -> bool {
246        self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
247    }
248
249    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
250    ///
251    /// # Examples
252    ///
253    /// ```rust
254    /// use llama_cpp_4::context::params::LlamaContextParams;
255    /// let params = LlamaContextParams::default()
256    ///     .with_offload_kqv(false);
257    /// assert_eq!(params.offload_kqv(), false);
258    /// ```
259    #[must_use]
260    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
261        self.context_params.offload_kqv = enabled;
262        self
263    }
264
265    /// Get the `offload_kqv` parameter
266    ///
267    /// # Examples
268    ///
269    /// ```rust
270    /// use llama_cpp_4::context::params::LlamaContextParams;
271    /// let params = LlamaContextParams::default();
272    /// assert_eq!(params.offload_kqv(), true);
273    /// ```
274    #[must_use]
275    pub fn offload_kqv(&self) -> bool {
276        self.context_params.offload_kqv
277    }
278
279    /// Set the type of rope scaling.
280    ///
281    /// # Examples
282    ///
283    /// ```rust
284    /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
285    /// let params = LlamaContextParams::default()
286    ///     .with_rope_scaling_type(RopeScalingType::Linear);
287    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
288    /// ```
289    #[must_use]
290    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
291        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
292        self
293    }
294
295    /// Get the type of rope scaling.
296    ///
297    /// # Examples
298    ///
299    /// ```rust
300    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
301    /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
302    /// ```
303    #[must_use]
304    pub fn rope_scaling_type(&self) -> RopeScalingType {
305        RopeScalingType::from(self.context_params.rope_scaling_type)
306    }
307
308    /// Set the rope frequency base.
309    ///
310    /// # Examples
311    ///
312    /// ```rust
313    /// use llama_cpp_4::context::params::LlamaContextParams;
314    /// let params = LlamaContextParams::default()
315    ///    .with_rope_freq_base(0.5);
316    /// assert_eq!(params.rope_freq_base(), 0.5);
317    /// ```
318    #[must_use]
319    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
320        self.context_params.rope_freq_base = rope_freq_base;
321        self
322    }
323
324    /// Get the rope frequency base.
325    ///
326    /// # Examples
327    ///
328    /// ```rust
329    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
330    /// assert_eq!(params.rope_freq_base(), 0.0);
331    /// ```
332    #[must_use]
333    pub fn rope_freq_base(&self) -> f32 {
334        self.context_params.rope_freq_base
335    }
336
337    /// Set the rope frequency scale.
338    ///
339    /// # Examples
340    ///
341    /// ```rust
342    /// use llama_cpp_4::context::params::LlamaContextParams;
343    /// let params = LlamaContextParams::default()
344    ///   .with_rope_freq_scale(0.5);
345    /// assert_eq!(params.rope_freq_scale(), 0.5);
346    /// ```
347    #[must_use]
348    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
349        self.context_params.rope_freq_scale = rope_freq_scale;
350        self
351    }
352
353    /// Get the rope frequency scale.
354    ///
355    /// # Examples
356    ///
357    /// ```rust
358    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
359    /// assert_eq!(params.rope_freq_scale(), 0.0);
360    /// ```
361    #[must_use]
362    pub fn rope_freq_scale(&self) -> f32 {
363        self.context_params.rope_freq_scale
364    }
365
366    /// Get the number of threads.
367    ///
368    /// # Examples
369    ///
370    /// ```rust
371    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
372    /// assert_eq!(params.n_threads(), 4);
373    /// ```
374    #[must_use]
375    pub fn n_threads(&self) -> i32 {
376        self.context_params.n_threads
377    }
378
379    /// Get the number of threads allocated for batches.
380    ///
381    /// # Examples
382    ///
383    /// ```rust
384    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
385    /// assert_eq!(params.n_threads_batch(), 4);
386    /// ```
387    #[must_use]
388    pub fn n_threads_batch(&self) -> i32 {
389        self.context_params.n_threads_batch
390    }
391
392    /// Set the number of threads.
393    ///
394    /// # Examples
395    ///
396    /// ```rust
397    /// use llama_cpp_4::context::params::LlamaContextParams;
398    /// let params = LlamaContextParams::default()
399    ///    .with_n_threads(8);
400    /// assert_eq!(params.n_threads(), 8);
401    /// ```
402    #[must_use]
403    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
404        self.context_params.n_threads = n_threads;
405        self
406    }
407
408    /// Set the number of threads allocated for batches.
409    ///
410    /// # Examples
411    ///
412    /// ```rust
413    /// use llama_cpp_4::context::params::LlamaContextParams;
414    /// let params = LlamaContextParams::default()
415    ///    .with_n_threads_batch(8);
416    /// assert_eq!(params.n_threads_batch(), 8);
417    /// ```
418    #[must_use]
419    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
420        self.context_params.n_threads_batch = n_threads;
421        self
422    }
423
424    /// Check whether embeddings are enabled
425    ///
426    /// # Examples
427    ///
428    /// ```rust
429    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
430    /// assert!(!params.embeddings());
431    /// ```
432    #[must_use]
433    pub fn embeddings(&self) -> bool {
434        self.context_params.embeddings
435    }
436
437    /// Enable the use of embeddings
438    ///
439    /// # Examples
440    ///
441    /// ```rust
442    /// use llama_cpp_4::context::params::LlamaContextParams;
443    /// let params = LlamaContextParams::default()
444    ///    .with_embeddings(true);
445    /// assert!(params.embeddings());
446    /// ```
447    #[must_use]
448    pub fn with_embeddings(mut self, embedding: bool) -> Self {
449        self.context_params.embeddings = embedding;
450        self
451    }
452
453    /// Set the evaluation callback.
454    ///
455    /// # Examples
456    ///
457    /// ```no_run
458    /// extern "C" fn cb_eval_fn(
459    ///     t: *mut llama_cpp_sys_4::ggml_tensor,
460    ///     ask: bool,
461    ///     user_data: *mut std::ffi::c_void,
462    /// ) -> bool {
463    ///     false
464    /// }
465    ///
466    /// use llama_cpp_4::context::params::LlamaContextParams;
467    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
468    /// ```
469    #[must_use]
470    pub fn with_cb_eval(
471        mut self,
472        cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
473    ) -> Self {
474        self.context_params.cb_eval = cb_eval;
475        self
476    }
477
478    /// Set the evaluation callback user data.
479    ///
480    /// # Examples
481    ///
482    /// ```no_run
483    /// use llama_cpp_4::context::params::LlamaContextParams;
484    /// let params = LlamaContextParams::default();
485    /// let user_data = std::ptr::null_mut();
486    /// let params = params.with_cb_eval_user_data(user_data);
487    /// ```
488    #[must_use]
489    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
490        self.context_params.cb_eval_user_data = cb_eval_user_data;
491        self
492    }
493
494    /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
495    /// intercept intermediate tensor outputs during `decode()`.
496    ///
497    /// This sets up the `cb_eval` callback to capture tensors matching the
498    /// capture's filter (e.g. specific layer outputs). After `decode()` the
499    /// captured data can be read from the `TensorCapture`.
500    ///
501    /// # Example
502    ///
503    /// ```rust,ignore
504    /// use llama_cpp_4::context::params::LlamaContextParams;
505    /// use llama_cpp_4::context::tensor_capture::TensorCapture;
506    ///
507    /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
508    /// let ctx_params = LlamaContextParams::default()
509    ///     .with_embeddings(true)
510    ///     .with_tensor_capture(&mut capture);
511    /// ```
512    #[must_use]
513    pub fn with_tensor_capture(
514        self,
515        capture: &mut super::tensor_capture::TensorCapture,
516    ) -> Self {
517        self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
518            .with_cb_eval_user_data(
519                capture as *mut super::tensor_capture::TensorCapture as *mut std::ffi::c_void,
520            )
521    }
522
523    /// Set the storage type for the **K** (key) KV cache tensors.
524    ///
525    /// The default is `GgmlType::F16`.  Quantized types like `GgmlType::Q5_0`
526    /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
527    /// TurboQuant attention rotation (the default) keeps quality high.
528    ///
529    /// # Examples
530    ///
531    /// ```rust
532    /// use llama_cpp_4::context::params::LlamaContextParams;
533    /// use llama_cpp_4::quantize::GgmlType;
534    /// let params = LlamaContextParams::default()
535    ///     .with_cache_type_k(GgmlType::Q5_0);
536    /// ```
537    #[must_use]
538    pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
539        self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
540        self
541    }
542
543    /// Get the K-cache storage type.
544    #[must_use]
545    pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
546        self.context_params.type_k
547    }
548
549    /// Set the storage type for the **V** (value) KV cache tensors.
550    ///
551    /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
552    ///
553    /// # Examples
554    ///
555    /// ```rust
556    /// use llama_cpp_4::context::params::LlamaContextParams;
557    /// use llama_cpp_4::quantize::GgmlType;
558    /// let params = LlamaContextParams::default()
559    ///     .with_cache_type_v(GgmlType::Q5_0);
560    /// ```
561    #[must_use]
562    pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
563        self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
564        self
565    }
566
567    /// Get the V-cache storage type.
568    #[must_use]
569    pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
570        self.context_params.type_v
571    }
572
573    /// Control the TurboQuant attention-rotation feature (llama.cpp PR #21038).
574    ///
575    /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
576    /// before writing them into the KV cache.  This significantly improves
577    /// quantized KV-cache quality at near-zero overhead, and is enabled
578    /// automatically for models whose head dimension is a power of two.
579    ///
580    /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
581    /// The env-var is applied just before the context is created and restored
582    /// afterwards, so this is safe to call from a single thread.
583    ///
584    /// # Examples
585    ///
586    /// ```rust
587    /// use llama_cpp_4::context::params::LlamaContextParams;
588    /// // Disable rotation for this context only:
589    /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
590    /// assert!(params.attn_rot_disabled());
591    /// ```
592    #[must_use]
593    pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
594        self.attn_rot_disabled = disabled;
595        self
596    }
597
598    /// Returns `true` if TurboQuant attention rotation is disabled for this context.
599    ///
600    /// ```rust
601    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
602    /// assert!(!params.attn_rot_disabled());
603    /// ```
604    #[must_use]
605    pub fn attn_rot_disabled(&self) -> bool {
606        self.attn_rot_disabled
607    }
608
609    /// Set the type of pooling.
610    ///
611    /// # Examples
612    ///
613    /// ```rust
614    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
615    /// let params = LlamaContextParams::default()
616    ///     .with_pooling_type(LlamaPoolingType::Last);
617    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
618    /// ```
619    #[must_use]
620    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
621        self.context_params.pooling_type = i32::from(pooling_type);
622        self
623    }
624
625    /// Get the type of pooling.
626    ///
627    /// # Examples
628    ///
629    /// ```rust
630    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
631    /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
632    /// ```
633    #[must_use]
634    pub fn pooling_type(&self) -> LlamaPoolingType {
635        LlamaPoolingType::from(self.context_params.pooling_type)
636    }
637}
638
639/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
640/// ```
641/// # use std::num::NonZeroU32;
642/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
643/// let params = LlamaContextParams::default();
644/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
645/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
646/// ```
647impl Default for LlamaContextParams {
648    fn default() -> Self {
649        let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
650        Self {
651            context_params,
652            attn_rot_disabled: false,
653        }
654    }
655}