Skip to main content

llama_cpp_bindings/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
5use crate::context::load_seq_state_error::LoadSeqStateError;
6use crate::context::load_session_error::LoadSessionError;
7use crate::context::save_seq_state_error::SaveSeqStateError;
8use crate::context::save_session_error::SaveSessionError;
9use crate::token::LlamaToken;
10use std::ffi::CString;
11use std::path::Path;
12
13fn process_session_load_result(
14    success: bool,
15    n_out: usize,
16    max_tokens: usize,
17    mut tokens: Vec<LlamaToken>,
18) -> Result<Vec<LlamaToken>, LoadSessionError> {
19    if !success {
20        return Err(LoadSessionError::FailedToLoad);
21    }
22
23    if n_out > max_tokens {
24        return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
25    }
26
27    unsafe { tokens.set_len(n_out) };
28
29    Ok(tokens)
30}
31
32fn process_seq_load_result(
33    bytes_read: usize,
34    n_out: usize,
35    max_tokens: usize,
36    mut tokens: Vec<LlamaToken>,
37) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
38    if bytes_read == 0 {
39        return Err(LoadSeqStateError::FailedToLoad);
40    }
41
42    if n_out > max_tokens {
43        return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
44    }
45
46    unsafe { tokens.set_len(n_out) };
47
48    Ok((tokens, bytes_read))
49}
50
51impl LlamaContext<'_> {
52    /// Save the full state to a file.
53    ///
54    /// # Parameters
55    ///
56    /// * `path_session` - The file to save to.
57    /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
58    ///   of tokens that the context has processed, so that the relevant KV caches are already filled.
59    ///
60    /// # Errors
61    ///
62    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the state file.
63    pub fn state_save_file(
64        &self,
65        path_session: impl AsRef<Path>,
66        tokens: &[LlamaToken],
67    ) -> Result<(), SaveSessionError> {
68        let path = path_session.as_ref();
69        let path = path
70            .to_str()
71            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
72
73        let cstr = CString::new(path)?;
74
75        if unsafe {
76            llama_cpp_bindings_sys::llama_state_save_file(
77                self.context.as_ptr(),
78                cstr.as_ptr(),
79                tokens
80                    .as_ptr()
81                    .cast::<llama_cpp_bindings_sys::llama_token>(),
82                tokens.len(),
83            )
84        } {
85            Ok(())
86        } else {
87            Err(SaveSessionError::FailedToSave)
88        }
89    }
90
91    /// Load a state file into the current context.
92    ///
93    /// You still need to pass the returned tokens to the context for inference to work. What this
94    /// function buys you is that the KV caches are already filled with the relevant data.
95    ///
96    /// # Parameters
97    ///
98    /// * `path_session` - The file to load from. It must be a state file from a compatible context,
99    ///   otherwise the function will error.
100    /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
101    ///   longer length, the function will error.
102    ///
103    /// # Errors
104    ///
105    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the state file.
106    pub fn state_load_file(
107        &mut self,
108        path_session: impl AsRef<Path>,
109        max_tokens: usize,
110    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
111        let path = path_session.as_ref();
112        let path = path
113            .to_str()
114            .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
115
116        let cstr = CString::new(path)?;
117        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
118        let mut n_out = 0;
119
120        // SAFETY: cast is valid as LlamaToken is repr(transparent)
121        let tokens_out = tokens
122            .as_mut_ptr()
123            .cast::<llama_cpp_bindings_sys::llama_token>();
124
125        let success = unsafe {
126            llama_cpp_bindings_sys::llama_state_load_file(
127                self.context.as_ptr(),
128                cstr.as_ptr(),
129                tokens_out,
130                max_tokens,
131                &raw mut n_out,
132            )
133        };
134        process_session_load_result(success, n_out, max_tokens, tokens)
135    }
136
137    /// Save state for a single sequence to a file.
138    ///
139    /// This enables saving state for individual sequences, which is useful for multi-sequence
140    /// inference scenarios.
141    ///
142    /// # Parameters
143    ///
144    /// * `filepath` - The file to save to.
145    /// * `seq_id` - The sequence ID whose state to save.
146    /// * `tokens` - The tokens to associate with the saved state.
147    ///
148    /// # Errors
149    ///
150    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the sequence state file.
151    ///
152    /// # Returns
153    ///
154    /// The number of bytes written on success.
155    pub fn state_seq_save_file(
156        &self,
157        filepath: impl AsRef<Path>,
158        seq_id: i32,
159        tokens: &[LlamaToken],
160    ) -> Result<usize, SaveSeqStateError> {
161        let path = filepath.as_ref();
162        let path = path
163            .to_str()
164            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166        let cstr = CString::new(path)?;
167
168        let bytes_written = unsafe {
169            llama_cpp_bindings_sys::llama_state_seq_save_file(
170                self.context.as_ptr(),
171                cstr.as_ptr(),
172                seq_id,
173                tokens
174                    .as_ptr()
175                    .cast::<llama_cpp_bindings_sys::llama_token>(),
176                tokens.len(),
177            )
178        };
179
180        if bytes_written == 0 {
181            Err(SaveSeqStateError::FailedToSave)
182        } else {
183            Ok(bytes_written)
184        }
185    }
186
187    /// Load state for a single sequence from a file.
188    ///
189    /// This enables loading state for individual sequences, which is useful for multi-sequence
190    /// inference scenarios.
191    ///
192    /// # Parameters
193    ///
194    /// * `filepath` - The file to load from.
195    /// * `dest_seq_id` - The destination sequence ID to load the state into.
196    /// * `max_tokens` - The maximum number of tokens to read.
197    ///
198    /// # Errors
199    ///
200    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the sequence state file.
201    ///
202    /// # Returns
203    ///
204    /// A tuple of `(tokens, bytes_read)` on success.
205    pub fn state_seq_load_file(
206        &mut self,
207        filepath: impl AsRef<Path>,
208        dest_seq_id: i32,
209        max_tokens: usize,
210    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
211        let path = filepath.as_ref();
212        let path = path
213            .to_str()
214            .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
215
216        let cstr = CString::new(path)?;
217        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
218        let mut n_out = 0;
219
220        // SAFETY: cast is valid as LlamaToken is repr(transparent)
221        let tokens_out = tokens
222            .as_mut_ptr()
223            .cast::<llama_cpp_bindings_sys::llama_token>();
224
225        let bytes_read = unsafe {
226            llama_cpp_bindings_sys::llama_state_seq_load_file(
227                self.context.as_ptr(),
228                cstr.as_ptr(),
229                dest_seq_id,
230                tokens_out,
231                max_tokens,
232                &raw mut n_out,
233            )
234        };
235
236        process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
237    }
238
239    /// Returns the maximum size in bytes of the state (rng, logits, embedding
240    /// and `kv_cache`) - will often be smaller after compacting tokens
241    #[must_use]
242    pub fn get_state_size(&self) -> usize {
243        unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
244    }
245
246    /// Copies the state to the specified destination buffer.
247    ///
248    /// Use [`get_state_size`](Self::get_state_size) to determine the required buffer size.
249    ///
250    /// Returns the number of bytes copied.
251    ///
252    /// # Safety
253    ///
254    /// The `dest` buffer must be large enough to hold the complete state data.
255    pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
256        unsafe {
257            llama_cpp_bindings_sys::llama_state_get_data(
258                self.context.as_ptr(),
259                dest.as_mut_ptr(),
260                dest.len(),
261            )
262        }
263    }
264
265    /// Set the state reading from the specified buffer.
266    ///
267    /// Returns the number of bytes read.
268    ///
269    /// # Safety
270    ///
271    /// The `src` buffer must contain data previously obtained from [`copy_state_data`](Self::copy_state_data)
272    /// on a compatible context (same model and parameters). Passing arbitrary or corrupted bytes
273    /// will lead to undefined behavior.
274    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
275        unsafe {
276            llama_cpp_bindings_sys::llama_state_set_data(
277                self.context.as_ptr(),
278                src.as_ptr(),
279                src.len(),
280            )
281        }
282    }
283
284    /// Get the size of the state data for a specific sequence, with extended flags.
285    ///
286    /// Useful for hybrid/recurrent models where partial state (e.g., only SSM state)
287    /// may be saved or restored.
288    #[must_use]
289    pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
290        unsafe {
291            llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
292                self.context.as_ptr(),
293                seq_id,
294                flags.bits(),
295            )
296        }
297    }
298
299    /// Copy state data for a specific sequence into `dest`, with extended flags.
300    ///
301    /// Use [`state_seq_get_size_ext`](Self::state_seq_get_size_ext) to determine the required
302    /// buffer size before calling this method.
303    ///
304    /// Returns the number of bytes written.
305    ///
306    /// # Safety
307    ///
308    /// The `dest` buffer must be large enough to hold the complete state data.
309    pub unsafe fn state_seq_get_data_ext(
310        &self,
311        dest: &mut [u8],
312        seq_id: i32,
313        flags: &LlamaStateSeqFlags,
314    ) -> usize {
315        unsafe {
316            llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
317                self.context.as_ptr(),
318                dest.as_mut_ptr(),
319                dest.len(),
320                seq_id,
321                flags.bits(),
322            )
323        }
324    }
325
326    /// Restore state data for a specific sequence from `src`, with extended flags.
327    ///
328    /// Returns the number of bytes read.
329    ///
330    /// # Safety
331    ///
332    /// The `src` buffer must contain data previously obtained from
333    /// [`state_seq_get_data_ext`](Self::state_seq_get_data_ext) on a compatible context.
334    pub unsafe fn state_seq_set_data_ext(
335        &mut self,
336        src: &[u8],
337        dest_seq_id: i32,
338        flags: &LlamaStateSeqFlags,
339    ) -> usize {
340        unsafe {
341            llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
342                self.context.as_ptr(),
343                src.as_ptr(),
344                src.len(),
345                dest_seq_id,
346                flags.bits(),
347            )
348        }
349    }
350}
351
352#[cfg(test)]
353mod unit_tests {
354    use crate::token::LlamaToken;
355
356    use crate::context::load_seq_state_error::LoadSeqStateError;
357    use crate::context::load_session_error::LoadSessionError;
358
359    use super::{process_seq_load_result, process_session_load_result};
360
361    #[test]
362    fn session_load_success_within_bounds() {
363        let tokens = vec![LlamaToken::new(0); 100];
364        let result = process_session_load_result(true, 10, 100, tokens);
365
366        assert!(result.is_ok());
367        assert_eq!(result.unwrap().len(), 10);
368    }
369
370    #[test]
371    fn session_load_fails_when_not_successful() {
372        let tokens = vec![LlamaToken::new(0); 100];
373        let result = process_session_load_result(false, 0, 100, tokens);
374
375        assert_eq!(result, Err(LoadSessionError::FailedToLoad));
376    }
377
378    #[test]
379    fn session_load_fails_when_n_out_exceeds_max() {
380        let tokens = vec![LlamaToken::new(0); 100];
381        let result = process_session_load_result(true, 101, 100, tokens);
382
383        assert_eq!(
384            result,
385            Err(LoadSessionError::InsufficientMaxLength {
386                n_out: 101,
387                max_tokens: 100,
388            })
389        );
390    }
391
392    #[test]
393    fn seq_load_success_within_bounds() {
394        let tokens = vec![LlamaToken::new(0); 100];
395        let result = process_seq_load_result(42, 10, 100, tokens);
396
397        assert!(result.is_ok());
398        let (loaded, bytes) = result.unwrap();
399        assert_eq!(loaded.len(), 10);
400        assert_eq!(bytes, 42);
401    }
402
403    #[test]
404    fn seq_load_fails_when_zero_bytes_read() {
405        let tokens = vec![LlamaToken::new(0); 100];
406        let result = process_seq_load_result(0, 0, 100, tokens);
407
408        assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
409    }
410
411    #[test]
412    fn seq_load_fails_when_n_out_exceeds_max() {
413        let tokens = vec![LlamaToken::new(0); 100];
414        let result = process_seq_load_result(42, 101, 100, tokens);
415
416        assert_eq!(
417            result,
418            Err(LoadSeqStateError::InsufficientMaxLength {
419                n_out: 101,
420                max_tokens: 100,
421            })
422        );
423    }
424}
425
426#[cfg(test)]
427#[cfg(feature = "tests_that_use_llms")]
428mod tests {
429    use std::num::NonZeroU32;
430
431    use serial_test::serial;
432
433    use crate::context::params::LlamaContextParams;
434    use crate::llama_batch::LlamaBatch;
435    use crate::model::AddBos;
436    use crate::test_model;
437
438    #[test]
439    #[serial]
440    fn save_and_load_session_file() {
441        let (backend, model) = test_model::load_default_model().unwrap();
442        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
443        let mut context = model.new_context(&backend, ctx_params).unwrap();
444
445        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
446        let mut batch = LlamaBatch::new(512, 1).unwrap();
447        batch.add_sequence(&tokens, 0, false).unwrap();
448        context.decode(&mut batch).unwrap();
449
450        let session_path = std::env::temp_dir().join("llama_test_session.bin");
451        context.state_save_file(&session_path, &tokens).unwrap();
452
453        let loaded_tokens = context.state_load_file(&session_path, 512).unwrap();
454        assert_eq!(loaded_tokens, tokens);
455
456        std::fs::remove_file(&session_path).unwrap();
457    }
458
459    #[test]
460    #[serial]
461    fn get_state_size_is_positive() {
462        let (backend, model) = test_model::load_default_model().unwrap();
463        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
464        let context = model.new_context(&backend, ctx_params).unwrap();
465        assert!(context.get_state_size() > 0);
466    }
467
468    #[test]
469    #[serial]
470    fn state_seq_save_and_load_file_roundtrip() {
471        let (backend, model) = test_model::load_default_model().unwrap();
472        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
473        let mut context = model.new_context(&backend, ctx_params).unwrap();
474
475        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
476        let mut batch = LlamaBatch::new(512, 1).unwrap();
477        batch.add_sequence(&tokens, 0, false).unwrap();
478        context.decode(&mut batch).unwrap();
479
480        let session_path = std::env::temp_dir().join("llama_test_seq_state.bin");
481        let bytes_written = context
482            .state_seq_save_file(&session_path, 0, &tokens)
483            .unwrap();
484        assert!(bytes_written > 0);
485
486        let (loaded_tokens, bytes_read) =
487            context.state_seq_load_file(&session_path, 0, 512).unwrap();
488        assert_eq!(loaded_tokens, tokens);
489        assert!(bytes_read > 0);
490
491        std::fs::remove_file(&session_path).unwrap();
492    }
493
494    #[test]
495    #[serial]
496    fn copy_state_data_and_set_state_data_roundtrip() {
497        let (backend, model) = test_model::load_default_model().unwrap();
498        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
499        let mut context = model.new_context(&backend, ctx_params).unwrap();
500
501        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
502        let mut batch = LlamaBatch::new(512, 1).unwrap();
503        batch.add_sequence(&tokens, 0, false).unwrap();
504        context.decode(&mut batch).unwrap();
505
506        let state_size = context.get_state_size();
507        let mut state_data = vec![0u8; state_size];
508        let bytes_copied = unsafe { context.copy_state_data(&mut state_data) };
509        assert!(bytes_copied > 0);
510
511        let bytes_read = unsafe { context.set_state_data(&state_data) };
512        assert!(bytes_read > 0);
513    }
514
515    #[test]
516    #[serial]
517    fn state_load_file_with_nonexistent_file_returns_error() {
518        let (backend, model) = test_model::load_default_model().unwrap();
519        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
520        let mut context = model.new_context(&backend, ctx_params).unwrap();
521
522        let result = context.state_load_file("/nonexistent/session.bin", 512);
523
524        assert!(result.is_err());
525    }
526
527    #[test]
528    #[serial]
529    fn state_seq_load_file_with_nonexistent_file_returns_error() {
530        let (backend, model) = test_model::load_default_model().unwrap();
531        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
532        let mut context = model.new_context(&backend, ctx_params).unwrap();
533
534        let result = context.state_seq_load_file("/nonexistent/seq_state.bin", 0, 512);
535
536        assert!(result.is_err());
537    }
538
539    #[test]
540    #[serial]
541    fn state_save_file_to_invalid_directory_returns_failed_to_save() {
542        let (backend, model) = test_model::load_default_model().unwrap();
543        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
544        let context = model.new_context(&backend, ctx_params).unwrap();
545
546        let result = context.state_save_file("/nonexistent_dir/session.bin", &[]);
547
548        assert!(result.is_err());
549    }
550
551    #[test]
552    #[serial]
553    fn state_seq_save_file_to_invalid_directory_returns_failed_to_save() {
554        let (backend, model) = test_model::load_default_model().unwrap();
555        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
556        let context = model.new_context(&backend, ctx_params).unwrap();
557
558        let result = context.state_seq_save_file("/nonexistent_dir/seq_state.bin", 0, &[]);
559
560        assert!(result.is_err());
561    }
562
563    #[test]
564    #[serial]
565    fn state_load_file_with_zero_max_tokens_returns_error() {
566        let (backend, model) = test_model::load_default_model().unwrap();
567        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
568        let mut context = model.new_context(&backend, ctx_params).unwrap();
569
570        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
571        let mut batch = LlamaBatch::new(512, 1).unwrap();
572        batch.add_sequence(&tokens, 0, false).unwrap();
573        context.decode(&mut batch).unwrap();
574
575        let session_path = std::env::temp_dir().join("llama_test_session_zero_max.bin");
576        context.state_save_file(&session_path, &tokens).unwrap();
577
578        let result = context.state_load_file(&session_path, 0);
579
580        assert!(result.is_err());
581        let _ = std::fs::remove_file(&session_path);
582    }
583
584    #[test]
585    #[serial]
586    fn state_seq_load_file_with_zero_max_tokens_returns_error() {
587        let (backend, model) = test_model::load_default_model().unwrap();
588        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
589        let mut context = model.new_context(&backend, ctx_params).unwrap();
590
591        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
592        let mut batch = LlamaBatch::new(512, 1).unwrap();
593        batch.add_sequence(&tokens, 0, false).unwrap();
594        context.decode(&mut batch).unwrap();
595
596        let session_path = std::env::temp_dir().join("llama_test_seq_state_zero_max.bin");
597        context
598            .state_seq_save_file(&session_path, 0, &tokens)
599            .unwrap();
600
601        let result = context.state_seq_load_file(&session_path, 0, 0);
602
603        assert!(result.is_err());
604        let _ = std::fs::remove_file(&session_path);
605    }
606
607    #[test]
608    #[serial]
609    fn state_load_file_with_insufficient_max_tokens_returns_length_error() {
610        let (backend, model) = test_model::load_default_model().unwrap();
611        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
612        let mut context = model.new_context(&backend, ctx_params).unwrap();
613
614        let tokens = model
615            .str_to_token(
616                "Hello world this is a longer string for more tokens",
617                AddBos::Always,
618            )
619            .unwrap();
620        let mut batch = LlamaBatch::new(512, 1).unwrap();
621        batch.add_sequence(&tokens, 0, false).unwrap();
622        context.decode(&mut batch).unwrap();
623
624        let session_path = std::env::temp_dir().join("llama_test_session_insuf.bin");
625        context.state_save_file(&session_path, &tokens).unwrap();
626
627        let result = context.state_load_file(&session_path, 1);
628
629        assert!(result.is_err());
630        let _ = std::fs::remove_file(&session_path);
631    }
632
633    #[test]
634    #[serial]
635    fn state_seq_load_file_with_insufficient_max_tokens_returns_length_error() {
636        let (backend, model) = test_model::load_default_model().unwrap();
637        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
638        let mut context = model.new_context(&backend, ctx_params).unwrap();
639
640        let tokens = model
641            .str_to_token(
642                "Hello world this is a longer string for more tokens",
643                AddBos::Always,
644            )
645            .unwrap();
646        let mut batch = LlamaBatch::new(512, 1).unwrap();
647        batch.add_sequence(&tokens, 0, false).unwrap();
648        context.decode(&mut batch).unwrap();
649
650        let session_path = std::env::temp_dir().join("llama_test_seq_state_insuf.bin");
651        context
652            .state_seq_save_file(&session_path, 0, &tokens)
653            .unwrap();
654
655        let result = context.state_seq_load_file(&session_path, 0, 1);
656
657        assert!(result.is_err());
658        let _ = std::fs::remove_file(&session_path);
659    }
660
661    #[cfg(unix)]
662    #[test]
663    #[serial]
664    fn state_save_file_with_non_utf8_path_returns_error() {
665        use std::ffi::OsStr;
666        use std::os::unix::ffi::OsStrExt;
667
668        let (backend, model) = test_model::load_default_model().unwrap();
669        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
670        let context = model.new_context(&backend, ctx_params).unwrap();
671
672        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
673        let result = context.state_save_file(non_utf8_path, &[]);
674
675        assert!(result.is_err());
676    }
677
678    #[cfg(unix)]
679    #[test]
680    #[serial]
681    fn state_load_file_with_non_utf8_path_returns_error() {
682        use std::ffi::OsStr;
683        use std::os::unix::ffi::OsStrExt;
684
685        let (backend, model) = test_model::load_default_model().unwrap();
686        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
687        let mut context = model.new_context(&backend, ctx_params).unwrap();
688
689        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
690        let result = context.state_load_file(non_utf8_path, 512);
691
692        assert!(result.is_err());
693    }
694
695    #[cfg(unix)]
696    #[test]
697    #[serial]
698    fn state_seq_save_file_with_non_utf8_path_returns_error() {
699        use std::ffi::OsStr;
700        use std::os::unix::ffi::OsStrExt;
701
702        let (backend, model) = test_model::load_default_model().unwrap();
703        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
704        let context = model.new_context(&backend, ctx_params).unwrap();
705
706        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
707        let result = context.state_seq_save_file(non_utf8_path, 0, &[]);
708
709        assert!(result.is_err());
710    }
711
712    #[cfg(unix)]
713    #[test]
714    #[serial]
715    fn state_seq_load_file_with_non_utf8_path_returns_error() {
716        use std::ffi::OsStr;
717        use std::os::unix::ffi::OsStrExt;
718
719        let (backend, model) = test_model::load_default_model().unwrap();
720        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
721        let mut context = model.new_context(&backend, ctx_params).unwrap();
722
723        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.bin"));
724        let result = context.state_seq_load_file(non_utf8_path, 0, 512);
725
726        assert!(result.is_err());
727    }
728}