Skip to main content

llama_cpp_bindings/context/
params.rs

1use std::fmt::Debug;
2use std::num::NonZeroU32;
3
4pub use crate::context::kv_cache_type::KvCacheType;
5pub use crate::context::llama_attention_type::LlamaAttentionType;
6pub use crate::context::llama_pooling_type::LlamaPoolingType;
7pub use crate::context::rope_scaling_type::RopeScalingType;
8
9#[derive(Debug, Clone)]
10#[expect(
11    missing_docs,
12    reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
13              one inline would risk drift from the upstream spec — the doc-comment on the struct \
14              points at the canonical reference"
15)]
16#[expect(
17    clippy::module_name_repetitions,
18    reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
19              `Params` would force `params::Params` at every call site"
20)]
21pub struct LlamaContextParams {
22    pub context_params: llama_cpp_bindings_sys::llama_context_params,
23}
24
25unsafe impl Send for LlamaContextParams {}
26unsafe impl Sync for LlamaContextParams {}
27
28impl LlamaContextParams {
29    #[must_use]
30    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
31        self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
32        self
33    }
34
35    #[must_use]
36    pub const fn n_ctx(&self) -> Option<NonZeroU32> {
37        NonZeroU32::new(self.context_params.n_ctx)
38    }
39
40    #[must_use]
41    pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
42        self.context_params.n_batch = n_batch;
43        self
44    }
45
46    #[must_use]
47    pub const fn n_batch(&self) -> u32 {
48        self.context_params.n_batch
49    }
50
51    #[must_use]
52    pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
53        self.context_params.n_ubatch = n_ubatch;
54        self
55    }
56
57    #[must_use]
58    pub const fn n_ubatch(&self) -> u32 {
59        self.context_params.n_ubatch
60    }
61
62    #[must_use]
63    pub const fn with_flash_attention_policy(
64        mut self,
65        policy: llama_cpp_bindings_sys::llama_flash_attn_type,
66    ) -> Self {
67        self.context_params.flash_attn_type = policy;
68        self
69    }
70
71    #[must_use]
72    pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
73        self.context_params.flash_attn_type
74    }
75
76    #[must_use]
77    pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
78        self.context_params.offload_kqv = enabled;
79        self
80    }
81
82    #[must_use]
83    pub const fn offload_kqv(&self) -> bool {
84        self.context_params.offload_kqv
85    }
86
87    #[must_use]
88    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
89        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
90        self
91    }
92
93    #[must_use]
94    pub fn rope_scaling_type(&self) -> RopeScalingType {
95        RopeScalingType::from(self.context_params.rope_scaling_type)
96    }
97
98    #[must_use]
99    pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
100        self.context_params.rope_freq_base = rope_freq_base;
101        self
102    }
103
104    #[must_use]
105    pub const fn rope_freq_base(&self) -> f32 {
106        self.context_params.rope_freq_base
107    }
108
109    #[must_use]
110    pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
111        self.context_params.rope_freq_scale = rope_freq_scale;
112        self
113    }
114
115    #[must_use]
116    pub const fn rope_freq_scale(&self) -> f32 {
117        self.context_params.rope_freq_scale
118    }
119
120    #[must_use]
121    pub const fn n_threads(&self) -> i32 {
122        self.context_params.n_threads
123    }
124
125    #[must_use]
126    pub const fn n_threads_batch(&self) -> i32 {
127        self.context_params.n_threads_batch
128    }
129
130    #[must_use]
131    pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
132        self.context_params.n_threads = n_threads;
133        self
134    }
135
136    #[must_use]
137    pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
138        self.context_params.n_threads_batch = n_threads;
139        self
140    }
141
142    #[must_use]
143    pub const fn embeddings(&self) -> bool {
144        self.context_params.embeddings
145    }
146
147    #[must_use]
148    pub const fn with_embeddings(mut self, embedding: bool) -> Self {
149        self.context_params.embeddings = embedding;
150        self
151    }
152
153    #[must_use]
154    pub fn with_cb_eval(
155        mut self,
156        cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
157    ) -> Self {
158        self.context_params.cb_eval = cb_eval;
159        self
160    }
161
162    #[must_use]
163    pub const fn with_cb_eval_user_data(
164        mut self,
165        cb_eval_user_data: *mut std::ffi::c_void,
166    ) -> Self {
167        self.context_params.cb_eval_user_data = cb_eval_user_data;
168        self
169    }
170
171    #[must_use]
172    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
173        self.context_params.pooling_type = i32::from(pooling_type);
174        self
175    }
176
177    #[must_use]
178    pub fn pooling_type(&self) -> LlamaPoolingType {
179        LlamaPoolingType::from(self.context_params.pooling_type)
180    }
181
182    #[must_use]
183    pub const fn with_swa_full(mut self, enabled: bool) -> Self {
184        self.context_params.swa_full = enabled;
185        self
186    }
187
188    #[must_use]
189    pub const fn swa_full(&self) -> bool {
190        self.context_params.swa_full
191    }
192
193    #[must_use]
194    pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
195        self.context_params.n_seq_max = n_seq_max;
196        self
197    }
198
199    #[must_use]
200    pub const fn n_seq_max(&self) -> u32 {
201        self.context_params.n_seq_max
202    }
203    #[must_use]
204    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
205        self.context_params.type_k = type_k.into();
206        self
207    }
208
209    #[must_use]
210    pub fn type_k(&self) -> KvCacheType {
211        KvCacheType::from(self.context_params.type_k)
212    }
213
214    #[must_use]
215    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
216        self.context_params.type_v = type_v.into();
217        self
218    }
219
220    #[must_use]
221    pub fn type_v(&self) -> KvCacheType {
222        KvCacheType::from(self.context_params.type_v)
223    }
224
225    #[must_use]
226    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
227        self.context_params.attention_type = i32::from(attention_type);
228        self
229    }
230
231    #[must_use]
232    pub fn attention_type(&self) -> LlamaAttentionType {
233        LlamaAttentionType::from(self.context_params.attention_type)
234    }
235
236    #[must_use]
237    pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
238        self.context_params.yarn_ext_factor = yarn_ext_factor;
239        self
240    }
241
242    #[must_use]
243    pub const fn yarn_ext_factor(&self) -> f32 {
244        self.context_params.yarn_ext_factor
245    }
246
247    #[must_use]
248    pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
249        self.context_params.yarn_attn_factor = yarn_attn_factor;
250        self
251    }
252
253    #[must_use]
254    pub const fn yarn_attn_factor(&self) -> f32 {
255        self.context_params.yarn_attn_factor
256    }
257
258    #[must_use]
259    pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
260        self.context_params.yarn_beta_fast = yarn_beta_fast;
261        self
262    }
263
264    #[must_use]
265    pub const fn yarn_beta_fast(&self) -> f32 {
266        self.context_params.yarn_beta_fast
267    }
268
269    #[must_use]
270    pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
271        self.context_params.yarn_beta_slow = yarn_beta_slow;
272        self
273    }
274
275    #[must_use]
276    pub const fn yarn_beta_slow(&self) -> f32 {
277        self.context_params.yarn_beta_slow
278    }
279
280    #[must_use]
281    pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
282        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
283        self
284    }
285
286    #[must_use]
287    pub const fn yarn_orig_ctx(&self) -> u32 {
288        self.context_params.yarn_orig_ctx
289    }
290
291    #[must_use]
292    pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
293        self.context_params.defrag_thold = defrag_thold;
294        self
295    }
296
297    #[must_use]
298    pub const fn defrag_thold(&self) -> f32 {
299        self.context_params.defrag_thold
300    }
301
302    #[must_use]
303    pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
304        self.context_params.no_perf = no_perf;
305        self
306    }
307
308    #[must_use]
309    pub const fn no_perf(&self) -> bool {
310        self.context_params.no_perf
311    }
312
313    #[must_use]
314    pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
315        self.context_params.op_offload = op_offload;
316        self
317    }
318
319    #[must_use]
320    pub const fn op_offload(&self) -> bool {
321        self.context_params.op_offload
322    }
323
324    #[must_use]
325    pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
326        self.context_params.kv_unified = kv_unified;
327        self
328    }
329
330    #[must_use]
331    pub const fn kv_unified(&self) -> bool {
332        self.context_params.kv_unified
333    }
334}
335
336impl Default for LlamaContextParams {
337    fn default() -> Self {
338        let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
339        Self { context_params }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
346
347    #[test]
348    fn default_params_have_expected_values() {
349        let params = super::LlamaContextParams::default();
350
351        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
352        assert_eq!(params.n_batch(), 2048);
353        assert_eq!(params.n_ubatch(), 512);
354        assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
355        assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
356    }
357
358    #[test]
359    fn with_n_ctx_sets_value() {
360        let params =
361            super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
362
363        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
364    }
365
366    #[test]
367    fn with_n_ctx_none_sets_zero() {
368        let params = super::LlamaContextParams::default().with_n_ctx(None);
369
370        assert_eq!(params.n_ctx(), None);
371    }
372
373    #[test]
374    fn with_n_batch_sets_value() {
375        let params = super::LlamaContextParams::default().with_n_batch(4096);
376
377        assert_eq!(params.n_batch(), 4096);
378    }
379
380    #[test]
381    fn with_n_ubatch_sets_value() {
382        let params = super::LlamaContextParams::default().with_n_ubatch(1024);
383
384        assert_eq!(params.n_ubatch(), 1024);
385    }
386
387    #[test]
388    fn with_n_seq_max_sets_value() {
389        let params = super::LlamaContextParams::default().with_n_seq_max(64);
390
391        assert_eq!(params.n_seq_max(), 64);
392    }
393
394    #[test]
395    fn with_embeddings_enables() {
396        let params = super::LlamaContextParams::default().with_embeddings(true);
397
398        assert!(params.embeddings());
399    }
400
401    #[test]
402    fn with_embeddings_disables() {
403        let params = super::LlamaContextParams::default().with_embeddings(false);
404
405        assert!(!params.embeddings());
406    }
407
408    #[test]
409    fn with_offload_kqv_disables() {
410        let params = super::LlamaContextParams::default().with_offload_kqv(false);
411
412        assert!(!params.offload_kqv());
413    }
414
415    #[test]
416    fn with_offload_kqv_enables() {
417        let params = super::LlamaContextParams::default().with_offload_kqv(true);
418
419        assert!(params.offload_kqv());
420    }
421
422    #[test]
423    fn with_swa_full_disables() {
424        let params = super::LlamaContextParams::default().with_swa_full(false);
425
426        assert!(!params.swa_full());
427    }
428
429    #[test]
430    fn with_swa_full_enables() {
431        let params = super::LlamaContextParams::default().with_swa_full(true);
432
433        assert!(params.swa_full());
434    }
435
436    #[test]
437    fn with_rope_scaling_type_linear() {
438        let params =
439            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
440
441        assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
442    }
443
444    #[test]
445    fn with_rope_scaling_type_yarn() {
446        let params =
447            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
448
449        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
450    }
451
452    #[test]
453    fn with_rope_scaling_type_none() {
454        let params =
455            super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
456
457        assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
458    }
459
460    #[test]
461    fn with_rope_freq_base_sets_value() {
462        let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
463
464        assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
465    }
466
467    #[test]
468    fn with_rope_freq_scale_sets_value() {
469        let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
470
471        assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
472    }
473
474    #[test]
475    fn with_n_threads_sets_value() {
476        let params = super::LlamaContextParams::default().with_n_threads(16);
477
478        assert_eq!(params.n_threads(), 16);
479    }
480
481    #[test]
482    fn with_n_threads_batch_sets_value() {
483        let params = super::LlamaContextParams::default().with_n_threads_batch(16);
484
485        assert_eq!(params.n_threads_batch(), 16);
486    }
487
488    #[test]
489    fn with_pooling_type_mean() {
490        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
491
492        assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
493    }
494
495    #[test]
496    fn with_pooling_type_cls() {
497        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
498
499        assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
500    }
501
502    #[test]
503    fn with_pooling_type_last() {
504        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
505
506        assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
507    }
508
509    #[test]
510    fn with_pooling_type_rank() {
511        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
512
513        assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
514    }
515
516    #[test]
517    fn with_pooling_type_none() {
518        let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
519
520        assert_eq!(params.pooling_type(), LlamaPoolingType::None);
521    }
522
523    #[test]
524    fn with_type_k_sets_value() {
525        let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
526
527        assert_eq!(params.type_k(), KvCacheType::Q4_0);
528    }
529
530    #[test]
531    fn with_type_v_sets_value() {
532        let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
533
534        assert_eq!(params.type_v(), KvCacheType::Q4_1);
535    }
536
537    #[test]
538    fn with_flash_attention_policy_sets_value() {
539        let params = super::LlamaContextParams::default()
540            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
541
542        assert_eq!(
543            params.flash_attention_policy(),
544            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
545        );
546    }
547
548    #[test]
549    fn builder_chaining_preserves_all_values() {
550        let params = super::LlamaContextParams::default()
551            .with_n_ctx(std::num::NonZeroU32::new(1024))
552            .with_n_batch(4096)
553            .with_n_ubatch(256)
554            .with_n_threads(8)
555            .with_n_threads_batch(12)
556            .with_embeddings(true)
557            .with_offload_kqv(false)
558            .with_rope_scaling_type(RopeScalingType::Yarn)
559            .with_rope_freq_base(5000.0)
560            .with_rope_freq_scale(0.25);
561
562        assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
563        assert_eq!(params.n_batch(), 4096);
564        assert_eq!(params.n_ubatch(), 256);
565        assert_eq!(params.n_threads(), 8);
566        assert_eq!(params.n_threads_batch(), 12);
567        assert!(params.embeddings());
568        assert!(!params.offload_kqv());
569        assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
570        assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
571        assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
572    }
573
574    #[test]
575    fn with_cb_eval_sets_callback() {
576        extern "C" fn test_cb_eval(
577            _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
578            _ask: bool,
579            _user_data: *mut std::ffi::c_void,
580        ) -> bool {
581            false
582        }
583
584        let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
585
586        assert!(!result);
587
588        let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
589
590        assert!(params.context_params.cb_eval.is_some());
591    }
592
593    #[test]
594    fn with_cb_eval_user_data_sets_pointer() {
595        let mut value: i32 = 42;
596        let user_data = (&raw mut value).cast::<std::ffi::c_void>();
597        let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
598
599        assert_eq!(params.context_params.cb_eval_user_data, user_data);
600    }
601
602    #[test]
603    fn with_flash_attention_policy_disabled() {
604        let params = super::LlamaContextParams::default()
605            .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
606
607        assert_eq!(
608            params.flash_attention_policy(),
609            llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
610        );
611    }
612
613    #[test]
614    fn with_attention_type_causal() {
615        let params =
616            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
617
618        assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
619    }
620
621    #[test]
622    fn with_attention_type_non_causal() {
623        let params =
624            super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
625
626        assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
627    }
628
629    #[test]
630    fn with_yarn_ext_factor_sets_value() {
631        let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
632
633        assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
634    }
635
636    #[test]
637    fn with_yarn_attn_factor_sets_value() {
638        let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
639
640        assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
641    }
642
643    #[test]
644    fn with_yarn_beta_fast_sets_value() {
645        let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
646
647        assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
648    }
649
650    #[test]
651    fn with_yarn_beta_slow_sets_value() {
652        let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
653
654        assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
655    }
656
657    #[test]
658    fn with_yarn_orig_ctx_sets_value() {
659        let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
660
661        assert_eq!(params.yarn_orig_ctx(), 4096);
662    }
663
664    #[test]
665    fn with_defrag_thold_sets_value() {
666        let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
667
668        assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
669    }
670
671    #[test]
672    fn with_no_perf_enables() {
673        let params = super::LlamaContextParams::default().with_no_perf(true);
674
675        assert!(params.no_perf());
676    }
677
678    #[test]
679    fn with_no_perf_disables() {
680        let params = super::LlamaContextParams::default().with_no_perf(false);
681
682        assert!(!params.no_perf());
683    }
684
685    #[test]
686    fn with_op_offload_enables() {
687        let params = super::LlamaContextParams::default().with_op_offload(true);
688
689        assert!(params.op_offload());
690    }
691
692    #[test]
693    fn with_op_offload_disables() {
694        let params = super::LlamaContextParams::default().with_op_offload(false);
695
696        assert!(!params.op_offload());
697    }
698
699    #[test]
700    fn with_kv_unified_enables() {
701        let params = super::LlamaContextParams::default().with_kv_unified(true);
702
703        assert!(params.kv_unified());
704    }
705
706    #[test]
707    fn with_kv_unified_disables() {
708        let params = super::LlamaContextParams::default().with_kv_unified(false);
709
710        assert!(!params.kv_unified());
711    }
712}