Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
20    LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24    if err_code != 0 {
25        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26    }
27
28    Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32    if err_code != 0 {
33        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34    }
35
36    Ok(())
37}
38
39pub mod kv_cache;
40pub mod llama_state_seq_flags;
41pub mod load_seq_state_error;
42pub mod load_session_error;
43pub mod params;
44pub mod save_seq_state_error;
45pub mod save_session_error;
46pub mod session;
47
48unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
49    let flag = unsafe { &*(data as *const AtomicBool) };
50
51    flag.load(Ordering::Relaxed)
52}
53
54/// Safe wrapper around `llama_context`.
55pub struct LlamaContext<'model> {
56    /// Raw pointer to the underlying `llama_context`.
57    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
58    /// A reference to the context's model.
59    pub model: &'model LlamaModel,
60    abort_flag: Option<Arc<AtomicBool>>,
61    initialized_logits: Vec<i32>,
62    embeddings_enabled: bool,
63}
64
65impl Debug for LlamaContext<'_> {
66    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67        f.debug_struct("LlamaContext")
68            .field("context", &self.context)
69            .finish()
70    }
71}
72
73impl<'model> LlamaContext<'model> {
74    /// Wraps existing raw pointers into a new `LlamaContext`.
75    #[must_use]
76    pub const fn new(
77        llama_model: &'model LlamaModel,
78        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79        embeddings_enabled: bool,
80    ) -> Self {
81        Self {
82            context: llama_context,
83            model: llama_model,
84            abort_flag: None,
85            initialized_logits: Vec::new(),
86            embeddings_enabled,
87        }
88    }
89
90    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
91    #[must_use]
92    pub fn n_batch(&self) -> u32 {
93        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
94    }
95
96    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
97    #[must_use]
98    pub fn n_ubatch(&self) -> u32 {
99        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
100    }
101
102    /// Gets the size of the context.
103    #[must_use]
104    pub fn n_ctx(&self) -> u32 {
105        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
106    }
107
108    /// Sets an abort flag that llama.cpp checks during computation.
109    ///
110    /// When the flag is set to `true`, any in-progress `decode()` call will
111    /// abort and return `DecodeError::Aborted`. The `Arc` is stored internally
112    /// to ensure the flag outlives the callback registration.
113    #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
114    pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
115        let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
116        self.abort_flag = Some(flag);
117
118        unsafe {
119            llama_cpp_bindings_sys::llama_set_abort_callback(
120                self.context.as_ptr(),
121                Some(abort_callback_trampoline),
122                raw_ptr,
123            );
124        }
125    }
126
127    /// Clears the abort callback so that decode calls are no longer interruptible.
128    #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
129    pub fn clear_abort_callback(&mut self) {
130        self.abort_flag = None;
131
132        unsafe {
133            llama_cpp_bindings_sys::llama_set_abort_callback(
134                self.context.as_ptr(),
135                None,
136                std::ptr::null_mut(),
137            );
138        }
139    }
140
141    /// Waits for all pending backend operations to complete.
142    ///
143    /// Must be called before freeing the context to prevent hangs
144    /// during resource cleanup.
145    #[expect(unsafe_code, reason = "required for FFI synchronization call")]
146    pub fn synchronize(&self) {
147        unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
148    }
149
150    /// Detaches the threadpool from the context.
151    ///
152    /// Must be called before freeing the context to prevent threadpool
153    /// workers from accessing freed resources.
154    #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
155    pub fn detach_threadpool(&self) {
156        unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
157    }
158
159    /// Marks a logit index as initialized so it can be read via
160    /// `get_logits_ith`. Use after external decode operations (like
161    /// `eval_chunks`) that bypass the Rust `decode()` method.
162    pub fn mark_logits_initialized(&mut self, token_index: i32) {
163        self.initialized_logits = vec![token_index];
164    }
165
166    /// Decodes the batch.
167    ///
168    /// # Errors
169    ///
170    /// - `DecodeError` if the decoding failed.
171    ///
172    /// # Panics
173    ///
174    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
175    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
176        let result = unsafe {
177            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
178        };
179
180        match NonZeroI32::new(result) {
181            None => {
182                self.initialized_logits
183                    .clone_from(&batch.initialized_logits);
184                Ok(())
185            }
186            Some(error) => Err(DecodeError::from(error)),
187        }
188    }
189
190    /// Encodes the batch.
191    ///
192    /// # Errors
193    ///
194    /// - `EncodeError` if the decoding failed.
195    ///
196    /// # Panics
197    ///
198    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
199    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
200        let status = unsafe {
201            llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
202        };
203
204        self.handle_encode_result(status, batch)
205    }
206
207    fn handle_encode_result(
208        &mut self,
209        status: llama_cpp_bindings_sys::llama_rs_status,
210        batch: &mut LlamaBatch,
211    ) -> Result<(), EncodeError> {
212        if crate::status_is_ok(status) {
213            self.initialized_logits
214                .clone_from(&batch.initialized_logits);
215
216            Ok(())
217        } else {
218            Err(EncodeError::from(
219                NonZeroI32::new(crate::status_to_i32(status))
220                    .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
221            ))
222        }
223    }
224
225    /// Get the embeddings for the given sequence in the current context.
226    ///
227    /// # Returns
228    ///
229    /// A slice containing the embeddings for the last decoded batch.
230    /// The size corresponds to the `n_embd` parameter of the context's model.
231    ///
232    /// # Errors
233    ///
234    /// - When the current context was constructed without enabling embeddings.
235    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
236    /// - If the given sequence index exceeds the max sequence id.
237    ///
238    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
239        if !self.embeddings_enabled {
240            return Err(EmbeddingsError::NotEnabled);
241        }
242
243        let n_embd = usize::try_from(self.model.n_embd())
244            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
245
246        unsafe {
247            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
248                self.context.as_ptr(),
249                sequence_index,
250            );
251
252            if embedding.is_null() {
253                Err(EmbeddingsError::NonePoolType)
254            } else {
255                Ok(slice::from_raw_parts(embedding, n_embd))
256            }
257        }
258    }
259
260    /// Get the embeddings for the given token in the current context.
261    ///
262    /// # Returns
263    ///
264    /// A slice containing the embeddings for the last decoded batch of the given token.
265    /// The size corresponds to the `n_embd` parameter of the context's model.
266    ///
267    /// # Errors
268    ///
269    /// - When the current context was constructed without enabling embeddings.
270    /// - When the given token didn't have logits enabled when it was passed.
271    /// - If the given token index exceeds the max token id.
272    ///
273    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
274        if !self.embeddings_enabled {
275            return Err(EmbeddingsError::NotEnabled);
276        }
277
278        let n_embd = usize::try_from(self.model.n_embd())
279            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
280
281        unsafe {
282            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
283                self.context.as_ptr(),
284                token_index,
285            );
286
287            if embedding.is_null() {
288                Err(EmbeddingsError::LogitsNotEnabled)
289            } else {
290                Ok(slice::from_raw_parts(embedding, n_embd))
291            }
292        }
293    }
294
295    /// Get the logits for the last token in the context.
296    ///
297    /// # Returns
298    /// An iterator over unsorted `LlamaTokenData` containing the
299    /// logits for the last token in the context.
300    ///
301    /// # Errors
302    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
303    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
304        let logits = self.get_logits()?;
305
306        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
307            let token = LlamaToken::new(token_id);
308            LlamaTokenData::new(token, *logit, 0_f32)
309        }))
310    }
311
312    /// Get the token data array for the last token in the context.
313    ///
314    /// # Errors
315    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
316    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
317        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
318    }
319
320    /// Token logits obtained from the last call to `decode()`.
321    /// The logits for which `batch.logits[i] != 0` are stored contiguously
322    /// in the order they have appeared in the batch.
323    /// Rows: number of tokens for which `batch.logits[i] != 0`
324    /// Cols: `n_vocab`
325    ///
326    /// # Returns
327    ///
328    /// A slice containing the logits for the last decoded token.
329    /// The size corresponds to the `n_vocab` parameter of the context's model.
330    ///
331    /// # Errors
332    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
333    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
334        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
335
336        if data.is_null() {
337            return Err(LogitsError::NullLogits);
338        }
339
340        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
341
342        Ok(unsafe { slice::from_raw_parts(data, len) })
343    }
344
345    /// Get the logits for the ith token in the context.
346    ///
347    /// # Errors
348    /// Returns `LogitsError` if the token is not initialized or out of range.
349    pub fn candidates_ith(
350        &self,
351        token_index: i32,
352    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
353        let logits = self.get_logits_ith(token_index)?;
354
355        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
356            let token = LlamaToken::new(token_id);
357            LlamaTokenData::new(token, *logit, 0_f32)
358        }))
359    }
360
361    /// Get the token data array for the ith token in the context.
362    ///
363    /// # Errors
364    /// Returns `LogitsError` if the token is not initialized or out of range.
365    pub fn token_data_array_ith(
366        &self,
367        token_index: i32,
368    ) -> Result<LlamaTokenDataArray, LogitsError> {
369        Ok(LlamaTokenDataArray::from_iter(
370            self.candidates_ith(token_index)?,
371            false,
372        ))
373    }
374
375    /// Get the logits for the ith token in the context.
376    ///
377    /// # Errors
378    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
379    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
380        if !self.initialized_logits.contains(&token_index) {
381            return Err(LogitsError::TokenNotInitialized(token_index));
382        }
383
384        if token_index >= 0 {
385            let token_index_u32 =
386                u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
387
388            if self.n_ctx() <= token_index_u32 {
389                return Err(LogitsError::TokenIndexExceedsContext {
390                    token_index: token_index_u32,
391                    context_size: self.n_ctx(),
392                });
393            }
394        }
395
396        let data = unsafe {
397            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
398        };
399        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
400
401        Ok(unsafe { slice::from_raw_parts(data, len) })
402    }
403
404    /// Reset the timings for the context.
405    pub fn reset_timings(&mut self) {
406        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
407    }
408
409    /// Returns the timings for the context.
410    pub fn timings(&mut self) -> LlamaTimings {
411        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
412        LlamaTimings { timings }
413    }
414
415    /// Sets a lora adapter.
416    ///
417    /// # Errors
418    ///
419    /// See [`LlamaLoraAdapterSetError`] for more information.
420    pub fn lora_adapter_set(
421        &self,
422        adapter: &mut LlamaLoraAdapter,
423        scale: f32,
424    ) -> Result<(), LlamaLoraAdapterSetError> {
425        let mut adapters = [adapter.lora_adapter.as_ptr()];
426        let mut scales = [scale];
427        let err_code = unsafe {
428            llama_cpp_bindings_sys::llama_set_adapters_lora(
429                self.context.as_ptr(),
430                adapters.as_mut_ptr(),
431                1,
432                scales.as_mut_ptr(),
433            )
434        };
435        check_lora_set_result(err_code)?;
436
437        tracing::debug!("Set lora adapter");
438        Ok(())
439    }
440
441    /// Remove all lora adapters.
442    ///
443    /// Note: The upstream API now replaces all adapters at once via
444    /// `llama_set_adapters_lora`. This clears all adapters from the context.
445    ///
446    /// # Errors
447    ///
448    /// See [`LlamaLoraAdapterRemoveError`] for more information.
449    pub fn lora_adapter_remove(
450        &self,
451        _adapter: &mut LlamaLoraAdapter,
452    ) -> Result<(), LlamaLoraAdapterRemoveError> {
453        let err_code = unsafe {
454            llama_cpp_bindings_sys::llama_set_adapters_lora(
455                self.context.as_ptr(),
456                std::ptr::null_mut(),
457                0,
458                std::ptr::null_mut(),
459            )
460        };
461        check_lora_remove_result(err_code)?;
462
463        tracing::debug!("Remove lora adapter");
464        Ok(())
465    }
466
467    /// Print a breakdown of per-device memory use to the default logger.
468    pub fn print_memory_breakdown(&self) {
469        unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
470    }
471}
472
473impl Drop for LlamaContext<'_> {
474    fn drop(&mut self) {
475        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
476    }
477}
478
479#[cfg(test)]
480mod unit_tests {
481    use crate::LlamaLoraAdapterRemoveError;
482    use crate::LlamaLoraAdapterSetError;
483
484    use super::{check_lora_remove_result, check_lora_set_result};
485
486    #[test]
487    fn check_lora_set_result_ok_for_zero() {
488        assert!(check_lora_set_result(0).is_ok());
489    }
490
491    #[test]
492    fn check_lora_set_result_error_for_nonzero() {
493        let result = check_lora_set_result(-1);
494
495        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
496    }
497
498    #[test]
499    fn check_lora_remove_result_ok_for_zero() {
500        assert!(check_lora_remove_result(0).is_ok());
501    }
502
503    #[test]
504    fn check_lora_remove_result_error_for_nonzero() {
505        let result = check_lora_remove_result(-1);
506
507        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
508    }
509}
510
511#[cfg(test)]
512#[cfg(feature = "tests_that_use_llms")]
513mod tests {
514    use serial_test::serial;
515
516    use crate::context::params::LlamaContextParams;
517    use crate::llama_batch::LlamaBatch;
518    use crate::model::AddBos;
519    use crate::test_model;
520
521    #[test]
522    #[serial]
523    fn context_creation_and_properties() {
524        let (backend, model) = test_model::load_default_model().unwrap();
525        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526        let context = model.new_context(&backend, ctx_params).unwrap();
527        assert!(context.n_ctx() > 0);
528        assert!(context.n_batch() > 0);
529        assert!(context.n_ubatch() > 0);
530    }
531
532    #[test]
533    #[serial]
534    fn decode_and_get_logits() {
535        let (backend, model) = test_model::load_default_model().unwrap();
536        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
537        let mut context = model.new_context(&backend, ctx_params).unwrap();
538        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
539        let mut batch = LlamaBatch::new(512, 1).unwrap();
540        batch.add_sequence(&tokens, 0, false).unwrap();
541
542        let decode_result = context.decode(&mut batch);
543        assert!(decode_result.is_ok());
544
545        let logits = context.get_logits().unwrap();
546        assert!(!logits.is_empty());
547    }
548
549    #[test]
550    #[serial]
551    fn timings_work() {
552        let (backend, model) = test_model::load_default_model().unwrap();
553        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
554        let mut context = model.new_context(&backend, ctx_params).unwrap();
555        context.reset_timings();
556        let timings = context.timings();
557        assert!(timings.t_start_ms() >= 0.0);
558    }
559
560    #[test]
561    #[serial]
562    fn token_data_array_has_entries_after_decode() {
563        let (backend, model) = test_model::load_default_model().unwrap();
564        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
565        let mut context = model.new_context(&backend, ctx_params).unwrap();
566        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
567        let mut batch = LlamaBatch::new(512, 1).unwrap();
568        batch.add_sequence(&tokens, 0, false).unwrap();
569        context.decode(&mut batch).unwrap();
570
571        let token_data_array = context.token_data_array().unwrap();
572
573        assert!(!token_data_array.data.is_empty());
574    }
575
576    #[test]
577    #[serial]
578    fn get_logits_ith_returns_valid_slice() {
579        let (backend, model) = test_model::load_default_model().unwrap();
580        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
581        let mut context = model.new_context(&backend, ctx_params).unwrap();
582        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
583        let last_index = i32::try_from(tokens.len() - 1).unwrap();
584        let mut batch = LlamaBatch::new(512, 1).unwrap();
585        batch.add_sequence(&tokens, 0, false).unwrap();
586        context.decode(&mut batch).unwrap();
587
588        let logits = context.get_logits_ith(last_index).unwrap();
589
590        assert_eq!(logits.len(), model.n_vocab() as usize);
591    }
592
593    #[test]
594    #[serial]
595    fn token_data_array_ith_returns_valid_data() {
596        let (backend, model) = test_model::load_default_model().unwrap();
597        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
598        let mut context = model.new_context(&backend, ctx_params).unwrap();
599        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
600        let last_index = i32::try_from(tokens.len() - 1).unwrap();
601        let mut batch = LlamaBatch::new(512, 1).unwrap();
602        batch.add_sequence(&tokens, 0, false).unwrap();
603        context.decode(&mut batch).unwrap();
604
605        let token_data_array = context.token_data_array_ith(last_index).unwrap();
606
607        assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
608    }
609
610    #[test]
611    #[serial]
612    fn embeddings_ith_returns_error_when_embeddings_disabled() {
613        let (backend, model) = test_model::load_default_model().unwrap();
614        let ctx_params = LlamaContextParams::default()
615            .with_n_ctx(std::num::NonZeroU32::new(512))
616            .with_embeddings(false);
617        let context = model.new_context(&backend, ctx_params).unwrap();
618
619        let result = context.embeddings_ith(0);
620
621        assert!(result.is_err());
622    }
623
624    #[test]
625    #[serial]
626    fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
627        let (backend, model) = test_model::load_default_model().unwrap();
628        let ctx_params = LlamaContextParams::default()
629            .with_n_ctx(std::num::NonZeroU32::new(512))
630            .with_embeddings(false);
631        let context = model.new_context(&backend, ctx_params).unwrap();
632
633        let result = context.embeddings_seq_ith(0);
634
635        assert!(result.is_err());
636    }
637
638    #[test]
639    #[serial]
640    fn candidates_returns_n_vocab_entries() {
641        let (backend, model) = test_model::load_default_model().unwrap();
642        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
643        let mut context = model.new_context(&backend, ctx_params).unwrap();
644        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
645        let mut batch = LlamaBatch::new(512, 1).unwrap();
646        batch.add_sequence(&tokens, 0, false).unwrap();
647        context.decode(&mut batch).unwrap();
648
649        let count = context.candidates().unwrap().count();
650
651        assert_eq!(count, model.n_vocab() as usize);
652    }
653
654    #[test]
655    #[serial]
656    fn debug_format_contains_struct_name() {
657        let (backend, model) = test_model::load_default_model().unwrap();
658        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
659        let context = model.new_context(&backend, ctx_params).unwrap();
660        let debug_output = format!("{context:?}");
661
662        assert!(debug_output.contains("LlamaContext"));
663    }
664
665    #[test]
666    #[serial]
667    fn decode_with_embeddings_enabled() {
668        let (backend, model) = test_model::load_default_embedding_model().unwrap();
669        let ctx_params = LlamaContextParams::default()
670            .with_n_ctx(std::num::NonZeroU32::new(512))
671            .with_embeddings(true);
672        let mut context = model.new_context(&backend, ctx_params).unwrap();
673        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
674        let mut batch = LlamaBatch::new(512, 1).unwrap();
675        batch.add_sequence(&tokens, 0, false).unwrap();
676
677        let result = context.decode(&mut batch);
678
679        assert!(result.is_ok());
680    }
681
682    #[test]
683    #[serial]
684    fn embeddings_seq_ith_returns_valid_embeddings() {
685        let (backend, model) = test_model::load_default_embedding_model().unwrap();
686        let ctx_params = LlamaContextParams::default()
687            .with_n_ctx(std::num::NonZeroU32::new(512))
688            .with_embeddings(true);
689        let mut context = model.new_context(&backend, ctx_params).unwrap();
690        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
691        let mut batch = LlamaBatch::new(512, 1).unwrap();
692        batch.add_sequence(&tokens, 0, false).unwrap();
693        context.decode(&mut batch).unwrap();
694
695        let embeddings = context.embeddings_seq_ith(0).unwrap();
696
697        assert_eq!(embeddings.len(), model.n_embd() as usize);
698    }
699
700    #[test]
701    #[serial]
702    fn multi_sequence_embeddings_returns_one_embedding_per_sequence() {
703        let (backend, model) = test_model::load_default_embedding_model().unwrap();
704        let ctx_params = LlamaContextParams::default()
705            .with_n_ctx(std::num::NonZeroU32::new(512))
706            .with_n_seq_max(4)
707            .with_embeddings(true);
708        let mut context = model.new_context(&backend, ctx_params).unwrap();
709
710        let inputs = [
711            "alpha is here",
712            "beta runs fast",
713            "gamma waits",
714            "delta jumps",
715        ];
716        let mut batch = LlamaBatch::new(64, 4).unwrap();
717
718        for (sequence_index, text) in inputs.iter().enumerate() {
719            let tokens = model.str_to_token(text, AddBos::Always).unwrap();
720            let sequence_id = i32::try_from(sequence_index).unwrap();
721
722            batch.add_sequence(&tokens, sequence_id, true).unwrap();
723        }
724
725        context.decode(&mut batch).unwrap();
726
727        let n_embd = model.n_embd() as usize;
728        let mut collected: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
729
730        for sequence_index in 0..inputs.len() {
731            let sequence_id = i32::try_from(sequence_index).unwrap();
732            let embedding = context.embeddings_seq_ith(sequence_id).unwrap();
733
734            assert_eq!(
735                embedding.len(),
736                n_embd,
737                "sequence {sequence_index} embedding length mismatch"
738            );
739
740            collected.push(embedding.to_vec());
741        }
742
743        for (left_index, left) in collected.iter().enumerate() {
744            for (right_index, right) in collected.iter().enumerate().skip(left_index + 1) {
745                assert_ne!(
746                    left, right,
747                    "embedding for sequence {left_index} must differ from sequence {right_index}",
748                );
749            }
750        }
751    }
752
753    /// Reproduces paddler's embedding batching loop exactly with the document strings, batch
754    /// shape, and iteration pattern from the failing harness test
755    /// `agent_embedding_batch_distribution_independent_of_context_size`. A `LlamaBatch` is
756    /// allocated once with `n_tokens=64` and `n_seq_max=4`, then reused across two iterations
757    /// of two sequences each (because the four ~22-token docs do not all fit in one
758    /// 64-token window). Per iteration: `add_sequence` for each doc, `clear_kv_cache`,
759    /// `decode`, `embeddings_seq_ith` for each filled slot, `batch.clear()`. Every iteration
760    /// must yield distinct, non-empty embeddings — including iterations after the first.
761    #[test]
762    #[serial]
763    fn embeddings_returns_distinct_values_when_reused_batch_has_extra_capacity() {
764        let (backend, model) = test_model::load_default_embedding_model().unwrap();
765        let ctx_params = LlamaContextParams::default()
766            .with_n_ctx(std::num::NonZeroU32::new(512))
767            .with_n_seq_max(4)
768            .with_embeddings(true);
769        let mut context = model.new_context(&backend, ctx_params).unwrap();
770
771        let iterations = [
772            [
773                "This is the first document with enough content to contribute meaningfully to the batch size calculation",
774                "This is the second document that should be processed in a potentially different batch from the first",
775            ],
776            [
777                "This is the third document adding more content to ensure the total exceeds the configured chunk limit",
778                "This is the fourth document which should demonstrate that batching distributes across agent requests",
779            ],
780        ];
781
782        let n_embd = model.n_embd() as usize;
783        let mut batch = LlamaBatch::new(64, 4).unwrap();
784        let mut collected: Vec<Vec<f32>> = Vec::new();
785
786        for iteration_inputs in iterations {
787            for (sequence_index, text) in iteration_inputs.iter().enumerate() {
788                let tokens = model.str_to_token(text, AddBos::Always).unwrap();
789                let sequence_id = i32::try_from(sequence_index).unwrap();
790
791                batch.add_sequence(&tokens, sequence_id, true).unwrap();
792            }
793
794            context.clear_kv_cache();
795            context.decode(&mut batch).unwrap();
796
797            for sequence_index in 0..iteration_inputs.len() {
798                let sequence_id = i32::try_from(sequence_index).unwrap();
799                let embedding = context.embeddings_seq_ith(sequence_id).unwrap();
800
801                assert_eq!(
802                    embedding.len(),
803                    n_embd,
804                    "iteration sequence {sequence_index} embedding length mismatch"
805                );
806
807                collected.push(embedding.to_vec());
808            }
809
810            batch.clear();
811        }
812
813        assert_eq!(
814            collected.len(),
815            iterations.iter().flatten().count(),
816            "expected one embedding per input across every iteration"
817        );
818
819        for (left_index, left) in collected.iter().enumerate() {
820            for (right_index, right) in collected.iter().enumerate().skip(left_index + 1) {
821                assert_ne!(
822                    left, right,
823                    "embedding {left_index} must differ from embedding {right_index} across reused-batch iterations",
824                );
825            }
826        }
827    }
828
829    #[test]
830    #[serial]
831    fn embeddings_ith_returns_valid_embeddings() {
832        let (backend, model) = test_model::load_default_embedding_model().unwrap();
833        let ctx_params = LlamaContextParams::default()
834            .with_n_ctx(std::num::NonZeroU32::new(512))
835            .with_embeddings(true);
836        let mut context = model.new_context(&backend, ctx_params).unwrap();
837        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
838        let last_index = i32::try_from(tokens.len() - 1).unwrap();
839        let mut batch = LlamaBatch::new(512, 1).unwrap();
840        batch.add_sequence(&tokens, 0, false).unwrap();
841        context.decode(&mut batch).unwrap();
842
843        let embeddings = context.embeddings_ith(last_index).unwrap();
844
845        assert_eq!(embeddings.len(), model.n_embd() as usize);
846    }
847
848    #[test]
849    #[serial]
850    fn candidates_ith_returns_n_vocab_entries() {
851        let (backend, model) = test_model::load_default_model().unwrap();
852        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
853        let mut context = model.new_context(&backend, ctx_params).unwrap();
854        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
855        let last_index = i32::try_from(tokens.len() - 1).unwrap();
856        let mut batch = LlamaBatch::new(512, 1).unwrap();
857        batch.add_sequence(&tokens, 0, false).unwrap();
858        context.decode(&mut batch).unwrap();
859
860        let count = context.candidates_ith(last_index).unwrap().count();
861
862        assert_eq!(count, model.n_vocab() as usize);
863    }
864
865    #[test]
866    #[serial]
867    fn lora_adapter_remove_succeeds_with_no_adapters() {
868        let (backend, model) = test_model::load_default_model().unwrap();
869        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
870        let context = model.new_context(&backend, ctx_params).unwrap();
871        let mut adapter = crate::model::LlamaLoraAdapter {
872            lora_adapter: std::ptr::NonNull::dangling(),
873        };
874
875        let result = context.lora_adapter_remove(&mut adapter);
876
877        assert!(result.is_ok());
878    }
879
880    #[test]
881    #[serial]
882    fn encode_on_non_encoder_model_returns_error() {
883        let (backend, model) = test_model::load_default_model().unwrap();
884        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
885        let mut context = model.new_context(&backend, ctx_params).unwrap();
886        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
887        let mut batch = LlamaBatch::new(512, 1).unwrap();
888        batch.add_sequence(&tokens, 0, false).unwrap();
889
890        let result = context.encode(&mut batch);
891
892        assert!(result.is_err());
893    }
894
895    #[test]
896    #[serial]
897    fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
898        let (backend, model) = test_model::load_default_model().unwrap();
899        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
900        let context = model.new_context(&backend, ctx_params).unwrap();
901        let mut adapter = crate::model::LlamaLoraAdapter {
902            lora_adapter: std::ptr::NonNull::dangling(),
903        };
904
905        let result = context.lora_adapter_set(&mut adapter, 1.0);
906
907        assert!(result.is_ok());
908    }
909
910    #[test]
911    #[serial]
912    fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
913        let (backend, model) = test_model::load_default_embedding_model().unwrap();
914        let ctx_params = LlamaContextParams::default()
915            .with_n_ctx(std::num::NonZeroU32::new(512))
916            .with_embeddings(true);
917        let context = model.new_context(&backend, ctx_params).unwrap();
918
919        let result = context.embeddings_ith(999);
920
921        assert!(result.is_err());
922    }
923
924    #[test]
925    #[serial]
926    fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
927        let (backend, model) = test_model::load_default_model().unwrap();
928        let ctx_params = LlamaContextParams::default()
929            .with_n_ctx(std::num::NonZeroU32::new(512))
930            .with_embeddings(true);
931        let mut context = model.new_context(&backend, ctx_params).unwrap();
932        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
933        let mut batch = LlamaBatch::new(512, 1).unwrap();
934        batch.add_sequence(&tokens, 0, false).unwrap();
935        context.decode(&mut batch).unwrap();
936
937        let result = context.embeddings_seq_ith(999);
938
939        assert!(result.is_err());
940    }
941
942    #[test]
943    #[serial]
944    fn decode_empty_batch_returns_error() {
945        let (backend, model) = test_model::load_default_model().unwrap();
946        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
947        let mut context = model.new_context(&backend, ctx_params).unwrap();
948        let mut batch = LlamaBatch::new(512, 1).unwrap();
949
950        let result = context.decode(&mut batch);
951
952        assert!(result.is_err());
953    }
954
955    #[test]
956    #[serial]
957    fn encode_succeeds_with_encoder_model() {
958        let backend = crate::llama_backend::LlamaBackend::init().unwrap();
959        let model_path = test_model::download_encoder_model().unwrap();
960        let model_params = crate::model::params::LlamaModelParams::default();
961        let model =
962            crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
963        let ctx_params = LlamaContextParams::default()
964            .with_n_ctx(std::num::NonZeroU32::new(512))
965            .with_embeddings(true);
966        let mut context = model.new_context(&backend, ctx_params).unwrap();
967        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
968        let mut batch = LlamaBatch::new(512, 1).unwrap();
969        batch.add_sequence(&tokens, 0, false).unwrap();
970
971        let result = context.encode(&mut batch);
972
973        assert!(result.is_ok());
974    }
975
976    #[test]
977    #[serial]
978    fn handle_encode_result_ok_updates_logits() {
979        let (backend, model) = test_model::load_default_model().unwrap();
980        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
981        let mut context = model.new_context(&backend, ctx_params).unwrap();
982        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
983        let mut batch = LlamaBatch::new(512, 1).unwrap();
984        batch.add_sequence(&tokens, 0, true).unwrap();
985
986        let result =
987            context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
988
989        assert!(result.is_ok());
990        assert!(!context.initialized_logits.is_empty());
991    }
992
993    #[test]
994    #[serial]
995    fn set_abort_flag_aborts_decode() {
996        use std::sync::Arc;
997        use std::sync::atomic::AtomicBool;
998        use std::sync::atomic::Ordering;
999
1000        use crate::DecodeError;
1001
1002        let (backend, model) = test_model::load_default_model().unwrap();
1003        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1004        let mut context = model.new_context(&backend, ctx_params).unwrap();
1005        let abort_flag = Arc::new(AtomicBool::new(true));
1006        context.set_abort_flag(abort_flag.clone());
1007
1008        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1009        let mut batch = LlamaBatch::new(512, 1).unwrap();
1010        batch.add_sequence(&tokens, 0, false).unwrap();
1011
1012        let result = context.decode(&mut batch);
1013
1014        assert_eq!(result, Err(DecodeError::Aborted));
1015    }
1016
1017    #[test]
1018    #[serial]
1019    fn set_abort_flag_false_allows_decode() {
1020        use std::sync::Arc;
1021        use std::sync::atomic::AtomicBool;
1022
1023        let (backend, model) = test_model::load_default_model().unwrap();
1024        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1025        let mut context = model.new_context(&backend, ctx_params).unwrap();
1026        let abort_flag = Arc::new(AtomicBool::new(false));
1027        context.set_abort_flag(abort_flag);
1028
1029        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1030        let mut batch = LlamaBatch::new(512, 1).unwrap();
1031        batch.add_sequence(&tokens, 0, false).unwrap();
1032
1033        let result = context.decode(&mut batch);
1034
1035        assert!(result.is_ok());
1036    }
1037
1038    #[test]
1039    #[serial]
1040    fn clear_abort_callback_allows_decode_with_flag_true() {
1041        use std::sync::Arc;
1042        use std::sync::atomic::AtomicBool;
1043
1044        let (backend, model) = test_model::load_default_model().unwrap();
1045        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1046        let mut context = model.new_context(&backend, ctx_params).unwrap();
1047        let abort_flag = Arc::new(AtomicBool::new(true));
1048        context.set_abort_flag(abort_flag);
1049        context.clear_abort_callback();
1050
1051        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1052        let mut batch = LlamaBatch::new(512, 1).unwrap();
1053        batch.add_sequence(&tokens, 0, false).unwrap();
1054
1055        let result = context.decode(&mut batch);
1056
1057        assert!(result.is_ok());
1058    }
1059
1060    #[test]
1061    #[serial]
1062    fn synchronize_completes_without_panic() {
1063        let (backend, model) = test_model::load_default_model().unwrap();
1064        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1065        let context = model.new_context(&backend, ctx_params).unwrap();
1066
1067        context.synchronize();
1068    }
1069
1070    #[test]
1071    #[serial]
1072    fn detach_threadpool_completes_without_panic() {
1073        let (backend, model) = test_model::load_default_model().unwrap();
1074        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1075        let context = model.new_context(&backend, ctx_params).unwrap();
1076
1077        context.detach_threadpool();
1078    }
1079
1080    #[test]
1081    #[serial]
1082    fn mark_logits_initialized_records_token_index() {
1083        let (backend, model) = test_model::load_default_model().unwrap();
1084        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1085        let mut context = model.new_context(&backend, ctx_params).unwrap();
1086
1087        context.mark_logits_initialized(0);
1088
1089        assert_eq!(context.initialized_logits, vec![0]);
1090    }
1091
1092    #[test]
1093    #[serial]
1094    fn print_memory_breakdown_completes_without_panic() {
1095        let (backend, model) = test_model::load_default_model().unwrap();
1096        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1097        let context = model.new_context(&backend, ctx_params).unwrap();
1098
1099        context.print_memory_breakdown();
1100    }
1101
1102    #[test]
1103    #[serial]
1104    fn get_logits_ith_returns_token_not_initialized_for_unknown_index() {
1105        let (backend, model) = test_model::load_default_model().unwrap();
1106        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1107        let context = model.new_context(&backend, ctx_params).unwrap();
1108
1109        let result = context.get_logits_ith(7);
1110
1111        assert!(matches!(
1112            result,
1113            Err(crate::LogitsError::TokenNotInitialized(7))
1114        ));
1115    }
1116
1117    #[test]
1118    #[serial]
1119    fn get_logits_ith_returns_token_index_exceeds_context_for_huge_index() {
1120        let (backend, model) = test_model::load_default_model().unwrap();
1121        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(64));
1122        let mut context = model.new_context(&backend, ctx_params).unwrap();
1123
1124        let huge_index = i32::try_from(context.n_ctx()).unwrap();
1125        context.mark_logits_initialized(huge_index);
1126        let result = context.get_logits_ith(huge_index);
1127
1128        assert!(matches!(
1129            result,
1130            Err(crate::LogitsError::TokenIndexExceedsContext { .. })
1131        ));
1132    }
1133}