llama_cpp_2/context/params/get_set.rs
1use std::num::NonZeroU32;
2
3use super::{
4 KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaPoolingType, RopeScalingType,
5};
6
7impl LlamaContextParams {
8 /// Set the size of the context
9 ///
10 /// # Examples
11 ///
12 /// ```rust
13 /// # use std::num::NonZeroU32;
14 /// # use llama_cpp_2::context::params::LlamaContextParams;
15 /// let params = LlamaContextParams::default();
16 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
17 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
18 /// ```
19 #[must_use]
20 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
21 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
22 self
23 }
24
25 /// Get the size of the context.
26 ///
27 /// [`None`] if the context size is specified by the model and not the context.
28 ///
29 /// # Examples
30 ///
31 /// ```rust
32 /// # use llama_cpp_2::context::params::LlamaContextParams;
33 /// let params = LlamaContextParams::default();
34 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
35 /// ```
36 #[must_use]
37 pub fn n_ctx(&self) -> Option<NonZeroU32> {
38 NonZeroU32::new(self.context_params.n_ctx)
39 }
40
41 /// Set the `n_batch`
42 ///
43 /// # Examples
44 ///
45 /// ```rust
46 /// # use llama_cpp_2::context::params::LlamaContextParams;
47 /// let params = LlamaContextParams::default()
48 /// .with_n_batch(2048);
49 /// assert_eq!(params.n_batch(), 2048);
50 /// ```
51 #[must_use]
52 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
53 self.context_params.n_batch = n_batch;
54 self
55 }
56
57 /// Get the `n_batch`
58 ///
59 /// # Examples
60 ///
61 /// ```rust
62 /// # use llama_cpp_2::context::params::LlamaContextParams;
63 /// let params = LlamaContextParams::default();
64 /// assert_eq!(params.n_batch(), 2048);
65 /// ```
66 #[must_use]
67 pub fn n_batch(&self) -> u32 {
68 self.context_params.n_batch
69 }
70
71 /// Set the `n_ubatch`
72 ///
73 /// # Examples
74 ///
75 /// ```rust
76 /// # use llama_cpp_2::context::params::LlamaContextParams;
77 /// let params = LlamaContextParams::default()
78 /// .with_n_ubatch(512);
79 /// assert_eq!(params.n_ubatch(), 512);
80 /// ```
81 #[must_use]
82 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
83 self.context_params.n_ubatch = n_ubatch;
84 self
85 }
86
87 /// Get the `n_ubatch`
88 ///
89 /// # Examples
90 ///
91 /// ```rust
92 /// # use llama_cpp_2::context::params::LlamaContextParams;
93 /// let params = LlamaContextParams::default();
94 /// assert_eq!(params.n_ubatch(), 512);
95 /// ```
96 #[must_use]
97 pub fn n_ubatch(&self) -> u32 {
98 self.context_params.n_ubatch
99 }
100
101 /// Set the max number of sequences (i.e. distinct states for recurrent models)
102 ///
103 /// # Examples
104 ///
105 /// ```rust
106 /// # use llama_cpp_2::context::params::LlamaContextParams;
107 /// let params = LlamaContextParams::default()
108 /// .with_n_seq_max(64);
109 /// assert_eq!(params.n_seq_max(), 64);
110 /// ```
111 #[must_use]
112 pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
113 self.context_params.n_seq_max = n_seq_max;
114 self
115 }
116
117 /// Get the max number of sequences (i.e. distinct states for recurrent models)
118 ///
119 /// # Examples
120 ///
121 /// ```rust
122 /// # use llama_cpp_2::context::params::LlamaContextParams;
123 /// let params = LlamaContextParams::default();
124 /// assert_eq!(params.n_seq_max(), 1);
125 /// ```
126 #[must_use]
127 pub fn n_seq_max(&self) -> u32 {
128 self.context_params.n_seq_max
129 }
130
131 /// Set the number of threads
132 ///
133 /// # Examples
134 ///
135 /// ```rust
136 /// # use llama_cpp_2::context::params::LlamaContextParams;
137 /// let params = LlamaContextParams::default()
138 /// .with_n_threads(8);
139 /// assert_eq!(params.n_threads(), 8);
140 /// ```
141 #[must_use]
142 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
143 self.context_params.n_threads = n_threads;
144 self
145 }
146
147 /// Get the number of threads
148 ///
149 /// # Examples
150 ///
151 /// ```rust
152 /// # use llama_cpp_2::context::params::LlamaContextParams;
153 /// let params = LlamaContextParams::default();
154 /// assert_eq!(params.n_threads(), 4);
155 /// ```
156 #[must_use]
157 pub fn n_threads(&self) -> i32 {
158 self.context_params.n_threads
159 }
160
161 /// Set the number of threads allocated for batches
162 ///
163 /// # Examples
164 ///
165 /// ```rust
166 /// # use llama_cpp_2::context::params::LlamaContextParams;
167 /// let params = LlamaContextParams::default()
168 /// .with_n_threads_batch(8);
169 /// assert_eq!(params.n_threads_batch(), 8);
170 /// ```
171 #[must_use]
172 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
173 self.context_params.n_threads_batch = n_threads;
174 self
175 }
176
177 /// Get the number of threads allocated for batches
178 ///
179 /// # Examples
180 ///
181 /// ```rust
182 /// # use llama_cpp_2::context::params::LlamaContextParams;
183 /// let params = LlamaContextParams::default();
184 /// assert_eq!(params.n_threads_batch(), 4);
185 /// ```
186 #[must_use]
187 pub fn n_threads_batch(&self) -> i32 {
188 self.context_params.n_threads_batch
189 }
190
191 /// Set the type of rope scaling
192 ///
193 /// # Examples
194 ///
195 /// ```rust
196 /// # use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
197 /// let params = LlamaContextParams::default()
198 /// .with_rope_scaling_type(RopeScalingType::Linear);
199 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
200 /// ```
201 #[must_use]
202 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
203 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
204 self
205 }
206
207 /// Get the type of rope scaling
208 ///
209 /// # Examples
210 ///
211 /// ```rust
212 /// # use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
213 /// let params = LlamaContextParams::default();
214 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
215 /// ```
216 #[must_use]
217 pub fn rope_scaling_type(&self) -> RopeScalingType {
218 RopeScalingType::from(self.context_params.rope_scaling_type)
219 }
220
221 /// Set the type of pooling
222 ///
223 /// # Examples
224 ///
225 /// ```rust
226 /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
227 /// let params = LlamaContextParams::default()
228 /// .with_pooling_type(LlamaPoolingType::Last);
229 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
230 /// ```
231 #[must_use]
232 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
233 self.context_params.pooling_type = i32::from(pooling_type);
234 self
235 }
236
237 /// Get the type of pooling
238 ///
239 /// # Examples
240 ///
241 /// ```rust
242 /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
243 /// let params = LlamaContextParams::default();
244 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
245 /// ```
246 #[must_use]
247 pub fn pooling_type(&self) -> LlamaPoolingType {
248 LlamaPoolingType::from(self.context_params.pooling_type)
249 }
250
251 /// Set the attention type for embeddings
252 ///
253 /// # Examples
254 ///
255 /// ```rust
256 /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaAttentionType};
257 /// let params = LlamaContextParams::default()
258 /// .with_attention_type(LlamaAttentionType::Causal);
259 /// assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
260 /// ```
261 #[must_use]
262 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
263 self.context_params.attention_type = i32::from(attention_type);
264 self
265 }
266
267 /// Get the attention type for embeddings
268 ///
269 /// # Examples
270 ///
271 /// ```rust
272 /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaAttentionType};
273 /// let params = LlamaContextParams::default();
274 /// assert_eq!(params.attention_type(), LlamaAttentionType::Unspecified);
275 /// ```
276 #[must_use]
277 pub fn attention_type(&self) -> LlamaAttentionType {
278 LlamaAttentionType::from(self.context_params.attention_type)
279 }
280
281 /// Set the flash attention policy using llama.cpp enum
282 #[must_use]
283 pub fn with_flash_attention_policy(
284 mut self,
285 policy: llama_cpp_sys_2::llama_flash_attn_type,
286 ) -> Self {
287 self.context_params.flash_attn_type = policy;
288 self
289 }
290
291 /// Get the flash attention policy
292 #[must_use]
293 pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
294 self.context_params.flash_attn_type
295 }
296
297 /// Set the rope frequency base
298 ///
299 /// # Examples
300 ///
301 /// ```rust
302 /// # use llama_cpp_2::context::params::LlamaContextParams;
303 /// let params = LlamaContextParams::default()
304 /// .with_rope_freq_base(0.5);
305 /// assert_eq!(params.rope_freq_base(), 0.5);
306 /// ```
307 #[must_use]
308 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
309 self.context_params.rope_freq_base = rope_freq_base;
310 self
311 }
312
313 /// Get the rope frequency base
314 ///
315 /// # Examples
316 ///
317 /// ```rust
318 /// # use llama_cpp_2::context::params::LlamaContextParams;
319 /// let params = LlamaContextParams::default();
320 /// assert_eq!(params.rope_freq_base(), 0.0);
321 /// ```
322 #[must_use]
323 pub fn rope_freq_base(&self) -> f32 {
324 self.context_params.rope_freq_base
325 }
326
327 /// Set the rope frequency scale
328 ///
329 /// # Examples
330 ///
331 /// ```rust
332 /// # use llama_cpp_2::context::params::LlamaContextParams;
333 /// let params = LlamaContextParams::default()
334 /// .with_rope_freq_scale(0.5);
335 /// assert_eq!(params.rope_freq_scale(), 0.5);
336 /// ```
337 #[must_use]
338 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
339 self.context_params.rope_freq_scale = rope_freq_scale;
340 self
341 }
342
343 /// Get the rope frequency scale
344 ///
345 /// # Examples
346 ///
347 /// ```rust
348 /// # use llama_cpp_2::context::params::LlamaContextParams;
349 /// let params = LlamaContextParams::default();
350 /// assert_eq!(params.rope_freq_scale(), 0.0);
351 /// ```
352 #[must_use]
353 pub fn rope_freq_scale(&self) -> f32 {
354 self.context_params.rope_freq_scale
355 }
356
357 /// Set the YaRN extrapolation mix factor
358 ///
359 /// # Examples
360 ///
361 /// ```rust
362 /// # use llama_cpp_2::context::params::LlamaContextParams;
363 /// let params = LlamaContextParams::default().with_yarn_ext_factor(1.0);
364 /// assert_eq!(params.yarn_ext_factor(), 1.0);
365 /// ```
366 #[must_use]
367 pub fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
368 self.context_params.yarn_ext_factor = yarn_ext_factor;
369 self
370 }
371
372 /// Get the YaRN extrapolation mix factor
373 #[must_use]
374 pub fn yarn_ext_factor(&self) -> f32 {
375 self.context_params.yarn_ext_factor
376 }
377
378 /// Set the YaRN magnitude scaling factor
379 ///
380 /// # Examples
381 ///
382 /// ```rust
383 /// # use llama_cpp_2::context::params::LlamaContextParams;
384 /// let params = LlamaContextParams::default().with_yarn_attn_factor(2.0);
385 /// assert_eq!(params.yarn_attn_factor(), 2.0);
386 /// ```
387 #[must_use]
388 pub fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
389 self.context_params.yarn_attn_factor = yarn_attn_factor;
390 self
391 }
392
393 /// Get the YaRN magnitude scaling factor
394 #[must_use]
395 pub fn yarn_attn_factor(&self) -> f32 {
396 self.context_params.yarn_attn_factor
397 }
398
399 /// Set the YaRN low correction dim
400 ///
401 /// # Examples
402 ///
403 /// ```rust
404 /// # use llama_cpp_2::context::params::LlamaContextParams;
405 /// let params = LlamaContextParams::default().with_yarn_beta_fast(16.0);
406 /// assert_eq!(params.yarn_beta_fast(), 16.0);
407 /// ```
408 #[must_use]
409 pub fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
410 self.context_params.yarn_beta_fast = yarn_beta_fast;
411 self
412 }
413
414 /// Get the YaRN low correction dim
415 #[must_use]
416 pub fn yarn_beta_fast(&self) -> f32 {
417 self.context_params.yarn_beta_fast
418 }
419
420 /// Set the YaRN high correction dim
421 ///
422 /// # Examples
423 ///
424 /// ```rust
425 /// # use llama_cpp_2::context::params::LlamaContextParams;
426 /// let params = LlamaContextParams::default().with_yarn_beta_slow(2.0);
427 /// assert_eq!(params.yarn_beta_slow(), 2.0);
428 /// ```
429 #[must_use]
430 pub fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
431 self.context_params.yarn_beta_slow = yarn_beta_slow;
432 self
433 }
434
435 /// Get the YaRN high correction dim
436 #[must_use]
437 pub fn yarn_beta_slow(&self) -> f32 {
438 self.context_params.yarn_beta_slow
439 }
440
441 /// Set the YaRN original context size
442 ///
443 /// # Examples
444 ///
445 /// ```rust
446 /// # use llama_cpp_2::context::params::LlamaContextParams;
447 /// let params = LlamaContextParams::default().with_yarn_orig_ctx(4096);
448 /// assert_eq!(params.yarn_orig_ctx(), 4096);
449 /// ```
450 #[must_use]
451 pub fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
452 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
453 self
454 }
455
456 /// Get the YaRN original context size
457 #[must_use]
458 pub fn yarn_orig_ctx(&self) -> u32 {
459 self.context_params.yarn_orig_ctx
460 }
461
462 /// Set the KV cache defragmentation threshold
463 ///
464 /// # Examples
465 ///
466 /// ```rust
467 /// # use llama_cpp_2::context::params::LlamaContextParams;
468 /// let params = LlamaContextParams::default().with_defrag_thold(0.1);
469 /// assert_eq!(params.defrag_thold(), 0.1);
470 /// ```
471 #[must_use]
472 pub fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
473 self.context_params.defrag_thold = defrag_thold;
474 self
475 }
476
477 /// Get the KV cache defragmentation threshold
478 #[must_use]
479 pub fn defrag_thold(&self) -> f32 {
480 self.context_params.defrag_thold
481 }
482
483 /// Set the KV cache data type for K
484 ///
485 /// # Examples
486 ///
487 /// ```rust
488 /// # use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
489 /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
490 /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
491 /// ```
492 #[must_use]
493 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
494 self.context_params.type_k = type_k.into();
495 self
496 }
497
498 /// Get the KV cache data type for K
499 ///
500 /// # Examples
501 ///
502 /// ```rust
503 /// # use llama_cpp_2::context::params::LlamaContextParams;
504 /// let params = LlamaContextParams::default();
505 /// let _ = params.type_k();
506 /// ```
507 #[must_use]
508 pub fn type_k(&self) -> KvCacheType {
509 KvCacheType::from(self.context_params.type_k)
510 }
511
512 /// Set the KV cache data type for V
513 ///
514 /// # Examples
515 ///
516 /// ```rust
517 /// # use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
518 /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
519 /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
520 /// ```
521 #[must_use]
522 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
523 self.context_params.type_v = type_v.into();
524 self
525 }
526
527 /// Get the KV cache data type for V
528 ///
529 /// # Examples
530 ///
531 /// ```rust
532 /// # use llama_cpp_2::context::params::LlamaContextParams;
533 /// let params = LlamaContextParams::default();
534 /// let _ = params.type_v();
535 /// ```
536 #[must_use]
537 pub fn type_v(&self) -> KvCacheType {
538 KvCacheType::from(self.context_params.type_v)
539 }
540
541 /// Set whether embeddings are enabled
542 ///
543 /// # Examples
544 ///
545 /// ```rust
546 /// # use llama_cpp_2::context::params::LlamaContextParams;
547 /// let params = LlamaContextParams::default()
548 /// .with_embeddings(true);
549 /// assert!(params.embeddings());
550 /// ```
551 #[must_use]
552 pub fn with_embeddings(mut self, embedding: bool) -> Self {
553 self.context_params.embeddings = embedding;
554 self
555 }
556
557 /// Get whether embeddings are enabled
558 ///
559 /// # Examples
560 ///
561 /// ```rust
562 /// # use llama_cpp_2::context::params::LlamaContextParams;
563 /// let params = LlamaContextParams::default();
564 /// assert!(!params.embeddings());
565 /// ```
566 #[must_use]
567 pub fn embeddings(&self) -> bool {
568 self.context_params.embeddings
569 }
570
571 /// Set whether to offload KQV ops to GPU
572 ///
573 /// # Examples
574 ///
575 /// ```rust
576 /// # use llama_cpp_2::context::params::LlamaContextParams;
577 /// let params = LlamaContextParams::default()
578 /// .with_offload_kqv(false);
579 /// assert_eq!(params.offload_kqv(), false);
580 /// ```
581 #[must_use]
582 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
583 self.context_params.offload_kqv = enabled;
584 self
585 }
586
587 /// Get whether KQV ops are offloaded to GPU
588 ///
589 /// # Examples
590 ///
591 /// ```rust
592 /// # use llama_cpp_2::context::params::LlamaContextParams;
593 /// let params = LlamaContextParams::default();
594 /// assert_eq!(params.offload_kqv(), true);
595 /// ```
596 #[must_use]
597 pub fn offload_kqv(&self) -> bool {
598 self.context_params.offload_kqv
599 }
600
601 /// Set whether to disable performance timings
602 ///
603 /// # Examples
604 ///
605 /// ```rust
606 /// # use llama_cpp_2::context::params::LlamaContextParams;
607 /// let params = LlamaContextParams::default().with_no_perf(true);
608 /// assert!(params.no_perf());
609 /// ```
610 #[must_use]
611 pub fn with_no_perf(mut self, no_perf: bool) -> Self {
612 self.context_params.no_perf = no_perf;
613 self
614 }
615
616 /// Get whether performance timings are disabled
617 #[must_use]
618 pub fn no_perf(&self) -> bool {
619 self.context_params.no_perf
620 }
621
622 /// Set whether to offload ops to GPU
623 ///
624 /// # Examples
625 ///
626 /// ```rust
627 /// # use llama_cpp_2::context::params::LlamaContextParams;
628 /// let params = LlamaContextParams::default().with_op_offload(false);
629 /// assert_eq!(params.op_offload(), false);
630 /// ```
631 #[must_use]
632 pub fn with_op_offload(mut self, op_offload: bool) -> Self {
633 self.context_params.op_offload = op_offload;
634 self
635 }
636
637 /// Get whether ops are offloaded to GPU
638 #[must_use]
639 pub fn op_offload(&self) -> bool {
640 self.context_params.op_offload
641 }
642
643 /// Set whether to use full sliding window attention
644 ///
645 /// # Examples
646 ///
647 /// ```rust
648 /// # use llama_cpp_2::context::params::LlamaContextParams;
649 /// let params = LlamaContextParams::default()
650 /// .with_swa_full(false);
651 /// assert_eq!(params.swa_full(), false);
652 /// ```
653 #[must_use]
654 pub fn with_swa_full(mut self, enabled: bool) -> Self {
655 self.context_params.swa_full = enabled;
656 self
657 }
658
659 /// Get whether full sliding window attention is enabled
660 ///
661 /// # Examples
662 ///
663 /// ```rust
664 /// # use llama_cpp_2::context::params::LlamaContextParams;
665 /// let params = LlamaContextParams::default();
666 /// assert_eq!(params.swa_full(), true);
667 /// ```
668 #[must_use]
669 pub fn swa_full(&self) -> bool {
670 self.context_params.swa_full
671 }
672
673 /// Set whether to use a unified KV cache buffer across input sequences
674 ///
675 /// # Examples
676 ///
677 /// ```rust
678 /// # use llama_cpp_2::context::params::LlamaContextParams;
679 /// let params = LlamaContextParams::default().with_kv_unified(true);
680 /// assert!(params.kv_unified());
681 /// ```
682 #[must_use]
683 pub fn with_kv_unified(mut self, kv_unified: bool) -> Self {
684 self.context_params.kv_unified = kv_unified;
685 self
686 }
687
688 /// Get whether a unified KV cache buffer is used across input sequences
689 ///
690 /// # Examples
691 ///
692 /// ```rust
693 /// # use llama_cpp_2::context::params::LlamaContextParams;
694 /// let params = LlamaContextParams::default();
695 /// let _ = params.kv_unified();
696 /// ```
697 #[must_use]
698 pub fn kv_unified(&self) -> bool {
699 self.context_params.kv_unified
700 }
701}