Skip to main content

llama_cpp_bindings/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58    /// Rank pooling
59    Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65    fn from(value: i32) -> Self {
66        match value {
67            0 => Self::None,
68            1 => Self::Mean,
69            2 => Self::Cls,
70            3 => Self::Last,
71            4 => Self::Rank,
72            _ => Self::Unspecified,
73        }
74    }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79    fn from(value: LlamaPoolingType) -> Self {
80        match value {
81            LlamaPoolingType::None => 0,
82            LlamaPoolingType::Mean => 1,
83            LlamaPoolingType::Cls => 2,
84            LlamaPoolingType::Last => 3,
85            LlamaPoolingType::Rank => 4,
86            LlamaPoolingType::Unspecified => -1,
87        }
88    }
89}
90
91/// A rusty wrapper around `LLAMA_ATTENTION_TYPE`.
92#[repr(i8)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum LlamaAttentionType {
95    /// The attention type is unspecified
96    Unspecified = -1,
97    /// Causal attention
98    Causal = 0,
99    /// Non-causal attention
100    NonCausal = 1,
101}
102
103impl From<i32> for LlamaAttentionType {
104    fn from(value: i32) -> Self {
105        match value {
106            0 => Self::Causal,
107            1 => Self::NonCausal,
108            _ => Self::Unspecified,
109        }
110    }
111}
112
113impl From<LlamaAttentionType> for i32 {
114    fn from(value: LlamaAttentionType) -> Self {
115        match value {
116            LlamaAttentionType::Causal => 0,
117            LlamaAttentionType::NonCausal => 1,
118            LlamaAttentionType::Unspecified => -1,
119        }
120    }
121}
122
123/// A rusty wrapper around `ggml_type` for KV cache types.
124#[allow(non_camel_case_types, missing_docs)]
125#[derive(Copy, Clone, Debug, PartialEq, Eq)]
126pub enum KvCacheType {
127    /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
128    /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
129    /// the runtime will operate with that type).
130    /// This variant preserves API compatibility when new `ggml_type` values are
131    /// introduced in the future.
132    Unknown(llama_cpp_bindings_sys::ggml_type),
133    F32,
134    F16,
135    Q4_0,
136    Q4_1,
137    Q5_0,
138    Q5_1,
139    Q8_0,
140    Q8_1,
141    Q2_K,
142    Q3_K,
143    Q4_K,
144    Q5_K,
145    Q6_K,
146    Q8_K,
147    IQ2_XXS,
148    IQ2_XS,
149    IQ3_XXS,
150    IQ1_S,
151    IQ4_NL,
152    IQ3_S,
153    IQ2_S,
154    IQ4_XS,
155    I8,
156    I16,
157    I32,
158    I64,
159    F64,
160    IQ1_M,
161    BF16,
162    TQ1_0,
163    TQ2_0,
164    MXFP4,
165}
166
167impl From<KvCacheType> for llama_cpp_bindings_sys::ggml_type {
168    fn from(value: KvCacheType) -> Self {
169        match value {
170            KvCacheType::Unknown(raw) => raw,
171            KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32,
172            KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16,
173            KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0,
174            KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1,
175            KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0,
176            KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1,
177            KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0,
178            KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1,
179            KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K,
180            KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K,
181            KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K,
182            KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K,
183            KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K,
184            KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K,
185            KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS,
186            KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS,
187            KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS,
188            KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S,
189            KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL,
190            KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S,
191            KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S,
192            KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS,
193            KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8,
194            KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16,
195            KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32,
196            KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64,
197            KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64,
198            KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M,
199            KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16,
200            KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0,
201            KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0,
202            KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4,
203        }
204    }
205}
206
207impl From<llama_cpp_bindings_sys::ggml_type> for KvCacheType {
208    fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self {
209        match value {
210            x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32,
211            x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16,
212            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0,
213            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1,
214            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0,
215            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1,
216            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0,
217            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1,
218            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K,
219            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K,
220            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K,
221            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K,
222            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K,
223            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K,
224            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS,
225            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS,
226            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS,
227            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S,
228            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL,
229            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S,
230            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S,
231            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS,
232            x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8,
233            x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16,
234            x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32,
235            x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64,
236            x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64,
237            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M,
238            x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16,
239            x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0,
240            x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0,
241            x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4,
242            _ => Self::Unknown(value),
243        }
244    }
245}
246
247/// A safe wrapper around `llama_context_params`.
248///
249/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
250///
251/// # Examples
252///
253/// ```rust
254/// # use std::num::NonZeroU32;
255/// use llama_cpp_bindings::context::params::LlamaContextParams;
256///
257///let ctx_params = LlamaContextParams::default()
258///    .with_n_ctx(NonZeroU32::new(2048));
259///
260/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
261/// ```
262#[derive(Debug, Clone)]
263#[allow(
264    missing_docs,
265    clippy::struct_excessive_bools,
266    clippy::module_name_repetitions
267)]
268pub struct LlamaContextParams {
269    pub context_params: llama_cpp_bindings_sys::llama_context_params,
270}
271
272/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
273unsafe impl Send for LlamaContextParams {}
274unsafe impl Sync for LlamaContextParams {}
275
276impl LlamaContextParams {
277    /// Set the side of the context
278    ///
279    /// # Examples
280    ///
281    /// ```rust
282    /// # use std::num::NonZeroU32;
283    /// use llama_cpp_bindings::context::params::LlamaContextParams;
284    /// let params = LlamaContextParams::default();
285    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
286    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
287    /// ```
288    #[must_use]
289    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
290        self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
291        self
292    }
293
294    /// Get the size of the context.
295    ///
296    /// [`None`] if the context size is specified by the model and not the context.
297    ///
298    /// # Examples
299    ///
300    /// ```rust
301    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
302    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
303    #[must_use]
304    pub const fn n_ctx(&self) -> Option<NonZeroU32> {
305        NonZeroU32::new(self.context_params.n_ctx)
306    }
307
308    /// Set the `n_batch`
309    ///
310    /// # Examples
311    ///
312    /// ```rust
313    /// # use std::num::NonZeroU32;
314    /// use llama_cpp_bindings::context::params::LlamaContextParams;
315    /// let params = LlamaContextParams::default()
316    ///     .with_n_batch(2048);
317    /// assert_eq!(params.n_batch(), 2048);
318    /// ```
319    #[must_use]
320    pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
321        self.context_params.n_batch = n_batch;
322        self
323    }
324
325    /// Get the `n_batch`
326    ///
327    /// # Examples
328    ///
329    /// ```rust
330    /// use llama_cpp_bindings::context::params::LlamaContextParams;
331    /// let params = LlamaContextParams::default();
332    /// assert_eq!(params.n_batch(), 2048);
333    /// ```
334    #[must_use]
335    pub const fn n_batch(&self) -> u32 {
336        self.context_params.n_batch
337    }
338
339    /// Set the `n_ubatch`
340    ///
341    /// # Examples
342    ///
343    /// ```rust
344    /// # use std::num::NonZeroU32;
345    /// use llama_cpp_bindings::context::params::LlamaContextParams;
346    /// let params = LlamaContextParams::default()
347    ///     .with_n_ubatch(512);
348    /// assert_eq!(params.n_ubatch(), 512);
349    /// ```
350    #[must_use]
351    pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
352        self.context_params.n_ubatch = n_ubatch;
353        self
354    }
355
356    /// Get the `n_ubatch`
357    ///
358    /// # Examples
359    ///
360    /// ```rust
361    /// use llama_cpp_bindings::context::params::LlamaContextParams;
362    /// let params = LlamaContextParams::default();
363    /// assert_eq!(params.n_ubatch(), 512);
364    /// ```
365    #[must_use]
366    pub const fn n_ubatch(&self) -> u32 {
367        self.context_params.n_ubatch
368    }
369
370    /// Set the flash attention policy using llama.cpp enum
371    #[must_use]
372    pub const fn with_flash_attention_policy(
373        mut self,
374        policy: llama_cpp_bindings_sys::llama_flash_attn_type,
375    ) -> Self {
376        self.context_params.flash_attn_type = policy;
377        self
378    }
379
380    /// Get the flash attention policy
381    #[must_use]
382    pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
383        self.context_params.flash_attn_type
384    }
385
386    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
387    ///
388    /// # Examples
389    ///
390    /// ```rust
391    /// use llama_cpp_bindings::context::params::LlamaContextParams;
392    /// let params = LlamaContextParams::default()
393    ///     .with_offload_kqv(false);
394    /// assert_eq!(params.offload_kqv(), false);
395    /// ```
396    #[must_use]
397    pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
398        self.context_params.offload_kqv = enabled;
399        self
400    }
401
402    /// Get the `offload_kqv` parameter
403    ///
404    /// # Examples
405    ///
406    /// ```rust
407    /// use llama_cpp_bindings::context::params::LlamaContextParams;
408    /// let params = LlamaContextParams::default();
409    /// assert_eq!(params.offload_kqv(), true);
410    /// ```
411    #[must_use]
412    pub const fn offload_kqv(&self) -> bool {
413        self.context_params.offload_kqv
414    }
415
416    /// Set the type of rope scaling.
417    ///
418    /// # Examples
419    ///
420    /// ```rust
421    /// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
422    /// let params = LlamaContextParams::default()
423    ///     .with_rope_scaling_type(RopeScalingType::Linear);
424    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
425    /// ```
426    #[must_use]
427    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
428        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
429        self
430    }
431
432    /// Get the type of rope scaling.
433    ///
434    /// # Examples
435    ///
436    /// ```rust
437    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
438    /// assert_eq!(params.rope_scaling_type(), llama_cpp_bindings::context::params::RopeScalingType::Unspecified);
439    /// ```
440    #[must_use]
441    pub fn rope_scaling_type(&self) -> RopeScalingType {
442        RopeScalingType::from(self.context_params.rope_scaling_type)
443    }
444
445    /// Set the rope frequency base.
446    ///
447    /// # Examples
448    ///
449    /// ```rust
450    /// use llama_cpp_bindings::context::params::LlamaContextParams;
451    /// let params = LlamaContextParams::default()
452    ///    .with_rope_freq_base(0.5);
453    /// assert_eq!(params.rope_freq_base(), 0.5);
454    /// ```
455    #[must_use]
456    pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
457        self.context_params.rope_freq_base = rope_freq_base;
458        self
459    }
460
461    /// Get the rope frequency base.
462    ///
463    /// # Examples
464    ///
465    /// ```rust
466    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
467    /// assert_eq!(params.rope_freq_base(), 0.0);
468    /// ```
469    #[must_use]
470    pub const fn rope_freq_base(&self) -> f32 {
471        self.context_params.rope_freq_base
472    }
473
474    /// Set the rope frequency scale.
475    ///
476    /// # Examples
477    ///
478    /// ```rust
479    /// use llama_cpp_bindings::context::params::LlamaContextParams;
480    /// let params = LlamaContextParams::default()
481    ///   .with_rope_freq_scale(0.5);
482    /// assert_eq!(params.rope_freq_scale(), 0.5);
483    /// ```
484    #[must_use]
485    pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
486        self.context_params.rope_freq_scale = rope_freq_scale;
487        self
488    }
489
490    /// Get the rope frequency scale.
491    ///
492    /// # Examples
493    ///
494    /// ```rust
495    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
496    /// assert_eq!(params.rope_freq_scale(), 0.0);
497    /// ```
498    #[must_use]
499    pub const fn rope_freq_scale(&self) -> f32 {
500        self.context_params.rope_freq_scale
501    }
502
503    /// Get the number of threads.
504    ///
505    /// # Examples
506    ///
507    /// ```rust
508    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
509    /// assert_eq!(params.n_threads(), 4);
510    /// ```
511    #[must_use]
512    pub const fn n_threads(&self) -> i32 {
513        self.context_params.n_threads
514    }
515
516    /// Get the number of threads allocated for batches.
517    ///
518    /// # Examples
519    ///
520    /// ```rust
521    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
522    /// assert_eq!(params.n_threads_batch(), 4);
523    /// ```
524    #[must_use]
525    pub const fn n_threads_batch(&self) -> i32 {
526        self.context_params.n_threads_batch
527    }
528
529    /// Set the number of threads.
530    ///
531    /// # Examples
532    ///
533    /// ```rust
534    /// use llama_cpp_bindings::context::params::LlamaContextParams;
535    /// let params = LlamaContextParams::default()
536    ///    .with_n_threads(8);
537    /// assert_eq!(params.n_threads(), 8);
538    /// ```
539    #[must_use]
540    pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
541        self.context_params.n_threads = n_threads;
542        self
543    }
544
545    /// Set the number of threads allocated for batches.
546    ///
547    /// # Examples
548    ///
549    /// ```rust
550    /// use llama_cpp_bindings::context::params::LlamaContextParams;
551    /// let params = LlamaContextParams::default()
552    ///    .with_n_threads_batch(8);
553    /// assert_eq!(params.n_threads_batch(), 8);
554    /// ```
555    #[must_use]
556    pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
557        self.context_params.n_threads_batch = n_threads;
558        self
559    }
560
561    /// Check whether embeddings are enabled
562    ///
563    /// # Examples
564    ///
565    /// ```rust
566    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
567    /// assert!(!params.embeddings());
568    /// ```
569    #[must_use]
570    pub const fn embeddings(&self) -> bool {
571        self.context_params.embeddings
572    }
573
574    /// Enable the use of embeddings
575    ///
576    /// # Examples
577    ///
578    /// ```rust
579    /// use llama_cpp_bindings::context::params::LlamaContextParams;
580    /// let params = LlamaContextParams::default()
581    ///    .with_embeddings(true);
582    /// assert!(params.embeddings());
583    /// ```
584    #[must_use]
585    pub const fn with_embeddings(mut self, embedding: bool) -> Self {
586        self.context_params.embeddings = embedding;
587        self
588    }
589
590    /// Set the evaluation callback.
591    ///
592    /// # Examples
593    ///
594    /// ```no_run
595    /// extern "C" fn cb_eval_fn(
596    ///     t: *mut llama_cpp_bindings_sys::ggml_tensor,
597    ///     ask: bool,
598    ///     user_data: *mut std::ffi::c_void,
599    /// ) -> bool {
600    ///     false
601    /// }
602    ///
603    /// use llama_cpp_bindings::context::params::LlamaContextParams;
604    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
605    /// ```
606    #[must_use]
607    pub fn with_cb_eval(
608        mut self,
609        cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
610    ) -> Self {
611        self.context_params.cb_eval = cb_eval;
612        self
613    }
614
615    /// Set the evaluation callback user data.
616    ///
617    /// # Examples
618    ///
619    /// ```no_run
620    /// use llama_cpp_bindings::context::params::LlamaContextParams;
621    /// let params = LlamaContextParams::default();
622    /// let user_data = std::ptr::null_mut();
623    /// let params = params.with_cb_eval_user_data(user_data);
624    /// ```
625    #[must_use]
626    pub const fn with_cb_eval_user_data(
627        mut self,
628        cb_eval_user_data: *mut std::ffi::c_void,
629    ) -> Self {
630        self.context_params.cb_eval_user_data = cb_eval_user_data;
631        self
632    }
633
634    /// Set the type of pooling.
635    ///
636    /// # Examples
637    ///
638    /// ```rust
639    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaPoolingType};
640    /// let params = LlamaContextParams::default()
641    ///     .with_pooling_type(LlamaPoolingType::Last);
642    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
643    /// ```
644    #[must_use]
645    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
646        self.context_params.pooling_type = i32::from(pooling_type);
647        self
648    }
649
650    /// Get the type of pooling.
651    ///
652    /// # Examples
653    ///
654    /// ```rust
655    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
656    /// assert_eq!(params.pooling_type(), llama_cpp_bindings::context::params::LlamaPoolingType::Unspecified);
657    /// ```
658    #[must_use]
659    pub fn pooling_type(&self) -> LlamaPoolingType {
660        LlamaPoolingType::from(self.context_params.pooling_type)
661    }
662
663    /// Set whether to use full sliding window attention
664    ///
665    /// # Examples
666    ///
667    /// ```rust
668    /// use llama_cpp_bindings::context::params::LlamaContextParams;
669    /// let params = LlamaContextParams::default()
670    ///     .with_swa_full(false);
671    /// assert_eq!(params.swa_full(), false);
672    /// ```
673    #[must_use]
674    pub const fn with_swa_full(mut self, enabled: bool) -> Self {
675        self.context_params.swa_full = enabled;
676        self
677    }
678
679    /// Get whether full sliding window attention is enabled
680    ///
681    /// # Examples
682    ///
683    /// ```rust
684    /// use llama_cpp_bindings::context::params::LlamaContextParams;
685    /// let params = LlamaContextParams::default();
686    /// assert_eq!(params.swa_full(), true);
687    /// ```
688    #[must_use]
689    pub const fn swa_full(&self) -> bool {
690        self.context_params.swa_full
691    }
692
693    /// Set the max number of sequences (i.e. distinct states for recurrent models)
694    ///
695    /// # Examples
696    ///
697    /// ```rust
698    /// use llama_cpp_bindings::context::params::LlamaContextParams;
699    /// let params = LlamaContextParams::default()
700    ///     .with_n_seq_max(64);
701    /// assert_eq!(params.n_seq_max(), 64);
702    /// ```
703    #[must_use]
704    pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
705        self.context_params.n_seq_max = n_seq_max;
706        self
707    }
708
709    /// Get the max number of sequences (i.e. distinct states for recurrent models)
710    ///
711    /// # Examples
712    ///
713    /// ```rust
714    /// use llama_cpp_bindings::context::params::LlamaContextParams;
715    /// let params = LlamaContextParams::default();
716    /// assert_eq!(params.n_seq_max(), 1);
717    /// ```
718    #[must_use]
719    pub const fn n_seq_max(&self) -> u32 {
720        self.context_params.n_seq_max
721    }
722    /// Set the KV cache data type for K
723    /// use `llama_cpp_bindings::context::params::{LlamaContextParams`, `KvCacheType`};
724    /// let params = `LlamaContextParams::default().with_type_k(KvCacheType::Q4_0)`;
725    /// `assert_eq!(params.type_k()`, `KvCacheType::Q4_0`);
726    /// ```
727    #[must_use]
728    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
729        self.context_params.type_k = type_k.into();
730        self
731    }
732
733    /// Get the KV cache data type for K
734    ///
735    /// # Examples
736    ///
737    /// ```rust
738    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
739    /// let _ = params.type_k();
740    /// ```
741    #[must_use]
742    pub fn type_k(&self) -> KvCacheType {
743        KvCacheType::from(self.context_params.type_k)
744    }
745
746    /// Set the KV cache data type for V
747    ///
748    /// # Examples
749    ///
750    /// ```rust
751    /// use llama_cpp_bindings::context::params::{LlamaContextParams, KvCacheType};
752    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
753    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
754    /// ```
755    #[must_use]
756    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
757        self.context_params.type_v = type_v.into();
758        self
759    }
760
761    /// Get the KV cache data type for V
762    ///
763    /// # Examples
764    ///
765    /// ```rust
766    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
767    /// let _ = params.type_v();
768    /// ```
769    #[must_use]
770    pub fn type_v(&self) -> KvCacheType {
771        KvCacheType::from(self.context_params.type_v)
772    }
773
774    /// Set the attention type
775    ///
776    /// # Examples
777    ///
778    /// ```rust
779    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaAttentionType};
780    /// let params = LlamaContextParams::default()
781    ///     .with_attention_type(LlamaAttentionType::NonCausal);
782    /// assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
783    /// ```
784    #[must_use]
785    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
786        self.context_params.attention_type = i32::from(attention_type);
787        self
788    }
789
790    /// Get the attention type
791    ///
792    /// # Examples
793    ///
794    /// ```rust
795    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
796    /// assert_eq!(params.attention_type(), llama_cpp_bindings::context::params::LlamaAttentionType::Unspecified);
797    /// ```
798    #[must_use]
799    pub fn attention_type(&self) -> LlamaAttentionType {
800        LlamaAttentionType::from(self.context_params.attention_type)
801    }
802
803    /// Set the `YaRN` extrapolation factor
804    ///
805    /// # Examples
806    ///
807    /// ```rust
808    /// use llama_cpp_bindings::context::params::LlamaContextParams;
809    /// let params = LlamaContextParams::default()
810    ///     .with_yarn_ext_factor(1.0);
811    /// assert!((params.yarn_ext_factor() - 1.0).abs() < f32::EPSILON);
812    /// ```
813    #[must_use]
814    pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
815        self.context_params.yarn_ext_factor = yarn_ext_factor;
816        self
817    }
818
819    /// Get the `YaRN` extrapolation factor
820    #[must_use]
821    pub const fn yarn_ext_factor(&self) -> f32 {
822        self.context_params.yarn_ext_factor
823    }
824
825    /// Set the `YaRN` attention factor
826    ///
827    /// # Examples
828    ///
829    /// ```rust
830    /// use llama_cpp_bindings::context::params::LlamaContextParams;
831    /// let params = LlamaContextParams::default()
832    ///     .with_yarn_attn_factor(2.0);
833    /// assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
834    /// ```
835    #[must_use]
836    pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
837        self.context_params.yarn_attn_factor = yarn_attn_factor;
838        self
839    }
840
841    /// Get the `YaRN` attention factor
842    #[must_use]
843    pub const fn yarn_attn_factor(&self) -> f32 {
844        self.context_params.yarn_attn_factor
845    }
846
847    /// Set the `YaRN` low correction dim
848    ///
849    /// # Examples
850    ///
851    /// ```rust
852    /// use llama_cpp_bindings::context::params::LlamaContextParams;
853    /// let params = LlamaContextParams::default()
854    ///     .with_yarn_beta_fast(32.0);
855    /// assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
856    /// ```
857    #[must_use]
858    pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
859        self.context_params.yarn_beta_fast = yarn_beta_fast;
860        self
861    }
862
863    /// Get the `YaRN` low correction dim
864    #[must_use]
865    pub const fn yarn_beta_fast(&self) -> f32 {
866        self.context_params.yarn_beta_fast
867    }
868
869    /// Set the `YaRN` high correction dim
870    ///
871    /// # Examples
872    ///
873    /// ```rust
874    /// use llama_cpp_bindings::context::params::LlamaContextParams;
875    /// let params = LlamaContextParams::default()
876    ///     .with_yarn_beta_slow(1.0);
877    /// assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
878    /// ```
879    #[must_use]
880    pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
881        self.context_params.yarn_beta_slow = yarn_beta_slow;
882        self
883    }
884
885    /// Get the `YaRN` high correction dim
886    #[must_use]
887    pub const fn yarn_beta_slow(&self) -> f32 {
888        self.context_params.yarn_beta_slow
889    }
890
891    /// Set the `YaRN` original context size
892    ///
893    /// # Examples
894    ///
895    /// ```rust
896    /// use llama_cpp_bindings::context::params::LlamaContextParams;
897    /// let params = LlamaContextParams::default()
898    ///     .with_yarn_orig_ctx(4096);
899    /// assert_eq!(params.yarn_orig_ctx(), 4096);
900    /// ```
901    #[must_use]
902    pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
903        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
904        self
905    }
906
907    /// Get the `YaRN` original context size
908    #[must_use]
909    pub const fn yarn_orig_ctx(&self) -> u32 {
910        self.context_params.yarn_orig_ctx
911    }
912
913    /// Set the KV cache defragmentation threshold
914    ///
915    /// # Examples
916    ///
917    /// ```rust
918    /// use llama_cpp_bindings::context::params::LlamaContextParams;
919    /// let params = LlamaContextParams::default()
920    ///     .with_defrag_thold(0.1);
921    /// assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
922    /// ```
923    #[must_use]
924    pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
925        self.context_params.defrag_thold = defrag_thold;
926        self
927    }
928
929    /// Get the KV cache defragmentation threshold
930    #[must_use]
931    pub const fn defrag_thold(&self) -> f32 {
932        self.context_params.defrag_thold
933    }
934
935    /// Set whether performance timings are disabled
936    ///
937    /// # Examples
938    ///
939    /// ```rust
940    /// use llama_cpp_bindings::context::params::LlamaContextParams;
941    /// let params = LlamaContextParams::default()
942    ///     .with_no_perf(true);
943    /// assert!(params.no_perf());
944    /// ```
945    #[must_use]
946    pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
947        self.context_params.no_perf = no_perf;
948        self
949    }
950
951    /// Get whether performance timings are disabled
952    #[must_use]
953    pub const fn no_perf(&self) -> bool {
954        self.context_params.no_perf
955    }
956
957    /// Set whether to offload ops to GPU
958    ///
959    /// # Examples
960    ///
961    /// ```rust
962    /// use llama_cpp_bindings::context::params::LlamaContextParams;
963    /// let params = LlamaContextParams::default()
964    ///     .with_op_offload(false);
965    /// assert!(!params.op_offload());
966    /// ```
967    #[must_use]
968    pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
969        self.context_params.op_offload = op_offload;
970        self
971    }
972
973    /// Get whether ops are offloaded to GPU
974    #[must_use]
975    pub const fn op_offload(&self) -> bool {
976        self.context_params.op_offload
977    }
978
979    /// Set whether to use a unified KV cache buffer across input sequences
980    ///
981    /// # Examples
982    ///
983    /// ```rust
984    /// use llama_cpp_bindings::context::params::LlamaContextParams;
985    /// let params = LlamaContextParams::default()
986    ///     .with_kv_unified(true);
987    /// assert!(params.kv_unified());
988    /// ```
989    #[must_use]
990    pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
991        self.context_params.kv_unified = kv_unified;
992        self
993    }
994
995    /// Get whether a unified KV cache buffer is used across input sequences
996    #[must_use]
997    pub const fn kv_unified(&self) -> bool {
998        self.context_params.kv_unified
999    }
1000}
1001
1002/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
1003/// ```
1004/// # use std::num::NonZeroU32;
1005/// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
1006/// let params = LlamaContextParams::default();
1007/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
1008/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1009/// ```
1010impl Default for LlamaContextParams {
1011    fn default() -> Self {
1012        let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
1013        Self { context_params }
1014    }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019    use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
1020
1021    #[test]
1022    fn rope_scaling_type_unknown_defaults_to_unspecified() {
1023        assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified);
1024        assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified);
1025    }
1026
1027    #[test]
1028    fn pooling_type_unknown_defaults_to_unspecified() {
1029        assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified);
1030        assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified);
1031    }
1032
1033    #[test]
1034    fn kv_cache_type_unknown_preserves_raw_value() {
1035        let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999;
1036        let cache_type = KvCacheType::from(unknown_raw);
1037
1038        assert_eq!(cache_type, KvCacheType::Unknown(99999));
1039
1040        let back: llama_cpp_bindings_sys::ggml_type = cache_type.into();
1041
1042        assert_eq!(back, 99999);
1043    }
1044
1045    #[test]
1046    fn default_params_have_expected_values() {
1047        let params = super::LlamaContextParams::default();
1048
1049        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
1050        assert_eq!(params.n_batch(), 2048);
1051        assert_eq!(params.n_ubatch(), 512);
1052        assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1053        assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
1054    }
1055
1056    #[test]
1057    fn with_n_ctx_sets_value() {
1058        let params =
1059            super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
1060
1061        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
1062    }
1063
1064    #[test]
1065    fn with_n_ctx_none_sets_zero() {
1066        let params = super::LlamaContextParams::default().with_n_ctx(None);
1067
1068        assert_eq!(params.n_ctx(), None);
1069    }
1070
1071    #[test]
1072    fn with_n_batch_sets_value() {
1073        let params = super::LlamaContextParams::default().with_n_batch(4096);
1074
1075        assert_eq!(params.n_batch(), 4096);
1076    }
1077
1078    #[test]
1079    fn with_n_ubatch_sets_value() {
1080        let params = super::LlamaContextParams::default().with_n_ubatch(1024);
1081
1082        assert_eq!(params.n_ubatch(), 1024);
1083    }
1084
1085    #[test]
1086    fn with_n_seq_max_sets_value() {
1087        let params = super::LlamaContextParams::default().with_n_seq_max(64);
1088
1089        assert_eq!(params.n_seq_max(), 64);
1090    }
1091
1092    #[test]
1093    fn with_embeddings_enables() {
1094        let params = super::LlamaContextParams::default().with_embeddings(true);
1095
1096        assert!(params.embeddings());
1097    }
1098
1099    #[test]
1100    fn with_embeddings_disables() {
1101        let params = super::LlamaContextParams::default().with_embeddings(false);
1102
1103        assert!(!params.embeddings());
1104    }
1105
1106    #[test]
1107    fn with_offload_kqv_disables() {
1108        let params = super::LlamaContextParams::default().with_offload_kqv(false);
1109
1110        assert!(!params.offload_kqv());
1111    }
1112
1113    #[test]
1114    fn with_offload_kqv_enables() {
1115        let params = super::LlamaContextParams::default().with_offload_kqv(true);
1116
1117        assert!(params.offload_kqv());
1118    }
1119
1120    #[test]
1121    fn with_swa_full_disables() {
1122        let params = super::LlamaContextParams::default().with_swa_full(false);
1123
1124        assert!(!params.swa_full());
1125    }
1126
1127    #[test]
1128    fn with_swa_full_enables() {
1129        let params = super::LlamaContextParams::default().with_swa_full(true);
1130
1131        assert!(params.swa_full());
1132    }
1133
1134    #[test]
1135    fn with_rope_scaling_type_linear() {
1136        let params =
1137            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
1138
1139        assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
1140    }
1141
1142    #[test]
1143    fn with_rope_scaling_type_yarn() {
1144        let params =
1145            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
1146
1147        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1148    }
1149
1150    #[test]
1151    fn with_rope_scaling_type_none() {
1152        let params =
1153            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
1154
1155        assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
1156    }
1157
1158    #[test]
1159    fn with_rope_freq_base_sets_value() {
1160        let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
1161
1162        assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
1163    }
1164
1165    #[test]
1166    fn with_rope_freq_scale_sets_value() {
1167        let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
1168
1169        assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
1170    }
1171
1172    #[test]
1173    fn with_n_threads_sets_value() {
1174        let params = super::LlamaContextParams::default().with_n_threads(16);
1175
1176        assert_eq!(params.n_threads(), 16);
1177    }
1178
1179    #[test]
1180    fn with_n_threads_batch_sets_value() {
1181        let params = super::LlamaContextParams::default().with_n_threads_batch(16);
1182
1183        assert_eq!(params.n_threads_batch(), 16);
1184    }
1185
1186    #[test]
1187    fn with_pooling_type_mean() {
1188        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
1189
1190        assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
1191    }
1192
1193    #[test]
1194    fn with_pooling_type_cls() {
1195        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
1196
1197        assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
1198    }
1199
1200    #[test]
1201    fn with_pooling_type_last() {
1202        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
1203
1204        assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
1205    }
1206
1207    #[test]
1208    fn with_pooling_type_rank() {
1209        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
1210
1211        assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
1212    }
1213
1214    #[test]
1215    fn with_pooling_type_none() {
1216        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
1217
1218        assert_eq!(params.pooling_type(), LlamaPoolingType::None);
1219    }
1220
1221    #[test]
1222    fn with_type_k_sets_value() {
1223        let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
1224
1225        assert_eq!(params.type_k(), KvCacheType::Q4_0);
1226    }
1227
1228    #[test]
1229    fn with_type_v_sets_value() {
1230        let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
1231
1232        assert_eq!(params.type_v(), KvCacheType::Q4_1);
1233    }
1234
1235    #[test]
1236    fn with_flash_attention_policy_sets_value() {
1237        let params = super::LlamaContextParams::default()
1238            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
1239
1240        assert_eq!(
1241            params.flash_attention_policy(),
1242            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
1243        );
1244    }
1245
1246    #[test]
1247    fn builder_chaining_preserves_all_values() {
1248        let params = super::LlamaContextParams::default()
1249            .with_n_ctx(std::num::NonZeroU32::new(1024))
1250            .with_n_batch(4096)
1251            .with_n_ubatch(256)
1252            .with_n_threads(8)
1253            .with_n_threads_batch(12)
1254            .with_embeddings(true)
1255            .with_offload_kqv(false)
1256            .with_rope_scaling_type(RopeScalingType::Yarn)
1257            .with_rope_freq_base(5000.0)
1258            .with_rope_freq_scale(0.25);
1259
1260        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1261        assert_eq!(params.n_batch(), 4096);
1262        assert_eq!(params.n_ubatch(), 256);
1263        assert_eq!(params.n_threads(), 8);
1264        assert_eq!(params.n_threads_batch(), 12);
1265        assert!(params.embeddings());
1266        assert!(!params.offload_kqv());
1267        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1268        assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1269        assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1270    }
1271
1272    #[test]
1273    fn rope_scaling_type_roundtrip_all_variants() {
1274        for (raw, expected) in [
1275            (-1, RopeScalingType::Unspecified),
1276            (0, RopeScalingType::None),
1277            (1, RopeScalingType::Linear),
1278            (2, RopeScalingType::Yarn),
1279        ] {
1280            let from_raw = RopeScalingType::from(raw);
1281            assert_eq!(from_raw, expected);
1282
1283            let back_to_raw: i32 = from_raw.into();
1284            assert_eq!(back_to_raw, raw);
1285        }
1286    }
1287
1288    #[test]
1289    fn pooling_type_roundtrip_all_variants() {
1290        for (raw, expected) in [
1291            (-1, LlamaPoolingType::Unspecified),
1292            (0, LlamaPoolingType::None),
1293            (1, LlamaPoolingType::Mean),
1294            (2, LlamaPoolingType::Cls),
1295            (3, LlamaPoolingType::Last),
1296            (4, LlamaPoolingType::Rank),
1297        ] {
1298            let from_raw = LlamaPoolingType::from(raw);
1299            assert_eq!(from_raw, expected);
1300
1301            let back_to_raw: i32 = from_raw.into();
1302            assert_eq!(back_to_raw, raw);
1303        }
1304    }
1305
1306    #[test]
1307    fn kv_cache_type_all_known_variants_roundtrip() {
1308        let all_variants = [
1309            KvCacheType::F32,
1310            KvCacheType::F16,
1311            KvCacheType::Q4_0,
1312            KvCacheType::Q4_1,
1313            KvCacheType::Q5_0,
1314            KvCacheType::Q5_1,
1315            KvCacheType::Q8_0,
1316            KvCacheType::Q8_1,
1317            KvCacheType::Q2_K,
1318            KvCacheType::Q3_K,
1319            KvCacheType::Q4_K,
1320            KvCacheType::Q5_K,
1321            KvCacheType::Q6_K,
1322            KvCacheType::Q8_K,
1323            KvCacheType::IQ2_XXS,
1324            KvCacheType::IQ2_XS,
1325            KvCacheType::IQ3_XXS,
1326            KvCacheType::IQ1_S,
1327            KvCacheType::IQ4_NL,
1328            KvCacheType::IQ3_S,
1329            KvCacheType::IQ2_S,
1330            KvCacheType::IQ4_XS,
1331            KvCacheType::I8,
1332            KvCacheType::I16,
1333            KvCacheType::I32,
1334            KvCacheType::I64,
1335            KvCacheType::F64,
1336            KvCacheType::IQ1_M,
1337            KvCacheType::BF16,
1338            KvCacheType::TQ1_0,
1339            KvCacheType::TQ2_0,
1340            KvCacheType::MXFP4,
1341        ];
1342
1343        for variant in all_variants {
1344            let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into();
1345            let back = KvCacheType::from(ggml_type);
1346
1347            assert_eq!(back, variant);
1348        }
1349    }
1350
1351    #[test]
1352    fn with_cb_eval_sets_callback() {
1353        extern "C" fn test_cb_eval(
1354            _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1355            _ask: bool,
1356            _user_data: *mut std::ffi::c_void,
1357        ) -> bool {
1358            false
1359        }
1360
1361        let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1362
1363        assert!(!result);
1364
1365        let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1366
1367        assert!(params.context_params.cb_eval.is_some());
1368    }
1369
1370    #[test]
1371    fn with_cb_eval_user_data_sets_pointer() {
1372        let mut value: i32 = 42;
1373        let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1374        let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1375
1376        assert_eq!(params.context_params.cb_eval_user_data, user_data);
1377    }
1378
1379    #[test]
1380    fn with_flash_attention_policy_disabled() {
1381        let params = super::LlamaContextParams::default()
1382            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1383
1384        assert_eq!(
1385            params.flash_attention_policy(),
1386            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1387        );
1388    }
1389
1390    #[test]
1391    fn attention_type_unknown_defaults_to_unspecified() {
1392        assert_eq!(
1393            LlamaAttentionType::from(99),
1394            LlamaAttentionType::Unspecified
1395        );
1396        assert_eq!(
1397            LlamaAttentionType::from(-50),
1398            LlamaAttentionType::Unspecified
1399        );
1400    }
1401
1402    #[test]
1403    fn attention_type_roundtrip_all_variants() {
1404        for (raw, expected) in [
1405            (-1, LlamaAttentionType::Unspecified),
1406            (0, LlamaAttentionType::Causal),
1407            (1, LlamaAttentionType::NonCausal),
1408        ] {
1409            let from_raw = LlamaAttentionType::from(raw);
1410            assert_eq!(from_raw, expected);
1411
1412            let back_to_raw: i32 = from_raw.into();
1413            assert_eq!(back_to_raw, raw);
1414        }
1415    }
1416
1417    #[test]
1418    fn with_attention_type_causal() {
1419        let params =
1420            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1421
1422        assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1423    }
1424
1425    #[test]
1426    fn with_attention_type_non_causal() {
1427        let params =
1428            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1429
1430        assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1431    }
1432
1433    #[test]
1434    fn with_yarn_ext_factor_sets_value() {
1435        let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1436
1437        assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1438    }
1439
1440    #[test]
1441    fn with_yarn_attn_factor_sets_value() {
1442        let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1443
1444        assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1445    }
1446
1447    #[test]
1448    fn with_yarn_beta_fast_sets_value() {
1449        let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1450
1451        assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1452    }
1453
1454    #[test]
1455    fn with_yarn_beta_slow_sets_value() {
1456        let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1457
1458        assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1459    }
1460
1461    #[test]
1462    fn with_yarn_orig_ctx_sets_value() {
1463        let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1464
1465        assert_eq!(params.yarn_orig_ctx(), 4096);
1466    }
1467
1468    #[test]
1469    fn with_defrag_thold_sets_value() {
1470        let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1471
1472        assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1473    }
1474
1475    #[test]
1476    fn with_no_perf_enables() {
1477        let params = super::LlamaContextParams::default().with_no_perf(true);
1478
1479        assert!(params.no_perf());
1480    }
1481
1482    #[test]
1483    fn with_no_perf_disables() {
1484        let params = super::LlamaContextParams::default().with_no_perf(false);
1485
1486        assert!(!params.no_perf());
1487    }
1488
1489    #[test]
1490    fn with_op_offload_enables() {
1491        let params = super::LlamaContextParams::default().with_op_offload(true);
1492
1493        assert!(params.op_offload());
1494    }
1495
1496    #[test]
1497    fn with_op_offload_disables() {
1498        let params = super::LlamaContextParams::default().with_op_offload(false);
1499
1500        assert!(!params.op_offload());
1501    }
1502
1503    #[test]
1504    fn with_kv_unified_enables() {
1505        let params = super::LlamaContextParams::default().with_kv_unified(true);
1506
1507        assert!(params.kv_unified());
1508    }
1509
1510    #[test]
1511    fn with_kv_unified_disables() {
1512        let params = super::LlamaContextParams::default().with_kv_unified(false);
1513
1514        assert!(!params.kv_unified());
1515    }
1516}