1use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
20 LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24 if err_code != 0 {
25 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26 }
27
28 Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32 if err_code != 0 {
33 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34 }
35
36 Ok(())
37}
38
39pub mod kv_cache;
40pub mod llama_state_seq_flags;
41pub mod load_seq_state_error;
42pub mod load_session_error;
43pub mod params;
44pub mod save_seq_state_error;
45pub mod save_session_error;
46pub mod session;
47
48unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
49 let flag = unsafe { &*(data as *const AtomicBool) };
50
51 flag.load(Ordering::Relaxed)
52}
53
54pub struct LlamaContext<'model> {
56 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
58 pub model: &'model LlamaModel,
60 abort_flag: Option<Arc<AtomicBool>>,
61 initialized_logits: Vec<i32>,
62 embeddings_enabled: bool,
63}
64
65impl Debug for LlamaContext<'_> {
66 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("LlamaContext")
68 .field("context", &self.context)
69 .finish()
70 }
71}
72
73impl<'model> LlamaContext<'model> {
74 #[must_use]
76 pub const fn new(
77 llama_model: &'model LlamaModel,
78 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79 embeddings_enabled: bool,
80 ) -> Self {
81 Self {
82 context: llama_context,
83 model: llama_model,
84 abort_flag: None,
85 initialized_logits: Vec::new(),
86 embeddings_enabled,
87 }
88 }
89
90 #[must_use]
92 pub fn n_batch(&self) -> u32 {
93 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
94 }
95
96 #[must_use]
98 pub fn n_ubatch(&self) -> u32 {
99 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
100 }
101
102 #[must_use]
104 pub fn n_ctx(&self) -> u32 {
105 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
106 }
107
108 #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
114 pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
115 let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
116 self.abort_flag = Some(flag);
117
118 unsafe {
119 llama_cpp_bindings_sys::llama_set_abort_callback(
120 self.context.as_ptr(),
121 Some(abort_callback_trampoline),
122 raw_ptr,
123 );
124 }
125 }
126
127 #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
129 pub fn clear_abort_callback(&mut self) {
130 self.abort_flag = None;
131
132 unsafe {
133 llama_cpp_bindings_sys::llama_set_abort_callback(
134 self.context.as_ptr(),
135 None,
136 std::ptr::null_mut(),
137 );
138 }
139 }
140
141 #[expect(unsafe_code, reason = "required for FFI synchronization call")]
146 pub fn synchronize(&self) {
147 unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
148 }
149
150 #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
155 pub fn detach_threadpool(&self) {
156 unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
157 }
158
159 pub fn mark_logits_initialized(&mut self, token_index: i32) {
163 self.initialized_logits = vec![token_index];
164 }
165
166 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
176 let result = unsafe {
177 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
178 };
179
180 match NonZeroI32::new(result) {
181 None => {
182 self.initialized_logits
183 .clone_from(&batch.initialized_logits);
184 Ok(())
185 }
186 Some(error) => Err(DecodeError::from(error)),
187 }
188 }
189
190 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
200 let status = unsafe {
201 llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
202 };
203
204 self.handle_encode_result(status, batch)
205 }
206
207 fn handle_encode_result(
208 &mut self,
209 status: llama_cpp_bindings_sys::llama_rs_status,
210 batch: &mut LlamaBatch,
211 ) -> Result<(), EncodeError> {
212 if crate::status_is_ok(status) {
213 self.initialized_logits
214 .clone_from(&batch.initialized_logits);
215
216 Ok(())
217 } else {
218 Err(EncodeError::from(
219 NonZeroI32::new(crate::status_to_i32(status))
220 .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
221 ))
222 }
223 }
224
225 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
239 if !self.embeddings_enabled {
240 return Err(EmbeddingsError::NotEnabled);
241 }
242
243 let n_embd = usize::try_from(self.model.n_embd())
244 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
245
246 unsafe {
247 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
248 self.context.as_ptr(),
249 sequence_index,
250 );
251
252 if embedding.is_null() {
253 Err(EmbeddingsError::NonePoolType)
254 } else {
255 Ok(slice::from_raw_parts(embedding, n_embd))
256 }
257 }
258 }
259
260 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
274 if !self.embeddings_enabled {
275 return Err(EmbeddingsError::NotEnabled);
276 }
277
278 let n_embd = usize::try_from(self.model.n_embd())
279 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
280
281 unsafe {
282 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
283 self.context.as_ptr(),
284 token_index,
285 );
286
287 if embedding.is_null() {
288 Err(EmbeddingsError::LogitsNotEnabled)
289 } else {
290 Ok(slice::from_raw_parts(embedding, n_embd))
291 }
292 }
293 }
294
295 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
304 let logits = self.get_logits()?;
305
306 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
307 let token = LlamaToken::new(token_id);
308 LlamaTokenData::new(token, *logit, 0_f32)
309 }))
310 }
311
312 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
317 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
318 }
319
320 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
334 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
335
336 if data.is_null() {
337 return Err(LogitsError::NullLogits);
338 }
339
340 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
341
342 Ok(unsafe { slice::from_raw_parts(data, len) })
343 }
344
345 pub fn candidates_ith(
350 &self,
351 token_index: i32,
352 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
353 let logits = self.get_logits_ith(token_index)?;
354
355 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
356 let token = LlamaToken::new(token_id);
357 LlamaTokenData::new(token, *logit, 0_f32)
358 }))
359 }
360
361 pub fn token_data_array_ith(
366 &self,
367 token_index: i32,
368 ) -> Result<LlamaTokenDataArray, LogitsError> {
369 Ok(LlamaTokenDataArray::from_iter(
370 self.candidates_ith(token_index)?,
371 false,
372 ))
373 }
374
375 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
380 if !self.initialized_logits.contains(&token_index) {
381 return Err(LogitsError::TokenNotInitialized(token_index));
382 }
383
384 if token_index >= 0 {
385 let token_index_u32 =
386 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
387
388 if self.n_ctx() <= token_index_u32 {
389 return Err(LogitsError::TokenIndexExceedsContext {
390 token_index: token_index_u32,
391 context_size: self.n_ctx(),
392 });
393 }
394 }
395
396 let data = unsafe {
397 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
398 };
399 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
400
401 Ok(unsafe { slice::from_raw_parts(data, len) })
402 }
403
404 pub fn reset_timings(&mut self) {
406 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
407 }
408
409 pub fn timings(&mut self) -> LlamaTimings {
411 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
412 LlamaTimings { timings }
413 }
414
415 pub fn lora_adapter_set(
421 &self,
422 adapter: &mut LlamaLoraAdapter,
423 scale: f32,
424 ) -> Result<(), LlamaLoraAdapterSetError> {
425 let mut adapters = [adapter.lora_adapter.as_ptr()];
426 let mut scales = [scale];
427 let err_code = unsafe {
428 llama_cpp_bindings_sys::llama_set_adapters_lora(
429 self.context.as_ptr(),
430 adapters.as_mut_ptr(),
431 1,
432 scales.as_mut_ptr(),
433 )
434 };
435 check_lora_set_result(err_code)?;
436
437 tracing::debug!("Set lora adapter");
438 Ok(())
439 }
440
441 pub fn lora_adapter_remove(
450 &self,
451 _adapter: &mut LlamaLoraAdapter,
452 ) -> Result<(), LlamaLoraAdapterRemoveError> {
453 let err_code = unsafe {
454 llama_cpp_bindings_sys::llama_set_adapters_lora(
455 self.context.as_ptr(),
456 std::ptr::null_mut(),
457 0,
458 std::ptr::null_mut(),
459 )
460 };
461 check_lora_remove_result(err_code)?;
462
463 tracing::debug!("Remove lora adapter");
464 Ok(())
465 }
466
467 pub fn print_memory_breakdown(&self) {
469 unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
470 }
471}
472
473impl Drop for LlamaContext<'_> {
474 fn drop(&mut self) {
475 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
476 }
477}
478
479#[cfg(test)]
480mod unit_tests {
481 use crate::LlamaLoraAdapterRemoveError;
482 use crate::LlamaLoraAdapterSetError;
483
484 use super::{check_lora_remove_result, check_lora_set_result};
485
486 #[test]
487 fn check_lora_set_result_ok_for_zero() {
488 assert!(check_lora_set_result(0).is_ok());
489 }
490
491 #[test]
492 fn check_lora_set_result_error_for_nonzero() {
493 let result = check_lora_set_result(-1);
494
495 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
496 }
497
498 #[test]
499 fn check_lora_remove_result_ok_for_zero() {
500 assert!(check_lora_remove_result(0).is_ok());
501 }
502
503 #[test]
504 fn check_lora_remove_result_error_for_nonzero() {
505 let result = check_lora_remove_result(-1);
506
507 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
508 }
509}
510
511#[cfg(test)]
512#[cfg(feature = "tests_that_use_llms")]
513mod tests {
514 use serial_test::serial;
515
516 use crate::context::params::LlamaContextParams;
517 use crate::llama_batch::LlamaBatch;
518 use crate::model::AddBos;
519 use crate::test_model;
520
521 #[test]
522 #[serial]
523 fn context_creation_and_properties() {
524 let (backend, model) = test_model::load_default_model().unwrap();
525 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526 let context = model.new_context(&backend, ctx_params).unwrap();
527 assert!(context.n_ctx() > 0);
528 assert!(context.n_batch() > 0);
529 assert!(context.n_ubatch() > 0);
530 }
531
532 #[test]
533 #[serial]
534 fn decode_and_get_logits() {
535 let (backend, model) = test_model::load_default_model().unwrap();
536 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
537 let mut context = model.new_context(&backend, ctx_params).unwrap();
538 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
539 let mut batch = LlamaBatch::new(512, 1).unwrap();
540 batch.add_sequence(&tokens, 0, false).unwrap();
541
542 let decode_result = context.decode(&mut batch);
543 assert!(decode_result.is_ok());
544
545 let logits = context.get_logits().unwrap();
546 assert!(!logits.is_empty());
547 }
548
549 #[test]
550 #[serial]
551 fn timings_work() {
552 let (backend, model) = test_model::load_default_model().unwrap();
553 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
554 let mut context = model.new_context(&backend, ctx_params).unwrap();
555 context.reset_timings();
556 let timings = context.timings();
557 assert!(timings.t_start_ms() >= 0.0);
558 }
559
560 #[test]
561 #[serial]
562 fn token_data_array_has_entries_after_decode() {
563 let (backend, model) = test_model::load_default_model().unwrap();
564 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
565 let mut context = model.new_context(&backend, ctx_params).unwrap();
566 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
567 let mut batch = LlamaBatch::new(512, 1).unwrap();
568 batch.add_sequence(&tokens, 0, false).unwrap();
569 context.decode(&mut batch).unwrap();
570
571 let token_data_array = context.token_data_array().unwrap();
572
573 assert!(!token_data_array.data.is_empty());
574 }
575
576 #[test]
577 #[serial]
578 fn get_logits_ith_returns_valid_slice() {
579 let (backend, model) = test_model::load_default_model().unwrap();
580 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
581 let mut context = model.new_context(&backend, ctx_params).unwrap();
582 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
583 let last_index = i32::try_from(tokens.len() - 1).unwrap();
584 let mut batch = LlamaBatch::new(512, 1).unwrap();
585 batch.add_sequence(&tokens, 0, false).unwrap();
586 context.decode(&mut batch).unwrap();
587
588 let logits = context.get_logits_ith(last_index).unwrap();
589
590 assert_eq!(logits.len(), model.n_vocab() as usize);
591 }
592
593 #[test]
594 #[serial]
595 fn token_data_array_ith_returns_valid_data() {
596 let (backend, model) = test_model::load_default_model().unwrap();
597 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
598 let mut context = model.new_context(&backend, ctx_params).unwrap();
599 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
600 let last_index = i32::try_from(tokens.len() - 1).unwrap();
601 let mut batch = LlamaBatch::new(512, 1).unwrap();
602 batch.add_sequence(&tokens, 0, false).unwrap();
603 context.decode(&mut batch).unwrap();
604
605 let token_data_array = context.token_data_array_ith(last_index).unwrap();
606
607 assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
608 }
609
610 #[test]
611 #[serial]
612 fn embeddings_ith_returns_error_when_embeddings_disabled() {
613 let (backend, model) = test_model::load_default_model().unwrap();
614 let ctx_params = LlamaContextParams::default()
615 .with_n_ctx(std::num::NonZeroU32::new(512))
616 .with_embeddings(false);
617 let context = model.new_context(&backend, ctx_params).unwrap();
618
619 let result = context.embeddings_ith(0);
620
621 assert!(result.is_err());
622 }
623
624 #[test]
625 #[serial]
626 fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
627 let (backend, model) = test_model::load_default_model().unwrap();
628 let ctx_params = LlamaContextParams::default()
629 .with_n_ctx(std::num::NonZeroU32::new(512))
630 .with_embeddings(false);
631 let context = model.new_context(&backend, ctx_params).unwrap();
632
633 let result = context.embeddings_seq_ith(0);
634
635 assert!(result.is_err());
636 }
637
638 #[test]
639 #[serial]
640 fn candidates_returns_n_vocab_entries() {
641 let (backend, model) = test_model::load_default_model().unwrap();
642 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
643 let mut context = model.new_context(&backend, ctx_params).unwrap();
644 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
645 let mut batch = LlamaBatch::new(512, 1).unwrap();
646 batch.add_sequence(&tokens, 0, false).unwrap();
647 context.decode(&mut batch).unwrap();
648
649 let count = context.candidates().unwrap().count();
650
651 assert_eq!(count, model.n_vocab() as usize);
652 }
653
654 #[test]
655 #[serial]
656 fn debug_format_contains_struct_name() {
657 let (backend, model) = test_model::load_default_model().unwrap();
658 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
659 let context = model.new_context(&backend, ctx_params).unwrap();
660 let debug_output = format!("{context:?}");
661
662 assert!(debug_output.contains("LlamaContext"));
663 }
664
665 #[test]
666 #[serial]
667 fn decode_with_embeddings_enabled() {
668 let (backend, model) = test_model::load_default_embedding_model().unwrap();
669 let ctx_params = LlamaContextParams::default()
670 .with_n_ctx(std::num::NonZeroU32::new(512))
671 .with_embeddings(true);
672 let mut context = model.new_context(&backend, ctx_params).unwrap();
673 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
674 let mut batch = LlamaBatch::new(512, 1).unwrap();
675 batch.add_sequence(&tokens, 0, false).unwrap();
676
677 let result = context.decode(&mut batch);
678
679 assert!(result.is_ok());
680 }
681
682 #[test]
683 #[serial]
684 fn embeddings_seq_ith_returns_valid_embeddings() {
685 let (backend, model) = test_model::load_default_embedding_model().unwrap();
686 let ctx_params = LlamaContextParams::default()
687 .with_n_ctx(std::num::NonZeroU32::new(512))
688 .with_embeddings(true);
689 let mut context = model.new_context(&backend, ctx_params).unwrap();
690 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
691 let mut batch = LlamaBatch::new(512, 1).unwrap();
692 batch.add_sequence(&tokens, 0, false).unwrap();
693 context.decode(&mut batch).unwrap();
694
695 let embeddings = context.embeddings_seq_ith(0).unwrap();
696
697 assert_eq!(embeddings.len(), model.n_embd() as usize);
698 }
699
700 #[test]
701 #[serial]
702 fn multi_sequence_embeddings_returns_one_embedding_per_sequence() {
703 let (backend, model) = test_model::load_default_embedding_model().unwrap();
704 let ctx_params = LlamaContextParams::default()
705 .with_n_ctx(std::num::NonZeroU32::new(512))
706 .with_n_seq_max(4)
707 .with_embeddings(true);
708 let mut context = model.new_context(&backend, ctx_params).unwrap();
709
710 let inputs = [
711 "alpha is here",
712 "beta runs fast",
713 "gamma waits",
714 "delta jumps",
715 ];
716 let mut batch = LlamaBatch::new(64, 4).unwrap();
717
718 for (sequence_index, text) in inputs.iter().enumerate() {
719 let tokens = model.str_to_token(text, AddBos::Always).unwrap();
720 let sequence_id = i32::try_from(sequence_index).unwrap();
721
722 batch.add_sequence(&tokens, sequence_id, true).unwrap();
723 }
724
725 context.decode(&mut batch).unwrap();
726
727 let n_embd = model.n_embd() as usize;
728 let mut collected: Vec<Vec<f32>> = Vec::with_capacity(inputs.len());
729
730 for sequence_index in 0..inputs.len() {
731 let sequence_id = i32::try_from(sequence_index).unwrap();
732 let embedding = context.embeddings_seq_ith(sequence_id).unwrap();
733
734 assert_eq!(
735 embedding.len(),
736 n_embd,
737 "sequence {sequence_index} embedding length mismatch"
738 );
739
740 collected.push(embedding.to_vec());
741 }
742
743 for (left_index, left) in collected.iter().enumerate() {
744 for (right_index, right) in collected.iter().enumerate().skip(left_index + 1) {
745 assert_ne!(
746 left, right,
747 "embedding for sequence {left_index} must differ from sequence {right_index}",
748 );
749 }
750 }
751 }
752
753 #[test]
762 #[serial]
763 fn embeddings_returns_distinct_values_when_reused_batch_has_extra_capacity() {
764 let (backend, model) = test_model::load_default_embedding_model().unwrap();
765 let ctx_params = LlamaContextParams::default()
766 .with_n_ctx(std::num::NonZeroU32::new(512))
767 .with_n_seq_max(4)
768 .with_embeddings(true);
769 let mut context = model.new_context(&backend, ctx_params).unwrap();
770
771 let iterations = [
772 [
773 "This is the first document with enough content to contribute meaningfully to the batch size calculation",
774 "This is the second document that should be processed in a potentially different batch from the first",
775 ],
776 [
777 "This is the third document adding more content to ensure the total exceeds the configured chunk limit",
778 "This is the fourth document which should demonstrate that batching distributes across agent requests",
779 ],
780 ];
781
782 let n_embd = model.n_embd() as usize;
783 let mut batch = LlamaBatch::new(64, 4).unwrap();
784 let mut collected: Vec<Vec<f32>> = Vec::new();
785
786 for iteration_inputs in iterations {
787 for (sequence_index, text) in iteration_inputs.iter().enumerate() {
788 let tokens = model.str_to_token(text, AddBos::Always).unwrap();
789 let sequence_id = i32::try_from(sequence_index).unwrap();
790
791 batch.add_sequence(&tokens, sequence_id, true).unwrap();
792 }
793
794 context.clear_kv_cache();
795 context.decode(&mut batch).unwrap();
796
797 for sequence_index in 0..iteration_inputs.len() {
798 let sequence_id = i32::try_from(sequence_index).unwrap();
799 let embedding = context.embeddings_seq_ith(sequence_id).unwrap();
800
801 assert_eq!(
802 embedding.len(),
803 n_embd,
804 "iteration sequence {sequence_index} embedding length mismatch"
805 );
806
807 collected.push(embedding.to_vec());
808 }
809
810 batch.clear();
811 }
812
813 assert_eq!(
814 collected.len(),
815 iterations.iter().flatten().count(),
816 "expected one embedding per input across every iteration"
817 );
818
819 for (left_index, left) in collected.iter().enumerate() {
820 for (right_index, right) in collected.iter().enumerate().skip(left_index + 1) {
821 assert_ne!(
822 left, right,
823 "embedding {left_index} must differ from embedding {right_index} across reused-batch iterations",
824 );
825 }
826 }
827 }
828
829 #[test]
830 #[serial]
831 fn embeddings_ith_returns_valid_embeddings() {
832 let (backend, model) = test_model::load_default_embedding_model().unwrap();
833 let ctx_params = LlamaContextParams::default()
834 .with_n_ctx(std::num::NonZeroU32::new(512))
835 .with_embeddings(true);
836 let mut context = model.new_context(&backend, ctx_params).unwrap();
837 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
838 let last_index = i32::try_from(tokens.len() - 1).unwrap();
839 let mut batch = LlamaBatch::new(512, 1).unwrap();
840 batch.add_sequence(&tokens, 0, false).unwrap();
841 context.decode(&mut batch).unwrap();
842
843 let embeddings = context.embeddings_ith(last_index).unwrap();
844
845 assert_eq!(embeddings.len(), model.n_embd() as usize);
846 }
847
848 #[test]
849 #[serial]
850 fn candidates_ith_returns_n_vocab_entries() {
851 let (backend, model) = test_model::load_default_model().unwrap();
852 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
853 let mut context = model.new_context(&backend, ctx_params).unwrap();
854 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
855 let last_index = i32::try_from(tokens.len() - 1).unwrap();
856 let mut batch = LlamaBatch::new(512, 1).unwrap();
857 batch.add_sequence(&tokens, 0, false).unwrap();
858 context.decode(&mut batch).unwrap();
859
860 let count = context.candidates_ith(last_index).unwrap().count();
861
862 assert_eq!(count, model.n_vocab() as usize);
863 }
864
865 #[test]
866 #[serial]
867 fn lora_adapter_remove_succeeds_with_no_adapters() {
868 let (backend, model) = test_model::load_default_model().unwrap();
869 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
870 let context = model.new_context(&backend, ctx_params).unwrap();
871 let mut adapter = crate::model::LlamaLoraAdapter {
872 lora_adapter: std::ptr::NonNull::dangling(),
873 };
874
875 let result = context.lora_adapter_remove(&mut adapter);
876
877 assert!(result.is_ok());
878 }
879
880 #[test]
881 #[serial]
882 fn encode_on_non_encoder_model_returns_error() {
883 let (backend, model) = test_model::load_default_model().unwrap();
884 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
885 let mut context = model.new_context(&backend, ctx_params).unwrap();
886 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
887 let mut batch = LlamaBatch::new(512, 1).unwrap();
888 batch.add_sequence(&tokens, 0, false).unwrap();
889
890 let result = context.encode(&mut batch);
891
892 assert!(result.is_err());
893 }
894
895 #[test]
896 #[serial]
897 fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
898 let (backend, model) = test_model::load_default_model().unwrap();
899 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
900 let context = model.new_context(&backend, ctx_params).unwrap();
901 let mut adapter = crate::model::LlamaLoraAdapter {
902 lora_adapter: std::ptr::NonNull::dangling(),
903 };
904
905 let result = context.lora_adapter_set(&mut adapter, 1.0);
906
907 assert!(result.is_ok());
908 }
909
910 #[test]
911 #[serial]
912 fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
913 let (backend, model) = test_model::load_default_embedding_model().unwrap();
914 let ctx_params = LlamaContextParams::default()
915 .with_n_ctx(std::num::NonZeroU32::new(512))
916 .with_embeddings(true);
917 let context = model.new_context(&backend, ctx_params).unwrap();
918
919 let result = context.embeddings_ith(999);
920
921 assert!(result.is_err());
922 }
923
924 #[test]
925 #[serial]
926 fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
927 let (backend, model) = test_model::load_default_model().unwrap();
928 let ctx_params = LlamaContextParams::default()
929 .with_n_ctx(std::num::NonZeroU32::new(512))
930 .with_embeddings(true);
931 let mut context = model.new_context(&backend, ctx_params).unwrap();
932 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
933 let mut batch = LlamaBatch::new(512, 1).unwrap();
934 batch.add_sequence(&tokens, 0, false).unwrap();
935 context.decode(&mut batch).unwrap();
936
937 let result = context.embeddings_seq_ith(999);
938
939 assert!(result.is_err());
940 }
941
942 #[test]
943 #[serial]
944 fn decode_empty_batch_returns_error() {
945 let (backend, model) = test_model::load_default_model().unwrap();
946 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
947 let mut context = model.new_context(&backend, ctx_params).unwrap();
948 let mut batch = LlamaBatch::new(512, 1).unwrap();
949
950 let result = context.decode(&mut batch);
951
952 assert!(result.is_err());
953 }
954
955 #[test]
956 #[serial]
957 fn encode_succeeds_with_encoder_model() {
958 let backend = crate::llama_backend::LlamaBackend::init().unwrap();
959 let model_path = test_model::download_encoder_model().unwrap();
960 let model_params = crate::model::params::LlamaModelParams::default();
961 let model =
962 crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
963 let ctx_params = LlamaContextParams::default()
964 .with_n_ctx(std::num::NonZeroU32::new(512))
965 .with_embeddings(true);
966 let mut context = model.new_context(&backend, ctx_params).unwrap();
967 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
968 let mut batch = LlamaBatch::new(512, 1).unwrap();
969 batch.add_sequence(&tokens, 0, false).unwrap();
970
971 let result = context.encode(&mut batch);
972
973 assert!(result.is_ok());
974 }
975
976 #[test]
977 #[serial]
978 fn handle_encode_result_ok_updates_logits() {
979 let (backend, model) = test_model::load_default_model().unwrap();
980 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
981 let mut context = model.new_context(&backend, ctx_params).unwrap();
982 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
983 let mut batch = LlamaBatch::new(512, 1).unwrap();
984 batch.add_sequence(&tokens, 0, true).unwrap();
985
986 let result =
987 context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
988
989 assert!(result.is_ok());
990 assert!(!context.initialized_logits.is_empty());
991 }
992
993 #[test]
994 #[serial]
995 fn set_abort_flag_aborts_decode() {
996 use std::sync::Arc;
997 use std::sync::atomic::AtomicBool;
998 use std::sync::atomic::Ordering;
999
1000 use crate::DecodeError;
1001
1002 let (backend, model) = test_model::load_default_model().unwrap();
1003 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1004 let mut context = model.new_context(&backend, ctx_params).unwrap();
1005 let abort_flag = Arc::new(AtomicBool::new(true));
1006 context.set_abort_flag(abort_flag.clone());
1007
1008 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1009 let mut batch = LlamaBatch::new(512, 1).unwrap();
1010 batch.add_sequence(&tokens, 0, false).unwrap();
1011
1012 let result = context.decode(&mut batch);
1013
1014 assert_eq!(result, Err(DecodeError::Aborted));
1015 }
1016
1017 #[test]
1018 #[serial]
1019 fn set_abort_flag_false_allows_decode() {
1020 use std::sync::Arc;
1021 use std::sync::atomic::AtomicBool;
1022
1023 let (backend, model) = test_model::load_default_model().unwrap();
1024 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1025 let mut context = model.new_context(&backend, ctx_params).unwrap();
1026 let abort_flag = Arc::new(AtomicBool::new(false));
1027 context.set_abort_flag(abort_flag);
1028
1029 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1030 let mut batch = LlamaBatch::new(512, 1).unwrap();
1031 batch.add_sequence(&tokens, 0, false).unwrap();
1032
1033 let result = context.decode(&mut batch);
1034
1035 assert!(result.is_ok());
1036 }
1037
1038 #[test]
1039 #[serial]
1040 fn clear_abort_callback_allows_decode_with_flag_true() {
1041 use std::sync::Arc;
1042 use std::sync::atomic::AtomicBool;
1043
1044 let (backend, model) = test_model::load_default_model().unwrap();
1045 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1046 let mut context = model.new_context(&backend, ctx_params).unwrap();
1047 let abort_flag = Arc::new(AtomicBool::new(true));
1048 context.set_abort_flag(abort_flag);
1049 context.clear_abort_callback();
1050
1051 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
1052 let mut batch = LlamaBatch::new(512, 1).unwrap();
1053 batch.add_sequence(&tokens, 0, false).unwrap();
1054
1055 let result = context.decode(&mut batch);
1056
1057 assert!(result.is_ok());
1058 }
1059
1060 #[test]
1061 #[serial]
1062 fn synchronize_completes_without_panic() {
1063 let (backend, model) = test_model::load_default_model().unwrap();
1064 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1065 let context = model.new_context(&backend, ctx_params).unwrap();
1066
1067 context.synchronize();
1068 }
1069
1070 #[test]
1071 #[serial]
1072 fn detach_threadpool_completes_without_panic() {
1073 let (backend, model) = test_model::load_default_model().unwrap();
1074 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1075 let context = model.new_context(&backend, ctx_params).unwrap();
1076
1077 context.detach_threadpool();
1078 }
1079
1080 #[test]
1081 #[serial]
1082 fn mark_logits_initialized_records_token_index() {
1083 let (backend, model) = test_model::load_default_model().unwrap();
1084 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1085 let mut context = model.new_context(&backend, ctx_params).unwrap();
1086
1087 context.mark_logits_initialized(0);
1088
1089 assert_eq!(context.initialized_logits, vec![0]);
1090 }
1091
1092 #[test]
1093 #[serial]
1094 fn print_memory_breakdown_completes_without_panic() {
1095 let (backend, model) = test_model::load_default_model().unwrap();
1096 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1097 let context = model.new_context(&backend, ctx_params).unwrap();
1098
1099 context.print_memory_breakdown();
1100 }
1101
1102 #[test]
1103 #[serial]
1104 fn get_logits_ith_returns_token_not_initialized_for_unknown_index() {
1105 let (backend, model) = test_model::load_default_model().unwrap();
1106 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
1107 let context = model.new_context(&backend, ctx_params).unwrap();
1108
1109 let result = context.get_logits_ith(7);
1110
1111 assert!(matches!(
1112 result,
1113 Err(crate::LogitsError::TokenNotInitialized(7))
1114 ));
1115 }
1116
1117 #[test]
1118 #[serial]
1119 fn get_logits_ith_returns_token_index_exceeds_context_for_huge_index() {
1120 let (backend, model) = test_model::load_default_model().unwrap();
1121 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(64));
1122 let mut context = model.new_context(&backend, ctx_params).unwrap();
1123
1124 let huge_index = i32::try_from(context.n_ctx()).unwrap();
1125 context.mark_logits_initialized(huge_index);
1126 let result = context.get_logits_ith(huge_index);
1127
1128 assert!(matches!(
1129 result,
1130 Err(crate::LogitsError::TokenIndexExceedsContext { .. })
1131 ));
1132 }
1133}