1use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 Unspecified = -1,
11 None = 0,
13 Linear = 1,
15 Yarn = 2,
17}
18
19impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 Unspecified = -1,
50 None = 0,
52 Mean = 1,
54 Cls = 2,
56 Last = 3,
58 Rank = 4,
60}
61
62impl From<i32> for LlamaPoolingType {
65 fn from(value: i32) -> Self {
66 match value {
67 0 => Self::None,
68 1 => Self::Mean,
69 2 => Self::Cls,
70 3 => Self::Last,
71 4 => Self::Rank,
72 _ => Self::Unspecified,
73 }
74 }
75}
76
77impl From<LlamaPoolingType> for i32 {
79 fn from(value: LlamaPoolingType) -> Self {
80 match value {
81 LlamaPoolingType::None => 0,
82 LlamaPoolingType::Mean => 1,
83 LlamaPoolingType::Cls => 2,
84 LlamaPoolingType::Last => 3,
85 LlamaPoolingType::Rank => 4,
86 LlamaPoolingType::Unspecified => -1,
87 }
88 }
89}
90
91#[repr(i8)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum LlamaAttentionType {
95 Unspecified = -1,
97 Causal = 0,
99 NonCausal = 1,
101}
102
103impl From<i32> for LlamaAttentionType {
104 fn from(value: i32) -> Self {
105 match value {
106 0 => Self::Causal,
107 1 => Self::NonCausal,
108 _ => Self::Unspecified,
109 }
110 }
111}
112
113impl From<LlamaAttentionType> for i32 {
114 fn from(value: LlamaAttentionType) -> Self {
115 match value {
116 LlamaAttentionType::Causal => 0,
117 LlamaAttentionType::NonCausal => 1,
118 LlamaAttentionType::Unspecified => -1,
119 }
120 }
121}
122
123#[allow(non_camel_case_types, missing_docs)]
125#[derive(Copy, Clone, Debug, PartialEq, Eq)]
126pub enum KvCacheType {
127 Unknown(llama_cpp_bindings_sys::ggml_type),
133 F32,
134 F16,
135 Q4_0,
136 Q4_1,
137 Q5_0,
138 Q5_1,
139 Q8_0,
140 Q8_1,
141 Q2_K,
142 Q3_K,
143 Q4_K,
144 Q5_K,
145 Q6_K,
146 Q8_K,
147 IQ2_XXS,
148 IQ2_XS,
149 IQ3_XXS,
150 IQ1_S,
151 IQ4_NL,
152 IQ3_S,
153 IQ2_S,
154 IQ4_XS,
155 I8,
156 I16,
157 I32,
158 I64,
159 F64,
160 IQ1_M,
161 BF16,
162 TQ1_0,
163 TQ2_0,
164 MXFP4,
165}
166
167impl From<KvCacheType> for llama_cpp_bindings_sys::ggml_type {
168 fn from(value: KvCacheType) -> Self {
169 match value {
170 KvCacheType::Unknown(raw) => raw,
171 KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32,
172 KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16,
173 KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0,
174 KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1,
175 KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0,
176 KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1,
177 KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0,
178 KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1,
179 KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K,
180 KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K,
181 KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K,
182 KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K,
183 KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K,
184 KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K,
185 KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS,
186 KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS,
187 KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS,
188 KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S,
189 KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL,
190 KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S,
191 KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S,
192 KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS,
193 KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8,
194 KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16,
195 KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32,
196 KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64,
197 KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64,
198 KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M,
199 KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16,
200 KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0,
201 KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0,
202 KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4,
203 }
204 }
205}
206
207impl From<llama_cpp_bindings_sys::ggml_type> for KvCacheType {
208 fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self {
209 match value {
210 x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32,
211 x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16,
212 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0,
213 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1,
214 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0,
215 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1,
216 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0,
217 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1,
218 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K,
219 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K,
220 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K,
221 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K,
222 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K,
223 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K,
224 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS,
225 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS,
226 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS,
227 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S,
228 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL,
229 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S,
230 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S,
231 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS,
232 x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8,
233 x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16,
234 x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32,
235 x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64,
236 x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64,
237 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M,
238 x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16,
239 x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0,
240 x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0,
241 x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4,
242 _ => Self::Unknown(value),
243 }
244 }
245}
246
247#[derive(Debug, Clone)]
263#[allow(
264 missing_docs,
265 clippy::struct_excessive_bools,
266 clippy::module_name_repetitions
267)]
268pub struct LlamaContextParams {
269 pub context_params: llama_cpp_bindings_sys::llama_context_params,
270}
271
272unsafe impl Send for LlamaContextParams {}
274unsafe impl Sync for LlamaContextParams {}
275
276impl LlamaContextParams {
277 #[must_use]
289 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
290 self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
291 self
292 }
293
294 #[must_use]
304 pub const fn n_ctx(&self) -> Option<NonZeroU32> {
305 NonZeroU32::new(self.context_params.n_ctx)
306 }
307
308 #[must_use]
320 pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
321 self.context_params.n_batch = n_batch;
322 self
323 }
324
325 #[must_use]
335 pub const fn n_batch(&self) -> u32 {
336 self.context_params.n_batch
337 }
338
339 #[must_use]
351 pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
352 self.context_params.n_ubatch = n_ubatch;
353 self
354 }
355
356 #[must_use]
366 pub const fn n_ubatch(&self) -> u32 {
367 self.context_params.n_ubatch
368 }
369
370 #[must_use]
372 pub const fn with_flash_attention_policy(
373 mut self,
374 policy: llama_cpp_bindings_sys::llama_flash_attn_type,
375 ) -> Self {
376 self.context_params.flash_attn_type = policy;
377 self
378 }
379
380 #[must_use]
382 pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
383 self.context_params.flash_attn_type
384 }
385
386 #[must_use]
397 pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
398 self.context_params.offload_kqv = enabled;
399 self
400 }
401
402 #[must_use]
412 pub const fn offload_kqv(&self) -> bool {
413 self.context_params.offload_kqv
414 }
415
416 #[must_use]
427 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
428 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
429 self
430 }
431
432 #[must_use]
441 pub fn rope_scaling_type(&self) -> RopeScalingType {
442 RopeScalingType::from(self.context_params.rope_scaling_type)
443 }
444
445 #[must_use]
456 pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
457 self.context_params.rope_freq_base = rope_freq_base;
458 self
459 }
460
461 #[must_use]
470 pub const fn rope_freq_base(&self) -> f32 {
471 self.context_params.rope_freq_base
472 }
473
474 #[must_use]
485 pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
486 self.context_params.rope_freq_scale = rope_freq_scale;
487 self
488 }
489
490 #[must_use]
499 pub const fn rope_freq_scale(&self) -> f32 {
500 self.context_params.rope_freq_scale
501 }
502
503 #[must_use]
512 pub const fn n_threads(&self) -> i32 {
513 self.context_params.n_threads
514 }
515
516 #[must_use]
525 pub const fn n_threads_batch(&self) -> i32 {
526 self.context_params.n_threads_batch
527 }
528
529 #[must_use]
540 pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
541 self.context_params.n_threads = n_threads;
542 self
543 }
544
545 #[must_use]
556 pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
557 self.context_params.n_threads_batch = n_threads;
558 self
559 }
560
561 #[must_use]
570 pub const fn embeddings(&self) -> bool {
571 self.context_params.embeddings
572 }
573
574 #[must_use]
585 pub const fn with_embeddings(mut self, embedding: bool) -> Self {
586 self.context_params.embeddings = embedding;
587 self
588 }
589
590 #[must_use]
607 pub fn with_cb_eval(
608 mut self,
609 cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
610 ) -> Self {
611 self.context_params.cb_eval = cb_eval;
612 self
613 }
614
615 #[must_use]
626 pub const fn with_cb_eval_user_data(
627 mut self,
628 cb_eval_user_data: *mut std::ffi::c_void,
629 ) -> Self {
630 self.context_params.cb_eval_user_data = cb_eval_user_data;
631 self
632 }
633
634 #[must_use]
645 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
646 self.context_params.pooling_type = i32::from(pooling_type);
647 self
648 }
649
650 #[must_use]
659 pub fn pooling_type(&self) -> LlamaPoolingType {
660 LlamaPoolingType::from(self.context_params.pooling_type)
661 }
662
663 #[must_use]
674 pub const fn with_swa_full(mut self, enabled: bool) -> Self {
675 self.context_params.swa_full = enabled;
676 self
677 }
678
679 #[must_use]
689 pub const fn swa_full(&self) -> bool {
690 self.context_params.swa_full
691 }
692
693 #[must_use]
704 pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
705 self.context_params.n_seq_max = n_seq_max;
706 self
707 }
708
709 #[must_use]
719 pub const fn n_seq_max(&self) -> u32 {
720 self.context_params.n_seq_max
721 }
722 #[must_use]
728 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
729 self.context_params.type_k = type_k.into();
730 self
731 }
732
733 #[must_use]
742 pub fn type_k(&self) -> KvCacheType {
743 KvCacheType::from(self.context_params.type_k)
744 }
745
746 #[must_use]
756 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
757 self.context_params.type_v = type_v.into();
758 self
759 }
760
761 #[must_use]
770 pub fn type_v(&self) -> KvCacheType {
771 KvCacheType::from(self.context_params.type_v)
772 }
773
774 #[must_use]
785 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
786 self.context_params.attention_type = i32::from(attention_type);
787 self
788 }
789
790 #[must_use]
799 pub fn attention_type(&self) -> LlamaAttentionType {
800 LlamaAttentionType::from(self.context_params.attention_type)
801 }
802
803 #[must_use]
814 pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
815 self.context_params.yarn_ext_factor = yarn_ext_factor;
816 self
817 }
818
819 #[must_use]
821 pub const fn yarn_ext_factor(&self) -> f32 {
822 self.context_params.yarn_ext_factor
823 }
824
825 #[must_use]
836 pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
837 self.context_params.yarn_attn_factor = yarn_attn_factor;
838 self
839 }
840
841 #[must_use]
843 pub const fn yarn_attn_factor(&self) -> f32 {
844 self.context_params.yarn_attn_factor
845 }
846
847 #[must_use]
858 pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
859 self.context_params.yarn_beta_fast = yarn_beta_fast;
860 self
861 }
862
863 #[must_use]
865 pub const fn yarn_beta_fast(&self) -> f32 {
866 self.context_params.yarn_beta_fast
867 }
868
869 #[must_use]
880 pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
881 self.context_params.yarn_beta_slow = yarn_beta_slow;
882 self
883 }
884
885 #[must_use]
887 pub const fn yarn_beta_slow(&self) -> f32 {
888 self.context_params.yarn_beta_slow
889 }
890
891 #[must_use]
902 pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
903 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
904 self
905 }
906
907 #[must_use]
909 pub const fn yarn_orig_ctx(&self) -> u32 {
910 self.context_params.yarn_orig_ctx
911 }
912
913 #[must_use]
924 pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
925 self.context_params.defrag_thold = defrag_thold;
926 self
927 }
928
929 #[must_use]
931 pub const fn defrag_thold(&self) -> f32 {
932 self.context_params.defrag_thold
933 }
934
935 #[must_use]
946 pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
947 self.context_params.no_perf = no_perf;
948 self
949 }
950
951 #[must_use]
953 pub const fn no_perf(&self) -> bool {
954 self.context_params.no_perf
955 }
956
957 #[must_use]
968 pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
969 self.context_params.op_offload = op_offload;
970 self
971 }
972
973 #[must_use]
975 pub const fn op_offload(&self) -> bool {
976 self.context_params.op_offload
977 }
978
979 #[must_use]
990 pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
991 self.context_params.kv_unified = kv_unified;
992 self
993 }
994
995 #[must_use]
997 pub const fn kv_unified(&self) -> bool {
998 self.context_params.kv_unified
999 }
1000}
1001
1002impl Default for LlamaContextParams {
1011 fn default() -> Self {
1012 let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
1013 Self { context_params }
1014 }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
1020
1021 #[test]
1022 fn rope_scaling_type_unknown_defaults_to_unspecified() {
1023 assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified);
1024 assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified);
1025 }
1026
1027 #[test]
1028 fn pooling_type_unknown_defaults_to_unspecified() {
1029 assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified);
1030 assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified);
1031 }
1032
1033 #[test]
1034 fn kv_cache_type_unknown_preserves_raw_value() {
1035 let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999;
1036 let cache_type = KvCacheType::from(unknown_raw);
1037
1038 assert_eq!(cache_type, KvCacheType::Unknown(99999));
1039
1040 let back: llama_cpp_bindings_sys::ggml_type = cache_type.into();
1041
1042 assert_eq!(back, 99999);
1043 }
1044
1045 #[test]
1046 fn default_params_have_expected_values() {
1047 let params = super::LlamaContextParams::default();
1048
1049 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
1050 assert_eq!(params.n_batch(), 2048);
1051 assert_eq!(params.n_ubatch(), 512);
1052 assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1053 assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
1054 }
1055
1056 #[test]
1057 fn with_n_ctx_sets_value() {
1058 let params =
1059 super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
1060
1061 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
1062 }
1063
1064 #[test]
1065 fn with_n_ctx_none_sets_zero() {
1066 let params = super::LlamaContextParams::default().with_n_ctx(None);
1067
1068 assert_eq!(params.n_ctx(), None);
1069 }
1070
1071 #[test]
1072 fn with_n_batch_sets_value() {
1073 let params = super::LlamaContextParams::default().with_n_batch(4096);
1074
1075 assert_eq!(params.n_batch(), 4096);
1076 }
1077
1078 #[test]
1079 fn with_n_ubatch_sets_value() {
1080 let params = super::LlamaContextParams::default().with_n_ubatch(1024);
1081
1082 assert_eq!(params.n_ubatch(), 1024);
1083 }
1084
1085 #[test]
1086 fn with_n_seq_max_sets_value() {
1087 let params = super::LlamaContextParams::default().with_n_seq_max(64);
1088
1089 assert_eq!(params.n_seq_max(), 64);
1090 }
1091
1092 #[test]
1093 fn with_embeddings_enables() {
1094 let params = super::LlamaContextParams::default().with_embeddings(true);
1095
1096 assert!(params.embeddings());
1097 }
1098
1099 #[test]
1100 fn with_embeddings_disables() {
1101 let params = super::LlamaContextParams::default().with_embeddings(false);
1102
1103 assert!(!params.embeddings());
1104 }
1105
1106 #[test]
1107 fn with_offload_kqv_disables() {
1108 let params = super::LlamaContextParams::default().with_offload_kqv(false);
1109
1110 assert!(!params.offload_kqv());
1111 }
1112
1113 #[test]
1114 fn with_offload_kqv_enables() {
1115 let params = super::LlamaContextParams::default().with_offload_kqv(true);
1116
1117 assert!(params.offload_kqv());
1118 }
1119
1120 #[test]
1121 fn with_swa_full_disables() {
1122 let params = super::LlamaContextParams::default().with_swa_full(false);
1123
1124 assert!(!params.swa_full());
1125 }
1126
1127 #[test]
1128 fn with_swa_full_enables() {
1129 let params = super::LlamaContextParams::default().with_swa_full(true);
1130
1131 assert!(params.swa_full());
1132 }
1133
1134 #[test]
1135 fn with_rope_scaling_type_linear() {
1136 let params =
1137 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
1138
1139 assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
1140 }
1141
1142 #[test]
1143 fn with_rope_scaling_type_yarn() {
1144 let params =
1145 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
1146
1147 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1148 }
1149
1150 #[test]
1151 fn with_rope_scaling_type_none() {
1152 let params =
1153 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
1154
1155 assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
1156 }
1157
1158 #[test]
1159 fn with_rope_freq_base_sets_value() {
1160 let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
1161
1162 assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
1163 }
1164
1165 #[test]
1166 fn with_rope_freq_scale_sets_value() {
1167 let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
1168
1169 assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
1170 }
1171
1172 #[test]
1173 fn with_n_threads_sets_value() {
1174 let params = super::LlamaContextParams::default().with_n_threads(16);
1175
1176 assert_eq!(params.n_threads(), 16);
1177 }
1178
1179 #[test]
1180 fn with_n_threads_batch_sets_value() {
1181 let params = super::LlamaContextParams::default().with_n_threads_batch(16);
1182
1183 assert_eq!(params.n_threads_batch(), 16);
1184 }
1185
1186 #[test]
1187 fn with_pooling_type_mean() {
1188 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
1189
1190 assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
1191 }
1192
1193 #[test]
1194 fn with_pooling_type_cls() {
1195 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
1196
1197 assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
1198 }
1199
1200 #[test]
1201 fn with_pooling_type_last() {
1202 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
1203
1204 assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
1205 }
1206
1207 #[test]
1208 fn with_pooling_type_rank() {
1209 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
1210
1211 assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
1212 }
1213
1214 #[test]
1215 fn with_pooling_type_none() {
1216 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
1217
1218 assert_eq!(params.pooling_type(), LlamaPoolingType::None);
1219 }
1220
1221 #[test]
1222 fn with_type_k_sets_value() {
1223 let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
1224
1225 assert_eq!(params.type_k(), KvCacheType::Q4_0);
1226 }
1227
1228 #[test]
1229 fn with_type_v_sets_value() {
1230 let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
1231
1232 assert_eq!(params.type_v(), KvCacheType::Q4_1);
1233 }
1234
1235 #[test]
1236 fn with_flash_attention_policy_sets_value() {
1237 let params = super::LlamaContextParams::default()
1238 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
1239
1240 assert_eq!(
1241 params.flash_attention_policy(),
1242 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
1243 );
1244 }
1245
1246 #[test]
1247 fn builder_chaining_preserves_all_values() {
1248 let params = super::LlamaContextParams::default()
1249 .with_n_ctx(std::num::NonZeroU32::new(1024))
1250 .with_n_batch(4096)
1251 .with_n_ubatch(256)
1252 .with_n_threads(8)
1253 .with_n_threads_batch(12)
1254 .with_embeddings(true)
1255 .with_offload_kqv(false)
1256 .with_rope_scaling_type(RopeScalingType::Yarn)
1257 .with_rope_freq_base(5000.0)
1258 .with_rope_freq_scale(0.25);
1259
1260 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1261 assert_eq!(params.n_batch(), 4096);
1262 assert_eq!(params.n_ubatch(), 256);
1263 assert_eq!(params.n_threads(), 8);
1264 assert_eq!(params.n_threads_batch(), 12);
1265 assert!(params.embeddings());
1266 assert!(!params.offload_kqv());
1267 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1268 assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1269 assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1270 }
1271
1272 #[test]
1273 fn rope_scaling_type_roundtrip_all_variants() {
1274 for (raw, expected) in [
1275 (-1, RopeScalingType::Unspecified),
1276 (0, RopeScalingType::None),
1277 (1, RopeScalingType::Linear),
1278 (2, RopeScalingType::Yarn),
1279 ] {
1280 let from_raw = RopeScalingType::from(raw);
1281 assert_eq!(from_raw, expected);
1282
1283 let back_to_raw: i32 = from_raw.into();
1284 assert_eq!(back_to_raw, raw);
1285 }
1286 }
1287
1288 #[test]
1289 fn pooling_type_roundtrip_all_variants() {
1290 for (raw, expected) in [
1291 (-1, LlamaPoolingType::Unspecified),
1292 (0, LlamaPoolingType::None),
1293 (1, LlamaPoolingType::Mean),
1294 (2, LlamaPoolingType::Cls),
1295 (3, LlamaPoolingType::Last),
1296 (4, LlamaPoolingType::Rank),
1297 ] {
1298 let from_raw = LlamaPoolingType::from(raw);
1299 assert_eq!(from_raw, expected);
1300
1301 let back_to_raw: i32 = from_raw.into();
1302 assert_eq!(back_to_raw, raw);
1303 }
1304 }
1305
1306 #[test]
1307 fn kv_cache_type_all_known_variants_roundtrip() {
1308 let all_variants = [
1309 KvCacheType::F32,
1310 KvCacheType::F16,
1311 KvCacheType::Q4_0,
1312 KvCacheType::Q4_1,
1313 KvCacheType::Q5_0,
1314 KvCacheType::Q5_1,
1315 KvCacheType::Q8_0,
1316 KvCacheType::Q8_1,
1317 KvCacheType::Q2_K,
1318 KvCacheType::Q3_K,
1319 KvCacheType::Q4_K,
1320 KvCacheType::Q5_K,
1321 KvCacheType::Q6_K,
1322 KvCacheType::Q8_K,
1323 KvCacheType::IQ2_XXS,
1324 KvCacheType::IQ2_XS,
1325 KvCacheType::IQ3_XXS,
1326 KvCacheType::IQ1_S,
1327 KvCacheType::IQ4_NL,
1328 KvCacheType::IQ3_S,
1329 KvCacheType::IQ2_S,
1330 KvCacheType::IQ4_XS,
1331 KvCacheType::I8,
1332 KvCacheType::I16,
1333 KvCacheType::I32,
1334 KvCacheType::I64,
1335 KvCacheType::F64,
1336 KvCacheType::IQ1_M,
1337 KvCacheType::BF16,
1338 KvCacheType::TQ1_0,
1339 KvCacheType::TQ2_0,
1340 KvCacheType::MXFP4,
1341 ];
1342
1343 for variant in all_variants {
1344 let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into();
1345 let back = KvCacheType::from(ggml_type);
1346
1347 assert_eq!(back, variant);
1348 }
1349 }
1350
1351 #[test]
1352 fn with_cb_eval_sets_callback() {
1353 extern "C" fn test_cb_eval(
1354 _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1355 _ask: bool,
1356 _user_data: *mut std::ffi::c_void,
1357 ) -> bool {
1358 false
1359 }
1360
1361 let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1362
1363 assert!(!result);
1364
1365 let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1366
1367 assert!(params.context_params.cb_eval.is_some());
1368 }
1369
1370 #[test]
1371 fn with_cb_eval_user_data_sets_pointer() {
1372 let mut value: i32 = 42;
1373 let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1374 let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1375
1376 assert_eq!(params.context_params.cb_eval_user_data, user_data);
1377 }
1378
1379 #[test]
1380 fn with_flash_attention_policy_disabled() {
1381 let params = super::LlamaContextParams::default()
1382 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1383
1384 assert_eq!(
1385 params.flash_attention_policy(),
1386 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1387 );
1388 }
1389
1390 #[test]
1391 fn attention_type_unknown_defaults_to_unspecified() {
1392 assert_eq!(
1393 LlamaAttentionType::from(99),
1394 LlamaAttentionType::Unspecified
1395 );
1396 assert_eq!(
1397 LlamaAttentionType::from(-50),
1398 LlamaAttentionType::Unspecified
1399 );
1400 }
1401
1402 #[test]
1403 fn attention_type_roundtrip_all_variants() {
1404 for (raw, expected) in [
1405 (-1, LlamaAttentionType::Unspecified),
1406 (0, LlamaAttentionType::Causal),
1407 (1, LlamaAttentionType::NonCausal),
1408 ] {
1409 let from_raw = LlamaAttentionType::from(raw);
1410 assert_eq!(from_raw, expected);
1411
1412 let back_to_raw: i32 = from_raw.into();
1413 assert_eq!(back_to_raw, raw);
1414 }
1415 }
1416
1417 #[test]
1418 fn with_attention_type_causal() {
1419 let params =
1420 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1421
1422 assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1423 }
1424
1425 #[test]
1426 fn with_attention_type_non_causal() {
1427 let params =
1428 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1429
1430 assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1431 }
1432
1433 #[test]
1434 fn with_yarn_ext_factor_sets_value() {
1435 let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1436
1437 assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1438 }
1439
1440 #[test]
1441 fn with_yarn_attn_factor_sets_value() {
1442 let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1443
1444 assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1445 }
1446
1447 #[test]
1448 fn with_yarn_beta_fast_sets_value() {
1449 let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1450
1451 assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1452 }
1453
1454 #[test]
1455 fn with_yarn_beta_slow_sets_value() {
1456 let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1457
1458 assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1459 }
1460
1461 #[test]
1462 fn with_yarn_orig_ctx_sets_value() {
1463 let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1464
1465 assert_eq!(params.yarn_orig_ctx(), 4096);
1466 }
1467
1468 #[test]
1469 fn with_defrag_thold_sets_value() {
1470 let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1471
1472 assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1473 }
1474
1475 #[test]
1476 fn with_no_perf_enables() {
1477 let params = super::LlamaContextParams::default().with_no_perf(true);
1478
1479 assert!(params.no_perf());
1480 }
1481
1482 #[test]
1483 fn with_no_perf_disables() {
1484 let params = super::LlamaContextParams::default().with_no_perf(false);
1485
1486 assert!(!params.no_perf());
1487 }
1488
1489 #[test]
1490 fn with_op_offload_enables() {
1491 let params = super::LlamaContextParams::default().with_op_offload(true);
1492
1493 assert!(params.op_offload());
1494 }
1495
1496 #[test]
1497 fn with_op_offload_disables() {
1498 let params = super::LlamaContextParams::default().with_op_offload(false);
1499
1500 assert!(!params.op_offload());
1501 }
1502
1503 #[test]
1504 fn with_kv_unified_enables() {
1505 let params = super::LlamaContextParams::default().with_kv_unified(true);
1506
1507 assert!(params.kv_unified());
1508 }
1509
1510 #[test]
1511 fn with_kv_unified_disables() {
1512 let params = super::LlamaContextParams::default().with_kv_unified(false);
1513
1514 assert!(!params.kv_unified());
1515 }
1516}