Skip to main content

llama_cpp_4/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use llama_cpp_sys_4::llama_pooling_type;
9use params::LlamaPoolingType;
10use perf::PerfContextData;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::token::data::LlamaTokenData;
15use crate::token::data_array::LlamaTokenDataArray;
16use crate::token::LlamaToken;
17use crate::{
18    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
19    LlamaLoraAdapterSetError,
20};
21
22pub mod kv_cache;
23pub mod params;
24pub mod perf;
25pub mod session;
26
27/// A safe wrapper around the `llama_context` C++ context.
28///
29/// This struct provides a safe interface to interact with the `llama_context` used by the `LlamaModel`.
30/// It encapsulates the raw C++ context pointer and provides additional fields for managing the model and
31/// context-specific settings like embeddings and logits.
32///
33/// The `LlamaContext` struct ensures that the C++ context is always valid by using the `NonNull` type for
34/// the context pointer, preventing it from being null. The struct also holds a reference to the model
35/// (`LlamaModel`) that the context is tied to, along with some internal state like whether embeddings are enabled
36/// and the initialized logits for the context.
37///
38/// # Fields
39///
40/// - `context`: A non-null pointer to the raw C++ `llama_context`. This is the main context used for interacting with the model.
41/// - `model`: A reference to the `LlamaModel` associated with this context. This model provides the data and parameters
42///   that the context interacts with.
43/// - `initialized_logits`: A vector used to store the initialized logits. These are used in the model's processing and
44///   are kept separate from the context data.
45/// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in the context. This is useful for
46///   controlling whether embedding data is generated during the interaction with the model.
47#[allow(clippy::module_name_repetitions)]
48pub struct LlamaContext<'a> {
49    pub(crate) context: NonNull<llama_cpp_sys_4::llama_context>,
50    /// a reference to the contexts model.
51    pub model: &'a LlamaModel,
52    initialized_logits: Vec<i32>,
53    embeddings_enabled: bool,
54}
55
56impl Debug for LlamaContext<'_> {
57    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
58        f.debug_struct("LlamaContext")
59            .field("context", &self.context)
60            .finish()
61    }
62}
63
64impl<'model> LlamaContext<'model> {
65    /// Creates a new instance of `LlamaContext` with the provided model, context, and embeddings flag.
66    ///
67    /// This function initializes a new `LlamaContext` object, which is used to interact with the
68    /// `LlamaModel`. The context is created from a pointer to a C++ context and the embeddings flag
69    /// determines whether embeddings are enabled in the context.
70    ///
71    /// # Parameters
72    ///
73    /// - `llama_model`: A reference to an existing `LlamaModel` that will be used with the new context.
74    /// - `llama_context`: A non-null pointer to an existing `llama_cpp_sys_4::llama_context` representing
75    ///   the context created in previous steps. This context is necessary for interacting with the model.
76    /// - `embeddings_enabled`: A boolean flag indicating whether embeddings are enabled in this context.
77    ///
78    /// # Returns
79    ///
80    /// This function returns a new instance of `LlamaContext` initialized with the given parameters:
81    /// - The model reference (`llama_model`) is stored in the context.
82    /// - The raw context pointer (`llama_context`) is wrapped in a `NonNull` to ensure safety.
83    /// - The `embeddings_enabled` flag is used to determine if embeddings are enabled for the context.
84    ///
85    /// # Example
86    /// ```
87    /// let llama_model = LlamaModel::load("path/to/model").unwrap();
88    /// let context_ptr = NonNull::new(some_llama_context_ptr).unwrap();
89    /// let context = LlamaContext::new(&llama_model, context_ptr, true);
90    /// // Now you can use the context
91    /// ```
92    pub(crate) fn new(
93        llama_model: &'model LlamaModel,
94        llama_context: NonNull<llama_cpp_sys_4::llama_context>,
95        embeddings_enabled: bool,
96    ) -> Self {
97        Self {
98            context: llama_context,
99            model: llama_model,
100            initialized_logits: Vec::new(),
101            embeddings_enabled,
102        }
103    }
104
105    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to `n_ubatch`.
106    #[must_use]
107    pub fn n_batch(&self) -> u32 {
108        unsafe { llama_cpp_sys_4::llama_n_batch(self.context.as_ptr()) }
109    }
110
111    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to `n_batch`.
112    #[must_use]
113    pub fn n_ubatch(&self) -> u32 {
114        unsafe { llama_cpp_sys_4::llama_n_ubatch(self.context.as_ptr()) }
115    }
116
117    /// Gets the size of the context.
118    #[must_use]
119    pub fn n_ctx(&self) -> u32 {
120        unsafe { llama_cpp_sys_4::llama_n_ctx(self.context.as_ptr()) }
121    }
122
123    /// Decodes the batch.
124    ///
125    /// # Errors
126    ///
127    /// - `DecodeError` if the decoding failed.
128    ///
129    /// # Panics
130    ///
131    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
132    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
133        let result =
134            unsafe { llama_cpp_sys_4::llama_decode(self.context.as_ptr(), batch.llama_batch) };
135
136        match NonZeroI32::new(result) {
137            None => {
138                self.initialized_logits
139                    .clone_from(&batch.initialized_logits);
140                Ok(())
141            }
142            Some(error) => Err(DecodeError::from(error)),
143        }
144    }
145
146    /// Encodes the batch.
147    ///
148    /// # Errors
149    ///
150    /// - `EncodeError` if the decoding failed.
151    ///
152    /// # Panics
153    ///
154    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
155    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
156        let result =
157            unsafe { llama_cpp_sys_4::llama_encode(self.context.as_ptr(), batch.llama_batch) };
158
159        match NonZeroI32::new(result) {
160            None => {
161                self.initialized_logits
162                    .clone_from(&batch.initialized_logits);
163                Ok(())
164            }
165            Some(error) => Err(EncodeError::from(error)),
166        }
167    }
168
169    /// Return Pooling type for Llama's Context
170    #[must_use]
171    pub fn pooling_type(&self) -> LlamaPoolingType {
172        let pooling_type = unsafe { llama_pooling_type(self.context.as_ptr()) };
173
174        LlamaPoolingType::from(pooling_type)
175    }
176
177    /// Get the embeddings for the `i`th sequence in the current context.
178    ///
179    /// # Returns
180    ///
181    /// A slice containing the embeddings for the last decoded batch.
182    /// The size corresponds to the `n_embd` parameter of the context's model.
183    ///
184    /// # Errors
185    ///
186    /// - When the current context was constructed without enabling embeddings.
187    /// - If the current model had a pooling type of [`llama_cpp_sys_4::LLAMA_POOLING_TYPE_NONE`]
188    /// - If the given sequence index exceeds the max sequence id.
189    ///
190    /// # Panics
191    ///
192    /// * `n_embd` does not fit into a usize
193    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
194        if !self.embeddings_enabled {
195            return Err(EmbeddingsError::NotEnabled);
196        }
197
198        let n_embd =
199            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
200
201        unsafe {
202            let embedding = llama_cpp_sys_4::llama_get_embeddings_seq(self.context.as_ptr(), i);
203
204            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
205            if embedding.is_null() {
206                Err(EmbeddingsError::NonePoolType)
207            } else {
208                Ok(slice::from_raw_parts(embedding, n_embd))
209            }
210        }
211    }
212
213    /// Get the embeddings for the `i`th token in the current context.
214    ///
215    /// # Returns
216    ///
217    /// A slice containing the embeddings for the last decoded batch of the given token.
218    /// The size corresponds to the `n_embd` parameter of the context's model.
219    ///
220    /// # Errors
221    ///
222    /// - When the current context was constructed without enabling embeddings.
223    /// - When the given token didn't have logits enabled when it was passed.
224    /// - If the given token index exceeds the max token id.
225    ///
226    /// # Panics
227    ///
228    /// * `n_embd` does not fit into a usize
229    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
230        if !self.embeddings_enabled {
231            return Err(EmbeddingsError::NotEnabled);
232        }
233
234        let n_embd =
235            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
236
237        unsafe {
238            let embedding = llama_cpp_sys_4::llama_get_embeddings_ith(self.context.as_ptr(), i);
239            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
240            if embedding.is_null() {
241                Err(EmbeddingsError::LogitsNotEnabled)
242            } else {
243                Ok(slice::from_raw_parts(embedding, n_embd))
244            }
245        }
246    }
247
248    /// Get the logits for the last token in the context.
249    ///
250    /// # Returns
251    /// An iterator over unsorted `LlamaTokenData` containing the
252    /// logits for the last token in the context.
253    ///
254    /// # Panics
255    ///
256    /// - underlying logits data is null
257    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
258        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
259            let token = LlamaToken::new(i);
260            LlamaTokenData::new(token, *logit, 0_f32)
261        })
262    }
263
264    /// Get the token data array for the last token in the context.
265    ///
266    /// This is a convience method that implements:
267    /// ```ignore
268    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
269    /// ```
270    ///
271    /// # Panics
272    ///
273    /// - underlying logits data is null
274    #[must_use]
275    pub fn token_data_array(&self) -> LlamaTokenDataArray {
276        LlamaTokenDataArray::from_iter(self.candidates(), false)
277    }
278
279    /// Token logits obtained from the last call to `decode()`.
280    /// The logits for which `batch.logits[i] != 0` are stored contiguously
281    /// in the order they have appeared in the batch.
282    /// Rows: number of tokens for which `batch.logits[i] != 0`
283    /// Cols: `n_vocab`
284    ///
285    /// # Returns
286    ///
287    /// A slice containing the logits for the last decoded token.
288    /// The size corresponds to the `n_vocab` parameter of the context's model.
289    ///
290    /// # Panics
291    ///
292    /// - `n_vocab` does not fit into a usize
293    /// - token data returned is null
294    #[must_use]
295    pub fn get_logits(&self) -> &[f32] {
296        let data = unsafe { llama_cpp_sys_4::llama_get_logits(self.context.as_ptr()) };
297        assert!(!data.is_null(), "logits data for last token is null");
298        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
299
300        unsafe { slice::from_raw_parts(data, len) }
301    }
302
303    /// Get the logits for the ith token in the context.
304    ///
305    /// # Panics
306    ///
307    /// - logit `i` is not initialized.
308    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
309        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
310            let token = LlamaToken::new(i);
311            LlamaTokenData::new(token, *logit, 0_f32)
312        })
313    }
314
315    /// Get the logits for the ith token in the context.
316    ///
317    /// # Panics
318    ///
319    /// - `i` is greater than `n_ctx`
320    /// - `n_vocab` does not fit into a usize
321    /// - logit `i` is not initialized.
322    #[must_use]
323    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
324        assert!(
325            self.initialized_logits.contains(&i),
326            "logit {i} is not initialized. only {:?} is",
327            self.initialized_logits
328        );
329        assert!(
330            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
331            "n_ctx ({}) must be greater than i ({})",
332            self.n_ctx(),
333            i
334        );
335
336        let data = unsafe { llama_cpp_sys_4::llama_get_logits_ith(self.context.as_ptr(), i) };
337        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
338
339        unsafe { slice::from_raw_parts(data, len) }
340    }
341
342    /// Reset the timings for the context.
343    pub fn reset_timings(&mut self) {
344        unsafe { llama_cpp_sys_4::ggml_time_init() }
345    }
346
347    /// Returns the timings for the context.
348    pub fn timings(&mut self) -> PerfContextData {
349        let perf_context_data =
350            unsafe { llama_cpp_sys_4::llama_perf_context(self.context.as_ptr()) };
351        PerfContextData { perf_context_data }
352    }
353
354    /// Sets a lora adapter.
355    ///
356    /// # Errors
357    ///
358    /// See [`LlamaLoraAdapterSetError`] for more information.
359    pub fn lora_adapter_set(
360        &self,
361        adapter: &mut LlamaLoraAdapter,
362        scale: f32,
363    ) -> Result<(), LlamaLoraAdapterSetError> {
364        let err_code = unsafe {
365            // llama_set_adapter_lora / llama_rm_adapter_lora were replaced by llama_set_adapters_lora
366            // which takes a full list of adapters + scales at once (b8249+)
367            let mut adapter_ptr = adapter.lora_adapter.as_ptr();
368            let mut scale_val = scale;
369            llama_cpp_sys_4::llama_set_adapters_lora(
370                self.context.as_ptr(),
371                &raw mut adapter_ptr,
372                1,
373                &raw mut scale_val,
374            )
375        };
376        if err_code != 0 {
377            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
378        }
379
380        tracing::debug!("Set lora adapter");
381        Ok(())
382    }
383
384    /// Remove all lora adapters from the context.
385    ///
386    /// Note: as of llama.cpp b8249 the per-adapter remove API was replaced by
387    /// `llama_set_adapters_lora` which operates on the full adapter list at once.
388    /// Calling this function clears **all** adapters currently set on the context.
389    ///
390    /// # Errors
391    ///
392    /// See [`LlamaLoraAdapterRemoveError`] for more information.
393    pub fn lora_adapter_remove(
394        &self,
395        _adapter: &mut LlamaLoraAdapter,
396    ) -> Result<(), LlamaLoraAdapterRemoveError> {
397        let err_code = unsafe {
398            llama_cpp_sys_4::llama_set_adapters_lora(
399                self.context.as_ptr(),
400                std::ptr::null_mut(),
401                0,
402                std::ptr::null_mut(),
403            )
404        };
405        if err_code != 0 {
406            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
407        }
408
409        tracing::debug!("Remove lora adapter");
410        Ok(())
411    }
412}
413
414impl Drop for LlamaContext<'_> {
415    fn drop(&mut self) {
416        unsafe { llama_cpp_sys_4::llama_free(self.context.as_ptr()) }
417    }
418}