1use std::ffi::c_void;
4use std::fmt::{Debug, Formatter};
5use std::num::NonZeroI32;
6use std::ptr::NonNull;
7use std::slice;
8use std::sync::Arc;
9use std::sync::atomic::AtomicBool;
10use std::sync::atomic::Ordering;
11
12use crate::llama_batch::LlamaBatch;
13use crate::model::{LlamaLoraAdapter, LlamaModel};
14use crate::timing::LlamaTimings;
15use crate::token::LlamaToken;
16use crate::token::data::LlamaTokenData;
17use crate::token::data_array::LlamaTokenDataArray;
18use crate::{
19 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
20 LlamaLoraAdapterSetError, LogitsError,
21};
22
23const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
24 if err_code != 0 {
25 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
26 }
27
28 Ok(())
29}
30
31const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
32 if err_code != 0 {
33 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
34 }
35
36 Ok(())
37}
38
39pub mod kv_cache;
40pub mod llama_state_seq_flags;
41pub mod load_seq_state_error;
42pub mod load_session_error;
43pub mod params;
44pub mod save_seq_state_error;
45pub mod save_session_error;
46pub mod session;
47
48unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
49 let flag = unsafe { &*(data as *const AtomicBool) };
50
51 flag.load(Ordering::Relaxed)
52}
53
54pub struct LlamaContext<'model> {
56 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
58 pub model: &'model LlamaModel,
60 abort_flag: Option<Arc<AtomicBool>>,
61 initialized_logits: Vec<i32>,
62 embeddings_enabled: bool,
63}
64
65impl Debug for LlamaContext<'_> {
66 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("LlamaContext")
68 .field("context", &self.context)
69 .finish()
70 }
71}
72
73impl<'model> LlamaContext<'model> {
74 #[must_use]
76 pub const fn new(
77 llama_model: &'model LlamaModel,
78 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
79 embeddings_enabled: bool,
80 ) -> Self {
81 Self {
82 context: llama_context,
83 model: llama_model,
84 abort_flag: None,
85 initialized_logits: Vec::new(),
86 embeddings_enabled,
87 }
88 }
89
90 #[must_use]
92 pub fn n_batch(&self) -> u32 {
93 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
94 }
95
96 #[must_use]
98 pub fn n_ubatch(&self) -> u32 {
99 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
100 }
101
102 #[must_use]
104 pub fn n_ctx(&self) -> u32 {
105 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
106 }
107
108 #[expect(unsafe_code, reason = "required for FFI abort callback registration")]
114 pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
115 let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
116 self.abort_flag = Some(flag);
117
118 unsafe {
119 llama_cpp_bindings_sys::llama_set_abort_callback(
120 self.context.as_ptr(),
121 Some(abort_callback_trampoline),
122 raw_ptr,
123 );
124 }
125 }
126
127 #[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
129 pub fn clear_abort_callback(&mut self) {
130 self.abort_flag = None;
131
132 unsafe {
133 llama_cpp_bindings_sys::llama_set_abort_callback(
134 self.context.as_ptr(),
135 None,
136 std::ptr::null_mut(),
137 );
138 }
139 }
140
141 #[expect(unsafe_code, reason = "required for FFI synchronization call")]
146 pub fn synchronize(&self) {
147 unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
148 }
149
150 #[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
155 pub fn detach_threadpool(&self) {
156 unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
157 }
158
159 pub fn mark_logits_initialized(&mut self, token_index: i32) {
163 self.initialized_logits = vec![token_index];
164 }
165
166 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
176 let result = unsafe {
177 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
178 };
179
180 match NonZeroI32::new(result) {
181 None => {
182 self.initialized_logits
183 .clone_from(&batch.initialized_logits);
184 Ok(())
185 }
186 Some(error) => Err(DecodeError::from(error)),
187 }
188 }
189
190 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
200 let status = unsafe {
201 llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
202 };
203
204 self.handle_encode_result(status, batch)
205 }
206
207 fn handle_encode_result(
208 &mut self,
209 status: llama_cpp_bindings_sys::llama_rs_status,
210 batch: &mut LlamaBatch,
211 ) -> Result<(), EncodeError> {
212 if crate::status_is_ok(status) {
213 self.initialized_logits
214 .clone_from(&batch.initialized_logits);
215
216 Ok(())
217 } else {
218 Err(EncodeError::from(
219 NonZeroI32::new(crate::status_to_i32(status))
220 .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
221 ))
222 }
223 }
224
225 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
239 if !self.embeddings_enabled {
240 return Err(EmbeddingsError::NotEnabled);
241 }
242
243 let n_embd = usize::try_from(self.model.n_embd())
244 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
245
246 unsafe {
247 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
248 self.context.as_ptr(),
249 sequence_index,
250 );
251
252 if embedding.is_null() {
253 Err(EmbeddingsError::NonePoolType)
254 } else {
255 Ok(slice::from_raw_parts(embedding, n_embd))
256 }
257 }
258 }
259
260 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
274 if !self.embeddings_enabled {
275 return Err(EmbeddingsError::NotEnabled);
276 }
277
278 let n_embd = usize::try_from(self.model.n_embd())
279 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
280
281 unsafe {
282 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
283 self.context.as_ptr(),
284 token_index,
285 );
286
287 if embedding.is_null() {
288 Err(EmbeddingsError::LogitsNotEnabled)
289 } else {
290 Ok(slice::from_raw_parts(embedding, n_embd))
291 }
292 }
293 }
294
295 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
304 let logits = self.get_logits()?;
305
306 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
307 let token = LlamaToken::new(token_id);
308 LlamaTokenData::new(token, *logit, 0_f32)
309 }))
310 }
311
312 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
317 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
318 }
319
320 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
334 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
335
336 if data.is_null() {
337 return Err(LogitsError::NullLogits);
338 }
339
340 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
341
342 Ok(unsafe { slice::from_raw_parts(data, len) })
343 }
344
345 pub fn candidates_ith(
350 &self,
351 token_index: i32,
352 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
353 let logits = self.get_logits_ith(token_index)?;
354
355 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
356 let token = LlamaToken::new(token_id);
357 LlamaTokenData::new(token, *logit, 0_f32)
358 }))
359 }
360
361 pub fn token_data_array_ith(
366 &self,
367 token_index: i32,
368 ) -> Result<LlamaTokenDataArray, LogitsError> {
369 Ok(LlamaTokenDataArray::from_iter(
370 self.candidates_ith(token_index)?,
371 false,
372 ))
373 }
374
375 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
380 if !self.initialized_logits.contains(&token_index) {
381 return Err(LogitsError::TokenNotInitialized(token_index));
382 }
383
384 if token_index >= 0 {
385 let token_index_u32 =
386 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
387
388 if self.n_ctx() <= token_index_u32 {
389 return Err(LogitsError::TokenIndexExceedsContext {
390 token_index: token_index_u32,
391 context_size: self.n_ctx(),
392 });
393 }
394 }
395
396 let data = unsafe {
397 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
398 };
399 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
400
401 Ok(unsafe { slice::from_raw_parts(data, len) })
402 }
403
404 pub fn reset_timings(&mut self) {
406 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
407 }
408
409 pub fn timings(&mut self) -> LlamaTimings {
411 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
412 LlamaTimings { timings }
413 }
414
415 pub fn lora_adapter_set(
421 &self,
422 adapter: &mut LlamaLoraAdapter,
423 scale: f32,
424 ) -> Result<(), LlamaLoraAdapterSetError> {
425 let mut adapters = [adapter.lora_adapter.as_ptr()];
426 let mut scales = [scale];
427 let err_code = unsafe {
428 llama_cpp_bindings_sys::llama_set_adapters_lora(
429 self.context.as_ptr(),
430 adapters.as_mut_ptr(),
431 1,
432 scales.as_mut_ptr(),
433 )
434 };
435 check_lora_set_result(err_code)?;
436
437 tracing::debug!("Set lora adapter");
438 Ok(())
439 }
440
441 pub fn lora_adapter_remove(
450 &self,
451 _adapter: &mut LlamaLoraAdapter,
452 ) -> Result<(), LlamaLoraAdapterRemoveError> {
453 let err_code = unsafe {
454 llama_cpp_bindings_sys::llama_set_adapters_lora(
455 self.context.as_ptr(),
456 std::ptr::null_mut(),
457 0,
458 std::ptr::null_mut(),
459 )
460 };
461 check_lora_remove_result(err_code)?;
462
463 tracing::debug!("Remove lora adapter");
464 Ok(())
465 }
466
467 pub fn print_memory_breakdown(&self) {
469 unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
470 }
471}
472
473impl Drop for LlamaContext<'_> {
474 fn drop(&mut self) {
475 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
476 }
477}
478
479#[cfg(test)]
480mod unit_tests {
481 use crate::LlamaLoraAdapterRemoveError;
482 use crate::LlamaLoraAdapterSetError;
483
484 use super::{check_lora_remove_result, check_lora_set_result};
485
486 #[test]
487 fn check_lora_set_result_ok_for_zero() {
488 assert!(check_lora_set_result(0).is_ok());
489 }
490
491 #[test]
492 fn check_lora_set_result_error_for_nonzero() {
493 let result = check_lora_set_result(-1);
494
495 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
496 }
497
498 #[test]
499 fn check_lora_remove_result_ok_for_zero() {
500 assert!(check_lora_remove_result(0).is_ok());
501 }
502
503 #[test]
504 fn check_lora_remove_result_error_for_nonzero() {
505 let result = check_lora_remove_result(-1);
506
507 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
508 }
509}
510
511#[cfg(test)]
512#[cfg(feature = "tests_that_use_llms")]
513mod tests {
514 use serial_test::serial;
515
516 use crate::context::params::LlamaContextParams;
517 use crate::llama_batch::LlamaBatch;
518 use crate::model::AddBos;
519 use crate::test_model;
520
521 #[test]
522 #[serial]
523 fn context_creation_and_properties() {
524 let (backend, model) = test_model::load_default_model().unwrap();
525 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526 let context = model.new_context(&backend, ctx_params).unwrap();
527 assert!(context.n_ctx() > 0);
528 assert!(context.n_batch() > 0);
529 assert!(context.n_ubatch() > 0);
530 }
531
532 #[test]
533 #[serial]
534 fn decode_and_get_logits() {
535 let (backend, model) = test_model::load_default_model().unwrap();
536 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
537 let mut context = model.new_context(&backend, ctx_params).unwrap();
538 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
539 let mut batch = LlamaBatch::new(512, 1).unwrap();
540 batch.add_sequence(&tokens, 0, false).unwrap();
541
542 let decode_result = context.decode(&mut batch);
543 assert!(decode_result.is_ok());
544
545 let logits = context.get_logits().unwrap();
546 assert!(!logits.is_empty());
547 }
548
549 #[test]
550 #[serial]
551 fn timings_work() {
552 let (backend, model) = test_model::load_default_model().unwrap();
553 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
554 let mut context = model.new_context(&backend, ctx_params).unwrap();
555 context.reset_timings();
556 let timings = context.timings();
557 assert!(timings.t_start_ms() >= 0.0);
558 }
559
560 #[test]
561 #[serial]
562 fn token_data_array_has_entries_after_decode() {
563 let (backend, model) = test_model::load_default_model().unwrap();
564 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
565 let mut context = model.new_context(&backend, ctx_params).unwrap();
566 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
567 let mut batch = LlamaBatch::new(512, 1).unwrap();
568 batch.add_sequence(&tokens, 0, false).unwrap();
569 context.decode(&mut batch).unwrap();
570
571 let token_data_array = context.token_data_array().unwrap();
572
573 assert!(!token_data_array.data.is_empty());
574 }
575
576 #[test]
577 #[serial]
578 fn get_logits_ith_returns_valid_slice() {
579 let (backend, model) = test_model::load_default_model().unwrap();
580 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
581 let mut context = model.new_context(&backend, ctx_params).unwrap();
582 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
583 let last_index = i32::try_from(tokens.len() - 1).unwrap();
584 let mut batch = LlamaBatch::new(512, 1).unwrap();
585 batch.add_sequence(&tokens, 0, false).unwrap();
586 context.decode(&mut batch).unwrap();
587
588 let logits = context.get_logits_ith(last_index).unwrap();
589
590 assert_eq!(logits.len(), model.n_vocab() as usize);
591 }
592
593 #[test]
594 #[serial]
595 fn token_data_array_ith_returns_valid_data() {
596 let (backend, model) = test_model::load_default_model().unwrap();
597 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
598 let mut context = model.new_context(&backend, ctx_params).unwrap();
599 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
600 let last_index = i32::try_from(tokens.len() - 1).unwrap();
601 let mut batch = LlamaBatch::new(512, 1).unwrap();
602 batch.add_sequence(&tokens, 0, false).unwrap();
603 context.decode(&mut batch).unwrap();
604
605 let token_data_array = context.token_data_array_ith(last_index).unwrap();
606
607 assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
608 }
609
610 #[test]
611 #[serial]
612 fn embeddings_ith_returns_error_when_embeddings_disabled() {
613 let (backend, model) = test_model::load_default_model().unwrap();
614 let ctx_params = LlamaContextParams::default()
615 .with_n_ctx(std::num::NonZeroU32::new(512))
616 .with_embeddings(false);
617 let context = model.new_context(&backend, ctx_params).unwrap();
618
619 let result = context.embeddings_ith(0);
620
621 assert!(result.is_err());
622 }
623
624 #[test]
625 #[serial]
626 fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
627 let (backend, model) = test_model::load_default_model().unwrap();
628 let ctx_params = LlamaContextParams::default()
629 .with_n_ctx(std::num::NonZeroU32::new(512))
630 .with_embeddings(false);
631 let context = model.new_context(&backend, ctx_params).unwrap();
632
633 let result = context.embeddings_seq_ith(0);
634
635 assert!(result.is_err());
636 }
637
638 #[test]
639 #[serial]
640 fn candidates_returns_n_vocab_entries() {
641 let (backend, model) = test_model::load_default_model().unwrap();
642 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
643 let mut context = model.new_context(&backend, ctx_params).unwrap();
644 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
645 let mut batch = LlamaBatch::new(512, 1).unwrap();
646 batch.add_sequence(&tokens, 0, false).unwrap();
647 context.decode(&mut batch).unwrap();
648
649 let count = context.candidates().unwrap().count();
650
651 assert_eq!(count, model.n_vocab() as usize);
652 }
653
654 #[test]
655 #[serial]
656 fn debug_format_contains_struct_name() {
657 let (backend, model) = test_model::load_default_model().unwrap();
658 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
659 let context = model.new_context(&backend, ctx_params).unwrap();
660 let debug_output = format!("{context:?}");
661
662 assert!(debug_output.contains("LlamaContext"));
663 }
664
665 #[test]
666 #[serial]
667 fn decode_with_embeddings_enabled() {
668 let (backend, model) = test_model::load_default_embedding_model().unwrap();
669 let ctx_params = LlamaContextParams::default()
670 .with_n_ctx(std::num::NonZeroU32::new(512))
671 .with_embeddings(true);
672 let mut context = model.new_context(&backend, ctx_params).unwrap();
673 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
674 let mut batch = LlamaBatch::new(512, 1).unwrap();
675 batch.add_sequence(&tokens, 0, false).unwrap();
676
677 let result = context.decode(&mut batch);
678
679 assert!(result.is_ok());
680 }
681
682 #[test]
683 #[serial]
684 fn embeddings_seq_ith_returns_valid_embeddings() {
685 let (backend, model) = test_model::load_default_embedding_model().unwrap();
686 let ctx_params = LlamaContextParams::default()
687 .with_n_ctx(std::num::NonZeroU32::new(512))
688 .with_embeddings(true);
689 let mut context = model.new_context(&backend, ctx_params).unwrap();
690 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
691 let mut batch = LlamaBatch::new(512, 1).unwrap();
692 batch.add_sequence(&tokens, 0, false).unwrap();
693 context.decode(&mut batch).unwrap();
694
695 let embeddings = context.embeddings_seq_ith(0).unwrap();
696
697 assert_eq!(embeddings.len(), model.n_embd() as usize);
698 }
699
700 #[test]
701 #[serial]
702 fn embeddings_ith_returns_valid_embeddings() {
703 let (backend, model) = test_model::load_default_embedding_model().unwrap();
704 let ctx_params = LlamaContextParams::default()
705 .with_n_ctx(std::num::NonZeroU32::new(512))
706 .with_embeddings(true);
707 let mut context = model.new_context(&backend, ctx_params).unwrap();
708 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
709 let last_index = i32::try_from(tokens.len() - 1).unwrap();
710 let mut batch = LlamaBatch::new(512, 1).unwrap();
711 batch.add_sequence(&tokens, 0, false).unwrap();
712 context.decode(&mut batch).unwrap();
713
714 let embeddings = context.embeddings_ith(last_index).unwrap();
715
716 assert_eq!(embeddings.len(), model.n_embd() as usize);
717 }
718
719 #[test]
720 #[serial]
721 fn candidates_ith_returns_n_vocab_entries() {
722 let (backend, model) = test_model::load_default_model().unwrap();
723 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
724 let mut context = model.new_context(&backend, ctx_params).unwrap();
725 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
726 let last_index = i32::try_from(tokens.len() - 1).unwrap();
727 let mut batch = LlamaBatch::new(512, 1).unwrap();
728 batch.add_sequence(&tokens, 0, false).unwrap();
729 context.decode(&mut batch).unwrap();
730
731 let count = context.candidates_ith(last_index).unwrap().count();
732
733 assert_eq!(count, model.n_vocab() as usize);
734 }
735
736 #[test]
737 #[serial]
738 fn lora_adapter_remove_succeeds_with_no_adapters() {
739 let (backend, model) = test_model::load_default_model().unwrap();
740 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
741 let context = model.new_context(&backend, ctx_params).unwrap();
742 let mut adapter = crate::model::LlamaLoraAdapter {
743 lora_adapter: std::ptr::NonNull::dangling(),
744 };
745
746 let result = context.lora_adapter_remove(&mut adapter);
747
748 assert!(result.is_ok());
749 }
750
751 #[test]
752 #[serial]
753 fn encode_on_non_encoder_model_returns_error() {
754 let (backend, model) = test_model::load_default_model().unwrap();
755 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
756 let mut context = model.new_context(&backend, ctx_params).unwrap();
757 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
758 let mut batch = LlamaBatch::new(512, 1).unwrap();
759 batch.add_sequence(&tokens, 0, false).unwrap();
760
761 let result = context.encode(&mut batch);
762
763 assert!(result.is_err());
764 }
765
766 #[test]
767 #[serial]
768 fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
769 let (backend, model) = test_model::load_default_model().unwrap();
770 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
771 let context = model.new_context(&backend, ctx_params).unwrap();
772 let mut adapter = crate::model::LlamaLoraAdapter {
773 lora_adapter: std::ptr::NonNull::dangling(),
774 };
775
776 let result = context.lora_adapter_set(&mut adapter, 1.0);
777
778 assert!(result.is_ok());
779 }
780
781 #[test]
782 #[serial]
783 fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
784 let (backend, model) = test_model::load_default_embedding_model().unwrap();
785 let ctx_params = LlamaContextParams::default()
786 .with_n_ctx(std::num::NonZeroU32::new(512))
787 .with_embeddings(true);
788 let context = model.new_context(&backend, ctx_params).unwrap();
789
790 let result = context.embeddings_ith(999);
791
792 assert!(result.is_err());
793 }
794
795 #[test]
796 #[serial]
797 fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
798 let (backend, model) = test_model::load_default_model().unwrap();
799 let ctx_params = LlamaContextParams::default()
800 .with_n_ctx(std::num::NonZeroU32::new(512))
801 .with_embeddings(true);
802 let mut context = model.new_context(&backend, ctx_params).unwrap();
803 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
804 let mut batch = LlamaBatch::new(512, 1).unwrap();
805 batch.add_sequence(&tokens, 0, false).unwrap();
806 context.decode(&mut batch).unwrap();
807
808 let result = context.embeddings_seq_ith(999);
809
810 assert!(result.is_err());
811 }
812
813 #[test]
814 #[serial]
815 fn decode_empty_batch_returns_error() {
816 let (backend, model) = test_model::load_default_model().unwrap();
817 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
818 let mut context = model.new_context(&backend, ctx_params).unwrap();
819 let mut batch = LlamaBatch::new(512, 1).unwrap();
820
821 let result = context.decode(&mut batch);
822
823 assert!(result.is_err());
824 }
825
826 #[test]
827 #[serial]
828 fn encode_succeeds_with_encoder_model() {
829 let backend = crate::llama_backend::LlamaBackend::init().unwrap();
830 let model_path = test_model::download_encoder_model().unwrap();
831 let model_params = crate::model::params::LlamaModelParams::default();
832 let model =
833 crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
834 let ctx_params = LlamaContextParams::default()
835 .with_n_ctx(std::num::NonZeroU32::new(512))
836 .with_embeddings(true);
837 let mut context = model.new_context(&backend, ctx_params).unwrap();
838 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
839 let mut batch = LlamaBatch::new(512, 1).unwrap();
840 batch.add_sequence(&tokens, 0, false).unwrap();
841
842 let result = context.encode(&mut batch);
843
844 assert!(result.is_ok());
845 }
846
847 #[test]
848 #[serial]
849 fn handle_encode_result_ok_updates_logits() {
850 let (backend, model) = test_model::load_default_model().unwrap();
851 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
852 let mut context = model.new_context(&backend, ctx_params).unwrap();
853 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
854 let mut batch = LlamaBatch::new(512, 1).unwrap();
855 batch.add_sequence(&tokens, 0, true).unwrap();
856
857 let result =
858 context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
859
860 assert!(result.is_ok());
861 assert!(!context.initialized_logits.is_empty());
862 }
863
864 #[test]
865 #[serial]
866 fn set_abort_flag_aborts_decode() {
867 use std::sync::Arc;
868 use std::sync::atomic::AtomicBool;
869 use std::sync::atomic::Ordering;
870
871 use crate::DecodeError;
872
873 let (backend, model) = test_model::load_default_model().unwrap();
874 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
875 let mut context = model.new_context(&backend, ctx_params).unwrap();
876 let abort_flag = Arc::new(AtomicBool::new(true));
877 context.set_abort_flag(abort_flag.clone());
878
879 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
880 let mut batch = LlamaBatch::new(512, 1).unwrap();
881 batch.add_sequence(&tokens, 0, false).unwrap();
882
883 let result = context.decode(&mut batch);
884
885 assert_eq!(result, Err(DecodeError::Aborted));
886 }
887
888 #[test]
889 #[serial]
890 fn set_abort_flag_false_allows_decode() {
891 use std::sync::Arc;
892 use std::sync::atomic::AtomicBool;
893
894 let (backend, model) = test_model::load_default_model().unwrap();
895 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
896 let mut context = model.new_context(&backend, ctx_params).unwrap();
897 let abort_flag = Arc::new(AtomicBool::new(false));
898 context.set_abort_flag(abort_flag);
899
900 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
901 let mut batch = LlamaBatch::new(512, 1).unwrap();
902 batch.add_sequence(&tokens, 0, false).unwrap();
903
904 let result = context.decode(&mut batch);
905
906 assert!(result.is_ok());
907 }
908
909 #[test]
910 #[serial]
911 fn clear_abort_callback_allows_decode_with_flag_true() {
912 use std::sync::Arc;
913 use std::sync::atomic::AtomicBool;
914
915 let (backend, model) = test_model::load_default_model().unwrap();
916 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
917 let mut context = model.new_context(&backend, ctx_params).unwrap();
918 let abort_flag = Arc::new(AtomicBool::new(true));
919 context.set_abort_flag(abort_flag);
920 context.clear_abort_callback();
921
922 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
923 let mut batch = LlamaBatch::new(512, 1).unwrap();
924 batch.add_sequence(&tokens, 0, false).unwrap();
925
926 let result = context.decode(&mut batch);
927
928 assert!(result.is_ok());
929 }
930}