llama_cpp_2/context/session.rs
1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Flags for state sequence operations.
9///
10/// These flags control what parts of the state are included when saving/restoring
11/// sequence state.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub struct LlamaStateSeqFlags(pub(crate) llama_cpp_sys_2::llama_state_seq_flags);
14
15impl LlamaStateSeqFlags {
16 /// Work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba).
17 ///
18 /// This flag is useful when you only want to save/restore the recurrent state
19 /// without affecting the KV cache.
20 pub const PARTIAL_ONLY: LlamaStateSeqFlags = LlamaStateSeqFlags(1);
21
22 /// Create an empty flags set.
23 pub const fn empty() -> LlamaStateSeqFlags {
24 LlamaStateSeqFlags(0)
25 }
26
27 /// Get the raw flags value.
28 pub const fn bits(&self) -> u32 {
29 self.0
30 }
31
32 /// Check if a flag is set.
33 pub const fn contains(&self, other: LlamaStateSeqFlags) -> bool {
34 (self.0 & other.0) != 0
35 }
36}
37
38impl Default for LlamaStateSeqFlags {
39 fn default() -> Self {
40 Self::empty()
41 }
42}
43
44/// Failed to save a sequence state file
45#[derive(Debug, Eq, PartialEq, thiserror::Error)]
46pub enum SaveSeqStateError {
47 /// llama.cpp failed to save the sequence state file
48 #[error("Failed to save sequence state file")]
49 FailedToSave,
50
51 /// null byte in string
52 #[error("null byte in string {0}")]
53 NullError(#[from] NulError),
54
55 /// failed to convert path to str
56 #[error("failed to convert path {0} to str")]
57 PathToStrError(PathBuf),
58}
59
60/// Failed to load a sequence state file
61#[derive(Debug, Eq, PartialEq, thiserror::Error)]
62pub enum LoadSeqStateError {
63 /// llama.cpp failed to load the sequence state file
64 #[error("Failed to load sequence state file")]
65 FailedToLoad,
66
67 /// null byte in string
68 #[error("null byte in string {0}")]
69 NullError(#[from] NulError),
70
71 /// failed to convert path to str
72 #[error("failed to convert path {0} to str")]
73 PathToStrError(PathBuf),
74
75 /// Insufficient max length
76 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
77 InsufficientMaxLength {
78 /// The length of the loaded sequence
79 n_out: usize,
80 /// The maximum length
81 max_tokens: usize,
82 },
83}
84
85/// Failed to save a Session file
86#[derive(Debug, Eq, PartialEq, thiserror::Error)]
87pub enum SaveSessionError {
88 /// llama.cpp failed to save the session file
89 #[error("Failed to save session file")]
90 FailedToSave,
91
92 /// null byte in string
93 #[error("null byte in string {0}")]
94 NullError(#[from] NulError),
95
96 /// failed to convert path to str
97 #[error("failed to convert path {0} to str")]
98 PathToStrError(PathBuf),
99}
100
101/// Failed to load a Session file
102#[derive(Debug, Eq, PartialEq, thiserror::Error)]
103pub enum LoadSessionError {
104 /// llama.cpp failed to load the session file
105 #[error("Failed to load session file")]
106 FailedToLoad,
107
108 /// null byte in string
109 #[error("null byte in string {0}")]
110 NullError(#[from] NulError),
111
112 /// failed to convert path to str
113 #[error("failed to convert path {0} to str")]
114 PathToStrError(PathBuf),
115
116 /// Insufficient max length
117 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
118 InsufficientMaxLength {
119 /// The length of the session file
120 n_out: usize,
121 /// The maximum length
122 max_tokens: usize,
123 },
124}
125
126impl LlamaContext<'_> {
127 /// Save the current session to a file.
128 ///
129 /// # Parameters
130 ///
131 /// * `path_session` - The file to save to.
132 /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
133 ///
134 /// # Errors
135 ///
136 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
137 #[deprecated(since = "0.1.136", note = "Use `state_save_file` instead")]
138 pub fn save_session_file(
139 &self,
140 path_session: impl AsRef<Path>,
141 tokens: &[LlamaToken],
142 ) -> Result<(), SaveSessionError> {
143 let path = path_session.as_ref();
144 let path = path
145 .to_str()
146 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
147
148 let cstr = CString::new(path)?;
149
150 if unsafe {
151 llama_cpp_sys_2::llama_save_session_file(
152 self.context.as_ptr(),
153 cstr.as_ptr(),
154 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
155 tokens.len(),
156 )
157 } {
158 Ok(())
159 } else {
160 Err(SaveSessionError::FailedToSave)
161 }
162 }
163 /// Load a session file into the current context.
164 ///
165 /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
166 ///
167 /// # Parameters
168 ///
169 /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
170 /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
171 ///
172 /// # Errors
173 ///
174 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
175 #[deprecated(since = "0.1.136", note = "Use `state_load_file` instead")]
176 pub fn load_session_file(
177 &mut self,
178 path_session: impl AsRef<Path>,
179 max_tokens: usize,
180 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
181 let path = path_session.as_ref();
182 let path = path
183 .to_str()
184 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
185
186 let cstr = CString::new(path)?;
187 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
188 let mut n_out = 0;
189
190 // SAFETY: cast is valid as LlamaToken is repr(transparent)
191 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
192
193 let load_session_success = unsafe {
194 llama_cpp_sys_2::llama_load_session_file(
195 self.context.as_ptr(),
196 cstr.as_ptr(),
197 tokens_out,
198 max_tokens,
199 &mut n_out,
200 )
201 };
202 if load_session_success {
203 if n_out > max_tokens {
204 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
205 }
206 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
207 unsafe {
208 tokens.set_len(n_out);
209 }
210 Ok(tokens)
211 } else {
212 Err(LoadSessionError::FailedToLoad)
213 }
214 }
215
216 /// Save the full state to a file.
217 ///
218 /// This is the non-deprecated replacement for [`save_session_file`](Self::save_session_file).
219 ///
220 /// # Parameters
221 ///
222 /// * `path_session` - The file to save to.
223 /// * `tokens` - The tokens to associate the state with. This should be a prefix of a sequence
224 /// of tokens that the context has processed, so that the relevant KV caches are already filled.
225 ///
226 /// # Errors
227 ///
228 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
229 /// the state file.
230 pub fn state_save_file(
231 &self,
232 path_session: impl AsRef<Path>,
233 tokens: &[LlamaToken],
234 ) -> Result<(), SaveSessionError> {
235 let path = path_session.as_ref();
236 let path = path
237 .to_str()
238 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
239
240 let cstr = CString::new(path)?;
241
242 if unsafe {
243 llama_cpp_sys_2::llama_state_save_file(
244 self.context.as_ptr(),
245 cstr.as_ptr(),
246 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
247 tokens.len(),
248 )
249 } {
250 Ok(())
251 } else {
252 Err(SaveSessionError::FailedToSave)
253 }
254 }
255
256 /// Load a state file into the current context.
257 ///
258 /// This is the non-deprecated replacement for [`load_session_file`](Self::load_session_file).
259 ///
260 /// You still need to pass the returned tokens to the context for inference to work. What this
261 /// function buys you is that the KV caches are already filled with the relevant data.
262 ///
263 /// # Parameters
264 ///
265 /// * `path_session` - The file to load from. It must be a state file from a compatible context,
266 /// otherwise the function will error.
267 /// * `max_tokens` - The maximum token length of the loaded state. If the state was saved with a
268 /// longer length, the function will error.
269 ///
270 /// # Errors
271 ///
272 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
273 /// the state file.
274 pub fn state_load_file(
275 &mut self,
276 path_session: impl AsRef<Path>,
277 max_tokens: usize,
278 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
279 let path = path_session.as_ref();
280 let path = path
281 .to_str()
282 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
283
284 let cstr = CString::new(path)?;
285 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
286 let mut n_out = 0;
287
288 // SAFETY: cast is valid as LlamaToken is repr(transparent)
289 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
290
291 let success = unsafe {
292 llama_cpp_sys_2::llama_state_load_file(
293 self.context.as_ptr(),
294 cstr.as_ptr(),
295 tokens_out,
296 max_tokens,
297 &mut n_out,
298 )
299 };
300 if success {
301 if n_out > max_tokens {
302 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
303 }
304 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
305 unsafe {
306 tokens.set_len(n_out);
307 }
308 Ok(tokens)
309 } else {
310 Err(LoadSessionError::FailedToLoad)
311 }
312 }
313
314 /// Save state for a single sequence to a file.
315 ///
316 /// This enables saving state for individual sequences, which is useful for multi-sequence
317 /// inference scenarios.
318 ///
319 /// # Parameters
320 ///
321 /// * `filepath` - The file to save to.
322 /// * `seq_id` - The sequence ID whose state to save.
323 /// * `tokens` - The tokens to associate with the saved state.
324 ///
325 /// # Errors
326 ///
327 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save
328 /// the sequence state file.
329 ///
330 /// # Returns
331 ///
332 /// The number of bytes written on success.
333 pub fn state_seq_save_file(
334 &self,
335 filepath: impl AsRef<Path>,
336 seq_id: i32,
337 tokens: &[LlamaToken],
338 ) -> Result<usize, SaveSeqStateError> {
339 let path = filepath.as_ref();
340 let path = path
341 .to_str()
342 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
343
344 let cstr = CString::new(path)?;
345
346 let bytes_written = unsafe {
347 llama_cpp_sys_2::llama_state_seq_save_file(
348 self.context.as_ptr(),
349 cstr.as_ptr(),
350 seq_id,
351 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
352 tokens.len(),
353 )
354 };
355
356 if bytes_written == 0 {
357 Err(SaveSeqStateError::FailedToSave)
358 } else {
359 Ok(bytes_written)
360 }
361 }
362
363 /// Load state for a single sequence from a file.
364 ///
365 /// This enables loading state for individual sequences, which is useful for multi-sequence
366 /// inference scenarios.
367 ///
368 /// # Parameters
369 ///
370 /// * `filepath` - The file to load from.
371 /// * `dest_seq_id` - The destination sequence ID to load the state into.
372 /// * `max_tokens` - The maximum number of tokens to read.
373 ///
374 /// # Errors
375 ///
376 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load
377 /// the sequence state file.
378 ///
379 /// # Returns
380 ///
381 /// A tuple of `(tokens, bytes_read)` on success.
382 pub fn state_seq_load_file(
383 &mut self,
384 filepath: impl AsRef<Path>,
385 dest_seq_id: i32,
386 max_tokens: usize,
387 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
388 let path = filepath.as_ref();
389 let path = path
390 .to_str()
391 .ok_or(LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
392
393 let cstr = CString::new(path)?;
394 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
395 let mut n_out = 0;
396
397 // SAFETY: cast is valid as LlamaToken is repr(transparent)
398 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
399
400 let bytes_read = unsafe {
401 llama_cpp_sys_2::llama_state_seq_load_file(
402 self.context.as_ptr(),
403 cstr.as_ptr(),
404 dest_seq_id,
405 tokens_out,
406 max_tokens,
407 &mut n_out,
408 )
409 };
410
411 if bytes_read == 0 {
412 return Err(LoadSeqStateError::FailedToLoad);
413 }
414
415 if n_out > max_tokens {
416 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
417 }
418
419 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
420 unsafe {
421 tokens.set_len(n_out);
422 }
423
424 Ok((tokens, bytes_read))
425 }
426
427 /// Returns the maximum size in bytes of the state (rng, logits, embedding
428 /// and `kv_cache`) - will often be smaller after compacting tokens
429 #[must_use]
430 pub fn get_state_size(&self) -> usize {
431 unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
432 }
433
434 /// Copies the state to the specified destination address.
435 ///
436 /// Returns the number of bytes copied
437 ///
438 /// # Safety
439 ///
440 /// Destination needs to have allocated enough memory.
441 pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
442 unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
443 }
444
445 /// Set the state reading from the specified address
446 /// Returns the number of bytes read
447 ///
448 /// # Safety
449 ///
450 /// help wanted: not entirely sure what the safety requirements are here.
451 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
452 unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
453 }
454
455 /// Get the size of the state for a single sequence with optional flags.
456 ///
457 /// This is the extended version that supports flags for partial state operations.
458 ///
459 /// # Parameters
460 ///
461 /// * `seq_id` - The sequence ID to get the state size for.
462 /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
463 ///
464 /// # Returns
465 ///
466 /// The size in bytes needed to store the sequence state.
467 #[must_use]
468 pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: LlamaStateSeqFlags) -> usize {
469 unsafe {
470 llama_cpp_sys_2::llama_state_seq_get_size_ext(self.context.as_ptr(), seq_id, flags.0)
471 }
472 }
473
474 /// Copy the state of a single sequence into the specified buffer with optional flags.
475 ///
476 /// This is the extended version that supports flags for partial state operations.
477 ///
478 /// # Parameters
479 ///
480 /// * `dest` - Destination buffer to copy state into.
481 /// * `seq_id` - The sequence ID to get the state for.
482 /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
483 ///
484 /// # Safety
485 ///
486 /// Destination needs to have allocated enough memory.
487 ///
488 /// # Returns
489 ///
490 /// The number of bytes copied.
491 pub unsafe fn state_seq_get_data_ext(
492 &self,
493 dest: *mut u8,
494 seq_id: i32,
495 flags: LlamaStateSeqFlags,
496 ) -> usize {
497 unsafe {
498 llama_cpp_sys_2::llama_state_seq_get_data_ext(
499 self.context.as_ptr(),
500 dest,
501 usize::MAX,
502 seq_id,
503 flags.0,
504 )
505 }
506 }
507
508 /// Set the state for a single sequence from the specified buffer with optional flags.
509 ///
510 /// This is the extended version that supports flags for partial state operations.
511 /// Useful for restoring only the recurrent/partial state without affecting the KV cache.
512 ///
513 /// # Parameters
514 ///
515 /// * `src` - Source buffer containing the state data.
516 /// * `dest_seq_id` - The destination sequence ID to load the state into.
517 /// * `flags` - Optional flags (e.g., [`LlamaStateSeqFlags::PARTIAL_ONLY`]).
518 ///
519 /// # Safety
520 ///
521 /// The source buffer must contain valid state data.
522 ///
523 /// # Returns
524 ///
525 /// Positive on success, zero on failure.
526 pub unsafe fn state_seq_set_data_ext(
527 &mut self,
528 src: &[u8],
529 dest_seq_id: i32,
530 flags: LlamaStateSeqFlags,
531 ) -> bool {
532 unsafe {
533 llama_cpp_sys_2::llama_state_seq_set_data_ext(
534 self.context.as_ptr(),
535 src.as_ptr(),
536 src.len(),
537 dest_seq_id,
538 flags.0,
539 ) > 0
540 }
541 }
542}