1use std::fmt::{Debug, Formatter};
4use std::num::NonZeroI32;
5use std::ptr::NonNull;
6use std::slice;
7
8use crate::llama_batch::LlamaBatch;
9use crate::model::{LlamaLoraAdapter, LlamaModel};
10use crate::timing::LlamaTimings;
11use crate::token::LlamaToken;
12use crate::token::data::LlamaTokenData;
13use crate::token::data_array::LlamaTokenDataArray;
14use crate::{
15 DecodeError, EmbeddingsError, EncodeError, LlamaLoraAdapterRemoveError,
16 LlamaLoraAdapterSetError, LogitsError,
17};
18
19const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
20 if err_code != 0 {
21 return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
22 }
23
24 Ok(())
25}
26
27const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
28 if err_code != 0 {
29 return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
30 }
31
32 Ok(())
33}
34
35pub mod kv_cache;
36pub mod llama_state_seq_flags;
37pub mod load_seq_state_error;
38pub mod load_session_error;
39pub mod params;
40pub mod save_seq_state_error;
41pub mod save_session_error;
42pub mod session;
43
44pub struct LlamaContext<'model> {
46 pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
48 pub model: &'model LlamaModel,
50 initialized_logits: Vec<i32>,
51 embeddings_enabled: bool,
52}
53
54impl Debug for LlamaContext<'_> {
55 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
56 f.debug_struct("LlamaContext")
57 .field("context", &self.context)
58 .finish()
59 }
60}
61
62impl<'model> LlamaContext<'model> {
63 #[must_use]
65 pub const fn new(
66 llama_model: &'model LlamaModel,
67 llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
68 embeddings_enabled: bool,
69 ) -> Self {
70 Self {
71 context: llama_context,
72 model: llama_model,
73 initialized_logits: Vec::new(),
74 embeddings_enabled,
75 }
76 }
77
78 #[must_use]
80 pub fn n_batch(&self) -> u32 {
81 unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
82 }
83
84 #[must_use]
86 pub fn n_ubatch(&self) -> u32 {
87 unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
88 }
89
90 #[must_use]
92 pub fn n_ctx(&self) -> u32 {
93 unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
94 }
95
96 pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
106 let result = unsafe {
107 llama_cpp_bindings_sys::llama_decode(self.context.as_ptr(), batch.llama_batch)
108 };
109
110 match NonZeroI32::new(result) {
111 None => {
112 self.initialized_logits
113 .clone_from(&batch.initialized_logits);
114 Ok(())
115 }
116 Some(error) => Err(DecodeError::from(error)),
117 }
118 }
119
120 pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
130 let status = unsafe {
131 llama_cpp_bindings_sys::llama_rs_encode(self.context.as_ptr(), batch.llama_batch)
132 };
133
134 self.handle_encode_result(status, batch)
135 }
136
137 fn handle_encode_result(
138 &mut self,
139 status: llama_cpp_bindings_sys::llama_rs_status,
140 batch: &mut LlamaBatch,
141 ) -> Result<(), EncodeError> {
142 if crate::status_is_ok(status) {
143 self.initialized_logits
144 .clone_from(&batch.initialized_logits);
145
146 Ok(())
147 } else {
148 Err(EncodeError::from(
149 NonZeroI32::new(crate::status_to_i32(status))
150 .unwrap_or(NonZeroI32::new(1).expect("1 is non-zero")),
151 ))
152 }
153 }
154
155 pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
169 if !self.embeddings_enabled {
170 return Err(EmbeddingsError::NotEnabled);
171 }
172
173 let n_embd = usize::try_from(self.model.n_embd())
174 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
175
176 unsafe {
177 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
178 self.context.as_ptr(),
179 sequence_index,
180 );
181
182 if embedding.is_null() {
183 Err(EmbeddingsError::NonePoolType)
184 } else {
185 Ok(slice::from_raw_parts(embedding, n_embd))
186 }
187 }
188 }
189
190 pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
204 if !self.embeddings_enabled {
205 return Err(EmbeddingsError::NotEnabled);
206 }
207
208 let n_embd = usize::try_from(self.model.n_embd())
209 .map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
210
211 unsafe {
212 let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
213 self.context.as_ptr(),
214 token_index,
215 );
216
217 if embedding.is_null() {
218 Err(EmbeddingsError::LogitsNotEnabled)
219 } else {
220 Ok(slice::from_raw_parts(embedding, n_embd))
221 }
222 }
223 }
224
225 pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
234 let logits = self.get_logits()?;
235
236 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
237 let token = LlamaToken::new(token_id);
238 LlamaTokenData::new(token, *logit, 0_f32)
239 }))
240 }
241
242 pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
247 Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
248 }
249
250 pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
264 let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
265
266 if data.is_null() {
267 return Err(LogitsError::NullLogits);
268 }
269
270 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
271
272 Ok(unsafe { slice::from_raw_parts(data, len) })
273 }
274
275 pub fn candidates_ith(
280 &self,
281 token_index: i32,
282 ) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
283 let logits = self.get_logits_ith(token_index)?;
284
285 Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
286 let token = LlamaToken::new(token_id);
287 LlamaTokenData::new(token, *logit, 0_f32)
288 }))
289 }
290
291 pub fn token_data_array_ith(
296 &self,
297 token_index: i32,
298 ) -> Result<LlamaTokenDataArray, LogitsError> {
299 Ok(LlamaTokenDataArray::from_iter(
300 self.candidates_ith(token_index)?,
301 false,
302 ))
303 }
304
305 pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
310 if !self.initialized_logits.contains(&token_index) {
311 return Err(LogitsError::TokenNotInitialized(token_index));
312 }
313
314 let token_index_u32 =
315 u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
316
317 if self.n_ctx() <= token_index_u32 {
318 return Err(LogitsError::TokenIndexExceedsContext {
319 token_index: token_index_u32,
320 context_size: self.n_ctx(),
321 });
322 }
323
324 let data = unsafe {
325 llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
326 };
327 let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
328
329 Ok(unsafe { slice::from_raw_parts(data, len) })
330 }
331
332 pub fn reset_timings(&mut self) {
334 unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
335 }
336
337 pub fn timings(&mut self) -> LlamaTimings {
339 let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
340 LlamaTimings { timings }
341 }
342
343 pub fn lora_adapter_set(
349 &self,
350 adapter: &mut LlamaLoraAdapter,
351 scale: f32,
352 ) -> Result<(), LlamaLoraAdapterSetError> {
353 let mut adapters = [adapter.lora_adapter.as_ptr()];
354 let mut scales = [scale];
355 let err_code = unsafe {
356 llama_cpp_bindings_sys::llama_set_adapters_lora(
357 self.context.as_ptr(),
358 adapters.as_mut_ptr(),
359 1,
360 scales.as_mut_ptr(),
361 )
362 };
363 check_lora_set_result(err_code)?;
364
365 tracing::debug!("Set lora adapter");
366 Ok(())
367 }
368
369 pub fn lora_adapter_remove(
378 &self,
379 _adapter: &mut LlamaLoraAdapter,
380 ) -> Result<(), LlamaLoraAdapterRemoveError> {
381 let err_code = unsafe {
382 llama_cpp_bindings_sys::llama_set_adapters_lora(
383 self.context.as_ptr(),
384 std::ptr::null_mut(),
385 0,
386 std::ptr::null_mut(),
387 )
388 };
389 check_lora_remove_result(err_code)?;
390
391 tracing::debug!("Remove lora adapter");
392 Ok(())
393 }
394
395 pub fn print_memory_breakdown(&self) {
397 unsafe { llama_cpp_bindings_sys::llama_memory_breakdown_print(self.context.as_ptr()) }
398 }
399}
400
401impl Drop for LlamaContext<'_> {
402 fn drop(&mut self) {
403 unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
404 }
405}
406
407#[cfg(test)]
408mod unit_tests {
409 use crate::LlamaLoraAdapterRemoveError;
410 use crate::LlamaLoraAdapterSetError;
411
412 use super::{check_lora_remove_result, check_lora_set_result};
413
414 #[test]
415 fn check_lora_set_result_ok_for_zero() {
416 assert!(check_lora_set_result(0).is_ok());
417 }
418
419 #[test]
420 fn check_lora_set_result_error_for_nonzero() {
421 let result = check_lora_set_result(-1);
422
423 assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
424 }
425
426 #[test]
427 fn check_lora_remove_result_ok_for_zero() {
428 assert!(check_lora_remove_result(0).is_ok());
429 }
430
431 #[test]
432 fn check_lora_remove_result_error_for_nonzero() {
433 let result = check_lora_remove_result(-1);
434
435 assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
436 }
437}
438
439#[cfg(test)]
440#[cfg(feature = "tests_that_use_llms")]
441mod tests {
442 use serial_test::serial;
443
444 use crate::context::params::LlamaContextParams;
445 use crate::llama_batch::LlamaBatch;
446 use crate::model::AddBos;
447 use crate::test_model;
448
449 #[test]
450 #[serial]
451 fn context_creation_and_properties() {
452 let (backend, model) = test_model::load_default_model().unwrap();
453 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
454 let context = model.new_context(&backend, ctx_params).unwrap();
455 assert!(context.n_ctx() > 0);
456 assert!(context.n_batch() > 0);
457 assert!(context.n_ubatch() > 0);
458 }
459
460 #[test]
461 #[serial]
462 fn decode_and_get_logits() {
463 let (backend, model) = test_model::load_default_model().unwrap();
464 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
465 let mut context = model.new_context(&backend, ctx_params).unwrap();
466 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
467 let mut batch = LlamaBatch::new(512, 1).unwrap();
468 batch.add_sequence(&tokens, 0, false).unwrap();
469
470 let decode_result = context.decode(&mut batch);
471 assert!(decode_result.is_ok());
472
473 let logits = context.get_logits().unwrap();
474 assert!(!logits.is_empty());
475 }
476
477 #[test]
478 #[serial]
479 fn timings_work() {
480 let (backend, model) = test_model::load_default_model().unwrap();
481 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
482 let mut context = model.new_context(&backend, ctx_params).unwrap();
483 context.reset_timings();
484 let timings = context.timings();
485 assert!(timings.t_start_ms() >= 0.0);
486 }
487
488 #[test]
489 #[serial]
490 fn token_data_array_has_entries_after_decode() {
491 let (backend, model) = test_model::load_default_model().unwrap();
492 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
493 let mut context = model.new_context(&backend, ctx_params).unwrap();
494 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
495 let mut batch = LlamaBatch::new(512, 1).unwrap();
496 batch.add_sequence(&tokens, 0, false).unwrap();
497 context.decode(&mut batch).unwrap();
498
499 let token_data_array = context.token_data_array().unwrap();
500
501 assert!(!token_data_array.data.is_empty());
502 }
503
504 #[test]
505 #[serial]
506 fn get_logits_ith_returns_valid_slice() {
507 let (backend, model) = test_model::load_default_model().unwrap();
508 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
509 let mut context = model.new_context(&backend, ctx_params).unwrap();
510 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
511 let last_index = i32::try_from(tokens.len() - 1).unwrap();
512 let mut batch = LlamaBatch::new(512, 1).unwrap();
513 batch.add_sequence(&tokens, 0, false).unwrap();
514 context.decode(&mut batch).unwrap();
515
516 let logits = context.get_logits_ith(last_index).unwrap();
517
518 assert_eq!(logits.len(), model.n_vocab() as usize);
519 }
520
521 #[test]
522 #[serial]
523 fn token_data_array_ith_returns_valid_data() {
524 let (backend, model) = test_model::load_default_model().unwrap();
525 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
526 let mut context = model.new_context(&backend, ctx_params).unwrap();
527 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
528 let last_index = i32::try_from(tokens.len() - 1).unwrap();
529 let mut batch = LlamaBatch::new(512, 1).unwrap();
530 batch.add_sequence(&tokens, 0, false).unwrap();
531 context.decode(&mut batch).unwrap();
532
533 let token_data_array = context.token_data_array_ith(last_index).unwrap();
534
535 assert_eq!(token_data_array.data.len(), model.n_vocab() as usize);
536 }
537
538 #[test]
539 #[serial]
540 fn embeddings_ith_returns_error_when_embeddings_disabled() {
541 let (backend, model) = test_model::load_default_model().unwrap();
542 let ctx_params = LlamaContextParams::default()
543 .with_n_ctx(std::num::NonZeroU32::new(512))
544 .with_embeddings(false);
545 let context = model.new_context(&backend, ctx_params).unwrap();
546
547 let result = context.embeddings_ith(0);
548
549 assert!(result.is_err());
550 }
551
552 #[test]
553 #[serial]
554 fn embeddings_seq_ith_returns_error_when_embeddings_disabled() {
555 let (backend, model) = test_model::load_default_model().unwrap();
556 let ctx_params = LlamaContextParams::default()
557 .with_n_ctx(std::num::NonZeroU32::new(512))
558 .with_embeddings(false);
559 let context = model.new_context(&backend, ctx_params).unwrap();
560
561 let result = context.embeddings_seq_ith(0);
562
563 assert!(result.is_err());
564 }
565
566 #[test]
567 #[serial]
568 fn candidates_returns_n_vocab_entries() {
569 let (backend, model) = test_model::load_default_model().unwrap();
570 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
571 let mut context = model.new_context(&backend, ctx_params).unwrap();
572 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
573 let mut batch = LlamaBatch::new(512, 1).unwrap();
574 batch.add_sequence(&tokens, 0, false).unwrap();
575 context.decode(&mut batch).unwrap();
576
577 let count = context.candidates().unwrap().count();
578
579 assert_eq!(count, model.n_vocab() as usize);
580 }
581
582 #[test]
583 #[serial]
584 fn debug_format_contains_struct_name() {
585 let (backend, model) = test_model::load_default_model().unwrap();
586 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
587 let context = model.new_context(&backend, ctx_params).unwrap();
588 let debug_output = format!("{context:?}");
589
590 assert!(debug_output.contains("LlamaContext"));
591 }
592
593 #[test]
594 #[serial]
595 fn decode_with_embeddings_enabled() {
596 let (backend, model) = test_model::load_default_embedding_model().unwrap();
597 let ctx_params = LlamaContextParams::default()
598 .with_n_ctx(std::num::NonZeroU32::new(512))
599 .with_embeddings(true);
600 let mut context = model.new_context(&backend, ctx_params).unwrap();
601 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
602 let mut batch = LlamaBatch::new(512, 1).unwrap();
603 batch.add_sequence(&tokens, 0, false).unwrap();
604
605 let result = context.decode(&mut batch);
606
607 assert!(result.is_ok());
608 }
609
610 #[test]
611 #[serial]
612 fn embeddings_seq_ith_returns_valid_embeddings() {
613 let (backend, model) = test_model::load_default_embedding_model().unwrap();
614 let ctx_params = LlamaContextParams::default()
615 .with_n_ctx(std::num::NonZeroU32::new(512))
616 .with_embeddings(true);
617 let mut context = model.new_context(&backend, ctx_params).unwrap();
618 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
619 let mut batch = LlamaBatch::new(512, 1).unwrap();
620 batch.add_sequence(&tokens, 0, false).unwrap();
621 context.decode(&mut batch).unwrap();
622
623 let embeddings = context.embeddings_seq_ith(0).unwrap();
624
625 assert_eq!(embeddings.len(), model.n_embd() as usize);
626 }
627
628 #[test]
629 #[serial]
630 fn embeddings_ith_returns_valid_embeddings() {
631 let (backend, model) = test_model::load_default_embedding_model().unwrap();
632 let ctx_params = LlamaContextParams::default()
633 .with_n_ctx(std::num::NonZeroU32::new(512))
634 .with_embeddings(true);
635 let mut context = model.new_context(&backend, ctx_params).unwrap();
636 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
637 let last_index = i32::try_from(tokens.len() - 1).unwrap();
638 let mut batch = LlamaBatch::new(512, 1).unwrap();
639 batch.add_sequence(&tokens, 0, false).unwrap();
640 context.decode(&mut batch).unwrap();
641
642 let embeddings = context.embeddings_ith(last_index).unwrap();
643
644 assert_eq!(embeddings.len(), model.n_embd() as usize);
645 }
646
647 #[test]
648 #[serial]
649 fn candidates_ith_returns_n_vocab_entries() {
650 let (backend, model) = test_model::load_default_model().unwrap();
651 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
652 let mut context = model.new_context(&backend, ctx_params).unwrap();
653 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
654 let last_index = i32::try_from(tokens.len() - 1).unwrap();
655 let mut batch = LlamaBatch::new(512, 1).unwrap();
656 batch.add_sequence(&tokens, 0, false).unwrap();
657 context.decode(&mut batch).unwrap();
658
659 let count = context.candidates_ith(last_index).unwrap().count();
660
661 assert_eq!(count, model.n_vocab() as usize);
662 }
663
664 #[test]
665 #[serial]
666 fn lora_adapter_remove_succeeds_with_no_adapters() {
667 let (backend, model) = test_model::load_default_model().unwrap();
668 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
669 let context = model.new_context(&backend, ctx_params).unwrap();
670 let mut adapter = crate::model::LlamaLoraAdapter {
671 lora_adapter: std::ptr::NonNull::dangling(),
672 };
673
674 let result = context.lora_adapter_remove(&mut adapter);
675
676 assert!(result.is_ok());
677 }
678
679 #[test]
680 #[serial]
681 fn encode_on_non_encoder_model_returns_error() {
682 let (backend, model) = test_model::load_default_model().unwrap();
683 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
684 let mut context = model.new_context(&backend, ctx_params).unwrap();
685 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
686 let mut batch = LlamaBatch::new(512, 1).unwrap();
687 batch.add_sequence(&tokens, 0, false).unwrap();
688
689 let result = context.encode(&mut batch);
690
691 assert!(result.is_err());
692 }
693
694 #[test]
695 #[serial]
696 fn lora_adapter_set_with_dangling_pointer_succeeds_or_errors() {
697 let (backend, model) = test_model::load_default_model().unwrap();
698 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
699 let context = model.new_context(&backend, ctx_params).unwrap();
700 let mut adapter = crate::model::LlamaLoraAdapter {
701 lora_adapter: std::ptr::NonNull::dangling(),
702 };
703
704 let result = context.lora_adapter_set(&mut adapter, 1.0);
705
706 assert!(result.is_ok());
707 }
708
709 #[test]
710 #[serial]
711 fn embeddings_ith_returns_null_embedding_error_for_non_embedding_token() {
712 let (backend, model) = test_model::load_default_embedding_model().unwrap();
713 let ctx_params = LlamaContextParams::default()
714 .with_n_ctx(std::num::NonZeroU32::new(512))
715 .with_embeddings(true);
716 let context = model.new_context(&backend, ctx_params).unwrap();
717
718 let result = context.embeddings_ith(999);
719
720 assert!(result.is_err());
721 }
722
723 #[test]
724 #[serial]
725 fn embeddings_seq_ith_returns_null_embedding_error_for_invalid_seq() {
726 let (backend, model) = test_model::load_default_model().unwrap();
727 let ctx_params = LlamaContextParams::default()
728 .with_n_ctx(std::num::NonZeroU32::new(512))
729 .with_embeddings(true);
730 let mut context = model.new_context(&backend, ctx_params).unwrap();
731 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
732 let mut batch = LlamaBatch::new(512, 1).unwrap();
733 batch.add_sequence(&tokens, 0, false).unwrap();
734 context.decode(&mut batch).unwrap();
735
736 let result = context.embeddings_seq_ith(999);
737
738 assert!(result.is_err());
739 }
740
741 #[test]
742 #[serial]
743 fn decode_empty_batch_returns_error() {
744 let (backend, model) = test_model::load_default_model().unwrap();
745 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
746 let mut context = model.new_context(&backend, ctx_params).unwrap();
747 let mut batch = LlamaBatch::new(512, 1).unwrap();
748
749 let result = context.decode(&mut batch);
750
751 assert!(result.is_err());
752 }
753
754 #[test]
755 #[serial]
756 fn encode_succeeds_with_encoder_model() {
757 let backend = crate::llama_backend::LlamaBackend::init().unwrap();
758 let model_path = test_model::download_encoder_model().unwrap();
759 let model_params = crate::model::params::LlamaModelParams::default();
760 let model =
761 crate::model::LlamaModel::load_from_file(&backend, &model_path, &model_params).unwrap();
762 let ctx_params = LlamaContextParams::default()
763 .with_n_ctx(std::num::NonZeroU32::new(512))
764 .with_embeddings(true);
765 let mut context = model.new_context(&backend, ctx_params).unwrap();
766 let tokens = model.str_to_token("hello", AddBos::Never).unwrap();
767 let mut batch = LlamaBatch::new(512, 1).unwrap();
768 batch.add_sequence(&tokens, 0, false).unwrap();
769
770 let result = context.encode(&mut batch);
771
772 assert!(result.is_ok());
773 }
774
775 #[test]
776 #[serial]
777 fn handle_encode_result_ok_updates_logits() {
778 let (backend, model) = test_model::load_default_model().unwrap();
779 let ctx_params = LlamaContextParams::default().with_n_ctx(std::num::NonZeroU32::new(512));
780 let mut context = model.new_context(&backend, ctx_params).unwrap();
781 let tokens = model.str_to_token("hello", AddBos::Always).unwrap();
782 let mut batch = LlamaBatch::new(512, 1).unwrap();
783 batch.add_sequence(&tokens, 0, true).unwrap();
784
785 let result =
786 context.handle_encode_result(llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK, &mut batch);
787
788 assert!(result.is_ok());
789 assert!(!context.initialized_logits.is_empty());
790 }
791}