Skip to main content

llama_cpp_4/
context.rs

1//! Safe wrapper around `llama_context`.
2//!
3//! Submodules:
4//!
5//! - [`tensor_capture`] — hook `cb_eval` during [`LlamaContext::decode`] to copy
6//!   intermediate tensors (per-layer hidden states, norms, …).
7//! - [`memory_breakdown`] — per-buffer memory usage after load/decode.
8//! - [`kv_cache`] — sequence copy, shift, and clear helpers.
9
10use std::fmt::{Debug, Formatter};
11use std::num::NonZeroI32;
12use std::ptr::NonNull;
13use std::slice;
14
15use llama_cpp_sys_4::llama_pooling_type;
16use params::LlamaPoolingType;
17use perf::PerfContextData;
18
19use crate::llama_batch::LlamaBatch;
20use crate::model::{LlamaLoraAdapter, LlamaModel};
21use crate::token::data::LlamaTokenData;
22use crate::token::data_array::LlamaTokenDataArray;
23use crate::token::LlamaToken;
24use crate::{
25    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
26    LlamaLoraAdapterSetError,
27};
28
29pub mod kv_cache;
30pub mod memory_breakdown;
31pub mod params;
32pub mod perf;
33pub mod session;
34pub mod tensor_capture;
35
36pub use memory_breakdown::MemoryBreakdownEntry;
37pub use tensor_capture::{CapturedTensor, TensorCapture};
38
39/// A safe wrapper around the `llama_context` C++ context.
40///
41/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
42/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
43/// context-specific settings like embeddings and logits.
44///
45/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
46/// the context pointer, preventing it from being null. The struct also holds a reference to the model
47/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
48/// and the initialized logits for the context.
49///
50/// # Fields
51///
52/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
53/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
54///   that the context interacts with.
55/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
56///   are kept separate from the context data.
57/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
58///   controlling whether embedding data is generated during the interaction with the model.
59#[allow(clippy::module_name_repetitions)]
60pub struct LlamaContext<'a> {
61    pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
62    /// a reference to the contexts model.
63    pub model: &'a LlamaModel,
64    initialized_logits: Vec<i32>,
65    embeddings_enabled: bool,
66}
67
68impl Debug for LlamaContext<'_> {
69    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("LlamaContext")
71            .field("context", &self.context)
72            .finish()
73    }
74}
75
76impl<'model> LlamaContext<'model> {
77    /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
78    ///
79    /// This function initializes a new `LlamaContext` object, which is used to interact with the
80    /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
81    /// determines whether embeddings are enabled in the context.
82    ///
83    /// # Parameters
84    ///
85    /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
86    /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
87    ///   the context created in previous steps. This context is necessary for interacting with the model.
88    /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
89    ///
90    /// # Returns
91    ///
92    /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
93    /// - The model reference (`llama_model`) is stored in the context.
94    /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
95    /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
96    ///
97    /// # Example
98    /// ```ignore
99    /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", &params).unwrap();
100    /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
101    /// let context = LlamaContext::new(&llama_model, context_ptr, true);
102    /// // Now you can use the context
103    /// ```
104    pub(crate) fn new(
105        llama_model: &'model LlamaModel,
106        llama_context: NonNull<llama_cpp_sys_4::llama_context>,
107        embeddings_enabled: bool,
108    ) -> Self {
109        Self {
110            context: llama_context,
111            model: llama_model,
112            initialized_logits: Vec::new(),
113            embeddings_enabled,
114        }
115    }
116
117    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
118    #[must_use]
119    pub fn n_batch(&self) -> u32 {
120        unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
121    }
122
123    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
124    #[must_use]
125    pub fn n_ubatch(&self) -> u32 {
126        unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
127    }
128
129    /// Gets the size of the context.
130    #[must_use]
131    pub fn n_ctx(&self) -> u32 {
132        unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
133    }
134
135    /// Decodes the batch.
136    ///
137    /// # Errors
138    ///
139    /// - `DecodeError` if the decoding failed.
140    ///
141    /// # Panics
142    ///
143    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
144    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
145        let result =
146            unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
147
148        match NonZeroI32::new(result) {
149            None => {
150                self.initialized_logits
151                    .clone_from(&batch.initialized_logits);
152                Ok(())
153            }
154            Some(error) => Err(DecodeError::from(error)),
155        }
156    }
157
158    /// Encodes the batch.
159    ///
160    /// # Errors
161    ///
162    /// - `EncodeError` if the decoding failed.
163    ///
164    /// # Panics
165    ///
166    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
167    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
168        let result =
169            unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
170
171        match NonZeroI32::new(result) {
172            None => {
173                self.initialized_logits
174                    .clone_from(&batch.initialized_logits);
175                Ok(())
176            }
177            Some(error) => Err(EncodeError::from(error)),
178        }
179    }
180
181    /// Return Pooling type for Llama's Context
182    #[must_use]
183    pub fn pooling_type(&self) -> LlamaPoolingType {
184        let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
185
186        LlamaPoolingType::from(pooling_type)
187    }
188
189    /// Get the embeddings for the `i`th sequence in the current context.
190    ///
191    /// # Returns
192    ///
193    /// A slice containing the embeddings for the last decoded batch.
194    /// The size corresponds to the `n_embd` parameter of the context's model.
195    ///
196    /// # Errors
197    ///
198    /// - When the current context was constructed without enabling embeddings.
199    /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
200    /// - If the given sequence index exceeds the max sequence id.
201    ///
202    /// # Panics
203    ///
204    /// * `n_embd` does not fit into a usize
205    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
206        if !self.embeddings_enabled {
207            return Err(EmbeddingsError::NotEnabled);
208        }
209
210        let n_embd =
211            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
212
213        unsafe {
214            let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
215
216            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
217            if embedding.is_null() {
218                Err(EmbeddingsError::NonePoolType)
219            } else {
220                Ok(slice::from_raw_parts(embedding, n_embd))
221            }
222        }
223    }
224
225    /// Get the embeddings for the `i`th token in the current context.
226    ///
227    /// # Returns
228    ///
229    /// A slice containing the embeddings for the last decoded batch of the given token.
230    /// The size corresponds to the `n_embd` parameter of the context's model.
231    ///
232    /// # Errors
233    ///
234    /// - When the current context was constructed without enabling embeddings.
235    /// - When the given token didn't have logits enabled when it was passed.
236    /// - If the given token index exceeds the max token id.
237    ///
238    /// # Panics
239    ///
240    /// * `n_embd` does not fit into a usize
241    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
242        if !self.embeddings_enabled {
243            return Err(EmbeddingsError::NotEnabled);
244        }
245
246        let n_embd =
247            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
248
249        unsafe {
250            let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
251            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
252            if embedding.is_null() {
253                Err(EmbeddingsError::LogitsNotEnabled)
254            } else {
255                Ok(slice::from_raw_parts(embedding, n_embd))
256            }
257        }
258    }
259
260    /// Get the logits for the last token in the context.
261    ///
262    /// # Returns
263    /// An iterator over unsorted `LlamaTokenData` containing the
264    /// logits for the last token in the context.
265    ///
266    /// # Panics
267    ///
268    /// - underlying logits data is null
269    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
270        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
271            let token = LlamaToken::new(i);
272            LlamaTokenData::new(token, *logit, 0_f32)
273        })
274    }
275
276    /// Get the token data array for the last token in the context.
277    ///
278    /// This is a convience method that implements:
279    /// ```ignore
280    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
281    /// ```
282    ///
283    /// # Panics
284    ///
285    /// - underlying logits data is null
286    #[must_use]
287    pub fn token_data_array(&self) -> LlamaTokenDataArray {
288        LlamaTokenDataArray::from_iter(self.candidates(), false)
289    }
290
291    /// Token logits obtained from the last call to `decode()`.
292    /// The logits for which `batch.logits[i] != 0` are stored contiguously
293    /// in the order they have appeared in the batch.
294    /// Rows: number of tokens for which `batch.logits[i] != 0`
295    /// Cols: `n_vocab`
296    ///
297    /// # Returns
298    ///
299    /// A slice containing the logits for the last decoded token.
300    /// The size corresponds to the `n_vocab` parameter of the context's model.
301    ///
302    /// # Panics
303    ///
304    /// - `n_vocab` does not fit into a usize
305    /// - token data returned is null
306    #[must_use]
307    pub fn get_logits(&self) -> &[f32] {
308        let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
309        assert!(!data.is_null(), "logits data for last token is null");
310        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
311
312        unsafe { slice::from_raw_parts(data, len) }
313    }
314
315    /// Get the logits for the ith token in the context.
316    ///
317    /// # Panics
318    ///
319    /// - logit `i` is not initialized.
320    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
321        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
322            let token = LlamaToken::new(i);
323            LlamaTokenData::new(token, *logit, 0_f32)
324        })
325    }
326
327    /// Get the logits for the ith token in the context.
328    ///
329    /// # Panics
330    ///
331    /// - `i` is greater than `n_ctx`
332    /// - `n_vocab` does not fit into a usize
333    /// - logit `i` is not initialized.
334    #[must_use]
335    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
336        assert!(
337            self.initialized_logits.contains(&i),
338            "logit {i} is not initialized. only {:?} is",
339            self.initialized_logits
340        );
341        assert!(
342            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
343            "n_ctx ({}) must be greater than i ({})",
344            self.n_ctx(),
345            i
346        );
347
348        let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
349        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
350
351        unsafe { slice::from_raw_parts(data, len) }
352    }
353
354    /// Get the number of context tokens per sequence.
355    #[must_use]
356    pub fn n_ctx_seq(&self) -> u32 {
357        unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
358    }
359
360    /// Get the maximum number of sequences.
361    #[must_use]
362    pub fn n_seq_max(&self) -> u32 {
363        unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
364    }
365
366    /// Get the number of recurrent-state snapshots per sequence.
367    #[must_use]
368    pub fn n_rs_seq(&self) -> u32 {
369        unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
370    }
371
372    /// Get the number of threads used for generation.
373    #[must_use]
374    pub fn n_threads(&self) -> i32 {
375        unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
376    }
377
378    /// Get the number of threads used for batch processing.
379    #[must_use]
380    pub fn n_threads_batch(&self) -> i32 {
381        unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
382    }
383
384    /// Set the number of threads used for generation and batch processing.
385    pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
386        unsafe {
387            llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
388        }
389    }
390
391    /// Set whether to use causal attention.
392    ///
393    /// If set to `false`, the model will use non-causal attention, which is
394    /// needed for embedding models.
395    pub fn set_causal_attn(&mut self, causal_attn: bool) {
396        unsafe {
397            llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
398        }
399    }
400
401    /// Set whether to compute embeddings.
402    ///
403    /// This allows toggling embedding mode at runtime (as opposed to only at
404    /// context creation time).
405    pub fn set_embeddings(&mut self, embeddings: bool) {
406        self.embeddings_enabled = embeddings;
407        unsafe {
408            llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
409        }
410    }
411
412    /// Mark the next computation as a warmup run.
413    ///
414    /// Warmup runs are useful for GPU backends to compile kernels before
415    /// actual inference begins.
416    pub fn set_warmup(&mut self, warmup: bool) {
417        unsafe {
418            llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
419        }
420    }
421
422    /// Wait for all pending async computations to finish.
423    pub fn synchronize(&mut self) {
424        unsafe {
425            llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
426        }
427    }
428
429    /// Get all embeddings for the current context.
430    ///
431    /// Returns a slice of all embeddings from the last decoded batch.
432    /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
433    ///
434    /// # Errors
435    ///
436    /// - When the current context was constructed without enabling embeddings.
437    /// - If the embeddings pointer is null.
438    ///
439    /// # Panics
440    ///
441    /// * `n_embd` does not fit into a usize
442    pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
443        if !self.embeddings_enabled {
444            return Err(EmbeddingsError::NotEnabled);
445        }
446
447        let n_embd =
448            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
449
450        unsafe {
451            let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
452            if embedding.is_null() {
453                Err(EmbeddingsError::NonePoolType)
454            } else {
455                Ok(slice::from_raw_parts(embedding, n_embd))
456            }
457        }
458    }
459
460    /// Toggle extraction of next-n embeddings (Rust name: pre-norm) — hidden
461    /// states used by MTP draft heads. Upstream C API: `llama_set_embeddings_nextn`
462    /// (llama.cpp PR #23198 and later renames).
463    ///
464    /// If `masked` is `true`, pre-norm rows are extracted only for tokens
465    /// whose `batch.logits[i]` is non-zero. If `masked` is `false`, rows are
466    /// extracted for every token in the batch regardless of `batch.logits` —
467    /// callers can then leave `batch.logits[i] = false` on prompt-fill
468    /// positions and avoid copying the full logits row for each one.
469    ///
470    /// Upstream's MTP session init configures pre-norm extraction on the target
471    /// and draft contexts automatically. Call this manually only for custom
472    /// speculative setups.
473    pub fn set_embeddings_pre_norm(&mut self, value: bool, masked: bool) {
474        unsafe {
475            llama_cpp_sys_4::llama_set_embeddings_nextn(self.context.as_ptr(), value, masked);
476        }
477    }
478
479    /// Get the full pre-norm embeddings buffer for the last decoded batch.
480    ///
481    /// Returns `None` when pre-norm embeddings are disabled or the buffer
482    /// hasn't been populated. The length of the returned slice is
483    /// `n_embd * <number of pre-norm rows>` — interpretation of the row
484    /// count depends on whether the setter was called with `masked=true`
485    /// (one row per sampled token) or `masked=false` (one row per batch
486    /// token). Use [`get_embeddings_pre_norm_ith`](Self::get_embeddings_pre_norm_ith)
487    /// when you only need a single row.
488    ///
489    /// # Panics
490    ///
491    /// Panics if `n_embd` does not fit in `usize`.
492    #[must_use]
493    pub fn get_embeddings_pre_norm(&self) -> Option<&[f32]> {
494        let n_embd =
495            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
496        unsafe {
497            let p = llama_cpp_sys_4::llama_get_embeddings_nextn(self.context.as_ptr());
498            if p.is_null() {
499                None
500            } else {
501                Some(slice::from_raw_parts(p, n_embd))
502            }
503        }
504    }
505
506    /// Get the pre-norm embedding row for the `i`th output position of the
507    /// last decoded batch. Returns `None` if upstream rejects the index
508    /// (e.g. masked mode with `batch.logits[i] == 0`, or out of range).
509    ///
510    /// # Panics
511    ///
512    /// Panics if `n_embd` does not fit in `usize`.
513    #[must_use]
514    pub fn get_embeddings_pre_norm_ith(&self, i: i32) -> Option<&[f32]> {
515        let n_embd =
516            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
517        unsafe {
518            let p = llama_cpp_sys_4::llama_get_embeddings_nextn_ith(self.context.as_ptr(), i);
519            if p.is_null() {
520                None
521            } else {
522                Some(slice::from_raw_parts(p, n_embd))
523            }
524        }
525    }
526
527    /// Select which `NextN` block the MTP draft graph runs.
528    ///
529    /// `offset` indexes past the trunk transformer layers (`0` = first `NextN`
530    /// head). Required for multi-head MTP models such as Step3.5; restore to
531    /// `0` after drafting. See [`crate::mtp`] for the full speculative loop.
532    ///
533    /// # Examples
534    ///
535    /// ```ignore
536    /// for head in 0..model.n_layer_nextn() {
537    ///     draft.set_nextn_layer_offset(head);
538    ///     let drafts = session.draft(0, n_past, last_token)?;
539    /// }
540    /// draft.set_nextn_layer_offset(0);
541    /// ```
542    pub fn set_nextn_layer_offset(&mut self, offset: i32) {
543        unsafe {
544            llama_cpp_sys_4::llama_set_nextn_layer_offset(self.context.as_ptr(), offset);
545        }
546    }
547
548    /// Return the paired context set via
549    /// [`crate::context::params::LlamaContextParams::with_ctx_other`].
550    ///
551    /// The pointer refers to the other live context created during
552    /// [`crate::model::LlamaModel::new_context`]; it is `None` when no pairing
553    /// was configured.
554    #[must_use]
555    pub fn ctx_other(&self) -> Option<NonNull<llama_cpp_sys_4::llama_context>> {
556        NonNull::new(unsafe { llama_cpp_sys_4::llama_get_ctx_other(self.context.as_ptr()) })
557    }
558
559    /// Reset the timings for the context.
560    pub fn reset_timings(&mut self) {
561        unsafe { llama_cpp_sys_4::ggml_time_init() }
562    }
563
564    /// Returns the timings for the context.
565    pub fn timings(&mut self) -> PerfContextData {
566        let perf_context_data =
567            unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
568        PerfContextData { perf_context_data }
569    }
570
571    /// Reset the performance counters for the context.
572    pub fn perf_context_reset(&mut self) {
573        unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
574    }
575
576    /// Check if the KV cache memory supports shifting.
577    #[must_use]
578    pub fn memory_can_shift(&self) -> bool {
579        unsafe {
580            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
581            llama_cpp_sys_4::llama_memory_can_shift(mem)
582        }
583    }
584
585    /// Get the minimum position in a sequence's KV cache.
586    #[must_use]
587    pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
588        unsafe {
589            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
590            llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
591        }
592    }
593
594    /// Print a human-readable memory breakdown to stderr via llama.cpp.
595    ///
596    /// For structured access use [`Self::memory_breakdown`].
597    pub fn memory_breakdown_print(&self) {
598        unsafe {
599            llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
600        }
601    }
602
603    /// Return structured per-buffer memory usage for this context.
604    ///
605    /// Each [`memory_breakdown::MemoryBreakdownEntry`] reports model weights,
606    /// KV / recurrent cache, and compute scratch bytes for one ggml buffer
607    /// type. Returns an empty vector when no buffers are registered.
608    ///
609    /// # Examples
610    ///
611    /// ```no_run
612    /// use llama_cpp_4::prelude::*;
613    ///
614    /// fn main() {
615    ///     let backend = LlamaBackend::init().unwrap();
616    ///     let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default()).unwrap();
617    ///     let ctx = model.new_context(&backend, LlamaContextParams::default()).unwrap();
618    ///     let total: usize = ctx.memory_breakdown().iter().map(|e| e.total()).sum();
619    ///     println!("context uses {total} bytes across all buffer types");
620    /// }
621    /// ```
622    #[must_use]
623    pub fn memory_breakdown(&self) -> Vec<memory_breakdown::MemoryBreakdownEntry> {
624        memory_breakdown::collect_memory_breakdown(self.context.as_ptr())
625    }
626
627    /// Enable or disable extraction of input embeddings for a transformer layer.
628    ///
629    /// Maps to `llama_set_embeddings_layer_inp`. After a successful
630    /// [`Self::decode`], read the vector with [`Self::get_embeddings_layer_inp`].
631    pub fn set_embeddings_layer_inp(&mut self, layer_id: u32, value: bool) {
632        unsafe {
633            llama_cpp_sys_4::llama_set_embeddings_layer_inp(self.context.as_ptr(), layer_id, value);
634        }
635    }
636
637    /// Get input embeddings for `layer_id` from the last decoded batch.
638    ///
639    /// Returns `None` when the layer was not enabled via
640    /// [`Self::set_embeddings_layer_inp`] or when upstream has no data for
641    /// `layer_id`. The slice length is [`LlamaModel::n_embd`].
642    ///
643    /// # Panics
644    ///
645    /// Panics if `n_embd` does not fit in `usize`.
646    #[must_use]
647    pub fn get_embeddings_layer_inp(&self, layer_id: u32) -> Option<&[f32]> {
648        let n_embd =
649            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
650        unsafe {
651            let p =
652                llama_cpp_sys_4::llama_get_embeddings_layer_inp(self.context.as_ptr(), layer_id);
653            if p.is_null() {
654                None
655            } else {
656                Some(slice::from_raw_parts(p, n_embd))
657            }
658        }
659    }
660
661    /// Get the size of the full context state in bytes.
662    ///
663    /// This is the size needed for [`state_get_data`](Self::state_get_data) and
664    /// [`state_set_data`](Self::state_set_data).
665    #[must_use]
666    pub fn state_get_size(&mut self) -> usize {
667        unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
668    }
669
670    /// Copy the full context state into a byte buffer.
671    ///
672    /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
673    ///
674    /// Returns the number of bytes written.
675    pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
676        unsafe {
677            llama_cpp_sys_4::llama_state_get_data(
678                self.context.as_ptr(),
679                dst.as_mut_ptr(),
680                dst.len(),
681            )
682        }
683    }
684
685    /// Restore the full context state from a byte buffer.
686    ///
687    /// Returns the number of bytes read.
688    pub fn state_set_data(&mut self, src: &[u8]) -> usize {
689        unsafe {
690            llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
691        }
692    }
693
694    /// Save the context state to a file along with the given tokens.
695    ///
696    /// Returns `true` on success.
697    ///
698    /// # Panics
699    ///
700    /// Panics if the path contains null bytes.
701    pub fn state_save_file(
702        &mut self,
703        path: impl AsRef<std::path::Path>,
704        tokens: &[LlamaToken],
705    ) -> bool {
706        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
707        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
708        unsafe {
709            llama_cpp_sys_4::llama_state_save_file(
710                self.context.as_ptr(),
711                c_path.as_ptr(),
712                tokens.as_ptr().cast(),
713                tokens.len(),
714            )
715        }
716    }
717
718    /// Load a context state from a file.
719    ///
720    /// Returns `true` on success and fills `tokens_out` with the saved tokens.
721    ///
722    /// # Panics
723    ///
724    /// Panics if the path contains null bytes.
725    pub fn state_load_file(
726        &mut self,
727        path: impl AsRef<std::path::Path>,
728        tokens_out: &mut Vec<LlamaToken>,
729        n_token_capacity: usize,
730    ) -> bool {
731        tokens_out.resize(n_token_capacity, LlamaToken(0));
732        let mut n_token_count: usize = 0;
733        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
734        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
735        let ok = unsafe {
736            llama_cpp_sys_4::llama_state_load_file(
737                self.context.as_ptr(),
738                c_path.as_ptr(),
739                tokens_out.as_mut_ptr().cast(),
740                n_token_capacity,
741                std::ptr::addr_of_mut!(n_token_count),
742            )
743        };
744        if ok {
745            tokens_out.truncate(n_token_count);
746        }
747        ok
748    }
749
750    /// Get the size of a single sequence's state in bytes.
751    #[must_use]
752    pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
753        unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
754    }
755
756    /// Copy a single sequence's state into a byte buffer.
757    ///
758    /// Returns the number of bytes written.
759    pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
760        unsafe {
761            llama_cpp_sys_4::llama_state_seq_get_data(
762                self.context.as_ptr(),
763                dst.as_mut_ptr(),
764                dst.len(),
765                seq_id,
766            )
767        }
768    }
769
770    /// Restore a single sequence's state from a byte buffer.
771    ///
772    /// Returns the number of bytes read.
773    pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
774        unsafe {
775            llama_cpp_sys_4::llama_state_seq_set_data(
776                self.context.as_ptr(),
777                src.as_ptr(),
778                src.len(),
779                dest_seq_id,
780            )
781        }
782    }
783
784    /// Save a single sequence's state to a file.
785    ///
786    /// Returns the number of bytes written (0 on failure).
787    ///
788    /// # Panics
789    ///
790    /// Panics if the path contains null bytes.
791    pub fn state_seq_save_file(
792        &mut self,
793        path: impl AsRef<std::path::Path>,
794        seq_id: i32,
795        tokens: &[LlamaToken],
796    ) -> usize {
797        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
798        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
799        unsafe {
800            llama_cpp_sys_4::llama_state_seq_save_file(
801                self.context.as_ptr(),
802                c_path.as_ptr(),
803                seq_id,
804                tokens.as_ptr().cast(),
805                tokens.len(),
806            )
807        }
808    }
809
810    /// Load a single sequence's state from a file.
811    ///
812    /// Returns the number of bytes read (0 on failure).
813    ///
814    /// # Panics
815    ///
816    /// Panics if the path contains null bytes.
817    pub fn state_seq_load_file(
818        &mut self,
819        path: impl AsRef<std::path::Path>,
820        dest_seq_id: i32,
821        tokens_out: &mut Vec<LlamaToken>,
822        n_token_capacity: usize,
823    ) -> usize {
824        tokens_out.resize(n_token_capacity, LlamaToken(0));
825        let mut n_token_count: usize = 0;
826        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
827        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
828        let ret = unsafe {
829            llama_cpp_sys_4::llama_state_seq_load_file(
830                self.context.as_ptr(),
831                c_path.as_ptr(),
832                dest_seq_id,
833                tokens_out.as_mut_ptr().cast(),
834                n_token_capacity,
835                std::ptr::addr_of_mut!(n_token_count),
836            )
837        };
838        if ret > 0 {
839            tokens_out.truncate(n_token_count);
840        }
841        ret
842    }
843
844    /// Set a control vector on the context.
845    ///
846    /// # Parameters
847    ///
848    /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
849    /// - `n_embd`: The embedding dimension.
850    /// - `il_start`: The starting layer index (inclusive).
851    /// - `il_end`: The ending layer index (exclusive).
852    ///
853    /// # Errors
854    ///
855    /// Returns `Err` with the error code if the operation fails.
856    pub fn set_adapter_cvec(
857        &mut self,
858        data: &[f32],
859        n_embd: i32,
860        il_start: i32,
861        il_end: i32,
862    ) -> Result<(), i32> {
863        let ret = unsafe {
864            llama_cpp_sys_4::llama_set_adapter_cvec(
865                self.context.as_ptr(),
866                data.as_ptr(),
867                data.len(),
868                n_embd,
869                il_start,
870                il_end,
871            )
872        };
873        if ret != 0 {
874            Err(ret)
875        } else {
876            Ok(())
877        }
878    }
879
880    /// Get sampled token debug info for the `i`th position.
881    ///
882    /// Returns the sampled token at position `i` from the last decode call.
883    #[must_use]
884    pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
885        let token =
886            unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
887        LlamaToken(token)
888    }
889
890    /// Get sampled candidate tokens for the `i`th position.
891    ///
892    /// Returns a slice of candidate tokens from the last decode call.
893    #[must_use]
894    pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
895        let count = unsafe {
896            llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
897        } as usize;
898        if count == 0 {
899            return &[];
900        }
901        let ptr =
902            unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
903        if ptr.is_null() {
904            return &[];
905        }
906        unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
907    }
908
909    /// Get the number of sampled logits for the `i`th position.
910    #[must_use]
911    pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
912        unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
913    }
914
915    /// Get sampled logits for the `i`th position.
916    ///
917    /// Returns a slice of logit values from the last decode call.
918    #[must_use]
919    pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
920        let count = self.get_sampled_logits_count_ith(i) as usize;
921        if count == 0 {
922            return &[];
923        }
924        let ptr =
925            unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
926        if ptr.is_null() {
927            return &[];
928        }
929        unsafe { slice::from_raw_parts(ptr, count) }
930    }
931
932    /// Get the number of sampled probabilities for the `i`th position.
933    #[must_use]
934    pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
935        unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
936    }
937
938    /// Get sampled probabilities for the `i`th position.
939    ///
940    /// Returns a slice of probability values from the last decode call.
941    #[must_use]
942    pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
943        let count = self.get_sampled_probs_count_ith(i) as usize;
944        if count == 0 {
945            return &[];
946        }
947        let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
948        if ptr.is_null() {
949            return &[];
950        }
951        unsafe { slice::from_raw_parts(ptr, count) }
952    }
953
954    /// Get the size of a single sequence's state with flags.
955    #[must_use]
956    pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
957        unsafe {
958            llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
959        }
960    }
961
962    /// Copy a single sequence's state into a byte buffer with flags.
963    ///
964    /// Returns the number of bytes written.
965    pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
966        unsafe {
967            llama_cpp_sys_4::llama_state_seq_get_data_ext(
968                self.context.as_ptr(),
969                dst.as_mut_ptr(),
970                dst.len(),
971                seq_id,
972                flags,
973            )
974        }
975    }
976
977    /// Restore a single sequence's state from a byte buffer with flags.
978    ///
979    /// Returns the number of bytes read.
980    pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
981        unsafe {
982            llama_cpp_sys_4::llama_state_seq_set_data_ext(
983                self.context.as_ptr(),
984                src.as_ptr(),
985                src.len(),
986                dest_seq_id,
987                flags,
988            )
989        }
990    }
991
992    /// Set an abort callback for the context.
993    ///
994    /// The callback is called periodically during computation. If it returns `true`,
995    /// the computation is aborted.
996    ///
997    /// # Safety
998    ///
999    /// The callback data must remain valid for the lifetime of the context or until
1000    /// the callback is replaced.
1001    pub unsafe fn set_abort_callback(
1002        &mut self,
1003        callback: llama_cpp_sys_4::ggml_abort_callback,
1004        data: *mut std::ffi::c_void,
1005    ) {
1006        llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
1007    }
1008
1009    /// Attach a thread pool to the context.
1010    ///
1011    /// # Safety
1012    ///
1013    /// The thread pools must remain valid for the lifetime of the context or until
1014    /// they are detached.
1015    pub unsafe fn attach_threadpool(
1016        &mut self,
1017        threadpool: llama_cpp_sys_4::ggml_threadpool_t,
1018        threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
1019    ) {
1020        llama_cpp_sys_4::llama_attach_threadpool(
1021            self.context.as_ptr(),
1022            threadpool,
1023            threadpool_batch,
1024        );
1025    }
1026
1027    /// Detach the thread pool from the context.
1028    pub fn detach_threadpool(&mut self) {
1029        unsafe {
1030            llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
1031        }
1032    }
1033
1034    /// Set a sampler for a specific sequence.
1035    ///
1036    /// Returns `true` on success.
1037    pub fn set_sampler(
1038        &mut self,
1039        seq_id: i32,
1040        sampler: &mut crate::sampling::LlamaSampler,
1041    ) -> bool {
1042        unsafe {
1043            llama_cpp_sys_4::llama_set_sampler(
1044                self.context.as_ptr(),
1045                seq_id,
1046                sampler.sampler.as_ptr(),
1047            )
1048        }
1049    }
1050
1051    /// Get the raw model pointer from this context.
1052    ///
1053    /// This is mainly useful for FFI interop. In normal usage, access
1054    /// the model via the `model` field instead.
1055    #[must_use]
1056    pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
1057        unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
1058    }
1059
1060    /// Sets a lora adapter.
1061    ///
1062    /// # Errors
1063    ///
1064    /// See [`LlamaLoraAdapterSetError`] for more information.
1065    pub fn lora_adapter_set(
1066        &self,
1067        adapter: &mut LlamaLoraAdapter,
1068        scale: f32,
1069    ) -> Result<(), LlamaLoraAdapterSetError> {
1070        let err_code = unsafe {
1071            // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
1072            // which takes a full list of adapters + scales at once (b8249+)
1073            let mut adapter_ptr = adapter.lora_adapter.as_ptr();
1074            let mut scale_val = scale;
1075            llama_cpp_sys_4::llama_set_adapters_lora(
1076                self.context.as_ptr(),
1077                &raw mut adapter_ptr,
1078                1,
1079                &raw mut scale_val,
1080            )
1081        };
1082        if err_code != 0 {
1083            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
1084        }
1085
1086        tracing::debug!("Set lora adapter");
1087        Ok(())
1088    }
1089
1090    /// Remove all lora adapters from the context.
1091    ///
1092    /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
1093    /// `llama_set_adapters_lora` which operates on the full adapter list at once.
1094    /// Calling this function clears **all** adapters currently set on the context.
1095    ///
1096    /// # Errors
1097    ///
1098    /// See [`LlamaLoraAdapterRemoveError`] for more information.
1099    pub fn lora_adapter_remove(
1100        &self,
1101        _adapter: &mut LlamaLoraAdapter,
1102    ) -> Result<(), LlamaLoraAdapterRemoveError> {
1103        let err_code = unsafe {
1104            llama_cpp_sys_4::llama_set_adapters_lora(
1105                self.context.as_ptr(),
1106                std::ptr::null_mut(),
1107                0,
1108                std::ptr::null_mut(),
1109            )
1110        };
1111        if err_code != 0 {
1112            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
1113        }
1114
1115        tracing::debug!("Remove lora adapter");
1116        Ok(())
1117    }
1118}
1119
1120impl Drop for LlamaContext<'_> {
1121    fn drop(&mut self) {
1122        unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
1123    }
1124}