Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16    LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23/// Safe wrapper around `llama_context`.
24pub struct LlamaContext<'model> {
25    /// Raw pointer to the underlying `llama_context`.
26    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
27    /// a reference to the contexts model.
28    pub model: &'model LlamaModel,
29    initialized_logits: Vec<i32>,
30    embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LlamaContext")
36            .field("context", &self.context)
37            .finish()
38    }
39}
40
41impl<'model> LlamaContext<'model> {
42    /// Wraps existing raw pointers into a new `LlamaContext`.
43    #[must_use]
44    pub fn new(
45        llama_model: &'model LlamaModel,
46        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
47        embeddings_enabled: bool,
48    ) -> Self {
49        Self {
50            context: llama_context,
51            model: llama_model,
52            initialized_logits: Vec::new(),
53            embeddings_enabled,
54        }
55    }
56
57    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
58    #[must_use]
59    pub fn n_batch(&self) -> u32 {
60        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
61    }
62
63    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
64    #[must_use]
65    pub fn n_ubatch(&self) -> u32 {
66        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
67    }
68
69    /// Gets the size of the context.
70    #[must_use]
71    pub fn n_ctx(&self) -> u32 {
72        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
73    }
74
75    /// Decodes the batch.
76    ///
77    /// # Errors
78    ///
79    /// - `DecodeError` if the decoding failed.
80    ///
81    /// # Panics
82    ///
83    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
84    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
85        let result = unsafe {
86            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
87        };
88
89        match NonZeroI32::new(result) {
90            None => {
91                self.initialized_logits
92                    .clone_from(&batch.initialized_logits);
93                Ok(())
94            }
95            Some(error) => Err(DecodeError::from(error)),
96        }
97    }
98
99    /// Encodes the batch.
100    ///
101    /// # Errors
102    ///
103    /// - `EncodeError` if the decoding failed.
104    ///
105    /// # Panics
106    ///
107    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
108    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
109        let result = unsafe {
110            llama_cpp_bindings_sys::llama_encode(self.context.as_ptr(), batch.llama_batch)
111        };
112
113        match NonZeroI32::new(result) {
114            None => {
115                self.initialized_logits
116                    .clone_from(&batch.initialized_logits);
117                Ok(())
118            }
119            Some(error) => Err(EncodeError::from(error)),
120        }
121    }
122
123    /// Get the embeddings for the `i`th sequence in the current context.
124    ///
125    /// # Returns
126    ///
127    /// A slice containing the embeddings for the last decoded batch.
128    /// The size corresponds to the `n_embd` parameter of the context's model.
129    ///
130    /// # Errors
131    ///
132    /// - When the current context was constructed without enabling embeddings.
133    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
134    /// - If the given sequence index exceeds the max sequence id.
135    ///
136    /// # Panics
137    ///
138    /// * `n_embd` does not fit into a usize
139    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
140        if !self.embeddings_enabled {
141            return Err(EmbeddingsError::NotEnabled);
142        }
143
144        let n_embd =
145            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
146
147        unsafe {
148            let embedding =
149                llama_cpp_bindings_sys::llama_get_embeddings_seq(self.context.as_ptr(), i);
150
151            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
152            if embedding.is_null() {
153                Err(EmbeddingsError::NonePoolType)
154            } else {
155                Ok(slice::from_raw_parts(embedding, n_embd))
156            }
157        }
158    }
159
160    /// Get the embeddings for the `i`th token in the current context.
161    ///
162    /// # Returns
163    ///
164    /// A slice containing the embeddings for the last decoded batch of the given token.
165    /// The size corresponds to the `n_embd` parameter of the context's model.
166    ///
167    /// # Errors
168    ///
169    /// - When the current context was constructed without enabling embeddings.
170    /// - When the given token didn't have logits enabled when it was passed.
171    /// - If the given token index exceeds the max token id.
172    ///
173    /// # Panics
174    ///
175    /// * `n_embd` does not fit into a usize
176    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
177        if !self.embeddings_enabled {
178            return Err(EmbeddingsError::NotEnabled);
179        }
180
181        let n_embd =
182            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
183
184        unsafe {
185            let embedding =
186                llama_cpp_bindings_sys::llama_get_embeddings_ith(self.context.as_ptr(), i);
187            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
188            if embedding.is_null() {
189                Err(EmbeddingsError::LogitsNotEnabled)
190            } else {
191                Ok(slice::from_raw_parts(embedding, n_embd))
192            }
193        }
194    }
195
196    /// Get the logits for the last token in the context.
197    ///
198    /// # Returns
199    /// An iterator over unsorted `LlamaTokenData` containing the
200    /// logits for the last token in the context.
201    ///
202    /// # Panics
203    ///
204    /// - underlying logits data is null
205    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
206        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
207            let token = LlamaToken::new(i);
208            LlamaTokenData::new(token, *logit, 0_f32)
209        })
210    }
211
212    /// Get the token data array for the last token in the context.
213    ///
214    /// This is a convience method that implements:
215    /// ```ignore
216    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
217    /// ```
218    ///
219    /// # Panics
220    ///
221    /// - underlying logits data is null
222    #[must_use]
223    pub fn token_data_array(&self) -> LlamaTokenDataArray {
224        LlamaTokenDataArray::from_iter(self.candidates(), false)
225    }
226
227    /// Token logits obtained from the last call to `decode()`.
228    /// The logits for which `batch.logits[i] != 0` are stored contiguously
229    /// in the order they have appeared in the batch.
230    /// Rows: number of tokens for which `batch.logits[i] != 0`
231    /// Cols: `n_vocab`
232    ///
233    /// # Returns
234    ///
235    /// A slice containing the logits for the last decoded token.
236    /// The size corresponds to the `n_vocab` parameter of the context's model.
237    ///
238    /// # Panics
239    ///
240    /// - `n_vocab` does not fit into a usize
241    /// - token data returned is null
242    #[must_use]
243    pub fn get_logits(&self) -> &[f32] {
244        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
245        assert!(!data.is_null(), "logits data for last token is null");
246        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
247
248        unsafe { slice::from_raw_parts(data, len) }
249    }
250
251    /// Get the logits for the ith token in the context.
252    ///
253    /// # Panics
254    ///
255    /// - logit `i` is not initialized.
256    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
257        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
258            let token = LlamaToken::new(i);
259            LlamaTokenData::new(token, *logit, 0_f32)
260        })
261    }
262
263    /// Get the token data array for the ith token in the context.
264    ///
265    /// This is a convience method that implements:
266    /// ```ignore
267    /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
268    /// ```
269    ///
270    /// # Panics
271    ///
272    /// - logit `i` is not initialized.
273    #[must_use]
274    pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
275        LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
276    }
277
278    /// Get the logits for the ith token in the context.
279    ///
280    /// # Panics
281    ///
282    /// - `i` is greater than `n_ctx`
283    /// - `n_vocab` does not fit into a usize
284    /// - logit `i` is not initialized.
285    #[must_use]
286    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
287        assert!(
288            self.initialized_logits.contains(&i),
289            "logit {i} is not initialized. only {:?} is",
290            self.initialized_logits
291        );
292        assert!(
293            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
294            "n_ctx ({}) must be greater than i ({})",
295            self.n_ctx(),
296            i
297        );
298
299        let data =
300            unsafe { llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), i) };
301        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
302
303        unsafe { slice::from_raw_parts(data, len) }
304    }
305
306    /// Reset the timings for the context.
307    pub fn reset_timings(&mut self) {
308        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
309    }
310
311    /// Returns the timings for the context.
312    pub fn timings(&mut self) -> LlamaTimings {
313        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
314        LlamaTimings { timings }
315    }
316
317    /// Sets a lora adapter.
318    ///
319    /// # Errors
320    ///
321    /// See [`LlamaLoraAdapterSetError`] for more information.
322    pub fn lora_adapter_set(
323        &self,
324        adapter: &mut LlamaLoraAdapter,
325        scale: f32,
326    ) -> Result<(), LlamaLoraAdapterSetError> {
327        let mut adapters = [adapter.lora_adapter.as_ptr()];
328        let mut scales = [scale];
329        let err_code = unsafe {
330            llama_cpp_bindings_sys::llama_set_adapters_lora(
331                self.context.as_ptr(),
332                adapters.as_mut_ptr(),
333                1,
334                scales.as_mut_ptr(),
335            )
336        };
337        if err_code != 0 {
338            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
339        }
340
341        tracing::debug!("Set lora adapter");
342        Ok(())
343    }
344
345    /// Remove all lora adapters.
346    ///
347    /// Note: The upstream API now replaces all adapters at once via
348    /// `llama_set_adapters_lora`. This clears all adapters from the context.
349    ///
350    /// # Errors
351    ///
352    /// See [`LlamaLoraAdapterRemoveError`] for more information.
353    pub fn lora_adapter_remove(
354        &self,
355        _adapter: &mut LlamaLoraAdapter,
356    ) -> Result<(), LlamaLoraAdapterRemoveError> {
357        let err_code = unsafe {
358            llama_cpp_bindings_sys::llama_set_adapters_lora(
359                self.context.as_ptr(),
360                std::ptr::null_mut(),
361                0,
362                std::ptr::null_mut(),
363            )
364        };
365        if err_code != 0 {
366            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
367        }
368
369        tracing::debug!("Remove lora adapter");
370        Ok(())
371    }
372}
373
374impl Drop for LlamaContext<'_> {
375    fn drop(&mut self) {
376        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
377    }
378}