llama_cpp_2/context/session.rs
1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a Session file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSessionError {
11 /// llama.cpp failed to save the session file
12 #[error("Failed to save session file")]
13 FailedToSave,
14
15 /// null byte in string
16 #[error("null byte in string {0}")]
17 NullError(#[from] NulError),
18
19 /// failed to convert path to str
20 #[error("failed to convert path {0} to str")]
21 PathToStrError(PathBuf),
22}
23
24/// Failed to load a Session file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSessionError {
27 /// llama.cpp failed to load the session file
28 #[error("Failed to load session file")]
29 FailedToLoad,
30
31 /// null byte in string
32 #[error("null byte in string {0}")]
33 NullError(#[from] NulError),
34
35 /// failed to convert path to str
36 #[error("failed to convert path {0} to str")]
37 PathToStrError(PathBuf),
38
39 /// Insufficient max length
40 #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41 InsufficientMaxLength {
42 /// The length of the session file
43 n_out: usize,
44 /// The maximum length
45 max_tokens: usize,
46 },
47}
48
49impl LlamaContext<'_> {
50 /// Save the current session to a file.
51 ///
52 /// # Parameters
53 ///
54 /// * `path_session` - The file to save to.
55 /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
56 ///
57 /// # Errors
58 ///
59 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
60 pub fn save_session_file(
61 &self,
62 path_session: impl AsRef<Path>,
63 tokens: &[LlamaToken],
64 ) -> Result<(), SaveSessionError> {
65 let path = path_session.as_ref();
66 let path = path
67 .to_str()
68 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
69
70 let cstr = CString::new(path)?;
71
72 if unsafe {
73 llama_cpp_sys_2::llama_save_session_file(
74 self.context.as_ptr(),
75 cstr.as_ptr(),
76 tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
77 tokens.len(),
78 )
79 } {
80 Ok(())
81 } else {
82 Err(SaveSessionError::FailedToSave)
83 }
84 }
85 /// Load a session file into the current context.
86 ///
87 /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
88 ///
89 /// # Parameters
90 ///
91 /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
92 /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
93 ///
94 /// # Errors
95 ///
96 /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
97 pub fn load_session_file(
98 &mut self,
99 path_session: impl AsRef<Path>,
100 max_tokens: usize,
101 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
102 let path = path_session.as_ref();
103 let path = path
104 .to_str()
105 .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
106
107 let cstr = CString::new(path)?;
108 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
109 let mut n_out = 0;
110
111 // SAFETY: cast is valid as LlamaToken is repr(transparent)
112 let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
113
114 let load_session_success = unsafe {
115 llama_cpp_sys_2::llama_load_session_file(
116 self.context.as_ptr(),
117 cstr.as_ptr(),
118 tokens_out,
119 max_tokens,
120 &mut n_out,
121 )
122 };
123 if load_session_success {
124 if n_out > max_tokens {
125 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
126 }
127 // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
128 unsafe {
129 tokens.set_len(n_out);
130 }
131 Ok(tokens)
132 } else {
133 Err(LoadSessionError::FailedToLoad)
134 }
135 }
136
137 /// Returns the maximum size in bytes of the state (rng, logits, embedding
138 /// and `kv_cache`) - will often be smaller after compacting tokens
139 #[must_use]
140 pub fn get_state_size(&self) -> usize {
141 unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
142 }
143
144 /// Copies the state to the specified destination address.
145 ///
146 /// Returns the number of bytes copied
147 ///
148 /// # Safety
149 ///
150 /// Destination needs to have allocated enough memory.
151 pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
152 unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
153 }
154
155 /// Set the state reading from the specified address
156 /// Returns the number of bytes read
157 ///
158 /// # Safety
159 ///
160 /// help wanted: not entirely sure what the safety requirements are here.
161 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
162 unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
163 }
164}