Skip to main content

llama_cpp_2/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Flags for state sequence operations.
9///
10/// These flags control what parts of the state are included when saving/restoring
11/// sequence state.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct LlamaStateSeqFlags(pub(crate) llama_cpp_sys_2::llama_state_seq_flags);
14
15impl LlamaStateSeqFlags {
16    /// Work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba).
17    ///
18    /// This flag is useful when you only want to save/restore the recurrent state
19    /// without affecting the KV cache.
20    pub const PARTIAL_ONLY: LlamaStateSeqFlags = LlamaStateSeqFlags(1);
21
22    /// Create an empty flags set.
23    pub const fn empty() -> LlamaStateSeqFlags {
24        LlamaStateSeqFlags(0)
25    }
26
27    /// Get the raw flags value.
28    pub const fn bits(&self) -> u32 {
29        self.0
30    }
31
32    /// Check if a flag is set.
33    pub const fn contains(&self, other: LlamaStateSeqFlags) -> bool {
34        (self.0 & other.0) != 0
35    }
36}
37
38impl Default for LlamaStateSeqFlags {
39    fn default() -> Self {
40        Self::empty()
41    }
42}
43
44/// Failed to save a sequence state file
45#[derive(Debug, Eq, PartialEq, thiserror::Error)]
46pub enum SaveSeqStateError {
47    /// llama.cpp failed to save the sequence state file
48    #[error("Failed to save sequence state file")]
49    FailedToSave,
50
51    /// null byte in string
52    #[error("null byte in string {0}")]
53    NullError(#[from] NulError),
54
55    /// failed to convert path to str
56    #[error("failed to convert path {0} to str")]
57    PathToStrError(PathBuf),
58}
59
60/// Failed to load a sequence state file
61#[derive(Debug, Eq, PartialEq, thiserror::Error)]
62pub enum LoadSeqStateError {
63    /// llama.cpp failed to load the sequence state file
64    #[error("Failed to load sequence state file")]
65    FailedToLoad,
66
67    /// null byte in string
68    #[error("null byte in string {0}")]
69    NullError(#[from] NulError),
70
71    /// failed to convert path to str
72    #[error("failed to convert path {0} to str")]
73    PathToStrError(PathBuf),
74
75    /// Insufficient max length
76    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
77    InsufficientMaxLength {
78        /// The length of the loaded sequence
79        n_out: usize,
80        /// The maximum length
81        max_tokens: usize,
82    },
83}
84
85/// Failed to save a Session file
86#[derive(Debug, Eq, PartialEq, thiserror::Error)]
87pub enum SaveSessionError {
88    /// llama.cpp failed to save the session file
89    #[error("Failed to save session file")]
90    FailedToSave,
91
92    /// null byte in string
93    #[error("null byte in string {0}")]
94    NullError(#[from] NulError),
95
96    /// failed to convert path to str
97    #[error("failed to convert path {0} to str")]
98    PathToStrError(PathBuf),
99}
100
101/// Failed to load a Session file
102#[derive(Debug, Eq, PartialEq, thiserror::Error)]
103pub enum LoadSessionError {
104    /// llama.cpp failed to load the session file
105    #[error("Failed to load session file")]
106    FailedToLoad,
107
108    /// null byte in string
109    #[error("null byte in string {0}")]
110    NullError(#[from] NulError),
111
112    /// failed to convert path to str
113    #[error("failed to convert path {0} to str")]
114    PathToStrError(PathBuf),
115
116    /// Insufficient max length
117    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
118    InsufficientMaxLength {
119        /// The length of the session file
120        n_out: usize,
121        /// The maximum length
122        max_tokens: usize,
123    },
124}
125
126impl LlamaContext<'_> {
127    /// Save the current session to a file.
128    ///
129    /// # Parameters
130    ///
131    /// * `path_session` - The file to save to.
132    /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
133    ///
134    /// # Errors
135    ///
136    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
137    #[deprecated(since = "0.1.136", note = "Use `state_save_file` instead")]
138    pub fn save_session_file(
139        &self,
140        path_session: impl AsRef<Path>,
141        tokens: &[LlamaToken],
142    ) -> Result<(), SaveSessionError> {
143        let path = path_session.as_ref();
144        let path = path
145            .to_str()
146            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
147
148        let cstr = CString::new(path)?;
149
150        if unsafe {
151            llama_cpp_sys_2::llama_save_session_file(
152                self.context.as_ptr(),
153                cstr.as_ptr(),
154                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
155                tokens.len(),
156            )
157        } {
158            Ok(())
159        } else {
160            Err(SaveSessionError::FailedToSave)
161        }
162    }
163    /// Load a session file into the current context.
164    ///
165    /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
166    ///
167    /// # Parameters
168    ///
169    /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
170    /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
171    ///
172    /// # Errors
173    ///
174    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
175    #[deprecated(since = "0.1.136", note = "Use `state_load_file` instead")]
176    pub fn load_session_file(
177        &mut self,
178        path_session: impl AsRef<Path>,
179        max_tokens: usize,
180    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
181        let path = path_session.as_ref();
182        let path = path
183            .to_str()
184            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
185
186        let cstr = CString::new(path)?;
187        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
188        let mut n_out = 0;
189
190        // SAFETY: cast is valid as LlamaToken is repr(transparent)
191        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
192
193        let load_session_success = unsafe {
194            llama_cpp_sys_2::llama_load_session_file(
195                self.context.as_ptr(),
196                cstr.as_ptr(),
197                tokens_out,
198                max_tokens,
199                &mut n_out,
200            )
201        };
202        if load_session_success {
203            if n_out > max_tokens {
204                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
205            }
206            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
207            unsafe {
208                tokens.set_len(n_out);
209            }
210            Ok(tokens)
211        } else {
212            Err(LoadSessionError::FailedToLoad)
213        }
214    }
215
216    /// Save the full state to a file.
217    ///
218    /// This is the non-deprecated replacement for [`save_session_file`](Self::save_session_file).
219    ///
220    /// # Parameters
221    ///
222    /// * `path_session` - The file to save to.
223    /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
224    ///   of tokens that the context has processed, so that the relevant KV caches are already filled.
225    ///
226    /// # Errors
227    ///
228    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
229    /// the state file.
230    pub fn state_save_file(
231        &self,
232        path_session: impl AsRef<Path>,
233        tokens: &[LlamaToken],
234    ) -> Result<(), SaveSessionError> {
235        let path = path_session.as_ref();
236        let path = path
237            .to_str()
238            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
239
240        let cstr = CString::new(path)?;
241
242        if unsafe {
243            llama_cpp_sys_2::llama_state_save_file(
244                self.context.as_ptr(),
245                cstr.as_ptr(),
246                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
247                tokens.len(),
248            )
249        } {
250            Ok(())
251        } else {
252            Err(SaveSessionError::FailedToSave)
253        }
254    }
255
256    /// Load a state file into the current context.
257    ///
258    /// This is the non-deprecated replacement for [`load_session_file`](Self::load_session_file).
259    ///
260    /// You still need to pass the returned tokens to the context for inference to work. What this
261    /// function buys you is that the KV caches are already filled with the relevant data.
262    ///
263    /// # Parameters
264    ///
265    /// * `path_session` - The file to load from. It must be a state file from a compatible context,
266    ///   otherwise the function will error.
267    /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
268    ///   longer length, the function will error.
269    ///
270    /// # Errors
271    ///
272    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
273    /// the state file.
274    pub fn state_load_file(
275        &mut self,
276        path_session: impl AsRef<Path>,
277        max_tokens: usize,
278    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
279        let path = path_session.as_ref();
280        let path = path
281            .to_str()
282            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
283
284        let cstr = CString::new(path)?;
285        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
286        let mut n_out = 0;
287
288        // SAFETY: cast is valid as LlamaToken is repr(transparent)
289        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
290
291        let success = unsafe {
292            llama_cpp_sys_2::llama_state_load_file(
293                self.context.as_ptr(),
294                cstr.as_ptr(),
295                tokens_out,
296                max_tokens,
297                &mut n_out,
298            )
299        };
300        if success {
301            if n_out > max_tokens {
302                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
303            }
304            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
305            unsafe {
306                tokens.set_len(n_out);
307            }
308            Ok(tokens)
309        } else {
310            Err(LoadSessionError::FailedToLoad)
311        }
312    }
313
314    /// Save state for a single sequence to a file.
315    ///
316    /// This enables saving state for individual sequences, which is useful for multi-sequence
317    /// inference scenarios.
318    ///
319    /// # Parameters
320    ///
321    /// * `filepath` - The file to save to.
322    /// * `seq_id` - The sequence ID whose state to save.
323    /// * `tokens` - The tokens to associate with the saved state.
324    ///
325    /// # Errors
326    ///
327    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
328    /// the sequence state file.
329    ///
330    /// # Returns
331    ///
332    /// The number of bytes written on success.
333    pub fn state_seq_save_file(
334        &self,
335        filepath: impl AsRef<Path>,
336        seq_id: i32,
337        tokens: &[LlamaToken],
338    ) -> Result<usize, SaveSeqStateError> {
339        let path = filepath.as_ref();
340        let path = path
341            .to_str()
342            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
343
344        let cstr = CString::new(path)?;
345
346        let bytes_written = unsafe {
347            llama_cpp_sys_2::llama_state_seq_save_file(
348                self.context.as_ptr(),
349                cstr.as_ptr(),
350                seq_id,
351                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
352                tokens.len(),
353            )
354        };
355
356        if bytes_written == 0 {
357            Err(SaveSeqStateError::FailedToSave)
358        } else {
359            Ok(bytes_written)
360        }
361    }
362
363    /// Load state for a single sequence from a file.
364    ///
365    /// This enables loading state for individual sequences, which is useful for multi-sequence
366    /// inference scenarios.
367    ///
368    /// # Parameters
369    ///
370    /// * `filepath` - The file to load from.
371    /// * `dest_seq_id` - The destination sequence ID to load the state into.
372    /// * `max_tokens` - The maximum number of tokens to read.
373    ///
374    /// # Errors
375    ///
376    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
377    /// the sequence state file.
378    ///
379    /// # Returns
380    ///
381    /// A tuple of `(tokens, bytes_read)` on success.
382    pub fn state_seq_load_file(
383        &mut self,
384        filepath: impl AsRef<Path>,
385        dest_seq_id: i32,
386        max_tokens: usize,
387    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
388        let path = filepath.as_ref();
389        let path = path
390            .to_str()
391            .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
392
393        let cstr = CString::new(path)?;
394        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
395        let mut n_out = 0;
396
397        // SAFETY: cast is valid as LlamaToken is repr(transparent)
398        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
399
400        let bytes_read = unsafe {
401            llama_cpp_sys_2::llama_state_seq_load_file(
402                self.context.as_ptr(),
403                cstr.as_ptr(),
404                dest_seq_id,
405                tokens_out,
406                max_tokens,
407                &mut n_out,
408            )
409        };
410
411        if bytes_read == 0 {
412            return Err(LoadSeqStateError::FailedToLoad);
413        }
414
415        if n_out > max_tokens {
416            return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
417        }
418
419        // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
420        unsafe {
421            tokens.set_len(n_out);
422        }
423
424        Ok((tokens, bytes_read))
425    }
426
427    /// Returns the maximum size in bytes of the state (rng, logits, embedding
428    /// and `kv_cache`) - will often be smaller after compacting tokens
429    #[must_use]
430    pub fn get_state_size(&self) -> usize {
431        unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
432    }
433
434    /// Copies the state to the specified destination address.
435    ///
436    /// Returns the number of bytes copied
437    ///
438    /// # Safety
439    ///
440    /// Destination needs to have allocated enough memory.
441    pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
442        unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
443    }
444
445    /// Set the state reading from the specified address
446    /// Returns the number of bytes read
447    ///
448    /// # Safety
449    ///
450    /// help wanted: not entirely sure what the safety requirements are here.
451    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
452        unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
453    }
454
455    /// Get the size of the state for a single sequence with optional flags.
456    ///
457    /// This is the extended version that supports flags for partial state operations.
458    ///
459    /// # Parameters
460    ///
461    /// * `seq_id` - The sequence ID to get the state size for.
462    /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
463    ///
464    /// # Returns
465    ///
466    /// The size in bytes needed to store the sequence state.
467    #[must_use]
468    pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: LlamaStateSeqFlags) -> usize {
469        unsafe {
470            llama_cpp_sys_2::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags.0)
471        }
472    }
473
474    /// Copy the state of a single sequence into the specified buffer with optional flags.
475    ///
476    /// This is the extended version that supports flags for partial state operations.
477    ///
478    /// # Parameters
479    ///
480    /// * `dest` - Destination buffer to copy state into.
481    /// * `seq_id` - The sequence ID to get the state for.
482    /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
483    ///
484    /// # Safety
485    ///
486    /// Destination needs to have allocated enough memory.
487    ///
488    /// # Returns
489    ///
490    /// The number of bytes copied.
491    pub unsafe fn state_seq_get_data_ext(
492        &self,
493        dest: *mut u8,
494        seq_id: i32,
495        flags: LlamaStateSeqFlags,
496    ) -> usize {
497        unsafe {
498            llama_cpp_sys_2::llama_state_seq_get_data_ext(
499                self.context.as_ptr(),
500                dest,
501                usize::MAX,
502                seq_id,
503                flags.0,
504            )
505        }
506    }
507
508    /// Set the state for a single sequence from the specified buffer with optional flags.
509    ///
510    /// This is the extended version that supports flags for partial state operations.
511    /// Useful for restoring only the recurrent/partial state without affecting the KV cache.
512    ///
513    /// # Parameters
514    ///
515    /// * `src` - Source buffer containing the state data.
516    /// * `dest_seq_id` - The destination sequence ID to load the state into.
517    /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
518    ///
519    /// # Safety
520    ///
521    /// The source buffer must contain valid state data.
522    ///
523    /// # Returns
524    ///
525    /// Positive on success, zero on failure.
526    pub unsafe fn state_seq_set_data_ext(
527        &mut self,
528        src: &[u8],
529        dest_seq_id: i32,
530        flags: LlamaStateSeqFlags,
531    ) -> bool {
532        unsafe {
533            llama_cpp_sys_2::llama_state_seq_set_data_ext(
534                self.context.as_ptr(),
535                src.as_ptr(),
536                src.len(),
537                dest_seq_id,
538                flags.0,
539            ) > 0
540        }
541    }
542}