Skip to main content

llama_cpp_bindings/context/
session.rs

1use crate::context::LlamaContext;
2use crate::context::llama_state_seq_flags::LlamaStateSeqFlags;
3use crate::context::load_seq_state_error::LoadSeqStateError;
4use crate::context::load_session_error::LoadSessionError;
5use crate::context::save_seq_state_error::SaveSeqStateError;
6use crate::context::save_session_error::SaveSessionError;
7use crate::token::LlamaToken;
8use std::ffi::CString;
9use std::path::Path;
10
11fn process_session_load_result(
12    success: bool,
13    n_out: usize,
14    max_tokens: usize,
15    mut tokens: Vec<LlamaToken>,
16) -> Result<Vec<LlamaToken>, LoadSessionError> {
17    if !success {
18        return Err(LoadSessionError::FailedToLoad);
19    }
20
21    if n_out > max_tokens {
22        return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
23    }
24
25    unsafe { tokens.set_len(n_out) };
26
27    Ok(tokens)
28}
29
30fn process_seq_load_result(
31    bytes_read: usize,
32    n_out: usize,
33    max_tokens: usize,
34    mut tokens: Vec<LlamaToken>,
35) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
36    if bytes_read == 0 {
37        return Err(LoadSeqStateError::FailedToLoad);
38    }
39
40    if n_out > max_tokens {
41        return Err(LoadSeqStateError::InsufficientMaxLength { n_out, max_tokens });
42    }
43
44    unsafe { tokens.set_len(n_out) };
45
46    Ok((tokens, bytes_read))
47}
48
49impl LlamaContext<'_> {
50    /// # Errors
51    ///
52    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the state file.
53    pub fn state_save_file(
54        &self,
55        path_session: impl AsRef<Path>,
56        tokens: &[LlamaToken],
57    ) -> Result<(), SaveSessionError> {
58        let path = path_session.as_ref();
59        let path = path
60            .to_str()
61            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
62
63        let cstr = CString::new(path)?;
64
65        if unsafe {
66            llama_cpp_bindings_sys::llama_state_save_file(
67                self.context.as_ptr(),
68                cstr.as_ptr(),
69                tokens
70                    .as_ptr()
71                    .cast::<llama_cpp_bindings_sys::llama_token>(),
72                tokens.len(),
73            )
74        } {
75            Ok(())
76        } else {
77            Err(SaveSessionError::FailedToSave)
78        }
79    }
80
81    /// # Errors
82    ///
83    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the state file.
84    pub fn state_load_file(
85        &mut self,
86        path_session: impl AsRef<Path>,
87        max_tokens: usize,
88    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
89        let path = path_session.as_ref();
90        let path = path
91            .to_str()
92            .ok_or_else(|| LoadSessionError::PathToStrError(path.to_path_buf()))?;
93
94        let cstr = CString::new(path)?;
95        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
96        let mut n_out = 0;
97
98        // SAFETY: cast is valid as LlamaToken is repr(transparent)
99        let tokens_out = tokens
100            .as_mut_ptr()
101            .cast::<llama_cpp_bindings_sys::llama_token>();
102
103        let success = unsafe {
104            llama_cpp_bindings_sys::llama_state_load_file(
105                self.context.as_ptr(),
106                cstr.as_ptr(),
107                tokens_out,
108                max_tokens,
109                &raw mut n_out,
110            )
111        };
112        process_session_load_result(success, n_out, max_tokens, tokens)
113    }
114
115    /// # Errors
116    ///
117    /// Fails if the path is not a valid utf8 or llama.cpp fails to save the sequence state file.
118    ///
119    pub fn state_seq_save_file(
120        &self,
121        filepath: impl AsRef<Path>,
122        seq_id: i32,
123        tokens: &[LlamaToken],
124    ) -> Result<usize, SaveSeqStateError> {
125        let path = filepath.as_ref();
126        let path = path
127            .to_str()
128            .ok_or_else(|| SaveSeqStateError::PathToStrError(path.to_path_buf()))?;
129
130        let cstr = CString::new(path)?;
131
132        let bytes_written = unsafe {
133            llama_cpp_bindings_sys::llama_state_seq_save_file(
134                self.context.as_ptr(),
135                cstr.as_ptr(),
136                seq_id,
137                tokens
138                    .as_ptr()
139                    .cast::<llama_cpp_bindings_sys::llama_token>(),
140                tokens.len(),
141            )
142        };
143
144        if bytes_written == 0 {
145            Err(SaveSeqStateError::FailedToSave)
146        } else {
147            Ok(bytes_written)
148        }
149    }
150
151    /// # Errors
152    ///
153    /// Fails if the path is not a valid utf8 or llama.cpp fails to load the sequence state file.
154    ///
155    pub fn state_seq_load_file(
156        &mut self,
157        filepath: impl AsRef<Path>,
158        dest_seq_id: i32,
159        max_tokens: usize,
160    ) -> Result<(Vec<LlamaToken>, usize), LoadSeqStateError> {
161        let path = filepath.as_ref();
162        let path = path
163            .to_str()
164            .ok_or_else(|| LoadSeqStateError::PathToStrError(path.to_path_buf()))?;
165
166        let cstr = CString::new(path)?;
167        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
168        let mut n_out = 0;
169
170        // SAFETY: cast is valid as LlamaToken is repr(transparent)
171        let tokens_out = tokens
172            .as_mut_ptr()
173            .cast::<llama_cpp_bindings_sys::llama_token>();
174
175        let bytes_read = unsafe {
176            llama_cpp_bindings_sys::llama_state_seq_load_file(
177                self.context.as_ptr(),
178                cstr.as_ptr(),
179                dest_seq_id,
180                tokens_out,
181                max_tokens,
182                &raw mut n_out,
183            )
184        };
185
186        process_seq_load_result(bytes_read, n_out, max_tokens, tokens)
187    }
188
189    #[must_use]
190    pub fn get_state_size(&self) -> usize {
191        unsafe { llama_cpp_bindings_sys::llama_state_get_size(self.context.as_ptr()) }
192    }
193
194    /// # Safety
195    ///
196    /// The `dest` buffer must be large enough to hold the complete state data.
197    pub unsafe fn copy_state_data(&self, dest: &mut [u8]) -> usize {
198        unsafe {
199            llama_cpp_bindings_sys::llama_state_get_data(
200                self.context.as_ptr(),
201                dest.as_mut_ptr(),
202                dest.len(),
203            )
204        }
205    }
206
207    /// # Safety
208    ///
209    /// The `src` buffer must contain data previously obtained from [`copy_state_data`](Self::copy_state_data)
210    /// on a compatible context (same model and parameters). Passing arbitrary or corrupted bytes
211    /// will lead to undefined behavior.
212    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
213        unsafe {
214            llama_cpp_bindings_sys::llama_state_set_data(
215                self.context.as_ptr(),
216                src.as_ptr(),
217                src.len(),
218            )
219        }
220    }
221
222    #[must_use]
223    pub fn state_seq_get_size_ext(&self, seq_id: i32, flags: &LlamaStateSeqFlags) -> usize {
224        unsafe {
225            llama_cpp_bindings_sys::llama_state_seq_get_size_ext(
226                self.context.as_ptr(),
227                seq_id,
228                flags.bits(),
229            )
230        }
231    }
232
233    /// # Safety
234    ///
235    /// The `dest` buffer must be large enough to hold the complete state data.
236    pub unsafe fn state_seq_get_data_ext(
237        &self,
238        dest: &mut [u8],
239        seq_id: i32,
240        flags: &LlamaStateSeqFlags,
241    ) -> usize {
242        unsafe {
243            llama_cpp_bindings_sys::llama_state_seq_get_data_ext(
244                self.context.as_ptr(),
245                dest.as_mut_ptr(),
246                dest.len(),
247                seq_id,
248                flags.bits(),
249            )
250        }
251    }
252
253    /// # Safety
254    ///
255    /// The `src` buffer must contain data previously obtained from
256    /// [`state_seq_get_data_ext`](Self::state_seq_get_data_ext) on a compatible context.
257    pub unsafe fn state_seq_set_data_ext(
258        &mut self,
259        src: &[u8],
260        dest_seq_id: i32,
261        flags: &LlamaStateSeqFlags,
262    ) -> usize {
263        unsafe {
264            llama_cpp_bindings_sys::llama_state_seq_set_data_ext(
265                self.context.as_ptr(),
266                src.as_ptr(),
267                src.len(),
268                dest_seq_id,
269                flags.bits(),
270            )
271        }
272    }
273}
274
275#[cfg(test)]
276mod unit_tests {
277    use crate::token::LlamaToken;
278
279    use crate::context::load_seq_state_error::LoadSeqStateError;
280    use crate::context::load_session_error::LoadSessionError;
281
282    use super::{process_seq_load_result, process_session_load_result};
283
284    #[test]
285    fn session_load_success_within_bounds() {
286        let tokens = vec![LlamaToken::new(0); 100];
287        let result = process_session_load_result(true, 10, 100, tokens);
288
289        assert!(result.is_ok());
290        assert_eq!(result.unwrap().len(), 10);
291    }
292
293    #[test]
294    fn session_load_fails_when_not_successful() {
295        let tokens = vec![LlamaToken::new(0); 100];
296        let result = process_session_load_result(false, 0, 100, tokens);
297
298        assert_eq!(result, Err(LoadSessionError::FailedToLoad));
299    }
300
301    #[test]
302    fn session_load_fails_when_n_out_exceeds_max() {
303        let tokens = vec![LlamaToken::new(0); 100];
304        let result = process_session_load_result(true, 101, 100, tokens);
305
306        assert_eq!(
307            result,
308            Err(LoadSessionError::InsufficientMaxLength {
309                n_out: 101,
310                max_tokens: 100,
311            })
312        );
313    }
314
315    #[test]
316    fn seq_load_success_within_bounds() {
317        let tokens = vec![LlamaToken::new(0); 100];
318        let result = process_seq_load_result(42, 10, 100, tokens);
319
320        assert!(result.is_ok());
321        let (loaded, bytes) = result.unwrap();
322        assert_eq!(loaded.len(), 10);
323        assert_eq!(bytes, 42);
324    }
325
326    #[test]
327    fn seq_load_fails_when_zero_bytes_read() {
328        let tokens = vec![LlamaToken::new(0); 100];
329        let result = process_seq_load_result(0, 0, 100, tokens);
330
331        assert_eq!(result, Err(LoadSeqStateError::FailedToLoad));
332    }
333
334    #[test]
335    fn seq_load_fails_when_n_out_exceeds_max() {
336        let tokens = vec![LlamaToken::new(0); 100];
337        let result = process_seq_load_result(42, 101, 100, tokens);
338
339        assert_eq!(
340            result,
341            Err(LoadSeqStateError::InsufficientMaxLength {
342                n_out: 101,
343                max_tokens: 100,
344            })
345        );
346    }
347}