llama_cpp_2/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58 /// Rank pooling
59 Rank = 4,
60}
61
62/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
63/// the value is not recognized.
64impl From<i32> for LlamaPoolingType {
65 fn from(value: i32) -> Self {
66 match value {
67 0 => Self::None,
68 1 => Self::Mean,
69 2 => Self::Cls,
70 3 => Self::Last,
71 4 => Self::Rank,
72 _ => Self::Unspecified,
73 }
74 }
75}
76
77/// Create a `c_int` from a `LlamaPoolingType`.
78impl From<LlamaPoolingType> for i32 {
79 fn from(value: LlamaPoolingType) -> Self {
80 match value {
81 LlamaPoolingType::None => 0,
82 LlamaPoolingType::Mean => 1,
83 LlamaPoolingType::Cls => 2,
84 LlamaPoolingType::Last => 3,
85 LlamaPoolingType::Rank => 4,
86 LlamaPoolingType::Unspecified => -1,
87 }
88 }
89}
90
91/// A rusty wrapper around `ggml_type` for KV cache types.
92#[allow(non_camel_case_types, missing_docs)]
93#[derive(Copy, Clone, Debug, PartialEq, Eq)]
94pub enum KvCacheType {
95 /// Represents an unknown or not-yet-mapped `ggml_type` and carries the raw value.
96 /// When passed through FFI, the raw value is used as-is (if llama.cpp supports it,
97 /// the runtime will operate with that type).
98 /// This variant preserves API compatibility when new `ggml_type` values are
99 /// introduced in the future.
100 Unknown(llama_cpp_sys_2::ggml_type),
101 F32,
102 F16,
103 Q4_0,
104 Q4_1,
105 Q5_0,
106 Q5_1,
107 Q8_0,
108 Q8_1,
109 Q2_K,
110 Q3_K,
111 Q4_K,
112 Q5_K,
113 Q6_K,
114 Q8_K,
115 IQ2_XXS,
116 IQ2_XS,
117 IQ3_XXS,
118 IQ1_S,
119 IQ4_NL,
120 IQ3_S,
121 IQ2_S,
122 IQ4_XS,
123 I8,
124 I16,
125 I32,
126 I64,
127 F64,
128 IQ1_M,
129 BF16,
130 TQ1_0,
131 TQ2_0,
132 MXFP4,
133}
134
135impl From<KvCacheType> for llama_cpp_sys_2::ggml_type {
136 fn from(value: KvCacheType) -> Self {
137 match value {
138 KvCacheType::Unknown(raw) => raw,
139 KvCacheType::F32 => llama_cpp_sys_2::GGML_TYPE_F32,
140 KvCacheType::F16 => llama_cpp_sys_2::GGML_TYPE_F16,
141 KvCacheType::Q4_0 => llama_cpp_sys_2::GGML_TYPE_Q4_0,
142 KvCacheType::Q4_1 => llama_cpp_sys_2::GGML_TYPE_Q4_1,
143 KvCacheType::Q5_0 => llama_cpp_sys_2::GGML_TYPE_Q5_0,
144 KvCacheType::Q5_1 => llama_cpp_sys_2::GGML_TYPE_Q5_1,
145 KvCacheType::Q8_0 => llama_cpp_sys_2::GGML_TYPE_Q8_0,
146 KvCacheType::Q8_1 => llama_cpp_sys_2::GGML_TYPE_Q8_1,
147 KvCacheType::Q2_K => llama_cpp_sys_2::GGML_TYPE_Q2_K,
148 KvCacheType::Q3_K => llama_cpp_sys_2::GGML_TYPE_Q3_K,
149 KvCacheType::Q4_K => llama_cpp_sys_2::GGML_TYPE_Q4_K,
150 KvCacheType::Q5_K => llama_cpp_sys_2::GGML_TYPE_Q5_K,
151 KvCacheType::Q6_K => llama_cpp_sys_2::GGML_TYPE_Q6_K,
152 KvCacheType::Q8_K => llama_cpp_sys_2::GGML_TYPE_Q8_K,
153 KvCacheType::IQ2_XXS => llama_cpp_sys_2::GGML_TYPE_IQ2_XXS,
154 KvCacheType::IQ2_XS => llama_cpp_sys_2::GGML_TYPE_IQ2_XS,
155 KvCacheType::IQ3_XXS => llama_cpp_sys_2::GGML_TYPE_IQ3_XXS,
156 KvCacheType::IQ1_S => llama_cpp_sys_2::GGML_TYPE_IQ1_S,
157 KvCacheType::IQ4_NL => llama_cpp_sys_2::GGML_TYPE_IQ4_NL,
158 KvCacheType::IQ3_S => llama_cpp_sys_2::GGML_TYPE_IQ3_S,
159 KvCacheType::IQ2_S => llama_cpp_sys_2::GGML_TYPE_IQ2_S,
160 KvCacheType::IQ4_XS => llama_cpp_sys_2::GGML_TYPE_IQ4_XS,
161 KvCacheType::I8 => llama_cpp_sys_2::GGML_TYPE_I8,
162 KvCacheType::I16 => llama_cpp_sys_2::GGML_TYPE_I16,
163 KvCacheType::I32 => llama_cpp_sys_2::GGML_TYPE_I32,
164 KvCacheType::I64 => llama_cpp_sys_2::GGML_TYPE_I64,
165 KvCacheType::F64 => llama_cpp_sys_2::GGML_TYPE_F64,
166 KvCacheType::IQ1_M => llama_cpp_sys_2::GGML_TYPE_IQ1_M,
167 KvCacheType::BF16 => llama_cpp_sys_2::GGML_TYPE_BF16,
168 KvCacheType::TQ1_0 => llama_cpp_sys_2::GGML_TYPE_TQ1_0,
169 KvCacheType::TQ2_0 => llama_cpp_sys_2::GGML_TYPE_TQ2_0,
170 KvCacheType::MXFP4 => llama_cpp_sys_2::GGML_TYPE_MXFP4,
171 }
172 }
173}
174
175impl From<llama_cpp_sys_2::ggml_type> for KvCacheType {
176 fn from(value: llama_cpp_sys_2::ggml_type) -> Self {
177 match value {
178 x if x == llama_cpp_sys_2::GGML_TYPE_F32 => KvCacheType::F32,
179 x if x == llama_cpp_sys_2::GGML_TYPE_F16 => KvCacheType::F16,
180 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_0 => KvCacheType::Q4_0,
181 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_1 => KvCacheType::Q4_1,
182 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_0 => KvCacheType::Q5_0,
183 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_1 => KvCacheType::Q5_1,
184 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_0 => KvCacheType::Q8_0,
185 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_1 => KvCacheType::Q8_1,
186 x if x == llama_cpp_sys_2::GGML_TYPE_Q2_K => KvCacheType::Q2_K,
187 x if x == llama_cpp_sys_2::GGML_TYPE_Q3_K => KvCacheType::Q3_K,
188 x if x == llama_cpp_sys_2::GGML_TYPE_Q4_K => KvCacheType::Q4_K,
189 x if x == llama_cpp_sys_2::GGML_TYPE_Q5_K => KvCacheType::Q5_K,
190 x if x == llama_cpp_sys_2::GGML_TYPE_Q6_K => KvCacheType::Q6_K,
191 x if x == llama_cpp_sys_2::GGML_TYPE_Q8_K => KvCacheType::Q8_K,
192 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XXS => KvCacheType::IQ2_XXS,
193 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_XS => KvCacheType::IQ2_XS,
194 x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_XXS => KvCacheType::IQ3_XXS,
195 x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_S => KvCacheType::IQ1_S,
196 x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_NL => KvCacheType::IQ4_NL,
197 x if x == llama_cpp_sys_2::GGML_TYPE_IQ3_S => KvCacheType::IQ3_S,
198 x if x == llama_cpp_sys_2::GGML_TYPE_IQ2_S => KvCacheType::IQ2_S,
199 x if x == llama_cpp_sys_2::GGML_TYPE_IQ4_XS => KvCacheType::IQ4_XS,
200 x if x == llama_cpp_sys_2::GGML_TYPE_I8 => KvCacheType::I8,
201 x if x == llama_cpp_sys_2::GGML_TYPE_I16 => KvCacheType::I16,
202 x if x == llama_cpp_sys_2::GGML_TYPE_I32 => KvCacheType::I32,
203 x if x == llama_cpp_sys_2::GGML_TYPE_I64 => KvCacheType::I64,
204 x if x == llama_cpp_sys_2::GGML_TYPE_F64 => KvCacheType::F64,
205 x if x == llama_cpp_sys_2::GGML_TYPE_IQ1_M => KvCacheType::IQ1_M,
206 x if x == llama_cpp_sys_2::GGML_TYPE_BF16 => KvCacheType::BF16,
207 x if x == llama_cpp_sys_2::GGML_TYPE_TQ1_0 => KvCacheType::TQ1_0,
208 x if x == llama_cpp_sys_2::GGML_TYPE_TQ2_0 => KvCacheType::TQ2_0,
209 x if x == llama_cpp_sys_2::GGML_TYPE_MXFP4 => KvCacheType::MXFP4,
210 _ => KvCacheType::Unknown(value),
211 }
212 }
213}
214
215/// A safe wrapper around `llama_context_params`.
216///
217/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
218///
219/// # Examples
220///
221/// ```rust
222/// # use std::num::NonZeroU32;
223/// use llama_cpp_2::context::params::LlamaContextParams;
224///
225///let ctx_params = LlamaContextParams::default()
226/// .with_n_ctx(NonZeroU32::new(2048));
227///
228/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
229/// ```
230#[derive(Debug, Clone)]
231#[allow(
232 missing_docs,
233 clippy::struct_excessive_bools,
234 clippy::module_name_repetitions
235)]
236pub struct LlamaContextParams {
237 pub(crate) context_params: llama_cpp_sys_2::llama_context_params,
238}
239
240/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
241unsafe impl Send for LlamaContextParams {}
242unsafe impl Sync for LlamaContextParams {}
243
244impl LlamaContextParams {
245 /// Set the side of the context
246 ///
247 /// # Examples
248 ///
249 /// ```rust
250 /// # use std::num::NonZeroU32;
251 /// use llama_cpp_2::context::params::LlamaContextParams;
252 /// let params = LlamaContextParams::default();
253 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
254 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
255 /// ```
256 #[must_use]
257 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
258 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
259 self
260 }
261
262 /// Get the size of the context.
263 ///
264 /// [`None`] if the context size is specified by the model and not the context.
265 ///
266 /// # Examples
267 ///
268 /// ```rust
269 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
270 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
271 #[must_use]
272 pub fn n_ctx(&self) -> Option<NonZeroU32> {
273 NonZeroU32::new(self.context_params.n_ctx)
274 }
275
276 /// Set the `n_batch`
277 ///
278 /// # Examples
279 ///
280 /// ```rust
281 /// # use std::num::NonZeroU32;
282 /// use llama_cpp_2::context::params::LlamaContextParams;
283 /// let params = LlamaContextParams::default()
284 /// .with_n_batch(2048);
285 /// assert_eq!(params.n_batch(), 2048);
286 /// ```
287 #[must_use]
288 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
289 self.context_params.n_batch = n_batch;
290 self
291 }
292
293 /// Get the `n_batch`
294 ///
295 /// # Examples
296 ///
297 /// ```rust
298 /// use llama_cpp_2::context::params::LlamaContextParams;
299 /// let params = LlamaContextParams::default();
300 /// assert_eq!(params.n_batch(), 2048);
301 /// ```
302 #[must_use]
303 pub fn n_batch(&self) -> u32 {
304 self.context_params.n_batch
305 }
306
307 /// Set the `n_ubatch`
308 ///
309 /// # Examples
310 ///
311 /// ```rust
312 /// # use std::num::NonZeroU32;
313 /// use llama_cpp_2::context::params::LlamaContextParams;
314 /// let params = LlamaContextParams::default()
315 /// .with_n_ubatch(512);
316 /// assert_eq!(params.n_ubatch(), 512);
317 /// ```
318 #[must_use]
319 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
320 self.context_params.n_ubatch = n_ubatch;
321 self
322 }
323
324 /// Get the `n_ubatch`
325 ///
326 /// # Examples
327 ///
328 /// ```rust
329 /// use llama_cpp_2::context::params::LlamaContextParams;
330 /// let params = LlamaContextParams::default();
331 /// assert_eq!(params.n_ubatch(), 512);
332 /// ```
333 #[must_use]
334 pub fn n_ubatch(&self) -> u32 {
335 self.context_params.n_ubatch
336 }
337
338 /// Set the flash attention policy using llama.cpp enum
339 #[must_use]
340 pub fn with_flash_attention_policy(
341 mut self,
342 policy: llama_cpp_sys_2::llama_flash_attn_type,
343 ) -> Self {
344 self.context_params.flash_attn_type = policy;
345 self
346 }
347
348 /// Get the flash attention policy
349 #[must_use]
350 pub fn flash_attention_policy(&self) -> llama_cpp_sys_2::llama_flash_attn_type {
351 self.context_params.flash_attn_type
352 }
353
354 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
355 ///
356 /// # Examples
357 ///
358 /// ```rust
359 /// use llama_cpp_2::context::params::LlamaContextParams;
360 /// let params = LlamaContextParams::default()
361 /// .with_offload_kqv(false);
362 /// assert_eq!(params.offload_kqv(), false);
363 /// ```
364 #[must_use]
365 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
366 self.context_params.offload_kqv = enabled;
367 self
368 }
369
370 /// Get the `offload_kqv` parameter
371 ///
372 /// # Examples
373 ///
374 /// ```rust
375 /// use llama_cpp_2::context::params::LlamaContextParams;
376 /// let params = LlamaContextParams::default();
377 /// assert_eq!(params.offload_kqv(), true);
378 /// ```
379 #[must_use]
380 pub fn offload_kqv(&self) -> bool {
381 self.context_params.offload_kqv
382 }
383
384 /// Set the type of rope scaling.
385 ///
386 /// # Examples
387 ///
388 /// ```rust
389 /// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
390 /// let params = LlamaContextParams::default()
391 /// .with_rope_scaling_type(RopeScalingType::Linear);
392 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
393 /// ```
394 #[must_use]
395 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
396 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
397 self
398 }
399
400 /// Get the type of rope scaling.
401 ///
402 /// # Examples
403 ///
404 /// ```rust
405 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
406 /// assert_eq!(params.rope_scaling_type(), llama_cpp_2::context::params::RopeScalingType::Unspecified);
407 /// ```
408 #[must_use]
409 pub fn rope_scaling_type(&self) -> RopeScalingType {
410 RopeScalingType::from(self.context_params.rope_scaling_type)
411 }
412
413 /// Set the rope frequency base.
414 ///
415 /// # Examples
416 ///
417 /// ```rust
418 /// use llama_cpp_2::context::params::LlamaContextParams;
419 /// let params = LlamaContextParams::default()
420 /// .with_rope_freq_base(0.5);
421 /// assert_eq!(params.rope_freq_base(), 0.5);
422 /// ```
423 #[must_use]
424 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
425 self.context_params.rope_freq_base = rope_freq_base;
426 self
427 }
428
429 /// Get the rope frequency base.
430 ///
431 /// # Examples
432 ///
433 /// ```rust
434 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
435 /// assert_eq!(params.rope_freq_base(), 0.0);
436 /// ```
437 #[must_use]
438 pub fn rope_freq_base(&self) -> f32 {
439 self.context_params.rope_freq_base
440 }
441
442 /// Set the rope frequency scale.
443 ///
444 /// # Examples
445 ///
446 /// ```rust
447 /// use llama_cpp_2::context::params::LlamaContextParams;
448 /// let params = LlamaContextParams::default()
449 /// .with_rope_freq_scale(0.5);
450 /// assert_eq!(params.rope_freq_scale(), 0.5);
451 /// ```
452 #[must_use]
453 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
454 self.context_params.rope_freq_scale = rope_freq_scale;
455 self
456 }
457
458 /// Get the rope frequency scale.
459 ///
460 /// # Examples
461 ///
462 /// ```rust
463 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
464 /// assert_eq!(params.rope_freq_scale(), 0.0);
465 /// ```
466 #[must_use]
467 pub fn rope_freq_scale(&self) -> f32 {
468 self.context_params.rope_freq_scale
469 }
470
471 /// Get the number of threads.
472 ///
473 /// # Examples
474 ///
475 /// ```rust
476 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
477 /// assert_eq!(params.n_threads(), 4);
478 /// ```
479 #[must_use]
480 pub fn n_threads(&self) -> i32 {
481 self.context_params.n_threads
482 }
483
484 /// Get the number of threads allocated for batches.
485 ///
486 /// # Examples
487 ///
488 /// ```rust
489 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
490 /// assert_eq!(params.n_threads_batch(), 4);
491 /// ```
492 #[must_use]
493 pub fn n_threads_batch(&self) -> i32 {
494 self.context_params.n_threads_batch
495 }
496
497 /// Set the number of threads.
498 ///
499 /// # Examples
500 ///
501 /// ```rust
502 /// use llama_cpp_2::context::params::LlamaContextParams;
503 /// let params = LlamaContextParams::default()
504 /// .with_n_threads(8);
505 /// assert_eq!(params.n_threads(), 8);
506 /// ```
507 #[must_use]
508 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
509 self.context_params.n_threads = n_threads;
510 self
511 }
512
513 /// Set the number of threads allocated for batches.
514 ///
515 /// # Examples
516 ///
517 /// ```rust
518 /// use llama_cpp_2::context::params::LlamaContextParams;
519 /// let params = LlamaContextParams::default()
520 /// .with_n_threads_batch(8);
521 /// assert_eq!(params.n_threads_batch(), 8);
522 /// ```
523 #[must_use]
524 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
525 self.context_params.n_threads_batch = n_threads;
526 self
527 }
528
529 /// Check whether embeddings are enabled
530 ///
531 /// # Examples
532 ///
533 /// ```rust
534 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
535 /// assert!(!params.embeddings());
536 /// ```
537 #[must_use]
538 pub fn embeddings(&self) -> bool {
539 self.context_params.embeddings
540 }
541
542 /// Enable the use of embeddings
543 ///
544 /// # Examples
545 ///
546 /// ```rust
547 /// use llama_cpp_2::context::params::LlamaContextParams;
548 /// let params = LlamaContextParams::default()
549 /// .with_embeddings(true);
550 /// assert!(params.embeddings());
551 /// ```
552 #[must_use]
553 pub fn with_embeddings(mut self, embedding: bool) -> Self {
554 self.context_params.embeddings = embedding;
555 self
556 }
557
558 /// Set the evaluation callback.
559 ///
560 /// # Examples
561 ///
562 /// ```no_run
563 /// extern "C" fn cb_eval_fn(
564 /// t: *mut llama_cpp_sys_2::ggml_tensor,
565 /// ask: bool,
566 /// user_data: *mut std::ffi::c_void,
567 /// ) -> bool {
568 /// false
569 /// }
570 ///
571 /// use llama_cpp_2::context::params::LlamaContextParams;
572 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
573 /// ```
574 #[must_use]
575 pub fn with_cb_eval(
576 mut self,
577 cb_eval: llama_cpp_sys_2::ggml_backend_sched_eval_callback,
578 ) -> Self {
579 self.context_params.cb_eval = cb_eval;
580 self
581 }
582
583 /// Set the evaluation callback user data.
584 ///
585 /// # Examples
586 ///
587 /// ```no_run
588 /// use llama_cpp_2::context::params::LlamaContextParams;
589 /// let params = LlamaContextParams::default();
590 /// let user_data = std::ptr::null_mut();
591 /// let params = params.with_cb_eval_user_data(user_data);
592 /// ```
593 #[must_use]
594 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
595 self.context_params.cb_eval_user_data = cb_eval_user_data;
596 self
597 }
598
599 /// Set the type of pooling.
600 ///
601 /// # Examples
602 ///
603 /// ```rust
604 /// use llama_cpp_2::context::params::{LlamaContextParams, LlamaPoolingType};
605 /// let params = LlamaContextParams::default()
606 /// .with_pooling_type(LlamaPoolingType::Last);
607 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
608 /// ```
609 #[must_use]
610 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
611 self.context_params.pooling_type = i32::from(pooling_type);
612 self
613 }
614
615 /// Get the type of pooling.
616 ///
617 /// # Examples
618 ///
619 /// ```rust
620 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
621 /// assert_eq!(params.pooling_type(), llama_cpp_2::context::params::LlamaPoolingType::Unspecified);
622 /// ```
623 #[must_use]
624 pub fn pooling_type(&self) -> LlamaPoolingType {
625 LlamaPoolingType::from(self.context_params.pooling_type)
626 }
627
628 /// Set whether to use full sliding window attention
629 ///
630 /// # Examples
631 ///
632 /// ```rust
633 /// use llama_cpp_2::context::params::LlamaContextParams;
634 /// let params = LlamaContextParams::default()
635 /// .with_swa_full(false);
636 /// assert_eq!(params.swa_full(), false);
637 /// ```
638 #[must_use]
639 pub fn with_swa_full(mut self, enabled: bool) -> Self {
640 self.context_params.swa_full = enabled;
641 self
642 }
643
644 /// Get whether full sliding window attention is enabled
645 ///
646 /// # Examples
647 ///
648 /// ```rust
649 /// use llama_cpp_2::context::params::LlamaContextParams;
650 /// let params = LlamaContextParams::default();
651 /// assert_eq!(params.swa_full(), true);
652 /// ```
653 #[must_use]
654 pub fn swa_full(&self) -> bool {
655 self.context_params.swa_full
656 }
657
658 /// Set the max number of sequences (i.e. distinct states for recurrent models)
659 ///
660 /// # Examples
661 ///
662 /// ```rust
663 /// use llama_cpp_2::context::params::LlamaContextParams;
664 /// let params = LlamaContextParams::default()
665 /// .with_n_seq_max(64);
666 /// assert_eq!(params.n_seq_max(), 64);
667 /// ```
668 #[must_use]
669 pub fn with_n_seq_max(mut self, n_seq_max: u32) -> Self {
670 self.context_params.n_seq_max = n_seq_max;
671 self
672 }
673
674 /// Get the max number of sequences (i.e. distinct states for recurrent models)
675 ///
676 /// # Examples
677 ///
678 /// ```rust
679 /// use llama_cpp_2::context::params::LlamaContextParams;
680 /// let params = LlamaContextParams::default();
681 /// assert_eq!(params.n_seq_max(), 1);
682 /// ```
683 #[must_use]
684 pub fn n_seq_max(&self) -> u32 {
685 self.context_params.n_seq_max
686 }
687 /// Set the KV cache data type for K
688 /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
689 /// let params = LlamaContextParams::default().with_type_k(KvCacheType::Q4_0);
690 /// assert_eq!(params.type_k(), KvCacheType::Q4_0);
691 /// ```
692 #[must_use]
693 pub fn with_type_k(mut self, type_k: KvCacheType) -> Self {
694 self.context_params.type_k = type_k.into();
695 self
696 }
697
698 /// Get the KV cache data type for K
699 ///
700 /// # Examples
701 ///
702 /// ```rust
703 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
704 /// let _ = params.type_k();
705 /// ```
706 #[must_use]
707 pub fn type_k(&self) -> KvCacheType {
708 KvCacheType::from(self.context_params.type_k)
709 }
710
711 /// Set the KV cache data type for V
712 ///
713 /// # Examples
714 ///
715 /// ```rust
716 /// use llama_cpp_2::context::params::{LlamaContextParams, KvCacheType};
717 /// let params = LlamaContextParams::default().with_type_v(KvCacheType::Q4_1);
718 /// assert_eq!(params.type_v(), KvCacheType::Q4_1);
719 /// ```
720 #[must_use]
721 pub fn with_type_v(mut self, type_v: KvCacheType) -> Self {
722 self.context_params.type_v = type_v.into();
723 self
724 }
725
726 /// Get the KV cache data type for V
727 ///
728 /// # Examples
729 ///
730 /// ```rust
731 /// let params = llama_cpp_2::context::params::LlamaContextParams::default();
732 /// let _ = params.type_v();
733 /// ```
734 #[must_use]
735 pub fn type_v(&self) -> KvCacheType {
736 KvCacheType::from(self.context_params.type_v)
737 }
738}
739
740/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
741/// ```
742/// # use std::num::NonZeroU32;
743/// use llama_cpp_2::context::params::{LlamaContextParams, RopeScalingType};
744/// let params = LlamaContextParams::default();
745/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
746/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
747/// ```
748impl Default for LlamaContextParams {
749 fn default() -> Self {
750 let context_params = unsafe { llama_cpp_sys_2::llama_context_default_params() };
751 Self { context_params }
752 }
753}