Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
20    LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24    if err_code != 0 {
25        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26    }
27
28    Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32    if err_code != 0 {
33        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34    }
35
36    Ok(())
37}
38
39pub mod kv_cache;
40pub mod llama_state_seq_flags;
41pub mod load_seq_state_error;
42pub mod load_session_error;
43pub mod params;
44pub mod save_seq_state_error;
45pub mod save_session_error;
46pub mod session;
47
48unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
49    let flag = unsafe { &*(data as *const AtomicBool) };
50
51    flag.load(Ordering::Relaxed)
52}
53
54/// Safe wrapper around `llama_context`.
55pub struct LlamaContext<'model> {
56    /// Raw pointer to the underlying `llama_context`.
57    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
58    /// A reference to the context's model.
59    pub model: &'model LlamaModel,
60    abort_flag: Option<Arc<AtomicBool>>,
61    initialized_logits: Vec<i32>,
62    embeddings_enabled: bool,
63}
64
65impl Debug for LlamaContext<'_> {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("LlamaContext")
68            .field("context", &self.context)
69            .finish()
70    }
71}
72
73impl<'model> LlamaContext<'model> {
74    /// Wraps existing raw pointers into a new `LlamaContext`.
75    #[must_use]
76    pub const fn new(
77        llama_model: &'model LlamaModel,
78        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79        embeddings_enabled: bool,
80    ) -> Self {
81        Self {
82            context: llama_context,
83            model: llama_model,
84            abort_flag: None,
85            initialized_logits: Vec::new(),
86            embeddings_enabled,
87        }
88    }
89
90    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
91    #[must_use]
92    pub fn n_batch(&self) -> u32 {
93        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
94    }
95
96    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
97    #[must_use]
98    pub fn n_ubatch(&self) -> u32 {
99        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
100    }
101
102    /// Gets the size of the context.
103    #[must_use]
104    pub fn n_ctx(&self) -> u32 {
105        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
106    }
107
108    /// Sets an abort flag that llama.cpp checks during computation.
109    ///
110    /// When the flag is set to `true`, any in-progress `decode()` call will
111    /// abort and return `DecodeError::Aborted`. The `Arc` is stored internally
112    /// to ensure the flag outlives the callback registration.
113    #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
114    pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
115        let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
116        self.abort_flag = Some(flag);
117
118        unsafe {
119            llama_cpp_bindings_sys::llama_set_abort_callback(
120                self.context.as_ptr(),
121                Some(abort_callback_trampoline),
122                raw_ptr,
123            );
124        }
125    }
126
127    /// Clears the abort callback so that decode calls are no longer interruptible.
128    #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
129    pub fn clear_abort_callback(&mut self) {
130        self.abort_flag = None;
131
132        unsafe {
133            llama_cpp_bindings_sys::llama_set_abort_callback(
134                self.context.as_ptr(),
135                None,
136                std::ptr::null_mut(),
137            );
138        }
139    }
140
141    /// Waits for all pending backend operations to complete.
142    ///
143    /// Must be called before freeing the context to prevent hangs
144    /// during resource cleanup.
145    #[expect(unsafe_code, reason = "required for FFI synchronization call")]
146    pub fn synchronize(&self) {
147        unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
148    }
149
150    /// Detaches the threadpool from the context.
151    ///
152    /// Must be called before freeing the context to prevent threadpool
153    /// workers from accessing freed resources.
154    #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
155    pub fn detach_threadpool(&self) {
156        unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
157    }
158
159    /// Marks a logit index as initialized so it can be read via
160    /// `get_logits_ith`. Use after external decode operations (like
161    /// `eval_chunks`) that bypass the Rust `decode()` method.
162    pub fn mark_logits_initialized(&mut self, token_index: i32) {
163        self.initialized_logits = vec![token_index];
164    }
165
166    /// Decodes the batch.
167    ///
168    /// # Errors
169    ///
170    /// - `DecodeError` if the decoding failed.
171    ///
172    /// # Panics
173    ///
174    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
175    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
176        let result = unsafe {
177            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
178        };
179
180        match NonZeroI32::new(result) {
181            None => {
182                self.initialized_logits
183                    .clone_from(&batch.initialized_logits);
184                Ok(())
185            }
186            Some(error) => Err(DecodeError::from(error)),
187        }
188    }
189
190    /// Encodes the batch.
191    ///
192    /// # Errors
193    ///
194    /// - `EncodeError` if the decoding failed.
195    ///
196    /// # Panics
197    ///
198    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
199    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
200        let status = unsafe {
201            llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
202        };
203
204        self.handle_encode_result(status, batch)
205    }
206
207    fn handle_encode_result(
208        &mut self,
209        status: llama_cpp_bindings_sys::llama_rs_status,
210        batch: &mut LlamaBatch,
211    ) -> Result<(), EncodeError> {
212        if crate::status_is_ok(status) {
213            self.initialized_logits
214                .clone_from(&batch.initialized_logits);
215
216            Ok(())
217        } else {
218            Err(EncodeError::from(
219                NonZeroI32::new(crate::status_to_i32(status))
220                    .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
221            ))
222        }
223    }
224
225    /// Get the embeddings for the given sequence in the current context.
226    ///
227    /// # Returns
228    ///
229    /// A slice containing the embeddings for the last decoded batch.
230    /// The size corresponds to the `n_embd` parameter of the context's model.
231    ///
232    /// # Errors
233    ///
234    /// - When the current context was constructed without enabling embeddings.
235    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
236    /// - If the given sequence index exceeds the max sequence id.
237    ///
238    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
239        if !self.embeddings_enabled {
240            return Err(EmbeddingsError::NotEnabled);
241        }
242
243        let n_embd = usize::try_from(self.model.n_embd())
244            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
245
246        unsafe {
247            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
248                self.context.as_ptr(),
249                sequence_index,
250            );
251
252            if embedding.is_null() {
253                Err(EmbeddingsError::NonePoolType)
254            } else {
255                Ok(slice::from_raw_parts(embedding, n_embd))
256            }
257        }
258    }
259
260    /// Get the embeddings for the given token in the current context.
261    ///
262    /// # Returns
263    ///
264    /// A slice containing the embeddings for the last decoded batch of the given token.
265    /// The size corresponds to the `n_embd` parameter of the context's model.
266    ///
267    /// # Errors
268    ///
269    /// - When the current context was constructed without enabling embeddings.
270    /// - When the given token didn't have logits enabled when it was passed.
271    /// - If the given token index exceeds the max token id.
272    ///
273    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
274        if !self.embeddings_enabled {
275            return Err(EmbeddingsError::NotEnabled);
276        }
277
278        let n_embd = usize::try_from(self.model.n_embd())
279            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
280
281        unsafe {
282            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
283                self.context.as_ptr(),
284                token_index,
285            );
286
287            if embedding.is_null() {
288                Err(EmbeddingsError::LogitsNotEnabled)
289            } else {
290                Ok(slice::from_raw_parts(embedding, n_embd))
291            }
292        }
293    }
294
295    /// Get the logits for the last token in the context.
296    ///
297    /// # Returns
298    /// An iterator over unsorted `LlamaTokenData` containing the
299    /// logits for the last token in the context.
300    ///
301    /// # Errors
302    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
303    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
304        let logits = self.get_logits()?;
305
306        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
307            let token = LlamaToken::new(token_id);
308            LlamaTokenData::new(token, *logit, 0_f32)
309        }))
310    }
311
312    /// Get the token data array for the last token in the context.
313    ///
314    /// # Errors
315    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
316    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
317        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
318    }
319
320    /// Token logits obtained from the last call to `decode()`.
321    /// The logits for which `batch.logits[i] != 0` are stored contiguously
322    /// in the order they have appeared in the batch.
323    /// Rows: number of tokens for which `batch.logits[i] != 0`
324    /// Cols: `n_vocab`
325    ///
326    /// # Returns
327    ///
328    /// A slice containing the logits for the last decoded token.
329    /// The size corresponds to the `n_vocab` parameter of the context's model.
330    ///
331    /// # Errors
332    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
333    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
334        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
335
336        if data.is_null() {
337            return Err(LogitsError::NullLogits);
338        }
339
340        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
341
342        Ok(unsafe { slice::from_raw_parts(data, len) })
343    }
344
345    /// Get the logits for the ith token in the context.
346    ///
347    /// # Errors
348    /// Returns `LogitsError` if the token is not initialized or out of range.
349    pub fn candidates_ith(
350        &self,
351        token_index: i32,
352    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
353        let logits = self.get_logits_ith(token_index)?;
354
355        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
356            let token = LlamaToken::new(token_id);
357            LlamaTokenData::new(token, *logit, 0_f32)
358        }))
359    }
360
361    /// Get the token data array for the ith token in the context.
362    ///
363    /// # Errors
364    /// Returns `LogitsError` if the token is not initialized or out of range.
365    pub fn token_data_array_ith(
366        &self,
367        token_index: i32,
368    ) -> Result<LlamaTokenDataArray, LogitsError> {
369        Ok(LlamaTokenDataArray::from_iter(
370            self.candidates_ith(token_index)?,
371            false,
372        ))
373    }
374
375    /// Get the logits for the ith token in the context.
376    ///
377    /// # Errors
378    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
379    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
380        if !self.initialized_logits.contains(&token_index) {
381            return Err(LogitsError::TokenNotInitialized(token_index));
382        }
383
384        if token_index >= 0 {
385            let token_index_u32 =
386                u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
387
388            if self.n_ctx() <= token_index_u32 {
389                return Err(LogitsError::TokenIndexExceedsContext {
390                    token_index: token_index_u32,
391                    context_size: self.n_ctx(),
392                });
393            }
394        }
395
396        let data = unsafe {
397            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
398        };
399        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
400
401        Ok(unsafe { slice::from_raw_parts(data, len) })
402    }
403
404    /// Reset the timings for the context.
405    pub fn reset_timings(&mut self) {
406        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
407    }
408
409    /// Returns the timings for the context.
410    pub fn timings(&mut self) -> LlamaTimings {
411        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
412        LlamaTimings { timings }
413    }
414
415    /// Sets a lora adapter.
416    ///
417    /// # Errors
418    ///
419    /// See [`LlamaLoraAdapterSetError`] for more information.
420    pub fn lora_adapter_set(
421        &self,
422        adapter: &mut LlamaLoraAdapter,
423        scale: f32,
424    ) -> Result<(), LlamaLoraAdapterSetError> {
425        let mut adapters = [adapter.lora_adapter.as_ptr()];
426        let mut scales = [scale];
427        let err_code = unsafe {
428            llama_cpp_bindings_sys::llama_set_adapters_lora(
429                self.context.as_ptr(),
430                adapters.as_mut_ptr(),
431                1,
432                scales.as_mut_ptr(),
433            )
434        };
435        check_lora_set_result(err_code)?;
436
437        tracing::debug!("Set lora adapter");
438        Ok(())
439    }
440
441    /// Remove all lora adapters.
442    ///
443    /// Note: The upstream API now replaces all adapters at once via
444    /// `llama_set_adapters_lora`. This clears all adapters from the context.
445    ///
446    /// # Errors
447    ///
448    /// See [`LlamaLoraAdapterRemoveError`] for more information.
449    pub fn lora_adapter_remove(
450        &self,
451        _adapter: &mut LlamaLoraAdapter,
452    ) -> Result<(), LlamaLoraAdapterRemoveError> {
453        let err_code = unsafe {
454            llama_cpp_bindings_sys::llama_set_adapters_lora(
455                self.context.as_ptr(),
456                std::ptr::null_mut(),
457                0,
458                std::ptr::null_mut(),
459            )
460        };
461        check_lora_remove_result(err_code)?;
462
463        tracing::debug!("Remove lora adapter");
464        Ok(())
465    }
466
467    /// Print a breakdown of per-device memory use to the default logger.
468    pub fn print_memory_breakdown(&self) {
469        unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
470    }
471}
472
473impl Drop for LlamaContext<'_> {
474    fn drop(&mut self) {
475        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
476    }
477}
478
479#[cfg(test)]
480mod unit_tests {
481    use crate::LlamaLoraAdapterRemoveError;
482    use crate::LlamaLoraAdapterSetError;
483
484    use super::{check_lora_remove_result, check_lora_set_result};
485
486    #[test]
487    fn check_lora_set_result_ok_for_zero() {
488        assert!(check_lora_set_result(0).is_ok());
489    }
490
491    #[test]
492    fn check_lora_set_result_error_for_nonzero() {
493        let result = check_lora_set_result(-1);
494
495        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
496    }
497
498    #[test]
499    fn check_lora_remove_result_ok_for_zero() {
500        assert!(check_lora_remove_result(0).is_ok());
501    }
502
503    #[test]
504    fn check_lora_remove_result_error_for_nonzero() {
505        let result = check_lora_remove_result(-1);
506
507        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
508    }
509}
510
511#[cfg(test)]
512#[cfg(feature = "tests_that_use_llms")]
513mod tests {
514    use serial_test::serial;
515
516    use crate::context::params::LlamaContextParams;
517    use crate::llama_batch::LlamaBatch;
518    use crate::model::AddBos;
519    use crate::test_model;
520
521    #[test]
522    #[serial]
523    fn context_creation_and_properties() {
524        let (backend, model) = test_model::load_default_model().unwrap();
525        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526        let context = model.new_context(&backend, ctx_params).unwrap();
527        assert!(context.n_ctx() > 0);
528        assert!(context.n_batch() > 0);
529        assert!(context.n_ubatch() > 0);
530    }
531
532    #[test]
533    #[serial]
534    fn decode_and_get_logits() {
535        let (backend, model) = test_model::load_default_model().unwrap();
536        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
537        let mut context = model.new_context(&backend, ctx_params).unwrap();
538        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
539        let mut batch = LlamaBatch::new(512, 1).unwrap();
540        batch.add_sequence(&tokens, 0, false).unwrap();
541
542        let decode_result = context.decode(&mut batch);
543        assert!(decode_result.is_ok());
544
545        let logits = context.get_logits().unwrap();
546        assert!(!logits.is_empty());
547    }
548
549    #[test]
550    #[serial]
551    fn timings_work() {
552        let (backend, model) = test_model::load_default_model().unwrap();
553        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
554        let mut context = model.new_context(&backend, ctx_params).unwrap();
555        context.reset_timings();
556        let timings = context.timings();
557        assert!(timings.t_start_ms() >= 0.0);
558    }
559
560    #[test]
561    #[serial]
562    fn token_data_array_has_entries_after_decode() {
563        let (backend, model) = test_model::load_default_model().unwrap();
564        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
565        let mut context = model.new_context(&backend, ctx_params).unwrap();
566        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
567        let mut batch = LlamaBatch::new(512, 1).unwrap();
568        batch.add_sequence(&tokens, 0, false).unwrap();
569        context.decode(&mut batch).unwrap();
570
571        let token_data_array = context.token_data_array().unwrap();
572
573        assert!(!token_data_array.data.is_empty());
574    }
575
576    #[test]
577    #[serial]
578    fn get_logits_ith_returns_valid_slice() {
579        let (backend, model) = test_model::load_default_model().unwrap();
580        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
581        let mut context = model.new_context(&backend, ctx_params).unwrap();
582        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
583        let last_index = i32::try_from(tokens.len() - 1).unwrap();
584        let mut batch = LlamaBatch::new(512, 1).unwrap();
585        batch.add_sequence(&tokens, 0, false).unwrap();
586        context.decode(&mut batch).unwrap();
587
588        let logits = context.get_logits_ith(last_index).unwrap();
589
590        assert_eq!(logits.len(), model.n_vocab() as usize);
591    }
592
593    #[test]
594    #[serial]
595    fn token_data_array_ith_returns_valid_data() {
596        let (backend, model) = test_model::load_default_model().unwrap();
597        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
598        let mut context = model.new_context(&backend, ctx_params).unwrap();
599        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
600        let last_index = i32::try_from(tokens.len() - 1).unwrap();
601        let mut batch = LlamaBatch::new(512, 1).unwrap();
602        batch.add_sequence(&tokens, 0, false).unwrap();
603        context.decode(&mut batch).unwrap();
604
605        let token_data_array = context.token_data_array_ith(last_index).unwrap();
606
607        assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
608    }
609
610    #[test]
611    #[serial]
612    fn embeddings_ith_returns_error_when_embeddings_disabled() {
613        let (backend, model) = test_model::load_default_model().unwrap();
614        let ctx_params = LlamaContextParams::default()
615            .with_n_ctx(std::num::NonZeroU32::new(512))
616            .with_embeddings(false);
617        let context = model.new_context(&backend, ctx_params).unwrap();
618
619        let result = context.embeddings_ith(0);
620
621        assert!(result.is_err());
622    }
623
624    #[test]
625    #[serial]
626    fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
627        let (backend, model) = test_model::load_default_model().unwrap();
628        let ctx_params = LlamaContextParams::default()
629            .with_n_ctx(std::num::NonZeroU32::new(512))
630            .with_embeddings(false);
631        let context = model.new_context(&backend, ctx_params).unwrap();
632
633        let result = context.embeddings_seq_ith(0);
634
635        assert!(result.is_err());
636    }
637
638    #[test]
639    #[serial]
640    fn candidates_returns_n_vocab_entries() {
641        let (backend, model) = test_model::load_default_model().unwrap();
642        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
643        let mut context = model.new_context(&backend, ctx_params).unwrap();
644        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
645        let mut batch = LlamaBatch::new(512, 1).unwrap();
646        batch.add_sequence(&tokens, 0, false).unwrap();
647        context.decode(&mut batch).unwrap();
648
649        let count = context.candidates().unwrap().count();
650
651        assert_eq!(count, model.n_vocab() as usize);
652    }
653
654    #[test]
655    #[serial]
656    fn debug_format_contains_struct_name() {
657        let (backend, model) = test_model::load_default_model().unwrap();
658        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
659        let context = model.new_context(&backend, ctx_params).unwrap();
660        let debug_output = format!("{context:?}");
661
662        assert!(debug_output.contains("LlamaContext"));
663    }
664
665    #[test]
666    #[serial]
667    fn decode_with_embeddings_enabled() {
668        let (backend, model) = test_model::load_default_embedding_model().unwrap();
669        let ctx_params = LlamaContextParams::default()
670            .with_n_ctx(std::num::NonZeroU32::new(512))
671            .with_embeddings(true);
672        let mut context = model.new_context(&backend, ctx_params).unwrap();
673        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
674        let mut batch = LlamaBatch::new(512, 1).unwrap();
675        batch.add_sequence(&tokens, 0, false).unwrap();
676
677        let result = context.decode(&mut batch);
678
679        assert!(result.is_ok());
680    }
681
682    #[test]
683    #[serial]
684    fn embeddings_seq_ith_returns_valid_embeddings() {
685        let (backend, model) = test_model::load_default_embedding_model().unwrap();
686        let ctx_params = LlamaContextParams::default()
687            .with_n_ctx(std::num::NonZeroU32::new(512))
688            .with_embeddings(true);
689        let mut context = model.new_context(&backend, ctx_params).unwrap();
690        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
691        let mut batch = LlamaBatch::new(512, 1).unwrap();
692        batch.add_sequence(&tokens, 0, false).unwrap();
693        context.decode(&mut batch).unwrap();
694
695        let embeddings = context.embeddings_seq_ith(0).unwrap();
696
697        assert_eq!(embeddings.len(), model.n_embd() as usize);
698    }
699
700    #[test]
701    #[serial]
702    fn embeddings_ith_returns_valid_embeddings() {
703        let (backend, model) = test_model::load_default_embedding_model().unwrap();
704        let ctx_params = LlamaContextParams::default()
705            .with_n_ctx(std::num::NonZeroU32::new(512))
706            .with_embeddings(true);
707        let mut context = model.new_context(&backend, ctx_params).unwrap();
708        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
709        let last_index = i32::try_from(tokens.len() - 1).unwrap();
710        let mut batch = LlamaBatch::new(512, 1).unwrap();
711        batch.add_sequence(&tokens, 0, false).unwrap();
712        context.decode(&mut batch).unwrap();
713
714        let embeddings = context.embeddings_ith(last_index).unwrap();
715
716        assert_eq!(embeddings.len(), model.n_embd() as usize);
717    }
718
719    #[test]
720    #[serial]
721    fn candidates_ith_returns_n_vocab_entries() {
722        let (backend, model) = test_model::load_default_model().unwrap();
723        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
724        let mut context = model.new_context(&backend, ctx_params).unwrap();
725        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
726        let last_index = i32::try_from(tokens.len() - 1).unwrap();
727        let mut batch = LlamaBatch::new(512, 1).unwrap();
728        batch.add_sequence(&tokens, 0, false).unwrap();
729        context.decode(&mut batch).unwrap();
730
731        let count = context.candidates_ith(last_index).unwrap().count();
732
733        assert_eq!(count, model.n_vocab() as usize);
734    }
735
736    #[test]
737    #[serial]
738    fn lora_adapter_remove_succeeds_with_no_adapters() {
739        let (backend, model) = test_model::load_default_model().unwrap();
740        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
741        let context = model.new_context(&backend, ctx_params).unwrap();
742        let mut adapter = crate::model::LlamaLoraAdapter {
743            lora_adapter: std::ptr::NonNull::dangling(),
744        };
745
746        let result = context.lora_adapter_remove(&mut adapter);
747
748        assert!(result.is_ok());
749    }
750
751    #[test]
752    #[serial]
753    fn encode_on_non_encoder_model_returns_error() {
754        let (backend, model) = test_model::load_default_model().unwrap();
755        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
756        let mut context = model.new_context(&backend, ctx_params).unwrap();
757        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
758        let mut batch = LlamaBatch::new(512, 1).unwrap();
759        batch.add_sequence(&tokens, 0, false).unwrap();
760
761        let result = context.encode(&mut batch);
762
763        assert!(result.is_err());
764    }
765
766    #[test]
767    #[serial]
768    fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
769        let (backend, model) = test_model::load_default_model().unwrap();
770        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
771        let context = model.new_context(&backend, ctx_params).unwrap();
772        let mut adapter = crate::model::LlamaLoraAdapter {
773            lora_adapter: std::ptr::NonNull::dangling(),
774        };
775
776        let result = context.lora_adapter_set(&mut adapter, 1.0);
777
778        assert!(result.is_ok());
779    }
780
781    #[test]
782    #[serial]
783    fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
784        let (backend, model) = test_model::load_default_embedding_model().unwrap();
785        let ctx_params = LlamaContextParams::default()
786            .with_n_ctx(std::num::NonZeroU32::new(512))
787            .with_embeddings(true);
788        let context = model.new_context(&backend, ctx_params).unwrap();
789
790        let result = context.embeddings_ith(999);
791
792        assert!(result.is_err());
793    }
794
795    #[test]
796    #[serial]
797    fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
798        let (backend, model) = test_model::load_default_model().unwrap();
799        let ctx_params = LlamaContextParams::default()
800            .with_n_ctx(std::num::NonZeroU32::new(512))
801            .with_embeddings(true);
802        let mut context = model.new_context(&backend, ctx_params).unwrap();
803        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
804        let mut batch = LlamaBatch::new(512, 1).unwrap();
805        batch.add_sequence(&tokens, 0, false).unwrap();
806        context.decode(&mut batch).unwrap();
807
808        let result = context.embeddings_seq_ith(999);
809
810        assert!(result.is_err());
811    }
812
813    #[test]
814    #[serial]
815    fn decode_empty_batch_returns_error() {
816        let (backend, model) = test_model::load_default_model().unwrap();
817        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
818        let mut context = model.new_context(&backend, ctx_params).unwrap();
819        let mut batch = LlamaBatch::new(512, 1).unwrap();
820
821        let result = context.decode(&mut batch);
822
823        assert!(result.is_err());
824    }
825
826    #[test]
827    #[serial]
828    fn encode_succeeds_with_encoder_model() {
829        let backend = crate::llama_backend::LlamaBackend::init().unwrap();
830        let model_path = test_model::download_encoder_model().unwrap();
831        let model_params = crate::model::params::LlamaModelParams::default();
832        let model =
833            crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
834        let ctx_params = LlamaContextParams::default()
835            .with_n_ctx(std::num::NonZeroU32::new(512))
836            .with_embeddings(true);
837        let mut context = model.new_context(&backend, ctx_params).unwrap();
838        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
839        let mut batch = LlamaBatch::new(512, 1).unwrap();
840        batch.add_sequence(&tokens, 0, false).unwrap();
841
842        let result = context.encode(&mut batch);
843
844        assert!(result.is_ok());
845    }
846
847    #[test]
848    #[serial]
849    fn handle_encode_result_ok_updates_logits() {
850        let (backend, model) = test_model::load_default_model().unwrap();
851        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
852        let mut context = model.new_context(&backend, ctx_params).unwrap();
853        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
854        let mut batch = LlamaBatch::new(512, 1).unwrap();
855        batch.add_sequence(&tokens, 0, true).unwrap();
856
857        let result =
858            context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
859
860        assert!(result.is_ok());
861        assert!(!context.initialized_logits.is_empty());
862    }
863
864    #[test]
865    #[serial]
866    fn set_abort_flag_aborts_decode() {
867        use std::sync::Arc;
868        use std::sync::atomic::AtomicBool;
869        use std::sync::atomic::Ordering;
870
871        use crate::DecodeError;
872
873        let (backend, model) = test_model::load_default_model().unwrap();
874        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
875        let mut context = model.new_context(&backend, ctx_params).unwrap();
876        let abort_flag = Arc::new(AtomicBool::new(true));
877        context.set_abort_flag(abort_flag.clone());
878
879        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
880        let mut batch = LlamaBatch::new(512, 1).unwrap();
881        batch.add_sequence(&tokens, 0, false).unwrap();
882
883        let result = context.decode(&mut batch);
884
885        assert_eq!(result, Err(DecodeError::Aborted));
886    }
887
888    #[test]
889    #[serial]
890    fn set_abort_flag_false_allows_decode() {
891        use std::sync::Arc;
892        use std::sync::atomic::AtomicBool;
893
894        let (backend, model) = test_model::load_default_model().unwrap();
895        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
896        let mut context = model.new_context(&backend, ctx_params).unwrap();
897        let abort_flag = Arc::new(AtomicBool::new(false));
898        context.set_abort_flag(abort_flag);
899
900        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
901        let mut batch = LlamaBatch::new(512, 1).unwrap();
902        batch.add_sequence(&tokens, 0, false).unwrap();
903
904        let result = context.decode(&mut batch);
905
906        assert!(result.is_ok());
907    }
908
909    #[test]
910    #[serial]
911    fn clear_abort_callback_allows_decode_with_flag_true() {
912        use std::sync::Arc;
913        use std::sync::atomic::AtomicBool;
914
915        let (backend, model) = test_model::load_default_model().unwrap();
916        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
917        let mut context = model.new_context(&backend, ctx_params).unwrap();
918        let abort_flag = Arc::new(AtomicBool::new(true));
919        context.set_abort_flag(abort_flag);
920        context.clear_abort_callback();
921
922        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
923        let mut batch = LlamaBatch::new(512, 1).unwrap();
924        batch.add_sequence(&tokens, 0, false).unwrap();
925
926        let result = context.decode(&mut batch);
927
928        assert!(result.is_ok());
929    }
930}