1use crate::context::LlamaContext;
4use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
5use crate::context::load_seq_state_error::LoadSeqStateError;
6use crate::context::load_session_error::LoadSessionError;
7use crate::context::save_seq_state_error::SaveSeqStateError;
8use crate::context::save_session_error::SaveSessionError;
9use crate::token::LlamaToken;
10use std::ffi::CString;
11use std::path::Path;
12
13fn process_session_load_result(
14 success: bool,
15 n_out: usize,
16 max_tokens: usize,
17 mut tokens: Vec<LlamaToken>,
18) -> Result<Vec<LlamaToken>, LoadSessionError> {
19 if !success {
20 return Err(LoadSessionError::FailedToLoad);
21 }
22
23 if n_out > max_tokens {
24 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
25 }
26
27 unsafe { tokens.set_len(n_out) };
28
29 Ok(tokens)
30}
31
32fn process_seq_load_result(
33 bytes_read: usize,
34 n_out: usize,
35 max_tokens: usize,
36 mut tokens: Vec<LlamaToken>,
37) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
38 if bytes_read == 0 {
39 return Err(LoadSeqStateError::FailedToLoad);
40 }
41
42 if n_out > max_tokens {
43 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
44 }
45
46 unsafe { tokens.set_len(n_out) };
47
48 Ok((tokens, bytes_read))
49}
50
51impl LlamaContext<'_> {
52 pub fn state_save_file(
64 &self,
65 path_session: impl AsRef<Path>,
66 tokens: &[LlamaToken],
67 ) -> Result<(), SaveSessionError> {
68 let path = path_session.as_ref();
69 let path = path
70 .to_str()
71 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
72
73 let cstr = CString::new(path)?;
74
75 if unsafe {
76 llama_cpp_bindings_sys::llama_state_save_file(
77 self.context.as_ptr(),
78 cstr.as_ptr(),
79 tokens
80 .as_ptr()
81 .cast::<llama_cpp_bindings_sys::llama_token>(),
82 tokens.len(),
83 )
84 } {
85 Ok(())
86 } else {
87 Err(SaveSessionError::FailedToSave)
88 }
89 }
90
91 pub fn state_load_file(
107 &mut self,
108 path_session: impl AsRef<Path>,
109 max_tokens: usize,
110 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
111 let path = path_session.as_ref();
112 let path = path
113 .to_str()
114 .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
115
116 let cstr = CString::new(path)?;
117 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
118 let mut n_out = 0;
119
120 let tokens_out = tokens
122 .as_mut_ptr()
123 .cast::<llama_cpp_bindings_sys::llama_token>();
124
125 let success = unsafe {
126 llama_cpp_bindings_sys::llama_state_load_file(
127 self.context.as_ptr(),
128 cstr.as_ptr(),
129 tokens_out,
130 max_tokens,
131 &raw mut n_out,
132 )
133 };
134 process_session_load_result(success, n_out, max_tokens, tokens)
135 }
136
137 pub fn state_seq_save_file(
156 &self,
157 filepath: impl AsRef<Path>,
158 seq_id: i32,
159 tokens: &[LlamaToken],
160 ) -> Result<usize, SaveSeqStateError> {
161 let path = filepath.as_ref();
162 let path = path
163 .to_str()
164 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166 let cstr = CString::new(path)?;
167
168 let bytes_written = unsafe {
169 llama_cpp_bindings_sys::llama_state_seq_save_file(
170 self.context.as_ptr(),
171 cstr.as_ptr(),
172 seq_id,
173 tokens
174 .as_ptr()
175 .cast::<llama_cpp_bindings_sys::llama_token>(),
176 tokens.len(),
177 )
178 };
179
180 if bytes_written == 0 {
181 Err(SaveSeqStateError::FailedToSave)
182 } else {
183 Ok(bytes_written)
184 }
185 }
186
187 pub fn state_seq_load_file(
206 &mut self,
207 filepath: impl AsRef<Path>,
208 dest_seq_id: i32,
209 max_tokens: usize,
210 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
211 let path = filepath.as_ref();
212 let path = path
213 .to_str()
214 .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
215
216 let cstr = CString::new(path)?;
217 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
218 let mut n_out = 0;
219
220 let tokens_out = tokens
222 .as_mut_ptr()
223 .cast::<llama_cpp_bindings_sys::llama_token>();
224
225 let bytes_read = unsafe {
226 llama_cpp_bindings_sys::llama_state_seq_load_file(
227 self.context.as_ptr(),
228 cstr.as_ptr(),
229 dest_seq_id,
230 tokens_out,
231 max_tokens,
232 &raw mut n_out,
233 )
234 };
235
236 process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
237 }
238
239 #[must_use]
242 pub fn get_state_size(&self) -> usize {
243 unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
244 }
245
246 pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
256 unsafe {
257 llama_cpp_bindings_sys::llama_state_get_data(
258 self.context.as_ptr(),
259 dest.as_mut_ptr(),
260 dest.len(),
261 )
262 }
263 }
264
265 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
275 unsafe {
276 llama_cpp_bindings_sys::llama_state_set_data(
277 self.context.as_ptr(),
278 src.as_ptr(),
279 src.len(),
280 )
281 }
282 }
283
284 #[must_use]
289 pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
290 unsafe {
291 llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
292 self.context.as_ptr(),
293 seq_id,
294 flags.bits(),
295 )
296 }
297 }
298
299 pub unsafe fn state_seq_get_data_ext(
310 &self,
311 dest: &mut [u8],
312 seq_id: i32,
313 flags: &LlamaStateSeqFlags,
314 ) -> usize {
315 unsafe {
316 llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
317 self.context.as_ptr(),
318 dest.as_mut_ptr(),
319 dest.len(),
320 seq_id,
321 flags.bits(),
322 )
323 }
324 }
325
326 pub unsafe fn state_seq_set_data_ext(
335 &mut self,
336 src: &[u8],
337 dest_seq_id: i32,
338 flags: &LlamaStateSeqFlags,
339 ) -> usize {
340 unsafe {
341 llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
342 self.context.as_ptr(),
343 src.as_ptr(),
344 src.len(),
345 dest_seq_id,
346 flags.bits(),
347 )
348 }
349 }
350}
351
352#[cfg(test)]
353mod unit_tests {
354 use crate::token::LlamaToken;
355
356 use crate::context::load_seq_state_error::LoadSeqStateError;
357 use crate::context::load_session_error::LoadSessionError;
358
359 use super::{process_seq_load_result, process_session_load_result};
360
361 #[test]
362 fn session_load_success_within_bounds() {
363 let tokens = vec![LlamaToken::new(0); 100];
364 let result = process_session_load_result(true, 10, 100, tokens);
365
366 assert!(result.is_ok());
367 assert_eq!(result.unwrap().len(), 10);
368 }
369
370 #[test]
371 fn session_load_fails_when_not_successful() {
372 let tokens = vec![LlamaToken::new(0); 100];
373 let result = process_session_load_result(false, 0, 100, tokens);
374
375 assert_eq!(result, Err(LoadSessionError::FailedToLoad));
376 }
377
378 #[test]
379 fn session_load_fails_when_n_out_exceeds_max() {
380 let tokens = vec![LlamaToken::new(0); 100];
381 let result = process_session_load_result(true, 101, 100, tokens);
382
383 assert_eq!(
384 result,
385 Err(LoadSessionError::InsufficientMaxLength {
386 n_out: 101,
387 max_tokens: 100,
388 })
389 );
390 }
391
392 #[test]
393 fn seq_load_success_within_bounds() {
394 let tokens = vec![LlamaToken::new(0); 100];
395 let result = process_seq_load_result(42, 10, 100, tokens);
396
397 assert!(result.is_ok());
398 let (loaded, bytes) = result.unwrap();
399 assert_eq!(loaded.len(), 10);
400 assert_eq!(bytes, 42);
401 }
402
403 #[test]
404 fn seq_load_fails_when_zero_bytes_read() {
405 let tokens = vec![LlamaToken::new(0); 100];
406 let result = process_seq_load_result(0, 0, 100, tokens);
407
408 assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
409 }
410
411 #[test]
412 fn seq_load_fails_when_n_out_exceeds_max() {
413 let tokens = vec![LlamaToken::new(0); 100];
414 let result = process_seq_load_result(42, 101, 100, tokens);
415
416 assert_eq!(
417 result,
418 Err(LoadSeqStateError::InsufficientMaxLength {
419 n_out: 101,
420 max_tokens: 100,
421 })
422 );
423 }
424}
425
426#[cfg(test)]
427#[cfg(feature = "tests_that_use_llms")]
428mod tests {
429 use std::num::NonZeroU32;
430
431 use serial_test::serial;
432
433 use crate::context::params::LlamaContextParams;
434 use crate::llama_batch::LlamaBatch;
435 use crate::model::AddBos;
436 use crate::test_model;
437
438 #[test]
439 #[serial]
440 fn save_and_load_session_file() {
441 let (backend, model) = test_model::load_default_model().unwrap();
442 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
443 let mut context = model.new_context(&backend, ctx_params).unwrap();
444
445 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
446 let mut batch = LlamaBatch::new(512, 1).unwrap();
447 batch.add_sequence(&tokens, 0, false).unwrap();
448 context.decode(&mut batch).unwrap();
449
450 let session_path = std::env::temp_dir().join("llama_test_session.bin");
451 context.state_save_file(&session_path, &tokens).unwrap();
452
453 let loaded_tokens = context.state_load_file(&session_path, 512).unwrap();
454 assert_eq!(loaded_tokens, tokens);
455
456 std::fs::remove_file(&session_path).unwrap();
457 }
458
459 #[test]
460 #[serial]
461 fn get_state_size_is_positive() {
462 let (backend, model) = test_model::load_default_model().unwrap();
463 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
464 let context = model.new_context(&backend, ctx_params).unwrap();
465 assert!(context.get_state_size() > 0);
466 }
467
468 #[test]
469 #[serial]
470 fn state_seq_save_and_load_file_roundtrip() {
471 let (backend, model) = test_model::load_default_model().unwrap();
472 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
473 let mut context = model.new_context(&backend, ctx_params).unwrap();
474
475 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
476 let mut batch = LlamaBatch::new(512, 1).unwrap();
477 batch.add_sequence(&tokens, 0, false).unwrap();
478 context.decode(&mut batch).unwrap();
479
480 let session_path = std::env::temp_dir().join("llama_test_seq_state.bin");
481 let bytes_written = context
482 .state_seq_save_file(&session_path, 0, &tokens)
483 .unwrap();
484 assert!(bytes_written > 0);
485
486 let (loaded_tokens, bytes_read) =
487 context.state_seq_load_file(&session_path, 0, 512).unwrap();
488 assert_eq!(loaded_tokens, tokens);
489 assert!(bytes_read > 0);
490
491 std::fs::remove_file(&session_path).unwrap();
492 }
493
494 #[test]
495 #[serial]
496 fn copy_state_data_and_set_state_data_roundtrip() {
497 let (backend, model) = test_model::load_default_model().unwrap();
498 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
499 let mut context = model.new_context(&backend, ctx_params).unwrap();
500
501 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
502 let mut batch = LlamaBatch::new(512, 1).unwrap();
503 batch.add_sequence(&tokens, 0, false).unwrap();
504 context.decode(&mut batch).unwrap();
505
506 let state_size = context.get_state_size();
507 let mut state_data = vec![0u8; state_size];
508 let bytes_copied = unsafe { context.copy_state_data(&mut state_data) };
509 assert!(bytes_copied > 0);
510
511 let bytes_read = unsafe { context.set_state_data(&state_data) };
512 assert!(bytes_read > 0);
513 }
514
515 #[test]
516 #[serial]
517 fn state_load_file_with_nonexistent_file_returns_error() {
518 let (backend, model) = test_model::load_default_model().unwrap();
519 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
520 let mut context = model.new_context(&backend, ctx_params).unwrap();
521
522 let result = context.state_load_file("/nonexistent/session.bin", 512);
523
524 assert!(result.is_err());
525 }
526
527 #[test]
528 #[serial]
529 fn state_seq_load_file_with_nonexistent_file_returns_error() {
530 let (backend, model) = test_model::load_default_model().unwrap();
531 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
532 let mut context = model.new_context(&backend, ctx_params).unwrap();
533
534 let result = context.state_seq_load_file("/nonexistent/seq_state.bin", 0, 512);
535
536 assert!(result.is_err());
537 }
538
539 #[test]
540 #[serial]
541 fn state_save_file_to_invalid_directory_returns_failed_to_save() {
542 let (backend, model) = test_model::load_default_model().unwrap();
543 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
544 let context = model.new_context(&backend, ctx_params).unwrap();
545
546 let result = context.state_save_file("/nonexistent_dir/session.bin", &[]);
547
548 assert!(result.is_err());
549 }
550
551 #[test]
552 #[serial]
553 fn state_seq_save_file_to_invalid_directory_returns_failed_to_save() {
554 let (backend, model) = test_model::load_default_model().unwrap();
555 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
556 let context = model.new_context(&backend, ctx_params).unwrap();
557
558 let result = context.state_seq_save_file("/nonexistent_dir/seq_state.bin", 0, &[]);
559
560 assert!(result.is_err());
561 }
562
563 #[test]
564 #[serial]
565 fn state_load_file_with_zero_max_tokens_returns_error() {
566 let (backend, model) = test_model::load_default_model().unwrap();
567 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
568 let mut context = model.new_context(&backend, ctx_params).unwrap();
569
570 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
571 let mut batch = LlamaBatch::new(512, 1).unwrap();
572 batch.add_sequence(&tokens, 0, false).unwrap();
573 context.decode(&mut batch).unwrap();
574
575 let session_path = std::env::temp_dir().join("llama_test_session_zero_max.bin");
576 context.state_save_file(&session_path, &tokens).unwrap();
577
578 let result = context.state_load_file(&session_path, 0);
579
580 assert!(result.is_err());
581 let _ = std::fs::remove_file(&session_path);
582 }
583
584 #[test]
585 #[serial]
586 fn state_seq_load_file_with_zero_max_tokens_returns_error() {
587 let (backend, model) = test_model::load_default_model().unwrap();
588 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
589 let mut context = model.new_context(&backend, ctx_params).unwrap();
590
591 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
592 let mut batch = LlamaBatch::new(512, 1).unwrap();
593 batch.add_sequence(&tokens, 0, false).unwrap();
594 context.decode(&mut batch).unwrap();
595
596 let session_path = std::env::temp_dir().join("llama_test_seq_state_zero_max.bin");
597 context
598 .state_seq_save_file(&session_path, 0, &tokens)
599 .unwrap();
600
601 let result = context.state_seq_load_file(&session_path, 0, 0);
602
603 assert!(result.is_err());
604 let _ = std::fs::remove_file(&session_path);
605 }
606
607 #[test]
608 #[serial]
609 fn state_load_file_with_insufficient_max_tokens_returns_length_error() {
610 let (backend, model) = test_model::load_default_model().unwrap();
611 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
612 let mut context = model.new_context(&backend, ctx_params).unwrap();
613
614 let tokens = model
615 .str_to_token(
616 "Hello world this is a longer string for more tokens",
617 AddBos::Always,
618 )
619 .unwrap();
620 let mut batch = LlamaBatch::new(512, 1).unwrap();
621 batch.add_sequence(&tokens, 0, false).unwrap();
622 context.decode(&mut batch).unwrap();
623
624 let session_path = std::env::temp_dir().join("llama_test_session_insuf.bin");
625 context.state_save_file(&session_path, &tokens).unwrap();
626
627 let result = context.state_load_file(&session_path, 1);
628
629 assert!(result.is_err());
630 let _ = std::fs::remove_file(&session_path);
631 }
632
633 #[test]
634 #[serial]
635 fn state_seq_load_file_with_insufficient_max_tokens_returns_length_error() {
636 let (backend, model) = test_model::load_default_model().unwrap();
637 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
638 let mut context = model.new_context(&backend, ctx_params).unwrap();
639
640 let tokens = model
641 .str_to_token(
642 "Hello world this is a longer string for more tokens",
643 AddBos::Always,
644 )
645 .unwrap();
646 let mut batch = LlamaBatch::new(512, 1).unwrap();
647 batch.add_sequence(&tokens, 0, false).unwrap();
648 context.decode(&mut batch).unwrap();
649
650 let session_path = std::env::temp_dir().join("llama_test_seq_state_insuf.bin");
651 context
652 .state_seq_save_file(&session_path, 0, &tokens)
653 .unwrap();
654
655 let result = context.state_seq_load_file(&session_path, 0, 1);
656
657 assert!(result.is_err());
658 let _ = std::fs::remove_file(&session_path);
659 }
660
661 #[cfg(unix)]
662 #[test]
663 #[serial]
664 fn state_save_file_with_non_utf8_path_returns_error() {
665 use std::ffi::OsStr;
666 use std::os::unix::ffi::OsStrExt;
667
668 let (backend, model) = test_model::load_default_model().unwrap();
669 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
670 let context = model.new_context(&backend, ctx_params).unwrap();
671
672 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
673 let result = context.state_save_file(non_utf8_path, &[]);
674
675 assert!(result.is_err());
676 }
677
678 #[cfg(unix)]
679 #[test]
680 #[serial]
681 fn state_load_file_with_non_utf8_path_returns_error() {
682 use std::ffi::OsStr;
683 use std::os::unix::ffi::OsStrExt;
684
685 let (backend, model) = test_model::load_default_model().unwrap();
686 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
687 let mut context = model.new_context(&backend, ctx_params).unwrap();
688
689 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
690 let result = context.state_load_file(non_utf8_path, 512);
691
692 assert!(result.is_err());
693 }
694
695 #[cfg(unix)]
696 #[test]
697 #[serial]
698 fn state_seq_save_file_with_non_utf8_path_returns_error() {
699 use std::ffi::OsStr;
700 use std::os::unix::ffi::OsStrExt;
701
702 let (backend, model) = test_model::load_default_model().unwrap();
703 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
704 let context = model.new_context(&backend, ctx_params).unwrap();
705
706 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
707 let result = context.state_seq_save_file(non_utf8_path, 0, &[]);
708
709 assert!(result.is_err());
710 }
711
712 #[cfg(unix)]
713 #[test]
714 #[serial]
715 fn state_seq_load_file_with_non_utf8_path_returns_error() {
716 use std::ffi::OsStr;
717 use std::os::unix::ffi::OsStrExt;
718
719 let (backend, model) = test_model::load_default_model().unwrap();
720 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
721 let mut context = model.new_context(&backend, ctx_params).unwrap();
722
723 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
724 let result = context.state_seq_load_file(non_utf8_path, 0, 512);
725
726 assert!(result.is_err());
727 }
728}