Skip to main content

llama_cpp_bindings/
context.rs

1//! Safe wrapper around `llama_context`.
2
3use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15    DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16    LlamaLoraAdapterSetError, LogitsError,
17};
18
19const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
20    if err_code != 0 {
21        return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
22    }
23
24    Ok(())
25}
26
27const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
28    if err_code != 0 {
29        return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
30    }
31
32    Ok(())
33}
34
35pub mod kv_cache;
36pub mod llama_state_seq_flags;
37pub mod load_seq_state_error;
38pub mod load_session_error;
39pub mod params;
40pub mod save_seq_state_error;
41pub mod save_session_error;
42pub mod session;
43
44/// Safe wrapper around `llama_context`.
45pub struct LlamaContext<'model> {
46    /// Raw pointer to the underlying `llama_context`.
47    pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
48    /// A reference to the context's model.
49    pub model: &'model LlamaModel,
50    initialized_logits: Vec<i32>,
51    embeddings_enabled: bool,
52}
53
54impl Debug for LlamaContext<'_> {
55    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56        f.debug_struct("LlamaContext")
57            .field("context", &self.context)
58            .finish()
59    }
60}
61
62impl<'model> LlamaContext<'model> {
63    /// Wraps existing raw pointers into a new `LlamaContext`.
64    #[must_use]
65    pub const fn new(
66        llama_model: &'model LlamaModel,
67        llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
68        embeddings_enabled: bool,
69    ) -> Self {
70        Self {
71            context: llama_context,
72            model: llama_model,
73            initialized_logits: Vec::new(),
74            embeddings_enabled,
75        }
76    }
77
78    /// Gets the max number of logical tokens that can be submitted to decode. Must be greater than or equal to [`Self::n_ubatch`].
79    #[must_use]
80    pub fn n_batch(&self) -> u32 {
81        unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
82    }
83
84    /// Gets the max number of physical tokens (hardware level) to decode in batch. Must be less than or equal to [`Self::n_batch`].
85    #[must_use]
86    pub fn n_ubatch(&self) -> u32 {
87        unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
88    }
89
90    /// Gets the size of the context.
91    #[must_use]
92    pub fn n_ctx(&self) -> u32 {
93        unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
94    }
95
96    /// Decodes the batch.
97    ///
98    /// # Errors
99    ///
100    /// - `DecodeError` if the decoding failed.
101    ///
102    /// # Panics
103    ///
104    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
105    pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
106        let result = unsafe {
107            llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
108        };
109
110        match NonZeroI32::new(result) {
111            None => {
112                self.initialized_logits
113                    .clone_from(&batch.initialized_logits);
114                Ok(())
115            }
116            Some(error) => Err(DecodeError::from(error)),
117        }
118    }
119
120    /// Encodes the batch.
121    ///
122    /// # Errors
123    ///
124    /// - `EncodeError` if the decoding failed.
125    ///
126    /// # Panics
127    ///
128    /// - the returned [`std::ffi::c_int`] from llama-cpp does not fit into a i32 (this should never happen on most systems)
129    pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
130        let status = unsafe {
131            llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
132        };
133
134        self.handle_encode_result(status, batch)
135    }
136
137    fn handle_encode_result(
138        &mut self,
139        status: llama_cpp_bindings_sys::llama_rs_status,
140        batch: &mut LlamaBatch,
141    ) -> Result<(), EncodeError> {
142        if crate::status_is_ok(status) {
143            self.initialized_logits
144                .clone_from(&batch.initialized_logits);
145
146            Ok(())
147        } else {
148            Err(EncodeError::from(
149                NonZeroI32::new(crate::status_to_i32(status))
150                    .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
151            ))
152        }
153    }
154
155    /// Get the embeddings for the given sequence in the current context.
156    ///
157    /// # Returns
158    ///
159    /// A slice containing the embeddings for the last decoded batch.
160    /// The size corresponds to the `n_embd` parameter of the context's model.
161    ///
162    /// # Errors
163    ///
164    /// - When the current context was constructed without enabling embeddings.
165    /// - If the current model had a pooling type of [`llama_cpp_bindings_sys::LLAMA_POOLING_TYPE_NONE`]
166    /// - If the given sequence index exceeds the max sequence id.
167    ///
168    pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
169        if !self.embeddings_enabled {
170            return Err(EmbeddingsError::NotEnabled);
171        }
172
173        let n_embd = usize::try_from(self.model.n_embd())
174            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
175
176        unsafe {
177            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
178                self.context.as_ptr(),
179                sequence_index,
180            );
181
182            if embedding.is_null() {
183                Err(EmbeddingsError::NonePoolType)
184            } else {
185                Ok(slice::from_raw_parts(embedding, n_embd))
186            }
187        }
188    }
189
190    /// Get the embeddings for the given token in the current context.
191    ///
192    /// # Returns
193    ///
194    /// A slice containing the embeddings for the last decoded batch of the given token.
195    /// The size corresponds to the `n_embd` parameter of the context's model.
196    ///
197    /// # Errors
198    ///
199    /// - When the current context was constructed without enabling embeddings.
200    /// - When the given token didn't have logits enabled when it was passed.
201    /// - If the given token index exceeds the max token id.
202    ///
203    pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
204        if !self.embeddings_enabled {
205            return Err(EmbeddingsError::NotEnabled);
206        }
207
208        let n_embd = usize::try_from(self.model.n_embd())
209            .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
210
211        unsafe {
212            let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
213                self.context.as_ptr(),
214                token_index,
215            );
216
217            if embedding.is_null() {
218                Err(EmbeddingsError::LogitsNotEnabled)
219            } else {
220                Ok(slice::from_raw_parts(embedding, n_embd))
221            }
222        }
223    }
224
225    /// Get the logits for the last token in the context.
226    ///
227    /// # Returns
228    /// An iterator over unsorted `LlamaTokenData` containing the
229    /// logits for the last token in the context.
230    ///
231    /// # Errors
232    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
233    pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
234        let logits = self.get_logits()?;
235
236        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
237            let token = LlamaToken::new(token_id);
238            LlamaTokenData::new(token, *logit, 0_f32)
239        }))
240    }
241
242    /// Get the token data array for the last token in the context.
243    ///
244    /// # Errors
245    /// Returns `LogitsError` if logits are null or `n_vocab` overflows.
246    pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
247        Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
248    }
249
250    /// Token logits obtained from the last call to `decode()`.
251    /// The logits for which `batch.logits[i] != 0` are stored contiguously
252    /// in the order they have appeared in the batch.
253    /// Rows: number of tokens for which `batch.logits[i] != 0`
254    /// Cols: `n_vocab`
255    ///
256    /// # Returns
257    ///
258    /// A slice containing the logits for the last decoded token.
259    /// The size corresponds to the `n_vocab` parameter of the context's model.
260    ///
261    /// # Errors
262    /// Returns `LogitsError` if the logits pointer is null or `n_vocab` overflows.
263    pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
264        let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
265
266        if data.is_null() {
267            return Err(LogitsError::NullLogits);
268        }
269
270        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
271
272        Ok(unsafe { slice::from_raw_parts(data, len) })
273    }
274
275    /// Get the logits for the ith token in the context.
276    ///
277    /// # Errors
278    /// Returns `LogitsError` if the token is not initialized or out of range.
279    pub fn candidates_ith(
280        &self,
281        token_index: i32,
282    ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
283        let logits = self.get_logits_ith(token_index)?;
284
285        Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
286            let token = LlamaToken::new(token_id);
287            LlamaTokenData::new(token, *logit, 0_f32)
288        }))
289    }
290
291    /// Get the token data array for the ith token in the context.
292    ///
293    /// # Errors
294    /// Returns `LogitsError` if the token is not initialized or out of range.
295    pub fn token_data_array_ith(
296        &self,
297        token_index: i32,
298    ) -> Result<LlamaTokenDataArray, LogitsError> {
299        Ok(LlamaTokenDataArray::from_iter(
300            self.candidates_ith(token_index)?,
301            false,
302        ))
303    }
304
305    /// Get the logits for the ith token in the context.
306    ///
307    /// # Errors
308    /// Returns `LogitsError` if the token is not initialized, out of range, or `n_vocab` overflows.
309    pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
310        if !self.initialized_logits.contains(&token_index) {
311            return Err(LogitsError::TokenNotInitialized(token_index));
312        }
313
314        let token_index_u32 =
315            u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
316
317        if self.n_ctx() <= token_index_u32 {
318            return Err(LogitsError::TokenIndexExceedsContext {
319                token_index: token_index_u32,
320                context_size: self.n_ctx(),
321            });
322        }
323
324        let data = unsafe {
325            llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
326        };
327        let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
328
329        Ok(unsafe { slice::from_raw_parts(data, len) })
330    }
331
332    /// Reset the timings for the context.
333    pub fn reset_timings(&mut self) {
334        unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
335    }
336
337    /// Returns the timings for the context.
338    pub fn timings(&mut self) -> LlamaTimings {
339        let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
340        LlamaTimings { timings }
341    }
342
343    /// Sets a lora adapter.
344    ///
345    /// # Errors
346    ///
347    /// See [`LlamaLoraAdapterSetError`] for more information.
348    pub fn lora_adapter_set(
349        &self,
350        adapter: &mut LlamaLoraAdapter,
351        scale: f32,
352    ) -> Result<(), LlamaLoraAdapterSetError> {
353        let mut adapters = [adapter.lora_adapter.as_ptr()];
354        let mut scales = [scale];
355        let err_code = unsafe {
356            llama_cpp_bindings_sys::llama_set_adapters_lora(
357                self.context.as_ptr(),
358                adapters.as_mut_ptr(),
359                1,
360                scales.as_mut_ptr(),
361            )
362        };
363        check_lora_set_result(err_code)?;
364
365        tracing::debug!("Set lora adapter");
366        Ok(())
367    }
368
369    /// Remove all lora adapters.
370    ///
371    /// Note: The upstream API now replaces all adapters at once via
372    /// `llama_set_adapters_lora`. This clears all adapters from the context.
373    ///
374    /// # Errors
375    ///
376    /// See [`LlamaLoraAdapterRemoveError`] for more information.
377    pub fn lora_adapter_remove(
378        &self,
379        _adapter: &mut LlamaLoraAdapter,
380    ) -> Result<(), LlamaLoraAdapterRemoveError> {
381        let err_code = unsafe {
382            llama_cpp_bindings_sys::llama_set_adapters_lora(
383                self.context.as_ptr(),
384                std::ptr::null_mut(),
385                0,
386                std::ptr::null_mut(),
387            )
388        };
389        check_lora_remove_result(err_code)?;
390
391        tracing::debug!("Remove lora adapter");
392        Ok(())
393    }
394
395    /// Print a breakdown of per-device memory use to the default logger.
396    pub fn print_memory_breakdown(&self) {
397        unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
398    }
399}
400
401impl Drop for LlamaContext<'_> {
402    fn drop(&mut self) {
403        unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
404    }
405}
406
407#[cfg(test)]
408mod unit_tests {
409    use crate::LlamaLoraAdapterRemoveError;
410    use crate::LlamaLoraAdapterSetError;
411
412    use super::{check_lora_remove_result, check_lora_set_result};
413
414    #[test]
415    fn check_lora_set_result_ok_for_zero() {
416        assert!(check_lora_set_result(0).is_ok());
417    }
418
419    #[test]
420    fn check_lora_set_result_error_for_nonzero() {
421        let result = check_lora_set_result(-1);
422
423        assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
424    }
425
426    #[test]
427    fn check_lora_remove_result_ok_for_zero() {
428        assert!(check_lora_remove_result(0).is_ok());
429    }
430
431    #[test]
432    fn check_lora_remove_result_error_for_nonzero() {
433        let result = check_lora_remove_result(-1);
434
435        assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
436    }
437}
438
439#[cfg(test)]
440#[cfg(feature = "tests_that_use_llms")]
441mod tests {
442    use serial_test::serial;
443
444    use crate::context::params::LlamaContextParams;
445    use crate::llama_batch::LlamaBatch;
446    use crate::model::AddBos;
447    use crate::test_model;
448
449    #[test]
450    #[serial]
451    fn context_creation_and_properties() {
452        let (backend, model) = test_model::load_default_model().unwrap();
453        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
454        let context = model.new_context(&backend, ctx_params).unwrap();
455        assert!(context.n_ctx() > 0);
456        assert!(context.n_batch() > 0);
457        assert!(context.n_ubatch() > 0);
458    }
459
460    #[test]
461    #[serial]
462    fn decode_and_get_logits() {
463        let (backend, model) = test_model::load_default_model().unwrap();
464        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
465        let mut context = model.new_context(&backend, ctx_params).unwrap();
466        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
467        let mut batch = LlamaBatch::new(512, 1).unwrap();
468        batch.add_sequence(&tokens, 0, false).unwrap();
469
470        let decode_result = context.decode(&mut batch);
471        assert!(decode_result.is_ok());
472
473        let logits = context.get_logits().unwrap();
474        assert!(!logits.is_empty());
475    }
476
477    #[test]
478    #[serial]
479    fn timings_work() {
480        let (backend, model) = test_model::load_default_model().unwrap();
481        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
482        let mut context = model.new_context(&backend, ctx_params).unwrap();
483        context.reset_timings();
484        let timings = context.timings();
485        assert!(timings.t_start_ms() >= 0.0);
486    }
487
488    #[test]
489    #[serial]
490    fn token_data_array_has_entries_after_decode() {
491        let (backend, model) = test_model::load_default_model().unwrap();
492        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
493        let mut context = model.new_context(&backend, ctx_params).unwrap();
494        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
495        let mut batch = LlamaBatch::new(512, 1).unwrap();
496        batch.add_sequence(&tokens, 0, false).unwrap();
497        context.decode(&mut batch).unwrap();
498
499        let token_data_array = context.token_data_array().unwrap();
500
501        assert!(!token_data_array.data.is_empty());
502    }
503
504    #[test]
505    #[serial]
506    fn get_logits_ith_returns_valid_slice() {
507        let (backend, model) = test_model::load_default_model().unwrap();
508        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
509        let mut context = model.new_context(&backend, ctx_params).unwrap();
510        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
511        let last_index = i32::try_from(tokens.len() - 1).unwrap();
512        let mut batch = LlamaBatch::new(512, 1).unwrap();
513        batch.add_sequence(&tokens, 0, false).unwrap();
514        context.decode(&mut batch).unwrap();
515
516        let logits = context.get_logits_ith(last_index).unwrap();
517
518        assert_eq!(logits.len(), model.n_vocab() as usize);
519    }
520
521    #[test]
522    #[serial]
523    fn token_data_array_ith_returns_valid_data() {
524        let (backend, model) = test_model::load_default_model().unwrap();
525        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526        let mut context = model.new_context(&backend, ctx_params).unwrap();
527        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
528        let last_index = i32::try_from(tokens.len() - 1).unwrap();
529        let mut batch = LlamaBatch::new(512, 1).unwrap();
530        batch.add_sequence(&tokens, 0, false).unwrap();
531        context.decode(&mut batch).unwrap();
532
533        let token_data_array = context.token_data_array_ith(last_index).unwrap();
534
535        assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
536    }
537
538    #[test]
539    #[serial]
540    fn embeddings_ith_returns_error_when_embeddings_disabled() {
541        let (backend, model) = test_model::load_default_model().unwrap();
542        let ctx_params = LlamaContextParams::default()
543            .with_n_ctx(std::num::NonZeroU32::new(512))
544            .with_embeddings(false);
545        let context = model.new_context(&backend, ctx_params).unwrap();
546
547        let result = context.embeddings_ith(0);
548
549        assert!(result.is_err());
550    }
551
552    #[test]
553    #[serial]
554    fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
555        let (backend, model) = test_model::load_default_model().unwrap();
556        let ctx_params = LlamaContextParams::default()
557            .with_n_ctx(std::num::NonZeroU32::new(512))
558            .with_embeddings(false);
559        let context = model.new_context(&backend, ctx_params).unwrap();
560
561        let result = context.embeddings_seq_ith(0);
562
563        assert!(result.is_err());
564    }
565
566    #[test]
567    #[serial]
568    fn candidates_returns_n_vocab_entries() {
569        let (backend, model) = test_model::load_default_model().unwrap();
570        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
571        let mut context = model.new_context(&backend, ctx_params).unwrap();
572        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
573        let mut batch = LlamaBatch::new(512, 1).unwrap();
574        batch.add_sequence(&tokens, 0, false).unwrap();
575        context.decode(&mut batch).unwrap();
576
577        let count = context.candidates().unwrap().count();
578
579        assert_eq!(count, model.n_vocab() as usize);
580    }
581
582    #[test]
583    #[serial]
584    fn debug_format_contains_struct_name() {
585        let (backend, model) = test_model::load_default_model().unwrap();
586        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
587        let context = model.new_context(&backend, ctx_params).unwrap();
588        let debug_output = format!("{context:?}");
589
590        assert!(debug_output.contains("LlamaContext"));
591    }
592
593    #[test]
594    #[serial]
595    fn decode_with_embeddings_enabled() {
596        let (backend, model) = test_model::load_default_embedding_model().unwrap();
597        let ctx_params = LlamaContextParams::default()
598            .with_n_ctx(std::num::NonZeroU32::new(512))
599            .with_embeddings(true);
600        let mut context = model.new_context(&backend, ctx_params).unwrap();
601        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
602        let mut batch = LlamaBatch::new(512, 1).unwrap();
603        batch.add_sequence(&tokens, 0, false).unwrap();
604
605        let result = context.decode(&mut batch);
606
607        assert!(result.is_ok());
608    }
609
610    #[test]
611    #[serial]
612    fn embeddings_seq_ith_returns_valid_embeddings() {
613        let (backend, model) = test_model::load_default_embedding_model().unwrap();
614        let ctx_params = LlamaContextParams::default()
615            .with_n_ctx(std::num::NonZeroU32::new(512))
616            .with_embeddings(true);
617        let mut context = model.new_context(&backend, ctx_params).unwrap();
618        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
619        let mut batch = LlamaBatch::new(512, 1).unwrap();
620        batch.add_sequence(&tokens, 0, false).unwrap();
621        context.decode(&mut batch).unwrap();
622
623        let embeddings = context.embeddings_seq_ith(0).unwrap();
624
625        assert_eq!(embeddings.len(), model.n_embd() as usize);
626    }
627
628    #[test]
629    #[serial]
630    fn embeddings_ith_returns_valid_embeddings() {
631        let (backend, model) = test_model::load_default_embedding_model().unwrap();
632        let ctx_params = LlamaContextParams::default()
633            .with_n_ctx(std::num::NonZeroU32::new(512))
634            .with_embeddings(true);
635        let mut context = model.new_context(&backend, ctx_params).unwrap();
636        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
637        let last_index = i32::try_from(tokens.len() - 1).unwrap();
638        let mut batch = LlamaBatch::new(512, 1).unwrap();
639        batch.add_sequence(&tokens, 0, false).unwrap();
640        context.decode(&mut batch).unwrap();
641
642        let embeddings = context.embeddings_ith(last_index).unwrap();
643
644        assert_eq!(embeddings.len(), model.n_embd() as usize);
645    }
646
647    #[test]
648    #[serial]
649    fn candidates_ith_returns_n_vocab_entries() {
650        let (backend, model) = test_model::load_default_model().unwrap();
651        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
652        let mut context = model.new_context(&backend, ctx_params).unwrap();
653        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
654        let last_index = i32::try_from(tokens.len() - 1).unwrap();
655        let mut batch = LlamaBatch::new(512, 1).unwrap();
656        batch.add_sequence(&tokens, 0, false).unwrap();
657        context.decode(&mut batch).unwrap();
658
659        let count = context.candidates_ith(last_index).unwrap().count();
660
661        assert_eq!(count, model.n_vocab() as usize);
662    }
663
664    #[test]
665    #[serial]
666    fn lora_adapter_remove_succeeds_with_no_adapters() {
667        let (backend, model) = test_model::load_default_model().unwrap();
668        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
669        let context = model.new_context(&backend, ctx_params).unwrap();
670        let mut adapter = crate::model::LlamaLoraAdapter {
671            lora_adapter: std::ptr::NonNull::dangling(),
672        };
673
674        let result = context.lora_adapter_remove(&mut adapter);
675
676        assert!(result.is_ok());
677    }
678
679    #[test]
680    #[serial]
681    fn encode_on_non_encoder_model_returns_error() {
682        let (backend, model) = test_model::load_default_model().unwrap();
683        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
684        let mut context = model.new_context(&backend, ctx_params).unwrap();
685        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
686        let mut batch = LlamaBatch::new(512, 1).unwrap();
687        batch.add_sequence(&tokens, 0, false).unwrap();
688
689        let result = context.encode(&mut batch);
690
691        assert!(result.is_err());
692    }
693
694    #[test]
695    #[serial]
696    fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
697        let (backend, model) = test_model::load_default_model().unwrap();
698        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
699        let context = model.new_context(&backend, ctx_params).unwrap();
700        let mut adapter = crate::model::LlamaLoraAdapter {
701            lora_adapter: std::ptr::NonNull::dangling(),
702        };
703
704        let result = context.lora_adapter_set(&mut adapter, 1.0);
705
706        assert!(result.is_ok());
707    }
708
709    #[test]
710    #[serial]
711    fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
712        let (backend, model) = test_model::load_default_embedding_model().unwrap();
713        let ctx_params = LlamaContextParams::default()
714            .with_n_ctx(std::num::NonZeroU32::new(512))
715            .with_embeddings(true);
716        let context = model.new_context(&backend, ctx_params).unwrap();
717
718        let result = context.embeddings_ith(999);
719
720        assert!(result.is_err());
721    }
722
723    #[test]
724    #[serial]
725    fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
726        let (backend, model) = test_model::load_default_model().unwrap();
727        let ctx_params = LlamaContextParams::default()
728            .with_n_ctx(std::num::NonZeroU32::new(512))
729            .with_embeddings(true);
730        let mut context = model.new_context(&backend, ctx_params).unwrap();
731        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
732        let mut batch = LlamaBatch::new(512, 1).unwrap();
733        batch.add_sequence(&tokens, 0, false).unwrap();
734        context.decode(&mut batch).unwrap();
735
736        let result = context.embeddings_seq_ith(999);
737
738        assert!(result.is_err());
739    }
740
741    #[test]
742    #[serial]
743    fn decode_empty_batch_returns_error() {
744        let (backend, model) = test_model::load_default_model().unwrap();
745        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
746        let mut context = model.new_context(&backend, ctx_params).unwrap();
747        let mut batch = LlamaBatch::new(512, 1).unwrap();
748
749        let result = context.decode(&mut batch);
750
751        assert!(result.is_err());
752    }
753
754    #[test]
755    #[serial]
756    fn encode_succeeds_with_encoder_model() {
757        let backend = crate::llama_backend::LlamaBackend::init().unwrap();
758        let model_path = test_model::download_encoder_model().unwrap();
759        let model_params = crate::model::params::LlamaModelParams::default();
760        let model =
761            crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
762        let ctx_params = LlamaContextParams::default()
763            .with_n_ctx(std::num::NonZeroU32::new(512))
764            .with_embeddings(true);
765        let mut context = model.new_context(&backend, ctx_params).unwrap();
766        let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
767        let mut batch = LlamaBatch::new(512, 1).unwrap();
768        batch.add_sequence(&tokens, 0, false).unwrap();
769
770        let result = context.encode(&mut batch);
771
772        assert!(result.is_ok());
773    }
774
775    #[test]
776    #[serial]
777    fn handle_encode_result_ok_updates_logits() {
778        let (backend, model) = test_model::load_default_model().unwrap();
779        let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
780        let mut context = model.new_context(&backend, ctx_params).unwrap();
781        let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
782        let mut batch = LlamaBatch::new(512, 1).unwrap();
783        batch.add_sequence(&tokens, 0, true).unwrap();
784
785        let result =
786            context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
787
788        assert!(result.is_ok());
789        assert!(!context.initialized_logits.is_empty());
790    }
791}