llama_cpp_bindings/context/
session.rs1use crate::context::LlamaContext;
4use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
5use crate::context::load_seq_state_error::LoadSeqStateError;
6use crate::context::load_session_error::LoadSessionError;
7use crate::context::save_seq_state_error::SaveSeqStateError;
8use crate::context::save_session_error::SaveSessionError;
9use crate::token::LlamaToken;
10use std::ffi::CString;
11use std::path::Path;
12
13fn process_session_load_result(
14 success: bool,
15 n_out: usize,
16 max_tokens: usize,
17 mut tokens: Vec<LlamaToken>,
18) -> Result<Vec<LlamaToken>, LoadSessionError> {
19 if !success {
20 return Err(LoadSessionError::FailedToLoad);
21 }
22
23 if n_out > max_tokens {
24 return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
25 }
26
27 unsafe { tokens.set_len(n_out) };
28
29 Ok(tokens)
30}
31
32fn process_seq_load_result(
33 bytes_read: usize,
34 n_out: usize,
35 max_tokens: usize,
36 mut tokens: Vec<LlamaToken>,
37) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
38 if bytes_read == 0 {
39 return Err(LoadSeqStateError::FailedToLoad);
40 }
41
42 if n_out > max_tokens {
43 return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
44 }
45
46 unsafe { tokens.set_len(n_out) };
47
48 Ok((tokens, bytes_read))
49}
50
51impl LlamaContext<'_> {
52 pub fn state_save_file(
64 &self,
65 path_session: impl AsRef<Path>,
66 tokens: &[LlamaToken],
67 ) -> Result<(), SaveSessionError> {
68 let path = path_session.as_ref();
69 let path = path
70 .to_str()
71 .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
72
73 let cstr = CString::new(path)?;
74
75 if unsafe {
76 llama_cpp_bindings_sys::llama_state_save_file(
77 self.context.as_ptr(),
78 cstr.as_ptr(),
79 tokens
80 .as_ptr()
81 .cast::<llama_cpp_bindings_sys::llama_token>(),
82 tokens.len(),
83 )
84 } {
85 Ok(())
86 } else {
87 Err(SaveSessionError::FailedToSave)
88 }
89 }
90
91 pub fn state_load_file(
107 &mut self,
108 path_session: impl AsRef<Path>,
109 max_tokens: usize,
110 ) -> Result<Vec<LlamaToken>, LoadSessionError> {
111 let path = path_session.as_ref();
112 let path = path
113 .to_str()
114 .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
115
116 let cstr = CString::new(path)?;
117 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
118 let mut n_out = 0;
119
120 let tokens_out = tokens
122 .as_mut_ptr()
123 .cast::<llama_cpp_bindings_sys::llama_token>();
124
125 let success = unsafe {
126 llama_cpp_bindings_sys::llama_state_load_file(
127 self.context.as_ptr(),
128 cstr.as_ptr(),
129 tokens_out,
130 max_tokens,
131 &raw mut n_out,
132 )
133 };
134 process_session_load_result(success, n_out, max_tokens, tokens)
135 }
136
137 pub fn state_seq_save_file(
156 &self,
157 filepath: impl AsRef<Path>,
158 seq_id: i32,
159 tokens: &[LlamaToken],
160 ) -> Result<usize, SaveSeqStateError> {
161 let path = filepath.as_ref();
162 let path = path
163 .to_str()
164 .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166 let cstr = CString::new(path)?;
167
168 let bytes_written = unsafe {
169 llama_cpp_bindings_sys::llama_state_seq_save_file(
170 self.context.as_ptr(),
171 cstr.as_ptr(),
172 seq_id,
173 tokens
174 .as_ptr()
175 .cast::<llama_cpp_bindings_sys::llama_token>(),
176 tokens.len(),
177 )
178 };
179
180 if bytes_written == 0 {
181 Err(SaveSeqStateError::FailedToSave)
182 } else {
183 Ok(bytes_written)
184 }
185 }
186
187 pub fn state_seq_load_file(
206 &mut self,
207 filepath: impl AsRef<Path>,
208 dest_seq_id: i32,
209 max_tokens: usize,
210 ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
211 let path = filepath.as_ref();
212 let path = path
213 .to_str()
214 .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
215
216 let cstr = CString::new(path)?;
217 let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
218 let mut n_out = 0;
219
220 let tokens_out = tokens
222 .as_mut_ptr()
223 .cast::<llama_cpp_bindings_sys::llama_token>();
224
225 let bytes_read = unsafe {
226 llama_cpp_bindings_sys::llama_state_seq_load_file(
227 self.context.as_ptr(),
228 cstr.as_ptr(),
229 dest_seq_id,
230 tokens_out,
231 max_tokens,
232 &raw mut n_out,
233 )
234 };
235
236 process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
237 }
238
239 #[must_use]
242 pub fn get_state_size(&self) -> usize {
243 unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
244 }
245
246 pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
256 unsafe {
257 llama_cpp_bindings_sys::llama_state_get_data(
258 self.context.as_ptr(),
259 dest.as_mut_ptr(),
260 dest.len(),
261 )
262 }
263 }
264
265 pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
275 unsafe {
276 llama_cpp_bindings_sys::llama_state_set_data(
277 self.context.as_ptr(),
278 src.as_ptr(),
279 src.len(),
280 )
281 }
282 }
283
284 #[must_use]
289 pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
290 unsafe {
291 llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
292 self.context.as_ptr(),
293 seq_id,
294 flags.bits(),
295 )
296 }
297 }
298
299 pub unsafe fn state_seq_get_data_ext(
310 &self,
311 dest: &mut [u8],
312 seq_id: i32,
313 flags: &LlamaStateSeqFlags,
314 ) -> usize {
315 unsafe {
316 llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
317 self.context.as_ptr(),
318 dest.as_mut_ptr(),
319 dest.len(),
320 seq_id,
321 flags.bits(),
322 )
323 }
324 }
325
326 pub unsafe fn state_seq_set_data_ext(
335 &mut self,
336 src: &[u8],
337 dest_seq_id: i32,
338 flags: &LlamaStateSeqFlags,
339 ) -> usize {
340 unsafe {
341 llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
342 self.context.as_ptr(),
343 src.as_ptr(),
344 src.len(),
345 dest_seq_id,
346 flags.bits(),
347 )
348 }
349 }
350}
351
352#[cfg(test)]
353mod unit_tests {
354 use crate::token::LlamaToken;
355
356 use crate::context::load_seq_state_error::LoadSeqStateError;
357 use crate::context::load_session_error::LoadSessionError;
358
359 use super::{process_seq_load_result, process_session_load_result};
360
361 #[test]
362 fn session_load_success_within_bounds() {
363 let tokens = vec![LlamaToken::new(0); 100];
364 let result = process_session_load_result(true, 10, 100, tokens);
365
366 assert!(result.is_ok());
367 assert_eq!(result.unwrap().len(), 10);
368 }
369
370 #[test]
371 fn session_load_fails_when_not_successful() {
372 let tokens = vec![LlamaToken::new(0); 100];
373 let result = process_session_load_result(false, 0, 100, tokens);
374
375 assert_eq!(result, Err(LoadSessionError::FailedToLoad));
376 }
377
378 #[test]
379 fn session_load_fails_when_n_out_exceeds_max() {
380 let tokens = vec![LlamaToken::new(0); 100];
381 let result = process_session_load_result(true, 101, 100, tokens);
382
383 assert_eq!(
384 result,
385 Err(LoadSessionError::InsufficientMaxLength {
386 n_out: 101,
387 max_tokens: 100,
388 })
389 );
390 }
391
392 #[test]
393 fn seq_load_success_within_bounds() {
394 let tokens = vec![LlamaToken::new(0); 100];
395 let result = process_seq_load_result(42, 10, 100, tokens);
396
397 assert!(result.is_ok());
398 let (loaded, bytes) = result.unwrap();
399 assert_eq!(loaded.len(), 10);
400 assert_eq!(bytes, 42);
401 }
402
403 #[test]
404 fn seq_load_fails_when_zero_bytes_read() {
405 let tokens = vec![LlamaToken::new(0); 100];
406 let result = process_seq_load_result(0, 0, 100, tokens);
407
408 assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
409 }
410
411 #[test]
412 fn seq_load_fails_when_n_out_exceeds_max() {
413 let tokens = vec![LlamaToken::new(0); 100];
414 let result = process_seq_load_result(42, 101, 100, tokens);
415
416 assert_eq!(
417 result,
418 Err(LoadSeqStateError::InsufficientMaxLength {
419 n_out: 101,
420 max_tokens: 100,
421 })
422 );
423 }
424}