Skip to main content

llama_cpp_4/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19    LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26pub mod tensor_capture;
27
28/// A safe wrapper around the `llama_context` C++ context.
29///
30/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
31/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
32/// context-specific settings like embeddings and logits.
33///
34/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
35/// the context pointer, preventing it from being null. The struct also holds a reference to the model
36/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
37/// and the initialized logits for the context.
38///
39/// # Fields
40///
41/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
42/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
43///   that the context interacts with.
44/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
45///   are kept separate from the context data.
46/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
47///   controlling whether embedding data is generated during the interaction with the model.
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaContext<'a> {
50    pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
51    /// a reference to the contexts model.
52    pub model: &'a LlamaModel,
53    initialized_logits: Vec<i32>,
54    embeddings_enabled: bool,
55}
56
57impl Debug for LlamaContext<'_> {
58    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("LlamaContext")
60            .field("context", &self.context)
61            .finish()
62    }
63}
64
65impl<'model> LlamaContext<'model> {
66    /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
67    ///
68    /// This function initializes a new `LlamaContext` object, which is used to interact with the
69    /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
70    /// determines whether embeddings are enabled in the context.
71    ///
72    /// # Parameters
73    ///
74    /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
75    /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
76    ///   the context created in previous steps. This context is necessary for interacting with the model.
77    /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
78    ///
79    /// # Returns
80    ///
81    /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
82    /// - The model reference (`llama_model`) is stored in the context.
83    /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
84    /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
85    ///
86    /// # Example
87    /// ```ignore
88    /// let llama_model = LlamaModel::load_from_file(&backend, "path/to/model", &params).unwrap();
89    /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
90    /// let context = LlamaContext::new(&llama_model, context_ptr, true);
91    /// // Now you can use the context
92    /// ```
93    pub(crate) fn new(
94        llama_model: &'model LlamaModel,
95        llama_context: NonNull<llama_cpp_sys_4::llama_context>,
96        embeddings_enabled: bool,
97    ) -> Self {
98        Self {
99            context: llama_context,
100            model: llama_model,
101            initialized_logits: Vec::new(),
102            embeddings_enabled,
103        }
104    }
105
106    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
107    #[must_use]
108    pub fn n_batch(&self) -> u32 {
109        unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
110    }
111
112    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
113    #[must_use]
114    pub fn n_ubatch(&self) -> u32 {
115        unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
116    }
117
118    /// Gets the size of the context.
119    #[must_use]
120    pub fn n_ctx(&self) -> u32 {
121        unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
122    }
123
124    /// Decodes the batch.
125    ///
126    /// # Errors
127    ///
128    /// - `DecodeError` if the decoding failed.
129    ///
130    /// # Panics
131    ///
132    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
133    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
134        let result =
135            unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
136
137        match NonZeroI32::new(result) {
138            None => {
139                self.initialized_logits
140                    .clone_from(&batch.initialized_logits);
141                Ok(())
142            }
143            Some(error) => Err(DecodeError::from(error)),
144        }
145    }
146
147    /// Encodes the batch.
148    ///
149    /// # Errors
150    ///
151    /// - `EncodeError` if the decoding failed.
152    ///
153    /// # Panics
154    ///
155    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
156    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
157        let result =
158            unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
159
160        match NonZeroI32::new(result) {
161            None => {
162                self.initialized_logits
163                    .clone_from(&batch.initialized_logits);
164                Ok(())
165            }
166            Some(error) => Err(EncodeError::from(error)),
167        }
168    }
169
170    /// Return Pooling type for Llama's Context
171    #[must_use]
172    pub fn pooling_type(&self) -> LlamaPoolingType {
173        let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
174
175        LlamaPoolingType::from(pooling_type)
176    }
177
178    /// Get the embeddings for the `i`th sequence in the current context.
179    ///
180    /// # Returns
181    ///
182    /// A slice containing the embeddings for the last decoded batch.
183    /// The size corresponds to the `n_embd` parameter of the context's model.
184    ///
185    /// # Errors
186    ///
187    /// - When the current context was constructed without enabling embeddings.
188    /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
189    /// - If the given sequence index exceeds the max sequence id.
190    ///
191    /// # Panics
192    ///
193    /// * `n_embd` does not fit into a usize
194    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
195        if !self.embeddings_enabled {
196            return Err(EmbeddingsError::NotEnabled);
197        }
198
199        let n_embd =
200            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
201
202        unsafe {
203            let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
204
205            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
206            if embedding.is_null() {
207                Err(EmbeddingsError::NonePoolType)
208            } else {
209                Ok(slice::from_raw_parts(embedding, n_embd))
210            }
211        }
212    }
213
214    /// Get the embeddings for the `i`th token in the current context.
215    ///
216    /// # Returns
217    ///
218    /// A slice containing the embeddings for the last decoded batch of the given token.
219    /// The size corresponds to the `n_embd` parameter of the context's model.
220    ///
221    /// # Errors
222    ///
223    /// - When the current context was constructed without enabling embeddings.
224    /// - When the given token didn't have logits enabled when it was passed.
225    /// - If the given token index exceeds the max token id.
226    ///
227    /// # Panics
228    ///
229    /// * `n_embd` does not fit into a usize
230    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
231        if !self.embeddings_enabled {
232            return Err(EmbeddingsError::NotEnabled);
233        }
234
235        let n_embd =
236            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
237
238        unsafe {
239            let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
240            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
241            if embedding.is_null() {
242                Err(EmbeddingsError::LogitsNotEnabled)
243            } else {
244                Ok(slice::from_raw_parts(embedding, n_embd))
245            }
246        }
247    }
248
249    /// Get the logits for the last token in the context.
250    ///
251    /// # Returns
252    /// An iterator over unsorted `LlamaTokenData` containing the
253    /// logits for the last token in the context.
254    ///
255    /// # Panics
256    ///
257    /// - underlying logits data is null
258    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
259        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
260            let token = LlamaToken::new(i);
261            LlamaTokenData::new(token, *logit, 0_f32)
262        })
263    }
264
265    /// Get the token data array for the last token in the context.
266    ///
267    /// This is a convience method that implements:
268    /// ```ignore
269    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
270    /// ```
271    ///
272    /// # Panics
273    ///
274    /// - underlying logits data is null
275    #[must_use]
276    pub fn token_data_array(&self) -> LlamaTokenDataArray {
277        LlamaTokenDataArray::from_iter(self.candidates(), false)
278    }
279
280    /// Token logits obtained from the last call to `decode()`.
281    /// The logits for which `batch.logits[i] != 0` are stored contiguously
282    /// in the order they have appeared in the batch.
283    /// Rows: number of tokens for which `batch.logits[i] != 0`
284    /// Cols: `n_vocab`
285    ///
286    /// # Returns
287    ///
288    /// A slice containing the logits for the last decoded token.
289    /// The size corresponds to the `n_vocab` parameter of the context's model.
290    ///
291    /// # Panics
292    ///
293    /// - `n_vocab` does not fit into a usize
294    /// - token data returned is null
295    #[must_use]
296    pub fn get_logits(&self) -> &[f32] {
297        let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
298        assert!(!data.is_null(), "logits data for last token is null");
299        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
300
301        unsafe { slice::from_raw_parts(data, len) }
302    }
303
304    /// Get the logits for the ith token in the context.
305    ///
306    /// # Panics
307    ///
308    /// - logit `i` is not initialized.
309    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
310        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
311            let token = LlamaToken::new(i);
312            LlamaTokenData::new(token, *logit, 0_f32)
313        })
314    }
315
316    /// Get the logits for the ith token in the context.
317    ///
318    /// # Panics
319    ///
320    /// - `i` is greater than `n_ctx`
321    /// - `n_vocab` does not fit into a usize
322    /// - logit `i` is not initialized.
323    #[must_use]
324    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
325        assert!(
326            self.initialized_logits.contains(&i),
327            "logit {i} is not initialized. only {:?} is",
328            self.initialized_logits
329        );
330        assert!(
331            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
332            "n_ctx ({}) must be greater than i ({})",
333            self.n_ctx(),
334            i
335        );
336
337        let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
338        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
339
340        unsafe { slice::from_raw_parts(data, len) }
341    }
342
343    /// Get the number of context tokens per sequence.
344    #[must_use]
345    pub fn n_ctx_seq(&self) -> u32 {
346        unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
347    }
348
349    /// Get the maximum number of sequences.
350    #[must_use]
351    pub fn n_seq_max(&self) -> u32 {
352        unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
353    }
354
355    /// Get the number of recurrent-state snapshots per sequence.
356    #[must_use]
357    pub fn n_rs_seq(&self) -> u32 {
358        unsafe { llama_cpp_sys_4::llama_n_rs_seq(self.context.as_ptr()) }
359    }
360
361    /// Get the number of threads used for generation.
362    #[must_use]
363    pub fn n_threads(&self) -> i32 {
364        unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
365    }
366
367    /// Get the number of threads used for batch processing.
368    #[must_use]
369    pub fn n_threads_batch(&self) -> i32 {
370        unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
371    }
372
373    /// Set the number of threads used for generation and batch processing.
374    pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
375        unsafe {
376            llama_cpp_sys_4::llama_set_n_threads(self.context.as_ptr(), n_threads, n_threads_batch);
377        }
378    }
379
380    /// Set whether to use causal attention.
381    ///
382    /// If set to `false`, the model will use non-causal attention, which is
383    /// needed for embedding models.
384    pub fn set_causal_attn(&mut self, causal_attn: bool) {
385        unsafe {
386            llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
387        }
388    }
389
390    /// Set whether to compute embeddings.
391    ///
392    /// This allows toggling embedding mode at runtime (as opposed to only at
393    /// context creation time).
394    pub fn set_embeddings(&mut self, embeddings: bool) {
395        self.embeddings_enabled = embeddings;
396        unsafe {
397            llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
398        }
399    }
400
401    /// Mark the next computation as a warmup run.
402    ///
403    /// Warmup runs are useful for GPU backends to compile kernels before
404    /// actual inference begins.
405    pub fn set_warmup(&mut self, warmup: bool) {
406        unsafe {
407            llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
408        }
409    }
410
411    /// Wait for all pending async computations to finish.
412    pub fn synchronize(&mut self) {
413        unsafe {
414            llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
415        }
416    }
417
418    /// Get all embeddings for the current context.
419    ///
420    /// Returns a slice of all embeddings from the last decoded batch.
421    /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
422    ///
423    /// # Errors
424    ///
425    /// - When the current context was constructed without enabling embeddings.
426    /// - If the embeddings pointer is null.
427    ///
428    /// # Panics
429    ///
430    /// * `n_embd` does not fit into a usize
431    pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
432        if !self.embeddings_enabled {
433            return Err(EmbeddingsError::NotEnabled);
434        }
435
436        let n_embd =
437            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
438
439        unsafe {
440            let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
441            if embedding.is_null() {
442                Err(EmbeddingsError::NonePoolType)
443            } else {
444                Ok(slice::from_raw_parts(embedding, n_embd))
445            }
446        }
447    }
448
449    /// Toggle extraction of pre-norm embeddings — the hidden state right
450    /// before the final output norm, used by MTP draft heads (upstream
451    /// llama.cpp PR #23198).
452    ///
453    /// If `masked` is `true`, pre-norm rows are extracted only for tokens
454    /// whose `batch.logits[i]` is non-zero. If `masked` is `false`, rows are
455    /// extracted for every token in the batch regardless of `batch.logits` —
456    /// callers can then leave `batch.logits[i] = false` on prompt-fill
457    /// positions and avoid copying the full logits row for each one.
458    ///
459    /// `MtpSession::new` already sets this up (unmasked on the target,
460    /// masked on the draft). Callers normally don't need to invoke this
461    /// directly.
462    pub fn set_embeddings_pre_norm(&mut self, value: bool, masked: bool) {
463        unsafe {
464            llama_cpp_sys_4::llama_set_embeddings_pre_norm(self.context.as_ptr(), value, masked);
465        }
466    }
467
468    /// Get the full pre-norm embeddings buffer for the last decoded batch.
469    ///
470    /// Returns `None` when pre-norm embeddings are disabled or the buffer
471    /// hasn't been populated. The length of the returned slice is
472    /// `n_embd * <number of pre-norm rows>` — interpretation of the row
473    /// count depends on whether the setter was called with `masked=true`
474    /// (one row per sampled token) or `masked=false` (one row per batch
475    /// token). Use [`get_embeddings_pre_norm_ith`](Self::get_embeddings_pre_norm_ith)
476    /// when you only need a single row.
477    #[must_use]
478    pub fn get_embeddings_pre_norm(&self) -> Option<&[f32]> {
479        let n_embd =
480            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
481        unsafe {
482            let p = llama_cpp_sys_4::llama_get_embeddings_pre_norm(self.context.as_ptr());
483            if p.is_null() {
484                None
485            } else {
486                Some(slice::from_raw_parts(p, n_embd))
487            }
488        }
489    }
490
491    /// Get the pre-norm embedding row for the `i`th output position of the
492    /// last decoded batch. Returns `None` if upstream rejects the index
493    /// (e.g. masked mode with `batch.logits[i] == 0`, or out of range).
494    #[must_use]
495    pub fn get_embeddings_pre_norm_ith(&self, i: i32) -> Option<&[f32]> {
496        let n_embd =
497            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
498        unsafe {
499            let p = llama_cpp_sys_4::llama_get_embeddings_pre_norm_ith(self.context.as_ptr(), i);
500            if p.is_null() {
501                None
502            } else {
503                Some(slice::from_raw_parts(p, n_embd))
504            }
505        }
506    }
507
508    /// Reset the timings for the context.
509    pub fn reset_timings(&mut self) {
510        unsafe { llama_cpp_sys_4::ggml_time_init() }
511    }
512
513    /// Returns the timings for the context.
514    pub fn timings(&mut self) -> PerfContextData {
515        let perf_context_data =
516            unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
517        PerfContextData { perf_context_data }
518    }
519
520    /// Reset the performance counters for the context.
521    pub fn perf_context_reset(&mut self) {
522        unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
523    }
524
525    /// Check if the KV cache memory supports shifting.
526    #[must_use]
527    pub fn memory_can_shift(&self) -> bool {
528        unsafe {
529            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
530            llama_cpp_sys_4::llama_memory_can_shift(mem)
531        }
532    }
533
534    /// Get the minimum position in a sequence's KV cache.
535    #[must_use]
536    pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
537        unsafe {
538            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
539            llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
540        }
541    }
542
543    /// Print a breakdown of the memory usage.
544    pub fn memory_breakdown_print(&self) {
545        unsafe {
546            llama_cpp_sys_4::common_memory_breakdown_print(self.context.as_ptr());
547        }
548    }
549
550    /// Get the size of the full context state in bytes.
551    ///
552    /// This is the size needed for [`state_get_data`](Self::state_get_data) and
553    /// [`state_set_data`](Self::state_set_data).
554    #[must_use]
555    pub fn state_get_size(&mut self) -> usize {
556        unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
557    }
558
559    /// Copy the full context state into a byte buffer.
560    ///
561    /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
562    ///
563    /// Returns the number of bytes written.
564    pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
565        unsafe {
566            llama_cpp_sys_4::llama_state_get_data(
567                self.context.as_ptr(),
568                dst.as_mut_ptr(),
569                dst.len(),
570            )
571        }
572    }
573
574    /// Restore the full context state from a byte buffer.
575    ///
576    /// Returns the number of bytes read.
577    pub fn state_set_data(&mut self, src: &[u8]) -> usize {
578        unsafe {
579            llama_cpp_sys_4::llama_state_set_data(self.context.as_ptr(), src.as_ptr(), src.len())
580        }
581    }
582
583    /// Save the context state to a file along with the given tokens.
584    ///
585    /// Returns `true` on success.
586    ///
587    /// # Panics
588    ///
589    /// Panics if the path contains null bytes.
590    pub fn state_save_file(
591        &mut self,
592        path: impl AsRef<std::path::Path>,
593        tokens: &[LlamaToken],
594    ) -> bool {
595        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
596        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
597        unsafe {
598            llama_cpp_sys_4::llama_state_save_file(
599                self.context.as_ptr(),
600                c_path.as_ptr(),
601                tokens.as_ptr().cast(),
602                tokens.len(),
603            )
604        }
605    }
606
607    /// Load a context state from a file.
608    ///
609    /// Returns `true` on success and fills `tokens_out` with the saved tokens.
610    ///
611    /// # Panics
612    ///
613    /// Panics if the path contains null bytes.
614    pub fn state_load_file(
615        &mut self,
616        path: impl AsRef<std::path::Path>,
617        tokens_out: &mut Vec<LlamaToken>,
618        n_token_capacity: usize,
619    ) -> bool {
620        tokens_out.resize(n_token_capacity, LlamaToken(0));
621        let mut n_token_count: usize = 0;
622        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
623        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
624        let ok = unsafe {
625            llama_cpp_sys_4::llama_state_load_file(
626                self.context.as_ptr(),
627                c_path.as_ptr(),
628                tokens_out.as_mut_ptr().cast(),
629                n_token_capacity,
630                std::ptr::addr_of_mut!(n_token_count),
631            )
632        };
633        if ok {
634            tokens_out.truncate(n_token_count);
635        }
636        ok
637    }
638
639    /// Get the size of a single sequence's state in bytes.
640    #[must_use]
641    pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
642        unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
643    }
644
645    /// Copy a single sequence's state into a byte buffer.
646    ///
647    /// Returns the number of bytes written.
648    pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
649        unsafe {
650            llama_cpp_sys_4::llama_state_seq_get_data(
651                self.context.as_ptr(),
652                dst.as_mut_ptr(),
653                dst.len(),
654                seq_id,
655            )
656        }
657    }
658
659    /// Restore a single sequence's state from a byte buffer.
660    ///
661    /// Returns the number of bytes read.
662    pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
663        unsafe {
664            llama_cpp_sys_4::llama_state_seq_set_data(
665                self.context.as_ptr(),
666                src.as_ptr(),
667                src.len(),
668                dest_seq_id,
669            )
670        }
671    }
672
673    /// Save a single sequence's state to a file.
674    ///
675    /// Returns the number of bytes written (0 on failure).
676    ///
677    /// # Panics
678    ///
679    /// Panics if the path contains null bytes.
680    pub fn state_seq_save_file(
681        &mut self,
682        path: impl AsRef<std::path::Path>,
683        seq_id: i32,
684        tokens: &[LlamaToken],
685    ) -> usize {
686        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
687        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
688        unsafe {
689            llama_cpp_sys_4::llama_state_seq_save_file(
690                self.context.as_ptr(),
691                c_path.as_ptr(),
692                seq_id,
693                tokens.as_ptr().cast(),
694                tokens.len(),
695            )
696        }
697    }
698
699    /// Load a single sequence's state from a file.
700    ///
701    /// Returns the number of bytes read (0 on failure).
702    ///
703    /// # Panics
704    ///
705    /// Panics if the path contains null bytes.
706    pub fn state_seq_load_file(
707        &mut self,
708        path: impl AsRef<std::path::Path>,
709        dest_seq_id: i32,
710        tokens_out: &mut Vec<LlamaToken>,
711        n_token_capacity: usize,
712    ) -> usize {
713        tokens_out.resize(n_token_capacity, LlamaToken(0));
714        let mut n_token_count: usize = 0;
715        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
716        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
717        let ret = unsafe {
718            llama_cpp_sys_4::llama_state_seq_load_file(
719                self.context.as_ptr(),
720                c_path.as_ptr(),
721                dest_seq_id,
722                tokens_out.as_mut_ptr().cast(),
723                n_token_capacity,
724                std::ptr::addr_of_mut!(n_token_count),
725            )
726        };
727        if ret > 0 {
728            tokens_out.truncate(n_token_count);
729        }
730        ret
731    }
732
733    /// Set a control vector on the context.
734    ///
735    /// # Parameters
736    ///
737    /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
738    /// - `n_embd`: The embedding dimension.
739    /// - `il_start`: The starting layer index (inclusive).
740    /// - `il_end`: The ending layer index (exclusive).
741    ///
742    /// # Errors
743    ///
744    /// Returns `Err` with the error code if the operation fails.
745    pub fn set_adapter_cvec(
746        &mut self,
747        data: &[f32],
748        n_embd: i32,
749        il_start: i32,
750        il_end: i32,
751    ) -> Result<(), i32> {
752        let ret = unsafe {
753            llama_cpp_sys_4::llama_set_adapter_cvec(
754                self.context.as_ptr(),
755                data.as_ptr(),
756                data.len(),
757                n_embd,
758                il_start,
759                il_end,
760            )
761        };
762        if ret != 0 {
763            Err(ret)
764        } else {
765            Ok(())
766        }
767    }
768
769    /// Get sampled token debug info for the `i`th position.
770    ///
771    /// Returns the sampled token at position `i` from the last decode call.
772    #[must_use]
773    pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
774        let token =
775            unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
776        LlamaToken(token)
777    }
778
779    /// Get sampled candidate tokens for the `i`th position.
780    ///
781    /// Returns a slice of candidate tokens from the last decode call.
782    #[must_use]
783    pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
784        let count = unsafe {
785            llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
786        } as usize;
787        if count == 0 {
788            return &[];
789        }
790        let ptr =
791            unsafe { llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i) };
792        if ptr.is_null() {
793            return &[];
794        }
795        unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
796    }
797
798    /// Get the number of sampled logits for the `i`th position.
799    #[must_use]
800    pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
801        unsafe { llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i) }
802    }
803
804    /// Get sampled logits for the `i`th position.
805    ///
806    /// Returns a slice of logit values from the last decode call.
807    #[must_use]
808    pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
809        let count = self.get_sampled_logits_count_ith(i) as usize;
810        if count == 0 {
811            return &[];
812        }
813        let ptr =
814            unsafe { llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i) };
815        if ptr.is_null() {
816            return &[];
817        }
818        unsafe { slice::from_raw_parts(ptr, count) }
819    }
820
821    /// Get the number of sampled probabilities for the `i`th position.
822    #[must_use]
823    pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
824        unsafe { llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i) }
825    }
826
827    /// Get sampled probabilities for the `i`th position.
828    ///
829    /// Returns a slice of probability values from the last decode call.
830    #[must_use]
831    pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
832        let count = self.get_sampled_probs_count_ith(i) as usize;
833        if count == 0 {
834            return &[];
835        }
836        let ptr = unsafe { llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i) };
837        if ptr.is_null() {
838            return &[];
839        }
840        unsafe { slice::from_raw_parts(ptr, count) }
841    }
842
843    /// Get the size of a single sequence's state with flags.
844    #[must_use]
845    pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
846        unsafe {
847            llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
848        }
849    }
850
851    /// Copy a single sequence's state into a byte buffer with flags.
852    ///
853    /// Returns the number of bytes written.
854    pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
855        unsafe {
856            llama_cpp_sys_4::llama_state_seq_get_data_ext(
857                self.context.as_ptr(),
858                dst.as_mut_ptr(),
859                dst.len(),
860                seq_id,
861                flags,
862            )
863        }
864    }
865
866    /// Restore a single sequence's state from a byte buffer with flags.
867    ///
868    /// Returns the number of bytes read.
869    pub fn state_seq_set_data_ext(&mut self, src: &[u8], dest_seq_id: i32, flags: u32) -> usize {
870        unsafe {
871            llama_cpp_sys_4::llama_state_seq_set_data_ext(
872                self.context.as_ptr(),
873                src.as_ptr(),
874                src.len(),
875                dest_seq_id,
876                flags,
877            )
878        }
879    }
880
881    /// Set an abort callback for the context.
882    ///
883    /// The callback is called periodically during computation. If it returns `true`,
884    /// the computation is aborted.
885    ///
886    /// # Safety
887    ///
888    /// The callback data must remain valid for the lifetime of the context or until
889    /// the callback is replaced.
890    pub unsafe fn set_abort_callback(
891        &mut self,
892        callback: llama_cpp_sys_4::ggml_abort_callback,
893        data: *mut std::ffi::c_void,
894    ) {
895        llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
896    }
897
898    /// Attach a thread pool to the context.
899    ///
900    /// # Safety
901    ///
902    /// The thread pools must remain valid for the lifetime of the context or until
903    /// they are detached.
904    pub unsafe fn attach_threadpool(
905        &mut self,
906        threadpool: llama_cpp_sys_4::ggml_threadpool_t,
907        threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
908    ) {
909        llama_cpp_sys_4::llama_attach_threadpool(
910            self.context.as_ptr(),
911            threadpool,
912            threadpool_batch,
913        );
914    }
915
916    /// Detach the thread pool from the context.
917    pub fn detach_threadpool(&mut self) {
918        unsafe {
919            llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
920        }
921    }
922
923    /// Set a sampler for a specific sequence.
924    ///
925    /// Returns `true` on success.
926    pub fn set_sampler(
927        &mut self,
928        seq_id: i32,
929        sampler: &mut crate::sampling::LlamaSampler,
930    ) -> bool {
931        unsafe {
932            llama_cpp_sys_4::llama_set_sampler(
933                self.context.as_ptr(),
934                seq_id,
935                sampler.sampler.as_ptr(),
936            )
937        }
938    }
939
940    /// Get the raw model pointer from this context.
941    ///
942    /// This is mainly useful for FFI interop. In normal usage, access
943    /// the model via the `model` field instead.
944    #[must_use]
945    pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
946        unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
947    }
948
949    /// Sets a lora adapter.
950    ///
951    /// # Errors
952    ///
953    /// See [`LlamaLoraAdapterSetError`] for more information.
954    pub fn lora_adapter_set(
955        &self,
956        adapter: &mut LlamaLoraAdapter,
957        scale: f32,
958    ) -> Result<(), LlamaLoraAdapterSetError> {
959        let err_code = unsafe {
960            // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
961            // which takes a full list of adapters + scales at once (b8249+)
962            let mut adapter_ptr = adapter.lora_adapter.as_ptr();
963            let mut scale_val = scale;
964            llama_cpp_sys_4::llama_set_adapters_lora(
965                self.context.as_ptr(),
966                &raw mut adapter_ptr,
967                1,
968                &raw mut scale_val,
969            )
970        };
971        if err_code != 0 {
972            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
973        }
974
975        tracing::debug!("Set lora adapter");
976        Ok(())
977    }
978
979    /// Remove all lora adapters from the context.
980    ///
981    /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
982    /// `llama_set_adapters_lora` which operates on the full adapter list at once.
983    /// Calling this function clears **all** adapters currently set on the context.
984    ///
985    /// # Errors
986    ///
987    /// See [`LlamaLoraAdapterRemoveError`] for more information.
988    pub fn lora_adapter_remove(
989        &self,
990        _adapter: &mut LlamaLoraAdapter,
991    ) -> Result<(), LlamaLoraAdapterRemoveError> {
992        let err_code = unsafe {
993            llama_cpp_sys_4::llama_set_adapters_lora(
994                self.context.as_ptr(),
995                std::ptr::null_mut(),
996                0,
997                std::ptr::null_mut(),
998            )
999        };
1000        if err_code != 0 {
1001            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
1002        }
1003
1004        tracing::debug!("Remove lora adapter");
1005        Ok(())
1006    }
1007}
1008
1009impl Drop for LlamaContext<'_> {
1010    fn drop(&mut self) {
1011        unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
1012    }
1013}