Skip to main content

llama_cpp_2/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a sequence state file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSeqStateError {
11    /// llama.cpp failed to save the sequence state file
12    #[error("Failed to save sequence state file")]
13    FailedToSave,
14
15    /// null byte in string
16    #[error("null byte in string {0}")]
17    NullError(#[from] NulError),
18
19    /// failed to convert path to str
20    #[error("failed to convert path {0} to str")]
21    PathToStrError(PathBuf),
22}
23
24/// Failed to load a sequence state file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSeqStateError {
27    /// llama.cpp failed to load the sequence state file
28    #[error("Failed to load sequence state file")]
29    FailedToLoad,
30
31    /// null byte in string
32    #[error("null byte in string {0}")]
33    NullError(#[from] NulError),
34
35    /// failed to convert path to str
36    #[error("failed to convert path {0} to str")]
37    PathToStrError(PathBuf),
38
39    /// Insufficient max length
40    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41    InsufficientMaxLength {
42        /// The length of the loaded sequence
43        n_out: usize,
44        /// The maximum length
45        max_tokens: usize,
46    },
47}
48
49/// Failed to save a Session file
50#[derive(Debug, Eq, PartialEq, thiserror::Error)]
51pub enum SaveSessionError {
52    /// llama.cpp failed to save the session file
53    #[error("Failed to save session file")]
54    FailedToSave,
55
56    /// null byte in string
57    #[error("null byte in string {0}")]
58    NullError(#[from] NulError),
59
60    /// failed to convert path to str
61    #[error("failed to convert path {0} to str")]
62    PathToStrError(PathBuf),
63}
64
65/// Failed to load a Session file
66#[derive(Debug, Eq, PartialEq, thiserror::Error)]
67pub enum LoadSessionError {
68    /// llama.cpp failed to load the session file
69    #[error("Failed to load session file")]
70    FailedToLoad,
71
72    /// null byte in string
73    #[error("null byte in string {0}")]
74    NullError(#[from] NulError),
75
76    /// failed to convert path to str
77    #[error("failed to convert path {0} to str")]
78    PathToStrError(PathBuf),
79
80    /// Insufficient max length
81    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
82    InsufficientMaxLength {
83        /// The length of the session file
84        n_out: usize,
85        /// The maximum length
86        max_tokens: usize,
87    },
88}
89
90impl LlamaContext<'_> {
91    /// Save the current session to a file.
92    ///
93    /// # Parameters
94    ///
95    /// * `path_session` - The file to save to.
96    /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
97    ///
98    /// # Errors
99    ///
100    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
101    #[deprecated(since = "0.1.136", note = "Use `state_save_file` instead")]
102    pub fn save_session_file(
103        &self,
104        path_session: impl AsRef<Path>,
105        tokens: &[LlamaToken],
106    ) -> Result<(), SaveSessionError> {
107        let path = path_session.as_ref();
108        let path = path
109            .to_str()
110            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
111
112        let cstr = CString::new(path)?;
113
114        if unsafe {
115            llama_cpp_sys_2::llama_save_session_file(
116                self.context.as_ptr(),
117                cstr.as_ptr(),
118                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
119                tokens.len(),
120            )
121        } {
122            Ok(())
123        } else {
124            Err(SaveSessionError::FailedToSave)
125        }
126    }
127    /// Load a session file into the current context.
128    ///
129    /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
130    ///
131    /// # Parameters
132    ///
133    /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
134    /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
135    ///
136    /// # Errors
137    ///
138    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
139    #[deprecated(since = "0.1.136", note = "Use `state_load_file` instead")]
140    pub fn load_session_file(
141        &mut self,
142        path_session: impl AsRef<Path>,
143        max_tokens: usize,
144    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
145        let path = path_session.as_ref();
146        let path = path
147            .to_str()
148            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
149
150        let cstr = CString::new(path)?;
151        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
152        let mut n_out = 0;
153
154        // SAFETY: cast is valid as LlamaToken is repr(transparent)
155        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
156
157        let load_session_success = unsafe {
158            llama_cpp_sys_2::llama_load_session_file(
159                self.context.as_ptr(),
160                cstr.as_ptr(),
161                tokens_out,
162                max_tokens,
163                &mut n_out,
164            )
165        };
166        if load_session_success {
167            if n_out > max_tokens {
168                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
169            }
170            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
171            unsafe {
172                tokens.set_len(n_out);
173            }
174            Ok(tokens)
175        } else {
176            Err(LoadSessionError::FailedToLoad)
177        }
178    }
179
180    /// Save the full state to a file.
181    ///
182    /// This is the non-deprecated replacement for [`save_session_file`](Self::save_session_file).
183    ///
184    /// # Parameters
185    ///
186    /// * `path_session` - The file to save to.
187    /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
188    ///   of tokens that the context has processed, so that the relevant KV caches are already filled.
189    ///
190    /// # Errors
191    ///
192    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
193    /// the state file.
194    pub fn state_save_file(
195        &self,
196        path_session: impl AsRef<Path>,
197        tokens: &[LlamaToken],
198    ) -> Result<(), SaveSessionError> {
199        let path = path_session.as_ref();
200        let path = path
201            .to_str()
202            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
203
204        let cstr = CString::new(path)?;
205
206        if unsafe {
207            llama_cpp_sys_2::llama_state_save_file(
208                self.context.as_ptr(),
209                cstr.as_ptr(),
210                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
211                tokens.len(),
212            )
213        } {
214            Ok(())
215        } else {
216            Err(SaveSessionError::FailedToSave)
217        }
218    }
219
220    /// Load a state file into the current context.
221    ///
222    /// This is the non-deprecated replacement for [`load_session_file`](Self::load_session_file).
223    ///
224    /// You still need to pass the returned tokens to the context for inference to work. What this
225    /// function buys you is that the KV caches are already filled with the relevant data.
226    ///
227    /// # Parameters
228    ///
229    /// * `path_session` - The file to load from. It must be a state file from a compatible context,
230    ///   otherwise the function will error.
231    /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
232    ///   longer length, the function will error.
233    ///
234    /// # Errors
235    ///
236    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
237    /// the state file.
238    pub fn state_load_file(
239        &mut self,
240        path_session: impl AsRef<Path>,
241        max_tokens: usize,
242    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
243        let path = path_session.as_ref();
244        let path = path
245            .to_str()
246            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
247
248        let cstr = CString::new(path)?;
249        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
250        let mut n_out = 0;
251
252        // SAFETY: cast is valid as LlamaToken is repr(transparent)
253        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
254
255        let success = unsafe {
256            llama_cpp_sys_2::llama_state_load_file(
257                self.context.as_ptr(),
258                cstr.as_ptr(),
259                tokens_out,
260                max_tokens,
261                &mut n_out,
262            )
263        };
264        if success {
265            if n_out > max_tokens {
266                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
267            }
268            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
269            unsafe {
270                tokens.set_len(n_out);
271            }
272            Ok(tokens)
273        } else {
274            Err(LoadSessionError::FailedToLoad)
275        }
276    }
277
278    /// Save state for a single sequence to a file.
279    ///
280    /// This enables saving state for individual sequences, which is useful for multi-sequence
281    /// inference scenarios.
282    ///
283    /// # Parameters
284    ///
285    /// * `filepath` - The file to save to.
286    /// * `seq_id` - The sequence ID whose state to save.
287    /// * `tokens` - The tokens to associate with the saved state.
288    ///
289    /// # Errors
290    ///
291    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
292    /// the sequence state file.
293    ///
294    /// # Returns
295    ///
296    /// The number of bytes written on success.
297    pub fn state_seq_save_file(
298        &self,
299        filepath: impl AsRef<Path>,
300        seq_id: i32,
301        tokens: &[LlamaToken],
302    ) -> Result<usize, SaveSeqStateError> {
303        let path = filepath.as_ref();
304        let path = path
305            .to_str()
306            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
307
308        let cstr = CString::new(path)?;
309
310        let bytes_written = unsafe {
311            llama_cpp_sys_2::llama_state_seq_save_file(
312                self.context.as_ptr(),
313                cstr.as_ptr(),
314                seq_id,
315                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
316                tokens.len(),
317            )
318        };
319
320        if bytes_written == 0 {
321            Err(SaveSeqStateError::FailedToSave)
322        } else {
323            Ok(bytes_written)
324        }
325    }
326
327    /// Load state for a single sequence from a file.
328    ///
329    /// This enables loading state for individual sequences, which is useful for multi-sequence
330    /// inference scenarios.
331    ///
332    /// # Parameters
333    ///
334    /// * `filepath` - The file to load from.
335    /// * `dest_seq_id` - The destination sequence ID to load the state into.
336    /// * `max_tokens` - The maximum number of tokens to read.
337    ///
338    /// # Errors
339    ///
340    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
341    /// the sequence state file.
342    ///
343    /// # Returns
344    ///
345    /// A tuple of `(tokens, bytes_read)` on success.
346    pub fn state_seq_load_file(
347        &mut self,
348        filepath: impl AsRef<Path>,
349        dest_seq_id: i32,
350        max_tokens: usize,
351    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
352        let path = filepath.as_ref();
353        let path = path
354            .to_str()
355            .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
356
357        let cstr = CString::new(path)?;
358        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
359        let mut n_out = 0;
360
361        // SAFETY: cast is valid as LlamaToken is repr(transparent)
362        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
363
364        let bytes_read = unsafe {
365            llama_cpp_sys_2::llama_state_seq_load_file(
366                self.context.as_ptr(),
367                cstr.as_ptr(),
368                dest_seq_id,
369                tokens_out,
370                max_tokens,
371                &mut n_out,
372            )
373        };
374
375        if bytes_read == 0 {
376            return Err(LoadSeqStateError::FailedToLoad);
377        }
378
379        if n_out > max_tokens {
380            return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
381        }
382
383        // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
384        unsafe {
385            tokens.set_len(n_out);
386        }
387
388        Ok((tokens, bytes_read))
389    }
390
391    /// Returns the maximum size in bytes of the state (rng, logits, embedding
392    /// and `kv_cache`) - will often be smaller after compacting tokens
393    #[must_use]
394    pub fn get_state_size(&self) -> usize {
395        unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
396    }
397
398    /// Copies the state to the specified destination address.
399    ///
400    /// Returns the number of bytes copied
401    ///
402    /// # Safety
403    ///
404    /// Destination needs to have allocated enough memory.
405    pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
406        unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
407    }
408
409    /// Set the state reading from the specified address
410    /// Returns the number of bytes read
411    ///
412    /// # Safety
413    ///
414    /// help wanted: not entirely sure what the safety requirements are here.
415    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
416        unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
417    }
418}