Skip to main content

llama_cpp_bindings/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a sequence state file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSeqStateError {
11    /// llama.cpp failed to save the sequence state file
12    #[error("Failed to save sequence state file")]
13    FailedToSave,
14
15    /// null byte in string
16    #[error("null byte in string {0}")]
17    NullError(#[from] NulError),
18
19    /// failed to convert path to str
20    #[error("failed to convert path {0} to str")]
21    PathToStrError(PathBuf),
22}
23
24/// Failed to load a sequence state file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSeqStateError {
27    /// llama.cpp failed to load the sequence state file
28    #[error("Failed to load sequence state file")]
29    FailedToLoad,
30
31    /// null byte in string
32    #[error("null byte in string {0}")]
33    NullError(#[from] NulError),
34
35    /// failed to convert path to str
36    #[error("failed to convert path {0} to str")]
37    PathToStrError(PathBuf),
38
39    /// Insufficient max length
40    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41    InsufficientMaxLength {
42        /// The length of the loaded sequence
43        n_out: usize,
44        /// The maximum length
45        max_tokens: usize,
46    },
47}
48
49/// Failed to save a Session file
50#[derive(Debug, Eq, PartialEq, thiserror::Error)]
51pub enum SaveSessionError {
52    /// llama.cpp failed to save the session file
53    #[error("Failed to save session file")]
54    FailedToSave,
55
56    /// null byte in string
57    #[error("null byte in string {0}")]
58    NullError(#[from] NulError),
59
60    /// failed to convert path to str
61    #[error("failed to convert path {0} to str")]
62    PathToStrError(PathBuf),
63}
64
65/// Failed to load a Session file
66#[derive(Debug, Eq, PartialEq, thiserror::Error)]
67pub enum LoadSessionError {
68    /// llama.cpp failed to load the session file
69    #[error("Failed to load session file")]
70    FailedToLoad,
71
72    /// null byte in string
73    #[error("null byte in string {0}")]
74    NullError(#[from] NulError),
75
76    /// failed to convert path to str
77    #[error("failed to convert path {0} to str")]
78    PathToStrError(PathBuf),
79
80    /// Insufficient max length
81    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
82    InsufficientMaxLength {
83        /// The length of the session file
84        n_out: usize,
85        /// The maximum length
86        max_tokens: usize,
87    },
88}
89
90impl LlamaContext<'_> {
91    /// Save the full state to a file.
92    ///
93    /// # Parameters
94    ///
95    /// * `path_session` - The file to save to.
96    /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
97    ///   of tokens that the context has processed, so that the relevant KV caches are already filled.
98    ///
99    /// # Errors
100    ///
101    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
102    /// the state file.
103    pub fn state_save_file(
104        &self,
105        path_session: impl AsRef<Path>,
106        tokens: &[LlamaToken],
107    ) -> Result<(), SaveSessionError> {
108        let path = path_session.as_ref();
109        let path = path
110            .to_str()
111            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
112
113        let cstr = CString::new(path)?;
114
115        if unsafe {
116            llama_cpp_bindings_sys::llama_state_save_file(
117                self.context.as_ptr(),
118                cstr.as_ptr(),
119                tokens
120                    .as_ptr()
121                    .cast::<llama_cpp_bindings_sys::llama_token>(),
122                tokens.len(),
123            )
124        } {
125            Ok(())
126        } else {
127            Err(SaveSessionError::FailedToSave)
128        }
129    }
130
131    /// Load a state file into the current context.
132    ///
133    /// You still need to pass the returned tokens to the context for inference to work. What this
134    /// function buys you is that the KV caches are already filled with the relevant data.
135    ///
136    /// # Parameters
137    ///
138    /// * `path_session` - The file to load from. It must be a state file from a compatible context,
139    ///   otherwise the function will error.
140    /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
141    ///   longer length, the function will error.
142    ///
143    /// # Errors
144    ///
145    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
146    /// the state file.
147    pub fn state_load_file(
148        &mut self,
149        path_session: impl AsRef<Path>,
150        max_tokens: usize,
151    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
152        let path = path_session.as_ref();
153        let path = path
154            .to_str()
155            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
156
157        let cstr = CString::new(path)?;
158        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
159        let mut n_out = 0;
160
161        // SAFETY: cast is valid as LlamaToken is repr(transparent)
162        let tokens_out = tokens
163            .as_mut_ptr()
164            .cast::<llama_cpp_bindings_sys::llama_token>();
165
166        let success = unsafe {
167            llama_cpp_bindings_sys::llama_state_load_file(
168                self.context.as_ptr(),
169                cstr.as_ptr(),
170                tokens_out,
171                max_tokens,
172                &raw mut n_out,
173            )
174        };
175        if success {
176            if n_out > max_tokens {
177                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
178            }
179            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
180            unsafe {
181                tokens.set_len(n_out);
182            }
183            Ok(tokens)
184        } else {
185            Err(LoadSessionError::FailedToLoad)
186        }
187    }
188
189    /// Save state for a single sequence to a file.
190    ///
191    /// This enables saving state for individual sequences, which is useful for multi-sequence
192    /// inference scenarios.
193    ///
194    /// # Parameters
195    ///
196    /// * `filepath` - The file to save to.
197    /// * `seq_id` - The sequence ID whose state to save.
198    /// * `tokens` - The tokens to associate with the saved state.
199    ///
200    /// # Errors
201    ///
202    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
203    /// the sequence state file.
204    ///
205    /// # Returns
206    ///
207    /// The number of bytes written on success.
208    pub fn state_seq_save_file(
209        &self,
210        filepath: impl AsRef<Path>,
211        seq_id: i32,
212        tokens: &[LlamaToken],
213    ) -> Result<usize, SaveSeqStateError> {
214        let path = filepath.as_ref();
215        let path = path
216            .to_str()
217            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
218
219        let cstr = CString::new(path)?;
220
221        let bytes_written = unsafe {
222            llama_cpp_bindings_sys::llama_state_seq_save_file(
223                self.context.as_ptr(),
224                cstr.as_ptr(),
225                seq_id,
226                tokens
227                    .as_ptr()
228                    .cast::<llama_cpp_bindings_sys::llama_token>(),
229                tokens.len(),
230            )
231        };
232
233        if bytes_written == 0 {
234            Err(SaveSeqStateError::FailedToSave)
235        } else {
236            Ok(bytes_written)
237        }
238    }
239
240    /// Load state for a single sequence from a file.
241    ///
242    /// This enables loading state for individual sequences, which is useful for multi-sequence
243    /// inference scenarios.
244    ///
245    /// # Parameters
246    ///
247    /// * `filepath` - The file to load from.
248    /// * `dest_seq_id` - The destination sequence ID to load the state into.
249    /// * `max_tokens` - The maximum number of tokens to read.
250    ///
251    /// # Errors
252    ///
253    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
254    /// the sequence state file.
255    ///
256    /// # Returns
257    ///
258    /// A tuple of `(tokens, bytes_read)` on success.
259    pub fn state_seq_load_file(
260        &mut self,
261        filepath: impl AsRef<Path>,
262        dest_seq_id: i32,
263        max_tokens: usize,
264    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
265        let path = filepath.as_ref();
266        let path = path
267            .to_str()
268            .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
269
270        let cstr = CString::new(path)?;
271        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
272        let mut n_out = 0;
273
274        // SAFETY: cast is valid as LlamaToken is repr(transparent)
275        let tokens_out = tokens
276            .as_mut_ptr()
277            .cast::<llama_cpp_bindings_sys::llama_token>();
278
279        let bytes_read = unsafe {
280            llama_cpp_bindings_sys::llama_state_seq_load_file(
281                self.context.as_ptr(),
282                cstr.as_ptr(),
283                dest_seq_id,
284                tokens_out,
285                max_tokens,
286                &raw mut n_out,
287            )
288        };
289
290        if bytes_read == 0 {
291            return Err(LoadSeqStateError::FailedToLoad);
292        }
293
294        if n_out > max_tokens {
295            return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
296        }
297
298        // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
299        unsafe {
300            tokens.set_len(n_out);
301        }
302
303        Ok((tokens, bytes_read))
304    }
305
306    /// Returns the maximum size in bytes of the state (rng, logits, embedding
307    /// and `kv_cache`) - will often be smaller after compacting tokens
308    #[must_use]
309    pub fn get_state_size(&self) -> usize {
310        unsafe { llama_cpp_bindings_sys::llama_get_state_size(self.context.as_ptr()) }
311    }
312
313    /// Copies the state to the specified destination address.
314    ///
315    /// Returns the number of bytes copied
316    ///
317    /// # Safety
318    ///
319    /// Destination needs to have allocated enough memory.
320    pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
321        unsafe { llama_cpp_bindings_sys::llama_copy_state_data(self.context.as_ptr(), dest) }
322    }
323
324    /// Set the state reading from the specified address.
325    /// Returns the number of bytes read.
326    ///
327    /// # Safety
328    ///
329    /// The `src` buffer must contain data previously obtained from [`copy_state_data`](Self::copy_state_data)
330    /// on a compatible context (same model and parameters). Passing arbitrary or corrupted bytes
331    /// will lead to undefined behavior.
332    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
333        unsafe { llama_cpp_bindings_sys::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
334    }
335}