Skip to main content

llama_cpp_4/context/
params.rs

1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9    /// The scaling type is unspecified
10    Unspecified = -1,
11    /// No scaling
12    None = 0,
13    /// Linear scaling
14    Linear = 1,
15    /// Yarn scaling
16    Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22    fn from(value: i32) -> Self {
23        match value {
24            0 => Self::None,
25            1 => Self::Linear,
26            2 => Self::Yarn,
27            _ => Self::Unspecified,
28        }
29    }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34    fn from(value: RopeScalingType) -> Self {
35        match value {
36            RopeScalingType::None => 0,
37            RopeScalingType::Linear => 1,
38            RopeScalingType::Yarn => 2,
39            RopeScalingType::Unspecified => -1,
40        }
41    }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48    /// The pooling type is unspecified
49    Unspecified = -1,
50    /// No pooling    
51    None = 0,
52    /// Mean pooling
53    Mean = 1,
54    /// CLS pooling
55    Cls = 2,
56    /// Last pooling
57    Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63    fn from(value: i32) -> Self {
64        match value {
65            0 => Self::None,
66            1 => Self::Mean,
67            2 => Self::Cls,
68            3 => Self::Last,
69            _ => Self::Unspecified,
70        }
71    }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76    fn from(value: LlamaPoolingType) -> Self {
77        match value {
78            LlamaPoolingType::None => 0,
79            LlamaPoolingType::Mean => 1,
80            LlamaPoolingType::Cls => 2,
81            LlamaPoolingType::Last => 3,
82            LlamaPoolingType::Unspecified => -1,
83        }
84    }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97///let ctx_params = LlamaContextParams::default()
98///    .with_n_ctx(NonZeroU32::new(2048))
99///    .with_seed(1234);
100///
101/// assert_eq!(ctx_params.seed(), 1234);
102/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103/// ```
104#[derive(Debug, Clone)]
105#[allow(
106    missing_docs,
107    clippy::struct_excessive_bools,
108    clippy::module_name_repetitions
109)]
110pub struct LlamaContextParams {
111    pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
112}
113
114/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
115unsafe impl Send for LlamaContextParams {}
116unsafe impl Sync for LlamaContextParams {}
117
118impl LlamaContextParams {
119    /// Set the side of the context
120    ///
121    /// # Examples
122    ///
123    /// ```rust
124    /// # use std::num::NonZeroU32;
125    /// use llama_cpp_4::context::params::LlamaContextParams;
126    /// let params = LlamaContextParams::default();
127    /// let params = params.with_n_ctx(NonZeroU32::new(2048));
128    /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
129    /// ```
130    #[must_use]
131    pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
132        self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
133        self
134    }
135
136    /// Get the size of the context.
137    ///
138    /// [`None`] if the context size is specified by the model and not the context.
139    ///
140    /// # Examples
141    ///
142    /// ```rust
143    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
144    /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
145    #[must_use]
146    pub fn n_ctx(&self) -> Option<NonZeroU32> {
147        NonZeroU32::new(self.context_params.n_ctx)
148    }
149
150    /// Set the `n_batch`
151    ///
152    /// # Examples
153    ///
154    /// ```rust
155    /// # use std::num::NonZeroU32;
156    /// use llama_cpp_4::context::params::LlamaContextParams;
157    /// let params = LlamaContextParams::default()
158    ///     .with_n_batch(2048);
159    /// assert_eq!(params.n_batch(), 2048);
160    /// ```
161    #[must_use]
162    pub fn with_n_batch(mut self, n_batch: u32) -> Self {
163        self.context_params.n_batch = n_batch;
164        self
165    }
166
167    /// Get the `n_batch`
168    ///
169    /// # Examples
170    ///
171    /// ```rust
172    /// use llama_cpp_4::context::params::LlamaContextParams;
173    /// let params = LlamaContextParams::default();
174    /// assert_eq!(params.n_batch(), 2048);
175    /// ```
176    #[must_use]
177    pub fn n_batch(&self) -> u32 {
178        self.context_params.n_batch
179    }
180
181    /// Set the `n_ubatch`
182    ///
183    /// # Examples
184    ///
185    /// ```rust
186    /// # use std::num::NonZeroU32;
187    /// use llama_cpp_4::context::params::LlamaContextParams;
188    /// let params = LlamaContextParams::default()
189    ///     .with_n_ubatch(512);
190    /// assert_eq!(params.n_ubatch(), 512);
191    /// ```
192    #[must_use]
193    pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
194        self.context_params.n_ubatch = n_ubatch;
195        self
196    }
197
198    /// Get the `n_ubatch`
199    ///
200    /// # Examples
201    ///
202    /// ```rust
203    /// use llama_cpp_4::context::params::LlamaContextParams;
204    /// let params = LlamaContextParams::default();
205    /// assert_eq!(params.n_ubatch(), 512);
206    /// ```
207    #[must_use]
208    pub fn n_ubatch(&self) -> u32 {
209        self.context_params.n_ubatch
210    }
211
212    /// Set the `flash_attention` parameter
213    ///
214    /// # Examples
215    ///
216    /// ```rust
217    /// use llama_cpp_4::context::params::LlamaContextParams;
218    /// let params = LlamaContextParams::default()
219    ///     .with_flash_attention(true);
220    /// assert_eq!(params.flash_attention(), true);
221    /// ```
222    #[must_use]
223    pub fn with_flash_attention(mut self, enabled: bool) -> Self {
224        self.context_params.flash_attn_type = if enabled {
225            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
226        } else {
227            llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
228        };
229        self
230    }
231
232    /// Get the `flash_attention` parameter
233    ///
234    /// # Examples
235    ///
236    /// ```rust
237    /// use llama_cpp_4::context::params::LlamaContextParams;
238    /// let params = LlamaContextParams::default();
239    /// assert_eq!(params.flash_attention(), false);
240    /// ```
241    #[must_use]
242    pub fn flash_attention(&self) -> bool {
243        self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
244    }
245
246    /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
247    ///
248    /// # Examples
249    ///
250    /// ```rust
251    /// use llama_cpp_4::context::params::LlamaContextParams;
252    /// let params = LlamaContextParams::default()
253    ///     .with_offload_kqv(false);
254    /// assert_eq!(params.offload_kqv(), false);
255    /// ```
256    #[must_use]
257    pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
258        self.context_params.offload_kqv = enabled;
259        self
260    }
261
262    /// Get the `offload_kqv` parameter
263    ///
264    /// # Examples
265    ///
266    /// ```rust
267    /// use llama_cpp_4::context::params::LlamaContextParams;
268    /// let params = LlamaContextParams::default();
269    /// assert_eq!(params.offload_kqv(), true);
270    /// ```
271    #[must_use]
272    pub fn offload_kqv(&self) -> bool {
273        self.context_params.offload_kqv
274    }
275
276    /// Set the type of rope scaling.
277    ///
278    /// # Examples
279    ///
280    /// ```rust
281    /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
282    /// let params = LlamaContextParams::default()
283    ///     .with_rope_scaling_type(RopeScalingType::Linear);
284    /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
285    /// ```
286    #[must_use]
287    pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
288        self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
289        self
290    }
291
292    /// Get the type of rope scaling.
293    ///
294    /// # Examples
295    ///
296    /// ```rust
297    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
298    /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
299    /// ```
300    #[must_use]
301    pub fn rope_scaling_type(&self) -> RopeScalingType {
302        RopeScalingType::from(self.context_params.rope_scaling_type)
303    }
304
305    /// Set the rope frequency base.
306    ///
307    /// # Examples
308    ///
309    /// ```rust
310    /// use llama_cpp_4::context::params::LlamaContextParams;
311    /// let params = LlamaContextParams::default()
312    ///    .with_rope_freq_base(0.5);
313    /// assert_eq!(params.rope_freq_base(), 0.5);
314    /// ```
315    #[must_use]
316    pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
317        self.context_params.rope_freq_base = rope_freq_base;
318        self
319    }
320
321    /// Get the rope frequency base.
322    ///
323    /// # Examples
324    ///
325    /// ```rust
326    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
327    /// assert_eq!(params.rope_freq_base(), 0.0);
328    /// ```
329    #[must_use]
330    pub fn rope_freq_base(&self) -> f32 {
331        self.context_params.rope_freq_base
332    }
333
334    /// Set the rope frequency scale.
335    ///
336    /// # Examples
337    ///
338    /// ```rust
339    /// use llama_cpp_4::context::params::LlamaContextParams;
340    /// let params = LlamaContextParams::default()
341    ///   .with_rope_freq_scale(0.5);
342    /// assert_eq!(params.rope_freq_scale(), 0.5);
343    /// ```
344    #[must_use]
345    pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
346        self.context_params.rope_freq_scale = rope_freq_scale;
347        self
348    }
349
350    /// Get the rope frequency scale.
351    ///
352    /// # Examples
353    ///
354    /// ```rust
355    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
356    /// assert_eq!(params.rope_freq_scale(), 0.0);
357    /// ```
358    #[must_use]
359    pub fn rope_freq_scale(&self) -> f32 {
360        self.context_params.rope_freq_scale
361    }
362
363    /// Get the number of threads.
364    ///
365    /// # Examples
366    ///
367    /// ```rust
368    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
369    /// assert_eq!(params.n_threads(), 4);
370    /// ```
371    #[must_use]
372    pub fn n_threads(&self) -> i32 {
373        self.context_params.n_threads
374    }
375
376    /// Get the number of threads allocated for batches.
377    ///
378    /// # Examples
379    ///
380    /// ```rust
381    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
382    /// assert_eq!(params.n_threads_batch(), 4);
383    /// ```
384    #[must_use]
385    pub fn n_threads_batch(&self) -> i32 {
386        self.context_params.n_threads_batch
387    }
388
389    /// Set the number of threads.
390    ///
391    /// # Examples
392    ///
393    /// ```rust
394    /// use llama_cpp_4::context::params::LlamaContextParams;
395    /// let params = LlamaContextParams::default()
396    ///    .with_n_threads(8);
397    /// assert_eq!(params.n_threads(), 8);
398    /// ```
399    #[must_use]
400    pub fn with_n_threads(mut self, n_threads: i32) -> Self {
401        self.context_params.n_threads = n_threads;
402        self
403    }
404
405    /// Set the number of threads allocated for batches.
406    ///
407    /// # Examples
408    ///
409    /// ```rust
410    /// use llama_cpp_4::context::params::LlamaContextParams;
411    /// let params = LlamaContextParams::default()
412    ///    .with_n_threads_batch(8);
413    /// assert_eq!(params.n_threads_batch(), 8);
414    /// ```
415    #[must_use]
416    pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
417        self.context_params.n_threads_batch = n_threads;
418        self
419    }
420
421    /// Check whether embeddings are enabled
422    ///
423    /// # Examples
424    ///
425    /// ```rust
426    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
427    /// assert!(!params.embeddings());
428    /// ```
429    #[must_use]
430    pub fn embeddings(&self) -> bool {
431        self.context_params.embeddings
432    }
433
434    /// Enable the use of embeddings
435    ///
436    /// # Examples
437    ///
438    /// ```rust
439    /// use llama_cpp_4::context::params::LlamaContextParams;
440    /// let params = LlamaContextParams::default()
441    ///    .with_embeddings(true);
442    /// assert!(params.embeddings());
443    /// ```
444    #[must_use]
445    pub fn with_embeddings(mut self, embedding: bool) -> Self {
446        self.context_params.embeddings = embedding;
447        self
448    }
449
450    /// Set the evaluation callback.
451    ///
452    /// # Examples
453    ///
454    /// ```no_run
455    /// extern "C" fn cb_eval_fn(
456    ///     t: *mut llama_cpp_sys_4::ggml_tensor,
457    ///     ask: bool,
458    ///     user_data: *mut std::ffi::c_void,
459    /// ) -> bool {
460    ///     false
461    /// }
462    ///
463    /// use llama_cpp_4::context::params::LlamaContextParams;
464    /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
465    /// ```
466    #[must_use]
467    pub fn with_cb_eval(
468        mut self,
469        cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
470    ) -> Self {
471        self.context_params.cb_eval = cb_eval;
472        self
473    }
474
475    /// Set the evaluation callback user data.
476    ///
477    /// # Examples
478    ///
479    /// ```no_run
480    /// use llama_cpp_4::context::params::LlamaContextParams;
481    /// let params = LlamaContextParams::default();
482    /// let user_data = std::ptr::null_mut();
483    /// let params = params.with_cb_eval_user_data(user_data);
484    /// ```
485    #[must_use]
486    pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
487        self.context_params.cb_eval_user_data = cb_eval_user_data;
488        self
489    }
490
491    /// Set the type of pooling.
492    ///
493    /// # Examples
494    ///
495    /// ```rust
496    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
497    /// let params = LlamaContextParams::default()
498    ///     .with_pooling_type(LlamaPoolingType::Last);
499    /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
500    /// ```
501    #[must_use]
502    pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
503        self.context_params.pooling_type = i32::from(pooling_type);
504        self
505    }
506
507    /// Get the type of pooling.
508    ///
509    /// # Examples
510    ///
511    /// ```rust
512    /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
513    /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
514    /// ```
515    #[must_use]
516    pub fn pooling_type(&self) -> LlamaPoolingType {
517        LlamaPoolingType::from(self.context_params.pooling_type)
518    }
519}
520
521/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
522/// ```
523/// # use std::num::NonZeroU32;
524/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
525/// let params = LlamaContextParams::default();
526/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
527/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
528/// ```
529impl Default for LlamaContextParams {
530    fn default() -> Self {
531        let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
532        Self { context_params }
533    }
534}