llama_cpp_4/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63 fn from(value: i32) -> Self {
64 match value {
65 0 => Self::None,
66 1 => Self::Mean,
67 2 => Self::Cls,
68 3 => Self::Last,
69 _ => Self::Unspecified,
70 }
71 }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76 fn from(value: LlamaPoolingType) -> Self {
77 match value {
78 LlamaPoolingType::None => 0,
79 LlamaPoolingType::Mean => 1,
80 LlamaPoolingType::Cls => 2,
81 LlamaPoolingType::Last => 3,
82 LlamaPoolingType::Unspecified => -1,
83 }
84 }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97/// let ctx_params = LlamaContextParams::default()
98/// .with_n_ctx(NonZeroU32::new(2048));
99///
100/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
101/// ```
102#[derive(Debug, Clone)]
103#[allow(
104 missing_docs,
105 clippy::struct_excessive_bools,
106 clippy::module_name_repetitions
107)]
108pub struct LlamaContextParams {
109 pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
110 /// When `true`, the TurboQuant attention rotation (PR #21038) will be
111 /// disabled for any context created from these params.
112 pub(crate) attn_rot_disabled: bool,
113}
114
115/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
116unsafe impl Send for LlamaContextParams {}
117unsafe impl Sync for LlamaContextParams {}
118
119impl LlamaContextParams {
120 /// Set the side of the context
121 ///
122 /// # Examples
123 ///
124 /// ```rust
125 /// # use std::num::NonZeroU32;
126 /// use llama_cpp_4::context::params::LlamaContextParams;
127 /// let params = LlamaContextParams::default();
128 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
129 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
130 /// ```
131 #[must_use]
132 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
133 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
134 self
135 }
136
137 /// Get the size of the context.
138 ///
139 /// [`None`] if the context size is specified by the model and not the context.
140 ///
141 /// # Examples
142 ///
143 /// ```rust
144 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
145 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
146 #[must_use]
147 pub fn n_ctx(&self) -> Option<NonZeroU32> {
148 NonZeroU32::new(self.context_params.n_ctx)
149 }
150
151 /// Set the `n_batch`
152 ///
153 /// # Examples
154 ///
155 /// ```rust
156 /// # use std::num::NonZeroU32;
157 /// use llama_cpp_4::context::params::LlamaContextParams;
158 /// let params = LlamaContextParams::default()
159 /// .with_n_batch(2048);
160 /// assert_eq!(params.n_batch(), 2048);
161 /// ```
162 #[must_use]
163 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
164 self.context_params.n_batch = n_batch;
165 self
166 }
167
168 /// Get the `n_batch`
169 ///
170 /// # Examples
171 ///
172 /// ```rust
173 /// use llama_cpp_4::context::params::LlamaContextParams;
174 /// let params = LlamaContextParams::default();
175 /// assert_eq!(params.n_batch(), 2048);
176 /// ```
177 #[must_use]
178 pub fn n_batch(&self) -> u32 {
179 self.context_params.n_batch
180 }
181
182 /// Set the `n_ubatch`
183 ///
184 /// # Examples
185 ///
186 /// ```rust
187 /// # use std::num::NonZeroU32;
188 /// use llama_cpp_4::context::params::LlamaContextParams;
189 /// let params = LlamaContextParams::default()
190 /// .with_n_ubatch(512);
191 /// assert_eq!(params.n_ubatch(), 512);
192 /// ```
193 #[must_use]
194 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
195 self.context_params.n_ubatch = n_ubatch;
196 self
197 }
198
199 /// Get the `n_ubatch`
200 ///
201 /// # Examples
202 ///
203 /// ```rust
204 /// use llama_cpp_4::context::params::LlamaContextParams;
205 /// let params = LlamaContextParams::default();
206 /// assert_eq!(params.n_ubatch(), 512);
207 /// ```
208 #[must_use]
209 pub fn n_ubatch(&self) -> u32 {
210 self.context_params.n_ubatch
211 }
212
213 /// Set the `flash_attention` parameter
214 ///
215 /// # Examples
216 ///
217 /// ```rust
218 /// use llama_cpp_4::context::params::LlamaContextParams;
219 /// let params = LlamaContextParams::default()
220 /// .with_flash_attention(true);
221 /// assert_eq!(params.flash_attention(), true);
222 /// ```
223 #[must_use]
224 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
225 self.context_params.flash_attn_type = if enabled {
226 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
227 } else {
228 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
229 };
230 self
231 }
232
233 /// Get the `flash_attention` parameter
234 ///
235 /// # Examples
236 ///
237 /// ```rust
238 /// use llama_cpp_4::context::params::LlamaContextParams;
239 /// let params = LlamaContextParams::default();
240 /// assert_eq!(params.flash_attention(), false);
241 /// ```
242 #[must_use]
243 pub fn flash_attention(&self) -> bool {
244 self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
245 }
246
247 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
248 ///
249 /// # Examples
250 ///
251 /// ```rust
252 /// use llama_cpp_4::context::params::LlamaContextParams;
253 /// let params = LlamaContextParams::default()
254 /// .with_offload_kqv(false);
255 /// assert_eq!(params.offload_kqv(), false);
256 /// ```
257 #[must_use]
258 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
259 self.context_params.offload_kqv = enabled;
260 self
261 }
262
263 /// Get the `offload_kqv` parameter
264 ///
265 /// # Examples
266 ///
267 /// ```rust
268 /// use llama_cpp_4::context::params::LlamaContextParams;
269 /// let params = LlamaContextParams::default();
270 /// assert_eq!(params.offload_kqv(), true);
271 /// ```
272 #[must_use]
273 pub fn offload_kqv(&self) -> bool {
274 self.context_params.offload_kqv
275 }
276
277 /// Set the type of rope scaling.
278 ///
279 /// # Examples
280 ///
281 /// ```rust
282 /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
283 /// let params = LlamaContextParams::default()
284 /// .with_rope_scaling_type(RopeScalingType::Linear);
285 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
286 /// ```
287 #[must_use]
288 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
289 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
290 self
291 }
292
293 /// Get the type of rope scaling.
294 ///
295 /// # Examples
296 ///
297 /// ```rust
298 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
299 /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
300 /// ```
301 #[must_use]
302 pub fn rope_scaling_type(&self) -> RopeScalingType {
303 RopeScalingType::from(self.context_params.rope_scaling_type)
304 }
305
306 /// Set the rope frequency base.
307 ///
308 /// # Examples
309 ///
310 /// ```rust
311 /// use llama_cpp_4::context::params::LlamaContextParams;
312 /// let params = LlamaContextParams::default()
313 /// .with_rope_freq_base(0.5);
314 /// assert_eq!(params.rope_freq_base(), 0.5);
315 /// ```
316 #[must_use]
317 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
318 self.context_params.rope_freq_base = rope_freq_base;
319 self
320 }
321
322 /// Get the rope frequency base.
323 ///
324 /// # Examples
325 ///
326 /// ```rust
327 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
328 /// assert_eq!(params.rope_freq_base(), 0.0);
329 /// ```
330 #[must_use]
331 pub fn rope_freq_base(&self) -> f32 {
332 self.context_params.rope_freq_base
333 }
334
335 /// Set the rope frequency scale.
336 ///
337 /// # Examples
338 ///
339 /// ```rust
340 /// use llama_cpp_4::context::params::LlamaContextParams;
341 /// let params = LlamaContextParams::default()
342 /// .with_rope_freq_scale(0.5);
343 /// assert_eq!(params.rope_freq_scale(), 0.5);
344 /// ```
345 #[must_use]
346 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
347 self.context_params.rope_freq_scale = rope_freq_scale;
348 self
349 }
350
351 /// Get the rope frequency scale.
352 ///
353 /// # Examples
354 ///
355 /// ```rust
356 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
357 /// assert_eq!(params.rope_freq_scale(), 0.0);
358 /// ```
359 #[must_use]
360 pub fn rope_freq_scale(&self) -> f32 {
361 self.context_params.rope_freq_scale
362 }
363
364 /// Get the number of threads.
365 ///
366 /// # Examples
367 ///
368 /// ```rust
369 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
370 /// assert_eq!(params.n_threads(), 4);
371 /// ```
372 #[must_use]
373 pub fn n_threads(&self) -> i32 {
374 self.context_params.n_threads
375 }
376
377 /// Get the number of threads allocated for batches.
378 ///
379 /// # Examples
380 ///
381 /// ```rust
382 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
383 /// assert_eq!(params.n_threads_batch(), 4);
384 /// ```
385 #[must_use]
386 pub fn n_threads_batch(&self) -> i32 {
387 self.context_params.n_threads_batch
388 }
389
390 /// Set the number of threads.
391 ///
392 /// # Examples
393 ///
394 /// ```rust
395 /// use llama_cpp_4::context::params::LlamaContextParams;
396 /// let params = LlamaContextParams::default()
397 /// .with_n_threads(8);
398 /// assert_eq!(params.n_threads(), 8);
399 /// ```
400 #[must_use]
401 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
402 self.context_params.n_threads = n_threads;
403 self
404 }
405
406 /// Set the number of threads allocated for batches.
407 ///
408 /// # Examples
409 ///
410 /// ```rust
411 /// use llama_cpp_4::context::params::LlamaContextParams;
412 /// let params = LlamaContextParams::default()
413 /// .with_n_threads_batch(8);
414 /// assert_eq!(params.n_threads_batch(), 8);
415 /// ```
416 #[must_use]
417 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
418 self.context_params.n_threads_batch = n_threads;
419 self
420 }
421
422 /// Check whether embeddings are enabled
423 ///
424 /// # Examples
425 ///
426 /// ```rust
427 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
428 /// assert!(!params.embeddings());
429 /// ```
430 #[must_use]
431 pub fn embeddings(&self) -> bool {
432 self.context_params.embeddings
433 }
434
435 /// Enable the use of embeddings
436 ///
437 /// # Examples
438 ///
439 /// ```rust
440 /// use llama_cpp_4::context::params::LlamaContextParams;
441 /// let params = LlamaContextParams::default()
442 /// .with_embeddings(true);
443 /// assert!(params.embeddings());
444 /// ```
445 #[must_use]
446 pub fn with_embeddings(mut self, embedding: bool) -> Self {
447 self.context_params.embeddings = embedding;
448 self
449 }
450
451 /// Set the evaluation callback.
452 ///
453 /// # Examples
454 ///
455 /// ```no_run
456 /// extern "C" fn cb_eval_fn(
457 /// t: *mut llama_cpp_sys_4::ggml_tensor,
458 /// ask: bool,
459 /// user_data: *mut std::ffi::c_void,
460 /// ) -> bool {
461 /// false
462 /// }
463 ///
464 /// use llama_cpp_4::context::params::LlamaContextParams;
465 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
466 /// ```
467 #[must_use]
468 pub fn with_cb_eval(
469 mut self,
470 cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
471 ) -> Self {
472 self.context_params.cb_eval = cb_eval;
473 self
474 }
475
476 /// Set the evaluation callback user data.
477 ///
478 /// # Examples
479 ///
480 /// ```no_run
481 /// use llama_cpp_4::context::params::LlamaContextParams;
482 /// let params = LlamaContextParams::default();
483 /// let user_data = std::ptr::null_mut();
484 /// let params = params.with_cb_eval_user_data(user_data);
485 /// ```
486 #[must_use]
487 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
488 self.context_params.cb_eval_user_data = cb_eval_user_data;
489 self
490 }
491
492 /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
493 /// intercept intermediate tensor outputs during `decode()`.
494 ///
495 /// This sets up the `cb_eval` callback to capture tensors matching the
496 /// capture's filter (e.g. specific layer outputs). After `decode()` the
497 /// captured data can be read from the `TensorCapture`.
498 ///
499 /// # Example
500 ///
501 /// ```rust,ignore
502 /// use llama_cpp_4::context::params::LlamaContextParams;
503 /// use llama_cpp_4::context::tensor_capture::TensorCapture;
504 ///
505 /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
506 /// let ctx_params = LlamaContextParams::default()
507 /// .with_embeddings(true)
508 /// .with_tensor_capture(&mut capture);
509 /// ```
510 #[must_use]
511 pub fn with_tensor_capture(
512 self,
513 capture: &mut super::tensor_capture::TensorCapture,
514 ) -> Self {
515 self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
516 .with_cb_eval_user_data(
517 capture as *mut super::tensor_capture::TensorCapture as *mut std::ffi::c_void,
518 )
519 }
520
521 /// Set the storage type for the **K** (key) KV cache tensors.
522 ///
523 /// The default is `GgmlType::F16`. Quantized types like `GgmlType::Q5_0`
524 /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
525 /// TurboQuant attention rotation (the default) keeps quality high.
526 ///
527 /// # Examples
528 ///
529 /// ```rust
530 /// use llama_cpp_4::context::params::LlamaContextParams;
531 /// use llama_cpp_4::quantize::GgmlType;
532 /// let params = LlamaContextParams::default()
533 /// .with_cache_type_k(GgmlType::Q5_0);
534 /// ```
535 #[must_use]
536 pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
537 self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
538 self
539 }
540
541 /// Get the K-cache storage type.
542 #[must_use]
543 pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
544 self.context_params.type_k
545 }
546
547 /// Set the storage type for the **V** (value) KV cache tensors.
548 ///
549 /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
550 ///
551 /// # Examples
552 ///
553 /// ```rust
554 /// use llama_cpp_4::context::params::LlamaContextParams;
555 /// use llama_cpp_4::quantize::GgmlType;
556 /// let params = LlamaContextParams::default()
557 /// .with_cache_type_v(GgmlType::Q5_0);
558 /// ```
559 #[must_use]
560 pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
561 self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
562 self
563 }
564
565 /// Get the V-cache storage type.
566 #[must_use]
567 pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
568 self.context_params.type_v
569 }
570
571 /// Control the TurboQuant attention-rotation feature (llama.cpp PR #21038).
572 ///
573 /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
574 /// before writing them into the KV cache. This significantly improves
575 /// quantized KV-cache quality at near-zero overhead, and is enabled
576 /// automatically for models whose head dimension is a power of two.
577 ///
578 /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
579 /// The env-var is applied just before the context is created and restored
580 /// afterwards, so this is safe to call from a single thread.
581 ///
582 /// # Examples
583 ///
584 /// ```rust
585 /// use llama_cpp_4::context::params::LlamaContextParams;
586 /// // Disable rotation for this context only:
587 /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
588 /// assert!(params.attn_rot_disabled());
589 /// ```
590 #[must_use]
591 pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
592 self.attn_rot_disabled = disabled;
593 self
594 }
595
596 /// Returns `true` if TurboQuant attention rotation is disabled for this context.
597 ///
598 /// ```rust
599 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
600 /// assert!(!params.attn_rot_disabled());
601 /// ```
602 #[must_use]
603 pub fn attn_rot_disabled(&self) -> bool {
604 self.attn_rot_disabled
605 }
606
607 /// Set the type of pooling.
608 ///
609 /// # Examples
610 ///
611 /// ```rust
612 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
613 /// let params = LlamaContextParams::default()
614 /// .with_pooling_type(LlamaPoolingType::Last);
615 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
616 /// ```
617 #[must_use]
618 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
619 self.context_params.pooling_type = i32::from(pooling_type);
620 self
621 }
622
623 /// Get the type of pooling.
624 ///
625 /// # Examples
626 ///
627 /// ```rust
628 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
629 /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
630 /// ```
631 #[must_use]
632 pub fn pooling_type(&self) -> LlamaPoolingType {
633 LlamaPoolingType::from(self.context_params.pooling_type)
634 }
635}
636
637/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
638/// ```
639/// # use std::num::NonZeroU32;
640/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
641/// let params = LlamaContextParams::default();
642/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
643/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
644/// ```
645impl Default for LlamaContextParams {
646 fn default() -> Self {
647 let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
648 Self {
649 context_params,
650 attn_rot_disabled: false,
651 }
652 }
653}