Skip to main content

llama_cpp_2/context/params/
get_set.rs

1use std::num::NonZeroU32;
2
3use super::{
4    KvCacheType, LlamaAttentionType, LlamaContextParams, LlamaPoolingType, RopeScalingType,
5};
6
7impl LlamaContextParams {
8    /// Set the size of the context
9    ///
10    /// # Examples
11    ///
12    /// ```rust
13    /// # use std::num::NonZeroU32;
14    /// # use llama_cpp_2::context::params::LlamaContextParams;
15    /// let params = LlamaContextParams::default();
16    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
17    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
18    /// ```
19    #[must_use]
20    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
21        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
22        self
23    }
24
25    /// Get the size of the context.
26    ///
27    /// [`None`] if the context size is specified by the model and not the context.
28    ///
29    /// # Examples
30    ///
31    /// ```rust
32    /// # use llama_cpp_2::context::params::LlamaContextParams;
33    /// let params = LlamaContextParams::default();
34    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
35    /// ```
36    #[must_use]
37    pub fn n_ctx(&self) -> Option<NonZeroU32> {
38        NonZeroU32::new(self.context_params.n_ctx)
39    }
40
41    /// Set the `n_batch`
42    ///
43    /// # Examples
44    ///
45    /// ```rust
46    /// # use llama_cpp_2::context::params::LlamaContextParams;
47    /// let params = LlamaContextParams::default()
48    ///     .with_n_batch(2048);
49    /// assert_eq!(params.n_batch(), 2048);
50    /// ```
51    #[must_use]
52    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
53        self.context_params.n_batch = n_batch;
54        self
55    }
56
57    /// Get the `n_batch`
58    ///
59    /// # Examples
60    ///
61    /// ```rust
62    /// # use llama_cpp_2::context::params::LlamaContextParams;
63    /// let params = LlamaContextParams::default();
64    /// assert_eq!(params.n_batch(), 2048);
65    /// ```
66    #[must_use]
67    pub fn n_batch(&self) -> u32 {
68        self.context_params.n_batch
69    }
70
71    /// Set the `n_ubatch`
72    ///
73    /// # Examples
74    ///
75    /// ```rust
76    /// # use llama_cpp_2::context::params::LlamaContextParams;
77    /// let params = LlamaContextParams::default()
78    ///     .with_n_ubatch(512);
79    /// assert_eq!(params.n_ubatch(), 512);
80    /// ```
81    #[must_use]
82    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
83        self.context_params.n_ubatch = n_ubatch;
84        self
85    }
86
87    /// Get the `n_ubatch`
88    ///
89    /// # Examples
90    ///
91    /// ```rust
92    /// # use llama_cpp_2::context::params::LlamaContextParams;
93    /// let params = LlamaContextParams::default();
94    /// assert_eq!(params.n_ubatch(), 512);
95    /// ```
96    #[must_use]
97    pub fn n_ubatch(&self) -> u32 {
98        self.context_params.n_ubatch
99    }
100
101    /// Set the max number of sequences (i.e. distinct states for recurrent models)
102    ///
103    /// # Examples
104    ///
105    /// ```rust
106    /// # use llama_cpp_2::context::params::LlamaContextParams;
107    /// let params = LlamaContextParams::default()
108    ///     .with_n_seq_max(64);
109    /// assert_eq!(params.n_seq_max(), 64);
110    /// ```
111    #[must_use]
112    pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
113        self.context_params.n_seq_max = n_seq_max;
114        self
115    }
116
117    /// Get the max number of sequences (i.e. distinct states for recurrent models)
118    ///
119    /// # Examples
120    ///
121    /// ```rust
122    /// # use llama_cpp_2::context::params::LlamaContextParams;
123    /// let params = LlamaContextParams::default();
124    /// assert_eq!(params.n_seq_max(), 1);
125    /// ```
126    #[must_use]
127    pub fn n_seq_max(&self) -> u32 {
128        self.context_params.n_seq_max
129    }
130
131    /// Set the number of threads
132    ///
133    /// # Examples
134    ///
135    /// ```rust
136    /// # use llama_cpp_2::context::params::LlamaContextParams;
137    /// let params = LlamaContextParams::default()
138    ///    .with_n_threads(8);
139    /// assert_eq!(params.n_threads(), 8);
140    /// ```
141    #[must_use]
142    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
143        self.context_params.n_threads = n_threads;
144        self
145    }
146
147    /// Get the number of threads
148    ///
149    /// # Examples
150    ///
151    /// ```rust
152    /// # use llama_cpp_2::context::params::LlamaContextParams;
153    /// let params = LlamaContextParams::default();
154    /// assert_eq!(params.n_threads(), 4);
155    /// ```
156    #[must_use]
157    pub fn n_threads(&self) -> i32 {
158        self.context_params.n_threads
159    }
160
161    /// Set the number of threads allocated for batches
162    ///
163    /// # Examples
164    ///
165    /// ```rust
166    /// # use llama_cpp_2::context::params::LlamaContextParams;
167    /// let params = LlamaContextParams::default()
168    ///    .with_n_threads_batch(8);
169    /// assert_eq!(params.n_threads_batch(), 8);
170    /// ```
171    #[must_use]
172    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
173        self.context_params.n_threads_batch = n_threads;
174        self
175    }
176
177    /// Get the number of threads allocated for batches
178    ///
179    /// # Examples
180    ///
181    /// ```rust
182    /// # use llama_cpp_2::context::params::LlamaContextParams;
183    /// let params = LlamaContextParams::default();
184    /// assert_eq!(params.n_threads_batch(), 4);
185    /// ```
186    #[must_use]
187    pub fn n_threads_batch(&self) -> i32 {
188        self.context_params.n_threads_batch
189    }
190
191    /// Set the type of rope scaling
192    ///
193    /// # Examples
194    ///
195    /// ```rust
196    /// # use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
197    /// let params = LlamaContextParams::default()
198    ///     .with_rope_scaling_type(RopeScalingType::Linear);
199    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
200    /// ```
201    #[must_use]
202    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
203        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
204        self
205    }
206
207    /// Get the type of rope scaling
208    ///
209    /// # Examples
210    ///
211    /// ```rust
212    /// # use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
213    /// let params = LlamaContextParams::default();
214    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
215    /// ```
216    #[must_use]
217    pub fn rope_scaling_type(&self) -> RopeScalingType {
218        RopeScalingType::from(self.context_params.rope_scaling_type)
219    }
220
221    /// Set the type of pooling
222    ///
223    /// # Examples
224    ///
225    /// ```rust
226    /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
227    /// let params = LlamaContextParams::default()
228    ///     .with_pooling_type(LlamaPoolingType::Last);
229    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
230    /// ```
231    #[must_use]
232    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
233        self.context_params.pooling_type = i32::from(pooling_type);
234        self
235    }
236
237    /// Get the type of pooling
238    ///
239    /// # Examples
240    ///
241    /// ```rust
242    /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
243    /// let params = LlamaContextParams::default();
244    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Unspecified);
245    /// ```
246    #[must_use]
247    pub fn pooling_type(&self) -> LlamaPoolingType {
248        LlamaPoolingType::from(self.context_params.pooling_type)
249    }
250
251    /// Set the attention type for embeddings
252    ///
253    /// # Examples
254    ///
255    /// ```rust
256    /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaAttentionType};
257    /// let params = LlamaContextParams::default()
258    ///     .with_attention_type(LlamaAttentionType::Causal);
259    /// assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
260    /// ```
261    #[must_use]
262    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
263        self.context_params.attention_type = i32::from(attention_type);
264        self
265    }
266
267    /// Get the attention type for embeddings
268    ///
269    /// # Examples
270    ///
271    /// ```rust
272    /// # use llama_cpp_2::context::params::{LlamaContextParams, LlamaAttentionType};
273    /// let params = LlamaContextParams::default();
274    /// assert_eq!(params.attention_type(), LlamaAttentionType::Unspecified);
275    /// ```
276    #[must_use]
277    pub fn attention_type(&self) -> LlamaAttentionType {
278        LlamaAttentionType::from(self.context_params.attention_type)
279    }
280
281    /// Set the flash attention policy using llama.cpp enum
282    #[must_use]
283    pub fn with_flash_attention_policy(
284        mut self,
285        policy: llama_cpp_sys_2::llama_flash_attn_type,
286    ) -> Self {
287        self.context_params.flash_attn_type = policy;
288        self
289    }
290
291    /// Get the flash attention policy
292    #[must_use]
293    pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
294        self.context_params.flash_attn_type
295    }
296
297    /// Set the rope frequency base
298    ///
299    /// # Examples
300    ///
301    /// ```rust
302    /// # use llama_cpp_2::context::params::LlamaContextParams;
303    /// let params = LlamaContextParams::default()
304    ///    .with_rope_freq_base(0.5);
305    /// assert_eq!(params.rope_freq_base(), 0.5);
306    /// ```
307    #[must_use]
308    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
309        self.context_params.rope_freq_base = rope_freq_base;
310        self
311    }
312
313    /// Get the rope frequency base
314    ///
315    /// # Examples
316    ///
317    /// ```rust
318    /// # use llama_cpp_2::context::params::LlamaContextParams;
319    /// let params = LlamaContextParams::default();
320    /// assert_eq!(params.rope_freq_base(), 0.0);
321    /// ```
322    #[must_use]
323    pub fn rope_freq_base(&self) -> f32 {
324        self.context_params.rope_freq_base
325    }
326
327    /// Set the rope frequency scale
328    ///
329    /// # Examples
330    ///
331    /// ```rust
332    /// # use llama_cpp_2::context::params::LlamaContextParams;
333    /// let params = LlamaContextParams::default()
334    ///   .with_rope_freq_scale(0.5);
335    /// assert_eq!(params.rope_freq_scale(), 0.5);
336    /// ```
337    #[must_use]
338    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
339        self.context_params.rope_freq_scale = rope_freq_scale;
340        self
341    }
342
343    /// Get the rope frequency scale
344    ///
345    /// # Examples
346    ///
347    /// ```rust
348    /// # use llama_cpp_2::context::params::LlamaContextParams;
349    /// let params = LlamaContextParams::default();
350    /// assert_eq!(params.rope_freq_scale(), 0.0);
351    /// ```
352    #[must_use]
353    pub fn rope_freq_scale(&self) -> f32 {
354        self.context_params.rope_freq_scale
355    }
356
357    /// Set the YaRN extrapolation mix factor
358    ///
359    /// # Examples
360    ///
361    /// ```rust
362    /// # use llama_cpp_2::context::params::LlamaContextParams;
363    /// let params = LlamaContextParams::default().with_yarn_ext_factor(1.0);
364    /// assert_eq!(params.yarn_ext_factor(), 1.0);
365    /// ```
366    #[must_use]
367    pub fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
368        self.context_params.yarn_ext_factor = yarn_ext_factor;
369        self
370    }
371
372    /// Get the YaRN extrapolation mix factor
373    #[must_use]
374    pub fn yarn_ext_factor(&self) -> f32 {
375        self.context_params.yarn_ext_factor
376    }
377
378    /// Set the YaRN magnitude scaling factor
379    ///
380    /// # Examples
381    ///
382    /// ```rust
383    /// # use llama_cpp_2::context::params::LlamaContextParams;
384    /// let params = LlamaContextParams::default().with_yarn_attn_factor(2.0);
385    /// assert_eq!(params.yarn_attn_factor(), 2.0);
386    /// ```
387    #[must_use]
388    pub fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
389        self.context_params.yarn_attn_factor = yarn_attn_factor;
390        self
391    }
392
393    /// Get the YaRN magnitude scaling factor
394    #[must_use]
395    pub fn yarn_attn_factor(&self) -> f32 {
396        self.context_params.yarn_attn_factor
397    }
398
399    /// Set the YaRN low correction dim
400    ///
401    /// # Examples
402    ///
403    /// ```rust
404    /// # use llama_cpp_2::context::params::LlamaContextParams;
405    /// let params = LlamaContextParams::default().with_yarn_beta_fast(16.0);
406    /// assert_eq!(params.yarn_beta_fast(), 16.0);
407    /// ```
408    #[must_use]
409    pub fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
410        self.context_params.yarn_beta_fast = yarn_beta_fast;
411        self
412    }
413
414    /// Get the YaRN low correction dim
415    #[must_use]
416    pub fn yarn_beta_fast(&self) -> f32 {
417        self.context_params.yarn_beta_fast
418    }
419
420    /// Set the YaRN high correction dim
421    ///
422    /// # Examples
423    ///
424    /// ```rust
425    /// # use llama_cpp_2::context::params::LlamaContextParams;
426    /// let params = LlamaContextParams::default().with_yarn_beta_slow(2.0);
427    /// assert_eq!(params.yarn_beta_slow(), 2.0);
428    /// ```
429    #[must_use]
430    pub fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
431        self.context_params.yarn_beta_slow = yarn_beta_slow;
432        self
433    }
434
435    /// Get the YaRN high correction dim
436    #[must_use]
437    pub fn yarn_beta_slow(&self) -> f32 {
438        self.context_params.yarn_beta_slow
439    }
440
441    /// Set the YaRN original context size
442    ///
443    /// # Examples
444    ///
445    /// ```rust
446    /// # use llama_cpp_2::context::params::LlamaContextParams;
447    /// let params = LlamaContextParams::default().with_yarn_orig_ctx(4096);
448    /// assert_eq!(params.yarn_orig_ctx(), 4096);
449    /// ```
450    #[must_use]
451    pub fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
452        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
453        self
454    }
455
456    /// Get the YaRN original context size
457    #[must_use]
458    pub fn yarn_orig_ctx(&self) -> u32 {
459        self.context_params.yarn_orig_ctx
460    }
461
462    /// Set the KV cache defragmentation threshold
463    ///
464    /// # Examples
465    ///
466    /// ```rust
467    /// # use llama_cpp_2::context::params::LlamaContextParams;
468    /// let params = LlamaContextParams::default().with_defrag_thold(0.1);
469    /// assert_eq!(params.defrag_thold(), 0.1);
470    /// ```
471    #[must_use]
472    pub fn with_defrag_thold(mut self, defrag_thold: f32) -> Self {
473        self.context_params.defrag_thold = defrag_thold;
474        self
475    }
476
477    /// Get the KV cache defragmentation threshold
478    #[must_use]
479    pub fn defrag_thold(&self) -> f32 {
480        self.context_params.defrag_thold
481    }
482
483    /// Set the KV cache data type for K
484    ///
485    /// # Examples
486    ///
487    /// ```rust
488    /// # use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
489    /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
490    /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
491    /// ```
492    #[must_use]
493    pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
494        self.context_params.type_k = type_k.into();
495        self
496    }
497
498    /// Get the KV cache data type for K
499    ///
500    /// # Examples
501    ///
502    /// ```rust
503    /// # use llama_cpp_2::context::params::LlamaContextParams;
504    /// let params = LlamaContextParams::default();
505    /// let _ = params.type_k();
506    /// ```
507    #[must_use]
508    pub fn type_k(&self) -> KvCacheType {
509        KvCacheType::from(self.context_params.type_k)
510    }
511
512    /// Set the KV cache data type for V
513    ///
514    /// # Examples
515    ///
516    /// ```rust
517    /// # use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
518    /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
519    /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
520    /// ```
521    #[must_use]
522    pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
523        self.context_params.type_v = type_v.into();
524        self
525    }
526
527    /// Get the KV cache data type for V
528    ///
529    /// # Examples
530    ///
531    /// ```rust
532    /// # use llama_cpp_2::context::params::LlamaContextParams;
533    /// let params = LlamaContextParams::default();
534    /// let _ = params.type_v();
535    /// ```
536    #[must_use]
537    pub fn type_v(&self) -> KvCacheType {
538        KvCacheType::from(self.context_params.type_v)
539    }
540
541    /// Set whether embeddings are enabled
542    ///
543    /// # Examples
544    ///
545    /// ```rust
546    /// # use llama_cpp_2::context::params::LlamaContextParams;
547    /// let params = LlamaContextParams::default()
548    ///    .with_embeddings(true);
549    /// assert!(params.embeddings());
550    /// ```
551    #[must_use]
552    pub fn with_embeddings(mut self, embedding: bool) -> Self {
553        self.context_params.embeddings = embedding;
554        self
555    }
556
557    /// Get whether embeddings are enabled
558    ///
559    /// # Examples
560    ///
561    /// ```rust
562    /// # use llama_cpp_2::context::params::LlamaContextParams;
563    /// let params = LlamaContextParams::default();
564    /// assert!(!params.embeddings());
565    /// ```
566    #[must_use]
567    pub fn embeddings(&self) -> bool {
568        self.context_params.embeddings
569    }
570
571    /// Set whether to offload KQV ops to GPU
572    ///
573    /// # Examples
574    ///
575    /// ```rust
576    /// # use llama_cpp_2::context::params::LlamaContextParams;
577    /// let params = LlamaContextParams::default()
578    ///     .with_offload_kqv(false);
579    /// assert_eq!(params.offload_kqv(), false);
580    /// ```
581    #[must_use]
582    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
583        self.context_params.offload_kqv = enabled;
584        self
585    }
586
587    /// Get whether KQV ops are offloaded to GPU
588    ///
589    /// # Examples
590    ///
591    /// ```rust
592    /// # use llama_cpp_2::context::params::LlamaContextParams;
593    /// let params = LlamaContextParams::default();
594    /// assert_eq!(params.offload_kqv(), true);
595    /// ```
596    #[must_use]
597    pub fn offload_kqv(&self) -> bool {
598        self.context_params.offload_kqv
599    }
600
601    /// Set whether to disable performance timings
602    ///
603    /// # Examples
604    ///
605    /// ```rust
606    /// # use llama_cpp_2::context::params::LlamaContextParams;
607    /// let params = LlamaContextParams::default().with_no_perf(true);
608    /// assert!(params.no_perf());
609    /// ```
610    #[must_use]
611    pub fn with_no_perf(mut self, no_perf: bool) -> Self {
612        self.context_params.no_perf = no_perf;
613        self
614    }
615
616    /// Get whether performance timings are disabled
617    #[must_use]
618    pub fn no_perf(&self) -> bool {
619        self.context_params.no_perf
620    }
621
622    /// Set whether to offload ops to GPU
623    ///
624    /// # Examples
625    ///
626    /// ```rust
627    /// # use llama_cpp_2::context::params::LlamaContextParams;
628    /// let params = LlamaContextParams::default().with_op_offload(false);
629    /// assert_eq!(params.op_offload(), false);
630    /// ```
631    #[must_use]
632    pub fn with_op_offload(mut self, op_offload: bool) -> Self {
633        self.context_params.op_offload = op_offload;
634        self
635    }
636
637    /// Get whether ops are offloaded to GPU
638    #[must_use]
639    pub fn op_offload(&self) -> bool {
640        self.context_params.op_offload
641    }
642
643    /// Set whether to use full sliding window attention
644    ///
645    /// # Examples
646    ///
647    /// ```rust
648    /// # use llama_cpp_2::context::params::LlamaContextParams;
649    /// let params = LlamaContextParams::default()
650    ///     .with_swa_full(false);
651    /// assert_eq!(params.swa_full(), false);
652    /// ```
653    #[must_use]
654    pub fn with_swa_full(mut self, enabled: bool) -> Self {
655        self.context_params.swa_full = enabled;
656        self
657    }
658
659    /// Get whether full sliding window attention is enabled
660    ///
661    /// # Examples
662    ///
663    /// ```rust
664    /// # use llama_cpp_2::context::params::LlamaContextParams;
665    /// let params = LlamaContextParams::default();
666    /// assert_eq!(params.swa_full(), true);
667    /// ```
668    #[must_use]
669    pub fn swa_full(&self) -> bool {
670        self.context_params.swa_full
671    }
672
673    /// Set whether to use a unified KV cache buffer across input sequences
674    ///
675    /// # Examples
676    ///
677    /// ```rust
678    /// # use llama_cpp_2::context::params::LlamaContextParams;
679    /// let params = LlamaContextParams::default().with_kv_unified(true);
680    /// assert!(params.kv_unified());
681    /// ```
682    #[must_use]
683    pub fn with_kv_unified(mut self, kv_unified: bool) -> Self {
684        self.context_params.kv_unified = kv_unified;
685        self
686    }
687
688    /// Get whether a unified KV cache buffer is used across input sequences
689    ///
690    /// # Examples
691    ///
692    /// ```rust
693    /// # use llama_cpp_2::context::params::LlamaContextParams;
694    /// let params = LlamaContextParams::default();
695    /// let _ = params.kv_unified();
696    /// ```
697    #[must_use]
698    pub fn kv_unified(&self) -> bool {
699        self.context_params.kv_unified
700    }
701}