Skip to main content

llama_cpp_4/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19    LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26
27/// A safe wrapper around the `llama_context` C++ context.
28///
29/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
30/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
31/// context-specific settings like embeddings and logits.
32///
33/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
34/// the context pointer, preventing it from being null. The struct also holds a reference to the model
35/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
36/// and the initialized logits for the context.
37///
38/// # Fields
39///
40/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
41/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
42///   that the context interacts with.
43/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
44///   are kept separate from the context data.
45/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
46///   controlling whether embedding data is generated during the interaction with the model.
47#[allow(clippy::module_name_repetitions)]
48pub struct LlamaContext<'a> {
49    pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
50    /// a reference to the contexts model.
51    pub model: &'a LlamaModel,
52    initialized_logits: Vec<i32>,
53    embeddings_enabled: bool,
54}
55
56impl Debug for LlamaContext<'_> {
57    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("LlamaContext")
59            .field("context", &self.context)
60            .finish()
61    }
62}
63
64impl<'model> LlamaContext<'model> {
65    /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
66    ///
67    /// This function initializes a new `LlamaContext` object, which is used to interact with the
68    /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
69    /// determines whether embeddings are enabled in the context.
70    ///
71    /// # Parameters
72    ///
73    /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
74    /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
75    ///   the context created in previous steps. This context is necessary for interacting with the model.
76    /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
77    ///
78    /// # Returns
79    ///
80    /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
81    /// - The model reference (`llama_model`) is stored in the context.
82    /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
83    /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
84    ///
85    /// # Example
86    /// ```
87    /// let llama_model = LlamaModel::load("path/to/model").unwrap();
88    /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
89    /// let context = LlamaContext::new(&llama_model, context_ptr, true);
90    /// // Now you can use the context
91    /// ```
92    pub(crate) fn new(
93        llama_model: &'model LlamaModel,
94        llama_context: NonNull<llama_cpp_sys_4::llama_context>,
95        embeddings_enabled: bool,
96    ) -> Self {
97        Self {
98            context: llama_context,
99            model: llama_model,
100            initialized_logits: Vec::new(),
101            embeddings_enabled,
102        }
103    }
104
105    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
106    #[must_use]
107    pub fn n_batch(&self) -> u32 {
108        unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
109    }
110
111    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
112    #[must_use]
113    pub fn n_ubatch(&self) -> u32 {
114        unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
115    }
116
117    /// Gets the size of the context.
118    #[must_use]
119    pub fn n_ctx(&self) -> u32 {
120        unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
121    }
122
123    /// Decodes the batch.
124    ///
125    /// # Errors
126    ///
127    /// - `DecodeError` if the decoding failed.
128    ///
129    /// # Panics
130    ///
131    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
132    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
133        let result =
134            unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
135
136        match NonZeroI32::new(result) {
137            None => {
138                self.initialized_logits
139                    .clone_from(&batch.initialized_logits);
140                Ok(())
141            }
142            Some(error) => Err(DecodeError::from(error)),
143        }
144    }
145
146    /// Encodes the batch.
147    ///
148    /// # Errors
149    ///
150    /// - `EncodeError` if the decoding failed.
151    ///
152    /// # Panics
153    ///
154    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
155    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
156        let result =
157            unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
158
159        match NonZeroI32::new(result) {
160            None => {
161                self.initialized_logits
162                    .clone_from(&batch.initialized_logits);
163                Ok(())
164            }
165            Some(error) => Err(EncodeError::from(error)),
166        }
167    }
168
169    /// Return Pooling type for Llama's Context
170    #[must_use]
171    pub fn pooling_type(&self) -> LlamaPoolingType {
172        let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
173
174        LlamaPoolingType::from(pooling_type)
175    }
176
177    /// Get the embeddings for the `i`th sequence in the current context.
178    ///
179    /// # Returns
180    ///
181    /// A slice containing the embeddings for the last decoded batch.
182    /// The size corresponds to the `n_embd` parameter of the context's model.
183    ///
184    /// # Errors
185    ///
186    /// - When the current context was constructed without enabling embeddings.
187    /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
188    /// - If the given sequence index exceeds the max sequence id.
189    ///
190    /// # Panics
191    ///
192    /// * `n_embd` does not fit into a usize
193    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
194        if !self.embeddings_enabled {
195            return Err(EmbeddingsError::NotEnabled);
196        }
197
198        let n_embd =
199            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
200
201        unsafe {
202            let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
203
204            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
205            if embedding.is_null() {
206                Err(EmbeddingsError::NonePoolType)
207            } else {
208                Ok(slice::from_raw_parts(embedding, n_embd))
209            }
210        }
211    }
212
213    /// Get the embeddings for the `i`th token in the current context.
214    ///
215    /// # Returns
216    ///
217    /// A slice containing the embeddings for the last decoded batch of the given token.
218    /// The size corresponds to the `n_embd` parameter of the context's model.
219    ///
220    /// # Errors
221    ///
222    /// - When the current context was constructed without enabling embeddings.
223    /// - When the given token didn't have logits enabled when it was passed.
224    /// - If the given token index exceeds the max token id.
225    ///
226    /// # Panics
227    ///
228    /// * `n_embd` does not fit into a usize
229    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
230        if !self.embeddings_enabled {
231            return Err(EmbeddingsError::NotEnabled);
232        }
233
234        let n_embd =
235            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
236
237        unsafe {
238            let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
239            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
240            if embedding.is_null() {
241                Err(EmbeddingsError::LogitsNotEnabled)
242            } else {
243                Ok(slice::from_raw_parts(embedding, n_embd))
244            }
245        }
246    }
247
248    /// Get the logits for the last token in the context.
249    ///
250    /// # Returns
251    /// An iterator over unsorted `LlamaTokenData` containing the
252    /// logits for the last token in the context.
253    ///
254    /// # Panics
255    ///
256    /// - underlying logits data is null
257    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
258        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
259            let token = LlamaToken::new(i);
260            LlamaTokenData::new(token, *logit, 0_f32)
261        })
262    }
263
264    /// Get the token data array for the last token in the context.
265    ///
266    /// This is a convience method that implements:
267    /// ```ignore
268    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
269    /// ```
270    ///
271    /// # Panics
272    ///
273    /// - underlying logits data is null
274    #[must_use]
275    pub fn token_data_array(&self) -> LlamaTokenDataArray {
276        LlamaTokenDataArray::from_iter(self.candidates(), false)
277    }
278
279    /// Token logits obtained from the last call to `decode()`.
280    /// The logits for which `batch.logits[i] != 0` are stored contiguously
281    /// in the order they have appeared in the batch.
282    /// Rows: number of tokens for which `batch.logits[i] != 0`
283    /// Cols: `n_vocab`
284    ///
285    /// # Returns
286    ///
287    /// A slice containing the logits for the last decoded token.
288    /// The size corresponds to the `n_vocab` parameter of the context's model.
289    ///
290    /// # Panics
291    ///
292    /// - `n_vocab` does not fit into a usize
293    /// - token data returned is null
294    #[must_use]
295    pub fn get_logits(&self) -> &[f32] {
296        let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
297        assert!(!data.is_null(), "logits data for last token is null");
298        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
299
300        unsafe { slice::from_raw_parts(data, len) }
301    }
302
303    /// Get the logits for the ith token in the context.
304    ///
305    /// # Panics
306    ///
307    /// - logit `i` is not initialized.
308    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
309        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
310            let token = LlamaToken::new(i);
311            LlamaTokenData::new(token, *logit, 0_f32)
312        })
313    }
314
315    /// Get the logits for the ith token in the context.
316    ///
317    /// # Panics
318    ///
319    /// - `i` is greater than `n_ctx`
320    /// - `n_vocab` does not fit into a usize
321    /// - logit `i` is not initialized.
322    #[must_use]
323    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
324        assert!(
325            self.initialized_logits.contains(&i),
326            "logit {i} is not initialized. only {:?} is",
327            self.initialized_logits
328        );
329        assert!(
330            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
331            "n_ctx ({}) must be greater than i ({})",
332            self.n_ctx(),
333            i
334        );
335
336        let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
337        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
338
339        unsafe { slice::from_raw_parts(data, len) }
340    }
341
342    /// Get the number of context tokens per sequence.
343    #[must_use]
344    pub fn n_ctx_seq(&self) -> u32 {
345        unsafe { llama_cpp_sys_4::llama_n_ctx_seq(self.context.as_ptr()) }
346    }
347
348    /// Get the maximum number of sequences.
349    #[must_use]
350    pub fn n_seq_max(&self) -> u32 {
351        unsafe { llama_cpp_sys_4::llama_n_seq_max(self.context.as_ptr()) }
352    }
353
354    /// Get the number of threads used for generation.
355    #[must_use]
356    pub fn n_threads(&self) -> i32 {
357        unsafe { llama_cpp_sys_4::llama_n_threads(self.context.as_ptr()) }
358    }
359
360    /// Get the number of threads used for batch processing.
361    #[must_use]
362    pub fn n_threads_batch(&self) -> i32 {
363        unsafe { llama_cpp_sys_4::llama_n_threads_batch(self.context.as_ptr()) }
364    }
365
366    /// Set the number of threads used for generation and batch processing.
367    pub fn set_n_threads(&mut self, n_threads: i32, n_threads_batch: i32) {
368        unsafe {
369            llama_cpp_sys_4::llama_set_n_threads(
370                self.context.as_ptr(),
371                n_threads,
372                n_threads_batch,
373            );
374        }
375    }
376
377    /// Set whether to use causal attention.
378    ///
379    /// If set to `false`, the model will use non-causal attention, which is
380    /// needed for embedding models.
381    pub fn set_causal_attn(&mut self, causal_attn: bool) {
382        unsafe {
383            llama_cpp_sys_4::llama_set_causal_attn(self.context.as_ptr(), causal_attn);
384        }
385    }
386
387    /// Set whether to compute embeddings.
388    ///
389    /// This allows toggling embedding mode at runtime (as opposed to only at
390    /// context creation time).
391    pub fn set_embeddings(&mut self, embeddings: bool) {
392        self.embeddings_enabled = embeddings;
393        unsafe {
394            llama_cpp_sys_4::llama_set_embeddings(self.context.as_ptr(), embeddings);
395        }
396    }
397
398    /// Mark the next computation as a warmup run.
399    ///
400    /// Warmup runs are useful for GPU backends to compile kernels before
401    /// actual inference begins.
402    pub fn set_warmup(&mut self, warmup: bool) {
403        unsafe {
404            llama_cpp_sys_4::llama_set_warmup(self.context.as_ptr(), warmup);
405        }
406    }
407
408    /// Wait for all pending async computations to finish.
409    pub fn synchronize(&mut self) {
410        unsafe {
411            llama_cpp_sys_4::llama_synchronize(self.context.as_ptr());
412        }
413    }
414
415    /// Get all embeddings for the current context.
416    ///
417    /// Returns a slice of all embeddings from the last decoded batch.
418    /// For pooled embeddings use [`embeddings_seq_ith`](Self::embeddings_seq_ith) instead.
419    ///
420    /// # Errors
421    ///
422    /// - When the current context was constructed without enabling embeddings.
423    /// - If the embeddings pointer is null.
424    ///
425    /// # Panics
426    ///
427    /// * `n_embd` does not fit into a usize
428    pub fn get_embeddings(&self) -> Result<&[f32], EmbeddingsError> {
429        if !self.embeddings_enabled {
430            return Err(EmbeddingsError::NotEnabled);
431        }
432
433        let n_embd =
434            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
435
436        unsafe {
437            let embedding = llama_cpp_sys_4::llama_get_embeddings(self.context.as_ptr());
438            if embedding.is_null() {
439                Err(EmbeddingsError::NonePoolType)
440            } else {
441                Ok(slice::from_raw_parts(embedding, n_embd))
442            }
443        }
444    }
445
446    /// Reset the timings for the context.
447    pub fn reset_timings(&mut self) {
448        unsafe { llama_cpp_sys_4::ggml_time_init() }
449    }
450
451    /// Returns the timings for the context.
452    pub fn timings(&mut self) -> PerfContextData {
453        let perf_context_data =
454            unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
455        PerfContextData { perf_context_data }
456    }
457
458    /// Reset the performance counters for the context.
459    pub fn perf_context_reset(&mut self) {
460        unsafe { llama_cpp_sys_4::llama_perf_context_reset(self.context.as_ptr()) }
461    }
462
463    /// Check if the KV cache memory supports shifting.
464    #[must_use]
465    pub fn memory_can_shift(&self) -> bool {
466        unsafe {
467            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
468            llama_cpp_sys_4::llama_memory_can_shift(mem)
469        }
470    }
471
472    /// Get the minimum position in a sequence's KV cache.
473    #[must_use]
474    pub fn memory_seq_pos_min(&self, seq_id: i32) -> i32 {
475        unsafe {
476            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
477            llama_cpp_sys_4::llama_memory_seq_pos_min(mem, seq_id)
478        }
479    }
480
481    /// Print a breakdown of the memory usage.
482    pub fn memory_breakdown_print(&self) {
483        unsafe {
484            llama_cpp_sys_4::llama_memory_breakdown_print(self.context.as_ptr());
485        }
486    }
487
488    /// Get the size of the full context state in bytes.
489    ///
490    /// This is the size needed for [`state_get_data`](Self::state_get_data) and
491    /// [`state_set_data`](Self::state_set_data).
492    #[must_use]
493    pub fn state_get_size(&mut self) -> usize {
494        unsafe { llama_cpp_sys_4::llama_state_get_size(self.context.as_ptr()) }
495    }
496
497    /// Copy the full context state into a byte buffer.
498    ///
499    /// The buffer must be at least [`state_get_size`](Self::state_get_size) bytes.
500    ///
501    /// Returns the number of bytes written.
502    pub fn state_get_data(&mut self, dst: &mut [u8]) -> usize {
503        unsafe {
504            llama_cpp_sys_4::llama_state_get_data(
505                self.context.as_ptr(),
506                dst.as_mut_ptr(),
507                dst.len(),
508            )
509        }
510    }
511
512    /// Restore the full context state from a byte buffer.
513    ///
514    /// Returns the number of bytes read.
515    pub fn state_set_data(&mut self, src: &[u8]) -> usize {
516        unsafe {
517            llama_cpp_sys_4::llama_state_set_data(
518                self.context.as_ptr(),
519                src.as_ptr(),
520                src.len(),
521            )
522        }
523    }
524
525    /// Save the context state to a file along with the given tokens.
526    ///
527    /// Returns `true` on success.
528    ///
529    /// # Panics
530    ///
531    /// Panics if the path contains null bytes.
532    pub fn state_save_file(
533        &mut self,
534        path: impl AsRef<std::path::Path>,
535        tokens: &[LlamaToken],
536    ) -> bool {
537        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
538        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
539        unsafe {
540            llama_cpp_sys_4::llama_state_save_file(
541                self.context.as_ptr(),
542                c_path.as_ptr(),
543                tokens.as_ptr().cast(),
544                tokens.len(),
545            )
546        }
547    }
548
549    /// Load a context state from a file.
550    ///
551    /// Returns `true` on success and fills `tokens_out` with the saved tokens.
552    ///
553    /// # Panics
554    ///
555    /// Panics if the path contains null bytes.
556    pub fn state_load_file(
557        &mut self,
558        path: impl AsRef<std::path::Path>,
559        tokens_out: &mut Vec<LlamaToken>,
560        n_token_capacity: usize,
561    ) -> bool {
562        tokens_out.resize(n_token_capacity, LlamaToken(0));
563        let mut n_token_count: usize = 0;
564        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
565        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
566        let ok = unsafe {
567            llama_cpp_sys_4::llama_state_load_file(
568                self.context.as_ptr(),
569                c_path.as_ptr(),
570                tokens_out.as_mut_ptr().cast(),
571                n_token_capacity,
572                std::ptr::addr_of_mut!(n_token_count),
573            )
574        };
575        if ok {
576            tokens_out.truncate(n_token_count);
577        }
578        ok
579    }
580
581    /// Get the size of a single sequence's state in bytes.
582    #[must_use]
583    pub fn state_seq_get_size(&mut self, seq_id: i32) -> usize {
584        unsafe { llama_cpp_sys_4::llama_state_seq_get_size(self.context.as_ptr(), seq_id) }
585    }
586
587    /// Copy a single sequence's state into a byte buffer.
588    ///
589    /// Returns the number of bytes written.
590    pub fn state_seq_get_data(&mut self, dst: &mut [u8], seq_id: i32) -> usize {
591        unsafe {
592            llama_cpp_sys_4::llama_state_seq_get_data(
593                self.context.as_ptr(),
594                dst.as_mut_ptr(),
595                dst.len(),
596                seq_id,
597            )
598        }
599    }
600
601    /// Restore a single sequence's state from a byte buffer.
602    ///
603    /// Returns the number of bytes read.
604    pub fn state_seq_set_data(&mut self, src: &[u8], dest_seq_id: i32) -> usize {
605        unsafe {
606            llama_cpp_sys_4::llama_state_seq_set_data(
607                self.context.as_ptr(),
608                src.as_ptr(),
609                src.len(),
610                dest_seq_id,
611            )
612        }
613    }
614
615    /// Save a single sequence's state to a file.
616    ///
617    /// Returns the number of bytes written (0 on failure).
618    ///
619    /// # Panics
620    ///
621    /// Panics if the path contains null bytes.
622    pub fn state_seq_save_file(
623        &mut self,
624        path: impl AsRef<std::path::Path>,
625        seq_id: i32,
626        tokens: &[LlamaToken],
627    ) -> usize {
628        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
629        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
630        unsafe {
631            llama_cpp_sys_4::llama_state_seq_save_file(
632                self.context.as_ptr(),
633                c_path.as_ptr(),
634                seq_id,
635                tokens.as_ptr().cast(),
636                tokens.len(),
637            )
638        }
639    }
640
641    /// Load a single sequence's state from a file.
642    ///
643    /// Returns the number of bytes read (0 on failure).
644    ///
645    /// # Panics
646    ///
647    /// Panics if the path contains null bytes.
648    pub fn state_seq_load_file(
649        &mut self,
650        path: impl AsRef<std::path::Path>,
651        dest_seq_id: i32,
652        tokens_out: &mut Vec<LlamaToken>,
653        n_token_capacity: usize,
654    ) -> usize {
655        tokens_out.resize(n_token_capacity, LlamaToken(0));
656        let mut n_token_count: usize = 0;
657        let path_str = path.as_ref().to_str().expect("path is not valid UTF-8");
658        let c_path = std::ffi::CString::new(path_str).expect("path contains null bytes");
659        let ret = unsafe {
660            llama_cpp_sys_4::llama_state_seq_load_file(
661                self.context.as_ptr(),
662                c_path.as_ptr(),
663                dest_seq_id,
664                tokens_out.as_mut_ptr().cast(),
665                n_token_capacity,
666                std::ptr::addr_of_mut!(n_token_count),
667            )
668        };
669        if ret > 0 {
670            tokens_out.truncate(n_token_count);
671        }
672        ret
673    }
674
675    /// Set a control vector on the context.
676    ///
677    /// # Parameters
678    ///
679    /// - `data`: The control vector data (embedding values). Pass an empty slice to clear.
680    /// - `n_embd`: The embedding dimension.
681    /// - `il_start`: The starting layer index (inclusive).
682    /// - `il_end`: The ending layer index (exclusive).
683    ///
684    /// # Errors
685    ///
686    /// Returns `Err` with the error code if the operation fails.
687    pub fn set_adapter_cvec(
688        &mut self,
689        data: &[f32],
690        n_embd: i32,
691        il_start: i32,
692        il_end: i32,
693    ) -> Result<(), i32> {
694        let ret = unsafe {
695            llama_cpp_sys_4::llama_set_adapter_cvec(
696                self.context.as_ptr(),
697                data.as_ptr(),
698                data.len(),
699                n_embd,
700                il_start,
701                il_end,
702            )
703        };
704        if ret != 0 {
705            Err(ret)
706        } else {
707            Ok(())
708        }
709    }
710
711    /// Get sampled token debug info for the `i`th position.
712    ///
713    /// Returns the sampled token at position `i` from the last decode call.
714    #[must_use]
715    pub fn get_sampled_token_ith(&self, i: i32) -> LlamaToken {
716        let token =
717            unsafe { llama_cpp_sys_4::llama_get_sampled_token_ith(self.context.as_ptr(), i) };
718        LlamaToken(token)
719    }
720
721    /// Get sampled candidate tokens for the `i`th position.
722    ///
723    /// Returns a slice of candidate tokens from the last decode call.
724    #[must_use]
725    pub fn get_sampled_candidates_ith(&self, i: i32) -> &[LlamaToken] {
726        let count = unsafe {
727            llama_cpp_sys_4::llama_get_sampled_candidates_count_ith(self.context.as_ptr(), i)
728        } as usize;
729        if count == 0 {
730            return &[];
731        }
732        let ptr = unsafe {
733            llama_cpp_sys_4::llama_get_sampled_candidates_ith(self.context.as_ptr(), i)
734        };
735        if ptr.is_null() {
736            return &[];
737        }
738        unsafe { slice::from_raw_parts(ptr.cast::<LlamaToken>(), count) }
739    }
740
741    /// Get the number of sampled logits for the `i`th position.
742    #[must_use]
743    pub fn get_sampled_logits_count_ith(&self, i: i32) -> u32 {
744        unsafe {
745            llama_cpp_sys_4::llama_get_sampled_logits_count_ith(self.context.as_ptr(), i)
746        }
747    }
748
749    /// Get sampled logits for the `i`th position.
750    ///
751    /// Returns a slice of logit values from the last decode call.
752    #[must_use]
753    pub fn get_sampled_logits_ith(&self, i: i32) -> &[f32] {
754        let count = self.get_sampled_logits_count_ith(i) as usize;
755        if count == 0 {
756            return &[];
757        }
758        let ptr = unsafe {
759            llama_cpp_sys_4::llama_get_sampled_logits_ith(self.context.as_ptr(), i)
760        };
761        if ptr.is_null() {
762            return &[];
763        }
764        unsafe { slice::from_raw_parts(ptr, count) }
765    }
766
767    /// Get the number of sampled probabilities for the `i`th position.
768    #[must_use]
769    pub fn get_sampled_probs_count_ith(&self, i: i32) -> u32 {
770        unsafe {
771            llama_cpp_sys_4::llama_get_sampled_probs_count_ith(self.context.as_ptr(), i)
772        }
773    }
774
775    /// Get sampled probabilities for the `i`th position.
776    ///
777    /// Returns a slice of probability values from the last decode call.
778    #[must_use]
779    pub fn get_sampled_probs_ith(&self, i: i32) -> &[f32] {
780        let count = self.get_sampled_probs_count_ith(i) as usize;
781        if count == 0 {
782            return &[];
783        }
784        let ptr = unsafe {
785            llama_cpp_sys_4::llama_get_sampled_probs_ith(self.context.as_ptr(), i)
786        };
787        if ptr.is_null() {
788            return &[];
789        }
790        unsafe { slice::from_raw_parts(ptr, count) }
791    }
792
793    /// Get the size of a single sequence's state with flags.
794    #[must_use]
795    pub fn state_seq_get_size_ext(&mut self, seq_id: i32, flags: u32) -> usize {
796        unsafe {
797            llama_cpp_sys_4::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags)
798        }
799    }
800
801    /// Copy a single sequence's state into a byte buffer with flags.
802    ///
803    /// Returns the number of bytes written.
804    pub fn state_seq_get_data_ext(&mut self, dst: &mut [u8], seq_id: i32, flags: u32) -> usize {
805        unsafe {
806            llama_cpp_sys_4::llama_state_seq_get_data_ext(
807                self.context.as_ptr(),
808                dst.as_mut_ptr(),
809                dst.len(),
810                seq_id,
811                flags,
812            )
813        }
814    }
815
816    /// Restore a single sequence's state from a byte buffer with flags.
817    ///
818    /// Returns the number of bytes read.
819    pub fn state_seq_set_data_ext(
820        &mut self,
821        src: &[u8],
822        dest_seq_id: i32,
823        flags: u32,
824    ) -> usize {
825        unsafe {
826            llama_cpp_sys_4::llama_state_seq_set_data_ext(
827                self.context.as_ptr(),
828                src.as_ptr(),
829                src.len(),
830                dest_seq_id,
831                flags,
832            )
833        }
834    }
835
836    /// Set an abort callback for the context.
837    ///
838    /// The callback is called periodically during computation. If it returns `true`,
839    /// the computation is aborted.
840    ///
841    /// # Safety
842    ///
843    /// The callback data must remain valid for the lifetime of the context or until
844    /// the callback is replaced.
845    pub unsafe fn set_abort_callback(
846        &mut self,
847        callback: llama_cpp_sys_4::ggml_abort_callback,
848        data: *mut std::ffi::c_void,
849    ) {
850        llama_cpp_sys_4::llama_set_abort_callback(self.context.as_ptr(), callback, data);
851    }
852
853    /// Attach a thread pool to the context.
854    ///
855    /// # Safety
856    ///
857    /// The thread pools must remain valid for the lifetime of the context or until
858    /// they are detached.
859    pub unsafe fn attach_threadpool(
860        &mut self,
861        threadpool: llama_cpp_sys_4::ggml_threadpool_t,
862        threadpool_batch: llama_cpp_sys_4::ggml_threadpool_t,
863    ) {
864        llama_cpp_sys_4::llama_attach_threadpool(
865            self.context.as_ptr(),
866            threadpool,
867            threadpool_batch,
868        );
869    }
870
871    /// Detach the thread pool from the context.
872    pub fn detach_threadpool(&mut self) {
873        unsafe {
874            llama_cpp_sys_4::llama_detach_threadpool(self.context.as_ptr());
875        }
876    }
877
878    /// Set a sampler for a specific sequence.
879    ///
880    /// Returns `true` on success.
881    pub fn set_sampler(
882        &mut self,
883        seq_id: i32,
884        sampler: &mut crate::sampling::LlamaSampler,
885    ) -> bool {
886        unsafe {
887            llama_cpp_sys_4::llama_set_sampler(
888                self.context.as_ptr(),
889                seq_id,
890                sampler.sampler.as_ptr(),
891            )
892        }
893    }
894
895    /// Get the raw model pointer from this context.
896    ///
897    /// This is mainly useful for FFI interop. In normal usage, access
898    /// the model via the `model` field instead.
899    #[must_use]
900    pub fn get_model_ptr(&self) -> *const llama_cpp_sys_4::llama_model {
901        unsafe { llama_cpp_sys_4::llama_get_model(self.context.as_ptr()) }
902    }
903
904    /// Sets a lora adapter.
905    ///
906    /// # Errors
907    ///
908    /// See [`LlamaLoraAdapterSetError`] for more information.
909    pub fn lora_adapter_set(
910        &self,
911        adapter: &mut LlamaLoraAdapter,
912        scale: f32,
913    ) -> Result<(), LlamaLoraAdapterSetError> {
914        let err_code = unsafe {
915            // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
916            // which takes a full list of adapters + scales at once (b8249+)
917            let mut adapter_ptr = adapter.lora_adapter.as_ptr();
918            let mut scale_val = scale;
919            llama_cpp_sys_4::llama_set_adapters_lora(
920                self.context.as_ptr(),
921                &raw mut adapter_ptr,
922                1,
923                &raw mut scale_val,
924            )
925        };
926        if err_code != 0 {
927            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
928        }
929
930        tracing::debug!("Set lora adapter");
931        Ok(())
932    }
933
934    /// Remove all lora adapters from the context.
935    ///
936    /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
937    /// `llama_set_adapters_lora` which operates on the full adapter list at once.
938    /// Calling this function clears **all** adapters currently set on the context.
939    ///
940    /// # Errors
941    ///
942    /// See [`LlamaLoraAdapterRemoveError`] for more information.
943    pub fn lora_adapter_remove(
944        &self,
945        _adapter: &mut LlamaLoraAdapter,
946    ) -> Result<(), LlamaLoraAdapterRemoveError> {
947        let err_code = unsafe {
948            llama_cpp_sys_4::llama_set_adapters_lora(
949                self.context.as_ptr(),
950                std::ptr::null_mut(),
951                0,
952                std::ptr::null_mut(),
953            )
954        };
955        if err_code != 0 {
956            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
957        }
958
959        tracing::debug!("Remove lora adapter");
960        Ok(())
961    }
962}
963
964impl Drop for LlamaContext<'_> {
965    fn drop(&mut self) {
966        unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
967    }
968}