Skip to main content

llama_cpp_bindings/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58    /// Rank pooling
59    Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65    fn from(value: i32) -> Self {
66        match value {
67            0 => Self::None,
68            1 => Self::Mean,
69            2 => Self::Cls,
70            3 => Self::Last,
71            4 => Self::Rank,
72            _ => Self::Unspecified,
73        }
74    }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79    fn from(value: LlamaPoolingType) -> Self {
80        match value {
81            LlamaPoolingType::None => 0,
82            LlamaPoolingType::Mean => 1,
83            LlamaPoolingType::Cls => 2,
84            LlamaPoolingType::Last => 3,
85            LlamaPoolingType::Rank => 4,
86            LlamaPoolingType::Unspecified => -1,
87        }
88    }
89}
90
91/// A rusty wrapper around `LLAMA_ATTENTION_TYPE`.
92#[repr(i8)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum LlamaAttentionType {
95    /// The attention type is unspecified
96    Unspecified = -1,
97    /// Causal attention
98    Causal = 0,
99    /// Non-causal attention
100    NonCausal = 1,
101}
102
103impl From<i32> for LlamaAttentionType {
104    fn from(value: i32) -> Self {
105        match value {
106            0 => Self::Causal,
107            1 => Self::NonCausal,
108            _ => Self::Unspecified,
109        }
110    }
111}
112
113impl From<LlamaAttentionType> for i32 {
114    fn from(value: LlamaAttentionType) -> Self {
115        match value {
116            LlamaAttentionType::Causal => 0,
117            LlamaAttentionType::NonCausal => 1,
118            LlamaAttentionType::Unspecified => -1,
119        }
120    }
121}
122
123/// A rusty wrapper around `ggml_type` for KV cache types.
124#[expect(
125    non_camel_case_types,
126    reason = "variant names mirror llama.cpp's `enum ggml_type` symbol names verbatim so they can \
127              be matched 1:1 against the C ABI without a translation table"
128)]
129#[expect(
130    missing_docs,
131    reason = "each variant denotes a quantisation flavour whose semantics are defined upstream in \
132              ggml; restating the upstream spec inline would risk drifting from the source of truth"
133)]
134#[derive(Copy, Clone, Debug, PartialEq, Eq)]
135pub enum KvCacheType {
136    /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
137    /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
138    /// the runtime will operate with that type).
139    /// This variant preserves API compatibility when new `ggml_type` values are
140    /// introduced in the future.
141    Unknown(llama_cpp_bindings_sys::ggml_type),
142    F32,
143    F16,
144    Q4_0,
145    Q4_1,
146    Q5_0,
147    Q5_1,
148    Q8_0,
149    Q8_1,
150    Q2_K,
151    Q3_K,
152    Q4_K,
153    Q5_K,
154    Q6_K,
155    Q8_K,
156    IQ2_XXS,
157    IQ2_XS,
158    IQ3_XXS,
159    IQ1_S,
160    IQ4_NL,
161    IQ3_S,
162    IQ2_S,
163    IQ4_XS,
164    I8,
165    I16,
166    I32,
167    I64,
168    F64,
169    IQ1_M,
170    BF16,
171    TQ1_0,
172    TQ2_0,
173    MXFP4,
174}
175
176impl From<KvCacheType> for llama_cpp_bindings_sys::ggml_type {
177    fn from(value: KvCacheType) -> Self {
178        match value {
179            KvCacheType::Unknown(raw) => raw,
180            KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32,
181            KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16,
182            KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0,
183            KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1,
184            KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0,
185            KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1,
186            KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0,
187            KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1,
188            KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K,
189            KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K,
190            KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K,
191            KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K,
192            KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K,
193            KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K,
194            KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS,
195            KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS,
196            KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS,
197            KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S,
198            KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL,
199            KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S,
200            KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S,
201            KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS,
202            KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8,
203            KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16,
204            KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32,
205            KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64,
206            KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64,
207            KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M,
208            KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16,
209            KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0,
210            KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0,
211            KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4,
212        }
213    }
214}
215
216impl From<llama_cpp_bindings_sys::ggml_type> for KvCacheType {
217    fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self {
218        match value {
219            x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32,
220            x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16,
221            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0,
222            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1,
223            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0,
224            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1,
225            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0,
226            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1,
227            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K,
228            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K,
229            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K,
230            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K,
231            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K,
232            x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K,
233            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS,
234            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS,
235            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS,
236            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S,
237            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL,
238            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S,
239            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S,
240            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS,
241            x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8,
242            x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16,
243            x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32,
244            x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64,
245            x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64,
246            x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M,
247            x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16,
248            x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0,
249            x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0,
250            x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4,
251            _ => Self::Unknown(value),
252        }
253    }
254}
255
256/// A safe wrapper around `llama_context_params`.
257///
258/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
259///
260/// # Examples
261///
262/// ```rust
263/// # use std::num::NonZeroU32;
264/// use llama_cpp_bindings::context::params::LlamaContextParams;
265///
266///let ctx_params = LlamaContextParams::default()
267///    .with_n_ctx(NonZeroU32::new(2048));
268///
269/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
270/// ```
271#[derive(Debug, Clone)]
272#[expect(
273    missing_docs,
274    reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
275              one inline would risk drift from the upstream spec — the doc-comment on the struct \
276              points at the canonical reference"
277)]
278#[expect(
279    clippy::module_name_repetitions,
280    reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
281              `Params` would force `params::Params` at every call site"
282)]
283pub struct LlamaContextParams {
284    pub context_params: llama_cpp_bindings_sys::llama_context_params,
285}
286
287/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
288unsafe impl Send for LlamaContextParams {}
289unsafe impl Sync for LlamaContextParams {}
290
291impl LlamaContextParams {
292    /// Set the side of the context
293    ///
294    /// # Examples
295    ///
296    /// ```rust
297    /// # use std::num::NonZeroU32;
298    /// use llama_cpp_bindings::context::params::LlamaContextParams;
299    /// let params = LlamaContextParams::default();
300    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
301    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
302    /// ```
303    #[must_use]
304    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
305        self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
306        self
307    }
308
309    /// Get the size of the context.
310    ///
311    /// [`None`] if the context size is specified by the model and not the context.
312    ///
313    /// # Examples
314    ///
315    /// ```rust
316    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
317    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
318    #[must_use]
319    pub const fn n_ctx(&self) -> Option<NonZeroU32> {
320        NonZeroU32::new(self.context_params.n_ctx)
321    }
322
323    /// Set the `n_batch`
324    ///
325    /// # Examples
326    ///
327    /// ```rust
328    /// # use std::num::NonZeroU32;
329    /// use llama_cpp_bindings::context::params::LlamaContextParams;
330    /// let params = LlamaContextParams::default()
331    ///     .with_n_batch(2048);
332    /// assert_eq!(params.n_batch(), 2048);
333    /// ```
334    #[must_use]
335    pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
336        self.context_params.n_batch = n_batch;
337        self
338    }
339
340    /// Get the `n_batch`
341    ///
342    /// # Examples
343    ///
344    /// ```rust
345    /// use llama_cpp_bindings::context::params::LlamaContextParams;
346    /// let params = LlamaContextParams::default();
347    /// assert_eq!(params.n_batch(), 2048);
348    /// ```
349    #[must_use]
350    pub const fn n_batch(&self) -> u32 {
351        self.context_params.n_batch
352    }
353
354    /// Set the `n_ubatch`
355    ///
356    /// # Examples
357    ///
358    /// ```rust
359    /// # use std::num::NonZeroU32;
360    /// use llama_cpp_bindings::context::params::LlamaContextParams;
361    /// let params = LlamaContextParams::default()
362    ///     .with_n_ubatch(512);
363    /// assert_eq!(params.n_ubatch(), 512);
364    /// ```
365    #[must_use]
366    pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
367        self.context_params.n_ubatch = n_ubatch;
368        self
369    }
370
371    /// Get the `n_ubatch`
372    ///
373    /// # Examples
374    ///
375    /// ```rust
376    /// use llama_cpp_bindings::context::params::LlamaContextParams;
377    /// let params = LlamaContextParams::default();
378    /// assert_eq!(params.n_ubatch(), 512);
379    /// ```
380    #[must_use]
381    pub const fn n_ubatch(&self) -> u32 {
382        self.context_params.n_ubatch
383    }
384
385    /// Set the flash attention policy using llama.cpp enum
386    #[must_use]
387    pub const fn with_flash_attention_policy(
388        mut self,
389        policy: llama_cpp_bindings_sys::llama_flash_attn_type,
390    ) -> Self {
391        self.context_params.flash_attn_type = policy;
392        self
393    }
394
395    /// Get the flash attention policy
396    #[must_use]
397    pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
398        self.context_params.flash_attn_type
399    }
400
401    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
402    ///
403    /// # Examples
404    ///
405    /// ```rust
406    /// use llama_cpp_bindings::context::params::LlamaContextParams;
407    /// let params = LlamaContextParams::default()
408    ///     .with_offload_kqv(false);
409    /// assert_eq!(params.offload_kqv(), false);
410    /// ```
411    #[must_use]
412    pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
413        self.context_params.offload_kqv = enabled;
414        self
415    }
416
417    /// Get the `offload_kqv` parameter
418    ///
419    /// # Examples
420    ///
421    /// ```rust
422    /// use llama_cpp_bindings::context::params::LlamaContextParams;
423    /// let params = LlamaContextParams::default();
424    /// assert_eq!(params.offload_kqv(), true);
425    /// ```
426    #[must_use]
427    pub const fn offload_kqv(&self) -> bool {
428        self.context_params.offload_kqv
429    }
430
431    /// Set the type of rope scaling.
432    ///
433    /// # Examples
434    ///
435    /// ```rust
436    /// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
437    /// let params = LlamaContextParams::default()
438    ///     .with_rope_scaling_type(RopeScalingType::Linear);
439    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
440    /// ```
441    #[must_use]
442    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
443        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
444        self
445    }
446
447    /// Get the type of rope scaling.
448    ///
449    /// # Examples
450    ///
451    /// ```rust
452    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
453    /// assert_eq!(params.rope_scaling_type(), llama_cpp_bindings::context::params::RopeScalingType::Unspecified);
454    /// ```
455    #[must_use]
456    pub fn rope_scaling_type(&self) -> RopeScalingType {
457        RopeScalingType::from(self.context_params.rope_scaling_type)
458    }
459
460    /// Set the rope frequency base.
461    ///
462    /// # Examples
463    ///
464    /// ```rust
465    /// use llama_cpp_bindings::context::params::LlamaContextParams;
466    /// let params = LlamaContextParams::default()
467    ///    .with_rope_freq_base(0.5);
468    /// assert_eq!(params.rope_freq_base(), 0.5);
469    /// ```
470    #[must_use]
471    pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
472        self.context_params.rope_freq_base = rope_freq_base;
473        self
474    }
475
476    /// Get the rope frequency base.
477    ///
478    /// # Examples
479    ///
480    /// ```rust
481    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
482    /// assert_eq!(params.rope_freq_base(), 0.0);
483    /// ```
484    #[must_use]
485    pub const fn rope_freq_base(&self) -> f32 {
486        self.context_params.rope_freq_base
487    }
488
489    /// Set the rope frequency scale.
490    ///
491    /// # Examples
492    ///
493    /// ```rust
494    /// use llama_cpp_bindings::context::params::LlamaContextParams;
495    /// let params = LlamaContextParams::default()
496    ///   .with_rope_freq_scale(0.5);
497    /// assert_eq!(params.rope_freq_scale(), 0.5);
498    /// ```
499    #[must_use]
500    pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
501        self.context_params.rope_freq_scale = rope_freq_scale;
502        self
503    }
504
505    /// Get the rope frequency scale.
506    ///
507    /// # Examples
508    ///
509    /// ```rust
510    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
511    /// assert_eq!(params.rope_freq_scale(), 0.0);
512    /// ```
513    #[must_use]
514    pub const fn rope_freq_scale(&self) -> f32 {
515        self.context_params.rope_freq_scale
516    }
517
518    /// Get the number of threads.
519    ///
520    /// # Examples
521    ///
522    /// ```rust
523    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
524    /// assert_eq!(params.n_threads(), 4);
525    /// ```
526    #[must_use]
527    pub const fn n_threads(&self) -> i32 {
528        self.context_params.n_threads
529    }
530
531    /// Get the number of threads allocated for batches.
532    ///
533    /// # Examples
534    ///
535    /// ```rust
536    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
537    /// assert_eq!(params.n_threads_batch(), 4);
538    /// ```
539    #[must_use]
540    pub const fn n_threads_batch(&self) -> i32 {
541        self.context_params.n_threads_batch
542    }
543
544    /// Set the number of threads.
545    ///
546    /// # Examples
547    ///
548    /// ```rust
549    /// use llama_cpp_bindings::context::params::LlamaContextParams;
550    /// let params = LlamaContextParams::default()
551    ///    .with_n_threads(8);
552    /// assert_eq!(params.n_threads(), 8);
553    /// ```
554    #[must_use]
555    pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
556        self.context_params.n_threads = n_threads;
557        self
558    }
559
560    /// Set the number of threads allocated for batches.
561    ///
562    /// # Examples
563    ///
564    /// ```rust
565    /// use llama_cpp_bindings::context::params::LlamaContextParams;
566    /// let params = LlamaContextParams::default()
567    ///    .with_n_threads_batch(8);
568    /// assert_eq!(params.n_threads_batch(), 8);
569    /// ```
570    #[must_use]
571    pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
572        self.context_params.n_threads_batch = n_threads;
573        self
574    }
575
576    /// Check whether embeddings are enabled
577    ///
578    /// # Examples
579    ///
580    /// ```rust
581    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
582    /// assert!(!params.embeddings());
583    /// ```
584    #[must_use]
585    pub const fn embeddings(&self) -> bool {
586        self.context_params.embeddings
587    }
588
589    /// Enable the use of embeddings
590    ///
591    /// # Examples
592    ///
593    /// ```rust
594    /// use llama_cpp_bindings::context::params::LlamaContextParams;
595    /// let params = LlamaContextParams::default()
596    ///    .with_embeddings(true);
597    /// assert!(params.embeddings());
598    /// ```
599    #[must_use]
600    pub const fn with_embeddings(mut self, embedding: bool) -> Self {
601        self.context_params.embeddings = embedding;
602        self
603    }
604
605    /// Set the evaluation callback.
606    ///
607    /// # Examples
608    ///
609    /// ```no_run
610    /// extern "C" fn cb_eval_fn(
611    ///     t: *mut llama_cpp_bindings_sys::ggml_tensor,
612    ///     ask: bool,
613    ///     user_data: *mut std::ffi::c_void,
614    /// ) -> bool {
615    ///     false
616    /// }
617    ///
618    /// use llama_cpp_bindings::context::params::LlamaContextParams;
619    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
620    /// ```
621    #[must_use]
622    pub fn with_cb_eval(
623        mut self,
624        cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
625    ) -> Self {
626        self.context_params.cb_eval = cb_eval;
627        self
628    }
629
630    /// Set the evaluation callback user data.
631    ///
632    /// # Examples
633    ///
634    /// ```no_run
635    /// use llama_cpp_bindings::context::params::LlamaContextParams;
636    /// let params = LlamaContextParams::default();
637    /// let user_data = std::ptr::null_mut();
638    /// let params = params.with_cb_eval_user_data(user_data);
639    /// ```
640    #[must_use]
641    pub const fn with_cb_eval_user_data(
642        mut self,
643        cb_eval_user_data: *mut std::ffi::c_void,
644    ) -> Self {
645        self.context_params.cb_eval_user_data = cb_eval_user_data;
646        self
647    }
648
649    /// Set the type of pooling.
650    ///
651    /// # Examples
652    ///
653    /// ```rust
654    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaPoolingType};
655    /// let params = LlamaContextParams::default()
656    ///     .with_pooling_type(LlamaPoolingType::Last);
657    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
658    /// ```
659    #[must_use]
660    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
661        self.context_params.pooling_type = i32::from(pooling_type);
662        self
663    }
664
665    /// Get the type of pooling.
666    ///
667    /// # Examples
668    ///
669    /// ```rust
670    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
671    /// assert_eq!(params.pooling_type(), llama_cpp_bindings::context::params::LlamaPoolingType::Unspecified);
672    /// ```
673    #[must_use]
674    pub fn pooling_type(&self) -> LlamaPoolingType {
675        LlamaPoolingType::from(self.context_params.pooling_type)
676    }
677
678    /// Set whether to use full sliding window attention
679    ///
680    /// # Examples
681    ///
682    /// ```rust
683    /// use llama_cpp_bindings::context::params::LlamaContextParams;
684    /// let params = LlamaContextParams::default()
685    ///     .with_swa_full(false);
686    /// assert_eq!(params.swa_full(), false);
687    /// ```
688    #[must_use]
689    pub const fn with_swa_full(mut self, enabled: bool) -> Self {
690        self.context_params.swa_full = enabled;
691        self
692    }
693
694    /// Get whether full sliding window attention is enabled
695    ///
696    /// # Examples
697    ///
698    /// ```rust
699    /// use llama_cpp_bindings::context::params::LlamaContextParams;
700    /// let params = LlamaContextParams::default();
701    /// assert_eq!(params.swa_full(), true);
702    /// ```
703    #[must_use]
704    pub const fn swa_full(&self) -> bool {
705        self.context_params.swa_full
706    }
707
708    /// Set the max number of sequences (i.e. distinct states for recurrent models)
709    ///
710    /// # Examples
711    ///
712    /// ```rust
713    /// use llama_cpp_bindings::context::params::LlamaContextParams;
714    /// let params = LlamaContextParams::default()
715    ///     .with_n_seq_max(64);
716    /// assert_eq!(params.n_seq_max(), 64);
717    /// ```
718    #[must_use]
719    pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
720        self.context_params.n_seq_max = n_seq_max;
721        self
722    }
723
724    /// Get the max number of sequences (i.e. distinct states for recurrent models)
725    ///
726    /// # Examples
727    ///
728    /// ```rust
729    /// use llama_cpp_bindings::context::params::LlamaContextParams;
730    /// let params = LlamaContextParams::default();
731    /// assert_eq!(params.n_seq_max(), 1);
732    /// ```
733    #[must_use]
734    pub const fn n_seq_max(&self) -> u32 {
735        self.context_params.n_seq_max
736    }
737    /// Set the KV cache data type for K
738    /// use `llama_cpp_bindings::context::params::{LlamaContextParams`, `KvCacheType`};
739    /// let params = `LlamaContextParams::default().with_type_k(KvCacheType::Q4_0)`;
740    /// `assert_eq!(params.type_k()`, `KvCacheType::Q4_0`);
741    /// ```
742    #[must_use]
743    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
744        self.context_params.type_k = type_k.into();
745        self
746    }
747
748    /// Get the KV cache data type for K
749    ///
750    /// # Examples
751    ///
752    /// ```rust
753    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
754    /// let _ = params.type_k();
755    /// ```
756    #[must_use]
757    pub fn type_k(&self) -> KvCacheType {
758        KvCacheType::from(self.context_params.type_k)
759    }
760
761    /// Set the KV cache data type for V
762    ///
763    /// # Examples
764    ///
765    /// ```rust
766    /// use llama_cpp_bindings::context::params::{LlamaContextParams, KvCacheType};
767    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
768    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
769    /// ```
770    #[must_use]
771    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
772        self.context_params.type_v = type_v.into();
773        self
774    }
775
776    /// Get the KV cache data type for V
777    ///
778    /// # Examples
779    ///
780    /// ```rust
781    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
782    /// let _ = params.type_v();
783    /// ```
784    #[must_use]
785    pub fn type_v(&self) -> KvCacheType {
786        KvCacheType::from(self.context_params.type_v)
787    }
788
789    /// Set the attention type
790    ///
791    /// # Examples
792    ///
793    /// ```rust
794    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaAttentionType};
795    /// let params = LlamaContextParams::default()
796    ///     .with_attention_type(LlamaAttentionType::NonCausal);
797    /// assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
798    /// ```
799    #[must_use]
800    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
801        self.context_params.attention_type = i32::from(attention_type);
802        self
803    }
804
805    /// Get the attention type
806    ///
807    /// # Examples
808    ///
809    /// ```rust
810    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
811    /// assert_eq!(params.attention_type(), llama_cpp_bindings::context::params::LlamaAttentionType::Unspecified);
812    /// ```
813    #[must_use]
814    pub fn attention_type(&self) -> LlamaAttentionType {
815        LlamaAttentionType::from(self.context_params.attention_type)
816    }
817
818    /// Set the `YaRN` extrapolation factor
819    ///
820    /// # Examples
821    ///
822    /// ```rust
823    /// use llama_cpp_bindings::context::params::LlamaContextParams;
824    /// let params = LlamaContextParams::default()
825    ///     .with_yarn_ext_factor(1.0);
826    /// assert!((params.yarn_ext_factor() - 1.0).abs() < f32::EPSILON);
827    /// ```
828    #[must_use]
829    pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
830        self.context_params.yarn_ext_factor = yarn_ext_factor;
831        self
832    }
833
834    /// Get the `YaRN` extrapolation factor
835    #[must_use]
836    pub const fn yarn_ext_factor(&self) -> f32 {
837        self.context_params.yarn_ext_factor
838    }
839
840    /// Set the `YaRN` attention factor
841    ///
842    /// # Examples
843    ///
844    /// ```rust
845    /// use llama_cpp_bindings::context::params::LlamaContextParams;
846    /// let params = LlamaContextParams::default()
847    ///     .with_yarn_attn_factor(2.0);
848    /// assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
849    /// ```
850    #[must_use]
851    pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
852        self.context_params.yarn_attn_factor = yarn_attn_factor;
853        self
854    }
855
856    /// Get the `YaRN` attention factor
857    #[must_use]
858    pub const fn yarn_attn_factor(&self) -> f32 {
859        self.context_params.yarn_attn_factor
860    }
861
862    /// Set the `YaRN` low correction dim
863    ///
864    /// # Examples
865    ///
866    /// ```rust
867    /// use llama_cpp_bindings::context::params::LlamaContextParams;
868    /// let params = LlamaContextParams::default()
869    ///     .with_yarn_beta_fast(32.0);
870    /// assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
871    /// ```
872    #[must_use]
873    pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
874        self.context_params.yarn_beta_fast = yarn_beta_fast;
875        self
876    }
877
878    /// Get the `YaRN` low correction dim
879    #[must_use]
880    pub const fn yarn_beta_fast(&self) -> f32 {
881        self.context_params.yarn_beta_fast
882    }
883
884    /// Set the `YaRN` high correction dim
885    ///
886    /// # Examples
887    ///
888    /// ```rust
889    /// use llama_cpp_bindings::context::params::LlamaContextParams;
890    /// let params = LlamaContextParams::default()
891    ///     .with_yarn_beta_slow(1.0);
892    /// assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
893    /// ```
894    #[must_use]
895    pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
896        self.context_params.yarn_beta_slow = yarn_beta_slow;
897        self
898    }
899
900    /// Get the `YaRN` high correction dim
901    #[must_use]
902    pub const fn yarn_beta_slow(&self) -> f32 {
903        self.context_params.yarn_beta_slow
904    }
905
906    /// Set the `YaRN` original context size
907    ///
908    /// # Examples
909    ///
910    /// ```rust
911    /// use llama_cpp_bindings::context::params::LlamaContextParams;
912    /// let params = LlamaContextParams::default()
913    ///     .with_yarn_orig_ctx(4096);
914    /// assert_eq!(params.yarn_orig_ctx(), 4096);
915    /// ```
916    #[must_use]
917    pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
918        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
919        self
920    }
921
922    /// Get the `YaRN` original context size
923    #[must_use]
924    pub const fn yarn_orig_ctx(&self) -> u32 {
925        self.context_params.yarn_orig_ctx
926    }
927
928    /// Set the KV cache defragmentation threshold
929    ///
930    /// # Examples
931    ///
932    /// ```rust
933    /// use llama_cpp_bindings::context::params::LlamaContextParams;
934    /// let params = LlamaContextParams::default()
935    ///     .with_defrag_thold(0.1);
936    /// assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
937    /// ```
938    #[must_use]
939    pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
940        self.context_params.defrag_thold = defrag_thold;
941        self
942    }
943
944    /// Get the KV cache defragmentation threshold
945    #[must_use]
946    pub const fn defrag_thold(&self) -> f32 {
947        self.context_params.defrag_thold
948    }
949
950    /// Set whether performance timings are disabled
951    ///
952    /// # Examples
953    ///
954    /// ```rust
955    /// use llama_cpp_bindings::context::params::LlamaContextParams;
956    /// let params = LlamaContextParams::default()
957    ///     .with_no_perf(true);
958    /// assert!(params.no_perf());
959    /// ```
960    #[must_use]
961    pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
962        self.context_params.no_perf = no_perf;
963        self
964    }
965
966    /// Get whether performance timings are disabled
967    #[must_use]
968    pub const fn no_perf(&self) -> bool {
969        self.context_params.no_perf
970    }
971
972    /// Set whether to offload ops to GPU
973    ///
974    /// # Examples
975    ///
976    /// ```rust
977    /// use llama_cpp_bindings::context::params::LlamaContextParams;
978    /// let params = LlamaContextParams::default()
979    ///     .with_op_offload(false);
980    /// assert!(!params.op_offload());
981    /// ```
982    #[must_use]
983    pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
984        self.context_params.op_offload = op_offload;
985        self
986    }
987
988    /// Get whether ops are offloaded to GPU
989    #[must_use]
990    pub const fn op_offload(&self) -> bool {
991        self.context_params.op_offload
992    }
993
994    /// Set whether to use a unified KV cache buffer across input sequences
995    ///
996    /// # Examples
997    ///
998    /// ```rust
999    /// use llama_cpp_bindings::context::params::LlamaContextParams;
1000    /// let params = LlamaContextParams::default()
1001    ///     .with_kv_unified(true);
1002    /// assert!(params.kv_unified());
1003    /// ```
1004    #[must_use]
1005    pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
1006        self.context_params.kv_unified = kv_unified;
1007        self
1008    }
1009
1010    /// Get whether a unified KV cache buffer is used across input sequences
1011    #[must_use]
1012    pub const fn kv_unified(&self) -> bool {
1013        self.context_params.kv_unified
1014    }
1015}
1016
1017/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
1018/// ```
1019/// # use std::num::NonZeroU32;
1020/// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
1021/// let params = LlamaContextParams::default();
1022/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
1023/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1024/// ```
1025impl Default for LlamaContextParams {
1026    fn default() -> Self {
1027        let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
1028        Self { context_params }
1029    }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034    use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
1035
1036    #[test]
1037    fn rope_scaling_type_unknown_defaults_to_unspecified() {
1038        assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified);
1039        assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified);
1040    }
1041
1042    #[test]
1043    fn pooling_type_unknown_defaults_to_unspecified() {
1044        assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified);
1045        assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified);
1046    }
1047
1048    #[test]
1049    fn kv_cache_type_unknown_preserves_raw_value() {
1050        let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999;
1051        let cache_type = KvCacheType::from(unknown_raw);
1052
1053        assert_eq!(cache_type, KvCacheType::Unknown(99999));
1054
1055        let back: llama_cpp_bindings_sys::ggml_type = cache_type.into();
1056
1057        assert_eq!(back, 99999);
1058    }
1059
1060    #[test]
1061    fn default_params_have_expected_values() {
1062        let params = super::LlamaContextParams::default();
1063
1064        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
1065        assert_eq!(params.n_batch(), 2048);
1066        assert_eq!(params.n_ubatch(), 512);
1067        assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1068        assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
1069    }
1070
1071    #[test]
1072    fn with_n_ctx_sets_value() {
1073        let params =
1074            super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
1075
1076        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
1077    }
1078
1079    #[test]
1080    fn with_n_ctx_none_sets_zero() {
1081        let params = super::LlamaContextParams::default().with_n_ctx(None);
1082
1083        assert_eq!(params.n_ctx(), None);
1084    }
1085
1086    #[test]
1087    fn with_n_batch_sets_value() {
1088        let params = super::LlamaContextParams::default().with_n_batch(4096);
1089
1090        assert_eq!(params.n_batch(), 4096);
1091    }
1092
1093    #[test]
1094    fn with_n_ubatch_sets_value() {
1095        let params = super::LlamaContextParams::default().with_n_ubatch(1024);
1096
1097        assert_eq!(params.n_ubatch(), 1024);
1098    }
1099
1100    #[test]
1101    fn with_n_seq_max_sets_value() {
1102        let params = super::LlamaContextParams::default().with_n_seq_max(64);
1103
1104        assert_eq!(params.n_seq_max(), 64);
1105    }
1106
1107    #[test]
1108    fn with_embeddings_enables() {
1109        let params = super::LlamaContextParams::default().with_embeddings(true);
1110
1111        assert!(params.embeddings());
1112    }
1113
1114    #[test]
1115    fn with_embeddings_disables() {
1116        let params = super::LlamaContextParams::default().with_embeddings(false);
1117
1118        assert!(!params.embeddings());
1119    }
1120
1121    #[test]
1122    fn with_offload_kqv_disables() {
1123        let params = super::LlamaContextParams::default().with_offload_kqv(false);
1124
1125        assert!(!params.offload_kqv());
1126    }
1127
1128    #[test]
1129    fn with_offload_kqv_enables() {
1130        let params = super::LlamaContextParams::default().with_offload_kqv(true);
1131
1132        assert!(params.offload_kqv());
1133    }
1134
1135    #[test]
1136    fn with_swa_full_disables() {
1137        let params = super::LlamaContextParams::default().with_swa_full(false);
1138
1139        assert!(!params.swa_full());
1140    }
1141
1142    #[test]
1143    fn with_swa_full_enables() {
1144        let params = super::LlamaContextParams::default().with_swa_full(true);
1145
1146        assert!(params.swa_full());
1147    }
1148
1149    #[test]
1150    fn with_rope_scaling_type_linear() {
1151        let params =
1152            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
1153
1154        assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
1155    }
1156
1157    #[test]
1158    fn with_rope_scaling_type_yarn() {
1159        let params =
1160            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
1161
1162        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1163    }
1164
1165    #[test]
1166    fn with_rope_scaling_type_none() {
1167        let params =
1168            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
1169
1170        assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
1171    }
1172
1173    #[test]
1174    fn with_rope_freq_base_sets_value() {
1175        let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
1176
1177        assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
1178    }
1179
1180    #[test]
1181    fn with_rope_freq_scale_sets_value() {
1182        let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
1183
1184        assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
1185    }
1186
1187    #[test]
1188    fn with_n_threads_sets_value() {
1189        let params = super::LlamaContextParams::default().with_n_threads(16);
1190
1191        assert_eq!(params.n_threads(), 16);
1192    }
1193
1194    #[test]
1195    fn with_n_threads_batch_sets_value() {
1196        let params = super::LlamaContextParams::default().with_n_threads_batch(16);
1197
1198        assert_eq!(params.n_threads_batch(), 16);
1199    }
1200
1201    #[test]
1202    fn with_pooling_type_mean() {
1203        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
1204
1205        assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
1206    }
1207
1208    #[test]
1209    fn with_pooling_type_cls() {
1210        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
1211
1212        assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
1213    }
1214
1215    #[test]
1216    fn with_pooling_type_last() {
1217        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
1218
1219        assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
1220    }
1221
1222    #[test]
1223    fn with_pooling_type_rank() {
1224        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
1225
1226        assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
1227    }
1228
1229    #[test]
1230    fn with_pooling_type_none() {
1231        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
1232
1233        assert_eq!(params.pooling_type(), LlamaPoolingType::None);
1234    }
1235
1236    #[test]
1237    fn with_type_k_sets_value() {
1238        let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
1239
1240        assert_eq!(params.type_k(), KvCacheType::Q4_0);
1241    }
1242
1243    #[test]
1244    fn with_type_v_sets_value() {
1245        let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
1246
1247        assert_eq!(params.type_v(), KvCacheType::Q4_1);
1248    }
1249
1250    #[test]
1251    fn with_flash_attention_policy_sets_value() {
1252        let params = super::LlamaContextParams::default()
1253            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
1254
1255        assert_eq!(
1256            params.flash_attention_policy(),
1257            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
1258        );
1259    }
1260
1261    #[test]
1262    fn builder_chaining_preserves_all_values() {
1263        let params = super::LlamaContextParams::default()
1264            .with_n_ctx(std::num::NonZeroU32::new(1024))
1265            .with_n_batch(4096)
1266            .with_n_ubatch(256)
1267            .with_n_threads(8)
1268            .with_n_threads_batch(12)
1269            .with_embeddings(true)
1270            .with_offload_kqv(false)
1271            .with_rope_scaling_type(RopeScalingType::Yarn)
1272            .with_rope_freq_base(5000.0)
1273            .with_rope_freq_scale(0.25);
1274
1275        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1276        assert_eq!(params.n_batch(), 4096);
1277        assert_eq!(params.n_ubatch(), 256);
1278        assert_eq!(params.n_threads(), 8);
1279        assert_eq!(params.n_threads_batch(), 12);
1280        assert!(params.embeddings());
1281        assert!(!params.offload_kqv());
1282        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1283        assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1284        assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1285    }
1286
1287    #[test]
1288    fn rope_scaling_type_roundtrip_all_variants() {
1289        for (raw, expected) in [
1290            (-1, RopeScalingType::Unspecified),
1291            (0, RopeScalingType::None),
1292            (1, RopeScalingType::Linear),
1293            (2, RopeScalingType::Yarn),
1294        ] {
1295            let from_raw = RopeScalingType::from(raw);
1296            assert_eq!(from_raw, expected);
1297
1298            let back_to_raw: i32 = from_raw.into();
1299            assert_eq!(back_to_raw, raw);
1300        }
1301    }
1302
1303    #[test]
1304    fn pooling_type_roundtrip_all_variants() {
1305        for (raw, expected) in [
1306            (-1, LlamaPoolingType::Unspecified),
1307            (0, LlamaPoolingType::None),
1308            (1, LlamaPoolingType::Mean),
1309            (2, LlamaPoolingType::Cls),
1310            (3, LlamaPoolingType::Last),
1311            (4, LlamaPoolingType::Rank),
1312        ] {
1313            let from_raw = LlamaPoolingType::from(raw);
1314            assert_eq!(from_raw, expected);
1315
1316            let back_to_raw: i32 = from_raw.into();
1317            assert_eq!(back_to_raw, raw);
1318        }
1319    }
1320
1321    #[test]
1322    fn kv_cache_type_all_known_variants_roundtrip() {
1323        let all_variants = [
1324            KvCacheType::F32,
1325            KvCacheType::F16,
1326            KvCacheType::Q4_0,
1327            KvCacheType::Q4_1,
1328            KvCacheType::Q5_0,
1329            KvCacheType::Q5_1,
1330            KvCacheType::Q8_0,
1331            KvCacheType::Q8_1,
1332            KvCacheType::Q2_K,
1333            KvCacheType::Q3_K,
1334            KvCacheType::Q4_K,
1335            KvCacheType::Q5_K,
1336            KvCacheType::Q6_K,
1337            KvCacheType::Q8_K,
1338            KvCacheType::IQ2_XXS,
1339            KvCacheType::IQ2_XS,
1340            KvCacheType::IQ3_XXS,
1341            KvCacheType::IQ1_S,
1342            KvCacheType::IQ4_NL,
1343            KvCacheType::IQ3_S,
1344            KvCacheType::IQ2_S,
1345            KvCacheType::IQ4_XS,
1346            KvCacheType::I8,
1347            KvCacheType::I16,
1348            KvCacheType::I32,
1349            KvCacheType::I64,
1350            KvCacheType::F64,
1351            KvCacheType::IQ1_M,
1352            KvCacheType::BF16,
1353            KvCacheType::TQ1_0,
1354            KvCacheType::TQ2_0,
1355            KvCacheType::MXFP4,
1356        ];
1357
1358        for variant in all_variants {
1359            let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into();
1360            let back = KvCacheType::from(ggml_type);
1361
1362            assert_eq!(back, variant);
1363        }
1364    }
1365
1366    #[test]
1367    fn with_cb_eval_sets_callback() {
1368        extern "C" fn test_cb_eval(
1369            _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1370            _ask: bool,
1371            _user_data: *mut std::ffi::c_void,
1372        ) -> bool {
1373            false
1374        }
1375
1376        let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1377
1378        assert!(!result);
1379
1380        let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1381
1382        assert!(params.context_params.cb_eval.is_some());
1383    }
1384
1385    #[test]
1386    fn with_cb_eval_user_data_sets_pointer() {
1387        let mut value: i32 = 42;
1388        let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1389        let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1390
1391        assert_eq!(params.context_params.cb_eval_user_data, user_data);
1392    }
1393
1394    #[test]
1395    fn with_flash_attention_policy_disabled() {
1396        let params = super::LlamaContextParams::default()
1397            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1398
1399        assert_eq!(
1400            params.flash_attention_policy(),
1401            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1402        );
1403    }
1404
1405    #[test]
1406    fn attention_type_unknown_defaults_to_unspecified() {
1407        assert_eq!(
1408            LlamaAttentionType::from(99),
1409            LlamaAttentionType::Unspecified
1410        );
1411        assert_eq!(
1412            LlamaAttentionType::from(-50),
1413            LlamaAttentionType::Unspecified
1414        );
1415    }
1416
1417    #[test]
1418    fn attention_type_roundtrip_all_variants() {
1419        for (raw, expected) in [
1420            (-1, LlamaAttentionType::Unspecified),
1421            (0, LlamaAttentionType::Causal),
1422            (1, LlamaAttentionType::NonCausal),
1423        ] {
1424            let from_raw = LlamaAttentionType::from(raw);
1425            assert_eq!(from_raw, expected);
1426
1427            let back_to_raw: i32 = from_raw.into();
1428            assert_eq!(back_to_raw, raw);
1429        }
1430    }
1431
1432    #[test]
1433    fn with_attention_type_causal() {
1434        let params =
1435            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1436
1437        assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1438    }
1439
1440    #[test]
1441    fn with_attention_type_non_causal() {
1442        let params =
1443            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1444
1445        assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1446    }
1447
1448    #[test]
1449    fn with_yarn_ext_factor_sets_value() {
1450        let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1451
1452        assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1453    }
1454
1455    #[test]
1456    fn with_yarn_attn_factor_sets_value() {
1457        let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1458
1459        assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1460    }
1461
1462    #[test]
1463    fn with_yarn_beta_fast_sets_value() {
1464        let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1465
1466        assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1467    }
1468
1469    #[test]
1470    fn with_yarn_beta_slow_sets_value() {
1471        let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1472
1473        assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1474    }
1475
1476    #[test]
1477    fn with_yarn_orig_ctx_sets_value() {
1478        let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1479
1480        assert_eq!(params.yarn_orig_ctx(), 4096);
1481    }
1482
1483    #[test]
1484    fn with_defrag_thold_sets_value() {
1485        let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1486
1487        assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1488    }
1489
1490    #[test]
1491    fn with_no_perf_enables() {
1492        let params = super::LlamaContextParams::default().with_no_perf(true);
1493
1494        assert!(params.no_perf());
1495    }
1496
1497    #[test]
1498    fn with_no_perf_disables() {
1499        let params = super::LlamaContextParams::default().with_no_perf(false);
1500
1501        assert!(!params.no_perf());
1502    }
1503
1504    #[test]
1505    fn with_op_offload_enables() {
1506        let params = super::LlamaContextParams::default().with_op_offload(true);
1507
1508        assert!(params.op_offload());
1509    }
1510
1511    #[test]
1512    fn with_op_offload_disables() {
1513        let params = super::LlamaContextParams::default().with_op_offload(false);
1514
1515        assert!(!params.op_offload());
1516    }
1517
1518    #[test]
1519    fn with_kv_unified_enables() {
1520        let params = super::LlamaContextParams::default().with_kv_unified(true);
1521
1522        assert!(params.kv_unified());
1523    }
1524
1525    #[test]
1526    fn with_kv_unified_disables() {
1527        let params = super::LlamaContextParams::default().with_kv_unified(false);
1528
1529        assert!(!params.kv_unified());
1530    }
1531}