llama_cpp_bindings/context/
params.rs1use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5pub use crate::context::kv_cache_type::KvCacheType;
6pub use crate::context::llama_attention_type::LlamaAttentionType;
7pub use crate::context::llama_pooling_type::LlamaPoolingType;
8pub use crate::context::rope_scaling_type::RopeScalingType;
9
10#[derive(Debug, Clone)]
26#[expect(
27 missing_docs,
28 reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
29 one inline would risk drift from the upstream spec — the doc-comment on the struct \
30 points at the canonical reference"
31)]
32#[expect(
33 clippy::module_name_repetitions,
34 reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
35 `Params` would force `params::Params` at every call site"
36)]
37pub struct LlamaContextParams {
38 pub context_params: llama_cpp_bindings_sys::llama_context_params,
39}
40
41unsafe impl Send for LlamaContextParams {}
43unsafe impl Sync for LlamaContextParams {}
44
45impl LlamaContextParams {
46 #[must_use]
58 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
59 self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
60 self
61 }
62
63 #[must_use]
73 pub const fn n_ctx(&self) -> Option<NonZeroU32> {
74 NonZeroU32::new(self.context_params.n_ctx)
75 }
76
77 #[must_use]
89 pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
90 self.context_params.n_batch = n_batch;
91 self
92 }
93
94 #[must_use]
104 pub const fn n_batch(&self) -> u32 {
105 self.context_params.n_batch
106 }
107
108 #[must_use]
120 pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
121 self.context_params.n_ubatch = n_ubatch;
122 self
123 }
124
125 #[must_use]
135 pub const fn n_ubatch(&self) -> u32 {
136 self.context_params.n_ubatch
137 }
138
139 #[must_use]
141 pub const fn with_flash_attention_policy(
142 mut self,
143 policy: llama_cpp_bindings_sys::llama_flash_attn_type,
144 ) -> Self {
145 self.context_params.flash_attn_type = policy;
146 self
147 }
148
149 #[must_use]
151 pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
152 self.context_params.flash_attn_type
153 }
154
155 #[must_use]
166 pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
167 self.context_params.offload_kqv = enabled;
168 self
169 }
170
171 #[must_use]
181 pub const fn offload_kqv(&self) -> bool {
182 self.context_params.offload_kqv
183 }
184
185 #[must_use]
196 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
197 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
198 self
199 }
200
201 #[must_use]
210 pub fn rope_scaling_type(&self) -> RopeScalingType {
211 RopeScalingType::from(self.context_params.rope_scaling_type)
212 }
213
214 #[must_use]
225 pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
226 self.context_params.rope_freq_base = rope_freq_base;
227 self
228 }
229
230 #[must_use]
239 pub const fn rope_freq_base(&self) -> f32 {
240 self.context_params.rope_freq_base
241 }
242
243 #[must_use]
254 pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
255 self.context_params.rope_freq_scale = rope_freq_scale;
256 self
257 }
258
259 #[must_use]
268 pub const fn rope_freq_scale(&self) -> f32 {
269 self.context_params.rope_freq_scale
270 }
271
272 #[must_use]
281 pub const fn n_threads(&self) -> i32 {
282 self.context_params.n_threads
283 }
284
285 #[must_use]
294 pub const fn n_threads_batch(&self) -> i32 {
295 self.context_params.n_threads_batch
296 }
297
298 #[must_use]
309 pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
310 self.context_params.n_threads = n_threads;
311 self
312 }
313
314 #[must_use]
325 pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
326 self.context_params.n_threads_batch = n_threads;
327 self
328 }
329
330 #[must_use]
339 pub const fn embeddings(&self) -> bool {
340 self.context_params.embeddings
341 }
342
343 #[must_use]
354 pub const fn with_embeddings(mut self, embedding: bool) -> Self {
355 self.context_params.embeddings = embedding;
356 self
357 }
358
359 #[must_use]
376 pub fn with_cb_eval(
377 mut self,
378 cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
379 ) -> Self {
380 self.context_params.cb_eval = cb_eval;
381 self
382 }
383
384 #[must_use]
395 pub const fn with_cb_eval_user_data(
396 mut self,
397 cb_eval_user_data: *mut std::ffi::c_void,
398 ) -> Self {
399 self.context_params.cb_eval_user_data = cb_eval_user_data;
400 self
401 }
402
403 #[must_use]
414 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
415 self.context_params.pooling_type = i32::from(pooling_type);
416 self
417 }
418
419 #[must_use]
428 pub fn pooling_type(&self) -> LlamaPoolingType {
429 LlamaPoolingType::from(self.context_params.pooling_type)
430 }
431
432 #[must_use]
443 pub const fn with_swa_full(mut self, enabled: bool) -> Self {
444 self.context_params.swa_full = enabled;
445 self
446 }
447
448 #[must_use]
458 pub const fn swa_full(&self) -> bool {
459 self.context_params.swa_full
460 }
461
462 #[must_use]
473 pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
474 self.context_params.n_seq_max = n_seq_max;
475 self
476 }
477
478 #[must_use]
488 pub const fn n_seq_max(&self) -> u32 {
489 self.context_params.n_seq_max
490 }
491 #[must_use]
497 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
498 self.context_params.type_k = type_k.into();
499 self
500 }
501
502 #[must_use]
511 pub fn type_k(&self) -> KvCacheType {
512 KvCacheType::from(self.context_params.type_k)
513 }
514
515 #[must_use]
525 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
526 self.context_params.type_v = type_v.into();
527 self
528 }
529
530 #[must_use]
539 pub fn type_v(&self) -> KvCacheType {
540 KvCacheType::from(self.context_params.type_v)
541 }
542
543 #[must_use]
554 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
555 self.context_params.attention_type = i32::from(attention_type);
556 self
557 }
558
559 #[must_use]
568 pub fn attention_type(&self) -> LlamaAttentionType {
569 LlamaAttentionType::from(self.context_params.attention_type)
570 }
571
572 #[must_use]
583 pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
584 self.context_params.yarn_ext_factor = yarn_ext_factor;
585 self
586 }
587
588 #[must_use]
590 pub const fn yarn_ext_factor(&self) -> f32 {
591 self.context_params.yarn_ext_factor
592 }
593
594 #[must_use]
605 pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
606 self.context_params.yarn_attn_factor = yarn_attn_factor;
607 self
608 }
609
610 #[must_use]
612 pub const fn yarn_attn_factor(&self) -> f32 {
613 self.context_params.yarn_attn_factor
614 }
615
616 #[must_use]
627 pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
628 self.context_params.yarn_beta_fast = yarn_beta_fast;
629 self
630 }
631
632 #[must_use]
634 pub const fn yarn_beta_fast(&self) -> f32 {
635 self.context_params.yarn_beta_fast
636 }
637
638 #[must_use]
649 pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
650 self.context_params.yarn_beta_slow = yarn_beta_slow;
651 self
652 }
653
654 #[must_use]
656 pub const fn yarn_beta_slow(&self) -> f32 {
657 self.context_params.yarn_beta_slow
658 }
659
660 #[must_use]
671 pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
672 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
673 self
674 }
675
676 #[must_use]
678 pub const fn yarn_orig_ctx(&self) -> u32 {
679 self.context_params.yarn_orig_ctx
680 }
681
682 #[must_use]
693 pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
694 self.context_params.defrag_thold = defrag_thold;
695 self
696 }
697
698 #[must_use]
700 pub const fn defrag_thold(&self) -> f32 {
701 self.context_params.defrag_thold
702 }
703
704 #[must_use]
715 pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
716 self.context_params.no_perf = no_perf;
717 self
718 }
719
720 #[must_use]
722 pub const fn no_perf(&self) -> bool {
723 self.context_params.no_perf
724 }
725
726 #[must_use]
737 pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
738 self.context_params.op_offload = op_offload;
739 self
740 }
741
742 #[must_use]
744 pub const fn op_offload(&self) -> bool {
745 self.context_params.op_offload
746 }
747
748 #[must_use]
759 pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
760 self.context_params.kv_unified = kv_unified;
761 self
762 }
763
764 #[must_use]
766 pub const fn kv_unified(&self) -> bool {
767 self.context_params.kv_unified
768 }
769}
770
771impl Default for LlamaContextParams {
780 fn default() -> Self {
781 let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
782 Self { context_params }
783 }
784}
785
786#[cfg(test)]
787mod tests {
788 use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
789
790 #[test]
791 fn default_params_have_expected_values() {
792 let params = super::LlamaContextParams::default();
793
794 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
795 assert_eq!(params.n_batch(), 2048);
796 assert_eq!(params.n_ubatch(), 512);
797 assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
798 assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
799 }
800
801 #[test]
802 fn with_n_ctx_sets_value() {
803 let params =
804 super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
805
806 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
807 }
808
809 #[test]
810 fn with_n_ctx_none_sets_zero() {
811 let params = super::LlamaContextParams::default().with_n_ctx(None);
812
813 assert_eq!(params.n_ctx(), None);
814 }
815
816 #[test]
817 fn with_n_batch_sets_value() {
818 let params = super::LlamaContextParams::default().with_n_batch(4096);
819
820 assert_eq!(params.n_batch(), 4096);
821 }
822
823 #[test]
824 fn with_n_ubatch_sets_value() {
825 let params = super::LlamaContextParams::default().with_n_ubatch(1024);
826
827 assert_eq!(params.n_ubatch(), 1024);
828 }
829
830 #[test]
831 fn with_n_seq_max_sets_value() {
832 let params = super::LlamaContextParams::default().with_n_seq_max(64);
833
834 assert_eq!(params.n_seq_max(), 64);
835 }
836
837 #[test]
838 fn with_embeddings_enables() {
839 let params = super::LlamaContextParams::default().with_embeddings(true);
840
841 assert!(params.embeddings());
842 }
843
844 #[test]
845 fn with_embeddings_disables() {
846 let params = super::LlamaContextParams::default().with_embeddings(false);
847
848 assert!(!params.embeddings());
849 }
850
851 #[test]
852 fn with_offload_kqv_disables() {
853 let params = super::LlamaContextParams::default().with_offload_kqv(false);
854
855 assert!(!params.offload_kqv());
856 }
857
858 #[test]
859 fn with_offload_kqv_enables() {
860 let params = super::LlamaContextParams::default().with_offload_kqv(true);
861
862 assert!(params.offload_kqv());
863 }
864
865 #[test]
866 fn with_swa_full_disables() {
867 let params = super::LlamaContextParams::default().with_swa_full(false);
868
869 assert!(!params.swa_full());
870 }
871
872 #[test]
873 fn with_swa_full_enables() {
874 let params = super::LlamaContextParams::default().with_swa_full(true);
875
876 assert!(params.swa_full());
877 }
878
879 #[test]
880 fn with_rope_scaling_type_linear() {
881 let params =
882 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
883
884 assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
885 }
886
887 #[test]
888 fn with_rope_scaling_type_yarn() {
889 let params =
890 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
891
892 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
893 }
894
895 #[test]
896 fn with_rope_scaling_type_none() {
897 let params =
898 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
899
900 assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
901 }
902
903 #[test]
904 fn with_rope_freq_base_sets_value() {
905 let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
906
907 assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
908 }
909
910 #[test]
911 fn with_rope_freq_scale_sets_value() {
912 let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
913
914 assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
915 }
916
917 #[test]
918 fn with_n_threads_sets_value() {
919 let params = super::LlamaContextParams::default().with_n_threads(16);
920
921 assert_eq!(params.n_threads(), 16);
922 }
923
924 #[test]
925 fn with_n_threads_batch_sets_value() {
926 let params = super::LlamaContextParams::default().with_n_threads_batch(16);
927
928 assert_eq!(params.n_threads_batch(), 16);
929 }
930
931 #[test]
932 fn with_pooling_type_mean() {
933 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
934
935 assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
936 }
937
938 #[test]
939 fn with_pooling_type_cls() {
940 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
941
942 assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
943 }
944
945 #[test]
946 fn with_pooling_type_last() {
947 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
948
949 assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
950 }
951
952 #[test]
953 fn with_pooling_type_rank() {
954 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
955
956 assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
957 }
958
959 #[test]
960 fn with_pooling_type_none() {
961 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
962
963 assert_eq!(params.pooling_type(), LlamaPoolingType::None);
964 }
965
966 #[test]
967 fn with_type_k_sets_value() {
968 let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
969
970 assert_eq!(params.type_k(), KvCacheType::Q4_0);
971 }
972
973 #[test]
974 fn with_type_v_sets_value() {
975 let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
976
977 assert_eq!(params.type_v(), KvCacheType::Q4_1);
978 }
979
980 #[test]
981 fn with_flash_attention_policy_sets_value() {
982 let params = super::LlamaContextParams::default()
983 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
984
985 assert_eq!(
986 params.flash_attention_policy(),
987 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
988 );
989 }
990
991 #[test]
992 fn builder_chaining_preserves_all_values() {
993 let params = super::LlamaContextParams::default()
994 .with_n_ctx(std::num::NonZeroU32::new(1024))
995 .with_n_batch(4096)
996 .with_n_ubatch(256)
997 .with_n_threads(8)
998 .with_n_threads_batch(12)
999 .with_embeddings(true)
1000 .with_offload_kqv(false)
1001 .with_rope_scaling_type(RopeScalingType::Yarn)
1002 .with_rope_freq_base(5000.0)
1003 .with_rope_freq_scale(0.25);
1004
1005 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1006 assert_eq!(params.n_batch(), 4096);
1007 assert_eq!(params.n_ubatch(), 256);
1008 assert_eq!(params.n_threads(), 8);
1009 assert_eq!(params.n_threads_batch(), 12);
1010 assert!(params.embeddings());
1011 assert!(!params.offload_kqv());
1012 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1013 assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1014 assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1015 }
1016
1017 #[test]
1018 fn with_cb_eval_sets_callback() {
1019 extern "C" fn test_cb_eval(
1020 _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1021 _ask: bool,
1022 _user_data: *mut std::ffi::c_void,
1023 ) -> bool {
1024 false
1025 }
1026
1027 let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1028
1029 assert!(!result);
1030
1031 let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1032
1033 assert!(params.context_params.cb_eval.is_some());
1034 }
1035
1036 #[test]
1037 fn with_cb_eval_user_data_sets_pointer() {
1038 let mut value: i32 = 42;
1039 let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1040 let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1041
1042 assert_eq!(params.context_params.cb_eval_user_data, user_data);
1043 }
1044
1045 #[test]
1046 fn with_flash_attention_policy_disabled() {
1047 let params = super::LlamaContextParams::default()
1048 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1049
1050 assert_eq!(
1051 params.flash_attention_policy(),
1052 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1053 );
1054 }
1055
1056 #[test]
1057 fn with_attention_type_causal() {
1058 let params =
1059 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1060
1061 assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1062 }
1063
1064 #[test]
1065 fn with_attention_type_non_causal() {
1066 let params =
1067 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1068
1069 assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1070 }
1071
1072 #[test]
1073 fn with_yarn_ext_factor_sets_value() {
1074 let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1075
1076 assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1077 }
1078
1079 #[test]
1080 fn with_yarn_attn_factor_sets_value() {
1081 let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1082
1083 assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1084 }
1085
1086 #[test]
1087 fn with_yarn_beta_fast_sets_value() {
1088 let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1089
1090 assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1091 }
1092
1093 #[test]
1094 fn with_yarn_beta_slow_sets_value() {
1095 let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1096
1097 assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1098 }
1099
1100 #[test]
1101 fn with_yarn_orig_ctx_sets_value() {
1102 let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1103
1104 assert_eq!(params.yarn_orig_ctx(), 4096);
1105 }
1106
1107 #[test]
1108 fn with_defrag_thold_sets_value() {
1109 let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1110
1111 assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1112 }
1113
1114 #[test]
1115 fn with_no_perf_enables() {
1116 let params = super::LlamaContextParams::default().with_no_perf(true);
1117
1118 assert!(params.no_perf());
1119 }
1120
1121 #[test]
1122 fn with_no_perf_disables() {
1123 let params = super::LlamaContextParams::default().with_no_perf(false);
1124
1125 assert!(!params.no_perf());
1126 }
1127
1128 #[test]
1129 fn with_op_offload_enables() {
1130 let params = super::LlamaContextParams::default().with_op_offload(true);
1131
1132 assert!(params.op_offload());
1133 }
1134
1135 #[test]
1136 fn with_op_offload_disables() {
1137 let params = super::LlamaContextParams::default().with_op_offload(false);
1138
1139 assert!(!params.op_offload());
1140 }
1141
1142 #[test]
1143 fn with_kv_unified_enables() {
1144 let params = super::LlamaContextParams::default().with_kv_unified(true);
1145
1146 assert!(params.kv_unified());
1147 }
1148
1149 #[test]
1150 fn with_kv_unified_disables() {
1151 let params = super::LlamaContextParams::default().with_kv_unified(false);
1152
1153 assert!(!params.kv_unified());
1154 }
1155}