llama_cpp_4/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63 fn from(value: i32) -> Self {
64 match value {
65 0 => Self::None,
66 1 => Self::Mean,
67 2 => Self::Cls,
68 3 => Self::Last,
69 _ => Self::Unspecified,
70 }
71 }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76 fn from(value: LlamaPoolingType) -> Self {
77 match value {
78 LlamaPoolingType::None => 0,
79 LlamaPoolingType::Mean => 1,
80 LlamaPoolingType::Cls => 2,
81 LlamaPoolingType::Last => 3,
82 LlamaPoolingType::Unspecified => -1,
83 }
84 }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97/// let ctx_params = LlamaContextParams::default()
98/// .with_n_ctx(NonZeroU32::new(2048));
99///
100/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
101/// ```
102#[derive(Debug, Clone)]
103#[allow(
104 missing_docs,
105 clippy::struct_excessive_bools,
106 clippy::module_name_repetitions
107)]
108pub struct LlamaContextParams {
109 pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
110 /// When `true`, the `TurboQuant` attention rotation (PR #21038) will be
111 /// disabled for any context created from these params.
112 pub(crate) attn_rot_disabled: bool,
113}
114
115/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
116unsafe impl Send for LlamaContextParams {}
117unsafe impl Sync for LlamaContextParams {}
118
119impl LlamaContextParams {
120 /// Set the side of the context
121 ///
122 /// # Examples
123 ///
124 /// ```rust
125 /// # use std::num::NonZeroU32;
126 /// use llama_cpp_4::context::params::LlamaContextParams;
127 /// let params = LlamaContextParams::default();
128 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
129 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
130 /// ```
131 #[must_use]
132 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
133 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
134 self
135 }
136
137 /// Get the size of the context.
138 ///
139 /// [`None`] if the context size is specified by the model and not the context.
140 ///
141 /// # Examples
142 ///
143 /// ```rust
144 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
145 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
146 #[must_use]
147 pub fn n_ctx(&self) -> Option<NonZeroU32> {
148 NonZeroU32::new(self.context_params.n_ctx)
149 }
150
151 /// Set the `n_batch`
152 ///
153 /// # Examples
154 ///
155 /// ```rust
156 /// # use std::num::NonZeroU32;
157 /// use llama_cpp_4::context::params::LlamaContextParams;
158 /// let params = LlamaContextParams::default()
159 /// .with_n_batch(2048);
160 /// assert_eq!(params.n_batch(), 2048);
161 /// ```
162 #[must_use]
163 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
164 self.context_params.n_batch = n_batch;
165 self
166 }
167
168 /// Get the `n_batch`
169 ///
170 /// # Examples
171 ///
172 /// ```rust
173 /// use llama_cpp_4::context::params::LlamaContextParams;
174 /// let params = LlamaContextParams::default();
175 /// assert_eq!(params.n_batch(), 2048);
176 /// ```
177 #[must_use]
178 pub fn n_batch(&self) -> u32 {
179 self.context_params.n_batch
180 }
181
182 /// Set the `n_ubatch`
183 ///
184 /// # Examples
185 ///
186 /// ```rust
187 /// # use std::num::NonZeroU32;
188 /// use llama_cpp_4::context::params::LlamaContextParams;
189 /// let params = LlamaContextParams::default()
190 /// .with_n_ubatch(512);
191 /// assert_eq!(params.n_ubatch(), 512);
192 /// ```
193 #[must_use]
194 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
195 self.context_params.n_ubatch = n_ubatch;
196 self
197 }
198
199 /// Get the `n_ubatch`
200 ///
201 /// # Examples
202 ///
203 /// ```rust
204 /// use llama_cpp_4::context::params::LlamaContextParams;
205 /// let params = LlamaContextParams::default();
206 /// assert_eq!(params.n_ubatch(), 512);
207 /// ```
208 #[must_use]
209 pub fn n_ubatch(&self) -> u32 {
210 self.context_params.n_ubatch
211 }
212
213 /// Set the number of recurrent-state snapshots per sequence used for MTP rollback.
214 ///
215 /// This is only available when built with the `mtp` feature.
216 #[cfg(feature = "mtp")]
217 #[must_use]
218 pub fn with_n_rs_seq(mut self, n_rs_seq: u32) -> Self {
219 self.context_params.n_rs_seq = n_rs_seq;
220 self
221 }
222
223 /// Get the number of recurrent-state snapshots per sequence used for MTP rollback.
224 ///
225 /// This is only available when built with the `mtp` feature.
226 #[cfg(feature = "mtp")]
227 #[must_use]
228 pub fn n_rs_seq(&self) -> u32 {
229 self.context_params.n_rs_seq
230 }
231
232 /// Set the `flash_attention` parameter
233 ///
234 /// # Examples
235 ///
236 /// ```rust
237 /// use llama_cpp_4::context::params::LlamaContextParams;
238 /// let params = LlamaContextParams::default()
239 /// .with_flash_attention(true);
240 /// assert_eq!(params.flash_attention(), true);
241 /// ```
242 #[must_use]
243 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
244 self.context_params.flash_attn_type = if enabled {
245 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
246 } else {
247 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
248 };
249 self
250 }
251
252 /// Get the `flash_attention` parameter
253 ///
254 /// # Examples
255 ///
256 /// ```rust
257 /// use llama_cpp_4::context::params::LlamaContextParams;
258 /// let params = LlamaContextParams::default();
259 /// assert_eq!(params.flash_attention(), false);
260 /// ```
261 #[must_use]
262 pub fn flash_attention(&self) -> bool {
263 self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
264 }
265
266 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
267 ///
268 /// # Examples
269 ///
270 /// ```rust
271 /// use llama_cpp_4::context::params::LlamaContextParams;
272 /// let params = LlamaContextParams::default()
273 /// .with_offload_kqv(false);
274 /// assert_eq!(params.offload_kqv(), false);
275 /// ```
276 #[must_use]
277 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
278 self.context_params.offload_kqv = enabled;
279 self
280 }
281
282 /// Get the `offload_kqv` parameter
283 ///
284 /// # Examples
285 ///
286 /// ```rust
287 /// use llama_cpp_4::context::params::LlamaContextParams;
288 /// let params = LlamaContextParams::default();
289 /// assert_eq!(params.offload_kqv(), true);
290 /// ```
291 #[must_use]
292 pub fn offload_kqv(&self) -> bool {
293 self.context_params.offload_kqv
294 }
295
296 /// Set the type of rope scaling.
297 ///
298 /// # Examples
299 ///
300 /// ```rust
301 /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
302 /// let params = LlamaContextParams::default()
303 /// .with_rope_scaling_type(RopeScalingType::Linear);
304 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
305 /// ```
306 #[must_use]
307 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
308 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
309 self
310 }
311
312 /// Get the type of rope scaling.
313 ///
314 /// # Examples
315 ///
316 /// ```rust
317 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
318 /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
319 /// ```
320 #[must_use]
321 pub fn rope_scaling_type(&self) -> RopeScalingType {
322 RopeScalingType::from(self.context_params.rope_scaling_type)
323 }
324
325 /// Set the rope frequency base.
326 ///
327 /// # Examples
328 ///
329 /// ```rust
330 /// use llama_cpp_4::context::params::LlamaContextParams;
331 /// let params = LlamaContextParams::default()
332 /// .with_rope_freq_base(0.5);
333 /// assert_eq!(params.rope_freq_base(), 0.5);
334 /// ```
335 #[must_use]
336 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
337 self.context_params.rope_freq_base = rope_freq_base;
338 self
339 }
340
341 /// Get the rope frequency base.
342 ///
343 /// # Examples
344 ///
345 /// ```rust
346 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
347 /// assert_eq!(params.rope_freq_base(), 0.0);
348 /// ```
349 #[must_use]
350 pub fn rope_freq_base(&self) -> f32 {
351 self.context_params.rope_freq_base
352 }
353
354 /// Set the rope frequency scale.
355 ///
356 /// # Examples
357 ///
358 /// ```rust
359 /// use llama_cpp_4::context::params::LlamaContextParams;
360 /// let params = LlamaContextParams::default()
361 /// .with_rope_freq_scale(0.5);
362 /// assert_eq!(params.rope_freq_scale(), 0.5);
363 /// ```
364 #[must_use]
365 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
366 self.context_params.rope_freq_scale = rope_freq_scale;
367 self
368 }
369
370 /// Get the rope frequency scale.
371 ///
372 /// # Examples
373 ///
374 /// ```rust
375 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
376 /// assert_eq!(params.rope_freq_scale(), 0.0);
377 /// ```
378 #[must_use]
379 pub fn rope_freq_scale(&self) -> f32 {
380 self.context_params.rope_freq_scale
381 }
382
383 /// Get the number of threads.
384 ///
385 /// # Examples
386 ///
387 /// ```rust
388 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
389 /// assert_eq!(params.n_threads(), 4);
390 /// ```
391 #[must_use]
392 pub fn n_threads(&self) -> i32 {
393 self.context_params.n_threads
394 }
395
396 /// Get the number of threads allocated for batches.
397 ///
398 /// # Examples
399 ///
400 /// ```rust
401 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
402 /// assert_eq!(params.n_threads_batch(), 4);
403 /// ```
404 #[must_use]
405 pub fn n_threads_batch(&self) -> i32 {
406 self.context_params.n_threads_batch
407 }
408
409 /// Set the number of threads.
410 ///
411 /// # Examples
412 ///
413 /// ```rust
414 /// use llama_cpp_4::context::params::LlamaContextParams;
415 /// let params = LlamaContextParams::default()
416 /// .with_n_threads(8);
417 /// assert_eq!(params.n_threads(), 8);
418 /// ```
419 #[must_use]
420 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
421 self.context_params.n_threads = n_threads;
422 self
423 }
424
425 /// Set the number of threads allocated for batches.
426 ///
427 /// # Examples
428 ///
429 /// ```rust
430 /// use llama_cpp_4::context::params::LlamaContextParams;
431 /// let params = LlamaContextParams::default()
432 /// .with_n_threads_batch(8);
433 /// assert_eq!(params.n_threads_batch(), 8);
434 /// ```
435 #[must_use]
436 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
437 self.context_params.n_threads_batch = n_threads;
438 self
439 }
440
441 /// Check whether embeddings are enabled
442 ///
443 /// # Examples
444 ///
445 /// ```rust
446 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
447 /// assert!(!params.embeddings());
448 /// ```
449 #[must_use]
450 pub fn embeddings(&self) -> bool {
451 self.context_params.embeddings
452 }
453
454 /// Enable the use of embeddings
455 ///
456 /// # Examples
457 ///
458 /// ```rust
459 /// use llama_cpp_4::context::params::LlamaContextParams;
460 /// let params = LlamaContextParams::default()
461 /// .with_embeddings(true);
462 /// assert!(params.embeddings());
463 /// ```
464 #[must_use]
465 pub fn with_embeddings(mut self, embedding: bool) -> Self {
466 self.context_params.embeddings = embedding;
467 self
468 }
469
470 /// Set the evaluation callback.
471 ///
472 /// # Examples
473 ///
474 /// ```no_run
475 /// extern "C" fn cb_eval_fn(
476 /// t: *mut llama_cpp_sys_4::ggml_tensor,
477 /// ask: bool,
478 /// user_data: *mut std::ffi::c_void,
479 /// ) -> bool {
480 /// false
481 /// }
482 ///
483 /// use llama_cpp_4::context::params::LlamaContextParams;
484 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
485 /// ```
486 #[must_use]
487 pub fn with_cb_eval(
488 mut self,
489 cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
490 ) -> Self {
491 self.context_params.cb_eval = cb_eval;
492 self
493 }
494
495 /// Set the evaluation callback user data.
496 ///
497 /// # Examples
498 ///
499 /// ```no_run
500 /// use llama_cpp_4::context::params::LlamaContextParams;
501 /// let params = LlamaContextParams::default();
502 /// let user_data = std::ptr::null_mut();
503 /// let params = params.with_cb_eval_user_data(user_data);
504 /// ```
505 #[must_use]
506 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
507 self.context_params.cb_eval_user_data = cb_eval_user_data;
508 self
509 }
510
511 /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
512 /// intercept intermediate tensor outputs during `decode()`.
513 ///
514 /// This sets up the `cb_eval` callback to capture tensors matching the
515 /// capture's filter (e.g. specific layer outputs). After `decode()` the
516 /// captured data can be read from the `TensorCapture`.
517 ///
518 /// # Example
519 ///
520 /// ```rust,ignore
521 /// use llama_cpp_4::context::params::LlamaContextParams;
522 /// use llama_cpp_4::context::tensor_capture::TensorCapture;
523 ///
524 /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
525 /// let ctx_params = LlamaContextParams::default()
526 /// .with_embeddings(true)
527 /// .with_tensor_capture(&mut capture);
528 /// ```
529 #[must_use]
530 pub fn with_tensor_capture(self, capture: &mut super::tensor_capture::TensorCapture) -> Self {
531 self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
532 .with_cb_eval_user_data(
533 std::ptr::from_mut::<super::tensor_capture::TensorCapture>(capture)
534 .cast::<std::ffi::c_void>(),
535 )
536 }
537
538 /// Set the storage type for the **K** (key) KV cache tensors.
539 ///
540 /// The default is `GgmlType::F16`. Quantized types like `GgmlType::Q5_0`
541 /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
542 /// `TurboQuant` attention rotation (the default) keeps quality high.
543 ///
544 /// # Examples
545 ///
546 /// ```rust
547 /// use llama_cpp_4::context::params::LlamaContextParams;
548 /// use llama_cpp_4::quantize::GgmlType;
549 /// let params = LlamaContextParams::default()
550 /// .with_cache_type_k(GgmlType::Q5_0);
551 /// ```
552 #[must_use]
553 pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
554 self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
555 self
556 }
557
558 /// Get the K-cache storage type.
559 #[must_use]
560 pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
561 self.context_params.type_k
562 }
563
564 /// Set the storage type for the **V** (value) KV cache tensors.
565 ///
566 /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
567 ///
568 /// # Examples
569 ///
570 /// ```rust
571 /// use llama_cpp_4::context::params::LlamaContextParams;
572 /// use llama_cpp_4::quantize::GgmlType;
573 /// let params = LlamaContextParams::default()
574 /// .with_cache_type_v(GgmlType::Q5_0);
575 /// ```
576 #[must_use]
577 pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
578 self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
579 self
580 }
581
582 /// Get the V-cache storage type.
583 #[must_use]
584 pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
585 self.context_params.type_v
586 }
587
588 /// Control the `TurboQuant` attention-rotation feature (llama.cpp PR #21038).
589 ///
590 /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
591 /// before writing them into the KV cache. This significantly improves
592 /// quantized KV-cache quality at near-zero overhead, and is enabled
593 /// automatically for models whose head dimension is a power of two.
594 ///
595 /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
596 /// The env-var is applied just before the context is created and restored
597 /// afterwards, so this is safe to call from a single thread.
598 ///
599 /// # Examples
600 ///
601 /// ```rust
602 /// use llama_cpp_4::context::params::LlamaContextParams;
603 /// // Disable rotation for this context only:
604 /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
605 /// assert!(params.attn_rot_disabled());
606 /// ```
607 #[must_use]
608 pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
609 self.attn_rot_disabled = disabled;
610 self
611 }
612
613 /// Returns `true` if `TurboQuant` attention rotation is disabled for this context.
614 ///
615 /// ```rust
616 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
617 /// assert!(!params.attn_rot_disabled());
618 /// ```
619 #[must_use]
620 pub fn attn_rot_disabled(&self) -> bool {
621 self.attn_rot_disabled
622 }
623
624 /// Set the type of pooling.
625 ///
626 /// # Examples
627 ///
628 /// ```rust
629 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
630 /// let params = LlamaContextParams::default()
631 /// .with_pooling_type(LlamaPoolingType::Last);
632 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
633 /// ```
634 #[must_use]
635 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
636 self.context_params.pooling_type = i32::from(pooling_type);
637 self
638 }
639
640 /// Get the type of pooling.
641 ///
642 /// # Examples
643 ///
644 /// ```rust
645 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
646 /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
647 /// ```
648 #[must_use]
649 pub fn pooling_type(&self) -> LlamaPoolingType {
650 LlamaPoolingType::from(self.context_params.pooling_type)
651 }
652}
653
654/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
655/// ```
656/// # use std::num::NonZeroU32;
657/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
658/// let params = LlamaContextParams::default();
659/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
660/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
661/// ```
662impl Default for LlamaContextParams {
663 fn default() -> Self {
664 let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
665 Self {
666 context_params,
667 attn_rot_disabled: false,
668 }
669 }
670}