Skip to main content

llama_cpp_bindings/
context.rs

1use std::ffi::c_void;
2use std::fmt::{Debug, Formatter};
3use std::num::NonZeroI32;
4use std::ptr::NonNull;
5use std::slice;
6use std::sync::Arc;
7use std::sync::atomic::AtomicBool;
8use std::sync::atomic::Ordering;
9
10use crate::context::params::LlamaContextParams;
11use crate::llama_backend::LlamaBackend;
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19    DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
20    LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24    if err_code != 0 {
25        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26    }
27
28    Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32    if err_code != 0 {
33        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34    }
35
36    Ok(())
37}
38
39pub mod kv_cache;
40pub mod kv_cache_type;
41pub mod llama_attention_type;
42pub mod llama_pooling_type;
43pub mod llama_state_seq_flags;
44pub mod load_seq_state_error;
45pub mod load_session_error;
46pub mod params;
47pub mod rope_scaling_type;
48pub mod save_seq_state_error;
49pub mod save_session_error;
50pub mod session;
51
52unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
53    let flag = unsafe { &*(data as *const AtomicBool) };
54
55    flag.load(Ordering::Relaxed)
56}
57
58pub struct LlamaContext<'model> {
59    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
60    pub model: &'model LlamaModel,
61    abort_flag: Option<Arc<AtomicBool>>,
62    initialized_logits: Vec<i32>,
63    embeddings_enabled: bool,
64}
65
66impl Debug for LlamaContext<'_> {
67    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
68        f.debug_struct("LlamaContext")
69            .field("context", &self.context)
70            .finish()
71    }
72}
73
74impl<'model> LlamaContext<'model> {
75    #[must_use]
76    pub const fn new(
77        llama_model: &'model LlamaModel,
78        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79        embeddings_enabled: bool,
80    ) -> Self {
81        Self {
82            context: llama_context,
83            model: llama_model,
84            abort_flag: None,
85            initialized_logits: Vec::new(),
86            embeddings_enabled,
87        }
88    }
89
90    /// # Errors
91    ///
92    /// Returns [`LlamaContextLoadError`] when llama.cpp fails to allocate the context.
93    #[expect(
94        clippy::needless_pass_by_value,
95        reason = "LlamaContextParams may become non-trivially copyable upstream"
96    )]
97    pub fn from_model(
98        model: &'model LlamaModel,
99        _backend: &LlamaBackend,
100        params: LlamaContextParams,
101    ) -> Result<Self, LlamaContextLoadError> {
102        let context_params = params.context_params;
103        let mut out_ctx: *mut llama_cpp_bindings_sys::llama_context = std::ptr::null_mut();
104        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
105        let status = unsafe {
106            llama_cpp_bindings_sys::llama_rs_new_context_with_model(
107                model.model.as_ptr(),
108                context_params,
109                &raw mut out_ctx,
110                &raw mut out_error,
111            )
112        };
113        match status {
114            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_OK => {
115                let context = NonNull::new(out_ctx)
116                    .ok_or(LlamaContextLoadError::Unconstructible)?;
117                Ok(Self::new(model, context, params.embeddings()))
118            }
119            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_RETURNED_NULL => {
120                Err(LlamaContextLoadError::Unconstructible)
121            }
122            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_ERROR_STRING_ALLOCATION_FAILED => {
123                Err(LlamaContextLoadError::NotEnoughMemory)
124            }
125            llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_THREW_CXX_EXCEPTION => {
126                let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
127                Err(LlamaContextLoadError::Reported { message })
128            }
129            other => unreachable!(
130                "llama_rs_new_context_with_model returned unrecognized status {other}"
131            ),
132        }
133    }
134
135    #[must_use]
136    pub fn n_batch(&self) -> u32 {
137        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
138    }
139
140    #[must_use]
141    pub fn n_ubatch(&self) -> u32 {
142        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
143    }
144
145    #[must_use]
146    pub fn n_ctx(&self) -> u32 {
147        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
148    }
149
150    #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
151    pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
152        let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
153        self.abort_flag = Some(flag);
154
155        unsafe {
156            llama_cpp_bindings_sys::llama_set_abort_callback(
157                self.context.as_ptr(),
158                Some(abort_callback_trampoline),
159                raw_ptr,
160            );
161        }
162    }
163
164    #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
165    pub fn clear_abort_callback(&mut self) {
166        self.abort_flag = None;
167
168        unsafe {
169            llama_cpp_bindings_sys::llama_set_abort_callback(
170                self.context.as_ptr(),
171                None,
172                std::ptr::null_mut(),
173            );
174        }
175    }
176
177    #[expect(unsafe_code, reason = "required for FFI synchronization call")]
178    pub fn synchronize(&self) {
179        unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
180    }
181
182    #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
183    pub fn detach_threadpool(&self) {
184        unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
185    }
186
187    pub fn mark_logits_initialized(&mut self, token_index: i32) {
188        self.initialized_logits = vec![token_index];
189    }
190
191    /// # Errors
192    ///
193    /// - `DecodeError` if the decoding failed.
194    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
195        let mut out_vendored_return_code: i32 = 0;
196        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
197        let status = unsafe {
198            llama_cpp_bindings_sys::llama_rs_decode(
199                self.context.as_ptr(),
200                batch.llama_batch,
201                &raw mut out_vendored_return_code,
202                &raw mut out_error,
203            )
204        };
205        match status {
206            llama_cpp_bindings_sys::LLAMA_RS_DECODE_OK => {
207                self.initialized_logits
208                    .clone_from(&batch.initialized_logits);
209                Ok(())
210            }
211            llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_RETURNED_NONZERO_CODE => {
212                let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
213                    unreachable!(
214                        "llama_rs_decode reported a nonzero return code but the value was zero"
215                    )
216                });
217                Err(DecodeError::from(code))
218            }
219            llama_cpp_bindings_sys::LLAMA_RS_DECODE_OUT_OF_MEMORY => {
220                Err(DecodeError::DecodeOutOfMemory)
221            }
222            llama_cpp_bindings_sys::LLAMA_RS_DECODE_COMPUTE_FAILED => {
223                Err(DecodeError::ComputeFailed)
224            }
225            llama_cpp_bindings_sys::LLAMA_RS_DECODE_ERROR_STRING_ALLOCATION_FAILED => {
226                Err(DecodeError::NotEnoughMemory)
227            }
228            llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_THREW_CXX_EXCEPTION => {
229                let message =
230                    unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
231                Err(DecodeError::Reported { message })
232            }
233            other => unreachable!("llama_rs_decode returned unrecognized status {other}"),
234        }
235    }
236
237    /// # Errors
238    ///
239    /// - `EncodeError` if the encoding failed.
240    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
241        let mut out_vendored_return_code: i32 = 0;
242        let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
243        let status = unsafe {
244            llama_cpp_bindings_sys::llama_rs_encode(
245                self.context.as_ptr(),
246                batch.llama_batch,
247                &raw mut out_vendored_return_code,
248                &raw mut out_error,
249            )
250        };
251        match status {
252            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OK => {
253                self.initialized_logits
254                    .clone_from(&batch.initialized_logits);
255                Ok(())
256            }
257            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_MODEL_HAS_NO_ENCODER => {
258                Err(EncodeError::ModelHasNoEncoder)
259            }
260            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_RETURNED_NONZERO_CODE => {
261                let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
262                    unreachable!(
263                        "llama_rs_encode reported a nonzero return code but the value was zero"
264                    )
265                });
266                Err(EncodeError::from(code))
267            }
268            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OUT_OF_MEMORY => {
269                Err(EncodeError::EncodeOutOfMemory)
270            }
271            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_COMPUTE_FAILED => {
272                Err(EncodeError::ComputeFailed)
273            }
274            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_ERROR_STRING_ALLOCATION_FAILED => {
275                Err(EncodeError::NotEnoughMemory)
276            }
277            llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_THREW_CXX_EXCEPTION => {
278                let message =
279                    unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
280                Err(EncodeError::Reported { message })
281            }
282            other => unreachable!("llama_rs_encode returned unrecognized status {other}"),
283        }
284    }
285
286    /// # Errors
287    ///
288    /// - When the current context was constructed without enabling embeddings.
289    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
290    /// - If the given sequence index exceeds the max sequence id.
291    ///
292    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
293        if !self.embeddings_enabled {
294            return Err(EmbeddingsError::NotEnabled);
295        }
296
297        let n_embd = usize::try_from(self.model.n_embd())
298            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
299
300        unsafe {
301            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
302                self.context.as_ptr(),
303                sequence_index,
304            );
305
306            if embedding.is_null() {
307                Err(EmbeddingsError::NonePoolType)
308            } else {
309                Ok(slice::from_raw_parts(embedding, n_embd))
310            }
311        }
312    }
313
314    /// # Errors
315    ///
316    /// - When the current context was constructed without enabling embeddings.
317    /// - When the given token didn't have logits enabled when it was passed.
318    /// - If the given token index exceeds the max token id.
319    ///
320    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
321        if !self.embeddings_enabled {
322            return Err(EmbeddingsError::NotEnabled);
323        }
324
325        let n_embd = usize::try_from(self.model.n_embd())
326            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
327
328        unsafe {
329            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
330                self.context.as_ptr(),
331                token_index,
332            );
333
334            if embedding.is_null() {
335                Err(EmbeddingsError::LogitsNotEnabled)
336            } else {
337                Ok(slice::from_raw_parts(embedding, n_embd))
338            }
339        }
340    }
341
342    /// # Errors
343    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
344    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
345        let logits = self.get_logits()?;
346
347        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
348            let token = LlamaToken::new(token_id);
349            LlamaTokenData::new(token, *logit, 0_f32)
350        }))
351    }
352
353    /// # Errors
354    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
355    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
356        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
357    }
358
359    /// # Errors
360    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
361    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
362        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
363
364        if data.is_null() {
365            return Err(LogitsError::NullLogits);
366        }
367
368        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
369
370        Ok(unsafe { slice::from_raw_parts(data, len) })
371    }
372
373    /// # Errors
374    /// Returns `LogitsError` if the token is not initialized or out of range.
375    pub fn candidates_ith(
376        &self,
377        token_index: i32,
378    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
379        let logits = self.get_logits_ith(token_index)?;
380
381        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
382            let token = LlamaToken::new(token_id);
383            LlamaTokenData::new(token, *logit, 0_f32)
384        }))
385    }
386
387    /// # Errors
388    /// Returns `LogitsError` if the token is not initialized or out of range.
389    pub fn token_data_array_ith(
390        &self,
391        token_index: i32,
392    ) -> Result<LlamaTokenDataArray, LogitsError> {
393        Ok(LlamaTokenDataArray::from_iter(
394            self.candidates_ith(token_index)?,
395            false,
396        ))
397    }
398
399    /// # Errors
400    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
401    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
402        if !self.initialized_logits.contains(&token_index) {
403            return Err(LogitsError::TokenNotInitialized(token_index));
404        }
405
406        if token_index >= 0 {
407            let token_index_u32 =
408                u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
409
410            if self.n_ctx() <= token_index_u32 {
411                return Err(LogitsError::TokenIndexExceedsContext {
412                    token_index: token_index_u32,
413                    context_size: self.n_ctx(),
414                });
415            }
416        }
417
418        let data = unsafe {
419            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
420        };
421        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
422
423        Ok(unsafe { slice::from_raw_parts(data, len) })
424    }
425
426    pub fn reset_timings(&mut self) {
427        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
428    }
429
430    pub fn timings(&mut self) -> LlamaTimings {
431        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
432        LlamaTimings { timings }
433    }
434
435    /// # Errors
436    ///
437    /// See [`LlamaLoraAdapterSetError`] for more information.
438    pub fn lora_adapter_set(
439        &self,
440        adapter: &mut LlamaLoraAdapter,
441        scale: f32,
442    ) -> Result<(), LlamaLoraAdapterSetError> {
443        let mut adapters = [adapter.lora_adapter.as_ptr()];
444        let mut scales = [scale];
445        let err_code = unsafe {
446            llama_cpp_bindings_sys::llama_set_adapters_lora(
447                self.context.as_ptr(),
448                adapters.as_mut_ptr(),
449                1,
450                scales.as_mut_ptr(),
451            )
452        };
453        check_lora_set_result(err_code)?;
454
455        log::debug!("Set lora adapter");
456        Ok(())
457    }
458
459    /// # Errors
460    ///
461    /// See [`LlamaLoraAdapterRemoveError`] for more information.
462    pub fn lora_adapter_remove(
463        &self,
464        _adapter: &mut LlamaLoraAdapter,
465    ) -> Result<(), LlamaLoraAdapterRemoveError> {
466        let err_code = unsafe {
467            llama_cpp_bindings_sys::llama_set_adapters_lora(
468                self.context.as_ptr(),
469                std::ptr::null_mut(),
470                0,
471                std::ptr::null_mut(),
472            )
473        };
474        check_lora_remove_result(err_code)?;
475
476        log::debug!("Remove lora adapter");
477        Ok(())
478    }
479}
480
481impl Drop for LlamaContext<'_> {
482    fn drop(&mut self) {
483        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
484    }
485}
486
487#[cfg(test)]
488mod unit_tests {
489    use crate::LlamaLoraAdapterRemoveError;
490    use crate::LlamaLoraAdapterSetError;
491
492    use super::{check_lora_remove_result, check_lora_set_result};
493
494    #[test]
495    fn check_lora_set_result_ok_for_zero() {
496        assert!(check_lora_set_result(0).is_ok());
497    }
498
499    #[test]
500    fn check_lora_set_result_error_for_nonzero() {
501        let result = check_lora_set_result(-1);
502
503        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
504    }
505
506    #[test]
507    fn check_lora_remove_result_ok_for_zero() {
508        assert!(check_lora_remove_result(0).is_ok());
509    }
510
511    #[test]
512    fn check_lora_remove_result_error_for_nonzero() {
513        let result = check_lora_remove_result(-1);
514
515        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
516    }
517}