Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::context::params::LlamaContextParams;
13use crate::llama_backend::LlamaBackend;
14use crate::llama_batch::LlamaBatch;
15use crate::model::{LlamaLoraAdapter, LlamaModel};
16use crate::timing::LlamaTimings;
17use crate::token::LlamaToken;
18use crate::token::data::LlamaTokenData;
19use crate::token::data_array::LlamaTokenDataArray;
20use crate::{
21    DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
22    LlamaLoraAdapterSetError, LogitsError,
23};
24
25const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
26    if err_code != 0 {
27        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
28    }
29
30    Ok(())
31}
32
33const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
34    if err_code != 0 {
35        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
36    }
37
38    Ok(())
39}
40
41pub mod kv_cache;
42pub mod llama_state_seq_flags;
43pub mod load_seq_state_error;
44pub mod load_session_error;
45pub mod params;
46pub mod save_seq_state_error;
47pub mod save_session_error;
48pub mod session;
49
50unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
51    let flag = unsafe { &*(data as *const AtomicBool) };
52
53    flag.load(Ordering::Relaxed)
54}
55
56/// Safe wrapper around `llama_context`.
57pub struct LlamaContext<'model> {
58    /// Raw pointer to the underlying `llama_context`.
59    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
60    /// A reference to the context's model.
61    pub model: &'model LlamaModel,
62    abort_flag: Option<Arc<AtomicBool>>,
63    initialized_logits: Vec<i32>,
64    embeddings_enabled: bool,
65}
66
67impl Debug for LlamaContext<'_> {
68    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
69        f.debug_struct("LlamaContext")
70            .field("context", &self.context)
71            .finish()
72    }
73}
74
75impl<'model> LlamaContext<'model> {
76    /// Wraps existing raw pointers into a new `LlamaContext`.
77    #[must_use]
78    pub const fn new(
79        llama_model: &'model LlamaModel,
80        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
81        embeddings_enabled: bool,
82    ) -> Self {
83        Self {
84            context: llama_context,
85            model: llama_model,
86            abort_flag: None,
87            initialized_logits: Vec::new(),
88            embeddings_enabled,
89        }
90    }
91
92    /// Create a new context bound to `model`.
93    ///
94    /// `_backend` is unused in the body but serves as a compile-time witness that
95    /// the global llama.cpp backend has been initialised before context creation.
96    ///
97    /// # Errors
98    ///
99    /// Returns [`LlamaContextLoadError`] when llama.cpp fails to allocate the context.
100    #[expect(
101        clippy::needless_pass_by_value,
102        reason = "LlamaContextParams may become non-trivially copyable upstream"
103    )]
104    pub fn from_model(
105        model: &'model LlamaModel,
106        _backend: &LlamaBackend,
107        params: LlamaContextParams,
108    ) -> Result<Self, LlamaContextLoadError> {
109        let context_params = params.context_params;
110        let context = unsafe {
111            llama_cpp_bindings_sys::llama_new_context_with_model(
112                model.model.as_ptr(),
113                context_params,
114            )
115        };
116        let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
117
118        Ok(Self::new(model, context, params.embeddings()))
119    }
120
121    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
122    #[must_use]
123    pub fn n_batch(&self) -> u32 {
124        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
125    }
126
127    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
128    #[must_use]
129    pub fn n_ubatch(&self) -> u32 {
130        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
131    }
132
133    /// Gets the size of the context.
134    #[must_use]
135    pub fn n_ctx(&self) -> u32 {
136        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
137    }
138
139    /// Sets an abort flag that llama.cpp checks during computation.
140    ///
141    /// When the flag is set to `true`, any in-progress `decode()` call will
142    /// abort and return `DecodeError::Aborted`. The `Arc` is stored internally
143    /// to ensure the flag outlives the callback registration.
144    #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
145    pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
146        let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
147        self.abort_flag = Some(flag);
148
149        unsafe {
150            llama_cpp_bindings_sys::llama_set_abort_callback(
151                self.context.as_ptr(),
152                Some(abort_callback_trampoline),
153                raw_ptr,
154            );
155        }
156    }
157
158    /// Clears the abort callback so that decode calls are no longer interruptible.
159    #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
160    pub fn clear_abort_callback(&mut self) {
161        self.abort_flag = None;
162
163        unsafe {
164            llama_cpp_bindings_sys::llama_set_abort_callback(
165                self.context.as_ptr(),
166                None,
167                std::ptr::null_mut(),
168            );
169        }
170    }
171
172    /// Waits for all pending backend operations to complete.
173    ///
174    /// Must be called before freeing the context to prevent hangs
175    /// during resource cleanup.
176    #[expect(unsafe_code, reason = "required for FFI synchronization call")]
177    pub fn synchronize(&self) {
178        unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
179    }
180
181    /// Detaches the threadpool from the context.
182    ///
183    /// Must be called before freeing the context to prevent threadpool
184    /// workers from accessing freed resources.
185    #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
186    pub fn detach_threadpool(&self) {
187        unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
188    }
189
190    /// Marks a logit index as initialized so it can be read via
191    /// `get_logits_ith`. Use after external decode operations (like
192    /// `eval_chunks`) that bypass the Rust `decode()` method.
193    pub fn mark_logits_initialized(&mut self, token_index: i32) {
194        self.initialized_logits = vec![token_index];
195    }
196
197    /// Decodes the batch.
198    ///
199    /// # Errors
200    ///
201    /// - `DecodeError` if the decoding failed.
202    ///
203    /// # Panics
204    ///
205    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
206    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
207        let result = unsafe {
208            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
209        };
210
211        match NonZeroI32::new(result) {
212            None => {
213                self.initialized_logits
214                    .clone_from(&batch.initialized_logits);
215                Ok(())
216            }
217            Some(error) => Err(DecodeError::from(error)),
218        }
219    }
220
221    /// Encodes the batch.
222    ///
223    /// # Errors
224    ///
225    /// - `EncodeError` if the decoding failed.
226    ///
227    /// # Panics
228    ///
229    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
230    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
231        let status = unsafe {
232            llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
233        };
234
235        self.handle_encode_result(status, batch)
236    }
237
238    fn handle_encode_result(
239        &mut self,
240        status: llama_cpp_bindings_sys::llama_rs_status,
241        batch: &mut LlamaBatch,
242    ) -> Result<(), EncodeError> {
243        if crate::status_is_ok(status) {
244            self.initialized_logits
245                .clone_from(&batch.initialized_logits);
246
247            Ok(())
248        } else {
249            Err(EncodeError::from(
250                NonZeroI32::new(crate::status_to_i32(status))
251                    .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
252            ))
253        }
254    }
255
256    /// Get the embeddings for the given sequence in the current context.
257    ///
258    /// # Returns
259    ///
260    /// A slice containing the embeddings for the last decoded batch.
261    /// The size corresponds to the `n_embd` parameter of the context's model.
262    ///
263    /// # Errors
264    ///
265    /// - When the current context was constructed without enabling embeddings.
266    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
267    /// - If the given sequence index exceeds the max sequence id.
268    ///
269    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
270        if !self.embeddings_enabled {
271            return Err(EmbeddingsError::NotEnabled);
272        }
273
274        let n_embd = usize::try_from(self.model.n_embd())
275            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
276
277        unsafe {
278            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
279                self.context.as_ptr(),
280                sequence_index,
281            );
282
283            if embedding.is_null() {
284                Err(EmbeddingsError::NonePoolType)
285            } else {
286                Ok(slice::from_raw_parts(embedding, n_embd))
287            }
288        }
289    }
290
291    /// Get the embeddings for the given token in the current context.
292    ///
293    /// # Returns
294    ///
295    /// A slice containing the embeddings for the last decoded batch of the given token.
296    /// The size corresponds to the `n_embd` parameter of the context's model.
297    ///
298    /// # Errors
299    ///
300    /// - When the current context was constructed without enabling embeddings.
301    /// - When the given token didn't have logits enabled when it was passed.
302    /// - If the given token index exceeds the max token id.
303    ///
304    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
305        if !self.embeddings_enabled {
306            return Err(EmbeddingsError::NotEnabled);
307        }
308
309        let n_embd = usize::try_from(self.model.n_embd())
310            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
311
312        unsafe {
313            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
314                self.context.as_ptr(),
315                token_index,
316            );
317
318            if embedding.is_null() {
319                Err(EmbeddingsError::LogitsNotEnabled)
320            } else {
321                Ok(slice::from_raw_parts(embedding, n_embd))
322            }
323        }
324    }
325
326    /// Get the logits for the last token in the context.
327    ///
328    /// # Returns
329    /// An iterator over unsorted `LlamaTokenData` containing the
330    /// logits for the last token in the context.
331    ///
332    /// # Errors
333    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
334    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
335        let logits = self.get_logits()?;
336
337        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
338            let token = LlamaToken::new(token_id);
339            LlamaTokenData::new(token, *logit, 0_f32)
340        }))
341    }
342
343    /// Get the token data array for the last token in the context.
344    ///
345    /// # Errors
346    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
347    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
348        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
349    }
350
351    /// Token logits obtained from the last call to `decode()`.
352    /// The logits for which `batch.logits[i] != 0` are stored contiguously
353    /// in the order they have appeared in the batch.
354    /// Rows: number of tokens for which `batch.logits[i] != 0`
355    /// Cols: `n_vocab`
356    ///
357    /// # Returns
358    ///
359    /// A slice containing the logits for the last decoded token.
360    /// The size corresponds to the `n_vocab` parameter of the context's model.
361    ///
362    /// # Errors
363    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
364    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
365        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
366
367        if data.is_null() {
368            return Err(LogitsError::NullLogits);
369        }
370
371        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
372
373        Ok(unsafe { slice::from_raw_parts(data, len) })
374    }
375
376    /// Get the logits for the ith token in the context.
377    ///
378    /// # Errors
379    /// Returns `LogitsError` if the token is not initialized or out of range.
380    pub fn candidates_ith(
381        &self,
382        token_index: i32,
383    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
384        let logits = self.get_logits_ith(token_index)?;
385
386        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
387            let token = LlamaToken::new(token_id);
388            LlamaTokenData::new(token, *logit, 0_f32)
389        }))
390    }
391
392    /// Get the token data array for the ith token in the context.
393    ///
394    /// # Errors
395    /// Returns `LogitsError` if the token is not initialized or out of range.
396    pub fn token_data_array_ith(
397        &self,
398        token_index: i32,
399    ) -> Result<LlamaTokenDataArray, LogitsError> {
400        Ok(LlamaTokenDataArray::from_iter(
401            self.candidates_ith(token_index)?,
402            false,
403        ))
404    }
405
406    /// Get the logits for the ith token in the context.
407    ///
408    /// # Errors
409    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
410    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
411        if !self.initialized_logits.contains(&token_index) {
412            return Err(LogitsError::TokenNotInitialized(token_index));
413        }
414
415        if token_index >= 0 {
416            let token_index_u32 =
417                u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
418
419            if self.n_ctx() <= token_index_u32 {
420                return Err(LogitsError::TokenIndexExceedsContext {
421                    token_index: token_index_u32,
422                    context_size: self.n_ctx(),
423                });
424            }
425        }
426
427        let data = unsafe {
428            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
429        };
430        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
431
432        Ok(unsafe { slice::from_raw_parts(data, len) })
433    }
434
435    /// Reset the timings for the context.
436    pub fn reset_timings(&mut self) {
437        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
438    }
439
440    /// Returns the timings for the context.
441    pub fn timings(&mut self) -> LlamaTimings {
442        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
443        LlamaTimings { timings }
444    }
445
446    /// Sets a lora adapter.
447    ///
448    /// # Errors
449    ///
450    /// See [`LlamaLoraAdapterSetError`] for more information.
451    pub fn lora_adapter_set(
452        &self,
453        adapter: &mut LlamaLoraAdapter,
454        scale: f32,
455    ) -> Result<(), LlamaLoraAdapterSetError> {
456        let mut adapters = [adapter.lora_adapter.as_ptr()];
457        let mut scales = [scale];
458        let err_code = unsafe {
459            llama_cpp_bindings_sys::llama_set_adapters_lora(
460                self.context.as_ptr(),
461                adapters.as_mut_ptr(),
462                1,
463                scales.as_mut_ptr(),
464            )
465        };
466        check_lora_set_result(err_code)?;
467
468        tracing::debug!("Set lora adapter");
469        Ok(())
470    }
471
472    /// Remove all lora adapters.
473    ///
474    /// Note: The upstream API now replaces all adapters at once via
475    /// `llama_set_adapters_lora`. This clears all adapters from the context.
476    ///
477    /// # Errors
478    ///
479    /// See [`LlamaLoraAdapterRemoveError`] for more information.
480    pub fn lora_adapter_remove(
481        &self,
482        _adapter: &mut LlamaLoraAdapter,
483    ) -> Result<(), LlamaLoraAdapterRemoveError> {
484        let err_code = unsafe {
485            llama_cpp_bindings_sys::llama_set_adapters_lora(
486                self.context.as_ptr(),
487                std::ptr::null_mut(),
488                0,
489                std::ptr::null_mut(),
490            )
491        };
492        check_lora_remove_result(err_code)?;
493
494        tracing::debug!("Remove lora adapter");
495        Ok(())
496    }
497}
498
499impl Drop for LlamaContext<'_> {
500    fn drop(&mut self) {
501        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
502    }
503}
504
505#[cfg(test)]
506mod unit_tests {
507    use crate::LlamaLoraAdapterRemoveError;
508    use crate::LlamaLoraAdapterSetError;
509
510    use super::{check_lora_remove_result, check_lora_set_result};
511
512    #[test]
513    fn check_lora_set_result_ok_for_zero() {
514        assert!(check_lora_set_result(0).is_ok());
515    }
516
517    #[test]
518    fn check_lora_set_result_error_for_nonzero() {
519        let result = check_lora_set_result(-1);
520
521        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
522    }
523
524    #[test]
525    fn check_lora_remove_result_ok_for_zero() {
526        assert!(check_lora_remove_result(0).is_ok());
527    }
528
529    #[test]
530    fn check_lora_remove_result_error_for_nonzero() {
531        let result = check_lora_remove_result(-1);
532
533        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
534    }
535}