Skip to main content

llama_cpp_bindings/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5pub use crate::context::kv_cache_type::KvCacheType;
6pub use crate::context::llama_attention_type::LlamaAttentionType;
7pub use crate::context::llama_pooling_type::LlamaPoolingType;
8pub use crate::context::rope_scaling_type::RopeScalingType;
9
10/// A safe wrapper around `llama_context_params`.
11///
12/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
13///
14/// # Examples
15///
16/// ```rust
17/// # use std::num::NonZeroU32;
18/// use llama_cpp_bindings::context::params::LlamaContextParams;
19///
20///let ctx_params = LlamaContextParams::default()
21///    .with_n_ctx(NonZeroU32::new(2048));
22///
23/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
24/// ```
25#[derive(Debug, Clone)]
26#[expect(
27    missing_docs,
28    reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
29              one inline would risk drift from the upstream spec — the doc-comment on the struct \
30              points at the canonical reference"
31)]
32#[expect(
33    clippy::module_name_repetitions,
34    reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
35              `Params` would force `params::Params` at every call site"
36)]
37pub struct LlamaContextParams {
38    pub context_params: llama_cpp_bindings_sys::llama_context_params,
39}
40
41/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
42unsafe impl Send for LlamaContextParams {}
43unsafe impl Sync for LlamaContextParams {}
44
45impl LlamaContextParams {
46    /// Set the side of the context
47    ///
48    /// # Examples
49    ///
50    /// ```rust
51    /// # use std::num::NonZeroU32;
52    /// use llama_cpp_bindings::context::params::LlamaContextParams;
53    /// let params = LlamaContextParams::default();
54    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
55    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
56    /// ```
57    #[must_use]
58    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
59        self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
60        self
61    }
62
63    /// Get the size of the context.
64    ///
65    /// [`None`] if the context size is specified by the model and not the context.
66    ///
67    /// # Examples
68    ///
69    /// ```rust
70    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
71    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
72    #[must_use]
73    pub const fn n_ctx(&self) -> Option<NonZeroU32> {
74        NonZeroU32::new(self.context_params.n_ctx)
75    }
76
77    /// Set the `n_batch`
78    ///
79    /// # Examples
80    ///
81    /// ```rust
82    /// # use std::num::NonZeroU32;
83    /// use llama_cpp_bindings::context::params::LlamaContextParams;
84    /// let params = LlamaContextParams::default()
85    ///     .with_n_batch(2048);
86    /// assert_eq!(params.n_batch(), 2048);
87    /// ```
88    #[must_use]
89    pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
90        self.context_params.n_batch = n_batch;
91        self
92    }
93
94    /// Get the `n_batch`
95    ///
96    /// # Examples
97    ///
98    /// ```rust
99    /// use llama_cpp_bindings::context::params::LlamaContextParams;
100    /// let params = LlamaContextParams::default();
101    /// assert_eq!(params.n_batch(), 2048);
102    /// ```
103    #[must_use]
104    pub const fn n_batch(&self) -> u32 {
105        self.context_params.n_batch
106    }
107
108    /// Set the `n_ubatch`
109    ///
110    /// # Examples
111    ///
112    /// ```rust
113    /// # use std::num::NonZeroU32;
114    /// use llama_cpp_bindings::context::params::LlamaContextParams;
115    /// let params = LlamaContextParams::default()
116    ///     .with_n_ubatch(512);
117    /// assert_eq!(params.n_ubatch(), 512);
118    /// ```
119    #[must_use]
120    pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
121        self.context_params.n_ubatch = n_ubatch;
122        self
123    }
124
125    /// Get the `n_ubatch`
126    ///
127    /// # Examples
128    ///
129    /// ```rust
130    /// use llama_cpp_bindings::context::params::LlamaContextParams;
131    /// let params = LlamaContextParams::default();
132    /// assert_eq!(params.n_ubatch(), 512);
133    /// ```
134    #[must_use]
135    pub const fn n_ubatch(&self) -> u32 {
136        self.context_params.n_ubatch
137    }
138
139    /// Set the flash attention policy using llama.cpp enum
140    #[must_use]
141    pub const fn with_flash_attention_policy(
142        mut self,
143        policy: llama_cpp_bindings_sys::llama_flash_attn_type,
144    ) -> Self {
145        self.context_params.flash_attn_type = policy;
146        self
147    }
148
149    /// Get the flash attention policy
150    #[must_use]
151    pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
152        self.context_params.flash_attn_type
153    }
154
155    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
156    ///
157    /// # Examples
158    ///
159    /// ```rust
160    /// use llama_cpp_bindings::context::params::LlamaContextParams;
161    /// let params = LlamaContextParams::default()
162    ///     .with_offload_kqv(false);
163    /// assert_eq!(params.offload_kqv(), false);
164    /// ```
165    #[must_use]
166    pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
167        self.context_params.offload_kqv = enabled;
168        self
169    }
170
171    /// Get the `offload_kqv` parameter
172    ///
173    /// # Examples
174    ///
175    /// ```rust
176    /// use llama_cpp_bindings::context::params::LlamaContextParams;
177    /// let params = LlamaContextParams::default();
178    /// assert_eq!(params.offload_kqv(), true);
179    /// ```
180    #[must_use]
181    pub const fn offload_kqv(&self) -> bool {
182        self.context_params.offload_kqv
183    }
184
185    /// Set the type of rope scaling.
186    ///
187    /// # Examples
188    ///
189    /// ```rust
190    /// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
191    /// let params = LlamaContextParams::default()
192    ///     .with_rope_scaling_type(RopeScalingType::Linear);
193    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
194    /// ```
195    #[must_use]
196    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
197        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
198        self
199    }
200
201    /// Get the type of rope scaling.
202    ///
203    /// # Examples
204    ///
205    /// ```rust
206    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
207    /// assert_eq!(params.rope_scaling_type(), llama_cpp_bindings::context::params::RopeScalingType::Unspecified);
208    /// ```
209    #[must_use]
210    pub fn rope_scaling_type(&self) -> RopeScalingType {
211        RopeScalingType::from(self.context_params.rope_scaling_type)
212    }
213
214    /// Set the rope frequency base.
215    ///
216    /// # Examples
217    ///
218    /// ```rust
219    /// use llama_cpp_bindings::context::params::LlamaContextParams;
220    /// let params = LlamaContextParams::default()
221    ///    .with_rope_freq_base(0.5);
222    /// assert_eq!(params.rope_freq_base(), 0.5);
223    /// ```
224    #[must_use]
225    pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
226        self.context_params.rope_freq_base = rope_freq_base;
227        self
228    }
229
230    /// Get the rope frequency base.
231    ///
232    /// # Examples
233    ///
234    /// ```rust
235    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
236    /// assert_eq!(params.rope_freq_base(), 0.0);
237    /// ```
238    #[must_use]
239    pub const fn rope_freq_base(&self) -> f32 {
240        self.context_params.rope_freq_base
241    }
242
243    /// Set the rope frequency scale.
244    ///
245    /// # Examples
246    ///
247    /// ```rust
248    /// use llama_cpp_bindings::context::params::LlamaContextParams;
249    /// let params = LlamaContextParams::default()
250    ///   .with_rope_freq_scale(0.5);
251    /// assert_eq!(params.rope_freq_scale(), 0.5);
252    /// ```
253    #[must_use]
254    pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
255        self.context_params.rope_freq_scale = rope_freq_scale;
256        self
257    }
258
259    /// Get the rope frequency scale.
260    ///
261    /// # Examples
262    ///
263    /// ```rust
264    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
265    /// assert_eq!(params.rope_freq_scale(), 0.0);
266    /// ```
267    #[must_use]
268    pub const fn rope_freq_scale(&self) -> f32 {
269        self.context_params.rope_freq_scale
270    }
271
272    /// Get the number of threads.
273    ///
274    /// # Examples
275    ///
276    /// ```rust
277    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
278    /// assert_eq!(params.n_threads(), 4);
279    /// ```
280    #[must_use]
281    pub const fn n_threads(&self) -> i32 {
282        self.context_params.n_threads
283    }
284
285    /// Get the number of threads allocated for batches.
286    ///
287    /// # Examples
288    ///
289    /// ```rust
290    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
291    /// assert_eq!(params.n_threads_batch(), 4);
292    /// ```
293    #[must_use]
294    pub const fn n_threads_batch(&self) -> i32 {
295        self.context_params.n_threads_batch
296    }
297
298    /// Set the number of threads.
299    ///
300    /// # Examples
301    ///
302    /// ```rust
303    /// use llama_cpp_bindings::context::params::LlamaContextParams;
304    /// let params = LlamaContextParams::default()
305    ///    .with_n_threads(8);
306    /// assert_eq!(params.n_threads(), 8);
307    /// ```
308    #[must_use]
309    pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
310        self.context_params.n_threads = n_threads;
311        self
312    }
313
314    /// Set the number of threads allocated for batches.
315    ///
316    /// # Examples
317    ///
318    /// ```rust
319    /// use llama_cpp_bindings::context::params::LlamaContextParams;
320    /// let params = LlamaContextParams::default()
321    ///    .with_n_threads_batch(8);
322    /// assert_eq!(params.n_threads_batch(), 8);
323    /// ```
324    #[must_use]
325    pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
326        self.context_params.n_threads_batch = n_threads;
327        self
328    }
329
330    /// Check whether embeddings are enabled
331    ///
332    /// # Examples
333    ///
334    /// ```rust
335    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
336    /// assert!(!params.embeddings());
337    /// ```
338    #[must_use]
339    pub const fn embeddings(&self) -> bool {
340        self.context_params.embeddings
341    }
342
343    /// Enable the use of embeddings
344    ///
345    /// # Examples
346    ///
347    /// ```rust
348    /// use llama_cpp_bindings::context::params::LlamaContextParams;
349    /// let params = LlamaContextParams::default()
350    ///    .with_embeddings(true);
351    /// assert!(params.embeddings());
352    /// ```
353    #[must_use]
354    pub const fn with_embeddings(mut self, embedding: bool) -> Self {
355        self.context_params.embeddings = embedding;
356        self
357    }
358
359    /// Set the evaluation callback.
360    ///
361    /// # Examples
362    ///
363    /// ```no_run
364    /// extern "C" fn cb_eval_fn(
365    ///     t: *mut llama_cpp_bindings_sys::ggml_tensor,
366    ///     ask: bool,
367    ///     user_data: *mut std::ffi::c_void,
368    /// ) -> bool {
369    ///     false
370    /// }
371    ///
372    /// use llama_cpp_bindings::context::params::LlamaContextParams;
373    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
374    /// ```
375    #[must_use]
376    pub fn with_cb_eval(
377        mut self,
378        cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
379    ) -> Self {
380        self.context_params.cb_eval = cb_eval;
381        self
382    }
383
384    /// Set the evaluation callback user data.
385    ///
386    /// # Examples
387    ///
388    /// ```no_run
389    /// use llama_cpp_bindings::context::params::LlamaContextParams;
390    /// let params = LlamaContextParams::default();
391    /// let user_data = std::ptr::null_mut();
392    /// let params = params.with_cb_eval_user_data(user_data);
393    /// ```
394    #[must_use]
395    pub const fn with_cb_eval_user_data(
396        mut self,
397        cb_eval_user_data: *mut std::ffi::c_void,
398    ) -> Self {
399        self.context_params.cb_eval_user_data = cb_eval_user_data;
400        self
401    }
402
403    /// Set the type of pooling.
404    ///
405    /// # Examples
406    ///
407    /// ```rust
408    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaPoolingType};
409    /// let params = LlamaContextParams::default()
410    ///     .with_pooling_type(LlamaPoolingType::Last);
411    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
412    /// ```
413    #[must_use]
414    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
415        self.context_params.pooling_type = i32::from(pooling_type);
416        self
417    }
418
419    /// Get the type of pooling.
420    ///
421    /// # Examples
422    ///
423    /// ```rust
424    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
425    /// assert_eq!(params.pooling_type(), llama_cpp_bindings::context::params::LlamaPoolingType::Unspecified);
426    /// ```
427    #[must_use]
428    pub fn pooling_type(&self) -> LlamaPoolingType {
429        LlamaPoolingType::from(self.context_params.pooling_type)
430    }
431
432    /// Set whether to use full sliding window attention
433    ///
434    /// # Examples
435    ///
436    /// ```rust
437    /// use llama_cpp_bindings::context::params::LlamaContextParams;
438    /// let params = LlamaContextParams::default()
439    ///     .with_swa_full(false);
440    /// assert_eq!(params.swa_full(), false);
441    /// ```
442    #[must_use]
443    pub const fn with_swa_full(mut self, enabled: bool) -> Self {
444        self.context_params.swa_full = enabled;
445        self
446    }
447
448    /// Get whether full sliding window attention is enabled
449    ///
450    /// # Examples
451    ///
452    /// ```rust
453    /// use llama_cpp_bindings::context::params::LlamaContextParams;
454    /// let params = LlamaContextParams::default();
455    /// assert_eq!(params.swa_full(), true);
456    /// ```
457    #[must_use]
458    pub const fn swa_full(&self) -> bool {
459        self.context_params.swa_full
460    }
461
462    /// Set the max number of sequences (i.e. distinct states for recurrent models)
463    ///
464    /// # Examples
465    ///
466    /// ```rust
467    /// use llama_cpp_bindings::context::params::LlamaContextParams;
468    /// let params = LlamaContextParams::default()
469    ///     .with_n_seq_max(64);
470    /// assert_eq!(params.n_seq_max(), 64);
471    /// ```
472    #[must_use]
473    pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
474        self.context_params.n_seq_max = n_seq_max;
475        self
476    }
477
478    /// Get the max number of sequences (i.e. distinct states for recurrent models)
479    ///
480    /// # Examples
481    ///
482    /// ```rust
483    /// use llama_cpp_bindings::context::params::LlamaContextParams;
484    /// let params = LlamaContextParams::default();
485    /// assert_eq!(params.n_seq_max(), 1);
486    /// ```
487    #[must_use]
488    pub const fn n_seq_max(&self) -> u32 {
489        self.context_params.n_seq_max
490    }
491    /// Set the KV cache data type for K
492    /// use `llama_cpp_bindings::context::params::{LlamaContextParams`, `KvCacheType`};
493    /// let params = `LlamaContextParams::default().with_type_k(KvCacheType::Q4_0)`;
494    /// `assert_eq!(params.type_k()`, `KvCacheType::Q4_0`);
495    /// ```
496    #[must_use]
497    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
498        self.context_params.type_k = type_k.into();
499        self
500    }
501
502    /// Get the KV cache data type for K
503    ///
504    /// # Examples
505    ///
506    /// ```rust
507    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
508    /// let _ = params.type_k();
509    /// ```
510    #[must_use]
511    pub fn type_k(&self) -> KvCacheType {
512        KvCacheType::from(self.context_params.type_k)
513    }
514
515    /// Set the KV cache data type for V
516    ///
517    /// # Examples
518    ///
519    /// ```rust
520    /// use llama_cpp_bindings::context::params::{LlamaContextParams, KvCacheType};
521    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
522    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
523    /// ```
524    #[must_use]
525    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
526        self.context_params.type_v = type_v.into();
527        self
528    }
529
530    /// Get the KV cache data type for V
531    ///
532    /// # Examples
533    ///
534    /// ```rust
535    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
536    /// let _ = params.type_v();
537    /// ```
538    #[must_use]
539    pub fn type_v(&self) -> KvCacheType {
540        KvCacheType::from(self.context_params.type_v)
541    }
542
543    /// Set the attention type
544    ///
545    /// # Examples
546    ///
547    /// ```rust
548    /// use llama_cpp_bindings::context::params::{LlamaContextParams, LlamaAttentionType};
549    /// let params = LlamaContextParams::default()
550    ///     .with_attention_type(LlamaAttentionType::NonCausal);
551    /// assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
552    /// ```
553    #[must_use]
554    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
555        self.context_params.attention_type = i32::from(attention_type);
556        self
557    }
558
559    /// Get the attention type
560    ///
561    /// # Examples
562    ///
563    /// ```rust
564    /// let params = llama_cpp_bindings::context::params::LlamaContextParams::default();
565    /// assert_eq!(params.attention_type(), llama_cpp_bindings::context::params::LlamaAttentionType::Unspecified);
566    /// ```
567    #[must_use]
568    pub fn attention_type(&self) -> LlamaAttentionType {
569        LlamaAttentionType::from(self.context_params.attention_type)
570    }
571
572    /// Set the `YaRN` extrapolation factor
573    ///
574    /// # Examples
575    ///
576    /// ```rust
577    /// use llama_cpp_bindings::context::params::LlamaContextParams;
578    /// let params = LlamaContextParams::default()
579    ///     .with_yarn_ext_factor(1.0);
580    /// assert!((params.yarn_ext_factor() - 1.0).abs() < f32::EPSILON);
581    /// ```
582    #[must_use]
583    pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
584        self.context_params.yarn_ext_factor = yarn_ext_factor;
585        self
586    }
587
588    /// Get the `YaRN` extrapolation factor
589    #[must_use]
590    pub const fn yarn_ext_factor(&self) -> f32 {
591        self.context_params.yarn_ext_factor
592    }
593
594    /// Set the `YaRN` attention factor
595    ///
596    /// # Examples
597    ///
598    /// ```rust
599    /// use llama_cpp_bindings::context::params::LlamaContextParams;
600    /// let params = LlamaContextParams::default()
601    ///     .with_yarn_attn_factor(2.0);
602    /// assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
603    /// ```
604    #[must_use]
605    pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
606        self.context_params.yarn_attn_factor = yarn_attn_factor;
607        self
608    }
609
610    /// Get the `YaRN` attention factor
611    #[must_use]
612    pub const fn yarn_attn_factor(&self) -> f32 {
613        self.context_params.yarn_attn_factor
614    }
615
616    /// Set the `YaRN` low correction dim
617    ///
618    /// # Examples
619    ///
620    /// ```rust
621    /// use llama_cpp_bindings::context::params::LlamaContextParams;
622    /// let params = LlamaContextParams::default()
623    ///     .with_yarn_beta_fast(32.0);
624    /// assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
625    /// ```
626    #[must_use]
627    pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
628        self.context_params.yarn_beta_fast = yarn_beta_fast;
629        self
630    }
631
632    /// Get the `YaRN` low correction dim
633    #[must_use]
634    pub const fn yarn_beta_fast(&self) -> f32 {
635        self.context_params.yarn_beta_fast
636    }
637
638    /// Set the `YaRN` high correction dim
639    ///
640    /// # Examples
641    ///
642    /// ```rust
643    /// use llama_cpp_bindings::context::params::LlamaContextParams;
644    /// let params = LlamaContextParams::default()
645    ///     .with_yarn_beta_slow(1.0);
646    /// assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
647    /// ```
648    #[must_use]
649    pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
650        self.context_params.yarn_beta_slow = yarn_beta_slow;
651        self
652    }
653
654    /// Get the `YaRN` high correction dim
655    #[must_use]
656    pub const fn yarn_beta_slow(&self) -> f32 {
657        self.context_params.yarn_beta_slow
658    }
659
660    /// Set the `YaRN` original context size
661    ///
662    /// # Examples
663    ///
664    /// ```rust
665    /// use llama_cpp_bindings::context::params::LlamaContextParams;
666    /// let params = LlamaContextParams::default()
667    ///     .with_yarn_orig_ctx(4096);
668    /// assert_eq!(params.yarn_orig_ctx(), 4096);
669    /// ```
670    #[must_use]
671    pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
672        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
673        self
674    }
675
676    /// Get the `YaRN` original context size
677    #[must_use]
678    pub const fn yarn_orig_ctx(&self) -> u32 {
679        self.context_params.yarn_orig_ctx
680    }
681
682    /// Set the KV cache defragmentation threshold
683    ///
684    /// # Examples
685    ///
686    /// ```rust
687    /// use llama_cpp_bindings::context::params::LlamaContextParams;
688    /// let params = LlamaContextParams::default()
689    ///     .with_defrag_thold(0.1);
690    /// assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
691    /// ```
692    #[must_use]
693    pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
694        self.context_params.defrag_thold = defrag_thold;
695        self
696    }
697
698    /// Get the KV cache defragmentation threshold
699    #[must_use]
700    pub const fn defrag_thold(&self) -> f32 {
701        self.context_params.defrag_thold
702    }
703
704    /// Set whether performance timings are disabled
705    ///
706    /// # Examples
707    ///
708    /// ```rust
709    /// use llama_cpp_bindings::context::params::LlamaContextParams;
710    /// let params = LlamaContextParams::default()
711    ///     .with_no_perf(true);
712    /// assert!(params.no_perf());
713    /// ```
714    #[must_use]
715    pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
716        self.context_params.no_perf = no_perf;
717        self
718    }
719
720    /// Get whether performance timings are disabled
721    #[must_use]
722    pub const fn no_perf(&self) -> bool {
723        self.context_params.no_perf
724    }
725
726    /// Set whether to offload ops to GPU
727    ///
728    /// # Examples
729    ///
730    /// ```rust
731    /// use llama_cpp_bindings::context::params::LlamaContextParams;
732    /// let params = LlamaContextParams::default()
733    ///     .with_op_offload(false);
734    /// assert!(!params.op_offload());
735    /// ```
736    #[must_use]
737    pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
738        self.context_params.op_offload = op_offload;
739        self
740    }
741
742    /// Get whether ops are offloaded to GPU
743    #[must_use]
744    pub const fn op_offload(&self) -> bool {
745        self.context_params.op_offload
746    }
747
748    /// Set whether to use a unified KV cache buffer across input sequences
749    ///
750    /// # Examples
751    ///
752    /// ```rust
753    /// use llama_cpp_bindings::context::params::LlamaContextParams;
754    /// let params = LlamaContextParams::default()
755    ///     .with_kv_unified(true);
756    /// assert!(params.kv_unified());
757    /// ```
758    #[must_use]
759    pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
760        self.context_params.kv_unified = kv_unified;
761        self
762    }
763
764    /// Get whether a unified KV cache buffer is used across input sequences
765    #[must_use]
766    pub const fn kv_unified(&self) -> bool {
767        self.context_params.kv_unified
768    }
769}
770
771/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
772/// ```
773/// # use std::num::NonZeroU32;
774/// use llama_cpp_bindings::context::params::{LlamaContextParams, RopeScalingType};
775/// let params = LlamaContextParams::default();
776/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
777/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
778/// ```
779impl Default for LlamaContextParams {
780    fn default() -> Self {
781        let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
782        Self { context_params }
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
789
790    #[test]
791    fn default_params_have_expected_values() {
792        let params = super::LlamaContextParams::default();
793
794        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
795        assert_eq!(params.n_batch(), 2048);
796        assert_eq!(params.n_ubatch(), 512);
797        assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
798        assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
799    }
800
801    #[test]
802    fn with_n_ctx_sets_value() {
803        let params =
804            super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
805
806        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
807    }
808
809    #[test]
810    fn with_n_ctx_none_sets_zero() {
811        let params = super::LlamaContextParams::default().with_n_ctx(None);
812
813        assert_eq!(params.n_ctx(), None);
814    }
815
816    #[test]
817    fn with_n_batch_sets_value() {
818        let params = super::LlamaContextParams::default().with_n_batch(4096);
819
820        assert_eq!(params.n_batch(), 4096);
821    }
822
823    #[test]
824    fn with_n_ubatch_sets_value() {
825        let params = super::LlamaContextParams::default().with_n_ubatch(1024);
826
827        assert_eq!(params.n_ubatch(), 1024);
828    }
829
830    #[test]
831    fn with_n_seq_max_sets_value() {
832        let params = super::LlamaContextParams::default().with_n_seq_max(64);
833
834        assert_eq!(params.n_seq_max(), 64);
835    }
836
837    #[test]
838    fn with_embeddings_enables() {
839        let params = super::LlamaContextParams::default().with_embeddings(true);
840
841        assert!(params.embeddings());
842    }
843
844    #[test]
845    fn with_embeddings_disables() {
846        let params = super::LlamaContextParams::default().with_embeddings(false);
847
848        assert!(!params.embeddings());
849    }
850
851    #[test]
852    fn with_offload_kqv_disables() {
853        let params = super::LlamaContextParams::default().with_offload_kqv(false);
854
855        assert!(!params.offload_kqv());
856    }
857
858    #[test]
859    fn with_offload_kqv_enables() {
860        let params = super::LlamaContextParams::default().with_offload_kqv(true);
861
862        assert!(params.offload_kqv());
863    }
864
865    #[test]
866    fn with_swa_full_disables() {
867        let params = super::LlamaContextParams::default().with_swa_full(false);
868
869        assert!(!params.swa_full());
870    }
871
872    #[test]
873    fn with_swa_full_enables() {
874        let params = super::LlamaContextParams::default().with_swa_full(true);
875
876        assert!(params.swa_full());
877    }
878
879    #[test]
880    fn with_rope_scaling_type_linear() {
881        let params =
882            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
883
884        assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
885    }
886
887    #[test]
888    fn with_rope_scaling_type_yarn() {
889        let params =
890            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
891
892        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
893    }
894
895    #[test]
896    fn with_rope_scaling_type_none() {
897        let params =
898            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
899
900        assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
901    }
902
903    #[test]
904    fn with_rope_freq_base_sets_value() {
905        let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
906
907        assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
908    }
909
910    #[test]
911    fn with_rope_freq_scale_sets_value() {
912        let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
913
914        assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
915    }
916
917    #[test]
918    fn with_n_threads_sets_value() {
919        let params = super::LlamaContextParams::default().with_n_threads(16);
920
921        assert_eq!(params.n_threads(), 16);
922    }
923
924    #[test]
925    fn with_n_threads_batch_sets_value() {
926        let params = super::LlamaContextParams::default().with_n_threads_batch(16);
927
928        assert_eq!(params.n_threads_batch(), 16);
929    }
930
931    #[test]
932    fn with_pooling_type_mean() {
933        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
934
935        assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
936    }
937
938    #[test]
939    fn with_pooling_type_cls() {
940        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
941
942        assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
943    }
944
945    #[test]
946    fn with_pooling_type_last() {
947        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
948
949        assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
950    }
951
952    #[test]
953    fn with_pooling_type_rank() {
954        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
955
956        assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
957    }
958
959    #[test]
960    fn with_pooling_type_none() {
961        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
962
963        assert_eq!(params.pooling_type(), LlamaPoolingType::None);
964    }
965
966    #[test]
967    fn with_type_k_sets_value() {
968        let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
969
970        assert_eq!(params.type_k(), KvCacheType::Q4_0);
971    }
972
973    #[test]
974    fn with_type_v_sets_value() {
975        let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
976
977        assert_eq!(params.type_v(), KvCacheType::Q4_1);
978    }
979
980    #[test]
981    fn with_flash_attention_policy_sets_value() {
982        let params = super::LlamaContextParams::default()
983            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
984
985        assert_eq!(
986            params.flash_attention_policy(),
987            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
988        );
989    }
990
991    #[test]
992    fn builder_chaining_preserves_all_values() {
993        let params = super::LlamaContextParams::default()
994            .with_n_ctx(std::num::NonZeroU32::new(1024))
995            .with_n_batch(4096)
996            .with_n_ubatch(256)
997            .with_n_threads(8)
998            .with_n_threads_batch(12)
999            .with_embeddings(true)
1000            .with_offload_kqv(false)
1001            .with_rope_scaling_type(RopeScalingType::Yarn)
1002            .with_rope_freq_base(5000.0)
1003            .with_rope_freq_scale(0.25);
1004
1005        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1006        assert_eq!(params.n_batch(), 4096);
1007        assert_eq!(params.n_ubatch(), 256);
1008        assert_eq!(params.n_threads(), 8);
1009        assert_eq!(params.n_threads_batch(), 12);
1010        assert!(params.embeddings());
1011        assert!(!params.offload_kqv());
1012        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1013        assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1014        assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1015    }
1016
1017    #[test]
1018    fn with_cb_eval_sets_callback() {
1019        extern "C" fn test_cb_eval(
1020            _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1021            _ask: bool,
1022            _user_data: *mut std::ffi::c_void,
1023        ) -> bool {
1024            false
1025        }
1026
1027        let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1028
1029        assert!(!result);
1030
1031        let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1032
1033        assert!(params.context_params.cb_eval.is_some());
1034    }
1035
1036    #[test]
1037    fn with_cb_eval_user_data_sets_pointer() {
1038        let mut value: i32 = 42;
1039        let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1040        let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1041
1042        assert_eq!(params.context_params.cb_eval_user_data, user_data);
1043    }
1044
1045    #[test]
1046    fn with_flash_attention_policy_disabled() {
1047        let params = super::LlamaContextParams::default()
1048            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1049
1050        assert_eq!(
1051            params.flash_attention_policy(),
1052            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1053        );
1054    }
1055
1056    #[test]
1057    fn with_attention_type_causal() {
1058        let params =
1059            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1060
1061        assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1062    }
1063
1064    #[test]
1065    fn with_attention_type_non_causal() {
1066        let params =
1067            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1068
1069        assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1070    }
1071
1072    #[test]
1073    fn with_yarn_ext_factor_sets_value() {
1074        let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1075
1076        assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1077    }
1078
1079    #[test]
1080    fn with_yarn_attn_factor_sets_value() {
1081        let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1082
1083        assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1084    }
1085
1086    #[test]
1087    fn with_yarn_beta_fast_sets_value() {
1088        let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1089
1090        assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1091    }
1092
1093    #[test]
1094    fn with_yarn_beta_slow_sets_value() {
1095        let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1096
1097        assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1098    }
1099
1100    #[test]
1101    fn with_yarn_orig_ctx_sets_value() {
1102        let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1103
1104        assert_eq!(params.yarn_orig_ctx(), 4096);
1105    }
1106
1107    #[test]
1108    fn with_defrag_thold_sets_value() {
1109        let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1110
1111        assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1112    }
1113
1114    #[test]
1115    fn with_no_perf_enables() {
1116        let params = super::LlamaContextParams::default().with_no_perf(true);
1117
1118        assert!(params.no_perf());
1119    }
1120
1121    #[test]
1122    fn with_no_perf_disables() {
1123        let params = super::LlamaContextParams::default().with_no_perf(false);
1124
1125        assert!(!params.no_perf());
1126    }
1127
1128    #[test]
1129    fn with_op_offload_enables() {
1130        let params = super::LlamaContextParams::default().with_op_offload(true);
1131
1132        assert!(params.op_offload());
1133    }
1134
1135    #[test]
1136    fn with_op_offload_disables() {
1137        let params = super::LlamaContextParams::default().with_op_offload(false);
1138
1139        assert!(!params.op_offload());
1140    }
1141
1142    #[test]
1143    fn with_kv_unified_enables() {
1144        let params = super::LlamaContextParams::default().with_kv_unified(true);
1145
1146        assert!(params.kv_unified());
1147    }
1148
1149    #[test]
1150    fn with_kv_unified_disables() {
1151        let params = super::LlamaContextParams::default().with_kv_unified(false);
1152
1153        assert!(!params.kv_unified());
1154    }
1155}