llama_cpp_4/context/params.rs
1//! A safe wrapper around `llama_context_params`.
2use std::fmt::Debug;
3use std::num::NonZeroU32;
4
5/// A rusty wrapper around `rope_scaling_type`.
6#[repr(i8)]
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
8pub enum RopeScalingType {
9 /// The scaling type is unspecified
10 Unspecified = -1,
11 /// No scaling
12 None = 0,
13 /// Linear scaling
14 Linear = 1,
15 /// Yarn scaling
16 Yarn = 2,
17}
18
19/// Create a `RopeScalingType` from a `c_int` - returns `RopeScalingType::ScalingUnspecified` if
20/// the value is not recognized.
21impl From<i32> for RopeScalingType {
22 fn from(value: i32) -> Self {
23 match value {
24 0 => Self::None,
25 1 => Self::Linear,
26 2 => Self::Yarn,
27 _ => Self::Unspecified,
28 }
29 }
30}
31
32/// Create a `c_int` from a `RopeScalingType`.
33impl From<RopeScalingType> for i32 {
34 fn from(value: RopeScalingType) -> Self {
35 match value {
36 RopeScalingType::None => 0,
37 RopeScalingType::Linear => 1,
38 RopeScalingType::Yarn => 2,
39 RopeScalingType::Unspecified => -1,
40 }
41 }
42}
43
44/// A rusty wrapper around `LLAMA_POOLING_TYPE`.
45#[repr(i8)]
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
47pub enum LlamaPoolingType {
48 /// The pooling type is unspecified
49 Unspecified = -1,
50 /// No pooling
51 None = 0,
52 /// Mean pooling
53 Mean = 1,
54 /// CLS pooling
55 Cls = 2,
56 /// Last pooling
57 Last = 3,
58}
59
60/// Create a `LlamaPoolingType` from a `c_int` - returns `LlamaPoolingType::Unspecified` if
61/// the value is not recognized.
62impl From<i32> for LlamaPoolingType {
63 fn from(value: i32) -> Self {
64 match value {
65 0 => Self::None,
66 1 => Self::Mean,
67 2 => Self::Cls,
68 3 => Self::Last,
69 _ => Self::Unspecified,
70 }
71 }
72}
73
74/// Create a `c_int` from a `LlamaPoolingType`.
75impl From<LlamaPoolingType> for i32 {
76 fn from(value: LlamaPoolingType) -> Self {
77 match value {
78 LlamaPoolingType::None => 0,
79 LlamaPoolingType::Mean => 1,
80 LlamaPoolingType::Cls => 2,
81 LlamaPoolingType::Last => 3,
82 LlamaPoolingType::Unspecified => -1,
83 }
84 }
85}
86
87/// A safe wrapper around `llama_context_params`.
88///
89/// Generally this should be created with [`Default::default()`] and then modified with `with_*` methods.
90///
91/// # Examples
92///
93/// ```rust
94/// # use std::num::NonZeroU32;
95/// use llama_cpp_4::context::params::LlamaContextParams;
96///
97///let ctx_params = LlamaContextParams::default()
98/// .with_n_ctx(NonZeroU32::new(2048))
99/// .with_seed(1234);
100///
101/// assert_eq!(ctx_params.seed(), 1234);
102/// assert_eq!(ctx_params.n_ctx(), NonZeroU32::new(2048));
103/// ```
104#[derive(Debug, Clone)]
105#[allow(
106 missing_docs,
107 clippy::struct_excessive_bools,
108 clippy::module_name_repetitions
109)]
110pub struct LlamaContextParams {
111 pub(crate) context_params: llama_cpp_sys_4::llama_context_params,
112}
113
114/// SAFETY: we do not currently allow setting or reading the pointers that cause this to not be automatically send or sync.
115unsafe impl Send for LlamaContextParams {}
116unsafe impl Sync for LlamaContextParams {}
117
118impl LlamaContextParams {
119 /// Set the side of the context
120 ///
121 /// # Examples
122 ///
123 /// ```rust
124 /// # use std::num::NonZeroU32;
125 /// use llama_cpp_4::context::params::LlamaContextParams;
126 /// let params = LlamaContextParams::default();
127 /// let params = params.with_n_ctx(NonZeroU32::new(2048));
128 /// assert_eq!(params.n_ctx(), NonZeroU32::new(2048));
129 /// ```
130 #[must_use]
131 pub fn with_n_ctx(mut self, n_ctx: Option<NonZeroU32>) -> Self {
132 self.context_params.n_ctx = n_ctx.map_or(0, std::num::NonZeroU32::get);
133 self
134 }
135
136 /// Get the size of the context.
137 ///
138 /// [`None`] if the context size is specified by the model and not the context.
139 ///
140 /// # Examples
141 ///
142 /// ```rust
143 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
144 /// assert_eq!(params.n_ctx(), std::num::NonZeroU32::new(512));
145 #[must_use]
146 pub fn n_ctx(&self) -> Option<NonZeroU32> {
147 NonZeroU32::new(self.context_params.n_ctx)
148 }
149
150 /// Set the `n_batch`
151 ///
152 /// # Examples
153 ///
154 /// ```rust
155 /// # use std::num::NonZeroU32;
156 /// use llama_cpp_4::context::params::LlamaContextParams;
157 /// let params = LlamaContextParams::default()
158 /// .with_n_batch(2048);
159 /// assert_eq!(params.n_batch(), 2048);
160 /// ```
161 #[must_use]
162 pub fn with_n_batch(mut self, n_batch: u32) -> Self {
163 self.context_params.n_batch = n_batch;
164 self
165 }
166
167 /// Get the `n_batch`
168 ///
169 /// # Examples
170 ///
171 /// ```rust
172 /// use llama_cpp_4::context::params::LlamaContextParams;
173 /// let params = LlamaContextParams::default();
174 /// assert_eq!(params.n_batch(), 2048);
175 /// ```
176 #[must_use]
177 pub fn n_batch(&self) -> u32 {
178 self.context_params.n_batch
179 }
180
181 /// Set the `n_ubatch`
182 ///
183 /// # Examples
184 ///
185 /// ```rust
186 /// # use std::num::NonZeroU32;
187 /// use llama_cpp_4::context::params::LlamaContextParams;
188 /// let params = LlamaContextParams::default()
189 /// .with_n_ubatch(512);
190 /// assert_eq!(params.n_ubatch(), 512);
191 /// ```
192 #[must_use]
193 pub fn with_n_ubatch(mut self, n_ubatch: u32) -> Self {
194 self.context_params.n_ubatch = n_ubatch;
195 self
196 }
197
198 /// Get the `n_ubatch`
199 ///
200 /// # Examples
201 ///
202 /// ```rust
203 /// use llama_cpp_4::context::params::LlamaContextParams;
204 /// let params = LlamaContextParams::default();
205 /// assert_eq!(params.n_ubatch(), 512);
206 /// ```
207 #[must_use]
208 pub fn n_ubatch(&self) -> u32 {
209 self.context_params.n_ubatch
210 }
211
212 /// Set the `flash_attention` parameter
213 ///
214 /// # Examples
215 ///
216 /// ```rust
217 /// use llama_cpp_4::context::params::LlamaContextParams;
218 /// let params = LlamaContextParams::default()
219 /// .with_flash_attention(true);
220 /// assert_eq!(params.flash_attention(), true);
221 /// ```
222 #[must_use]
223 pub fn with_flash_attention(mut self, enabled: bool) -> Self {
224 self.context_params.flash_attn_type = if enabled {
225 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
226 } else {
227 llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_DISABLED
228 };
229 self
230 }
231
232 /// Get the `flash_attention` parameter
233 ///
234 /// # Examples
235 ///
236 /// ```rust
237 /// use llama_cpp_4::context::params::LlamaContextParams;
238 /// let params = LlamaContextParams::default();
239 /// assert_eq!(params.flash_attention(), false);
240 /// ```
241 #[must_use]
242 pub fn flash_attention(&self) -> bool {
243 self.context_params.flash_attn_type == llama_cpp_sys_4::LLAMA_FLASH_ATTN_TYPE_ENABLED
244 }
245
246 /// Set the `offload_kqv` parameter to control offloading KV cache & KQV ops to GPU
247 ///
248 /// # Examples
249 ///
250 /// ```rust
251 /// use llama_cpp_4::context::params::LlamaContextParams;
252 /// let params = LlamaContextParams::default()
253 /// .with_offload_kqv(false);
254 /// assert_eq!(params.offload_kqv(), false);
255 /// ```
256 #[must_use]
257 pub fn with_offload_kqv(mut self, enabled: bool) -> Self {
258 self.context_params.offload_kqv = enabled;
259 self
260 }
261
262 /// Get the `offload_kqv` parameter
263 ///
264 /// # Examples
265 ///
266 /// ```rust
267 /// use llama_cpp_4::context::params::LlamaContextParams;
268 /// let params = LlamaContextParams::default();
269 /// assert_eq!(params.offload_kqv(), true);
270 /// ```
271 #[must_use]
272 pub fn offload_kqv(&self) -> bool {
273 self.context_params.offload_kqv
274 }
275
276 /// Set the type of rope scaling.
277 ///
278 /// # Examples
279 ///
280 /// ```rust
281 /// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
282 /// let params = LlamaContextParams::default()
283 /// .with_rope_scaling_type(RopeScalingType::Linear);
284 /// assert_eq!(params.rope_scaling_type(), RopeScalingType::Linear);
285 /// ```
286 #[must_use]
287 pub fn with_rope_scaling_type(mut self, rope_scaling_type: RopeScalingType) -> Self {
288 self.context_params.rope_scaling_type = i32::from(rope_scaling_type);
289 self
290 }
291
292 /// Get the type of rope scaling.
293 ///
294 /// # Examples
295 ///
296 /// ```rust
297 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
298 /// assert_eq!(params.rope_scaling_type(), llama_cpp_4::context::params::RopeScalingType::Unspecified);
299 /// ```
300 #[must_use]
301 pub fn rope_scaling_type(&self) -> RopeScalingType {
302 RopeScalingType::from(self.context_params.rope_scaling_type)
303 }
304
305 /// Set the rope frequency base.
306 ///
307 /// # Examples
308 ///
309 /// ```rust
310 /// use llama_cpp_4::context::params::LlamaContextParams;
311 /// let params = LlamaContextParams::default()
312 /// .with_rope_freq_base(0.5);
313 /// assert_eq!(params.rope_freq_base(), 0.5);
314 /// ```
315 #[must_use]
316 pub fn with_rope_freq_base(mut self, rope_freq_base: f32) -> Self {
317 self.context_params.rope_freq_base = rope_freq_base;
318 self
319 }
320
321 /// Get the rope frequency base.
322 ///
323 /// # Examples
324 ///
325 /// ```rust
326 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
327 /// assert_eq!(params.rope_freq_base(), 0.0);
328 /// ```
329 #[must_use]
330 pub fn rope_freq_base(&self) -> f32 {
331 self.context_params.rope_freq_base
332 }
333
334 /// Set the rope frequency scale.
335 ///
336 /// # Examples
337 ///
338 /// ```rust
339 /// use llama_cpp_4::context::params::LlamaContextParams;
340 /// let params = LlamaContextParams::default()
341 /// .with_rope_freq_scale(0.5);
342 /// assert_eq!(params.rope_freq_scale(), 0.5);
343 /// ```
344 #[must_use]
345 pub fn with_rope_freq_scale(mut self, rope_freq_scale: f32) -> Self {
346 self.context_params.rope_freq_scale = rope_freq_scale;
347 self
348 }
349
350 /// Get the rope frequency scale.
351 ///
352 /// # Examples
353 ///
354 /// ```rust
355 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
356 /// assert_eq!(params.rope_freq_scale(), 0.0);
357 /// ```
358 #[must_use]
359 pub fn rope_freq_scale(&self) -> f32 {
360 self.context_params.rope_freq_scale
361 }
362
363 /// Get the number of threads.
364 ///
365 /// # Examples
366 ///
367 /// ```rust
368 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
369 /// assert_eq!(params.n_threads(), 4);
370 /// ```
371 #[must_use]
372 pub fn n_threads(&self) -> i32 {
373 self.context_params.n_threads
374 }
375
376 /// Get the number of threads allocated for batches.
377 ///
378 /// # Examples
379 ///
380 /// ```rust
381 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
382 /// assert_eq!(params.n_threads_batch(), 4);
383 /// ```
384 #[must_use]
385 pub fn n_threads_batch(&self) -> i32 {
386 self.context_params.n_threads_batch
387 }
388
389 /// Set the number of threads.
390 ///
391 /// # Examples
392 ///
393 /// ```rust
394 /// use llama_cpp_4::context::params::LlamaContextParams;
395 /// let params = LlamaContextParams::default()
396 /// .with_n_threads(8);
397 /// assert_eq!(params.n_threads(), 8);
398 /// ```
399 #[must_use]
400 pub fn with_n_threads(mut self, n_threads: i32) -> Self {
401 self.context_params.n_threads = n_threads;
402 self
403 }
404
405 /// Set the number of threads allocated for batches.
406 ///
407 /// # Examples
408 ///
409 /// ```rust
410 /// use llama_cpp_4::context::params::LlamaContextParams;
411 /// let params = LlamaContextParams::default()
412 /// .with_n_threads_batch(8);
413 /// assert_eq!(params.n_threads_batch(), 8);
414 /// ```
415 #[must_use]
416 pub fn with_n_threads_batch(mut self, n_threads: i32) -> Self {
417 self.context_params.n_threads_batch = n_threads;
418 self
419 }
420
421 /// Check whether embeddings are enabled
422 ///
423 /// # Examples
424 ///
425 /// ```rust
426 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
427 /// assert!(!params.embeddings());
428 /// ```
429 #[must_use]
430 pub fn embeddings(&self) -> bool {
431 self.context_params.embeddings
432 }
433
434 /// Enable the use of embeddings
435 ///
436 /// # Examples
437 ///
438 /// ```rust
439 /// use llama_cpp_4::context::params::LlamaContextParams;
440 /// let params = LlamaContextParams::default()
441 /// .with_embeddings(true);
442 /// assert!(params.embeddings());
443 /// ```
444 #[must_use]
445 pub fn with_embeddings(mut self, embedding: bool) -> Self {
446 self.context_params.embeddings = embedding;
447 self
448 }
449
450 /// Set the evaluation callback.
451 ///
452 /// # Examples
453 ///
454 /// ```no_run
455 /// extern "C" fn cb_eval_fn(
456 /// t: *mut llama_cpp_sys_4::ggml_tensor,
457 /// ask: bool,
458 /// user_data: *mut std::ffi::c_void,
459 /// ) -> bool {
460 /// false
461 /// }
462 ///
463 /// use llama_cpp_4::context::params::LlamaContextParams;
464 /// let params = LlamaContextParams::default().with_cb_eval(Some(cb_eval_fn));
465 /// ```
466 #[must_use]
467 pub fn with_cb_eval(
468 mut self,
469 cb_eval: llama_cpp_sys_4::ggml_backend_sched_eval_callback,
470 ) -> Self {
471 self.context_params.cb_eval = cb_eval;
472 self
473 }
474
475 /// Set the evaluation callback user data.
476 ///
477 /// # Examples
478 ///
479 /// ```no_run
480 /// use llama_cpp_4::context::params::LlamaContextParams;
481 /// let params = LlamaContextParams::default();
482 /// let user_data = std::ptr::null_mut();
483 /// let params = params.with_cb_eval_user_data(user_data);
484 /// ```
485 #[must_use]
486 pub fn with_cb_eval_user_data(mut self, cb_eval_user_data: *mut std::ffi::c_void) -> Self {
487 self.context_params.cb_eval_user_data = cb_eval_user_data;
488 self
489 }
490
491 /// Set the type of pooling.
492 ///
493 /// # Examples
494 ///
495 /// ```rust
496 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaPoolingType};
497 /// let params = LlamaContextParams::default()
498 /// .with_pooling_type(LlamaPoolingType::Last);
499 /// assert_eq!(params.pooling_type(), LlamaPoolingType::Last);
500 /// ```
501 #[must_use]
502 pub fn with_pooling_type(mut self, pooling_type: LlamaPoolingType) -> Self {
503 self.context_params.pooling_type = i32::from(pooling_type);
504 self
505 }
506
507 /// Get the type of pooling.
508 ///
509 /// # Examples
510 ///
511 /// ```rust
512 /// let params = llama_cpp_4::context::params::LlamaContextParams::default();
513 /// assert_eq!(params.pooling_type(), llama_cpp_4::context::params::LlamaPoolingType::Unspecified);
514 /// ```
515 #[must_use]
516 pub fn pooling_type(&self) -> LlamaPoolingType {
517 LlamaPoolingType::from(self.context_params.pooling_type)
518 }
519}
520
521/// Default parameters for `LlamaContext`. (as defined in llama.cpp by `llama_context_default_params`)
522/// ```
523/// # use std::num::NonZeroU32;
524/// use llama_cpp_4::context::params::{LlamaContextParams, RopeScalingType};
525/// let params = LlamaContextParams::default();
526/// assert_eq!(params.n_ctx(), NonZeroU32::new(512), "n_ctx should be 512");
527/// assert_eq!(params.rope_scaling_type(), RopeScalingType::Unspecified);
528/// ```
529impl Default for LlamaContextParams {
530 fn default() -> Self {
531 let context_params = unsafe { llama_cpp_sys_4::llama_context_default_params() };
532 Self { context_params }
533 }
534}