llama_cpp_2/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58 /// Rank pooling
59 Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65 fn from(value: i32) -> Self {
66 match value {
67 0 => Self::None,
68 1 => Self::Mean,
69 2 => Self::Cls,
70 3 => Self::Last,
71 4 => Self::Rank,
72 _ => Self::Unspecified,
73 }
74 }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79 fn from(value: LlamaPoolingType) -> Self {
80 match value {
81 LlamaPoolingType::None => 0,
82 LlamaPoolingType::Mean => 1,
83 LlamaPoolingType::Cls => 2,
84 LlamaPoolingType::Last => 3,
85 LlamaPoolingType::Rank => 4,
86 LlamaPoolingType::Unspecified => -1,
87 }
88 }
89}
90
91/// A safe wrapper around `llama_context_params`.
92///
93/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
94///
95/// # Examples
96///
97/// ```rust
98/// # use std::num::NonZeroU32;
99/// use llama_cpp_2::context::params::LlamaContextParams;
100///
101///let ctx_params = LlamaContextParams::default()
102/// .with_n_ctx(NonZeroU32::new(2048));
103///
104/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
105/// ```
106#[derive(Debug, Clone)]
107#[allow(
108 missing_docs,
109 clippy::struct_excessive_bools,
110 clippy::module_name_repetitions
111)]
112pub struct LlamaContextParams {
113 pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
114}
115
116/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
117unsafe impl Send for LlamaContextParams {}
118unsafe impl Sync for LlamaContextParams {}
119
120impl LlamaContextParams {
121 /// Set the side of the context
122 ///
123 /// # Examples
124 ///
125 /// ```rust
126 /// # use std::num::NonZeroU32;
127 /// use llama_cpp_2::context::params::LlamaContextParams;
128 /// let params = LlamaContextParams::default();
129 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
130 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
131 /// ```
132 #[must_use]
133 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
134 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
135 self
136 }
137
138 /// Get the size of the context.
139 ///
140 /// [`None`] if the context size is specified by the model and not the context.
141 ///
142 /// # Examples
143 ///
144 /// ```rust
145 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
146 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
147 #[must_use]
148 pub fn n_ctx(&self) -> Option<NonZeroU32> {
149 NonZeroU32::new(self.context_params.n_ctx)
150 }
151
152 /// Set the `n_batch`
153 ///
154 /// # Examples
155 ///
156 /// ```rust
157 /// # use std::num::NonZeroU32;
158 /// use llama_cpp_2::context::params::LlamaContextParams;
159 /// let params = LlamaContextParams::default()
160 /// .with_n_batch(2048);
161 /// assert_eq!(params.n_batch(), 2048);
162 /// ```
163 #[must_use]
164 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
165 self.context_params.n_batch = n_batch;
166 self
167 }
168
169 /// Get the `n_batch`
170 ///
171 /// # Examples
172 ///
173 /// ```rust
174 /// use llama_cpp_2::context::params::LlamaContextParams;
175 /// let params = LlamaContextParams::default();
176 /// assert_eq!(params.n_batch(), 2048);
177 /// ```
178 #[must_use]
179 pub fn n_batch(&self) -> u32 {
180 self.context_params.n_batch
181 }
182
183 /// Set the `n_ubatch`
184 ///
185 /// # Examples
186 ///
187 /// ```rust
188 /// # use std::num::NonZeroU32;
189 /// use llama_cpp_2::context::params::LlamaContextParams;
190 /// let params = LlamaContextParams::default()
191 /// .with_n_ubatch(512);
192 /// assert_eq!(params.n_ubatch(), 512);
193 /// ```
194 #[must_use]
195 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
196 self.context_params.n_ubatch = n_ubatch;
197 self
198 }
199
200 /// Get the `n_ubatch`
201 ///
202 /// # Examples
203 ///
204 /// ```rust
205 /// use llama_cpp_2::context::params::LlamaContextParams;
206 /// let params = LlamaContextParams::default();
207 /// assert_eq!(params.n_ubatch(), 512);
208 /// ```
209 #[must_use]
210 pub fn n_ubatch(&self) -> u32 {
211 self.context_params.n_ubatch
212 }
213
214 /// Set the `flash_attention` parameter
215 ///
216 /// # Examples
217 ///
218 /// ```rust
219 /// use llama_cpp_2::context::params::LlamaContextParams;
220 /// let params = LlamaContextParams::default()
221 /// .with_flash_attention(true);
222 /// assert_eq!(params.flash_attention(), true);
223 /// ```
224 #[must_use]
225 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
226 self.context_params.flash_attn = enabled;
227 self
228 }
229
230 /// Get the `flash_attention` parameter
231 ///
232 /// # Examples
233 ///
234 /// ```rust
235 /// use llama_cpp_2::context::params::LlamaContextParams;
236 /// let params = LlamaContextParams::default();
237 /// assert_eq!(params.flash_attention(), false);
238 /// ```
239 #[must_use]
240 pub fn flash_attention(&self) -> bool {
241 self.context_params.flash_attn
242 }
243
244 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
245 ///
246 /// # Examples
247 ///
248 /// ```rust
249 /// use llama_cpp_2::context::params::LlamaContextParams;
250 /// let params = LlamaContextParams::default()
251 /// .with_offload_kqv(false);
252 /// assert_eq!(params.offload_kqv(), false);
253 /// ```
254 #[must_use]
255 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
256 self.context_params.offload_kqv = enabled;
257 self
258 }
259
260 /// Get the `offload_kqv` parameter
261 ///
262 /// # Examples
263 ///
264 /// ```rust
265 /// use llama_cpp_2::context::params::LlamaContextParams;
266 /// let params = LlamaContextParams::default();
267 /// assert_eq!(params.offload_kqv(), true);
268 /// ```
269 #[must_use]
270 pub fn offload_kqv(&self) -> bool {
271 self.context_params.offload_kqv
272 }
273
274 /// Set the type of rope scaling.
275 ///
276 /// # Examples
277 ///
278 /// ```rust
279 /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
280 /// let params = LlamaContextParams::default()
281 /// .with_rope_scaling_type(RopeScalingType::Linear);
282 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
283 /// ```
284 #[must_use]
285 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
286 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
287 self
288 }
289
290 /// Get the type of rope scaling.
291 ///
292 /// # Examples
293 ///
294 /// ```rust
295 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
296 /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
297 /// ```
298 #[must_use]
299 pub fn rope_scaling_type(&self) -> RopeScalingType {
300 RopeScalingType::from(self.context_params.rope_scaling_type)
301 }
302
303 /// Set the rope frequency base.
304 ///
305 /// # Examples
306 ///
307 /// ```rust
308 /// use llama_cpp_2::context::params::LlamaContextParams;
309 /// let params = LlamaContextParams::default()
310 /// .with_rope_freq_base(0.5);
311 /// assert_eq!(params.rope_freq_base(), 0.5);
312 /// ```
313 #[must_use]
314 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
315 self.context_params.rope_freq_base = rope_freq_base;
316 self
317 }
318
319 /// Get the rope frequency base.
320 ///
321 /// # Examples
322 ///
323 /// ```rust
324 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
325 /// assert_eq!(params.rope_freq_base(), 0.0);
326 /// ```
327 #[must_use]
328 pub fn rope_freq_base(&self) -> f32 {
329 self.context_params.rope_freq_base
330 }
331
332 /// Set the rope frequency scale.
333 ///
334 /// # Examples
335 ///
336 /// ```rust
337 /// use llama_cpp_2::context::params::LlamaContextParams;
338 /// let params = LlamaContextParams::default()
339 /// .with_rope_freq_scale(0.5);
340 /// assert_eq!(params.rope_freq_scale(), 0.5);
341 /// ```
342 #[must_use]
343 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
344 self.context_params.rope_freq_scale = rope_freq_scale;
345 self
346 }
347
348 /// Get the rope frequency scale.
349 ///
350 /// # Examples
351 ///
352 /// ```rust
353 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
354 /// assert_eq!(params.rope_freq_scale(), 0.0);
355 /// ```
356 #[must_use]
357 pub fn rope_freq_scale(&self) -> f32 {
358 self.context_params.rope_freq_scale
359 }
360
361 /// Get the number of threads.
362 ///
363 /// # Examples
364 ///
365 /// ```rust
366 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
367 /// assert_eq!(params.n_threads(), 4);
368 /// ```
369 #[must_use]
370 pub fn n_threads(&self) -> i32 {
371 self.context_params.n_threads
372 }
373
374 /// Get the number of threads allocated for batches.
375 ///
376 /// # Examples
377 ///
378 /// ```rust
379 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
380 /// assert_eq!(params.n_threads_batch(), 4);
381 /// ```
382 #[must_use]
383 pub fn n_threads_batch(&self) -> i32 {
384 self.context_params.n_threads_batch
385 }
386
387 /// Set the number of threads.
388 ///
389 /// # Examples
390 ///
391 /// ```rust
392 /// use llama_cpp_2::context::params::LlamaContextParams;
393 /// let params = LlamaContextParams::default()
394 /// .with_n_threads(8);
395 /// assert_eq!(params.n_threads(), 8);
396 /// ```
397 #[must_use]
398 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
399 self.context_params.n_threads = n_threads;
400 self
401 }
402
403 /// Set the number of threads allocated for batches.
404 ///
405 /// # Examples
406 ///
407 /// ```rust
408 /// use llama_cpp_2::context::params::LlamaContextParams;
409 /// let params = LlamaContextParams::default()
410 /// .with_n_threads_batch(8);
411 /// assert_eq!(params.n_threads_batch(), 8);
412 /// ```
413 #[must_use]
414 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
415 self.context_params.n_threads_batch = n_threads;
416 self
417 }
418
419 /// Check whether embeddings are enabled
420 ///
421 /// # Examples
422 ///
423 /// ```rust
424 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
425 /// assert!(!params.embeddings());
426 /// ```
427 #[must_use]
428 pub fn embeddings(&self) -> bool {
429 self.context_params.embeddings
430 }
431
432 /// Enable the use of embeddings
433 ///
434 /// # Examples
435 ///
436 /// ```rust
437 /// use llama_cpp_2::context::params::LlamaContextParams;
438 /// let params = LlamaContextParams::default()
439 /// .with_embeddings(true);
440 /// assert!(params.embeddings());
441 /// ```
442 #[must_use]
443 pub fn with_embeddings(mut self, embedding: bool) -> Self {
444 self.context_params.embeddings = embedding;
445 self
446 }
447
448 /// Set the evaluation callback.
449 ///
450 /// # Examples
451 ///
452 /// ```no_run
453 /// extern "C" fn cb_eval_fn(
454 /// t: *mut llama_cpp_sys_2::ggml_tensor,
455 /// ask: bool,
456 /// user_data: *mut std::ffi::c_void,
457 /// ) -> bool {
458 /// false
459 /// }
460 ///
461 /// use llama_cpp_2::context::params::LlamaContextParams;
462 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
463 /// ```
464 #[must_use]
465 pub fn with_cb_eval(
466 mut self,
467 cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
468 ) -> Self {
469 self.context_params.cb_eval = cb_eval;
470 self
471 }
472
473 /// Set the evaluation callback user data.
474 ///
475 /// # Examples
476 ///
477 /// ```no_run
478 /// use llama_cpp_2::context::params::LlamaContextParams;
479 /// let params = LlamaContextParams::default();
480 /// let user_data = std::ptr::null_mut();
481 /// let params = params.with_cb_eval_user_data(user_data);
482 /// ```
483 #[must_use]
484 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
485 self.context_params.cb_eval_user_data = cb_eval_user_data;
486 self
487 }
488
489 /// Set the type of pooling.
490 ///
491 /// # Examples
492 ///
493 /// ```rust
494 /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
495 /// let params = LlamaContextParams::default()
496 /// .with_pooling_type(LlamaPoolingType::Last);
497 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
498 /// ```
499 #[must_use]
500 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
501 self.context_params.pooling_type = i32::from(pooling_type);
502 self
503 }
504
505 /// Get the type of pooling.
506 ///
507 /// # Examples
508 ///
509 /// ```rust
510 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
511 /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
512 /// ```
513 #[must_use]
514 pub fn pooling_type(&self) -> LlamaPoolingType {
515 LlamaPoolingType::from(self.context_params.pooling_type)
516 }
517
518 /// Set whether to use full sliding window attention
519 ///
520 /// # Examples
521 ///
522 /// ```rust
523 /// use llama_cpp_2::context::params::LlamaContextParams;
524 /// let params = LlamaContextParams::default()
525 /// .with_swa_full(false);
526 /// assert_eq!(params.swa_full(), false);
527 /// ```
528 #[must_use]
529 pub fn with_swa_full(mut self, enabled: bool) -> Self {
530 self.context_params.swa_full = enabled;
531 self
532 }
533
534 /// Get whether full sliding window attention is enabled
535 ///
536 /// # Examples
537 ///
538 /// ```rust
539 /// use llama_cpp_2::context::params::LlamaContextParams;
540 /// let params = LlamaContextParams::default();
541 /// assert_eq!(params.swa_full(), true);
542 /// ```
543 #[must_use]
544 pub fn swa_full(&self) -> bool {
545 self.context_params.swa_full
546 }
547}
548
549/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
550/// ```
551/// # use std::num::NonZeroU32;
552/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
553/// let params = LlamaContextParams::default();
554/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
555/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
556/// ```
557impl Default for LlamaContextParams {
558 fn default() -> Self {
559 let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
560 Self { context_params }
561 }
562}