Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16    LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23/// Safe wrapper around `llama_context`.
24pub struct LlamaContext<'model> {
25    /// Raw pointer to the underlying `llama_context`.
26    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
27    /// A reference to the context's model.
28    pub model: &'model LlamaModel,
29    initialized_logits: Vec<i32>,
30    embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LlamaContext")
36            .field("context", &self.context)
37            .finish()
38    }
39}
40
41impl<'model> LlamaContext<'model> {
42    /// Wraps existing raw pointers into a new `LlamaContext`.
43    #[must_use]
44    pub fn new(
45        llama_model: &'model LlamaModel,
46        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
47        embeddings_enabled: bool,
48    ) -> Self {
49        Self {
50            context: llama_context,
51            model: llama_model,
52            initialized_logits: Vec::new(),
53            embeddings_enabled,
54        }
55    }
56
57    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
58    #[must_use]
59    pub fn n_batch(&self) -> u32 {
60        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
61    }
62
63    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
64    #[must_use]
65    pub fn n_ubatch(&self) -> u32 {
66        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
67    }
68
69    /// Gets the size of the context.
70    #[must_use]
71    pub fn n_ctx(&self) -> u32 {
72        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
73    }
74
75    /// Decodes the batch.
76    ///
77    /// # Errors
78    ///
79    /// - `DecodeError` if the decoding failed.
80    ///
81    /// # Panics
82    ///
83    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
84    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
85        let result = unsafe {
86            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
87        };
88
89        match NonZeroI32::new(result) {
90            None => {
91                self.initialized_logits
92                    .clone_from(&batch.initialized_logits);
93                Ok(())
94            }
95            Some(error) => Err(DecodeError::from(error)),
96        }
97    }
98
99    /// Encodes the batch.
100    ///
101    /// # Errors
102    ///
103    /// - `EncodeError` if the decoding failed.
104    ///
105    /// # Panics
106    ///
107    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
108    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
109        let result = unsafe {
110            llama_cpp_bindings_sys::llama_encode(self.context.as_ptr(), batch.llama_batch)
111        };
112
113        match NonZeroI32::new(result) {
114            None => {
115                self.initialized_logits
116                    .clone_from(&batch.initialized_logits);
117                Ok(())
118            }
119            Some(error) => Err(EncodeError::from(error)),
120        }
121    }
122
123    /// Get the embeddings for the given sequence in the current context.
124    ///
125    /// # Returns
126    ///
127    /// A slice containing the embeddings for the last decoded batch.
128    /// The size corresponds to the `n_embd` parameter of the context's model.
129    ///
130    /// # Errors
131    ///
132    /// - When the current context was constructed without enabling embeddings.
133    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
134    /// - If the given sequence index exceeds the max sequence id.
135    ///
136    /// # Panics
137    ///
138    /// * `n_embd` does not fit into a usize
139    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
140        if !self.embeddings_enabled {
141            return Err(EmbeddingsError::NotEnabled);
142        }
143
144        let n_embd =
145            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
146
147        unsafe {
148            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
149                self.context.as_ptr(),
150                sequence_index,
151            );
152
153            if embedding.is_null() {
154                Err(EmbeddingsError::NonePoolType)
155            } else {
156                Ok(slice::from_raw_parts(embedding, n_embd))
157            }
158        }
159    }
160
161    /// Get the embeddings for the given token in the current context.
162    ///
163    /// # Returns
164    ///
165    /// A slice containing the embeddings for the last decoded batch of the given token.
166    /// The size corresponds to the `n_embd` parameter of the context's model.
167    ///
168    /// # Errors
169    ///
170    /// - When the current context was constructed without enabling embeddings.
171    /// - When the given token didn't have logits enabled when it was passed.
172    /// - If the given token index exceeds the max token id.
173    ///
174    /// # Panics
175    ///
176    /// * `n_embd` does not fit into a usize
177    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
178        if !self.embeddings_enabled {
179            return Err(EmbeddingsError::NotEnabled);
180        }
181
182        let n_embd =
183            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
184
185        unsafe {
186            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
187                self.context.as_ptr(),
188                token_index,
189            );
190
191            if embedding.is_null() {
192                Err(EmbeddingsError::LogitsNotEnabled)
193            } else {
194                Ok(slice::from_raw_parts(embedding, n_embd))
195            }
196        }
197    }
198
199    /// Get the logits for the last token in the context.
200    ///
201    /// # Returns
202    /// An iterator over unsorted `LlamaTokenData` containing the
203    /// logits for the last token in the context.
204    ///
205    /// # Panics
206    ///
207    /// - underlying logits data is null
208    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
209        (0_i32..).zip(self.get_logits()).map(|(token_id, logit)| {
210            let token = LlamaToken::new(token_id);
211            LlamaTokenData::new(token, *logit, 0_f32)
212        })
213    }
214
215    /// Get the token data array for the last token in the context.
216    ///
217    /// This is a convenience method that implements:
218    /// ```ignore
219    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
220    /// ```
221    ///
222    /// # Panics
223    ///
224    /// - underlying logits data is null
225    #[must_use]
226    pub fn token_data_array(&self) -> LlamaTokenDataArray {
227        LlamaTokenDataArray::from_iter(self.candidates(), false)
228    }
229
230    /// Token logits obtained from the last call to `decode()`.
231    /// The logits for which `batch.logits[i] != 0` are stored contiguously
232    /// in the order they have appeared in the batch.
233    /// Rows: number of tokens for which `batch.logits[i] != 0`
234    /// Cols: `n_vocab`
235    ///
236    /// # Returns
237    ///
238    /// A slice containing the logits for the last decoded token.
239    /// The size corresponds to the `n_vocab` parameter of the context's model.
240    ///
241    /// # Panics
242    ///
243    /// - `n_vocab` does not fit into a usize
244    /// - token data returned is null
245    #[must_use]
246    pub fn get_logits(&self) -> &[f32] {
247        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
248        assert!(!data.is_null(), "logits data for last token is null");
249        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
250
251        unsafe { slice::from_raw_parts(data, len) }
252    }
253
254    /// Get the logits for the ith token in the context.
255    ///
256    /// # Panics
257    ///
258    /// - logit `i` is not initialized.
259    pub fn candidates_ith(&self, token_index: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
260        (0_i32..)
261            .zip(self.get_logits_ith(token_index))
262            .map(|(token_id, logit)| {
263                let token = LlamaToken::new(token_id);
264                LlamaTokenData::new(token, *logit, 0_f32)
265            })
266    }
267
268    /// Get the token data array for the ith token in the context.
269    ///
270    /// This is a convenience method that implements:
271    /// ```ignore
272    /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(token_index), false)
273    /// ```
274    ///
275    /// # Panics
276    ///
277    /// - logit `i` is not initialized.
278    #[must_use]
279    pub fn token_data_array_ith(&self, token_index: i32) -> LlamaTokenDataArray {
280        LlamaTokenDataArray::from_iter(self.candidates_ith(token_index), false)
281    }
282
283    /// Get the logits for the ith token in the context.
284    ///
285    /// # Panics
286    ///
287    /// - `token_index` is greater than `n_ctx`
288    /// - `n_vocab` does not fit into a usize
289    /// - logit `token_index` is not initialized.
290    #[must_use]
291    pub fn get_logits_ith(&self, token_index: i32) -> &[f32] {
292        assert!(
293            self.initialized_logits.contains(&token_index),
294            "logit {token_index} is not initialized. only {:?} is",
295            self.initialized_logits
296        );
297        assert!(
298            self.n_ctx() > u32::try_from(token_index).expect("token_index does not fit into a u32"),
299            "n_ctx ({}) must be greater than token_index ({})",
300            self.n_ctx(),
301            token_index
302        );
303
304        let data = unsafe {
305            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
306        };
307        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
308
309        unsafe { slice::from_raw_parts(data, len) }
310    }
311
312    /// Reset the timings for the context.
313    pub fn reset_timings(&mut self) {
314        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
315    }
316
317    /// Returns the timings for the context.
318    pub fn timings(&mut self) -> LlamaTimings {
319        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
320        LlamaTimings { timings }
321    }
322
323    /// Sets a lora adapter.
324    ///
325    /// # Errors
326    ///
327    /// See [`LlamaLoraAdapterSetError`] for more information.
328    pub fn lora_adapter_set(
329        &self,
330        adapter: &mut LlamaLoraAdapter,
331        scale: f32,
332    ) -> Result<(), LlamaLoraAdapterSetError> {
333        let mut adapters = [adapter.lora_adapter.as_ptr()];
334        let mut scales = [scale];
335        let err_code = unsafe {
336            llama_cpp_bindings_sys::llama_set_adapters_lora(
337                self.context.as_ptr(),
338                adapters.as_mut_ptr(),
339                1,
340                scales.as_mut_ptr(),
341            )
342        };
343        if err_code != 0 {
344            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
345        }
346
347        tracing::debug!("Set lora adapter");
348        Ok(())
349    }
350
351    /// Remove all lora adapters.
352    ///
353    /// Note: The upstream API now replaces all adapters at once via
354    /// `llama_set_adapters_lora`. This clears all adapters from the context.
355    ///
356    /// # Errors
357    ///
358    /// See [`LlamaLoraAdapterRemoveError`] for more information.
359    pub fn lora_adapter_remove(
360        &self,
361        _adapter: &mut LlamaLoraAdapter,
362    ) -> Result<(), LlamaLoraAdapterRemoveError> {
363        let err_code = unsafe {
364            llama_cpp_bindings_sys::llama_set_adapters_lora(
365                self.context.as_ptr(),
366                std::ptr::null_mut(),
367                0,
368                std::ptr::null_mut(),
369            )
370        };
371        if err_code != 0 {
372            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
373        }
374
375        tracing::debug!("Remove lora adapter");
376        Ok(())
377    }
378}
379
380impl Drop for LlamaContext<'_> {
381    fn drop(&mut self) {
382        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
383    }
384}