llama_cpp_bindings/context/session.rs
1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a sequence state file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSeqStateError {
11 /// llama.cpp failed to save the sequence state file
12 #[error("Failed to save sequence state file")]
13 FailedToSave,
14
15 /// null byte in string
16 #[error("null byte in string {0}")]
17 NullError(#[from] NulError),
18
19 /// failed to convert path to str
20 #[error("failed to convert path {0} to str")]
21 PathToStrError(PathBuf),
22}
23
24/// Failed to load a sequence state file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSeqStateError {
27 /// llama.cpp failed to load the sequence state file
28 #[error("Failed to load sequence state file")]
29 FailedToLoad,
30
31 /// null byte in string
32 #[error("null byte in string {0}")]
33 NullError(#[from] NulError),
34
35 /// failed to convert path to str
36 #[error("failed to convert path {0} to str")]
37 PathToStrError(PathBuf),
38
39 /// Insufficient max length
40 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41 InsufficientMaxLength {
42 /// The length of the loaded sequence
43 n_out: usize,
44 /// The maximum length
45 max_tokens: usize,
46 },
47}
48
49/// Failed to save a Session file
50#[derive(Debug, Eq, PartialEq, thiserror::Error)]
51pub enum SaveSessionError {
52 /// llama.cpp failed to save the session file
53 #[error("Failed to save session file")]
54 FailedToSave,
55
56 /// null byte in string
57 #[error("null byte in string {0}")]
58 NullError(#[from] NulError),
59
60 /// failed to convert path to str
61 #[error("failed to convert path {0} to str")]
62 PathToStrError(PathBuf),
63}
64
65/// Failed to load a Session file
66#[derive(Debug, Eq, PartialEq, thiserror::Error)]
67pub enum LoadSessionError {
68 /// llama.cpp failed to load the session file
69 #[error("Failed to load session file")]
70 FailedToLoad,
71
72 /// null byte in string
73 #[error("null byte in string {0}")]
74 NullError(#[from] NulError),
75
76 /// failed to convert path to str
77 #[error("failed to convert path {0} to str")]
78 PathToStrError(PathBuf),
79
80 /// Insufficient max length
81 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
82 InsufficientMaxLength {
83 /// The length of the session file
84 n_out: usize,
85 /// The maximum length
86 max_tokens: usize,
87 },
88}
89
90impl LlamaContext<'_> {
91 /// Save the full state to a file.
92 ///
93 /// # Parameters
94 ///
95 /// * `path_session` - The file to save to.
96 /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
97 /// of tokens that the context has processed, so that the relevant KV caches are already filled.
98 ///
99 /// # Errors
100 ///
101 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
102 /// the state file.
103 pub fn state_save_file(
104 &self,
105 path_session: impl AsRef<Path>,
106 tokens: &[LlamaToken],
107 ) -> Result<(), SaveSessionError> {
108 let path = path_session.as_ref();
109 let path = path
110 .to_str()
111 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
112
113 let cstr = CString::new(path)?;
114
115 if unsafe {
116 llama_cpp_bindings_sys::llama_state_save_file(
117 self.context.as_ptr(),
118 cstr.as_ptr(),
119 tokens
120 .as_ptr()
121 .cast::<llama_cpp_bindings_sys::llama_token>(),
122 tokens.len(),
123 )
124 } {
125 Ok(())
126 } else {
127 Err(SaveSessionError::FailedToSave)
128 }
129 }
130
131 /// Load a state file into the current context.
132 ///
133 /// You still need to pass the returned tokens to the context for inference to work. What this
134 /// function buys you is that the KV caches are already filled with the relevant data.
135 ///
136 /// # Parameters
137 ///
138 /// * `path_session` - The file to load from. It must be a state file from a compatible context,
139 /// otherwise the function will error.
140 /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
141 /// longer length, the function will error.
142 ///
143 /// # Errors
144 ///
145 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
146 /// the state file.
147 pub fn state_load_file(
148 &mut self,
149 path_session: impl AsRef<Path>,
150 max_tokens: usize,
151 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
152 let path = path_session.as_ref();
153 let path = path
154 .to_str()
155 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
156
157 let cstr = CString::new(path)?;
158 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
159 let mut n_out = 0;
160
161 // SAFETY: cast is valid as LlamaToken is repr(transparent)
162 let tokens_out = tokens
163 .as_mut_ptr()
164 .cast::<llama_cpp_bindings_sys::llama_token>();
165
166 let success = unsafe {
167 llama_cpp_bindings_sys::llama_state_load_file(
168 self.context.as_ptr(),
169 cstr.as_ptr(),
170 tokens_out,
171 max_tokens,
172 &raw mut n_out,
173 )
174 };
175 if success {
176 if n_out > max_tokens {
177 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
178 }
179 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
180 unsafe {
181 tokens.set_len(n_out);
182 }
183 Ok(tokens)
184 } else {
185 Err(LoadSessionError::FailedToLoad)
186 }
187 }
188
189 /// Save state for a single sequence to a file.
190 ///
191 /// This enables saving state for individual sequences, which is useful for multi-sequence
192 /// inference scenarios.
193 ///
194 /// # Parameters
195 ///
196 /// * `filepath` - The file to save to.
197 /// * `seq_id` - The sequence ID whose state to save.
198 /// * `tokens` - The tokens to associate with the saved state.
199 ///
200 /// # Errors
201 ///
202 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
203 /// the sequence state file.
204 ///
205 /// # Returns
206 ///
207 /// The number of bytes written on success.
208 pub fn state_seq_save_file(
209 &self,
210 filepath: impl AsRef<Path>,
211 seq_id: i32,
212 tokens: &[LlamaToken],
213 ) -> Result<usize, SaveSeqStateError> {
214 let path = filepath.as_ref();
215 let path = path
216 .to_str()
217 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
218
219 let cstr = CString::new(path)?;
220
221 let bytes_written = unsafe {
222 llama_cpp_bindings_sys::llama_state_seq_save_file(
223 self.context.as_ptr(),
224 cstr.as_ptr(),
225 seq_id,
226 tokens
227 .as_ptr()
228 .cast::<llama_cpp_bindings_sys::llama_token>(),
229 tokens.len(),
230 )
231 };
232
233 if bytes_written == 0 {
234 Err(SaveSeqStateError::FailedToSave)
235 } else {
236 Ok(bytes_written)
237 }
238 }
239
240 /// Load state for a single sequence from a file.
241 ///
242 /// This enables loading state for individual sequences, which is useful for multi-sequence
243 /// inference scenarios.
244 ///
245 /// # Parameters
246 ///
247 /// * `filepath` - The file to load from.
248 /// * `dest_seq_id` - The destination sequence ID to load the state into.
249 /// * `max_tokens` - The maximum number of tokens to read.
250 ///
251 /// # Errors
252 ///
253 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
254 /// the sequence state file.
255 ///
256 /// # Returns
257 ///
258 /// A tuple of `(tokens, bytes_read)` on success.
259 pub fn state_seq_load_file(
260 &mut self,
261 filepath: impl AsRef<Path>,
262 dest_seq_id: i32,
263 max_tokens: usize,
264 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
265 let path = filepath.as_ref();
266 let path = path
267 .to_str()
268 .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
269
270 let cstr = CString::new(path)?;
271 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
272 let mut n_out = 0;
273
274 // SAFETY: cast is valid as LlamaToken is repr(transparent)
275 let tokens_out = tokens
276 .as_mut_ptr()
277 .cast::<llama_cpp_bindings_sys::llama_token>();
278
279 let bytes_read = unsafe {
280 llama_cpp_bindings_sys::llama_state_seq_load_file(
281 self.context.as_ptr(),
282 cstr.as_ptr(),
283 dest_seq_id,
284 tokens_out,
285 max_tokens,
286 &raw mut n_out,
287 )
288 };
289
290 if bytes_read == 0 {
291 return Err(LoadSeqStateError::FailedToLoad);
292 }
293
294 if n_out > max_tokens {
295 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
296 }
297
298 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
299 unsafe {
300 tokens.set_len(n_out);
301 }
302
303 Ok((tokens, bytes_read))
304 }
305
306 /// Returns the maximum size in bytes of the state (rng, logits, embedding
307 /// and `kv_cache`) - will often be smaller after compacting tokens
308 #[must_use]
309 pub fn get_state_size(&self) -> usize {
310 unsafe { llama_cpp_bindings_sys::llama_get_state_size(self.context.as_ptr()) }
311 }
312
313 /// Copies the state to the specified destination address.
314 ///
315 /// Returns the number of bytes copied
316 ///
317 /// # Safety
318 ///
319 /// Destination needs to have allocated enough memory.
320 pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
321 unsafe { llama_cpp_bindings_sys::llama_copy_state_data(self.context.as_ptr(), dest) }
322 }
323
324 /// Set the state reading from the specified address.
325 /// Returns the number of bytes read.
326 ///
327 /// # Safety
328 ///
329 /// The `src` buffer must contain data previously obtained from [`copy_state_data`](Self::copy_state_data)
330 /// on a compatible context (same model and parameters). Passing arbitrary or corrupted bytes
331 /// will lead to undefined behavior.
332 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
333 unsafe { llama_cpp_bindings_sys::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
334 }
335}