llama_cpp_2/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58 /// Rank pooling
59 Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65 fn from(value: i32) -> Self {
66 match value {
67 0 => Self::None,
68 1 => Self::Mean,
69 2 => Self::Cls,
70 3 => Self::Last,
71 4 => Self::Rank,
72 _ => Self::Unspecified,
73 }
74 }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79 fn from(value: LlamaPoolingType) -> Self {
80 match value {
81 LlamaPoolingType::None => 0,
82 LlamaPoolingType::Mean => 1,
83 LlamaPoolingType::Cls => 2,
84 LlamaPoolingType::Last => 3,
85 LlamaPoolingType::Rank => 4,
86 LlamaPoolingType::Unspecified => -1,
87 }
88 }
89}
90
91/// A rusty wrapper around `ggml_type` for KV cache types.
92#[allow(non_camel_case_types, missing_docs)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum KvCacheType {
95 /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
96 /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
97 /// the runtime will operate with that type).
98 /// This variant preserves API compatibility when new `ggml_type` values are
99 /// introduced in the future.
100 Unknown(llama_cpp_sys_2::ggml_type),
101 F32,
102 F16,
103 Q4_0,
104 Q4_1,
105 Q5_0,
106 Q5_1,
107 Q8_0,
108 Q8_1,
109 Q2_K,
110 Q3_K,
111 Q4_K,
112 Q5_K,
113 Q6_K,
114 Q8_K,
115 IQ2_XXS,
116 IQ2_XS,
117 IQ3_XXS,
118 IQ1_S,
119 IQ4_NL,
120 IQ3_S,
121 IQ2_S,
122 IQ4_XS,
123 I8,
124 I16,
125 I32,
126 I64,
127 F64,
128 IQ1_M,
129 BF16,
130 TQ1_0,
131 TQ2_0,
132 MXFP4,
133}
134
135impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
136 fn from(value: KvCacheType) -> Self {
137 match value {
138 KvCacheType::Unknown(raw) => raw,
139 KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
140 KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
141 KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
142 KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
143 KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
144 KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
145 KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
146 KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
147 KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
148 KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
149 KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
150 KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
151 KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
152 KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
153 KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
154 KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
155 KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
156 KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
157 KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
158 KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
159 KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
160 KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
161 KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
162 KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
163 KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
164 KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
165 KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
166 KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
167 KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
168 KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
169 KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
170 KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
171 }
172 }
173}
174
175impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
176 fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
177 match value {
178 x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
179 x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
180 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
181 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
182 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
183 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
184 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
185 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
186 x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
187 x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
188 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
189 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
190 x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
191 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
192 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
193 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
194 x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
195 x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
196 x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
197 x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
198 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
199 x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
200 x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
201 x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
202 x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
203 x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
204 x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
205 x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
206 x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
207 x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
208 x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
209 x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
210 _ => KvCacheType::Unknown(value),
211 }
212 }
213}
214
215/// A safe wrapper around `llama_context_params`.
216///
217/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
218///
219/// # Examples
220///
221/// ```rust
222/// # use std::num::NonZeroU32;
223/// use llama_cpp_2::context::params::LlamaContextParams;
224///
225///let ctx_params = LlamaContextParams::default()
226/// .with_n_ctx(NonZeroU32::new(2048));
227///
228/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
229/// ```
230#[derive(Debug, Clone)]
231#[allow(
232 missing_docs,
233 clippy::struct_excessive_bools,
234 clippy::module_name_repetitions
235)]
236pub struct LlamaContextParams {
237 pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
238}
239
240/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
241unsafe impl Send for LlamaContextParams {}
242unsafe impl Sync for LlamaContextParams {}
243
244impl LlamaContextParams {
245 /// Set the side of the context
246 ///
247 /// # Examples
248 ///
249 /// ```rust
250 /// # use std::num::NonZeroU32;
251 /// use llama_cpp_2::context::params::LlamaContextParams;
252 /// let params = LlamaContextParams::default();
253 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
254 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
255 /// ```
256 #[must_use]
257 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
258 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
259 self
260 }
261
262 /// Get the size of the context.
263 ///
264 /// [`None`] if the context size is specified by the model and not the context.
265 ///
266 /// # Examples
267 ///
268 /// ```rust
269 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
270 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
271 #[must_use]
272 pub fn n_ctx(&self) -> Option<NonZeroU32> {
273 NonZeroU32::new(self.context_params.n_ctx)
274 }
275
276 /// Set the `n_batch`
277 ///
278 /// # Examples
279 ///
280 /// ```rust
281 /// # use std::num::NonZeroU32;
282 /// use llama_cpp_2::context::params::LlamaContextParams;
283 /// let params = LlamaContextParams::default()
284 /// .with_n_batch(2048);
285 /// assert_eq!(params.n_batch(), 2048);
286 /// ```
287 #[must_use]
288 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
289 self.context_params.n_batch = n_batch;
290 self
291 }
292
293 /// Get the `n_batch`
294 ///
295 /// # Examples
296 ///
297 /// ```rust
298 /// use llama_cpp_2::context::params::LlamaContextParams;
299 /// let params = LlamaContextParams::default();
300 /// assert_eq!(params.n_batch(), 2048);
301 /// ```
302 #[must_use]
303 pub fn n_batch(&self) -> u32 {
304 self.context_params.n_batch
305 }
306
307 /// Set the `n_ubatch`
308 ///
309 /// # Examples
310 ///
311 /// ```rust
312 /// # use std::num::NonZeroU32;
313 /// use llama_cpp_2::context::params::LlamaContextParams;
314 /// let params = LlamaContextParams::default()
315 /// .with_n_ubatch(512);
316 /// assert_eq!(params.n_ubatch(), 512);
317 /// ```
318 #[must_use]
319 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
320 self.context_params.n_ubatch = n_ubatch;
321 self
322 }
323
324 /// Get the `n_ubatch`
325 ///
326 /// # Examples
327 ///
328 /// ```rust
329 /// use llama_cpp_2::context::params::LlamaContextParams;
330 /// let params = LlamaContextParams::default();
331 /// assert_eq!(params.n_ubatch(), 512);
332 /// ```
333 #[must_use]
334 pub fn n_ubatch(&self) -> u32 {
335 self.context_params.n_ubatch
336 }
337
338 /// Set the `flash_attention` parameter
339 ///
340 /// # Examples
341 ///
342 /// ```rust
343 /// use llama_cpp_2::context::params::LlamaContextParams;
344 /// let params = LlamaContextParams::default()
345 /// .with_flash_attention(true);
346 /// assert_eq!(params.flash_attention(), true);
347 /// ```
348 #[must_use]
349 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
350 self.context_params.flash_attn = enabled;
351 self
352 }
353
354 /// Get the `flash_attention` parameter
355 ///
356 /// # Examples
357 ///
358 /// ```rust
359 /// use llama_cpp_2::context::params::LlamaContextParams;
360 /// let params = LlamaContextParams::default();
361 /// assert_eq!(params.flash_attention(), false);
362 /// ```
363 #[must_use]
364 pub fn flash_attention(&self) -> bool {
365 self.context_params.flash_attn
366 }
367
368 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
369 ///
370 /// # Examples
371 ///
372 /// ```rust
373 /// use llama_cpp_2::context::params::LlamaContextParams;
374 /// let params = LlamaContextParams::default()
375 /// .with_offload_kqv(false);
376 /// assert_eq!(params.offload_kqv(), false);
377 /// ```
378 #[must_use]
379 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
380 self.context_params.offload_kqv = enabled;
381 self
382 }
383
384 /// Get the `offload_kqv` parameter
385 ///
386 /// # Examples
387 ///
388 /// ```rust
389 /// use llama_cpp_2::context::params::LlamaContextParams;
390 /// let params = LlamaContextParams::default();
391 /// assert_eq!(params.offload_kqv(), true);
392 /// ```
393 #[must_use]
394 pub fn offload_kqv(&self) -> bool {
395 self.context_params.offload_kqv
396 }
397
398 /// Set the type of rope scaling.
399 ///
400 /// # Examples
401 ///
402 /// ```rust
403 /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
404 /// let params = LlamaContextParams::default()
405 /// .with_rope_scaling_type(RopeScalingType::Linear);
406 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
407 /// ```
408 #[must_use]
409 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
410 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
411 self
412 }
413
414 /// Get the type of rope scaling.
415 ///
416 /// # Examples
417 ///
418 /// ```rust
419 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
420 /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
421 /// ```
422 #[must_use]
423 pub fn rope_scaling_type(&self) -> RopeScalingType {
424 RopeScalingType::from(self.context_params.rope_scaling_type)
425 }
426
427 /// Set the rope frequency base.
428 ///
429 /// # Examples
430 ///
431 /// ```rust
432 /// use llama_cpp_2::context::params::LlamaContextParams;
433 /// let params = LlamaContextParams::default()
434 /// .with_rope_freq_base(0.5);
435 /// assert_eq!(params.rope_freq_base(), 0.5);
436 /// ```
437 #[must_use]
438 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
439 self.context_params.rope_freq_base = rope_freq_base;
440 self
441 }
442
443 /// Get the rope frequency base.
444 ///
445 /// # Examples
446 ///
447 /// ```rust
448 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
449 /// assert_eq!(params.rope_freq_base(), 0.0);
450 /// ```
451 #[must_use]
452 pub fn rope_freq_base(&self) -> f32 {
453 self.context_params.rope_freq_base
454 }
455
456 /// Set the rope frequency scale.
457 ///
458 /// # Examples
459 ///
460 /// ```rust
461 /// use llama_cpp_2::context::params::LlamaContextParams;
462 /// let params = LlamaContextParams::default()
463 /// .with_rope_freq_scale(0.5);
464 /// assert_eq!(params.rope_freq_scale(), 0.5);
465 /// ```
466 #[must_use]
467 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
468 self.context_params.rope_freq_scale = rope_freq_scale;
469 self
470 }
471
472 /// Get the rope frequency scale.
473 ///
474 /// # Examples
475 ///
476 /// ```rust
477 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
478 /// assert_eq!(params.rope_freq_scale(), 0.0);
479 /// ```
480 #[must_use]
481 pub fn rope_freq_scale(&self) -> f32 {
482 self.context_params.rope_freq_scale
483 }
484
485 /// Get the number of threads.
486 ///
487 /// # Examples
488 ///
489 /// ```rust
490 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
491 /// assert_eq!(params.n_threads(), 4);
492 /// ```
493 #[must_use]
494 pub fn n_threads(&self) -> i32 {
495 self.context_params.n_threads
496 }
497
498 /// Get the number of threads allocated for batches.
499 ///
500 /// # Examples
501 ///
502 /// ```rust
503 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
504 /// assert_eq!(params.n_threads_batch(), 4);
505 /// ```
506 #[must_use]
507 pub fn n_threads_batch(&self) -> i32 {
508 self.context_params.n_threads_batch
509 }
510
511 /// Set the number of threads.
512 ///
513 /// # Examples
514 ///
515 /// ```rust
516 /// use llama_cpp_2::context::params::LlamaContextParams;
517 /// let params = LlamaContextParams::default()
518 /// .with_n_threads(8);
519 /// assert_eq!(params.n_threads(), 8);
520 /// ```
521 #[must_use]
522 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
523 self.context_params.n_threads = n_threads;
524 self
525 }
526
527 /// Set the number of threads allocated for batches.
528 ///
529 /// # Examples
530 ///
531 /// ```rust
532 /// use llama_cpp_2::context::params::LlamaContextParams;
533 /// let params = LlamaContextParams::default()
534 /// .with_n_threads_batch(8);
535 /// assert_eq!(params.n_threads_batch(), 8);
536 /// ```
537 #[must_use]
538 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
539 self.context_params.n_threads_batch = n_threads;
540 self
541 }
542
543 /// Check whether embeddings are enabled
544 ///
545 /// # Examples
546 ///
547 /// ```rust
548 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
549 /// assert!(!params.embeddings());
550 /// ```
551 #[must_use]
552 pub fn embeddings(&self) -> bool {
553 self.context_params.embeddings
554 }
555
556 /// Enable the use of embeddings
557 ///
558 /// # Examples
559 ///
560 /// ```rust
561 /// use llama_cpp_2::context::params::LlamaContextParams;
562 /// let params = LlamaContextParams::default()
563 /// .with_embeddings(true);
564 /// assert!(params.embeddings());
565 /// ```
566 #[must_use]
567 pub fn with_embeddings(mut self, embedding: bool) -> Self {
568 self.context_params.embeddings = embedding;
569 self
570 }
571
572 /// Set the evaluation callback.
573 ///
574 /// # Examples
575 ///
576 /// ```no_run
577 /// extern "C" fn cb_eval_fn(
578 /// t: *mut llama_cpp_sys_2::ggml_tensor,
579 /// ask: bool,
580 /// user_data: *mut std::ffi::c_void,
581 /// ) -> bool {
582 /// false
583 /// }
584 ///
585 /// use llama_cpp_2::context::params::LlamaContextParams;
586 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
587 /// ```
588 #[must_use]
589 pub fn with_cb_eval(
590 mut self,
591 cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
592 ) -> Self {
593 self.context_params.cb_eval = cb_eval;
594 self
595 }
596
597 /// Set the evaluation callback user data.
598 ///
599 /// # Examples
600 ///
601 /// ```no_run
602 /// use llama_cpp_2::context::params::LlamaContextParams;
603 /// let params = LlamaContextParams::default();
604 /// let user_data = std::ptr::null_mut();
605 /// let params = params.with_cb_eval_user_data(user_data);
606 /// ```
607 #[must_use]
608 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
609 self.context_params.cb_eval_user_data = cb_eval_user_data;
610 self
611 }
612
613 /// Set the type of pooling.
614 ///
615 /// # Examples
616 ///
617 /// ```rust
618 /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
619 /// let params = LlamaContextParams::default()
620 /// .with_pooling_type(LlamaPoolingType::Last);
621 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
622 /// ```
623 #[must_use]
624 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
625 self.context_params.pooling_type = i32::from(pooling_type);
626 self
627 }
628
629 /// Get the type of pooling.
630 ///
631 /// # Examples
632 ///
633 /// ```rust
634 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
635 /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
636 /// ```
637 #[must_use]
638 pub fn pooling_type(&self) -> LlamaPoolingType {
639 LlamaPoolingType::from(self.context_params.pooling_type)
640 }
641
642 /// Set whether to use full sliding window attention
643 ///
644 /// # Examples
645 ///
646 /// ```rust
647 /// use llama_cpp_2::context::params::LlamaContextParams;
648 /// let params = LlamaContextParams::default()
649 /// .with_swa_full(false);
650 /// assert_eq!(params.swa_full(), false);
651 /// ```
652 #[must_use]
653 pub fn with_swa_full(mut self, enabled: bool) -> Self {
654 self.context_params.swa_full = enabled;
655 self
656 }
657
658 /// Get whether full sliding window attention is enabled
659 ///
660 /// # Examples
661 ///
662 /// ```rust
663 /// use llama_cpp_2::context::params::LlamaContextParams;
664 /// let params = LlamaContextParams::default();
665 /// assert_eq!(params.swa_full(), true);
666 /// ```
667 #[must_use]
668 pub fn swa_full(&self) -> bool {
669 self.context_params.swa_full
670 }
671
672 /// Set the max number of sequences (i.e. distinct states for recurrent models)
673 ///
674 /// # Examples
675 ///
676 /// ```rust
677 /// use llama_cpp_2::context::params::LlamaContextParams;
678 /// let params = LlamaContextParams::default()
679 /// .with_n_seq_max(64);
680 /// assert_eq!(params.n_seq_max(), 64);
681 /// ```
682 #[must_use]
683 pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
684 self.context_params.n_seq_max = n_seq_max;
685 self
686 }
687
688 /// Get the max number of sequences (i.e. distinct states for recurrent models)
689 ///
690 /// # Examples
691 ///
692 /// ```rust
693 /// use llama_cpp_2::context::params::LlamaContextParams;
694 /// let params = LlamaContextParams::default();
695 /// assert_eq!(params.n_seq_max(), 1);
696 /// ```
697 #[must_use]
698 pub fn n_seq_max(&self) -> u32 {
699 self.context_params.n_seq_max
700 }
701 /// Set the KV cache data type for K
702 /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
703 /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
704 /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
705 /// ```
706 #[must_use]
707 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
708 self.context_params.type_k = type_k.into();
709 self
710 }
711
712 /// Get the KV cache data type for K
713 ///
714 /// # Examples
715 ///
716 /// ```rust
717 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
718 /// let _ = params.type_k();
719 /// ```
720 #[must_use]
721 pub fn type_k(&self) -> KvCacheType {
722 KvCacheType::from(self.context_params.type_k)
723 }
724
725 /// Set the KV cache data type for V
726 ///
727 /// # Examples
728 ///
729 /// ```rust
730 /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
731 /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
732 /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
733 /// ```
734 #[must_use]
735 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
736 self.context_params.type_v = type_v.into();
737 self
738 }
739
740 /// Get the KV cache data type for V
741 ///
742 /// # Examples
743 ///
744 /// ```rust
745 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
746 /// let _ = params.type_v();
747 /// ```
748 #[must_use]
749 pub fn type_v(&self) -> KvCacheType {
750 KvCacheType::from(self.context_params.type_v)
751 }
752}
753
754/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
755/// ```
756/// # use std::num::NonZeroU32;
757/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
758/// let params = LlamaContextParams::default();
759/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
760/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
761/// ```
762impl Default for LlamaContextParams {
763 fn default() -> Self {
764 let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
765 Self { context_params }
766 }
767}