Skip to main content

llama_cpp_bindings/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
5use crate::context::load_seq_state_error::LoadSeqStateError;
6use crate::context::load_session_error::LoadSessionError;
7use crate::context::save_seq_state_error::SaveSeqStateError;
8use crate::context::save_session_error::SaveSessionError;
9use crate::token::LlamaToken;
10use std::ffi::CString;
11use std::path::Path;
12
13fn process_session_load_result(
14    success: bool,
15    n_out: usize,
16    max_tokens: usize,
17    mut tokens: Vec<LlamaToken>,
18) -> Result<Vec<LlamaToken>, LoadSessionError> {
19    if !success {
20        return Err(LoadSessionError::FailedToLoad);
21    }
22
23    if n_out > max_tokens {
24        return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
25    }
26
27    unsafe { tokens.set_len(n_out) };
28
29    Ok(tokens)
30}
31
32fn process_seq_load_result(
33    bytes_read: usize,
34    n_out: usize,
35    max_tokens: usize,
36    mut tokens: Vec<LlamaToken>,
37) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
38    if bytes_read == 0 {
39        return Err(LoadSeqStateError::FailedToLoad);
40    }
41
42    if n_out > max_tokens {
43        return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
44    }
45
46    unsafe { tokens.set_len(n_out) };
47
48    Ok((tokens, bytes_read))
49}
50
51impl LlamaContext<'_> {
52    /// Save the full state to a file.
53    ///
54    /// # Parameters
55    ///
56    /// * `path_session` - The file to save to.
57    /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
58    ///   of tokens that the context has processed, so that the relevant KV caches are already filled.
59    ///
60    /// # Errors
61    ///
62    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the state file.
63    pub fn state_save_file(
64        &self,
65        path_session: impl AsRef<Path>,
66        tokens: &[LlamaToken],
67    ) -> Result<(), SaveSessionError> {
68        let path = path_session.as_ref();
69        let path = path
70            .to_str()
71            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
72
73        let cstr = CString::new(path)?;
74
75        if unsafe {
76            llama_cpp_bindings_sys::llama_state_save_file(
77                self.context.as_ptr(),
78                cstr.as_ptr(),
79                tokens
80                    .as_ptr()
81                    .cast::<llama_cpp_bindings_sys::llama_token>(),
82                tokens.len(),
83            )
84        } {
85            Ok(())
86        } else {
87            Err(SaveSessionError::FailedToSave)
88        }
89    }
90
91    /// Load a state file into the current context.
92    ///
93    /// You still need to pass the returned tokens to the context for inference to work. What this
94    /// function buys you is that the KV caches are already filled with the relevant data.
95    ///
96    /// # Parameters
97    ///
98    /// * `path_session` - The file to load from. It must be a state file from a compatible context,
99    ///   otherwise the function will error.
100    /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
101    ///   longer length, the function will error.
102    ///
103    /// # Errors
104    ///
105    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the state file.
106    pub fn state_load_file(
107        &mut self,
108        path_session: impl AsRef<Path>,
109        max_tokens: usize,
110    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
111        let path = path_session.as_ref();
112        let path = path
113            .to_str()
114            .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
115
116        let cstr = CString::new(path)?;
117        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
118        let mut n_out = 0;
119
120        // SAFETY: cast is valid as LlamaToken is repr(transparent)
121        let tokens_out = tokens
122            .as_mut_ptr()
123            .cast::<llama_cpp_bindings_sys::llama_token>();
124
125        let success = unsafe {
126            llama_cpp_bindings_sys::llama_state_load_file(
127                self.context.as_ptr(),
128                cstr.as_ptr(),
129                tokens_out,
130                max_tokens,
131                &raw mut n_out,
132            )
133        };
134        process_session_load_result(success, n_out, max_tokens, tokens)
135    }
136
137    /// Save state for a single sequence to a file.
138    ///
139    /// This enables saving state for individual sequences, which is useful for multi-sequence
140    /// inference scenarios.
141    ///
142    /// # Parameters
143    ///
144    /// * `filepath` - The file to save to.
145    /// * `seq_id` - The sequence ID whose state to save.
146    /// * `tokens` - The tokens to associate with the saved state.
147    ///
148    /// # Errors
149    ///
150    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the sequence state file.
151    ///
152    /// # Returns
153    ///
154    /// The number of bytes written on success.
155    pub fn state_seq_save_file(
156        &self,
157        filepath: impl AsRef<Path>,
158        seq_id: i32,
159        tokens: &[LlamaToken],
160    ) -> Result<usize, SaveSeqStateError> {
161        let path = filepath.as_ref();
162        let path = path
163            .to_str()
164            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166        let cstr = CString::new(path)?;
167
168        let bytes_written = unsafe {
169            llama_cpp_bindings_sys::llama_state_seq_save_file(
170                self.context.as_ptr(),
171                cstr.as_ptr(),
172                seq_id,
173                tokens
174                    .as_ptr()
175                    .cast::<llama_cpp_bindings_sys::llama_token>(),
176                tokens.len(),
177            )
178        };
179
180        if bytes_written == 0 {
181            Err(SaveSeqStateError::FailedToSave)
182        } else {
183            Ok(bytes_written)
184        }
185    }
186
187    /// Load state for a single sequence from a file.
188    ///
189    /// This enables loading state for individual sequences, which is useful for multi-sequence
190    /// inference scenarios.
191    ///
192    /// # Parameters
193    ///
194    /// * `filepath` - The file to load from.
195    /// * `dest_seq_id` - The destination sequence ID to load the state into.
196    /// * `max_tokens` - The maximum number of tokens to read.
197    ///
198    /// # Errors
199    ///
200    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the sequence state file.
201    ///
202    /// # Returns
203    ///
204    /// A tuple of `(tokens, bytes_read)` on success.
205    pub fn state_seq_load_file(
206        &mut self,
207        filepath: impl AsRef<Path>,
208        dest_seq_id: i32,
209        max_tokens: usize,
210    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
211        let path = filepath.as_ref();
212        let path = path
213            .to_str()
214            .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
215
216        let cstr = CString::new(path)?;
217        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
218        let mut n_out = 0;
219
220        // SAFETY: cast is valid as LlamaToken is repr(transparent)
221        let tokens_out = tokens
222            .as_mut_ptr()
223            .cast::<llama_cpp_bindings_sys::llama_token>();
224
225        let bytes_read = unsafe {
226            llama_cpp_bindings_sys::llama_state_seq_load_file(
227                self.context.as_ptr(),
228                cstr.as_ptr(),
229                dest_seq_id,
230                tokens_out,
231                max_tokens,
232                &raw mut n_out,
233            )
234        };
235
236        process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
237    }
238
239    /// Returns the maximum size in bytes of the state (rng, logits, embedding
240    /// and `kv_cache`) - will often be smaller after compacting tokens
241    #[must_use]
242    pub fn get_state_size(&self) -> usize {
243        unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
244    }
245
246    /// Copies the state to the specified destination buffer.
247    ///
248    /// Use [`get_state_size`](Self::get_state_size) to determine the required buffer size.
249    ///
250    /// Returns the number of bytes copied.
251    ///
252    /// # Safety
253    ///
254    /// The `dest` buffer must be large enough to hold the complete state data.
255    pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
256        unsafe {
257            llama_cpp_bindings_sys::llama_state_get_data(
258                self.context.as_ptr(),
259                dest.as_mut_ptr(),
260                dest.len(),
261            )
262        }
263    }
264
265    /// Set the state reading from the specified buffer.
266    ///
267    /// Returns the number of bytes read.
268    ///
269    /// # Safety
270    ///
271    /// The `src` buffer must contain data previously obtained from [`copy_state_data`](Self::copy_state_data)
272    /// on a compatible context (same model and parameters). Passing arbitrary or corrupted bytes
273    /// will lead to undefined behavior.
274    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
275        unsafe {
276            llama_cpp_bindings_sys::llama_state_set_data(
277                self.context.as_ptr(),
278                src.as_ptr(),
279                src.len(),
280            )
281        }
282    }
283
284    /// Get the size of the state data for a specific sequence, with extended flags.
285    ///
286    /// Useful for hybrid/recurrent models where partial state (e.g., only SSM state)
287    /// may be saved or restored.
288    #[must_use]
289    pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
290        unsafe {
291            llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
292                self.context.as_ptr(),
293                seq_id,
294                flags.bits(),
295            )
296        }
297    }
298
299    /// Copy state data for a specific sequence into `dest`, with extended flags.
300    ///
301    /// Use [`state_seq_get_size_ext`](Self::state_seq_get_size_ext) to determine the required
302    /// buffer size before calling this method.
303    ///
304    /// Returns the number of bytes written.
305    ///
306    /// # Safety
307    ///
308    /// The `dest` buffer must be large enough to hold the complete state data.
309    pub unsafe fn state_seq_get_data_ext(
310        &self,
311        dest: &mut [u8],
312        seq_id: i32,
313        flags: &LlamaStateSeqFlags,
314    ) -> usize {
315        unsafe {
316            llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
317                self.context.as_ptr(),
318                dest.as_mut_ptr(),
319                dest.len(),
320                seq_id,
321                flags.bits(),
322            )
323        }
324    }
325
326    /// Restore state data for a specific sequence from `src`, with extended flags.
327    ///
328    /// Returns the number of bytes read.
329    ///
330    /// # Safety
331    ///
332    /// The `src` buffer must contain data previously obtained from
333    /// [`state_seq_get_data_ext`](Self::state_seq_get_data_ext) on a compatible context.
334    pub unsafe fn state_seq_set_data_ext(
335        &mut self,
336        src: &[u8],
337        dest_seq_id: i32,
338        flags: &LlamaStateSeqFlags,
339    ) -> usize {
340        unsafe {
341            llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
342                self.context.as_ptr(),
343                src.as_ptr(),
344                src.len(),
345                dest_seq_id,
346                flags.bits(),
347            )
348        }
349    }
350}
351
352#[cfg(test)]
353mod unit_tests {
354    use crate::token::LlamaToken;
355
356    use crate::context::load_seq_state_error::LoadSeqStateError;
357    use crate::context::load_session_error::LoadSessionError;
358
359    use super::{process_seq_load_result, process_session_load_result};
360
361    #[test]
362    fn session_load_success_within_bounds() {
363        let tokens = vec![LlamaToken::new(0); 100];
364        let result = process_session_load_result(true, 10, 100, tokens);
365
366        assert!(result.is_ok());
367        assert_eq!(result.unwrap().len(), 10);
368    }
369
370    #[test]
371    fn session_load_fails_when_not_successful() {
372        let tokens = vec![LlamaToken::new(0); 100];
373        let result = process_session_load_result(false, 0, 100, tokens);
374
375        assert_eq!(result, Err(LoadSessionError::FailedToLoad));
376    }
377
378    #[test]
379    fn session_load_fails_when_n_out_exceeds_max() {
380        let tokens = vec![LlamaToken::new(0); 100];
381        let result = process_session_load_result(true, 101, 100, tokens);
382
383        assert_eq!(
384            result,
385            Err(LoadSessionError::InsufficientMaxLength {
386                n_out: 101,
387                max_tokens: 100,
388            })
389        );
390    }
391
392    #[test]
393    fn seq_load_success_within_bounds() {
394        let tokens = vec![LlamaToken::new(0); 100];
395        let result = process_seq_load_result(42, 10, 100, tokens);
396
397        assert!(result.is_ok());
398        let (loaded, bytes) = result.unwrap();
399        assert_eq!(loaded.len(), 10);
400        assert_eq!(bytes, 42);
401    }
402
403    #[test]
404    fn seq_load_fails_when_zero_bytes_read() {
405        let tokens = vec![LlamaToken::new(0); 100];
406        let result = process_seq_load_result(0, 0, 100, tokens);
407
408        assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
409    }
410
411    #[test]
412    fn seq_load_fails_when_n_out_exceeds_max() {
413        let tokens = vec![LlamaToken::new(0); 100];
414        let result = process_seq_load_result(42, 101, 100, tokens);
415
416        assert_eq!(
417            result,
418            Err(LoadSeqStateError::InsufficientMaxLength {
419                n_out: 101,
420                max_tokens: 100,
421            })
422        );
423    }
424}