Skip to main content

llama_cpp_4/context/params/
advanced.rs

1use super::{LlamaAttentionType, LlamaContextParams, LlamaFlashAttnType};
2use crate::sampling::LlamaSampler;
3
4impl LlamaContextParams {
5    /// Set the flash-attention mode (`Auto`, `Enabled`, or `Disabled`).
6    ///
7    /// Maps to `llama_context_params.flash_attn_type`. Use
8    /// [`LlamaFlashAttnType::Auto`] to match llama.cpp defaults.
9    ///
10    /// # Examples
11    ///
12    /// ```rust
13    /// use llama_cpp_4::context::params::{LlamaContextParams, LlamaFlashAttnType};
14    /// let params = LlamaContextParams::default()
15    ///     .with_flash_attn_type(LlamaFlashAttnType::Auto);
16    /// assert_eq!(params.flash_attn_type(), LlamaFlashAttnType::Auto);
17    /// ```
18    #[must_use]
19    pub fn with_flash_attn_type(mut self, flash_attn_type: LlamaFlashAttnType) -> Self {
20        self.context_params.flash_attn_type = flash_attn_type.into();
21        self
22    }
23
24    /// Get the configured flash-attention mode.
25    #[must_use]
26    pub fn flash_attn_type(&self) -> LlamaFlashAttnType {
27        LlamaFlashAttnType::from(self.context_params.flash_attn_type)
28    }
29
30    /// Set the attention type used when extracting embeddings.
31    ///
32    /// Maps to `llama_context_params.attention_type`. Embedding models often
33    /// need [`LlamaAttentionType::NonCausal`]; generative decoding uses
34    /// [`LlamaAttentionType::Causal`].
35    ///
36    /// # Examples
37    ///
38    /// ```rust
39    /// use llama_cpp_4::context::params::{LlamaAttentionType, LlamaContextParams};
40    /// let params = LlamaContextParams::default()
41    ///     .with_attention_type(LlamaAttentionType::Causal);
42    /// assert_eq!(params.attention_type(), LlamaAttentionType::Causal);
43    /// ```
44    #[must_use]
45    pub fn with_attention_type(mut self, attention_type: LlamaAttentionType) -> Self {
46        self.context_params.attention_type = attention_type.into();
47        self
48    }
49
50    /// Get the attention type used when extracting embeddings.
51    #[must_use]
52    pub fn attention_type(&self) -> LlamaAttentionType {
53        LlamaAttentionType::from(self.context_params.attention_type)
54    }
55
56    /// Set the maximum number of outputs per micro-batch.
57    ///
58    /// Maps to `llama_context_params.n_outputs_max`. When `0`, llama.cpp uses
59    /// `n_batch` as the cap.
60    ///
61    /// # Examples
62    ///
63    /// ```rust
64    /// use llama_cpp_4::context::params::LlamaContextParams;
65    /// let params = LlamaContextParams::default().with_n_outputs_max(256);
66    /// assert_eq!(params.n_outputs_max(), 256);
67    /// ```
68    #[must_use]
69    pub fn with_n_outputs_max(mut self, n_outputs_max: u32) -> Self {
70        self.context_params.n_outputs_max = n_outputs_max;
71        self
72    }
73
74    /// Get the maximum number of outputs per micro-batch.
75    #[must_use]
76    pub fn n_outputs_max(&self) -> u32 {
77        self.context_params.n_outputs_max
78    }
79
80    /// Use a unified KV buffer across input sequences.
81    ///
82    /// Maps to `llama_context_params.kv_unified`. Disabling can improve
83    /// throughput for batched decoding when sequences do not share a long prefix.
84    ///
85    /// # Examples
86    ///
87    /// ```rust
88    /// use llama_cpp_4::context::params::LlamaContextParams;
89    /// let params = LlamaContextParams::default().with_kv_unified(false);
90    /// assert!(!params.kv_unified());
91    /// ```
92    #[must_use]
93    pub fn with_kv_unified(mut self, kv_unified: bool) -> Self {
94        self.context_params.kv_unified = kv_unified;
95        self
96    }
97
98    /// Returns `true` when a unified KV buffer is enabled.
99    #[must_use]
100    pub fn kv_unified(&self) -> bool {
101        self.context_params.kv_unified
102    }
103
104    /// Use a full-size sliding-window-attention (SWA) KV cache.
105    ///
106    /// Maps to `llama_context_params.swa_full`. When `false` and `n_seq_max > 1`,
107    /// llama.cpp may use a smaller per-sequence SWA window for better performance.
108    ///
109    /// # Examples
110    ///
111    /// ```rust
112    /// use llama_cpp_4::context::params::LlamaContextParams;
113    /// let params = LlamaContextParams::default().with_swa_full(true);
114    /// assert!(params.swa_full());
115    /// ```
116    #[must_use]
117    pub fn with_swa_full(mut self, swa_full: bool) -> Self {
118        self.context_params.swa_full = swa_full;
119        self
120    }
121
122    /// Returns `true` when full SWA cache is enabled.
123    #[must_use]
124    pub fn swa_full(&self) -> bool {
125        self.context_params.swa_full
126    }
127
128    /// Offload eligible host tensor operations to the active device.
129    ///
130    /// Maps to `llama_context_params.op_offload`.
131    ///
132    /// # Examples
133    ///
134    /// ```rust
135    /// use llama_cpp_4::context::params::LlamaContextParams;
136    /// let params = LlamaContextParams::default().with_op_offload(true);
137    /// assert!(params.op_offload());
138    /// ```
139    #[must_use]
140    pub fn with_op_offload(mut self, op_offload: bool) -> Self {
141        self.context_params.op_offload = op_offload;
142        self
143    }
144
145    /// Returns `true` when host tensor ops are offloaded to device.
146    #[must_use]
147    pub fn op_offload(&self) -> bool {
148        self.context_params.op_offload
149    }
150
151    /// Pair this context with another for shared memory or cross-context results.
152    ///
153    /// Maps to `llama_context_params.ctx_other`. The paired context is returned
154    /// by [`crate::context::LlamaContext::ctx_other`] after creation.
155    ///
156    /// `other` must remain alive until [`crate::model::LlamaModel::new_context`]
157    /// returns.
158    ///
159    /// # Examples
160    ///
161    /// ```ignore
162    /// let target = model.new_context(&backend, LlamaContextParams::default())?;
163    /// let draft = model.new_context(
164    ///     &backend,
165    ///     LlamaContextParams::default().with_ctx_other(&target),
166    /// )?;
167    /// ```
168    #[must_use]
169    pub fn with_ctx_other(mut self, other: &crate::context::LlamaContext<'_>) -> Self {
170        self.context_params.ctx_other = other.context.as_ptr();
171        self
172    }
173
174    /// Set `YaRN` extrapolation mix factor.
175    ///
176    /// Maps to `llama_context_params.yarn_ext_factor`. Negative values use the
177    /// model default. Only meaningful when [`super::RopeScalingType::Yarn`] is active.
178    ///
179    /// # Examples
180    ///
181    /// ```rust
182    /// use llama_cpp_4::context::params::LlamaContextParams;
183    /// let params = LlamaContextParams::default().with_yarn_ext_factor(1.0);
184    /// assert_eq!(params.yarn_ext_factor(), 1.0);
185    /// ```
186    #[must_use]
187    pub fn with_yarn_ext_factor(mut self, yarn_ext_factor: f32) -> Self {
188        self.context_params.yarn_ext_factor = yarn_ext_factor;
189        self
190    }
191
192    /// Get `YaRN` extrapolation mix factor (`yarn_ext_factor`).
193    #[must_use]
194    pub fn yarn_ext_factor(&self) -> f32 {
195        self.context_params.yarn_ext_factor
196    }
197
198    /// Set `YaRN` magnitude scaling factor.
199    ///
200    /// Maps to `llama_context_params.yarn_attn_factor`.
201    ///
202    /// # Examples
203    ///
204    /// ```rust
205    /// use llama_cpp_4::context::params::LlamaContextParams;
206    /// let params = LlamaContextParams::default().with_yarn_attn_factor(1.0);
207    /// assert_eq!(params.yarn_attn_factor(), 1.0);
208    /// ```
209    #[must_use]
210    pub fn with_yarn_attn_factor(mut self, yarn_attn_factor: f32) -> Self {
211        self.context_params.yarn_attn_factor = yarn_attn_factor;
212        self
213    }
214
215    /// Get `YaRN` magnitude scaling factor (`yarn_attn_factor`).
216    #[must_use]
217    pub fn yarn_attn_factor(&self) -> f32 {
218        self.context_params.yarn_attn_factor
219    }
220
221    /// Set `YaRN` low correction dimension (`yarn_beta_fast`).
222    ///
223    /// Maps to `llama_context_params.yarn_beta_fast`.
224    #[must_use]
225    pub fn with_yarn_beta_fast(mut self, yarn_beta_fast: f32) -> Self {
226        self.context_params.yarn_beta_fast = yarn_beta_fast;
227        self
228    }
229
230    /// Get `YaRN` low correction dimension.
231    #[must_use]
232    pub fn yarn_beta_fast(&self) -> f32 {
233        self.context_params.yarn_beta_fast
234    }
235
236    /// Set `YaRN` high correction dimension (`yarn_beta_slow`).
237    ///
238    /// Maps to `llama_context_params.yarn_beta_slow`.
239    #[must_use]
240    pub fn with_yarn_beta_slow(mut self, yarn_beta_slow: f32) -> Self {
241        self.context_params.yarn_beta_slow = yarn_beta_slow;
242        self
243    }
244
245    /// Get `YaRN` high correction dimension.
246    #[must_use]
247    pub fn yarn_beta_slow(&self) -> f32 {
248        self.context_params.yarn_beta_slow
249    }
250
251    /// Set `YaRN` original context size.
252    ///
253    /// Maps to `llama_context_params.yarn_orig_ctx`. `0` uses the model default.
254    ///
255    /// # Examples
256    ///
257    /// ```rust
258    /// use llama_cpp_4::context::params::LlamaContextParams;
259    /// let params = LlamaContextParams::default().with_yarn_orig_ctx(8192);
260    /// assert_eq!(params.yarn_orig_ctx(), 8192);
261    /// ```
262    #[must_use]
263    pub fn with_yarn_orig_ctx(mut self, yarn_orig_ctx: u32) -> Self {
264        self.context_params.yarn_orig_ctx = yarn_orig_ctx;
265        self
266    }
267
268    /// Get `YaRN` original context size (`yarn_orig_ctx`).
269    #[must_use]
270    pub fn yarn_orig_ctx(&self) -> u32 {
271        self.context_params.yarn_orig_ctx
272    }
273
274    /// Disable performance timing collection for this context.
275    ///
276    /// Maps to `llama_context_params.no_perf`. When `true`, calls such as
277    /// [`crate::context::LlamaContext::timings`] return empty counters.
278    ///
279    /// # Examples
280    ///
281    /// ```rust
282    /// use llama_cpp_4::context::params::LlamaContextParams;
283    /// let params = LlamaContextParams::default().with_no_perf(true);
284    /// assert!(params.no_perf());
285    /// ```
286    #[must_use]
287    pub fn with_no_perf(mut self, no_perf: bool) -> Self {
288        self.context_params.no_perf = no_perf;
289        self
290    }
291
292    /// Returns `true` when perf timings are disabled for this context.
293    #[must_use]
294    pub fn no_perf(&self) -> bool {
295        self.context_params.no_perf
296    }
297
298    /// Register an abort callback checked during `decode()` on CPU backends.
299    ///
300    /// Maps to `llama_context_params.abort_callback` / `abort_callback_data`.
301    /// The callback is invoked periodically during long decodes; return a
302    /// non-zero value to stop the current operation.
303    ///
304    /// `user_data` is passed through unchanged and must remain valid for the
305    /// lifetime of any context created from these params.
306    #[must_use]
307    pub fn with_abort_callback(
308        mut self,
309        callback: llama_cpp_sys_4::ggml_abort_callback,
310        user_data: *mut std::ffi::c_void,
311    ) -> Self {
312        self.context_params.abort_callback = callback;
313        self.context_params.abort_callback_data = user_data;
314        self
315    }
316
317    /// Assign per-sequence backend sampler chains.
318    ///
319    /// Maps to `llama_context_params.samplers` / `n_samplers`. Each
320    /// [`LlamaSampler`] must be a sampler **chain** created with
321    /// `llama_sampler_chain_init`. The samplers are kept alive inside these
322    /// params until [`crate::model::LlamaModel::new_context`] returns.
323    ///
324    /// Pair sequence ids with the chains that should run when decoding those
325    /// sequences on the backend.
326    ///
327    /// # Examples
328    ///
329    /// ```ignore
330    /// use llama_cpp_4::context::params::LlamaContextParams;
331    /// use llama_cpp_4::sampling::LlamaSampler;
332    ///
333    /// let chain = LlamaSampler::chain_default(&model)?;
334    /// let params = LlamaContextParams::default()
335    ///     .with_sampler_seq_configs([(0, chain)]);
336    /// assert_eq!(params.n_sampler_seq_configs(), 1);
337    /// ```
338    #[must_use]
339    pub fn with_sampler_seq_configs(
340        mut self,
341        configs: impl IntoIterator<Item = (i32, LlamaSampler)>,
342    ) -> Self {
343        self.owned_samplers.clear();
344        self.sampler_configs.clear();
345
346        for (seq_id, sampler) in configs {
347            self.sampler_configs
348                .push(llama_cpp_sys_4::llama_sampler_seq_config {
349                    seq_id,
350                    sampler: sampler.sampler.as_ptr(),
351                });
352            self.owned_samplers.push(sampler);
353        }
354
355        if self.sampler_configs.is_empty() {
356            self.context_params.samplers = std::ptr::null_mut();
357            self.context_params.n_samplers = 0;
358        } else {
359            self.context_params.samplers = self.sampler_configs.as_mut_ptr();
360            self.context_params.n_samplers = self.sampler_configs.len();
361        }
362
363        self
364    }
365
366    /// Number of per-sequence sampler configs attached to these params.
367    ///
368    /// Returns `0` when no chains were set or after [`Clone`] (sampler chains
369    /// are not duplicated).
370    #[must_use]
371    pub fn n_sampler_seq_configs(&self) -> usize {
372        self.sampler_configs.len()
373    }
374}