llama_cpp_4/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63 fn from(value: i32) -> Self {
64 match value {
65 0 => Self::None,
66 1 => Self::Mean,
67 2 => Self::Cls,
68 3 => Self::Last,
69 _ => Self::Unspecified,
70 }
71 }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76 fn from(value: LlamaPoolingType) -> Self {
77 match value {
78 LlamaPoolingType::None => 0,
79 LlamaPoolingType::Mean => 1,
80 LlamaPoolingType::Cls => 2,
81 LlamaPoolingType::Last => 3,
82 LlamaPoolingType::Unspecified => -1,
83 }
84 }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97///let ctx_params = LlamaContextParams::default()
98/// .with_n_ctx(NonZeroU32::new(2048))
99/// .with_seed(1234);
100///
101/// assert_eq!(ctx_params.seed(), 1234);
102/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103/// ```
104#[derive(Debug, Clone)]
105#[allow(
106 missing_docs,
107 clippy::struct_excessive_bools,
108 clippy::module_name_repetitions
109)]
110pub struct LlamaContextParams {
111 pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
112 /// When `true`, the TurboQuant attention rotation (PR #21038) will be
113 /// disabled for any context created from these params.
114 pub(crate) attn_rot_disabled: bool,
115}
116
117/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
118unsafe impl Send for LlamaContextParams {}
119unsafe impl Sync for LlamaContextParams {}
120
121impl LlamaContextParams {
122 /// Set the side of the context
123 ///
124 /// # Examples
125 ///
126 /// ```rust
127 /// # use std::num::NonZeroU32;
128 /// use llama_cpp_4::context::params::LlamaContextParams;
129 /// let params = LlamaContextParams::default();
130 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
131 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
132 /// ```
133 #[must_use]
134 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
135 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
136 self
137 }
138
139 /// Get the size of the context.
140 ///
141 /// [`None`] if the context size is specified by the model and not the context.
142 ///
143 /// # Examples
144 ///
145 /// ```rust
146 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
147 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
148 #[must_use]
149 pub fn n_ctx(&self) -> Option<NonZeroU32> {
150 NonZeroU32::new(self.context_params.n_ctx)
151 }
152
153 /// Set the `n_batch`
154 ///
155 /// # Examples
156 ///
157 /// ```rust
158 /// # use std::num::NonZeroU32;
159 /// use llama_cpp_4::context::params::LlamaContextParams;
160 /// let params = LlamaContextParams::default()
161 /// .with_n_batch(2048);
162 /// assert_eq!(params.n_batch(), 2048);
163 /// ```
164 #[must_use]
165 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
166 self.context_params.n_batch = n_batch;
167 self
168 }
169
170 /// Get the `n_batch`
171 ///
172 /// # Examples
173 ///
174 /// ```rust
175 /// use llama_cpp_4::context::params::LlamaContextParams;
176 /// let params = LlamaContextParams::default();
177 /// assert_eq!(params.n_batch(), 2048);
178 /// ```
179 #[must_use]
180 pub fn n_batch(&self) -> u32 {
181 self.context_params.n_batch
182 }
183
184 /// Set the `n_ubatch`
185 ///
186 /// # Examples
187 ///
188 /// ```rust
189 /// # use std::num::NonZeroU32;
190 /// use llama_cpp_4::context::params::LlamaContextParams;
191 /// let params = LlamaContextParams::default()
192 /// .with_n_ubatch(512);
193 /// assert_eq!(params.n_ubatch(), 512);
194 /// ```
195 #[must_use]
196 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
197 self.context_params.n_ubatch = n_ubatch;
198 self
199 }
200
201 /// Get the `n_ubatch`
202 ///
203 /// # Examples
204 ///
205 /// ```rust
206 /// use llama_cpp_4::context::params::LlamaContextParams;
207 /// let params = LlamaContextParams::default();
208 /// assert_eq!(params.n_ubatch(), 512);
209 /// ```
210 #[must_use]
211 pub fn n_ubatch(&self) -> u32 {
212 self.context_params.n_ubatch
213 }
214
215 /// Set the `flash_attention` parameter
216 ///
217 /// # Examples
218 ///
219 /// ```rust
220 /// use llama_cpp_4::context::params::LlamaContextParams;
221 /// let params = LlamaContextParams::default()
222 /// .with_flash_attention(true);
223 /// assert_eq!(params.flash_attention(), true);
224 /// ```
225 #[must_use]
226 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
227 self.context_params.flash_attn_type = if enabled {
228 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
229 } else {
230 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
231 };
232 self
233 }
234
235 /// Get the `flash_attention` parameter
236 ///
237 /// # Examples
238 ///
239 /// ```rust
240 /// use llama_cpp_4::context::params::LlamaContextParams;
241 /// let params = LlamaContextParams::default();
242 /// assert_eq!(params.flash_attention(), false);
243 /// ```
244 #[must_use]
245 pub fn flash_attention(&self) -> bool {
246 self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
247 }
248
249 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
250 ///
251 /// # Examples
252 ///
253 /// ```rust
254 /// use llama_cpp_4::context::params::LlamaContextParams;
255 /// let params = LlamaContextParams::default()
256 /// .with_offload_kqv(false);
257 /// assert_eq!(params.offload_kqv(), false);
258 /// ```
259 #[must_use]
260 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
261 self.context_params.offload_kqv = enabled;
262 self
263 }
264
265 /// Get the `offload_kqv` parameter
266 ///
267 /// # Examples
268 ///
269 /// ```rust
270 /// use llama_cpp_4::context::params::LlamaContextParams;
271 /// let params = LlamaContextParams::default();
272 /// assert_eq!(params.offload_kqv(), true);
273 /// ```
274 #[must_use]
275 pub fn offload_kqv(&self) -> bool {
276 self.context_params.offload_kqv
277 }
278
279 /// Set the type of rope scaling.
280 ///
281 /// # Examples
282 ///
283 /// ```rust
284 /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
285 /// let params = LlamaContextParams::default()
286 /// .with_rope_scaling_type(RopeScalingType::Linear);
287 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
288 /// ```
289 #[must_use]
290 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
291 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
292 self
293 }
294
295 /// Get the type of rope scaling.
296 ///
297 /// # Examples
298 ///
299 /// ```rust
300 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
301 /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
302 /// ```
303 #[must_use]
304 pub fn rope_scaling_type(&self) -> RopeScalingType {
305 RopeScalingType::from(self.context_params.rope_scaling_type)
306 }
307
308 /// Set the rope frequency base.
309 ///
310 /// # Examples
311 ///
312 /// ```rust
313 /// use llama_cpp_4::context::params::LlamaContextParams;
314 /// let params = LlamaContextParams::default()
315 /// .with_rope_freq_base(0.5);
316 /// assert_eq!(params.rope_freq_base(), 0.5);
317 /// ```
318 #[must_use]
319 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
320 self.context_params.rope_freq_base = rope_freq_base;
321 self
322 }
323
324 /// Get the rope frequency base.
325 ///
326 /// # Examples
327 ///
328 /// ```rust
329 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
330 /// assert_eq!(params.rope_freq_base(), 0.0);
331 /// ```
332 #[must_use]
333 pub fn rope_freq_base(&self) -> f32 {
334 self.context_params.rope_freq_base
335 }
336
337 /// Set the rope frequency scale.
338 ///
339 /// # Examples
340 ///
341 /// ```rust
342 /// use llama_cpp_4::context::params::LlamaContextParams;
343 /// let params = LlamaContextParams::default()
344 /// .with_rope_freq_scale(0.5);
345 /// assert_eq!(params.rope_freq_scale(), 0.5);
346 /// ```
347 #[must_use]
348 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
349 self.context_params.rope_freq_scale = rope_freq_scale;
350 self
351 }
352
353 /// Get the rope frequency scale.
354 ///
355 /// # Examples
356 ///
357 /// ```rust
358 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
359 /// assert_eq!(params.rope_freq_scale(), 0.0);
360 /// ```
361 #[must_use]
362 pub fn rope_freq_scale(&self) -> f32 {
363 self.context_params.rope_freq_scale
364 }
365
366 /// Get the number of threads.
367 ///
368 /// # Examples
369 ///
370 /// ```rust
371 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
372 /// assert_eq!(params.n_threads(), 4);
373 /// ```
374 #[must_use]
375 pub fn n_threads(&self) -> i32 {
376 self.context_params.n_threads
377 }
378
379 /// Get the number of threads allocated for batches.
380 ///
381 /// # Examples
382 ///
383 /// ```rust
384 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
385 /// assert_eq!(params.n_threads_batch(), 4);
386 /// ```
387 #[must_use]
388 pub fn n_threads_batch(&self) -> i32 {
389 self.context_params.n_threads_batch
390 }
391
392 /// Set the number of threads.
393 ///
394 /// # Examples
395 ///
396 /// ```rust
397 /// use llama_cpp_4::context::params::LlamaContextParams;
398 /// let params = LlamaContextParams::default()
399 /// .with_n_threads(8);
400 /// assert_eq!(params.n_threads(), 8);
401 /// ```
402 #[must_use]
403 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
404 self.context_params.n_threads = n_threads;
405 self
406 }
407
408 /// Set the number of threads allocated for batches.
409 ///
410 /// # Examples
411 ///
412 /// ```rust
413 /// use llama_cpp_4::context::params::LlamaContextParams;
414 /// let params = LlamaContextParams::default()
415 /// .with_n_threads_batch(8);
416 /// assert_eq!(params.n_threads_batch(), 8);
417 /// ```
418 #[must_use]
419 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
420 self.context_params.n_threads_batch = n_threads;
421 self
422 }
423
424 /// Check whether embeddings are enabled
425 ///
426 /// # Examples
427 ///
428 /// ```rust
429 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
430 /// assert!(!params.embeddings());
431 /// ```
432 #[must_use]
433 pub fn embeddings(&self) -> bool {
434 self.context_params.embeddings
435 }
436
437 /// Enable the use of embeddings
438 ///
439 /// # Examples
440 ///
441 /// ```rust
442 /// use llama_cpp_4::context::params::LlamaContextParams;
443 /// let params = LlamaContextParams::default()
444 /// .with_embeddings(true);
445 /// assert!(params.embeddings());
446 /// ```
447 #[must_use]
448 pub fn with_embeddings(mut self, embedding: bool) -> Self {
449 self.context_params.embeddings = embedding;
450 self
451 }
452
453 /// Set the evaluation callback.
454 ///
455 /// # Examples
456 ///
457 /// ```no_run
458 /// extern "C" fn cb_eval_fn(
459 /// t: *mut llama_cpp_sys_4::ggml_tensor,
460 /// ask: bool,
461 /// user_data: *mut std::ffi::c_void,
462 /// ) -> bool {
463 /// false
464 /// }
465 ///
466 /// use llama_cpp_4::context::params::LlamaContextParams;
467 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
468 /// ```
469 #[must_use]
470 pub fn with_cb_eval(
471 mut self,
472 cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
473 ) -> Self {
474 self.context_params.cb_eval = cb_eval;
475 self
476 }
477
478 /// Set the evaluation callback user data.
479 ///
480 /// # Examples
481 ///
482 /// ```no_run
483 /// use llama_cpp_4::context::params::LlamaContextParams;
484 /// let params = LlamaContextParams::default();
485 /// let user_data = std::ptr::null_mut();
486 /// let params = params.with_cb_eval_user_data(user_data);
487 /// ```
488 #[must_use]
489 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
490 self.context_params.cb_eval_user_data = cb_eval_user_data;
491 self
492 }
493
494 /// Attach a [`TensorCapture`](super::tensor_capture::TensorCapture) to
495 /// intercept intermediate tensor outputs during `decode()`.
496 ///
497 /// This sets up the `cb_eval` callback to capture tensors matching the
498 /// capture's filter (e.g. specific layer outputs). After `decode()` the
499 /// captured data can be read from the `TensorCapture`.
500 ///
501 /// # Example
502 ///
503 /// ```rust,ignore
504 /// use llama_cpp_4::context::params::LlamaContextParams;
505 /// use llama_cpp_4::context::tensor_capture::TensorCapture;
506 ///
507 /// let mut capture = TensorCapture::for_layers(&[13, 20, 27]);
508 /// let ctx_params = LlamaContextParams::default()
509 /// .with_embeddings(true)
510 /// .with_tensor_capture(&mut capture);
511 /// ```
512 #[must_use]
513 pub fn with_tensor_capture(
514 self,
515 capture: &mut super::tensor_capture::TensorCapture,
516 ) -> Self {
517 self.with_cb_eval(Some(super::tensor_capture::tensor_capture_callback))
518 .with_cb_eval_user_data(
519 capture as *mut super::tensor_capture::TensorCapture as *mut std::ffi::c_void,
520 )
521 }
522
523 /// Set the storage type for the **K** (key) KV cache tensors.
524 ///
525 /// The default is `GgmlType::F16`. Quantized types like `GgmlType::Q5_0`
526 /// or `GgmlType::Q4_0` reduce VRAM usage significantly; combining them with
527 /// TurboQuant attention rotation (the default) keeps quality high.
528 ///
529 /// # Examples
530 ///
531 /// ```rust
532 /// use llama_cpp_4::context::params::LlamaContextParams;
533 /// use llama_cpp_4::quantize::GgmlType;
534 /// let params = LlamaContextParams::default()
535 /// .with_cache_type_k(GgmlType::Q5_0);
536 /// ```
537 #[must_use]
538 pub fn with_cache_type_k(mut self, ty: crate::quantize::GgmlType) -> Self {
539 self.context_params.type_k = ty as llama_cpp_sys_4::ggml_type;
540 self
541 }
542
543 /// Get the K-cache storage type.
544 #[must_use]
545 pub fn cache_type_k(&self) -> llama_cpp_sys_4::ggml_type {
546 self.context_params.type_k
547 }
548
549 /// Set the storage type for the **V** (value) KV cache tensors.
550 ///
551 /// See [`with_cache_type_k`](Self::with_cache_type_k) for details.
552 ///
553 /// # Examples
554 ///
555 /// ```rust
556 /// use llama_cpp_4::context::params::LlamaContextParams;
557 /// use llama_cpp_4::quantize::GgmlType;
558 /// let params = LlamaContextParams::default()
559 /// .with_cache_type_v(GgmlType::Q5_0);
560 /// ```
561 #[must_use]
562 pub fn with_cache_type_v(mut self, ty: crate::quantize::GgmlType) -> Self {
563 self.context_params.type_v = ty as llama_cpp_sys_4::ggml_type;
564 self
565 }
566
567 /// Get the V-cache storage type.
568 #[must_use]
569 pub fn cache_type_v(&self) -> llama_cpp_sys_4::ggml_type {
570 self.context_params.type_v
571 }
572
573 /// Control the TurboQuant attention-rotation feature (llama.cpp PR #21038).
574 ///
575 /// By default, llama.cpp applies a Hadamard rotation to Q/K/V tensors
576 /// before writing them into the KV cache. This significantly improves
577 /// quantized KV-cache quality at near-zero overhead, and is enabled
578 /// automatically for models whose head dimension is a power of two.
579 ///
580 /// Set `disabled = true` to opt out (equivalent to `LLAMA_ATTN_ROT_DISABLE=1`).
581 /// The env-var is applied just before the context is created and restored
582 /// afterwards, so this is safe to call from a single thread.
583 ///
584 /// # Examples
585 ///
586 /// ```rust
587 /// use llama_cpp_4::context::params::LlamaContextParams;
588 /// // Disable rotation for this context only:
589 /// let params = LlamaContextParams::default().with_attn_rot_disabled(true);
590 /// assert!(params.attn_rot_disabled());
591 /// ```
592 #[must_use]
593 pub fn with_attn_rot_disabled(mut self, disabled: bool) -> Self {
594 self.attn_rot_disabled = disabled;
595 self
596 }
597
598 /// Returns `true` if TurboQuant attention rotation is disabled for this context.
599 ///
600 /// ```rust
601 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
602 /// assert!(!params.attn_rot_disabled());
603 /// ```
604 #[must_use]
605 pub fn attn_rot_disabled(&self) -> bool {
606 self.attn_rot_disabled
607 }
608
609 /// Set the type of pooling.
610 ///
611 /// # Examples
612 ///
613 /// ```rust
614 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
615 /// let params = LlamaContextParams::default()
616 /// .with_pooling_type(LlamaPoolingType::Last);
617 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
618 /// ```
619 #[must_use]
620 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
621 self.context_params.pooling_type = i32::from(pooling_type);
622 self
623 }
624
625 /// Get the type of pooling.
626 ///
627 /// # Examples
628 ///
629 /// ```rust
630 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
631 /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
632 /// ```
633 #[must_use]
634 pub fn pooling_type(&self) -> LlamaPoolingType {
635 LlamaPoolingType::from(self.context_params.pooling_type)
636 }
637}
638
639/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
640/// ```
641/// # use std::num::NonZeroU32;
642/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
643/// let params = LlamaContextParams::default();
644/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
645/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
646/// ```
647impl Default for LlamaContextParams {
648 fn default() -> Self {
649 let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
650 Self {
651 context_params,
652 attn_rot_disabled: false,
653 }
654 }
655}