llama_cpp_2/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::data::LlamaTokenData;
12use crate::token::data_array::LlamaTokenDataArray;
13use crate::token::LlamaToken;
14use crate::{
15    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16    LlamaLoraAdapterSetError,
17};
18
19pub mod kv_cache;
20pub mod params;
21pub mod session;
22
23/// Safe wrapper around `llama_context`.
24#[allow(clippy::module_name_repetitions)]
25pub struct LlamaContext<'a> {
26    pub(crate) context: NonNull<llama_cpp_sys_2::llama_context>,
27    /// a reference to the contexts model.
28    pub model: &'a LlamaModel,
29    initialized_logits: Vec<i32>,
30    embeddings_enabled: bool,
31}
32
33impl Debug for LlamaContext<'_> {
34    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
35        f.debug_struct("LlamaContext")
36            .field("context", &self.context)
37            .finish()
38    }
39}
40
41impl<'model> LlamaContext<'model> {
42    pub(crate) fn new(
43        llama_model: &'model LlamaModel,
44        llama_context: NonNull<llama_cpp_sys_2::llama_context>,
45        embeddings_enabled: bool,
46    ) -> Self {
47        Self {
48            context: llama_context,
49            model: llama_model,
50            initialized_logits: Vec::new(),
51            embeddings_enabled,
52        }
53    }
54
55    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
56    #[must_use]
57    pub fn n_batch(&self) -> u32 {
58        unsafe { llama_cpp_sys_2::llama_n_batch(self.context.as_ptr()) }
59    }
60
61    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
62    #[must_use]
63    pub fn n_ubatch(&self) -> u32 {
64        unsafe { llama_cpp_sys_2::llama_n_ubatch(self.context.as_ptr()) }
65    }
66
67    /// Gets the size of the context.
68    #[must_use]
69    pub fn n_ctx(&self) -> u32 {
70        unsafe { llama_cpp_sys_2::llama_n_ctx(self.context.as_ptr()) }
71    }
72
73    /// Decodes the batch.
74    ///
75    /// # Errors
76    ///
77    /// - `DecodeError` if the decoding failed.
78    ///
79    /// # Panics
80    ///
81    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
82    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
83        let result =
84            unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };
85
86        match NonZeroI32::new(result) {
87            None => {
88                self.initialized_logits
89                    .clone_from(&batch.initialized_logits);
90                Ok(())
91            }
92            Some(error) => Err(DecodeError::from(error)),
93        }
94    }
95
96    /// Encodes the batch.
97    ///
98    /// # Errors
99    ///
100    /// - `EncodeError` if the decoding failed.
101    ///
102    /// # Panics
103    ///
104    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
105    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
106        let result =
107            unsafe { llama_cpp_sys_2::llama_encode(self.context.as_ptr(), batch.llama_batch) };
108
109        match NonZeroI32::new(result) {
110            None => {
111                self.initialized_logits
112                    .clone_from(&batch.initialized_logits);
113                Ok(())
114            }
115            Some(error) => Err(EncodeError::from(error)),
116        }
117    }
118
119    /// Get the embeddings for the `i`th sequence in the current context.
120    ///
121    /// # Returns
122    ///
123    /// A slice containing the embeddings for the last decoded batch.
124    /// The size corresponds to the `n_embd` parameter of the context's model.
125    ///
126    /// # Errors
127    ///
128    /// - When the current context was constructed without enabling embeddings.
129    /// - If the current model had a pooling type of [`llama_cpp_sys_2::LLAMA_POOLING_TYPE_NONE`]
130    /// - If the given sequence index exceeds the max sequence id.
131    ///
132    /// # Panics
133    ///
134    /// * `n_embd` does not fit into a usize
135    pub fn embeddings_seq_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
136        if !self.embeddings_enabled {
137            return Err(EmbeddingsError::NotEnabled);
138        }
139
140        let n_embd =
141            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
142
143        unsafe {
144            let embedding = llama_cpp_sys_2::llama_get_embeddings_seq(self.context.as_ptr(), i);
145
146            // Technically also possible whenever `i >= max(batch.n_seq)`, but can't check that here.
147            if embedding.is_null() {
148                Err(EmbeddingsError::NonePoolType)
149            } else {
150                Ok(slice::from_raw_parts(embedding, n_embd))
151            }
152        }
153    }
154
155    /// Get the embeddings for the `i`th token in the current context.
156    ///
157    /// # Returns
158    ///
159    /// A slice containing the embeddings for the last decoded batch of the given token.
160    /// The size corresponds to the `n_embd` parameter of the context's model.
161    ///
162    /// # Errors
163    ///
164    /// - When the current context was constructed without enabling embeddings.
165    /// - When the given token didn't have logits enabled when it was passed.
166    /// - If the given token index exceeds the max token id.
167    ///
168    /// # Panics
169    ///
170    /// * `n_embd` does not fit into a usize
171    pub fn embeddings_ith(&self, i: i32) -> Result<&[f32], EmbeddingsError> {
172        if !self.embeddings_enabled {
173            return Err(EmbeddingsError::NotEnabled);
174        }
175
176        let n_embd =
177            usize::try_from(self.model.n_embd()).expect("n_embd does not fit into a usize");
178
179        unsafe {
180            let embedding = llama_cpp_sys_2::llama_get_embeddings_ith(self.context.as_ptr(), i);
181            // Technically also possible whenever `i >= batch.n_tokens`, but no good way of checking `n_tokens` here.
182            if embedding.is_null() {
183                Err(EmbeddingsError::LogitsNotEnabled)
184            } else {
185                Ok(slice::from_raw_parts(embedding, n_embd))
186            }
187        }
188    }
189
190    /// Get the logits for the last token in the context.
191    ///
192    /// # Returns
193    /// An iterator over unsorted `LlamaTokenData` containing the
194    /// logits for the last token in the context.
195    ///
196    /// # Panics
197    ///
198    /// - underlying logits data is null
199    pub fn candidates(&self) -> impl Iterator<Item = LlamaTokenData> + '_ {
200        (0_i32..).zip(self.get_logits()).map(|(i, logit)| {
201            let token = LlamaToken::new(i);
202            LlamaTokenData::new(token, *logit, 0_f32)
203        })
204    }
205
206    /// Get the token data array for the last token in the context.
207    ///
208    /// This is a convience method that implements:
209    /// ```ignore
210    /// LlamaTokenDataArray::from_iter(ctx.candidates(), false)
211    /// ```
212    ///
213    /// # Panics
214    ///
215    /// - underlying logits data is null
216    #[must_use]
217    pub fn token_data_array(&self) -> LlamaTokenDataArray {
218        LlamaTokenDataArray::from_iter(self.candidates(), false)
219    }
220
221    /// Token logits obtained from the last call to `decode()`.
222    /// The logits for which `batch.logits[i] != 0` are stored contiguously
223    /// in the order they have appeared in the batch.
224    /// Rows: number of tokens for which `batch.logits[i] != 0`
225    /// Cols: `n_vocab`
226    ///
227    /// # Returns
228    ///
229    /// A slice containing the logits for the last decoded token.
230    /// The size corresponds to the `n_vocab` parameter of the context's model.
231    ///
232    /// # Panics
233    ///
234    /// - `n_vocab` does not fit into a usize
235    /// - token data returned is null
236    #[must_use]
237    pub fn get_logits(&self) -> &[f32] {
238        let data = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
239        assert!(!data.is_null(), "logits data for last token is null");
240        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
241
242        unsafe { slice::from_raw_parts(data, len) }
243    }
244
245    /// Get the logits for the ith token in the context.
246    ///
247    /// # Panics
248    ///
249    /// - logit `i` is not initialized.
250    pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
251        (0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
252            let token = LlamaToken::new(i);
253            LlamaTokenData::new(token, *logit, 0_f32)
254        })
255    }
256
257    /// Get the token data array for the ith token in the context.
258    ///
259    /// This is a convience method that implements:
260    /// ```ignore
261    /// LlamaTokenDataArray::from_iter(ctx.candidates_ith(i), false)
262    /// ```
263    ///
264    /// # Panics
265    ///
266    /// - logit `i` is not initialized.
267    #[must_use]
268    pub fn token_data_array_ith(&self, i: i32) -> LlamaTokenDataArray {
269        LlamaTokenDataArray::from_iter(self.candidates_ith(i), false)
270    }
271
272    /// Get the logits for the ith token in the context.
273    ///
274    /// # Panics
275    ///
276    /// - `i` is greater than `n_ctx`
277    /// - `n_vocab` does not fit into a usize
278    /// - logit `i` is not initialized.
279    #[must_use]
280    pub fn get_logits_ith(&self, i: i32) -> &[f32] {
281        assert!(
282            self.initialized_logits.contains(&i),
283            "logit {i} is not initialized. only {:?} is",
284            self.initialized_logits
285        );
286        assert!(
287            self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
288            "n_ctx ({}) must be greater than i ({})",
289            self.n_ctx(),
290            i
291        );
292
293        let data = unsafe { llama_cpp_sys_2::llama_get_logits_ith(self.context.as_ptr(), i) };
294        let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
295
296        unsafe { slice::from_raw_parts(data, len) }
297    }
298
299    /// Reset the timings for the context.
300    pub fn reset_timings(&mut self) {
301        unsafe { llama_cpp_sys_2::llama_perf_context_reset(self.context.as_ptr()) }
302    }
303
304    /// Returns the timings for the context.
305    pub fn timings(&mut self) -> LlamaTimings {
306        let timings = unsafe { llama_cpp_sys_2::llama_perf_context(self.context.as_ptr()) };
307        LlamaTimings { timings }
308    }
309
310    /// Sets a lora adapter.
311    ///
312    /// # Errors
313    ///
314    /// See [`LlamaLoraAdapterSetError`] for more information.
315    pub fn lora_adapter_set(
316        &self,
317        adapter: &mut LlamaLoraAdapter,
318        scale: f32,
319    ) -> Result<(), LlamaLoraAdapterSetError> {
320        let err_code = unsafe {
321            llama_cpp_sys_2::llama_set_adapter_lora(
322                self.context.as_ptr(),
323                adapter.lora_adapter.as_ptr(),
324                scale,
325            )
326        };
327        if err_code != 0 {
328            return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
329        }
330
331        tracing::debug!("Set lora adapter");
332        Ok(())
333    }
334
335    /// Remove a lora adapter.
336    ///
337    /// # Errors
338    ///
339    /// See [`LlamaLoraAdapterRemoveError`] for more information.
340    pub fn lora_adapter_remove(
341        &self,
342        adapter: &mut LlamaLoraAdapter,
343    ) -> Result<(), LlamaLoraAdapterRemoveError> {
344        let err_code = unsafe {
345            llama_cpp_sys_2::llama_rm_adapter_lora(
346                self.context.as_ptr(),
347                adapter.lora_adapter.as_ptr(),
348            )
349        };
350        if err_code != 0 {
351            return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
352        }
353
354        tracing::debug!("Remove lora adapter");
355        Ok(())
356    }
357}
358
359impl Drop for LlamaContext<'_> {
360    fn drop(&mut self) {
361        unsafe { llama_cpp_sys_2::llama_free(self.context.as_ptr()) }
362    }
363}