llama_cpp_2/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58    /// Rank pooling
59    Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65    fn from(value: i32) -> Self {
66        match value {
67            0 => Self::None,
68            1 => Self::Mean,
69            2 => Self::Cls,
70            3 => Self::Last,
71            4 => Self::Rank,
72            _ => Self::Unspecified,
73        }
74    }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79    fn from(value: LlamaPoolingType) -> Self {
80        match value {
81            LlamaPoolingType::None => 0,
82            LlamaPoolingType::Mean => 1,
83            LlamaPoolingType::Cls => 2,
84            LlamaPoolingType::Last => 3,
85            LlamaPoolingType::Rank => 4,
86            LlamaPoolingType::Unspecified => -1,
87        }
88    }
89}
90
91/// A rusty wrapper around `ggml_type` for KV cache types.
92#[allow(non_camel_case_types, missing_docs)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum KvCacheType {
95    /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
96    /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
97    /// the runtime will operate with that type).
98    /// This variant preserves API compatibility when new `ggml_type` values are
99    /// introduced in the future.
100    Unknown(llama_cpp_sys_2::ggml_type),
101    F32,
102    F16,
103    Q4_0,
104    Q4_1,
105    Q5_0,
106    Q5_1,
107    Q8_0,
108    Q8_1,
109    Q2_K,
110    Q3_K,
111    Q4_K,
112    Q5_K,
113    Q6_K,
114    Q8_K,
115    IQ2_XXS,
116    IQ2_XS,
117    IQ3_XXS,
118    IQ1_S,
119    IQ4_NL,
120    IQ3_S,
121    IQ2_S,
122    IQ4_XS,
123    I8,
124    I16,
125    I32,
126    I64,
127    F64,
128    IQ1_M,
129    BF16,
130    TQ1_0,
131    TQ2_0,
132    MXFP4,
133}
134
135impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
136    fn from(value: KvCacheType) -> Self {
137        match value {
138            KvCacheType::Unknown(raw) => raw,
139            KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
140            KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
141            KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
142            KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
143            KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
144            KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
145            KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
146            KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
147            KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
148            KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
149            KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
150            KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
151            KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
152            KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
153            KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
154            KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
155            KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
156            KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
157            KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
158            KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
159            KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
160            KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
161            KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
162            KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
163            KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
164            KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
165            KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
166            KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
167            KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
168            KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
169            KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
170            KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
171        }
172    }
173}
174
175impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
176    fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
177        match value {
178            x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
179            x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
180            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
181            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
182            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
183            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
184            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
185            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
186            x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
187            x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
188            x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
189            x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
190            x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
191            x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
192            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
193            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
194            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
195            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
196            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
197            x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
198            x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
199            x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
200            x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
201            x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
202            x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
203            x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
204            x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
205            x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
206            x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
207            x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
208            x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
209            x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
210            _ => KvCacheType::Unknown(value),
211        }
212    }
213}
214
215/// A safe wrapper around `llama_context_params`.
216///
217/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
218///
219/// # Examples
220///
221/// ```rust
222/// # use std::num::NonZeroU32;
223/// use llama_cpp_2::context::params::LlamaContextParams;
224///
225///let ctx_params = LlamaContextParams::default()
226///    .with_n_ctx(NonZeroU32::new(2048));
227///
228/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
229/// ```
230#[derive(Debug, Clone)]
231#[allow(
232    missing_docs,
233    clippy::struct_excessive_bools,
234    clippy::module_name_repetitions
235)]
236pub struct LlamaContextParams {
237    pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
238}
239
240/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
241unsafe impl Send for LlamaContextParams {}
242unsafe impl Sync for LlamaContextParams {}
243
244impl LlamaContextParams {
245    /// Set the side of the context
246    ///
247    /// # Examples
248    ///
249    /// ```rust
250    /// # use std::num::NonZeroU32;
251    /// use llama_cpp_2::context::params::LlamaContextParams;
252    /// let params = LlamaContextParams::default();
253    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
254    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
255    /// ```
256    #[must_use]
257    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
258        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
259        self
260    }
261
262    /// Get the size of the context.
263    ///
264    /// [`None`] if the context size is specified by the model and not the context.
265    ///
266    /// # Examples
267    ///
268    /// ```rust
269    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
270    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
271    #[must_use]
272    pub fn n_ctx(&self) -> Option<NonZeroU32> {
273        NonZeroU32::new(self.context_params.n_ctx)
274    }
275
276    /// Set the `n_batch`
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// # use std::num::NonZeroU32;
282    /// use llama_cpp_2::context::params::LlamaContextParams;
283    /// let params = LlamaContextParams::default()
284    ///     .with_n_batch(2048);
285    /// assert_eq!(params.n_batch(), 2048);
286    /// ```
287    #[must_use]
288    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
289        self.context_params.n_batch = n_batch;
290        self
291    }
292
293    /// Get the `n_batch`
294    ///
295    /// # Examples
296    ///
297    /// ```rust
298    /// use llama_cpp_2::context::params::LlamaContextParams;
299    /// let params = LlamaContextParams::default();
300    /// assert_eq!(params.n_batch(), 2048);
301    /// ```
302    #[must_use]
303    pub fn n_batch(&self) -> u32 {
304        self.context_params.n_batch
305    }
306
307    /// Set the `n_ubatch`
308    ///
309    /// # Examples
310    ///
311    /// ```rust
312    /// # use std::num::NonZeroU32;
313    /// use llama_cpp_2::context::params::LlamaContextParams;
314    /// let params = LlamaContextParams::default()
315    ///     .with_n_ubatch(512);
316    /// assert_eq!(params.n_ubatch(), 512);
317    /// ```
318    #[must_use]
319    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
320        self.context_params.n_ubatch = n_ubatch;
321        self
322    }
323
324    /// Get the `n_ubatch`
325    ///
326    /// # Examples
327    ///
328    /// ```rust
329    /// use llama_cpp_2::context::params::LlamaContextParams;
330    /// let params = LlamaContextParams::default();
331    /// assert_eq!(params.n_ubatch(), 512);
332    /// ```
333    #[must_use]
334    pub fn n_ubatch(&self) -> u32 {
335        self.context_params.n_ubatch
336    }
337
338    /// Set the flash attention policy using llama.cpp enum
339    #[must_use]
340    pub fn with_flash_attention_policy(
341        mut self,
342        policy: llama_cpp_sys_2::llama_flash_attn_type,
343    ) -> Self {
344        self.context_params.flash_attn_type = policy;
345        self
346    }
347
348    /// Get the flash attention policy
349    #[must_use]
350    pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
351        self.context_params.flash_attn_type
352    }
353
354    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
355    ///
356    /// # Examples
357    ///
358    /// ```rust
359    /// use llama_cpp_2::context::params::LlamaContextParams;
360    /// let params = LlamaContextParams::default()
361    ///     .with_offload_kqv(false);
362    /// assert_eq!(params.offload_kqv(), false);
363    /// ```
364    #[must_use]
365    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
366        self.context_params.offload_kqv = enabled;
367        self
368    }
369
370    /// Get the `offload_kqv` parameter
371    ///
372    /// # Examples
373    ///
374    /// ```rust
375    /// use llama_cpp_2::context::params::LlamaContextParams;
376    /// let params = LlamaContextParams::default();
377    /// assert_eq!(params.offload_kqv(), true);
378    /// ```
379    #[must_use]
380    pub fn offload_kqv(&self) -> bool {
381        self.context_params.offload_kqv
382    }
383
384    /// Set the type of rope scaling.
385    ///
386    /// # Examples
387    ///
388    /// ```rust
389    /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
390    /// let params = LlamaContextParams::default()
391    ///     .with_rope_scaling_type(RopeScalingType::Linear);
392    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
393    /// ```
394    #[must_use]
395    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
396        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
397        self
398    }
399
400    /// Get the type of rope scaling.
401    ///
402    /// # Examples
403    ///
404    /// ```rust
405    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
406    /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
407    /// ```
408    #[must_use]
409    pub fn rope_scaling_type(&self) -> RopeScalingType {
410        RopeScalingType::from(self.context_params.rope_scaling_type)
411    }
412
413    /// Set the rope frequency base.
414    ///
415    /// # Examples
416    ///
417    /// ```rust
418    /// use llama_cpp_2::context::params::LlamaContextParams;
419    /// let params = LlamaContextParams::default()
420    ///    .with_rope_freq_base(0.5);
421    /// assert_eq!(params.rope_freq_base(), 0.5);
422    /// ```
423    #[must_use]
424    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
425        self.context_params.rope_freq_base = rope_freq_base;
426        self
427    }
428
429    /// Get the rope frequency base.
430    ///
431    /// # Examples
432    ///
433    /// ```rust
434    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
435    /// assert_eq!(params.rope_freq_base(), 0.0);
436    /// ```
437    #[must_use]
438    pub fn rope_freq_base(&self) -> f32 {
439        self.context_params.rope_freq_base
440    }
441
442    /// Set the rope frequency scale.
443    ///
444    /// # Examples
445    ///
446    /// ```rust
447    /// use llama_cpp_2::context::params::LlamaContextParams;
448    /// let params = LlamaContextParams::default()
449    ///   .with_rope_freq_scale(0.5);
450    /// assert_eq!(params.rope_freq_scale(), 0.5);
451    /// ```
452    #[must_use]
453    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
454        self.context_params.rope_freq_scale = rope_freq_scale;
455        self
456    }
457
458    /// Get the rope frequency scale.
459    ///
460    /// # Examples
461    ///
462    /// ```rust
463    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
464    /// assert_eq!(params.rope_freq_scale(), 0.0);
465    /// ```
466    #[must_use]
467    pub fn rope_freq_scale(&self) -> f32 {
468        self.context_params.rope_freq_scale
469    }
470
471    /// Get the number of threads.
472    ///
473    /// # Examples
474    ///
475    /// ```rust
476    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
477    /// assert_eq!(params.n_threads(), 4);
478    /// ```
479    #[must_use]
480    pub fn n_threads(&self) -> i32 {
481        self.context_params.n_threads
482    }
483
484    /// Get the number of threads allocated for batches.
485    ///
486    /// # Examples
487    ///
488    /// ```rust
489    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
490    /// assert_eq!(params.n_threads_batch(), 4);
491    /// ```
492    #[must_use]
493    pub fn n_threads_batch(&self) -> i32 {
494        self.context_params.n_threads_batch
495    }
496
497    /// Set the number of threads.
498    ///
499    /// # Examples
500    ///
501    /// ```rust
502    /// use llama_cpp_2::context::params::LlamaContextParams;
503    /// let params = LlamaContextParams::default()
504    ///    .with_n_threads(8);
505    /// assert_eq!(params.n_threads(), 8);
506    /// ```
507    #[must_use]
508    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
509        self.context_params.n_threads = n_threads;
510        self
511    }
512
513    /// Set the number of threads allocated for batches.
514    ///
515    /// # Examples
516    ///
517    /// ```rust
518    /// use llama_cpp_2::context::params::LlamaContextParams;
519    /// let params = LlamaContextParams::default()
520    ///    .with_n_threads_batch(8);
521    /// assert_eq!(params.n_threads_batch(), 8);
522    /// ```
523    #[must_use]
524    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
525        self.context_params.n_threads_batch = n_threads;
526        self
527    }
528
529    /// Check whether embeddings are enabled
530    ///
531    /// # Examples
532    ///
533    /// ```rust
534    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
535    /// assert!(!params.embeddings());
536    /// ```
537    #[must_use]
538    pub fn embeddings(&self) -> bool {
539        self.context_params.embeddings
540    }
541
542    /// Enable the use of embeddings
543    ///
544    /// # Examples
545    ///
546    /// ```rust
547    /// use llama_cpp_2::context::params::LlamaContextParams;
548    /// let params = LlamaContextParams::default()
549    ///    .with_embeddings(true);
550    /// assert!(params.embeddings());
551    /// ```
552    #[must_use]
553    pub fn with_embeddings(mut self, embedding: bool) -> Self {
554        self.context_params.embeddings = embedding;
555        self
556    }
557
558    /// Set the evaluation callback.
559    ///
560    /// # Examples
561    ///
562    /// ```no_run
563    /// extern "C" fn cb_eval_fn(
564    ///     t: *mut llama_cpp_sys_2::ggml_tensor,
565    ///     ask: bool,
566    ///     user_data: *mut std::ffi::c_void,
567    /// ) -> bool {
568    ///     false
569    /// }
570    ///
571    /// use llama_cpp_2::context::params::LlamaContextParams;
572    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
573    /// ```
574    #[must_use]
575    pub fn with_cb_eval(
576        mut self,
577        cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
578    ) -> Self {
579        self.context_params.cb_eval = cb_eval;
580        self
581    }
582
583    /// Set the evaluation callback user data.
584    ///
585    /// # Examples
586    ///
587    /// ```no_run
588    /// use llama_cpp_2::context::params::LlamaContextParams;
589    /// let params = LlamaContextParams::default();
590    /// let user_data = std::ptr::null_mut();
591    /// let params = params.with_cb_eval_user_data(user_data);
592    /// ```
593    #[must_use]
594    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
595        self.context_params.cb_eval_user_data = cb_eval_user_data;
596        self
597    }
598
599    /// Set the type of pooling.
600    ///
601    /// # Examples
602    ///
603    /// ```rust
604    /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
605    /// let params = LlamaContextParams::default()
606    ///     .with_pooling_type(LlamaPoolingType::Last);
607    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
608    /// ```
609    #[must_use]
610    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
611        self.context_params.pooling_type = i32::from(pooling_type);
612        self
613    }
614
615    /// Get the type of pooling.
616    ///
617    /// # Examples
618    ///
619    /// ```rust
620    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
621    /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
622    /// ```
623    #[must_use]
624    pub fn pooling_type(&self) -> LlamaPoolingType {
625        LlamaPoolingType::from(self.context_params.pooling_type)
626    }
627
628    /// Set whether to use full sliding window attention
629    ///
630    /// # Examples
631    ///
632    /// ```rust
633    /// use llama_cpp_2::context::params::LlamaContextParams;
634    /// let params = LlamaContextParams::default()
635    ///     .with_swa_full(false);
636    /// assert_eq!(params.swa_full(), false);
637    /// ```
638    #[must_use]
639    pub fn with_swa_full(mut self, enabled: bool) -> Self {
640        self.context_params.swa_full = enabled;
641        self
642    }
643
644    /// Get whether full sliding window attention is enabled
645    ///
646    /// # Examples
647    ///
648    /// ```rust
649    /// use llama_cpp_2::context::params::LlamaContextParams;
650    /// let params = LlamaContextParams::default();
651    /// assert_eq!(params.swa_full(), true);
652    /// ```
653    #[must_use]
654    pub fn swa_full(&self) -> bool {
655        self.context_params.swa_full
656    }
657
658    /// Set the max number of sequences (i.e. distinct states for recurrent models)
659    ///
660    /// # Examples
661    ///
662    /// ```rust
663    /// use llama_cpp_2::context::params::LlamaContextParams;
664    /// let params = LlamaContextParams::default()
665    ///     .with_n_seq_max(64);
666    /// assert_eq!(params.n_seq_max(), 64);
667    /// ```
668    #[must_use]
669    pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
670        self.context_params.n_seq_max = n_seq_max;
671        self
672    }
673
674    /// Get the max number of sequences (i.e. distinct states for recurrent models)
675    ///
676    /// # Examples
677    ///
678    /// ```rust
679    /// use llama_cpp_2::context::params::LlamaContextParams;
680    /// let params = LlamaContextParams::default();
681    /// assert_eq!(params.n_seq_max(), 1);
682    /// ```
683    #[must_use]
684    pub fn n_seq_max(&self) -> u32 {
685        self.context_params.n_seq_max
686    }
687    /// Set the KV cache data type for K
688    /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
689    /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
690    /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
691    /// ```
692    #[must_use]
693    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
694        self.context_params.type_k = type_k.into();
695        self
696    }
697
698    /// Get the KV cache data type for K
699    ///
700    /// # Examples
701    ///
702    /// ```rust
703    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
704    /// let _ = params.type_k();
705    /// ```
706    #[must_use]
707    pub fn type_k(&self) -> KvCacheType {
708        KvCacheType::from(self.context_params.type_k)
709    }
710
711    /// Set the KV cache data type for V
712    ///
713    /// # Examples
714    ///
715    /// ```rust
716    /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
717    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
718    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
719    /// ```
720    #[must_use]
721    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
722        self.context_params.type_v = type_v.into();
723        self
724    }
725
726    /// Get the KV cache data type for V
727    ///
728    /// # Examples
729    ///
730    /// ```rust
731    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
732    /// let _ = params.type_v();
733    /// ```
734    #[must_use]
735    pub fn type_v(&self) -> KvCacheType {
736        KvCacheType::from(self.context_params.type_v)
737    }
738}
739
740/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
741/// ```
742/// # use std::num::NonZeroU32;
743/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
744/// let params = LlamaContextParams::default();
745/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
746/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
747/// ```
748impl Default for LlamaContextParams {
749    fn default() -> Self {
750        let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
751        Self { context_params }
752    }
753}