llama_cpp_bindings/context/
params.rs1use std::fmt::Debug;
2use std::num::NonZeroU32;
3
4pub use crate::context::kv_cache_type::KvCacheType;
5pub use crate::context::llama_attention_type::LlamaAttentionType;
6pub use crate::context::llama_pooling_type::LlamaPoolingType;
7pub use crate::context::rope_scaling_type::RopeScalingType;
8
9#[derive(Debug, Clone)]
10#[expect(
11 missing_docs,
12 reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
13 one inline would risk drift from the upstream spec — the doc-comment on the struct \
14 points at the canonical reference"
15)]
16#[expect(
17 clippy::module_name_repetitions,
18 reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
19 `Params` would force `params::Params` at every call site"
20)]
21pub struct LlamaContextParams {
22 pub context_params: llama_cpp_bindings_sys::llama_context_params,
23}
24
25unsafe impl Send for LlamaContextParams {}
26unsafe impl Sync for LlamaContextParams {}
27
28impl LlamaContextParams {
29 #[must_use]
30 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
31 self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
32 self
33 }
34
35 #[must_use]
36 pub const fn n_ctx(&self) -> Option<NonZeroU32> {
37 NonZeroU32::new(self.context_params.n_ctx)
38 }
39
40 #[must_use]
41 pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
42 self.context_params.n_batch = n_batch;
43 self
44 }
45
46 #[must_use]
47 pub const fn n_batch(&self) -> u32 {
48 self.context_params.n_batch
49 }
50
51 #[must_use]
52 pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
53 self.context_params.n_ubatch = n_ubatch;
54 self
55 }
56
57 #[must_use]
58 pub const fn n_ubatch(&self) -> u32 {
59 self.context_params.n_ubatch
60 }
61
62 #[must_use]
63 pub const fn with_flash_attention_policy(
64 mut self,
65 policy: llama_cpp_bindings_sys::llama_flash_attn_type,
66 ) -> Self {
67 self.context_params.flash_attn_type = policy;
68 self
69 }
70
71 #[must_use]
72 pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
73 self.context_params.flash_attn_type
74 }
75
76 #[must_use]
77 pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
78 self.context_params.offload_kqv = enabled;
79 self
80 }
81
82 #[must_use]
83 pub const fn offload_kqv(&self) -> bool {
84 self.context_params.offload_kqv
85 }
86
87 #[must_use]
88 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
89 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
90 self
91 }
92
93 #[must_use]
94 pub fn rope_scaling_type(&self) -> RopeScalingType {
95 RopeScalingType::from(self.context_params.rope_scaling_type)
96 }
97
98 #[must_use]
99 pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
100 self.context_params.rope_freq_base = rope_freq_base;
101 self
102 }
103
104 #[must_use]
105 pub const fn rope_freq_base(&self) -> f32 {
106 self.context_params.rope_freq_base
107 }
108
109 #[must_use]
110 pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
111 self.context_params.rope_freq_scale = rope_freq_scale;
112 self
113 }
114
115 #[must_use]
116 pub const fn rope_freq_scale(&self) -> f32 {
117 self.context_params.rope_freq_scale
118 }
119
120 #[must_use]
121 pub const fn n_threads(&self) -> i32 {
122 self.context_params.n_threads
123 }
124
125 #[must_use]
126 pub const fn n_threads_batch(&self) -> i32 {
127 self.context_params.n_threads_batch
128 }
129
130 #[must_use]
131 pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
132 self.context_params.n_threads = n_threads;
133 self
134 }
135
136 #[must_use]
137 pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
138 self.context_params.n_threads_batch = n_threads;
139 self
140 }
141
142 #[must_use]
143 pub const fn embeddings(&self) -> bool {
144 self.context_params.embeddings
145 }
146
147 #[must_use]
148 pub const fn with_embeddings(mut self, embedding: bool) -> Self {
149 self.context_params.embeddings = embedding;
150 self
151 }
152
153 #[must_use]
154 pub fn with_cb_eval(
155 mut self,
156 cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
157 ) -> Self {
158 self.context_params.cb_eval = cb_eval;
159 self
160 }
161
162 #[must_use]
163 pub const fn with_cb_eval_user_data(
164 mut self,
165 cb_eval_user_data: *mut std::ffi::c_void,
166 ) -> Self {
167 self.context_params.cb_eval_user_data = cb_eval_user_data;
168 self
169 }
170
171 #[must_use]
172 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
173 self.context_params.pooling_type = i32::from(pooling_type);
174 self
175 }
176
177 #[must_use]
178 pub fn pooling_type(&self) -> LlamaPoolingType {
179 LlamaPoolingType::from(self.context_params.pooling_type)
180 }
181
182 #[must_use]
183 pub const fn with_swa_full(mut self, enabled: bool) -> Self {
184 self.context_params.swa_full = enabled;
185 self
186 }
187
188 #[must_use]
189 pub const fn swa_full(&self) -> bool {
190 self.context_params.swa_full
191 }
192
193 #[must_use]
194 pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
195 self.context_params.n_seq_max = n_seq_max;
196 self
197 }
198
199 #[must_use]
200 pub const fn n_seq_max(&self) -> u32 {
201 self.context_params.n_seq_max
202 }
203 #[must_use]
204 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
205 self.context_params.type_k = type_k.into();
206 self
207 }
208
209 #[must_use]
210 pub fn type_k(&self) -> KvCacheType {
211 KvCacheType::from(self.context_params.type_k)
212 }
213
214 #[must_use]
215 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
216 self.context_params.type_v = type_v.into();
217 self
218 }
219
220 #[must_use]
221 pub fn type_v(&self) -> KvCacheType {
222 KvCacheType::from(self.context_params.type_v)
223 }
224
225 #[must_use]
226 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
227 self.context_params.attention_type = i32::from(attention_type);
228 self
229 }
230
231 #[must_use]
232 pub fn attention_type(&self) -> LlamaAttentionType {
233 LlamaAttentionType::from(self.context_params.attention_type)
234 }
235
236 #[must_use]
237 pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
238 self.context_params.yarn_ext_factor = yarn_ext_factor;
239 self
240 }
241
242 #[must_use]
243 pub const fn yarn_ext_factor(&self) -> f32 {
244 self.context_params.yarn_ext_factor
245 }
246
247 #[must_use]
248 pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
249 self.context_params.yarn_attn_factor = yarn_attn_factor;
250 self
251 }
252
253 #[must_use]
254 pub const fn yarn_attn_factor(&self) -> f32 {
255 self.context_params.yarn_attn_factor
256 }
257
258 #[must_use]
259 pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
260 self.context_params.yarn_beta_fast = yarn_beta_fast;
261 self
262 }
263
264 #[must_use]
265 pub const fn yarn_beta_fast(&self) -> f32 {
266 self.context_params.yarn_beta_fast
267 }
268
269 #[must_use]
270 pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
271 self.context_params.yarn_beta_slow = yarn_beta_slow;
272 self
273 }
274
275 #[must_use]
276 pub const fn yarn_beta_slow(&self) -> f32 {
277 self.context_params.yarn_beta_slow
278 }
279
280 #[must_use]
281 pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
282 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
283 self
284 }
285
286 #[must_use]
287 pub const fn yarn_orig_ctx(&self) -> u32 {
288 self.context_params.yarn_orig_ctx
289 }
290
291 #[must_use]
292 pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
293 self.context_params.defrag_thold = defrag_thold;
294 self
295 }
296
297 #[must_use]
298 pub const fn defrag_thold(&self) -> f32 {
299 self.context_params.defrag_thold
300 }
301
302 #[must_use]
303 pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
304 self.context_params.no_perf = no_perf;
305 self
306 }
307
308 #[must_use]
309 pub const fn no_perf(&self) -> bool {
310 self.context_params.no_perf
311 }
312
313 #[must_use]
314 pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
315 self.context_params.op_offload = op_offload;
316 self
317 }
318
319 #[must_use]
320 pub const fn op_offload(&self) -> bool {
321 self.context_params.op_offload
322 }
323
324 #[must_use]
325 pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
326 self.context_params.kv_unified = kv_unified;
327 self
328 }
329
330 #[must_use]
331 pub const fn kv_unified(&self) -> bool {
332 self.context_params.kv_unified
333 }
334}
335
336impl Default for LlamaContextParams {
337 fn default() -> Self {
338 let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
339 Self { context_params }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
346
347 #[test]
348 fn default_params_have_expected_values() {
349 let params = super::LlamaContextParams::default();
350
351 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
352 assert_eq!(params.n_batch(), 2048);
353 assert_eq!(params.n_ubatch(), 512);
354 assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
355 assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
356 }
357
358 #[test]
359 fn with_n_ctx_sets_value() {
360 let params =
361 super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
362
363 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
364 }
365
366 #[test]
367 fn with_n_ctx_none_sets_zero() {
368 let params = super::LlamaContextParams::default().with_n_ctx(None);
369
370 assert_eq!(params.n_ctx(), None);
371 }
372
373 #[test]
374 fn with_n_batch_sets_value() {
375 let params = super::LlamaContextParams::default().with_n_batch(4096);
376
377 assert_eq!(params.n_batch(), 4096);
378 }
379
380 #[test]
381 fn with_n_ubatch_sets_value() {
382 let params = super::LlamaContextParams::default().with_n_ubatch(1024);
383
384 assert_eq!(params.n_ubatch(), 1024);
385 }
386
387 #[test]
388 fn with_n_seq_max_sets_value() {
389 let params = super::LlamaContextParams::default().with_n_seq_max(64);
390
391 assert_eq!(params.n_seq_max(), 64);
392 }
393
394 #[test]
395 fn with_embeddings_enables() {
396 let params = super::LlamaContextParams::default().with_embeddings(true);
397
398 assert!(params.embeddings());
399 }
400
401 #[test]
402 fn with_embeddings_disables() {
403 let params = super::LlamaContextParams::default().with_embeddings(false);
404
405 assert!(!params.embeddings());
406 }
407
408 #[test]
409 fn with_offload_kqv_disables() {
410 let params = super::LlamaContextParams::default().with_offload_kqv(false);
411
412 assert!(!params.offload_kqv());
413 }
414
415 #[test]
416 fn with_offload_kqv_enables() {
417 let params = super::LlamaContextParams::default().with_offload_kqv(true);
418
419 assert!(params.offload_kqv());
420 }
421
422 #[test]
423 fn with_swa_full_disables() {
424 let params = super::LlamaContextParams::default().with_swa_full(false);
425
426 assert!(!params.swa_full());
427 }
428
429 #[test]
430 fn with_swa_full_enables() {
431 let params = super::LlamaContextParams::default().with_swa_full(true);
432
433 assert!(params.swa_full());
434 }
435
436 #[test]
437 fn with_rope_scaling_type_linear() {
438 let params =
439 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
440
441 assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
442 }
443
444 #[test]
445 fn with_rope_scaling_type_yarn() {
446 let params =
447 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
448
449 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
450 }
451
452 #[test]
453 fn with_rope_scaling_type_none() {
454 let params =
455 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
456
457 assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
458 }
459
460 #[test]
461 fn with_rope_freq_base_sets_value() {
462 let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
463
464 assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
465 }
466
467 #[test]
468 fn with_rope_freq_scale_sets_value() {
469 let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
470
471 assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
472 }
473
474 #[test]
475 fn with_n_threads_sets_value() {
476 let params = super::LlamaContextParams::default().with_n_threads(16);
477
478 assert_eq!(params.n_threads(), 16);
479 }
480
481 #[test]
482 fn with_n_threads_batch_sets_value() {
483 let params = super::LlamaContextParams::default().with_n_threads_batch(16);
484
485 assert_eq!(params.n_threads_batch(), 16);
486 }
487
488 #[test]
489 fn with_pooling_type_mean() {
490 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
491
492 assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
493 }
494
495 #[test]
496 fn with_pooling_type_cls() {
497 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
498
499 assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
500 }
501
502 #[test]
503 fn with_pooling_type_last() {
504 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
505
506 assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
507 }
508
509 #[test]
510 fn with_pooling_type_rank() {
511 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
512
513 assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
514 }
515
516 #[test]
517 fn with_pooling_type_none() {
518 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
519
520 assert_eq!(params.pooling_type(), LlamaPoolingType::None);
521 }
522
523 #[test]
524 fn with_type_k_sets_value() {
525 let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
526
527 assert_eq!(params.type_k(), KvCacheType::Q4_0);
528 }
529
530 #[test]
531 fn with_type_v_sets_value() {
532 let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
533
534 assert_eq!(params.type_v(), KvCacheType::Q4_1);
535 }
536
537 #[test]
538 fn with_flash_attention_policy_sets_value() {
539 let params = super::LlamaContextParams::default()
540 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
541
542 assert_eq!(
543 params.flash_attention_policy(),
544 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
545 );
546 }
547
548 #[test]
549 fn builder_chaining_preserves_all_values() {
550 let params = super::LlamaContextParams::default()
551 .with_n_ctx(std::num::NonZeroU32::new(1024))
552 .with_n_batch(4096)
553 .with_n_ubatch(256)
554 .with_n_threads(8)
555 .with_n_threads_batch(12)
556 .with_embeddings(true)
557 .with_offload_kqv(false)
558 .with_rope_scaling_type(RopeScalingType::Yarn)
559 .with_rope_freq_base(5000.0)
560 .with_rope_freq_scale(0.25);
561
562 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
563 assert_eq!(params.n_batch(), 4096);
564 assert_eq!(params.n_ubatch(), 256);
565 assert_eq!(params.n_threads(), 8);
566 assert_eq!(params.n_threads_batch(), 12);
567 assert!(params.embeddings());
568 assert!(!params.offload_kqv());
569 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
570 assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
571 assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
572 }
573
574 #[test]
575 fn with_cb_eval_sets_callback() {
576 extern "C" fn test_cb_eval(
577 _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
578 _ask: bool,
579 _user_data: *mut std::ffi::c_void,
580 ) -> bool {
581 false
582 }
583
584 let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
585
586 assert!(!result);
587
588 let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
589
590 assert!(params.context_params.cb_eval.is_some());
591 }
592
593 #[test]
594 fn with_cb_eval_user_data_sets_pointer() {
595 let mut value: i32 = 42;
596 let user_data = (&raw mut value).cast::<std::ffi::c_void>();
597 let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
598
599 assert_eq!(params.context_params.cb_eval_user_data, user_data);
600 }
601
602 #[test]
603 fn with_flash_attention_policy_disabled() {
604 let params = super::LlamaContextParams::default()
605 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
606
607 assert_eq!(
608 params.flash_attention_policy(),
609 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
610 );
611 }
612
613 #[test]
614 fn with_attention_type_causal() {
615 let params =
616 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
617
618 assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
619 }
620
621 #[test]
622 fn with_attention_type_non_causal() {
623 let params =
624 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
625
626 assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
627 }
628
629 #[test]
630 fn with_yarn_ext_factor_sets_value() {
631 let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
632
633 assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
634 }
635
636 #[test]
637 fn with_yarn_attn_factor_sets_value() {
638 let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
639
640 assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
641 }
642
643 #[test]
644 fn with_yarn_beta_fast_sets_value() {
645 let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
646
647 assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
648 }
649
650 #[test]
651 fn with_yarn_beta_slow_sets_value() {
652 let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
653
654 assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
655 }
656
657 #[test]
658 fn with_yarn_orig_ctx_sets_value() {
659 let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
660
661 assert_eq!(params.yarn_orig_ctx(), 4096);
662 }
663
664 #[test]
665 fn with_defrag_thold_sets_value() {
666 let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
667
668 assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
669 }
670
671 #[test]
672 fn with_no_perf_enables() {
673 let params = super::LlamaContextParams::default().with_no_perf(true);
674
675 assert!(params.no_perf());
676 }
677
678 #[test]
679 fn with_no_perf_disables() {
680 let params = super::LlamaContextParams::default().with_no_perf(false);
681
682 assert!(!params.no_perf());
683 }
684
685 #[test]
686 fn with_op_offload_enables() {
687 let params = super::LlamaContextParams::default().with_op_offload(true);
688
689 assert!(params.op_offload());
690 }
691
692 #[test]
693 fn with_op_offload_disables() {
694 let params = super::LlamaContextParams::default().with_op_offload(false);
695
696 assert!(!params.op_offload());
697 }
698
699 #[test]
700 fn with_kv_unified_enables() {
701 let params = super::LlamaContextParams::default().with_kv_unified(true);
702
703 assert!(params.kv_unified());
704 }
705
706 #[test]
707 fn with_kv_unified_disables() {
708 let params = super::LlamaContextParams::default().with_kv_unified(false);
709
710 assert!(!params.kv_unified());
711 }
712}