llama_cpp_bindings/context/
session.rs1use crate::context::LlamaContext;
2use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
3use crate::context::load_seq_state_error::LoadSeqStateError;
4use crate::context::load_session_error::LoadSessionError;
5use crate::context::save_seq_state_error::SaveSeqStateError;
6use crate::context::save_session_error::SaveSessionError;
7use crate::token::LlamaToken;
8use std::ffi::CString;
9use std::path::Path;
10
11fn process_session_load_result(
12 success: bool,
13 n_out: usize,
14 max_tokens: usize,
15 mut tokens: Vec<LlamaToken>,
16) -> Result<Vec<LlamaToken>, LoadSessionError> {
17 if !success {
18 return Err(LoadSessionError::FailedToLoad);
19 }
20
21 if n_out > max_tokens {
22 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
23 }
24
25 unsafe { tokens.set_len(n_out) };
26
27 Ok(tokens)
28}
29
30fn process_seq_load_result(
31 bytes_read: usize,
32 n_out: usize,
33 max_tokens: usize,
34 mut tokens: Vec<LlamaToken>,
35) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
36 if bytes_read == 0 {
37 return Err(LoadSeqStateError::FailedToLoad);
38 }
39
40 if n_out > max_tokens {
41 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
42 }
43
44 unsafe { tokens.set_len(n_out) };
45
46 Ok((tokens, bytes_read))
47}
48
49impl LlamaContext<'_> {
50 pub fn state_save_file(
54 &self,
55 path_session: impl AsRef<Path>,
56 tokens: &[LlamaToken],
57 ) -> Result<(), SaveSessionError> {
58 let path = path_session.as_ref();
59 let path = path
60 .to_str()
61 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
62
63 let cstr = CString::new(path)?;
64
65 if unsafe {
66 llama_cpp_bindings_sys::llama_state_save_file(
67 self.context.as_ptr(),
68 cstr.as_ptr(),
69 tokens
70 .as_ptr()
71 .cast::<llama_cpp_bindings_sys::llama_token>(),
72 tokens.len(),
73 )
74 } {
75 Ok(())
76 } else {
77 Err(SaveSessionError::FailedToSave)
78 }
79 }
80
81 pub fn state_load_file(
85 &mut self,
86 path_session: impl AsRef<Path>,
87 max_tokens: usize,
88 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
89 let path = path_session.as_ref();
90 let path = path
91 .to_str()
92 .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
93
94 let cstr = CString::new(path)?;
95 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
96 let mut n_out = 0;
97
98 let tokens_out = tokens
100 .as_mut_ptr()
101 .cast::<llama_cpp_bindings_sys::llama_token>();
102
103 let success = unsafe {
104 llama_cpp_bindings_sys::llama_state_load_file(
105 self.context.as_ptr(),
106 cstr.as_ptr(),
107 tokens_out,
108 max_tokens,
109 &raw mut n_out,
110 )
111 };
112 process_session_load_result(success, n_out, max_tokens, tokens)
113 }
114
115 pub fn state_seq_save_file(
120 &self,
121 filepath: impl AsRef<Path>,
122 seq_id: i32,
123 tokens: &[LlamaToken],
124 ) -> Result<usize, SaveSeqStateError> {
125 let path = filepath.as_ref();
126 let path = path
127 .to_str()
128 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
129
130 let cstr = CString::new(path)?;
131
132 let bytes_written = unsafe {
133 llama_cpp_bindings_sys::llama_state_seq_save_file(
134 self.context.as_ptr(),
135 cstr.as_ptr(),
136 seq_id,
137 tokens
138 .as_ptr()
139 .cast::<llama_cpp_bindings_sys::llama_token>(),
140 tokens.len(),
141 )
142 };
143
144 if bytes_written == 0 {
145 Err(SaveSeqStateError::FailedToSave)
146 } else {
147 Ok(bytes_written)
148 }
149 }
150
151 pub fn state_seq_load_file(
156 &mut self,
157 filepath: impl AsRef<Path>,
158 dest_seq_id: i32,
159 max_tokens: usize,
160 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
161 let path = filepath.as_ref();
162 let path = path
163 .to_str()
164 .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166 let cstr = CString::new(path)?;
167 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
168 let mut n_out = 0;
169
170 let tokens_out = tokens
172 .as_mut_ptr()
173 .cast::<llama_cpp_bindings_sys::llama_token>();
174
175 let bytes_read = unsafe {
176 llama_cpp_bindings_sys::llama_state_seq_load_file(
177 self.context.as_ptr(),
178 cstr.as_ptr(),
179 dest_seq_id,
180 tokens_out,
181 max_tokens,
182 &raw mut n_out,
183 )
184 };
185
186 process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
187 }
188
189 #[must_use]
190 pub fn get_state_size(&self) -> usize {
191 unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
192 }
193
194 pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
198 unsafe {
199 llama_cpp_bindings_sys::llama_state_get_data(
200 self.context.as_ptr(),
201 dest.as_mut_ptr(),
202 dest.len(),
203 )
204 }
205 }
206
207 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
213 unsafe {
214 llama_cpp_bindings_sys::llama_state_set_data(
215 self.context.as_ptr(),
216 src.as_ptr(),
217 src.len(),
218 )
219 }
220 }
221
222 #[must_use]
223 pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
224 unsafe {
225 llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
226 self.context.as_ptr(),
227 seq_id,
228 flags.bits(),
229 )
230 }
231 }
232
233 pub unsafe fn state_seq_get_data_ext(
237 &self,
238 dest: &mut [u8],
239 seq_id: i32,
240 flags: &LlamaStateSeqFlags,
241 ) -> usize {
242 unsafe {
243 llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
244 self.context.as_ptr(),
245 dest.as_mut_ptr(),
246 dest.len(),
247 seq_id,
248 flags.bits(),
249 )
250 }
251 }
252
253 pub unsafe fn state_seq_set_data_ext(
258 &mut self,
259 src: &[u8],
260 dest_seq_id: i32,
261 flags: &LlamaStateSeqFlags,
262 ) -> usize {
263 unsafe {
264 llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
265 self.context.as_ptr(),
266 src.as_ptr(),
267 src.len(),
268 dest_seq_id,
269 flags.bits(),
270 )
271 }
272 }
273}
274
275#[cfg(test)]
276mod unit_tests {
277 use crate::token::LlamaToken;
278
279 use crate::context::load_seq_state_error::LoadSeqStateError;
280 use crate::context::load_session_error::LoadSessionError;
281
282 use super::{process_seq_load_result, process_session_load_result};
283
284 #[test]
285 fn session_load_success_within_bounds() {
286 let tokens = vec![LlamaToken::new(0); 100];
287 let result = process_session_load_result(true, 10, 100, tokens);
288
289 assert!(result.is_ok());
290 assert_eq!(result.unwrap().len(), 10);
291 }
292
293 #[test]
294 fn session_load_fails_when_not_successful() {
295 let tokens = vec![LlamaToken::new(0); 100];
296 let result = process_session_load_result(false, 0, 100, tokens);
297
298 assert_eq!(result, Err(LoadSessionError::FailedToLoad));
299 }
300
301 #[test]
302 fn session_load_fails_when_n_out_exceeds_max() {
303 let tokens = vec![LlamaToken::new(0); 100];
304 let result = process_session_load_result(true, 101, 100, tokens);
305
306 assert_eq!(
307 result,
308 Err(LoadSessionError::InsufficientMaxLength {
309 n_out: 101,
310 max_tokens: 100,
311 })
312 );
313 }
314
315 #[test]
316 fn seq_load_success_within_bounds() {
317 let tokens = vec![LlamaToken::new(0); 100];
318 let result = process_seq_load_result(42, 10, 100, tokens);
319
320 assert!(result.is_ok());
321 let (loaded, bytes) = result.unwrap();
322 assert_eq!(loaded.len(), 10);
323 assert_eq!(bytes, 42);
324 }
325
326 #[test]
327 fn seq_load_fails_when_zero_bytes_read() {
328 let tokens = vec![LlamaToken::new(0); 100];
329 let result = process_seq_load_result(0, 0, 100, tokens);
330
331 assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
332 }
333
334 #[test]
335 fn seq_load_fails_when_n_out_exceeds_max() {
336 let tokens = vec![LlamaToken::new(0); 100];
337 let result = process_seq_load_result(42, 101, 100, tokens);
338
339 assert_eq!(
340 result,
341 Err(LoadSeqStateError::InsufficientMaxLength {
342 n_out: 101,
343 max_tokens: 100,
344 })
345 );
346 }
347}