Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::context::params::LlamaContextParams;
13use crate::llama_backend::LlamaBackend;
14use crate::llama_batch::LlamaBatch;
15use crate::model::{LlamaLoraAdapter, LlamaModel};
16use crate::timing::LlamaTimings;
17use crate::token::LlamaToken;
18use crate::token::data::LlamaTokenData;
19use crate::token::data_array::LlamaTokenDataArray;
20use crate::{
21    DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
22    LlamaLoraAdapterSetError, LogitsError,
23};
24
25const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
26    if err_code != 0 {
27        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
28    }
29
30    Ok(())
31}
32
33const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
34    if err_code != 0 {
35        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
36    }
37
38    Ok(())
39}
40
41pub mod kv_cache;
42pub mod kv_cache_type;
43pub mod llama_attention_type;
44pub mod llama_pooling_type;
45pub mod llama_state_seq_flags;
46pub mod load_seq_state_error;
47pub mod load_session_error;
48pub mod params;
49pub mod rope_scaling_type;
50pub mod save_seq_state_error;
51pub mod save_session_error;
52pub mod session;
53
54unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
55    let flag = unsafe { &*(data as *const AtomicBool) };
56
57    flag.load(Ordering::Relaxed)
58}
59
60/// Safe wrapper around `llama_context`.
61pub struct LlamaContext<'model> {
62    /// Raw pointer to the underlying `llama_context`.
63    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
64    /// A reference to the context's model.
65    pub model: &'model LlamaModel,
66    abort_flag: Option<Arc<AtomicBool>>,
67    initialized_logits: Vec<i32>,
68    embeddings_enabled: bool,
69}
70
71impl Debug for LlamaContext<'_> {
72    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("LlamaContext")
74            .field("context", &self.context)
75            .finish()
76    }
77}
78
79impl<'model> LlamaContext<'model> {
80    /// Wraps existing raw pointers into a new `LlamaContext`.
81    #[must_use]
82    pub const fn new(
83        llama_model: &'model LlamaModel,
84        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
85        embeddings_enabled: bool,
86    ) -> Self {
87        Self {
88            context: llama_context,
89            model: llama_model,
90            abort_flag: None,
91            initialized_logits: Vec::new(),
92            embeddings_enabled,
93        }
94    }
95
96    /// Create a new context bound to `model`.
97    ///
98    /// `_backend` is unused in the body but serves as a compile-time witness that
99    /// the global llama.cpp backend has been initialised before context creation.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`LlamaContextLoadError`] when llama.cpp fails to allocate the context.
104    #[expect(
105        clippy::needless_pass_by_value,
106        reason = "LlamaContextParams may become non-trivially copyable upstream"
107    )]
108    pub fn from_model(
109        model: &'model LlamaModel,
110        _backend: &LlamaBackend,
111        params: LlamaContextParams,
112    ) -> Result<Self, LlamaContextLoadError> {
113        let context_params = params.context_params;
114        let mut out_ctx: *mut llama_cpp_bindings_sys::llama_context = std::ptr::null_mut();
115        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
116        let status = unsafe {
117            llama_cpp_bindings_sys::llama_rs_new_context_with_model(
118                model.model.as_ptr(),
119                context_params,
120                &raw mut out_ctx,
121                &raw mut out_error,
122            )
123        };
124        match status {
125            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_OK => {
126                let context = NonNull::new(out_ctx)
127                    .ok_or(LlamaContextLoadError::Unconstructible)?;
128                Ok(Self::new(model, context, params.embeddings()))
129            }
130            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_RETURNED_NULL => {
131                Err(LlamaContextLoadError::Unconstructible)
132            }
133            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_ERROR_STRING_ALLOCATION_FAILED => {
134                Err(LlamaContextLoadError::NotEnoughMemory)
135            }
136            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_THREW_CXX_EXCEPTION => {
137                let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
138                Err(LlamaContextLoadError::Reported { message })
139            }
140            other => unreachable!(
141                "llama_rs_new_context_with_model returned unrecognized status {other}"
142            ),
143        }
144    }
145
146    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
147    #[must_use]
148    pub fn n_batch(&self) -> u32 {
149        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
150    }
151
152    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
153    #[must_use]
154    pub fn n_ubatch(&self) -> u32 {
155        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
156    }
157
158    /// Gets the size of the context.
159    #[must_use]
160    pub fn n_ctx(&self) -> u32 {
161        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
162    }
163
164    /// Sets an abort flag that llama.cpp checks during computation.
165    ///
166    /// When the flag is set to `true`, any in-progress `decode()` call will
167    /// abort and return `DecodeError::Aborted`. The `Arc` is stored internally
168    /// to ensure the flag outlives the callback registration.
169    #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
170    pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
171        let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
172        self.abort_flag = Some(flag);
173
174        unsafe {
175            llama_cpp_bindings_sys::llama_set_abort_callback(
176                self.context.as_ptr(),
177                Some(abort_callback_trampoline),
178                raw_ptr,
179            );
180        }
181    }
182
183    /// Clears the abort callback so that decode calls are no longer interruptible.
184    #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
185    pub fn clear_abort_callback(&mut self) {
186        self.abort_flag = None;
187
188        unsafe {
189            llama_cpp_bindings_sys::llama_set_abort_callback(
190                self.context.as_ptr(),
191                None,
192                std::ptr::null_mut(),
193            );
194        }
195    }
196
197    /// Waits for all pending backend operations to complete.
198    ///
199    /// Must be called before freeing the context to prevent hangs
200    /// during resource cleanup.
201    #[expect(unsafe_code, reason = "required for FFI synchronization call")]
202    pub fn synchronize(&self) {
203        unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
204    }
205
206    /// Detaches the threadpool from the context.
207    ///
208    /// Must be called before freeing the context to prevent threadpool
209    /// workers from accessing freed resources.
210    #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
211    pub fn detach_threadpool(&self) {
212        unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
213    }
214
215    /// Marks a logit index as initialized so it can be read via
216    /// `get_logits_ith`. Use after external decode operations (like
217    /// `eval_chunks`) that bypass the Rust `decode()` method.
218    pub fn mark_logits_initialized(&mut self, token_index: i32) {
219        self.initialized_logits = vec![token_index];
220    }
221
222    /// Decodes the batch.
223    ///
224    /// # Errors
225    ///
226    /// - `DecodeError` if the decoding failed.
227    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
228        let mut out_vendored_return_code: i32 = 0;
229        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
230        let status = unsafe {
231            llama_cpp_bindings_sys::llama_rs_decode(
232                self.context.as_ptr(),
233                batch.llama_batch,
234                &raw mut out_vendored_return_code,
235                &raw mut out_error,
236            )
237        };
238        match status {
239            llama_cpp_bindings_sys::LLAMA_RS_DECODE_OK => {
240                self.initialized_logits
241                    .clone_from(&batch.initialized_logits);
242                Ok(())
243            }
244            llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_RETURNED_NONZERO_CODE => {
245                let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
246                    unreachable!(
247                        "llama_rs_decode reported a nonzero return code but the value was zero"
248                    )
249                });
250                Err(DecodeError::from(code))
251            }
252            llama_cpp_bindings_sys::LLAMA_RS_DECODE_OUT_OF_MEMORY => {
253                Err(DecodeError::DecodeOutOfMemory)
254            }
255            llama_cpp_bindings_sys::LLAMA_RS_DECODE_COMPUTE_FAILED => {
256                Err(DecodeError::ComputeFailed)
257            }
258            llama_cpp_bindings_sys::LLAMA_RS_DECODE_ERROR_STRING_ALLOCATION_FAILED => {
259                Err(DecodeError::NotEnoughMemory)
260            }
261            llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_THREW_CXX_EXCEPTION => {
262                let message =
263                    unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
264                Err(DecodeError::Reported { message })
265            }
266            other => unreachable!("llama_rs_decode returned unrecognized status {other}"),
267        }
268    }
269
270    /// Encodes the batch.
271    ///
272    /// # Errors
273    ///
274    /// - `EncodeError` if the encoding failed.
275    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
276        let mut out_vendored_return_code: i32 = 0;
277        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
278        let status = unsafe {
279            llama_cpp_bindings_sys::llama_rs_encode(
280                self.context.as_ptr(),
281                batch.llama_batch,
282                &raw mut out_vendored_return_code,
283                &raw mut out_error,
284            )
285        };
286        match status {
287            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OK => {
288                self.initialized_logits
289                    .clone_from(&batch.initialized_logits);
290                Ok(())
291            }
292            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_MODEL_HAS_NO_ENCODER => {
293                Err(EncodeError::ModelHasNoEncoder)
294            }
295            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_RETURNED_NONZERO_CODE => {
296                let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
297                    unreachable!(
298                        "llama_rs_encode reported a nonzero return code but the value was zero"
299                    )
300                });
301                Err(EncodeError::from(code))
302            }
303            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OUT_OF_MEMORY => {
304                Err(EncodeError::EncodeOutOfMemory)
305            }
306            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_COMPUTE_FAILED => {
307                Err(EncodeError::ComputeFailed)
308            }
309            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_ERROR_STRING_ALLOCATION_FAILED => {
310                Err(EncodeError::NotEnoughMemory)
311            }
312            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_THREW_CXX_EXCEPTION => {
313                let message =
314                    unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
315                Err(EncodeError::Reported { message })
316            }
317            other => unreachable!("llama_rs_encode returned unrecognized status {other}"),
318        }
319    }
320
321    /// Get the embeddings for the given sequence in the current context.
322    ///
323    /// # Returns
324    ///
325    /// A slice containing the embeddings for the last decoded batch.
326    /// The size corresponds to the `n_embd` parameter of the context's model.
327    ///
328    /// # Errors
329    ///
330    /// - When the current context was constructed without enabling embeddings.
331    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
332    /// - If the given sequence index exceeds the max sequence id.
333    ///
334    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
335        if !self.embeddings_enabled {
336            return Err(EmbeddingsError::NotEnabled);
337        }
338
339        let n_embd = usize::try_from(self.model.n_embd())
340            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
341
342        unsafe {
343            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
344                self.context.as_ptr(),
345                sequence_index,
346            );
347
348            if embedding.is_null() {
349                Err(EmbeddingsError::NonePoolType)
350            } else {
351                Ok(slice::from_raw_parts(embedding, n_embd))
352            }
353        }
354    }
355
356    /// Get the embeddings for the given token in the current context.
357    ///
358    /// # Returns
359    ///
360    /// A slice containing the embeddings for the last decoded batch of the given token.
361    /// The size corresponds to the `n_embd` parameter of the context's model.
362    ///
363    /// # Errors
364    ///
365    /// - When the current context was constructed without enabling embeddings.
366    /// - When the given token didn't have logits enabled when it was passed.
367    /// - If the given token index exceeds the max token id.
368    ///
369    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
370        if !self.embeddings_enabled {
371            return Err(EmbeddingsError::NotEnabled);
372        }
373
374        let n_embd = usize::try_from(self.model.n_embd())
375            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
376
377        unsafe {
378            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
379                self.context.as_ptr(),
380                token_index,
381            );
382
383            if embedding.is_null() {
384                Err(EmbeddingsError::LogitsNotEnabled)
385            } else {
386                Ok(slice::from_raw_parts(embedding, n_embd))
387            }
388        }
389    }
390
391    /// Get the logits for the last token in the context.
392    ///
393    /// # Returns
394    /// An iterator over unsorted `LlamaTokenData` containing the
395    /// logits for the last token in the context.
396    ///
397    /// # Errors
398    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
399    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
400        let logits = self.get_logits()?;
401
402        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
403            let token = LlamaToken::new(token_id);
404            LlamaTokenData::new(token, *logit, 0_f32)
405        }))
406    }
407
408    /// Get the token data array for the last token in the context.
409    ///
410    /// # Errors
411    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
412    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
413        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
414    }
415
416    /// Token logits obtained from the last call to `decode()`.
417    /// The logits for which `batch.logits[i] != 0` are stored contiguously
418    /// in the order they have appeared in the batch.
419    /// Rows: number of tokens for which `batch.logits[i] != 0`
420    /// Cols: `n_vocab`
421    ///
422    /// # Returns
423    ///
424    /// A slice containing the logits for the last decoded token.
425    /// The size corresponds to the `n_vocab` parameter of the context's model.
426    ///
427    /// # Errors
428    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
429    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
430        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
431
432        if data.is_null() {
433            return Err(LogitsError::NullLogits);
434        }
435
436        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
437
438        Ok(unsafe { slice::from_raw_parts(data, len) })
439    }
440
441    /// Get the logits for the ith token in the context.
442    ///
443    /// # Errors
444    /// Returns `LogitsError` if the token is not initialized or out of range.
445    pub fn candidates_ith(
446        &self,
447        token_index: i32,
448    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
449        let logits = self.get_logits_ith(token_index)?;
450
451        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
452            let token = LlamaToken::new(token_id);
453            LlamaTokenData::new(token, *logit, 0_f32)
454        }))
455    }
456
457    /// Get the token data array for the ith token in the context.
458    ///
459    /// # Errors
460    /// Returns `LogitsError` if the token is not initialized or out of range.
461    pub fn token_data_array_ith(
462        &self,
463        token_index: i32,
464    ) -> Result<LlamaTokenDataArray, LogitsError> {
465        Ok(LlamaTokenDataArray::from_iter(
466            self.candidates_ith(token_index)?,
467            false,
468        ))
469    }
470
471    /// Get the logits for the ith token in the context.
472    ///
473    /// # Errors
474    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
475    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
476        if !self.initialized_logits.contains(&token_index) {
477            return Err(LogitsError::TokenNotInitialized(token_index));
478        }
479
480        if token_index >= 0 {
481            let token_index_u32 =
482                u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
483
484            if self.n_ctx() <= token_index_u32 {
485                return Err(LogitsError::TokenIndexExceedsContext {
486                    token_index: token_index_u32,
487                    context_size: self.n_ctx(),
488                });
489            }
490        }
491
492        let data = unsafe {
493            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
494        };
495        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
496
497        Ok(unsafe { slice::from_raw_parts(data, len) })
498    }
499
500    /// Reset the timings for the context.
501    pub fn reset_timings(&mut self) {
502        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
503    }
504
505    /// Returns the timings for the context.
506    pub fn timings(&mut self) -> LlamaTimings {
507        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
508        LlamaTimings { timings }
509    }
510
511    /// Sets a lora adapter.
512    ///
513    /// # Errors
514    ///
515    /// See [`LlamaLoraAdapterSetError`] for more information.
516    pub fn lora_adapter_set(
517        &self,
518        adapter: &mut LlamaLoraAdapter,
519        scale: f32,
520    ) -> Result<(), LlamaLoraAdapterSetError> {
521        let mut adapters = [adapter.lora_adapter.as_ptr()];
522        let mut scales = [scale];
523        let err_code = unsafe {
524            llama_cpp_bindings_sys::llama_set_adapters_lora(
525                self.context.as_ptr(),
526                adapters.as_mut_ptr(),
527                1,
528                scales.as_mut_ptr(),
529            )
530        };
531        check_lora_set_result(err_code)?;
532
533        log::debug!("Set lora adapter");
534        Ok(())
535    }
536
537    /// Remove all lora adapters.
538    ///
539    /// Note: The upstream API now replaces all adapters at once via
540    /// `llama_set_adapters_lora`. This clears all adapters from the context.
541    ///
542    /// # Errors
543    ///
544    /// See [`LlamaLoraAdapterRemoveError`] for more information.
545    pub fn lora_adapter_remove(
546        &self,
547        _adapter: &mut LlamaLoraAdapter,
548    ) -> Result<(), LlamaLoraAdapterRemoveError> {
549        let err_code = unsafe {
550            llama_cpp_bindings_sys::llama_set_adapters_lora(
551                self.context.as_ptr(),
552                std::ptr::null_mut(),
553                0,
554                std::ptr::null_mut(),
555            )
556        };
557        check_lora_remove_result(err_code)?;
558
559        log::debug!("Remove lora adapter");
560        Ok(())
561    }
562}
563
564impl Drop for LlamaContext<'_> {
565    fn drop(&mut self) {
566        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
567    }
568}
569
570#[cfg(test)]
571mod unit_tests {
572    use crate::LlamaLoraAdapterRemoveError;
573    use crate::LlamaLoraAdapterSetError;
574
575    use super::{check_lora_remove_result, check_lora_set_result};
576
577    #[test]
578    fn check_lora_set_result_ok_for_zero() {
579        assert!(check_lora_set_result(0).is_ok());
580    }
581
582    #[test]
583    fn check_lora_set_result_error_for_nonzero() {
584        let result = check_lora_set_result(-1);
585
586        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
587    }
588
589    #[test]
590    fn check_lora_remove_result_ok_for_zero() {
591        assert!(check_lora_remove_result(0).is_ok());
592    }
593
594    #[test]
595    fn check_lora_remove_result_error_for_nonzero() {
596        let result = check_lora_remove_result(-1);
597
598        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
599    }
600}