llama_cpp_4/context/params/advanced.rs
1use super::{LlamaAttentionType, LlamaContextParams, LlamaFlashAttnType};
2use crate::sampling::LlamaSampler;
3
4impl LlamaContextParams {
5 /// Set the flash-attention mode (`Auto`, `Enabled`, or `Disabled`).
6 ///
7 /// Maps to `llama_context_params.flash_attn_type`. Use
8 /// [`LlamaFlashAttnType::Auto`] to match llama.cpp defaults.
9 ///
10 /// # Examples
11 ///
12 /// ```rust
13 /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaFlashAttnType};
14 /// let params = LlamaContextParams::default()
15 /// .with_flash_attn_type(LlamaFlashAttnType::Auto);
16 /// assert_eq!(params.flash_attn_type(), LlamaFlashAttnType::Auto);
17 /// ```
18 #[must_use]
19 pub fn with_flash_attn_type(mut self, flash_attn_type: LlamaFlashAttnType) -> Self {
20 self.context_params.flash_attn_type = flash_attn_type.into();
21 self
22 }
23
24 /// Get the configured flash-attention mode.
25 #[must_use]
26 pub fn flash_attn_type(&self) -> LlamaFlashAttnType {
27 LlamaFlashAttnType::from(self.context_params.flash_attn_type)
28 }
29
30 /// Set the attention type used when extracting embeddings.
31 ///
32 /// Maps to `llama_context_params.attention_type`. Embedding models often
33 /// need [`LlamaAttentionType::NonCausal`]; generative decoding uses
34 /// [`LlamaAttentionType::Causal`].
35 ///
36 /// # Examples
37 ///
38 /// ```rust
39 /// use llama_cpp_4::context::params::{LlamaAttentionType, LlamaContextParams};
40 /// let params = LlamaContextParams::default()
41 /// .with_attention_type(LlamaAttentionType::Causal);
42 /// assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
43 /// ```
44 #[must_use]
45 pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
46 self.context_params.attention_type = attention_type.into();
47 self
48 }
49
50 /// Get the attention type used when extracting embeddings.
51 #[must_use]
52 pub fn attention_type(&self) -> LlamaAttentionType {
53 LlamaAttentionType::from(self.context_params.attention_type)
54 }
55
56 /// Set the maximum number of outputs per micro-batch.
57 ///
58 /// Maps to `llama_context_params.n_outputs_max`. When `0`, llama.cpp uses
59 /// `n_batch` as the cap.
60 ///
61 /// # Examples
62 ///
63 /// ```rust
64 /// use llama_cpp_4::context::params::LlamaContextParams;
65 /// let params = LlamaContextParams::default().with_n_outputs_max(256);
66 /// assert_eq!(params.n_outputs_max(), 256);
67 /// ```
68 #[must_use]
69 pub fn with_n_outputs_max(mut self, n_outputs_max: u32) -> Self {
70 self.context_params.n_outputs_max = n_outputs_max;
71 self
72 }
73
74 /// Get the maximum number of outputs per micro-batch.
75 #[must_use]
76 pub fn n_outputs_max(&self) -> u32 {
77 self.context_params.n_outputs_max
78 }
79
80 /// Use a unified KV buffer across input sequences.
81 ///
82 /// Maps to `llama_context_params.kv_unified`. Disabling can improve
83 /// throughput for batched decoding when sequences do not share a long prefix.
84 ///
85 /// # Examples
86 ///
87 /// ```rust
88 /// use llama_cpp_4::context::params::LlamaContextParams;
89 /// let params = LlamaContextParams::default().with_kv_unified(false);
90 /// assert!(!params.kv_unified());
91 /// ```
92 #[must_use]
93 pub fn with_kv_unified(mut self, kv_unified: bool) -> Self {
94 self.context_params.kv_unified = kv_unified;
95 self
96 }
97
98 /// Returns `true` when a unified KV buffer is enabled.
99 #[must_use]
100 pub fn kv_unified(&self) -> bool {
101 self.context_params.kv_unified
102 }
103
104 /// Use a full-size sliding-window-attention (SWA) KV cache.
105 ///
106 /// Maps to `llama_context_params.swa_full`. When `false` and `n_seq_max > 1`,
107 /// llama.cpp may use a smaller per-sequence SWA window for better performance.
108 ///
109 /// # Examples
110 ///
111 /// ```rust
112 /// use llama_cpp_4::context::params::LlamaContextParams;
113 /// let params = LlamaContextParams::default().with_swa_full(true);
114 /// assert!(params.swa_full());
115 /// ```
116 #[must_use]
117 pub fn with_swa_full(mut self, swa_full: bool) -> Self {
118 self.context_params.swa_full = swa_full;
119 self
120 }
121
122 /// Returns `true` when full SWA cache is enabled.
123 #[must_use]
124 pub fn swa_full(&self) -> bool {
125 self.context_params.swa_full
126 }
127
128 /// Offload eligible host tensor operations to the active device.
129 ///
130 /// Maps to `llama_context_params.op_offload`.
131 ///
132 /// # Examples
133 ///
134 /// ```rust
135 /// use llama_cpp_4::context::params::LlamaContextParams;
136 /// let params = LlamaContextParams::default().with_op_offload(true);
137 /// assert!(params.op_offload());
138 /// ```
139 #[must_use]
140 pub fn with_op_offload(mut self, op_offload: bool) -> Self {
141 self.context_params.op_offload = op_offload;
142 self
143 }
144
145 /// Returns `true` when host tensor ops are offloaded to device.
146 #[must_use]
147 pub fn op_offload(&self) -> bool {
148 self.context_params.op_offload
149 }
150
151 /// Pair this context with another for shared memory or cross-context results.
152 ///
153 /// Maps to `llama_context_params.ctx_other`. The paired context is returned
154 /// by [`crate::context::LlamaContext::ctx_other`] after creation.
155 ///
156 /// `other` must remain alive until [`crate::model::LlamaModel::new_context`]
157 /// returns.
158 ///
159 /// # Examples
160 ///
161 /// ```ignore
162 /// let target = model.new_context(&backend, LlamaContextParams::default())?;
163 /// let draft = model.new_context(
164 /// &backend,
165 /// LlamaContextParams::default().with_ctx_other(&target),
166 /// )?;
167 /// ```
168 #[must_use]
169 pub fn with_ctx_other(mut self, other: &crate::context::LlamaContext<'_>) -> Self {
170 self.context_params.ctx_other = other.context.as_ptr();
171 self
172 }
173
174 /// Set `YaRN` extrapolation mix factor.
175 ///
176 /// Maps to `llama_context_params.yarn_ext_factor`. Negative values use the
177 /// model default. Only meaningful when [`super::RopeScalingType::Yarn`] is active.
178 ///
179 /// # Examples
180 ///
181 /// ```rust
182 /// use llama_cpp_4::context::params::LlamaContextParams;
183 /// let params = LlamaContextParams::default().with_yarn_ext_factor(1.0);
184 /// assert_eq!(params.yarn_ext_factor(), 1.0);
185 /// ```
186 #[must_use]
187 pub fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
188 self.context_params.yarn_ext_factor = yarn_ext_factor;
189 self
190 }
191
192 /// Get `YaRN` extrapolation mix factor (`yarn_ext_factor`).
193 #[must_use]
194 pub fn yarn_ext_factor(&self) -> f32 {
195 self.context_params.yarn_ext_factor
196 }
197
198 /// Set `YaRN` magnitude scaling factor.
199 ///
200 /// Maps to `llama_context_params.yarn_attn_factor`.
201 ///
202 /// # Examples
203 ///
204 /// ```rust
205 /// use llama_cpp_4::context::params::LlamaContextParams;
206 /// let params = LlamaContextParams::default().with_yarn_attn_factor(1.0);
207 /// assert_eq!(params.yarn_attn_factor(), 1.0);
208 /// ```
209 #[must_use]
210 pub fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
211 self.context_params.yarn_attn_factor = yarn_attn_factor;
212 self
213 }
214
215 /// Get `YaRN` magnitude scaling factor (`yarn_attn_factor`).
216 #[must_use]
217 pub fn yarn_attn_factor(&self) -> f32 {
218 self.context_params.yarn_attn_factor
219 }
220
221 /// Set `YaRN` low correction dimension (`yarn_beta_fast`).
222 ///
223 /// Maps to `llama_context_params.yarn_beta_fast`.
224 #[must_use]
225 pub fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
226 self.context_params.yarn_beta_fast = yarn_beta_fast;
227 self
228 }
229
230 /// Get `YaRN` low correction dimension.
231 #[must_use]
232 pub fn yarn_beta_fast(&self) -> f32 {
233 self.context_params.yarn_beta_fast
234 }
235
236 /// Set `YaRN` high correction dimension (`yarn_beta_slow`).
237 ///
238 /// Maps to `llama_context_params.yarn_beta_slow`.
239 #[must_use]
240 pub fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
241 self.context_params.yarn_beta_slow = yarn_beta_slow;
242 self
243 }
244
245 /// Get `YaRN` high correction dimension.
246 #[must_use]
247 pub fn yarn_beta_slow(&self) -> f32 {
248 self.context_params.yarn_beta_slow
249 }
250
251 /// Set `YaRN` original context size.
252 ///
253 /// Maps to `llama_context_params.yarn_orig_ctx`. `0` uses the model default.
254 ///
255 /// # Examples
256 ///
257 /// ```rust
258 /// use llama_cpp_4::context::params::LlamaContextParams;
259 /// let params = LlamaContextParams::default().with_yarn_orig_ctx(8192);
260 /// assert_eq!(params.yarn_orig_ctx(), 8192);
261 /// ```
262 #[must_use]
263 pub fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
264 self.context_params.yarn_orig_ctx = yarn_orig_ctx;
265 self
266 }
267
268 /// Get `YaRN` original context size (`yarn_orig_ctx`).
269 #[must_use]
270 pub fn yarn_orig_ctx(&self) -> u32 {
271 self.context_params.yarn_orig_ctx
272 }
273
274 /// Disable performance timing collection for this context.
275 ///
276 /// Maps to `llama_context_params.no_perf`. When `true`, calls such as
277 /// [`crate::context::LlamaContext::timings`] return empty counters.
278 ///
279 /// # Examples
280 ///
281 /// ```rust
282 /// use llama_cpp_4::context::params::LlamaContextParams;
283 /// let params = LlamaContextParams::default().with_no_perf(true);
284 /// assert!(params.no_perf());
285 /// ```
286 #[must_use]
287 pub fn with_no_perf(mut self, no_perf: bool) -> Self {
288 self.context_params.no_perf = no_perf;
289 self
290 }
291
292 /// Returns `true` when perf timings are disabled for this context.
293 #[must_use]
294 pub fn no_perf(&self) -> bool {
295 self.context_params.no_perf
296 }
297
298 /// Register an abort callback checked during `decode()` on CPU backends.
299 ///
300 /// Maps to `llama_context_params.abort_callback` / `abort_callback_data`.
301 /// The callback is invoked periodically during long decodes; return a
302 /// non-zero value to stop the current operation.
303 ///
304 /// `user_data` is passed through unchanged and must remain valid for the
305 /// lifetime of any context created from these params.
306 #[must_use]
307 pub fn with_abort_callback(
308 mut self,
309 callback: llama_cpp_sys_4::ggml_abort_callback,
310 user_data: *mut std::ffi::c_void,
311 ) -> Self {
312 self.context_params.abort_callback = callback;
313 self.context_params.abort_callback_data = user_data;
314 self
315 }
316
317 /// Assign per-sequence backend sampler chains.
318 ///
319 /// Maps to `llama_context_params.samplers` / `n_samplers`. Each
320 /// [`LlamaSampler`] must be a sampler **chain** created with
321 /// `llama_sampler_chain_init`. The samplers are kept alive inside these
322 /// params until [`crate::model::LlamaModel::new_context`] returns.
323 ///
324 /// Pair sequence ids with the chains that should run when decoding those
325 /// sequences on the backend.
326 ///
327 /// # Examples
328 ///
329 /// ```ignore
330 /// use llama_cpp_4::context::params::LlamaContextParams;
331 /// use llama_cpp_4::sampling::LlamaSampler;
332 ///
333 /// let chain = LlamaSampler::chain_default(&model)?;
334 /// let params = LlamaContextParams::default()
335 /// .with_sampler_seq_configs([(0, chain)]);
336 /// assert_eq!(params.n_sampler_seq_configs(), 1);
337 /// ```
338 #[must_use]
339 pub fn with_sampler_seq_configs(
340 mut self,
341 configs: impl IntoIterator<Item = (i32, LlamaSampler)>,
342 ) -> Self {
343 self.owned_samplers.clear();
344 self.sampler_configs.clear();
345
346 for (seq_id, sampler) in configs {
347 self.sampler_configs
348 .push(llama_cpp_sys_4::llama_sampler_seq_config {
349 seq_id,
350 sampler: sampler.sampler.as_ptr(),
351 });
352 self.owned_samplers.push(sampler);
353 }
354
355 if self.sampler_configs.is_empty() {
356 self.context_params.samplers = std::ptr::null_mut();
357 self.context_params.n_samplers = 0;
358 } else {
359 self.context_params.samplers = self.sampler_configs.as_mut_ptr();
360 self.context_params.n_samplers = self.sampler_configs.len();
361 }
362
363 self
364 }
365
366 /// Number of per-sequence sampler configs attached to these params.
367 ///
368 /// Returns `0` when no chains were set or after [`Clone`] (sampler chains
369 /// are not duplicated).
370 #[must_use]
371 pub fn n_sampler_seq_configs(&self) -> usize {
372 self.sampler_configs.len()
373 }
374}