Skip to main content

llama_cpp_4/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling    
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63    fn from(value: i32) -> Self {
64        match value {
65            0 => Self::None,
66            1 => Self::Mean,
67            2 => Self::Cls,
68            3 => Self::Last,
69            _ => Self::Unspecified,
70        }
71    }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76    fn from(value: LlamaPoolingType) -> Self {
77        match value {
78            LlamaPoolingType::None => 0,
79            LlamaPoolingType::Mean => 1,
80            LlamaPoolingType::Cls => 2,
81            LlamaPoolingType::Last => 3,
82            LlamaPoolingType::Unspecified => -1,
83        }
84    }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97/// let ctx_params = LlamaContextParams::default()
98///     .with_n_ctx(NonZeroU32::new(2048));
99///
100/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
101/// ```
102#[derive(Debug, Clone)]
103#[allow(
104    missing_docs,
105    clippy::struct_excessive_bools,
106    clippy::module_name_repetitions
107)]
108pub struct LlamaContextParams {
109    pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
110    /// When `true`, the TurboQuant attention rotation (PR #21038) will be
111    /// disabled for any context created from these params.
112    pub(crate) attn_rot_disabled: bool,
113}
114
115/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
116unsafe impl Send for LlamaContextParams {}
117unsafe impl Sync for LlamaContextParams {}
118
119impl LlamaContextParams {
120    /// Set the side of the context
121    ///
122    /// # Examples
123    ///
124    /// ```rust
125    /// # use std::num::NonZeroU32;
126    /// use llama_cpp_4::context::params::LlamaContextParams;
127    /// let params = LlamaContextParams::default();
128    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
129    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
130    /// ```
131    #[must_use]
132    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
133        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
134        self
135    }
136
137    /// Get the size of the context.
138    ///
139    /// [`None`] if the context size is specified by the model and not the context.
140    ///
141    /// # Examples
142    ///
143    /// ```rust
144    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
145    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
146    #[must_use]
147    pub fn n_ctx(&self) -> Option<NonZeroU32> {
148        NonZeroU32::new(self.context_params.n_ctx)
149    }
150
151    /// Set the `n_batch`
152    ///
153    /// # Examples
154    ///
155    /// ```rust
156    /// # use std::num::NonZeroU32;
157    /// use llama_cpp_4::context::params::LlamaContextParams;
158    /// let params = LlamaContextParams::default()
159    ///     .with_n_batch(2048);
160    /// assert_eq!(params.n_batch(), 2048);
161    /// ```
162    #[must_use]
163    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
164        self.context_params.n_batch = n_batch;
165        self
166    }
167
168    /// Get the `n_batch`
169    ///
170    /// # Examples
171    ///
172    /// ```rust
173    /// use llama_cpp_4::context::params::LlamaContextParams;
174    /// let params = LlamaContextParams::default();
175    /// assert_eq!(params.n_batch(), 2048);
176    /// ```
177    #[must_use]
178    pub fn n_batch(&self) -> u32 {
179        self.context_params.n_batch
180    }
181
182    /// Set the `n_ubatch`
183    ///
184    /// # Examples
185    ///
186    /// ```rust
187    /// # use std::num::NonZeroU32;
188    /// use llama_cpp_4::context::params::LlamaContextParams;
189    /// let params = LlamaContextParams::default()
190    ///     .with_n_ubatch(512);
191    /// assert_eq!(params.n_ubatch(), 512);
192    /// ```
193    #[must_use]
194    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
195        self.context_params.n_ubatch = n_ubatch;
196        self
197    }
198
199    /// Get the `n_ubatch`
200    ///
201    /// # Examples
202    ///
203    /// ```rust
204    /// use llama_cpp_4::context::params::LlamaContextParams;
205    /// let params = LlamaContextParams::default();
206    /// assert_eq!(params.n_ubatch(), 512);
207    /// ```
208    #[must_use]
209    pub fn n_ubatch(&self) -> u32 {
210        self.context_params.n_ubatch
211    }
212
213    /// Set the `flash_attention` parameter
214    ///
215    /// # Examples
216    ///
217    /// ```rust
218    /// use llama_cpp_4::context::params::LlamaContextParams;
219    /// let params = LlamaContextParams::default()
220    ///     .with_flash_attention(true);
221    /// assert_eq!(params.flash_attention(), true);
222    /// ```
223    #[must_use]
224    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
225        self.context_params.flash_attn_type = if enabled {
226            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
227        } else {
228            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
229        };
230        self
231    }
232
233    /// Get the `flash_attention` parameter
234    ///
235    /// # Examples
236    ///
237    /// ```rust
238    /// use llama_cpp_4::context::params::LlamaContextParams;
239    /// let params = LlamaContextParams::default();
240    /// assert_eq!(params.flash_attention(), false);
241    /// ```
242    #[must_use]
243    pub fn flash_attention(&self) -> bool {
244        self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
245    }
246
247    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
248    ///
249    /// # Examples
250    ///
251    /// ```rust
252    /// use llama_cpp_4::context::params::LlamaContextParams;
253    /// let params = LlamaContextParams::default()
254    ///     .with_offload_kqv(false);
255    /// assert_eq!(params.offload_kqv(), false);
256    /// ```
257    #[must_use]
258    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
259        self.context_params.offload_kqv = enabled;
260        self
261    }
262
263    /// Get the `offload_kqv` parameter
264    ///
265    /// # Examples
266    ///
267    /// ```rust
268    /// use llama_cpp_4::context::params::LlamaContextParams;
269    /// let params = LlamaContextParams::default();
270    /// assert_eq!(params.offload_kqv(), true);
271    /// ```
272    #[must_use]
273    pub fn offload_kqv(&self) -> bool {
274        self.context_params.offload_kqv
275    }
276
277    /// Set the type of rope scaling.
278    ///
279    /// # Examples
280    ///
281    /// ```rust
282    /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
283    /// let params = LlamaContextParams::default()
284    ///     .with_rope_scaling_type(RopeScalingType::Linear);
285    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
286    /// ```
287    #[must_use]
288    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
289        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
290        self
291    }
292
293    /// Get the type of rope scaling.
294    ///
295    /// # Examples
296    ///
297    /// ```rust
298    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
299    /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
300    /// ```
301    #[must_use]
302    pub fn rope_scaling_type(&self) -> RopeScalingType {
303        RopeScalingType::from(self.context_params.rope_scaling_type)
304    }
305
306    /// Set the rope frequency base.
307    ///
308    /// # Examples
309    ///
310    /// ```rust
311    /// use llama_cpp_4::context::params::LlamaContextParams;
312    /// let params = LlamaContextParams::default()
313    ///    .with_rope_freq_base(0.5);
314    /// assert_eq!(params.rope_freq_base(), 0.5);
315    /// ```
316    #[must_use]
317    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
318        self.context_params.rope_freq_base = rope_freq_base;
319        self
320    }
321
322    /// Get the rope frequency base.
323    ///
324    /// # Examples
325    ///
326    /// ```rust
327    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
328    /// assert_eq!(params.rope_freq_base(), 0.0);
329    /// ```
330    #[must_use]
331    pub fn rope_freq_base(&self) -> f32 {
332        self.context_params.rope_freq_base
333    }
334
335    /// Set the rope frequency scale.
336    ///
337    /// # Examples
338    ///
339    /// ```rust
340    /// use llama_cpp_4::context::params::LlamaContextParams;
341    /// let params = LlamaContextParams::default()
342    ///   .with_rope_freq_scale(0.5);
343    /// assert_eq!(params.rope_freq_scale(), 0.5);
344    /// ```
345    #[must_use]
346    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
347        self.context_params.rope_freq_scale = rope_freq_scale;
348        self
349    }
350
351    /// Get the rope frequency scale.
352    ///
353    /// # Examples
354    ///
355    /// ```rust
356    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
357    /// assert_eq!(params.rope_freq_scale(), 0.0);
358    /// ```
359    #[must_use]
360    pub fn rope_freq_scale(&self) -> f32 {
361        self.context_params.rope_freq_scale
362    }
363
364    /// Get the number of threads.
365    ///
366    /// # Examples
367    ///
368    /// ```rust
369    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
370    /// assert_eq!(params.n_threads(), 4);
371    /// ```
372    #[must_use]
373    pub fn n_threads(&self) -> i32 {
374        self.context_params.n_threads
375    }
376
377    /// Get the number of threads allocated for batches.
378    ///
379    /// # Examples
380    ///
381    /// ```rust
382    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
383    /// assert_eq!(params.n_threads_batch(), 4);
384    /// ```
385    #[must_use]
386    pub fn n_threads_batch(&self) -> i32 {
387        self.context_params.n_threads_batch
388    }
389
390    /// Set the number of threads.
391    ///
392    /// # Examples
393    ///
394    /// ```rust
395    /// use llama_cpp_4::context::params::LlamaContextParams;
396    /// let params = LlamaContextParams::default()
397    ///    .with_n_threads(8);
398    /// assert_eq!(params.n_threads(), 8);
399    /// ```
400    #[must_use]
401    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
402        self.context_params.n_threads = n_threads;
403        self
404    }
405
406    /// Set the number of threads allocated for batches.
407    ///
408    /// # Examples
409    ///
410    /// ```rust
411    /// use llama_cpp_4::context::params::LlamaContextParams;
412    /// let params = LlamaContextParams::default()
413    ///    .with_n_threads_batch(8);
414    /// assert_eq!(params.n_threads_batch(), 8);
415    /// ```
416    #[must_use]
417    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
418        self.context_params.n_threads_batch = n_threads;
419        self
420    }
421
422    /// Check whether embeddings are enabled
423    ///
424    /// # Examples
425    ///
426    /// ```rust
427    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
428    /// assert!(!params.embeddings());
429    /// ```
430    #[must_use]
431    pub fn embeddings(&self) -> bool {
432        self.context_params.embeddings
433    }
434
435    /// Enable the use of embeddings
436    ///
437    /// # Examples
438    ///
439    /// ```rust
440    /// use llama_cpp_4::context::params::LlamaContextParams;
441    /// let params = LlamaContextParams::default()
442    ///    .with_embeddings(true);
443    /// assert!(params.embeddings());
444    /// ```
445    #[must_use]
446    pub fn with_embeddings(mut self, embedding: bool) -> Self {
447        self.context_params.embeddings = embedding;
448        self
449    }
450
451    /// Set the evaluation callback.
452    ///
453    /// # Examples
454    ///
455    /// ```no_run
456    /// extern "C" fn cb_eval_fn(
457    ///     t: *mut llama_cpp_sys_4::ggml_tensor,
458    ///     ask: bool,
459    ///     user_data: *mut std::ffi::c_void,
460    /// ) -> bool {
461    ///     false
462    /// }
463    ///
464    /// use llama_cpp_4::context::params::LlamaContextParams;
465    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
466    /// ```
467    #[must_use]
468    pub fn with_cb_eval(
469        mut self,
470        cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
471    ) -> Self {
472        self.context_params.cb_eval = cb_eval;
473        self
474    }
475
476    /// Set the evaluation callback user data.
477    ///
478    /// # Examples
479    ///
480    /// ```no_run
481    /// use llama_cpp_4::context::params::LlamaContextParams;
482    /// let params = LlamaContextParams::default();
483    /// let user_data = std::ptr::null_mut();
484    /// let params = params.with_cb_eval_user_data(user_data);
485    /// ```
486    #[must_use]
487    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
488        self.context_params.cb_eval_user_data = cb_eval_user_data;
489        self
490    }
491
492    /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
493    /// intercept intermediate tensor outputs during `decode()`.
494    ///
495    /// This sets up the `cb_eval` callback to capture tensors matching the
496    /// capture's filter (e.g. specific layer outputs). After `decode()` the
497    /// captured data can be read from the `TensorCapture`.
498    ///
499    /// # Example
500    ///
501    /// ```rust,ignore
502    /// use llama_cpp_4::context::params::LlamaContextParams;
503    /// use llama_cpp_4::context::tensor_capture::TensorCapture;
504    ///
505    /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
506    /// let ctx_params = LlamaContextParams::default()
507    ///     .with_embeddings(true)
508    ///     .with_tensor_capture(&mut capture);
509    /// ```
510    #[must_use]
511    pub fn with_tensor_capture(
512        self,
513        capture: &mut super::tensor_capture::TensorCapture,
514    ) -> Self {
515        self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
516            .with_cb_eval_user_data(
517                capture as *mut super::tensor_capture::TensorCapture as *mut std::ffi::c_void,
518            )
519    }
520
521    /// Set the storage type for the **K** (key) KV cache tensors.
522    ///
523    /// The default is `GgmlType::F16`.  Quantized types like `GgmlType::Q5_0`
524    /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
525    /// TurboQuant attention rotation (the default) keeps quality high.
526    ///
527    /// # Examples
528    ///
529    /// ```rust
530    /// use llama_cpp_4::context::params::LlamaContextParams;
531    /// use llama_cpp_4::quantize::GgmlType;
532    /// let params = LlamaContextParams::default()
533    ///     .with_cache_type_k(GgmlType::Q5_0);
534    /// ```
535    #[must_use]
536    pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
537        self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
538        self
539    }
540
541    /// Get the K-cache storage type.
542    #[must_use]
543    pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
544        self.context_params.type_k
545    }
546
547    /// Set the storage type for the **V** (value) KV cache tensors.
548    ///
549    /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
550    ///
551    /// # Examples
552    ///
553    /// ```rust
554    /// use llama_cpp_4::context::params::LlamaContextParams;
555    /// use llama_cpp_4::quantize::GgmlType;
556    /// let params = LlamaContextParams::default()
557    ///     .with_cache_type_v(GgmlType::Q5_0);
558    /// ```
559    #[must_use]
560    pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
561        self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
562        self
563    }
564
565    /// Get the V-cache storage type.
566    #[must_use]
567    pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
568        self.context_params.type_v
569    }
570
571    /// Control the TurboQuant attention-rotation feature (llama.cpp PR #21038).
572    ///
573    /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
574    /// before writing them into the KV cache.  This significantly improves
575    /// quantized KV-cache quality at near-zero overhead, and is enabled
576    /// automatically for models whose head dimension is a power of two.
577    ///
578    /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
579    /// The env-var is applied just before the context is created and restored
580    /// afterwards, so this is safe to call from a single thread.
581    ///
582    /// # Examples
583    ///
584    /// ```rust
585    /// use llama_cpp_4::context::params::LlamaContextParams;
586    /// // Disable rotation for this context only:
587    /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
588    /// assert!(params.attn_rot_disabled());
589    /// ```
590    #[must_use]
591    pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
592        self.attn_rot_disabled = disabled;
593        self
594    }
595
596    /// Returns `true` if TurboQuant attention rotation is disabled for this context.
597    ///
598    /// ```rust
599    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
600    /// assert!(!params.attn_rot_disabled());
601    /// ```
602    #[must_use]
603    pub fn attn_rot_disabled(&self) -> bool {
604        self.attn_rot_disabled
605    }
606
607    /// Set the type of pooling.
608    ///
609    /// # Examples
610    ///
611    /// ```rust
612    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
613    /// let params = LlamaContextParams::default()
614    ///     .with_pooling_type(LlamaPoolingType::Last);
615    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
616    /// ```
617    #[must_use]
618    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
619        self.context_params.pooling_type = i32::from(pooling_type);
620        self
621    }
622
623    /// Get the type of pooling.
624    ///
625    /// # Examples
626    ///
627    /// ```rust
628    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
629    /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
630    /// ```
631    #[must_use]
632    pub fn pooling_type(&self) -> LlamaPoolingType {
633        LlamaPoolingType::from(self.context_params.pooling_type)
634    }
635}
636
637/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
638/// ```
639/// # use std::num::NonZeroU32;
640/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
641/// let params = LlamaContextParams::default();
642/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
643/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
644/// ```
645impl Default for LlamaContextParams {
646    fn default() -> Self {
647        let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
648        Self {
649            context_params,
650            attn_rot_disabled: false,
651        }
652    }
653}