1use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 Unspecified = -1,
11 None = 0,
13 Linear = 1,
15 Yarn = 2,
17}
18
19impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 Unspecified = -1,
50 None = 0,
52 Mean = 1,
54 Cls = 2,
56 Last = 3,
58 Rank = 4,
60}
61
62impl From<i32> for LlamaPoolingType {
65 fn from(value: i32) -> Self {
66 match value {
67 0 => Self::None,
68 1 => Self::Mean,
69 2 => Self::Cls,
70 3 => Self::Last,
71 4 => Self::Rank,
72 _ => Self::Unspecified,
73 }
74 }
75}
76
77impl From<LlamaPoolingType> for i32 {
79 fn from(value: LlamaPoolingType) -> Self {
80 match value {
81 LlamaPoolingType::None => 0,
82 LlamaPoolingType::Mean => 1,
83 LlamaPoolingType::Cls => 2,
84 LlamaPoolingType::Last => 3,
85 LlamaPoolingType::Rank => 4,
86 LlamaPoolingType::Unspecified => -1,
87 }
88 }
89}
90
91#[repr(i8)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum LlamaAttentionType {
95 Unspecified = -1,
97 Causal = 0,
99 NonCausal = 1,
101}
102
103impl From<i32> for LlamaAttentionType {
104 fn from(value: i32) -> Self {
105 match value {
106 0 => Self::Causal,
107 1 => Self::NonCausal,
108 _ => Self::Unspecified,
109 }
110 }
111}
112
113impl From<LlamaAttentionType> for i32 {
114 fn from(value: LlamaAttentionType) -> Self {
115 match value {
116 LlamaAttentionType::Causal => 0,
117 LlamaAttentionType::NonCausal => 1,
118 LlamaAttentionType::Unspecified => -1,
119 }
120 }
121}
122
123#[expect(
125 non_camel_case_types,
126 reason = "variant names mirror llama.cpp's `enum ggml_type` symbol names verbatim so they can \
127 be matched 1:1 against the C ABI without a translation table"
128)]
129#[expect(
130 missing_docs,
131 reason = "each variant denotes a quantisation flavour whose semantics are defined upstream in \
132 ggml; restating the upstream spec inline would risk drifting from the source of truth"
133)]
134#[derive(Copy, Clone, Debug, PartialEq, Eq)]
135pub enum KvCacheType {
136 Unknown(llama_cpp_bindings_sys::ggml_type),
142 F32,
143 F16,
144 Q4_0,
145 Q4_1,
146 Q5_0,
147 Q5_1,
148 Q8_0,
149 Q8_1,
150 Q2_K,
151 Q3_K,
152 Q4_K,
153 Q5_K,
154 Q6_K,
155 Q8_K,
156 IQ2_XXS,
157 IQ2_XS,
158 IQ3_XXS,
159 IQ1_S,
160 IQ4_NL,
161 IQ3_S,
162 IQ2_S,
163 IQ4_XS,
164 I8,
165 I16,
166 I32,
167 I64,
168 F64,
169 IQ1_M,
170 BF16,
171 TQ1_0,
172 TQ2_0,
173 MXFP4,
174}
175
176impl From<KvCacheType> for llama_cpp_bindings_sys::ggml_type {
177 fn from(value: KvCacheType) -> Self {
178 match value {
179 KvCacheType::Unknown(raw) => raw,
180 KvCacheType::F32 => llama_cpp_bindings_sys::GGML_TYPE_F32,
181 KvCacheType::F16 => llama_cpp_bindings_sys::GGML_TYPE_F16,
182 KvCacheType::Q4_0 => llama_cpp_bindings_sys::GGML_TYPE_Q4_0,
183 KvCacheType::Q4_1 => llama_cpp_bindings_sys::GGML_TYPE_Q4_1,
184 KvCacheType::Q5_0 => llama_cpp_bindings_sys::GGML_TYPE_Q5_0,
185 KvCacheType::Q5_1 => llama_cpp_bindings_sys::GGML_TYPE_Q5_1,
186 KvCacheType::Q8_0 => llama_cpp_bindings_sys::GGML_TYPE_Q8_0,
187 KvCacheType::Q8_1 => llama_cpp_bindings_sys::GGML_TYPE_Q8_1,
188 KvCacheType::Q2_K => llama_cpp_bindings_sys::GGML_TYPE_Q2_K,
189 KvCacheType::Q3_K => llama_cpp_bindings_sys::GGML_TYPE_Q3_K,
190 KvCacheType::Q4_K => llama_cpp_bindings_sys::GGML_TYPE_Q4_K,
191 KvCacheType::Q5_K => llama_cpp_bindings_sys::GGML_TYPE_Q5_K,
192 KvCacheType::Q6_K => llama_cpp_bindings_sys::GGML_TYPE_Q6_K,
193 KvCacheType::Q8_K => llama_cpp_bindings_sys::GGML_TYPE_Q8_K,
194 KvCacheType::IQ2_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS,
195 KvCacheType::IQ2_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS,
196 KvCacheType::IQ3_XXS => llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS,
197 KvCacheType::IQ1_S => llama_cpp_bindings_sys::GGML_TYPE_IQ1_S,
198 KvCacheType::IQ4_NL => llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL,
199 KvCacheType::IQ3_S => llama_cpp_bindings_sys::GGML_TYPE_IQ3_S,
200 KvCacheType::IQ2_S => llama_cpp_bindings_sys::GGML_TYPE_IQ2_S,
201 KvCacheType::IQ4_XS => llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS,
202 KvCacheType::I8 => llama_cpp_bindings_sys::GGML_TYPE_I8,
203 KvCacheType::I16 => llama_cpp_bindings_sys::GGML_TYPE_I16,
204 KvCacheType::I32 => llama_cpp_bindings_sys::GGML_TYPE_I32,
205 KvCacheType::I64 => llama_cpp_bindings_sys::GGML_TYPE_I64,
206 KvCacheType::F64 => llama_cpp_bindings_sys::GGML_TYPE_F64,
207 KvCacheType::IQ1_M => llama_cpp_bindings_sys::GGML_TYPE_IQ1_M,
208 KvCacheType::BF16 => llama_cpp_bindings_sys::GGML_TYPE_BF16,
209 KvCacheType::TQ1_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ1_0,
210 KvCacheType::TQ2_0 => llama_cpp_bindings_sys::GGML_TYPE_TQ2_0,
211 KvCacheType::MXFP4 => llama_cpp_bindings_sys::GGML_TYPE_MXFP4,
212 }
213 }
214}
215
216impl From<llama_cpp_bindings_sys::ggml_type> for KvCacheType {
217 fn from(value: llama_cpp_bindings_sys::ggml_type) -> Self {
218 match value {
219 x if x == llama_cpp_bindings_sys::GGML_TYPE_F32 => Self::F32,
220 x if x == llama_cpp_bindings_sys::GGML_TYPE_F16 => Self::F16,
221 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_0 => Self::Q4_0,
222 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_1 => Self::Q4_1,
223 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_0 => Self::Q5_0,
224 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_1 => Self::Q5_1,
225 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_0 => Self::Q8_0,
226 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_1 => Self::Q8_1,
227 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q2_K => Self::Q2_K,
228 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q3_K => Self::Q3_K,
229 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q4_K => Self::Q4_K,
230 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q5_K => Self::Q5_K,
231 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q6_K => Self::Q6_K,
232 x if x == llama_cpp_bindings_sys::GGML_TYPE_Q8_K => Self::Q8_K,
233 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XXS => Self::IQ2_XXS,
234 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_XS => Self::IQ2_XS,
235 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_XXS => Self::IQ3_XXS,
236 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_S => Self::IQ1_S,
237 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_NL => Self::IQ4_NL,
238 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ3_S => Self::IQ3_S,
239 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ2_S => Self::IQ2_S,
240 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ4_XS => Self::IQ4_XS,
241 x if x == llama_cpp_bindings_sys::GGML_TYPE_I8 => Self::I8,
242 x if x == llama_cpp_bindings_sys::GGML_TYPE_I16 => Self::I16,
243 x if x == llama_cpp_bindings_sys::GGML_TYPE_I32 => Self::I32,
244 x if x == llama_cpp_bindings_sys::GGML_TYPE_I64 => Self::I64,
245 x if x == llama_cpp_bindings_sys::GGML_TYPE_F64 => Self::F64,
246 x if x == llama_cpp_bindings_sys::GGML_TYPE_IQ1_M => Self::IQ1_M,
247 x if x == llama_cpp_bindings_sys::GGML_TYPE_BF16 => Self::BF16,
248 x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ1_0 => Self::TQ1_0,
249 x if x == llama_cpp_bindings_sys::GGML_TYPE_TQ2_0 => Self::TQ2_0,
250 x if x == llama_cpp_bindings_sys::GGML_TYPE_MXFP4 => Self::MXFP4,
251 _ => Self::Unknown(value),
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
272#[expect(
273 missing_docs,
274 reason = "field meanings mirror llama.cpp's `llama_context_params` C struct; restating each \
275 one inline would risk drift from the upstream spec — the doc-comment on the struct \
276 points at the canonical reference"
277)]
278#[expect(
279 clippy::module_name_repetitions,
280 reason = "`LlamaContextParams` is the canonical Rust name in the public API; renaming it to \
281 `Params` would force `params::Params` at every call site"
282)]
283pub struct LlamaContextParams {
284 pub context_params: llama_cpp_bindings_sys::llama_context_params,
285}
286
287unsafe impl Send for LlamaContextParams {}
289unsafe impl Sync for LlamaContextParams {}
290
291impl LlamaContextParams {
292 #[must_use]
304 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
305 self.context_params.n_ctx = n_ctx.map_or(0, NonZeroU32::get);
306 self
307 }
308
309 #[must_use]
319 pub const fn n_ctx(&self) -> Option<NonZeroU32> {
320 NonZeroU32::new(self.context_params.n_ctx)
321 }
322
323 #[must_use]
335 pub const fn with_n_batch(mut self, n_batch: u32) -> Self {
336 self.context_params.n_batch = n_batch;
337 self
338 }
339
340 #[must_use]
350 pub const fn n_batch(&self) -> u32 {
351 self.context_params.n_batch
352 }
353
354 #[must_use]
366 pub const fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
367 self.context_params.n_ubatch = n_ubatch;
368 self
369 }
370
371 #[must_use]
381 pub const fn n_ubatch(&self) -> u32 {
382 self.context_params.n_ubatch
383 }
384
385 #[must_use]
387 pub const fn with_flash_attention_policy(
388 mut self,
389 policy: llama_cpp_bindings_sys::llama_flash_attn_type,
390 ) -> Self {
391 self.context_params.flash_attn_type = policy;
392 self
393 }
394
395 #[must_use]
397 pub const fn flash_attention_policy(&self) -> llama_cpp_bindings_sys::llama_flash_attn_type {
398 self.context_params.flash_attn_type
399 }
400
401 #[must_use]
412 pub const fn with_offload_kqv(mut self, enabled: bool) -> Self {
413 self.context_params.offload_kqv = enabled;
414 self
415 }
416
417 #[must_use]
427 pub const fn offload_kqv(&self) -> bool {
428 self.context_params.offload_kqv
429 }
430
431 #[must_use]
442 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
443 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
444 self
445 }
446
447 #[must_use]
456 pub fn rope_scaling_type(&self) -> RopeScalingType {
457 RopeScalingType::from(self.context_params.rope_scaling_type)
458 }
459
460 #[must_use]
471 pub const fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
472 self.context_params.rope_freq_base = rope_freq_base;
473 self
474 }
475
476 #[must_use]
485 pub const fn rope_freq_base(&self) -> f32 {
486 self.context_params.rope_freq_base
487 }
488
489 #[must_use]
500 pub const fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
501 self.context_params.rope_freq_scale = rope_freq_scale;
502 self
503 }
504
505 #[must_use]
514 pub const fn rope_freq_scale(&self) -> f32 {
515 self.context_params.rope_freq_scale
516 }
517
518 #[must_use]
527 pub const fn n_threads(&self) -> i32 {
528 self.context_params.n_threads
529 }
530
531 #[must_use]
540 pub const fn n_threads_batch(&self) -> i32 {
541 self.context_params.n_threads_batch
542 }
543
544 #[must_use]
555 pub const fn with_n_threads(mut self, n_threads: i32) -> Self {
556 self.context_params.n_threads = n_threads;
557 self
558 }
559
560 #[must_use]
571 pub const fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
572 self.context_params.n_threads_batch = n_threads;
573 self
574 }
575
576 #[must_use]
585 pub const fn embeddings(&self) -> bool {
586 self.context_params.embeddings
587 }
588
589 #[must_use]
600 pub const fn with_embeddings(mut self, embedding: bool) -> Self {
601 self.context_params.embeddings = embedding;
602 self
603 }
604
605 #[must_use]
622 pub fn with_cb_eval(
623 mut self,
624 cb_eval: llama_cpp_bindings_sys::ggml_backend_sched_eval_callback,
625 ) -> Self {
626 self.context_params.cb_eval = cb_eval;
627 self
628 }
629
630 #[must_use]
641 pub const fn with_cb_eval_user_data(
642 mut self,
643 cb_eval_user_data: *mut std::ffi::c_void,
644 ) -> Self {
645 self.context_params.cb_eval_user_data = cb_eval_user_data;
646 self
647 }
648
649 #[must_use]
660 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
661 self.context_params.pooling_type = i32::from(pooling_type);
662 self
663 }
664
665 #[must_use]
674 pub fn pooling_type(&self) -> LlamaPoolingType {
675 LlamaPoolingType::from(self.context_params.pooling_type)
676 }
677
678 #[must_use]
689 pub const fn with_swa_full(mut self, enabled: bool) -> Self {
690 self.context_params.swa_full = enabled;
691 self
692 }
693
694 #[must_use]
704 pub const fn swa_full(&self) -> bool {
705 self.context_params.swa_full
706 }
707
708 #[must_use]
719 pub const fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
720 self.context_params.n_seq_max = n_seq_max;
721 self
722 }
723
724 #[must_use]
734 pub const fn n_seq_max(&self) -> u32 {
735 self.context_params.n_seq_max
736 }
737 #[must_use]
743 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
744 self.context_params.type_k = type_k.into();
745 self
746 }
747
748 #[must_use]
757 pub fn type_k(&self) -> KvCacheType {
758 KvCacheType::from(self.context_params.type_k)
759 }
760
761 #[must_use]
771 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
772 self.context_params.type_v = type_v.into();
773 self
774 }
775
776 #[must_use]
785 pub fn type_v(&self) -> KvCacheType {
786 KvCacheType::from(self.context_params.type_v)
787 }
788
789 #[must_use]
800 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
801 self.context_params.attention_type = i32::from(attention_type);
802 self
803 }
804
805 #[must_use]
814 pub fn attention_type(&self) -> LlamaAttentionType {
815 LlamaAttentionType::from(self.context_params.attention_type)
816 }
817
818 #[must_use]
829 pub const fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
830 self.context_params.yarn_ext_factor = yarn_ext_factor;
831 self
832 }
833
834 #[must_use]
836 pub const fn yarn_ext_factor(&self) -> f32 {
837 self.context_params.yarn_ext_factor
838 }
839
840 #[must_use]
851 pub const fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
852 self.context_params.yarn_attn_factor = yarn_attn_factor;
853 self
854 }
855
856 #[must_use]
858 pub const fn yarn_attn_factor(&self) -> f32 {
859 self.context_params.yarn_attn_factor
860 }
861
862 #[must_use]
873 pub const fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
874 self.context_params.yarn_beta_fast = yarn_beta_fast;
875 self
876 }
877
878 #[must_use]
880 pub const fn yarn_beta_fast(&self) -> f32 {
881 self.context_params.yarn_beta_fast
882 }
883
884 #[must_use]
895 pub const fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
896 self.context_params.yarn_beta_slow = yarn_beta_slow;
897 self
898 }
899
900 #[must_use]
902 pub const fn yarn_beta_slow(&self) -> f32 {
903 self.context_params.yarn_beta_slow
904 }
905
906 #[must_use]
917 pub const fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
918 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
919 self
920 }
921
922 #[must_use]
924 pub const fn yarn_orig_ctx(&self) -> u32 {
925 self.context_params.yarn_orig_ctx
926 }
927
928 #[must_use]
939 pub const fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
940 self.context_params.defrag_thold = defrag_thold;
941 self
942 }
943
944 #[must_use]
946 pub const fn defrag_thold(&self) -> f32 {
947 self.context_params.defrag_thold
948 }
949
950 #[must_use]
961 pub const fn with_no_perf(mut self, no_perf: bool) -> Self {
962 self.context_params.no_perf = no_perf;
963 self
964 }
965
966 #[must_use]
968 pub const fn no_perf(&self) -> bool {
969 self.context_params.no_perf
970 }
971
972 #[must_use]
983 pub const fn with_op_offload(mut self, op_offload: bool) -> Self {
984 self.context_params.op_offload = op_offload;
985 self
986 }
987
988 #[must_use]
990 pub const fn op_offload(&self) -> bool {
991 self.context_params.op_offload
992 }
993
994 #[must_use]
1005 pub const fn with_kv_unified(mut self, kv_unified: bool) -> Self {
1006 self.context_params.kv_unified = kv_unified;
1007 self
1008 }
1009
1010 #[must_use]
1012 pub const fn kv_unified(&self) -> bool {
1013 self.context_params.kv_unified
1014 }
1015}
1016
1017impl Default for LlamaContextParams {
1026 fn default() -> Self {
1027 let context_params = unsafe { llama_cpp_bindings_sys::llama_context_default_params() };
1028 Self { context_params }
1029 }
1030}
1031
1032#[cfg(test)]
1033mod tests {
1034 use super::{KvCacheType, LlamaAttentionType, LlamaPoolingType, RopeScalingType};
1035
1036 #[test]
1037 fn rope_scaling_type_unknown_defaults_to_unspecified() {
1038 assert_eq!(RopeScalingType::from(99), RopeScalingType::Unspecified);
1039 assert_eq!(RopeScalingType::from(-100), RopeScalingType::Unspecified);
1040 }
1041
1042 #[test]
1043 fn pooling_type_unknown_defaults_to_unspecified() {
1044 assert_eq!(LlamaPoolingType::from(99), LlamaPoolingType::Unspecified);
1045 assert_eq!(LlamaPoolingType::from(-50), LlamaPoolingType::Unspecified);
1046 }
1047
1048 #[test]
1049 fn kv_cache_type_unknown_preserves_raw_value() {
1050 let unknown_raw: llama_cpp_bindings_sys::ggml_type = 99999;
1051 let cache_type = KvCacheType::from(unknown_raw);
1052
1053 assert_eq!(cache_type, KvCacheType::Unknown(99999));
1054
1055 let back: llama_cpp_bindings_sys::ggml_type = cache_type.into();
1056
1057 assert_eq!(back, 99999);
1058 }
1059
1060 #[test]
1061 fn default_params_have_expected_values() {
1062 let params = super::LlamaContextParams::default();
1063
1064 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
1065 assert_eq!(params.n_batch(), 2048);
1066 assert_eq!(params.n_ubatch(), 512);
1067 assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
1068 assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
1069 }
1070
1071 #[test]
1072 fn with_n_ctx_sets_value() {
1073 let params =
1074 super::LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(2048));
1075
1076 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(2048));
1077 }
1078
1079 #[test]
1080 fn with_n_ctx_none_sets_zero() {
1081 let params = super::LlamaContextParams::default().with_n_ctx(None);
1082
1083 assert_eq!(params.n_ctx(), None);
1084 }
1085
1086 #[test]
1087 fn with_n_batch_sets_value() {
1088 let params = super::LlamaContextParams::default().with_n_batch(4096);
1089
1090 assert_eq!(params.n_batch(), 4096);
1091 }
1092
1093 #[test]
1094 fn with_n_ubatch_sets_value() {
1095 let params = super::LlamaContextParams::default().with_n_ubatch(1024);
1096
1097 assert_eq!(params.n_ubatch(), 1024);
1098 }
1099
1100 #[test]
1101 fn with_n_seq_max_sets_value() {
1102 let params = super::LlamaContextParams::default().with_n_seq_max(64);
1103
1104 assert_eq!(params.n_seq_max(), 64);
1105 }
1106
1107 #[test]
1108 fn with_embeddings_enables() {
1109 let params = super::LlamaContextParams::default().with_embeddings(true);
1110
1111 assert!(params.embeddings());
1112 }
1113
1114 #[test]
1115 fn with_embeddings_disables() {
1116 let params = super::LlamaContextParams::default().with_embeddings(false);
1117
1118 assert!(!params.embeddings());
1119 }
1120
1121 #[test]
1122 fn with_offload_kqv_disables() {
1123 let params = super::LlamaContextParams::default().with_offload_kqv(false);
1124
1125 assert!(!params.offload_kqv());
1126 }
1127
1128 #[test]
1129 fn with_offload_kqv_enables() {
1130 let params = super::LlamaContextParams::default().with_offload_kqv(true);
1131
1132 assert!(params.offload_kqv());
1133 }
1134
1135 #[test]
1136 fn with_swa_full_disables() {
1137 let params = super::LlamaContextParams::default().with_swa_full(false);
1138
1139 assert!(!params.swa_full());
1140 }
1141
1142 #[test]
1143 fn with_swa_full_enables() {
1144 let params = super::LlamaContextParams::default().with_swa_full(true);
1145
1146 assert!(params.swa_full());
1147 }
1148
1149 #[test]
1150 fn with_rope_scaling_type_linear() {
1151 let params =
1152 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Linear);
1153
1154 assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
1155 }
1156
1157 #[test]
1158 fn with_rope_scaling_type_yarn() {
1159 let params =
1160 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::Yarn);
1161
1162 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1163 }
1164
1165 #[test]
1166 fn with_rope_scaling_type_none() {
1167 let params =
1168 super::LlamaContextParams::default().with_rope_scaling_type(RopeScalingType::None);
1169
1170 assert_eq!(params.rope_scaling_type(), RopeScalingType::None);
1171 }
1172
1173 #[test]
1174 fn with_rope_freq_base_sets_value() {
1175 let params = super::LlamaContextParams::default().with_rope_freq_base(10000.0);
1176
1177 assert!((params.rope_freq_base() - 10000.0).abs() < f32::EPSILON);
1178 }
1179
1180 #[test]
1181 fn with_rope_freq_scale_sets_value() {
1182 let params = super::LlamaContextParams::default().with_rope_freq_scale(0.5);
1183
1184 assert!((params.rope_freq_scale() - 0.5).abs() < f32::EPSILON);
1185 }
1186
1187 #[test]
1188 fn with_n_threads_sets_value() {
1189 let params = super::LlamaContextParams::default().with_n_threads(16);
1190
1191 assert_eq!(params.n_threads(), 16);
1192 }
1193
1194 #[test]
1195 fn with_n_threads_batch_sets_value() {
1196 let params = super::LlamaContextParams::default().with_n_threads_batch(16);
1197
1198 assert_eq!(params.n_threads_batch(), 16);
1199 }
1200
1201 #[test]
1202 fn with_pooling_type_mean() {
1203 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Mean);
1204
1205 assert_eq!(params.pooling_type(), LlamaPoolingType::Mean);
1206 }
1207
1208 #[test]
1209 fn with_pooling_type_cls() {
1210 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Cls);
1211
1212 assert_eq!(params.pooling_type(), LlamaPoolingType::Cls);
1213 }
1214
1215 #[test]
1216 fn with_pooling_type_last() {
1217 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Last);
1218
1219 assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
1220 }
1221
1222 #[test]
1223 fn with_pooling_type_rank() {
1224 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::Rank);
1225
1226 assert_eq!(params.pooling_type(), LlamaPoolingType::Rank);
1227 }
1228
1229 #[test]
1230 fn with_pooling_type_none() {
1231 let params = super::LlamaContextParams::default().with_pooling_type(LlamaPoolingType::None);
1232
1233 assert_eq!(params.pooling_type(), LlamaPoolingType::None);
1234 }
1235
1236 #[test]
1237 fn with_type_k_sets_value() {
1238 let params = super::LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
1239
1240 assert_eq!(params.type_k(), KvCacheType::Q4_0);
1241 }
1242
1243 #[test]
1244 fn with_type_v_sets_value() {
1245 let params = super::LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
1246
1247 assert_eq!(params.type_v(), KvCacheType::Q4_1);
1248 }
1249
1250 #[test]
1251 fn with_flash_attention_policy_sets_value() {
1252 let params = super::LlamaContextParams::default()
1253 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED);
1254
1255 assert_eq!(
1256 params.flash_attention_policy(),
1257 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_ENABLED
1258 );
1259 }
1260
1261 #[test]
1262 fn builder_chaining_preserves_all_values() {
1263 let params = super::LlamaContextParams::default()
1264 .with_n_ctx(std::num::NonZeroU32::new(1024))
1265 .with_n_batch(4096)
1266 .with_n_ubatch(256)
1267 .with_n_threads(8)
1268 .with_n_threads_batch(12)
1269 .with_embeddings(true)
1270 .with_offload_kqv(false)
1271 .with_rope_scaling_type(RopeScalingType::Yarn)
1272 .with_rope_freq_base(5000.0)
1273 .with_rope_freq_scale(0.25);
1274
1275 assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(1024));
1276 assert_eq!(params.n_batch(), 4096);
1277 assert_eq!(params.n_ubatch(), 256);
1278 assert_eq!(params.n_threads(), 8);
1279 assert_eq!(params.n_threads_batch(), 12);
1280 assert!(params.embeddings());
1281 assert!(!params.offload_kqv());
1282 assert_eq!(params.rope_scaling_type(), RopeScalingType::Yarn);
1283 assert!((params.rope_freq_base() - 5000.0).abs() < f32::EPSILON);
1284 assert!((params.rope_freq_scale() - 0.25).abs() < f32::EPSILON);
1285 }
1286
1287 #[test]
1288 fn rope_scaling_type_roundtrip_all_variants() {
1289 for (raw, expected) in [
1290 (-1, RopeScalingType::Unspecified),
1291 (0, RopeScalingType::None),
1292 (1, RopeScalingType::Linear),
1293 (2, RopeScalingType::Yarn),
1294 ] {
1295 let from_raw = RopeScalingType::from(raw);
1296 assert_eq!(from_raw, expected);
1297
1298 let back_to_raw: i32 = from_raw.into();
1299 assert_eq!(back_to_raw, raw);
1300 }
1301 }
1302
1303 #[test]
1304 fn pooling_type_roundtrip_all_variants() {
1305 for (raw, expected) in [
1306 (-1, LlamaPoolingType::Unspecified),
1307 (0, LlamaPoolingType::None),
1308 (1, LlamaPoolingType::Mean),
1309 (2, LlamaPoolingType::Cls),
1310 (3, LlamaPoolingType::Last),
1311 (4, LlamaPoolingType::Rank),
1312 ] {
1313 let from_raw = LlamaPoolingType::from(raw);
1314 assert_eq!(from_raw, expected);
1315
1316 let back_to_raw: i32 = from_raw.into();
1317 assert_eq!(back_to_raw, raw);
1318 }
1319 }
1320
1321 #[test]
1322 fn kv_cache_type_all_known_variants_roundtrip() {
1323 let all_variants = [
1324 KvCacheType::F32,
1325 KvCacheType::F16,
1326 KvCacheType::Q4_0,
1327 KvCacheType::Q4_1,
1328 KvCacheType::Q5_0,
1329 KvCacheType::Q5_1,
1330 KvCacheType::Q8_0,
1331 KvCacheType::Q8_1,
1332 KvCacheType::Q2_K,
1333 KvCacheType::Q3_K,
1334 KvCacheType::Q4_K,
1335 KvCacheType::Q5_K,
1336 KvCacheType::Q6_K,
1337 KvCacheType::Q8_K,
1338 KvCacheType::IQ2_XXS,
1339 KvCacheType::IQ2_XS,
1340 KvCacheType::IQ3_XXS,
1341 KvCacheType::IQ1_S,
1342 KvCacheType::IQ4_NL,
1343 KvCacheType::IQ3_S,
1344 KvCacheType::IQ2_S,
1345 KvCacheType::IQ4_XS,
1346 KvCacheType::I8,
1347 KvCacheType::I16,
1348 KvCacheType::I32,
1349 KvCacheType::I64,
1350 KvCacheType::F64,
1351 KvCacheType::IQ1_M,
1352 KvCacheType::BF16,
1353 KvCacheType::TQ1_0,
1354 KvCacheType::TQ2_0,
1355 KvCacheType::MXFP4,
1356 ];
1357
1358 for variant in all_variants {
1359 let ggml_type: llama_cpp_bindings_sys::ggml_type = variant.into();
1360 let back = KvCacheType::from(ggml_type);
1361
1362 assert_eq!(back, variant);
1363 }
1364 }
1365
1366 #[test]
1367 fn with_cb_eval_sets_callback() {
1368 extern "C" fn test_cb_eval(
1369 _tensor: *mut llama_cpp_bindings_sys::ggml_tensor,
1370 _ask: bool,
1371 _user_data: *mut std::ffi::c_void,
1372 ) -> bool {
1373 false
1374 }
1375
1376 let result = test_cb_eval(std::ptr::null_mut(), false, std::ptr::null_mut());
1377
1378 assert!(!result);
1379
1380 let params = super::LlamaContextParams::default().with_cb_eval(Some(test_cb_eval));
1381
1382 assert!(params.context_params.cb_eval.is_some());
1383 }
1384
1385 #[test]
1386 fn with_cb_eval_user_data_sets_pointer() {
1387 let mut value: i32 = 42;
1388 let user_data = (&raw mut value).cast::<std::ffi::c_void>();
1389 let params = super::LlamaContextParams::default().with_cb_eval_user_data(user_data);
1390
1391 assert_eq!(params.context_params.cb_eval_user_data, user_data);
1392 }
1393
1394 #[test]
1395 fn with_flash_attention_policy_disabled() {
1396 let params = super::LlamaContextParams::default()
1397 .with_flash_attention_policy(llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED);
1398
1399 assert_eq!(
1400 params.flash_attention_policy(),
1401 llama_cpp_bindings_sys::LLAMA_FLASH_ATTN_TYPE_DISABLED
1402 );
1403 }
1404
1405 #[test]
1406 fn attention_type_unknown_defaults_to_unspecified() {
1407 assert_eq!(
1408 LlamaAttentionType::from(99),
1409 LlamaAttentionType::Unspecified
1410 );
1411 assert_eq!(
1412 LlamaAttentionType::from(-50),
1413 LlamaAttentionType::Unspecified
1414 );
1415 }
1416
1417 #[test]
1418 fn attention_type_roundtrip_all_variants() {
1419 for (raw, expected) in [
1420 (-1, LlamaAttentionType::Unspecified),
1421 (0, LlamaAttentionType::Causal),
1422 (1, LlamaAttentionType::NonCausal),
1423 ] {
1424 let from_raw = LlamaAttentionType::from(raw);
1425 assert_eq!(from_raw, expected);
1426
1427 let back_to_raw: i32 = from_raw.into();
1428 assert_eq!(back_to_raw, raw);
1429 }
1430 }
1431
1432 #[test]
1433 fn with_attention_type_causal() {
1434 let params =
1435 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::Causal);
1436
1437 assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
1438 }
1439
1440 #[test]
1441 fn with_attention_type_non_causal() {
1442 let params =
1443 super::LlamaContextParams::default().with_attention_type(LlamaAttentionType::NonCausal);
1444
1445 assert_eq!(params.attention_type(), LlamaAttentionType::NonCausal);
1446 }
1447
1448 #[test]
1449 fn with_yarn_ext_factor_sets_value() {
1450 let params = super::LlamaContextParams::default().with_yarn_ext_factor(1.5);
1451
1452 assert!((params.yarn_ext_factor() - 1.5).abs() < f32::EPSILON);
1453 }
1454
1455 #[test]
1456 fn with_yarn_attn_factor_sets_value() {
1457 let params = super::LlamaContextParams::default().with_yarn_attn_factor(2.0);
1458
1459 assert!((params.yarn_attn_factor() - 2.0).abs() < f32::EPSILON);
1460 }
1461
1462 #[test]
1463 fn with_yarn_beta_fast_sets_value() {
1464 let params = super::LlamaContextParams::default().with_yarn_beta_fast(32.0);
1465
1466 assert!((params.yarn_beta_fast() - 32.0).abs() < f32::EPSILON);
1467 }
1468
1469 #[test]
1470 fn with_yarn_beta_slow_sets_value() {
1471 let params = super::LlamaContextParams::default().with_yarn_beta_slow(1.0);
1472
1473 assert!((params.yarn_beta_slow() - 1.0).abs() < f32::EPSILON);
1474 }
1475
1476 #[test]
1477 fn with_yarn_orig_ctx_sets_value() {
1478 let params = super::LlamaContextParams::default().with_yarn_orig_ctx(4096);
1479
1480 assert_eq!(params.yarn_orig_ctx(), 4096);
1481 }
1482
1483 #[test]
1484 fn with_defrag_thold_sets_value() {
1485 let params = super::LlamaContextParams::default().with_defrag_thold(0.1);
1486
1487 assert!((params.defrag_thold() - 0.1).abs() < f32::EPSILON);
1488 }
1489
1490 #[test]
1491 fn with_no_perf_enables() {
1492 let params = super::LlamaContextParams::default().with_no_perf(true);
1493
1494 assert!(params.no_perf());
1495 }
1496
1497 #[test]
1498 fn with_no_perf_disables() {
1499 let params = super::LlamaContextParams::default().with_no_perf(false);
1500
1501 assert!(!params.no_perf());
1502 }
1503
1504 #[test]
1505 fn with_op_offload_enables() {
1506 let params = super::LlamaContextParams::default().with_op_offload(true);
1507
1508 assert!(params.op_offload());
1509 }
1510
1511 #[test]
1512 fn with_op_offload_disables() {
1513 let params = super::LlamaContextParams::default().with_op_offload(false);
1514
1515 assert!(!params.op_offload());
1516 }
1517
1518 #[test]
1519 fn with_kv_unified_enables() {
1520 let params = super::LlamaContextParams::default().with_kv_unified(true);
1521
1522 assert!(params.kv_unified());
1523 }
1524
1525 #[test]
1526 fn with_kv_unified_disables() {
1527 let params = super::LlamaContextParams::default().with_kv_unified(false);
1528
1529 assert!(!params.kv_unified());
1530 }
1531}