llama_cpp_2/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58    /// Rank pooling
59    Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65    fn from(value: i32) -> Self {
66        match value {
67            0 => Self::None,
68            1 => Self::Mean,
69            2 => Self::Cls,
70            3 => Self::Last,
71            4 => Self::Rank,
72            _ => Self::Unspecified,
73        }
74    }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79    fn from(value: LlamaPoolingType) -> Self {
80        match value {
81            LlamaPoolingType::None => 0,
82            LlamaPoolingType::Mean => 1,
83            LlamaPoolingType::Cls => 2,
84            LlamaPoolingType::Last => 3,
85            LlamaPoolingType::Rank => 4,
86            LlamaPoolingType::Unspecified => -1,
87        }
88    }
89}
90
91/// A safe wrapper around `llama_context_params`.
92///
93/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
94///
95/// # Examples
96///
97/// ```rust
98/// # use std::num::NonZeroU32;
99/// use llama_cpp_2::context::params::LlamaContextParams;
100///
101///let ctx_params = LlamaContextParams::default()
102///    .with_n_ctx(NonZeroU32::new(2048));
103///
104/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
105/// ```
106#[derive(Debug, Clone)]
107#[allow(
108    missing_docs,
109    clippy::struct_excessive_bools,
110    clippy::module_name_repetitions
111)]
112pub struct LlamaContextParams {
113    pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
114}
115
116/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
117unsafe impl Send for LlamaContextParams {}
118unsafe impl Sync for LlamaContextParams {}
119
120impl LlamaContextParams {
121    /// Set the side of the context
122    ///
123    /// # Examples
124    ///
125    /// ```rust
126    /// # use std::num::NonZeroU32;
127    /// use llama_cpp_2::context::params::LlamaContextParams;
128    /// let params = LlamaContextParams::default();
129    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
130    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
131    /// ```
132    #[must_use]
133    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
134        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
135        self
136    }
137
138    /// Get the size of the context.
139    ///
140    /// [`None`] if the context size is specified by the model and not the context.
141    ///
142    /// # Examples
143    ///
144    /// ```rust
145    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
146    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
147    #[must_use]
148    pub fn n_ctx(&self) -> Option<NonZeroU32> {
149        NonZeroU32::new(self.context_params.n_ctx)
150    }
151
152    /// Set the `n_batch`
153    ///
154    /// # Examples
155    ///
156    /// ```rust
157    /// # use std::num::NonZeroU32;
158    /// use llama_cpp_2::context::params::LlamaContextParams;
159    /// let params = LlamaContextParams::default()
160    ///     .with_n_batch(2048);
161    /// assert_eq!(params.n_batch(), 2048);
162    /// ```
163    #[must_use]
164    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
165        self.context_params.n_batch = n_batch;
166        self
167    }
168
169    /// Get the `n_batch`
170    ///
171    /// # Examples
172    ///
173    /// ```rust
174    /// use llama_cpp_2::context::params::LlamaContextParams;
175    /// let params = LlamaContextParams::default();
176    /// assert_eq!(params.n_batch(), 2048);
177    /// ```
178    #[must_use]
179    pub fn n_batch(&self) -> u32 {
180        self.context_params.n_batch
181    }
182
183    /// Set the `n_ubatch`
184    ///
185    /// # Examples
186    ///
187    /// ```rust
188    /// # use std::num::NonZeroU32;
189    /// use llama_cpp_2::context::params::LlamaContextParams;
190    /// let params = LlamaContextParams::default()
191    ///     .with_n_ubatch(512);
192    /// assert_eq!(params.n_ubatch(), 512);
193    /// ```
194    #[must_use]
195    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
196        self.context_params.n_ubatch = n_ubatch;
197        self
198    }
199
200    /// Get the `n_ubatch`
201    ///
202    /// # Examples
203    ///
204    /// ```rust
205    /// use llama_cpp_2::context::params::LlamaContextParams;
206    /// let params = LlamaContextParams::default();
207    /// assert_eq!(params.n_ubatch(), 512);
208    /// ```
209    #[must_use]
210    pub fn n_ubatch(&self) -> u32 {
211        self.context_params.n_ubatch
212    }
213
214    /// Set the `flash_attention` parameter
215    ///
216    /// # Examples
217    ///
218    /// ```rust
219    /// use llama_cpp_2::context::params::LlamaContextParams;
220    /// let params = LlamaContextParams::default()
221    ///     .with_flash_attention(true);
222    /// assert_eq!(params.flash_attention(), true);
223    /// ```
224    #[must_use]
225    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
226        self.context_params.flash_attn = enabled;
227        self
228    }
229
230    /// Get the `flash_attention` parameter
231    ///
232    /// # Examples
233    ///
234    /// ```rust
235    /// use llama_cpp_2::context::params::LlamaContextParams;
236    /// let params = LlamaContextParams::default();
237    /// assert_eq!(params.flash_attention(), false);
238    /// ```
239    #[must_use]
240    pub fn flash_attention(&self) -> bool {
241        self.context_params.flash_attn
242    }
243
244    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
245    ///
246    /// # Examples
247    ///
248    /// ```rust
249    /// use llama_cpp_2::context::params::LlamaContextParams;
250    /// let params = LlamaContextParams::default()
251    ///     .with_offload_kqv(false);
252    /// assert_eq!(params.offload_kqv(), false);
253    /// ```
254    #[must_use]
255    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
256        self.context_params.offload_kqv = enabled;
257        self
258    }
259
260    /// Get the `offload_kqv` parameter
261    ///
262    /// # Examples
263    ///
264    /// ```rust
265    /// use llama_cpp_2::context::params::LlamaContextParams;
266    /// let params = LlamaContextParams::default();
267    /// assert_eq!(params.offload_kqv(), true);
268    /// ```
269    #[must_use]
270    pub fn offload_kqv(&self) -> bool {
271        self.context_params.offload_kqv
272    }
273
274    /// Set the type of rope scaling.
275    ///
276    /// # Examples
277    ///
278    /// ```rust
279    /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
280    /// let params = LlamaContextParams::default()
281    ///     .with_rope_scaling_type(RopeScalingType::Linear);
282    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
283    /// ```
284    #[must_use]
285    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
286        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
287        self
288    }
289
290    /// Get the type of rope scaling.
291    ///
292    /// # Examples
293    ///
294    /// ```rust
295    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
296    /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
297    /// ```
298    #[must_use]
299    pub fn rope_scaling_type(&self) -> RopeScalingType {
300        RopeScalingType::from(self.context_params.rope_scaling_type)
301    }
302
303    /// Set the rope frequency base.
304    ///
305    /// # Examples
306    ///
307    /// ```rust
308    /// use llama_cpp_2::context::params::LlamaContextParams;
309    /// let params = LlamaContextParams::default()
310    ///    .with_rope_freq_base(0.5);
311    /// assert_eq!(params.rope_freq_base(), 0.5);
312    /// ```
313    #[must_use]
314    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
315        self.context_params.rope_freq_base = rope_freq_base;
316        self
317    }
318
319    /// Get the rope frequency base.
320    ///
321    /// # Examples
322    ///
323    /// ```rust
324    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
325    /// assert_eq!(params.rope_freq_base(), 0.0);
326    /// ```
327    #[must_use]
328    pub fn rope_freq_base(&self) -> f32 {
329        self.context_params.rope_freq_base
330    }
331
332    /// Set the rope frequency scale.
333    ///
334    /// # Examples
335    ///
336    /// ```rust
337    /// use llama_cpp_2::context::params::LlamaContextParams;
338    /// let params = LlamaContextParams::default()
339    ///   .with_rope_freq_scale(0.5);
340    /// assert_eq!(params.rope_freq_scale(), 0.5);
341    /// ```
342    #[must_use]
343    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
344        self.context_params.rope_freq_scale = rope_freq_scale;
345        self
346    }
347
348    /// Get the rope frequency scale.
349    ///
350    /// # Examples
351    ///
352    /// ```rust
353    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
354    /// assert_eq!(params.rope_freq_scale(), 0.0);
355    /// ```
356    #[must_use]
357    pub fn rope_freq_scale(&self) -> f32 {
358        self.context_params.rope_freq_scale
359    }
360
361    /// Get the number of threads.
362    ///
363    /// # Examples
364    ///
365    /// ```rust
366    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
367    /// assert_eq!(params.n_threads(), 4);
368    /// ```
369    #[must_use]
370    pub fn n_threads(&self) -> i32 {
371        self.context_params.n_threads
372    }
373
374    /// Get the number of threads allocated for batches.
375    ///
376    /// # Examples
377    ///
378    /// ```rust
379    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
380    /// assert_eq!(params.n_threads_batch(), 4);
381    /// ```
382    #[must_use]
383    pub fn n_threads_batch(&self) -> i32 {
384        self.context_params.n_threads_batch
385    }
386
387    /// Set the number of threads.
388    ///
389    /// # Examples
390    ///
391    /// ```rust
392    /// use llama_cpp_2::context::params::LlamaContextParams;
393    /// let params = LlamaContextParams::default()
394    ///    .with_n_threads(8);
395    /// assert_eq!(params.n_threads(), 8);
396    /// ```
397    #[must_use]
398    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
399        self.context_params.n_threads = n_threads;
400        self
401    }
402
403    /// Set the number of threads allocated for batches.
404    ///
405    /// # Examples
406    ///
407    /// ```rust
408    /// use llama_cpp_2::context::params::LlamaContextParams;
409    /// let params = LlamaContextParams::default()
410    ///    .with_n_threads_batch(8);
411    /// assert_eq!(params.n_threads_batch(), 8);
412    /// ```
413    #[must_use]
414    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
415        self.context_params.n_threads_batch = n_threads;
416        self
417    }
418
419    /// Check whether embeddings are enabled
420    ///
421    /// # Examples
422    ///
423    /// ```rust
424    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
425    /// assert!(!params.embeddings());
426    /// ```
427    #[must_use]
428    pub fn embeddings(&self) -> bool {
429        self.context_params.embeddings
430    }
431
432    /// Enable the use of embeddings
433    ///
434    /// # Examples
435    ///
436    /// ```rust
437    /// use llama_cpp_2::context::params::LlamaContextParams;
438    /// let params = LlamaContextParams::default()
439    ///    .with_embeddings(true);
440    /// assert!(params.embeddings());
441    /// ```
442    #[must_use]
443    pub fn with_embeddings(mut self, embedding: bool) -> Self {
444        self.context_params.embeddings = embedding;
445        self
446    }
447
448    /// Set the evaluation callback.
449    ///
450    /// # Examples
451    ///
452    /// ```no_run
453    /// extern "C" fn cb_eval_fn(
454    ///     t: *mut llama_cpp_sys_2::ggml_tensor,
455    ///     ask: bool,
456    ///     user_data: *mut std::ffi::c_void,
457    /// ) -> bool {
458    ///     false
459    /// }
460    ///
461    /// use llama_cpp_2::context::params::LlamaContextParams;
462    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
463    /// ```
464    #[must_use]
465    pub fn with_cb_eval(
466        mut self,
467        cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
468    ) -> Self {
469        self.context_params.cb_eval = cb_eval;
470        self
471    }
472
473    /// Set the evaluation callback user data.
474    ///
475    /// # Examples
476    ///
477    /// ```no_run
478    /// use llama_cpp_2::context::params::LlamaContextParams;
479    /// let params = LlamaContextParams::default();
480    /// let user_data = std::ptr::null_mut();
481    /// let params = params.with_cb_eval_user_data(user_data);
482    /// ```
483    #[must_use]
484    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
485        self.context_params.cb_eval_user_data = cb_eval_user_data;
486        self
487    }
488
489    /// Set the type of pooling.
490    ///
491    /// # Examples
492    ///
493    /// ```rust
494    /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
495    /// let params = LlamaContextParams::default()
496    ///     .with_pooling_type(LlamaPoolingType::Last);
497    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
498    /// ```
499    #[must_use]
500    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
501        self.context_params.pooling_type = i32::from(pooling_type);
502        self
503    }
504
505    /// Get the type of pooling.
506    ///
507    /// # Examples
508    ///
509    /// ```rust
510    /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
511    /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
512    /// ```
513    #[must_use]
514    pub fn pooling_type(&self) -> LlamaPoolingType {
515        LlamaPoolingType::from(self.context_params.pooling_type)
516    }
517
518    /// Set whether to use full sliding window attention
519    ///
520    /// # Examples
521    ///
522    /// ```rust
523    /// use llama_cpp_2::context::params::LlamaContextParams;
524    /// let params = LlamaContextParams::default()
525    ///     .with_swa_full(false);
526    /// assert_eq!(params.swa_full(), false);
527    /// ```
528    #[must_use]
529    pub fn with_swa_full(mut self, enabled: bool) -> Self {
530        self.context_params.swa_full = enabled;
531        self
532    }
533
534    /// Get whether full sliding window attention is enabled
535    ///
536    /// # Examples
537    ///
538    /// ```rust
539    /// use llama_cpp_2::context::params::LlamaContextParams;
540    /// let params = LlamaContextParams::default();
541    /// assert_eq!(params.swa_full(), true);
542    /// ```
543    #[must_use]
544    pub fn swa_full(&self) -> bool {
545        self.context_params.swa_full
546    }
547}
548
549/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
550/// ```
551/// # use std::num::NonZeroU32;
552/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
553/// let params = LlamaContextParams::default();
554/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
555/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
556/// ```
557impl Default for LlamaContextParams {
558    fn default() -> Self {
559        let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
560        Self { context_params }
561    }
562}