llama_cpp_2/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58    /// Rank pooling
59    Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65    fn from(value: i32) -> Self {
66        match value {
67            0 => Self::None,
68            1 => Self::Mean,
69            2 => Self::Cls,
70            3 => Self::Last,
71            4 => Self::Rank,
72            _ => Self::Unspecified,
73        }
74    }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79    fn from(value: LlamaPoolingType) -> Self {
80        match value {
81            LlamaPoolingType::None => 0,
82            LlamaPoolingType::Mean => 1,
83            LlamaPoolingType::Cls => 2,
84            LlamaPoolingType::Last => 3,
85            LlamaPoolingType::Rank => 4,
86            LlamaPoolingType::Unspecified => -1,
87        }
88    }
89}
90
91/// A rusty wrapper around `ggml_type` for KV cache types.
92#[allow(non_camel_case_types, missing_docs)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum KvCacheType {
95    /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
96    /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
97    /// the runtime will operate with that type).
98    /// This variant preserves API compatibility when new `ggml_type` values are
99    /// introduced in the future.
100    Unknown(llama_cpp_sys_2::ggml_type),
101    F32,
102    F16,
103    Q4_0,
104    Q4_1,
105    Q5_0,
106    Q5_1,
107    Q8_0,
108    Q8_1,
109    Q2_K,
110    Q3_K,
111    Q4_K,
112    Q5_K,
113    Q6_K,
114    Q8_K,
115    IQ2_XXS,
116    IQ2_XS,
117    IQ3_XXS,
118    IQ1_S,
119    IQ4_NL,
120    IQ3_S,
121    IQ2_S,
122    IQ4_XS,
123    I8,
124    I16,
125    I32,
126    I64,
127    F64,
128    IQ1_M,
129    BF16,
130    TQ1_0,
131    TQ2_0,
132    MXFP4,
133}
134
135impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
136    fn from(value: KvCacheType) -> Self {
137        match value {
138            KvCacheType::Unknown(raw) => raw,
139            KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
140            KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
141            KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
142            KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
143            KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
144            KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
145            KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
146            KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
147            KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
148            KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
149            KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
150            KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
151            KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
152            KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
153            KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
154            KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
155            KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
156            KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
157            KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
158            KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
159            KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
160            KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
161            KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
162            KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
163            KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
164            KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
165            KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
166            KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
167            KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
168            KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
169            KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
170            KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
171        }
172    }
173}
174
175impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
176    fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
177        match value {
178            x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
179            x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
180            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
181            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
182            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
183            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
184            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
185            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
186            x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
187            x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
188            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
189            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
190            x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
191            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
192            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
193            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
194            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
195            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
196            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
197            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
198            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
199            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
200            x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
201            x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
202            x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
203            x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
204            x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
205            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
206            x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
207            x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
208            x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
209            x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
210            _ => KvCacheType::Unknown(value),
211        }
212    }
213}
214
215/// A safe wrapper around `llama_context_params`.
216///
217/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
218///
219/// # Examples
220///
221/// ```rust
222/// # use std::num::NonZeroU32;
223/// use llama_cpp_2::context::params::LlamaContextParams;
224///
225///let ctx_params = LlamaContextParams::default()
226///    .with_n_ctx(NonZeroU32::new(2048));
227///
228/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
229/// ```
230#[derive(Debug, Clone)]
231#[allow(
232    missing_docs,
233    clippy::struct_excessive_bools,
234    clippy::module_name_repetitions
235)]
236pub struct LlamaContextParams {
237    pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
238}
239
240/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
241unsafe impl Send for LlamaContextParams {}
242unsafe impl Sync for LlamaContextParams {}
243
244impl LlamaContextParams {
245    /// Set the side of the context
246    ///
247    /// # Examples
248    ///
249    /// ```rust
250    /// # use std::num::NonZeroU32;
251    /// use llama_cpp_2::context::params::LlamaContextParams;
252    /// let params = LlamaContextParams::default();
253    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
254    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
255    /// ```
256    #[must_use]
257    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
258        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
259        self
260    }
261
262    /// Get the size of the context.
263    ///
264    /// [`None`] if the context size is specified by the model and not the context.
265    ///
266    /// # Examples
267    ///
268    /// ```rust
269    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
270    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
271    #[must_use]
272    pub fn n_ctx(&self) -> Option<NonZeroU32> {
273        NonZeroU32::new(self.context_params.n_ctx)
274    }
275
276    /// Set the `n_batch`
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// # use std::num::NonZeroU32;
282    /// use llama_cpp_2::context::params::LlamaContextParams;
283    /// let params = LlamaContextParams::default()
284    ///     .with_n_batch(2048);
285    /// assert_eq!(params.n_batch(), 2048);
286    /// ```
287    #[must_use]
288    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
289        self.context_params.n_batch = n_batch;
290        self
291    }
292
293    /// Get the `n_batch`
294    ///
295    /// # Examples
296    ///
297    /// ```rust
298    /// use llama_cpp_2::context::params::LlamaContextParams;
299    /// let params = LlamaContextParams::default();
300    /// assert_eq!(params.n_batch(), 2048);
301    /// ```
302    #[must_use]
303    pub fn n_batch(&self) -> u32 {
304        self.context_params.n_batch
305    }
306
307    /// Set the `n_ubatch`
308    ///
309    /// # Examples
310    ///
311    /// ```rust
312    /// # use std::num::NonZeroU32;
313    /// use llama_cpp_2::context::params::LlamaContextParams;
314    /// let params = LlamaContextParams::default()
315    ///     .with_n_ubatch(512);
316    /// assert_eq!(params.n_ubatch(), 512);
317    /// ```
318    #[must_use]
319    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
320        self.context_params.n_ubatch = n_ubatch;
321        self
322    }
323
324    /// Get the `n_ubatch`
325    ///
326    /// # Examples
327    ///
328    /// ```rust
329    /// use llama_cpp_2::context::params::LlamaContextParams;
330    /// let params = LlamaContextParams::default();
331    /// assert_eq!(params.n_ubatch(), 512);
332    /// ```
333    #[must_use]
334    pub fn n_ubatch(&self) -> u32 {
335        self.context_params.n_ubatch
336    }
337
338    /// Set the `flash_attention` parameter
339    ///
340    /// # Examples
341    ///
342    /// ```rust
343    /// use llama_cpp_2::context::params::LlamaContextParams;
344    /// let params = LlamaContextParams::default()
345    ///     .with_flash_attention(true);
346    /// assert_eq!(params.flash_attention(), true);
347    /// ```
348    #[must_use]
349    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
350        self.context_params.flash_attn = enabled;
351        self
352    }
353
354    /// Get the `flash_attention` parameter
355    ///
356    /// # Examples
357    ///
358    /// ```rust
359    /// use llama_cpp_2::context::params::LlamaContextParams;
360    /// let params = LlamaContextParams::default();
361    /// assert_eq!(params.flash_attention(), false);
362    /// ```
363    #[must_use]
364    pub fn flash_attention(&self) -> bool {
365        self.context_params.flash_attn
366    }
367
368    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
369    ///
370    /// # Examples
371    ///
372    /// ```rust
373    /// use llama_cpp_2::context::params::LlamaContextParams;
374    /// let params = LlamaContextParams::default()
375    ///     .with_offload_kqv(false);
376    /// assert_eq!(params.offload_kqv(), false);
377    /// ```
378    #[must_use]
379    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
380        self.context_params.offload_kqv = enabled;
381        self
382    }
383
384    /// Get the `offload_kqv` parameter
385    ///
386    /// # Examples
387    ///
388    /// ```rust
389    /// use llama_cpp_2::context::params::LlamaContextParams;
390    /// let params = LlamaContextParams::default();
391    /// assert_eq!(params.offload_kqv(), true);
392    /// ```
393    #[must_use]
394    pub fn offload_kqv(&self) -> bool {
395        self.context_params.offload_kqv
396    }
397
398    /// Set the type of rope scaling.
399    ///
400    /// # Examples
401    ///
402    /// ```rust
403    /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
404    /// let params = LlamaContextParams::default()
405    ///     .with_rope_scaling_type(RopeScalingType::Linear);
406    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
407    /// ```
408    #[must_use]
409    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
410        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
411        self
412    }
413
414    /// Get the type of rope scaling.
415    ///
416    /// # Examples
417    ///
418    /// ```rust
419    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
420    /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
421    /// ```
422    #[must_use]
423    pub fn rope_scaling_type(&self) -> RopeScalingType {
424        RopeScalingType::from(self.context_params.rope_scaling_type)
425    }
426
427    /// Set the rope frequency base.
428    ///
429    /// # Examples
430    ///
431    /// ```rust
432    /// use llama_cpp_2::context::params::LlamaContextParams;
433    /// let params = LlamaContextParams::default()
434    ///    .with_rope_freq_base(0.5);
435    /// assert_eq!(params.rope_freq_base(), 0.5);
436    /// ```
437    #[must_use]
438    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
439        self.context_params.rope_freq_base = rope_freq_base;
440        self
441    }
442
443    /// Get the rope frequency base.
444    ///
445    /// # Examples
446    ///
447    /// ```rust
448    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
449    /// assert_eq!(params.rope_freq_base(), 0.0);
450    /// ```
451    #[must_use]
452    pub fn rope_freq_base(&self) -> f32 {
453        self.context_params.rope_freq_base
454    }
455
456    /// Set the rope frequency scale.
457    ///
458    /// # Examples
459    ///
460    /// ```rust
461    /// use llama_cpp_2::context::params::LlamaContextParams;
462    /// let params = LlamaContextParams::default()
463    ///   .with_rope_freq_scale(0.5);
464    /// assert_eq!(params.rope_freq_scale(), 0.5);
465    /// ```
466    #[must_use]
467    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
468        self.context_params.rope_freq_scale = rope_freq_scale;
469        self
470    }
471
472    /// Get the rope frequency scale.
473    ///
474    /// # Examples
475    ///
476    /// ```rust
477    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
478    /// assert_eq!(params.rope_freq_scale(), 0.0);
479    /// ```
480    #[must_use]
481    pub fn rope_freq_scale(&self) -> f32 {
482        self.context_params.rope_freq_scale
483    }
484
485    /// Get the number of threads.
486    ///
487    /// # Examples
488    ///
489    /// ```rust
490    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
491    /// assert_eq!(params.n_threads(), 4);
492    /// ```
493    #[must_use]
494    pub fn n_threads(&self) -> i32 {
495        self.context_params.n_threads
496    }
497
498    /// Get the number of threads allocated for batches.
499    ///
500    /// # Examples
501    ///
502    /// ```rust
503    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
504    /// assert_eq!(params.n_threads_batch(), 4);
505    /// ```
506    #[must_use]
507    pub fn n_threads_batch(&self) -> i32 {
508        self.context_params.n_threads_batch
509    }
510
511    /// Set the number of threads.
512    ///
513    /// # Examples
514    ///
515    /// ```rust
516    /// use llama_cpp_2::context::params::LlamaContextParams;
517    /// let params = LlamaContextParams::default()
518    ///    .with_n_threads(8);
519    /// assert_eq!(params.n_threads(), 8);
520    /// ```
521    #[must_use]
522    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
523        self.context_params.n_threads = n_threads;
524        self
525    }
526
527    /// Set the number of threads allocated for batches.
528    ///
529    /// # Examples
530    ///
531    /// ```rust
532    /// use llama_cpp_2::context::params::LlamaContextParams;
533    /// let params = LlamaContextParams::default()
534    ///    .with_n_threads_batch(8);
535    /// assert_eq!(params.n_threads_batch(), 8);
536    /// ```
537    #[must_use]
538    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
539        self.context_params.n_threads_batch = n_threads;
540        self
541    }
542
543    /// Check whether embeddings are enabled
544    ///
545    /// # Examples
546    ///
547    /// ```rust
548    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
549    /// assert!(!params.embeddings());
550    /// ```
551    #[must_use]
552    pub fn embeddings(&self) -> bool {
553        self.context_params.embeddings
554    }
555
556    /// Enable the use of embeddings
557    ///
558    /// # Examples
559    ///
560    /// ```rust
561    /// use llama_cpp_2::context::params::LlamaContextParams;
562    /// let params = LlamaContextParams::default()
563    ///    .with_embeddings(true);
564    /// assert!(params.embeddings());
565    /// ```
566    #[must_use]
567    pub fn with_embeddings(mut self, embedding: bool) -> Self {
568        self.context_params.embeddings = embedding;
569        self
570    }
571
572    /// Set the evaluation callback.
573    ///
574    /// # Examples
575    ///
576    /// ```no_run
577    /// extern "C" fn cb_eval_fn(
578    ///     t: *mut llama_cpp_sys_2::ggml_tensor,
579    ///     ask: bool,
580    ///     user_data: *mut std::ffi::c_void,
581    /// ) -> bool {
582    ///     false
583    /// }
584    ///
585    /// use llama_cpp_2::context::params::LlamaContextParams;
586    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
587    /// ```
588    #[must_use]
589    pub fn with_cb_eval(
590        mut self,
591        cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
592    ) -> Self {
593        self.context_params.cb_eval = cb_eval;
594        self
595    }
596
597    /// Set the evaluation callback user data.
598    ///
599    /// # Examples
600    ///
601    /// ```no_run
602    /// use llama_cpp_2::context::params::LlamaContextParams;
603    /// let params = LlamaContextParams::default();
604    /// let user_data = std::ptr::null_mut();
605    /// let params = params.with_cb_eval_user_data(user_data);
606    /// ```
607    #[must_use]
608    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
609        self.context_params.cb_eval_user_data = cb_eval_user_data;
610        self
611    }
612
613    /// Set the type of pooling.
614    ///
615    /// # Examples
616    ///
617    /// ```rust
618    /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
619    /// let params = LlamaContextParams::default()
620    ///     .with_pooling_type(LlamaPoolingType::Last);
621    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
622    /// ```
623    #[must_use]
624    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
625        self.context_params.pooling_type = i32::from(pooling_type);
626        self
627    }
628
629    /// Get the type of pooling.
630    ///
631    /// # Examples
632    ///
633    /// ```rust
634    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
635    /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
636    /// ```
637    #[must_use]
638    pub fn pooling_type(&self) -> LlamaPoolingType {
639        LlamaPoolingType::from(self.context_params.pooling_type)
640    }
641
642    /// Set whether to use full sliding window attention
643    ///
644    /// # Examples
645    ///
646    /// ```rust
647    /// use llama_cpp_2::context::params::LlamaContextParams;
648    /// let params = LlamaContextParams::default()
649    ///     .with_swa_full(false);
650    /// assert_eq!(params.swa_full(), false);
651    /// ```
652    #[must_use]
653    pub fn with_swa_full(mut self, enabled: bool) -> Self {
654        self.context_params.swa_full = enabled;
655        self
656    }
657
658    /// Get whether full sliding window attention is enabled
659    ///
660    /// # Examples
661    ///
662    /// ```rust
663    /// use llama_cpp_2::context::params::LlamaContextParams;
664    /// let params = LlamaContextParams::default();
665    /// assert_eq!(params.swa_full(), true);
666    /// ```
667    #[must_use]
668    pub fn swa_full(&self) -> bool {
669        self.context_params.swa_full
670    }
671
672    /// Set the max number of sequences (i.e. distinct states for recurrent models)
673    ///
674    /// # Examples
675    ///
676    /// ```rust
677    /// use llama_cpp_2::context::params::LlamaContextParams;
678    /// let params = LlamaContextParams::default()
679    ///     .with_n_seq_max(64);
680    /// assert_eq!(params.n_seq_max(), 64);
681    /// ```
682    #[must_use]
683    pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
684        self.context_params.n_seq_max = n_seq_max;
685        self
686    }
687
688    /// Get the max number of sequences (i.e. distinct states for recurrent models)
689    ///
690    /// # Examples
691    ///
692    /// ```rust
693    /// use llama_cpp_2::context::params::LlamaContextParams;
694    /// let params = LlamaContextParams::default();
695    /// assert_eq!(params.n_seq_max(), 1);
696    /// ```
697    #[must_use]
698    pub fn n_seq_max(&self) -> u32 {
699        self.context_params.n_seq_max
700    }
701    /// Set the KV cache data type for K
702    /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
703    /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
704    /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
705    /// ```
706    #[must_use]
707    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
708        self.context_params.type_k = type_k.into();
709        self
710    }
711
712    /// Get the KV cache data type for K
713    ///
714    /// # Examples
715    ///
716    /// ```rust
717    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
718    /// let _ = params.type_k();
719    /// ```
720    #[must_use]
721    pub fn type_k(&self) -> KvCacheType {
722        KvCacheType::from(self.context_params.type_k)
723    }
724
725    /// Set the KV cache data type for V
726    ///
727    /// # Examples
728    ///
729    /// ```rust
730    /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
731    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
732    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
733    /// ```
734    #[must_use]
735    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
736        self.context_params.type_v = type_v.into();
737        self
738    }
739
740    /// Get the KV cache data type for V
741    ///
742    /// # Examples
743    ///
744    /// ```rust
745    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
746    /// let _ = params.type_v();
747    /// ```
748    #[must_use]
749    pub fn type_v(&self) -> KvCacheType {
750        KvCacheType::from(self.context_params.type_v)
751    }
752}
753
754/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
755/// ```
756/// # use std::num::NonZeroU32;
757/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
758/// let params = LlamaContextParams::default();
759/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
760/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
761/// ```
762impl Default for LlamaContextParams {
763    fn default() -> Self {
764        let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
765        Self { context_params }
766    }
767}