llama_cpp_2/context/session.rs
1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a sequence state file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSeqStateError {
11 /// llama.cpp failed to save the sequence state file
12 #[error("Failed to save sequence state file")]
13 FailedToSave,
14
15 /// null byte in string
16 #[error("null byte in string {0}")]
17 NullError(#[from] NulError),
18
19 /// failed to convert path to str
20 #[error("failed to convert path {0} to str")]
21 PathToStrError(PathBuf),
22}
23
24/// Failed to load a sequence state file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSeqStateError {
27 /// llama.cpp failed to load the sequence state file
28 #[error("Failed to load sequence state file")]
29 FailedToLoad,
30
31 /// null byte in string
32 #[error("null byte in string {0}")]
33 NullError(#[from] NulError),
34
35 /// failed to convert path to str
36 #[error("failed to convert path {0} to str")]
37 PathToStrError(PathBuf),
38
39 /// Insufficient max length
40 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41 InsufficientMaxLength {
42 /// The length of the loaded sequence
43 n_out: usize,
44 /// The maximum length
45 max_tokens: usize,
46 },
47}
48
49/// Failed to save a Session file
50#[derive(Debug, Eq, PartialEq, thiserror::Error)]
51pub enum SaveSessionError {
52 /// llama.cpp failed to save the session file
53 #[error("Failed to save session file")]
54 FailedToSave,
55
56 /// null byte in string
57 #[error("null byte in string {0}")]
58 NullError(#[from] NulError),
59
60 /// failed to convert path to str
61 #[error("failed to convert path {0} to str")]
62 PathToStrError(PathBuf),
63}
64
65/// Failed to load a Session file
66#[derive(Debug, Eq, PartialEq, thiserror::Error)]
67pub enum LoadSessionError {
68 /// llama.cpp failed to load the session file
69 #[error("Failed to load session file")]
70 FailedToLoad,
71
72 /// null byte in string
73 #[error("null byte in string {0}")]
74 NullError(#[from] NulError),
75
76 /// failed to convert path to str
77 #[error("failed to convert path {0} to str")]
78 PathToStrError(PathBuf),
79
80 /// Insufficient max length
81 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
82 InsufficientMaxLength {
83 /// The length of the session file
84 n_out: usize,
85 /// The maximum length
86 max_tokens: usize,
87 },
88}
89
90impl LlamaContext<'_> {
91 /// Save the current session to a file.
92 ///
93 /// # Parameters
94 ///
95 /// * `path_session` - The file to save to.
96 /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
97 ///
98 /// # Errors
99 ///
100 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
101 #[deprecated(since = "0.1.136", note = "Use `state_save_file` instead")]
102 pub fn save_session_file(
103 &self,
104 path_session: impl AsRef<Path>,
105 tokens: &[LlamaToken],
106 ) -> Result<(), SaveSessionError> {
107 let path = path_session.as_ref();
108 let path = path
109 .to_str()
110 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
111
112 let cstr = CString::new(path)?;
113
114 if unsafe {
115 llama_cpp_sys_2::llama_save_session_file(
116 self.context.as_ptr(),
117 cstr.as_ptr(),
118 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
119 tokens.len(),
120 )
121 } {
122 Ok(())
123 } else {
124 Err(SaveSessionError::FailedToSave)
125 }
126 }
127 /// Load a session file into the current context.
128 ///
129 /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
130 ///
131 /// # Parameters
132 ///
133 /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
134 /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
135 ///
136 /// # Errors
137 ///
138 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
139 #[deprecated(since = "0.1.136", note = "Use `state_load_file` instead")]
140 pub fn load_session_file(
141 &mut self,
142 path_session: impl AsRef<Path>,
143 max_tokens: usize,
144 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
145 let path = path_session.as_ref();
146 let path = path
147 .to_str()
148 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
149
150 let cstr = CString::new(path)?;
151 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
152 let mut n_out = 0;
153
154 // SAFETY: cast is valid as LlamaToken is repr(transparent)
155 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
156
157 let load_session_success = unsafe {
158 llama_cpp_sys_2::llama_load_session_file(
159 self.context.as_ptr(),
160 cstr.as_ptr(),
161 tokens_out,
162 max_tokens,
163 &mut n_out,
164 )
165 };
166 if load_session_success {
167 if n_out > max_tokens {
168 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
169 }
170 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
171 unsafe {
172 tokens.set_len(n_out);
173 }
174 Ok(tokens)
175 } else {
176 Err(LoadSessionError::FailedToLoad)
177 }
178 }
179
180 /// Save the full state to a file.
181 ///
182 /// This is the non-deprecated replacement for [`save_session_file`](Self::save_session_file).
183 ///
184 /// # Parameters
185 ///
186 /// * `path_session` - The file to save to.
187 /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
188 /// of tokens that the context has processed, so that the relevant KV caches are already filled.
189 ///
190 /// # Errors
191 ///
192 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
193 /// the state file.
194 pub fn state_save_file(
195 &self,
196 path_session: impl AsRef<Path>,
197 tokens: &[LlamaToken],
198 ) -> Result<(), SaveSessionError> {
199 let path = path_session.as_ref();
200 let path = path
201 .to_str()
202 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
203
204 let cstr = CString::new(path)?;
205
206 if unsafe {
207 llama_cpp_sys_2::llama_state_save_file(
208 self.context.as_ptr(),
209 cstr.as_ptr(),
210 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
211 tokens.len(),
212 )
213 } {
214 Ok(())
215 } else {
216 Err(SaveSessionError::FailedToSave)
217 }
218 }
219
220 /// Load a state file into the current context.
221 ///
222 /// This is the non-deprecated replacement for [`load_session_file`](Self::load_session_file).
223 ///
224 /// You still need to pass the returned tokens to the context for inference to work. What this
225 /// function buys you is that the KV caches are already filled with the relevant data.
226 ///
227 /// # Parameters
228 ///
229 /// * `path_session` - The file to load from. It must be a state file from a compatible context,
230 /// otherwise the function will error.
231 /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
232 /// longer length, the function will error.
233 ///
234 /// # Errors
235 ///
236 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
237 /// the state file.
238 pub fn state_load_file(
239 &mut self,
240 path_session: impl AsRef<Path>,
241 max_tokens: usize,
242 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
243 let path = path_session.as_ref();
244 let path = path
245 .to_str()
246 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
247
248 let cstr = CString::new(path)?;
249 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
250 let mut n_out = 0;
251
252 // SAFETY: cast is valid as LlamaToken is repr(transparent)
253 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
254
255 let success = unsafe {
256 llama_cpp_sys_2::llama_state_load_file(
257 self.context.as_ptr(),
258 cstr.as_ptr(),
259 tokens_out,
260 max_tokens,
261 &mut n_out,
262 )
263 };
264 if success {
265 if n_out > max_tokens {
266 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
267 }
268 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
269 unsafe {
270 tokens.set_len(n_out);
271 }
272 Ok(tokens)
273 } else {
274 Err(LoadSessionError::FailedToLoad)
275 }
276 }
277
278 /// Save state for a single sequence to a file.
279 ///
280 /// This enables saving state for individual sequences, which is useful for multi-sequence
281 /// inference scenarios.
282 ///
283 /// # Parameters
284 ///
285 /// * `filepath` - The file to save to.
286 /// * `seq_id` - The sequence ID whose state to save.
287 /// * `tokens` - The tokens to associate with the saved state.
288 ///
289 /// # Errors
290 ///
291 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
292 /// the sequence state file.
293 ///
294 /// # Returns
295 ///
296 /// The number of bytes written on success.
297 pub fn state_seq_save_file(
298 &self,
299 filepath: impl AsRef<Path>,
300 seq_id: i32,
301 tokens: &[LlamaToken],
302 ) -> Result<usize, SaveSeqStateError> {
303 let path = filepath.as_ref();
304 let path = path
305 .to_str()
306 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
307
308 let cstr = CString::new(path)?;
309
310 let bytes_written = unsafe {
311 llama_cpp_sys_2::llama_state_seq_save_file(
312 self.context.as_ptr(),
313 cstr.as_ptr(),
314 seq_id,
315 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
316 tokens.len(),
317 )
318 };
319
320 if bytes_written == 0 {
321 Err(SaveSeqStateError::FailedToSave)
322 } else {
323 Ok(bytes_written)
324 }
325 }
326
327 /// Load state for a single sequence from a file.
328 ///
329 /// This enables loading state for individual sequences, which is useful for multi-sequence
330 /// inference scenarios.
331 ///
332 /// # Parameters
333 ///
334 /// * `filepath` - The file to load from.
335 /// * `dest_seq_id` - The destination sequence ID to load the state into.
336 /// * `max_tokens` - The maximum number of tokens to read.
337 ///
338 /// # Errors
339 ///
340 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
341 /// the sequence state file.
342 ///
343 /// # Returns
344 ///
345 /// A tuple of `(tokens, bytes_read)` on success.
346 pub fn state_seq_load_file(
347 &mut self,
348 filepath: impl AsRef<Path>,
349 dest_seq_id: i32,
350 max_tokens: usize,
351 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
352 let path = filepath.as_ref();
353 let path = path
354 .to_str()
355 .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
356
357 let cstr = CString::new(path)?;
358 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
359 let mut n_out = 0;
360
361 // SAFETY: cast is valid as LlamaToken is repr(transparent)
362 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
363
364 let bytes_read = unsafe {
365 llama_cpp_sys_2::llama_state_seq_load_file(
366 self.context.as_ptr(),
367 cstr.as_ptr(),
368 dest_seq_id,
369 tokens_out,
370 max_tokens,
371 &mut n_out,
372 )
373 };
374
375 if bytes_read == 0 {
376 return Err(LoadSeqStateError::FailedToLoad);
377 }
378
379 if n_out > max_tokens {
380 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
381 }
382
383 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
384 unsafe {
385 tokens.set_len(n_out);
386 }
387
388 Ok((tokens, bytes_read))
389 }
390
391 /// Returns the maximum size in bytes of the state (rng, logits, embedding
392 /// and `kv_cache`) - will often be smaller after compacting tokens
393 #[must_use]
394 pub fn get_state_size(&self) -> usize {
395 unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
396 }
397
398 /// Copies the state to the specified destination address.
399 ///
400 /// Returns the number of bytes copied
401 ///
402 /// # Safety
403 ///
404 /// Destination needs to have allocated enough memory.
405 pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
406 unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
407 }
408
409 /// Set the state reading from the specified address
410 /// Returns the number of bytes read
411 ///
412 /// # Safety
413 ///
414 /// help wanted: not entirely sure what the safety requirements are here.
415 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
416 unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
417 }
418}