llama_cpp_2/context/
session.rs

1//! utilities for working with session files
2
3use crate::context::LlamaContext;
4use crate::token::LlamaToken;
5use std::ffi::{CString, NulError};
6use std::path::{Path, PathBuf};
7
8/// Failed to save a Session file
9#[derive(Debug, Eq, PartialEq, thiserror::Error)]
10pub enum SaveSessionError {
11    /// llama.cpp failed to save the session file
12    #[error("Failed to save session file")]
13    FailedToSave,
14
15    /// null byte in string
16    #[error("null byte in string {0}")]
17    NullError(#[from] NulError),
18
19    /// failed to convert path to str
20    #[error("failed to convert path {0} to str")]
21    PathToStrError(PathBuf),
22}
23
24/// Failed to load a Session file
25#[derive(Debug, Eq, PartialEq, thiserror::Error)]
26pub enum LoadSessionError {
27    /// llama.cpp failed to load the session file
28    #[error("Failed to load session file")]
29    FailedToLoad,
30
31    /// null byte in string
32    #[error("null byte in string {0}")]
33    NullError(#[from] NulError),
34
35    /// failed to convert path to str
36    #[error("failed to convert path {0} to str")]
37    PathToStrError(PathBuf),
38
39    /// Insufficient max length
40    #[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
41    InsufficientMaxLength {
42        /// The length of the session file
43        n_out: usize,
44        /// The maximum length
45        max_tokens: usize,
46    },
47}
48
49impl LlamaContext<'_> {
50    /// Save the current session to a file.
51    ///
52    /// # Parameters
53    ///
54    /// * `path_session` - The file to save to.
55    /// * `tokens` - The tokens to associate the session with. This should be a prefix of a sequence of tokens that the context has processed, so that the relevant KV caches are already filled.
56    ///
57    /// # Errors
58    ///
59    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to save the session file.
60    pub fn save_session_file(
61        &self,
62        path_session: impl AsRef<Path>,
63        tokens: &[LlamaToken],
64    ) -> Result<(), SaveSessionError> {
65        let path = path_session.as_ref();
66        let path = path
67            .to_str()
68            .ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
69
70        let cstr = CString::new(path)?;
71
72        if unsafe {
73            llama_cpp_sys_2::llama_save_session_file(
74                self.context.as_ptr(),
75                cstr.as_ptr(),
76                tokens.as_ptr().cast::<llama_cpp_sys_2::llama_token>(),
77                tokens.len(),
78            )
79        } {
80            Ok(())
81        } else {
82            Err(SaveSessionError::FailedToSave)
83        }
84    }
85    /// Load a session file into the current context.
86    ///
87    /// You still need to pass the returned tokens to the context for inference to work. What this function buys you is that the KV caches are already filled with the relevant data.
88    ///
89    /// # Parameters
90    ///
91    /// * `path_session` - The file to load from. It must be a session file from a compatible context, otherwise the function will error.
92    /// * `max_tokens` - The maximum token length of the loaded session. If the session was saved with a longer length, the function will error.
93    ///
94    /// # Errors
95    ///
96    /// Fails if the path is not a valid utf8, is not a valid c string, or llama.cpp fails to load the session file. (e.g. the file does not exist, is not a session file, etc.)
97    pub fn load_session_file(
98        &mut self,
99        path_session: impl AsRef<Path>,
100        max_tokens: usize,
101    ) -> Result<Vec<LlamaToken>, LoadSessionError> {
102        let path = path_session.as_ref();
103        let path = path
104            .to_str()
105            .ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
106
107        let cstr = CString::new(path)?;
108        let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
109        let mut n_out = 0;
110
111        // SAFETY: cast is valid as LlamaToken is repr(transparent)
112        let tokens_out = tokens.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>();
113
114        let load_session_success = unsafe {
115            llama_cpp_sys_2::llama_load_session_file(
116                self.context.as_ptr(),
117                cstr.as_ptr(),
118                tokens_out,
119                max_tokens,
120                &mut n_out,
121            )
122        };
123        if load_session_success {
124            if n_out > max_tokens {
125                return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
126            }
127            // SAFETY: we checked that n_out <= max_tokens and llama.cpp promises that n_out tokens will be written
128            unsafe {
129                tokens.set_len(n_out);
130            }
131            Ok(tokens)
132        } else {
133            Err(LoadSessionError::FailedToLoad)
134        }
135    }
136
137    /// Returns the maximum size in bytes of the state (rng, logits, embedding
138    /// and `kv_cache`) - will often be smaller after compacting tokens
139    #[must_use]
140    pub fn get_state_size(&self) -> usize {
141        unsafe { llama_cpp_sys_2::llama_get_state_size(self.context.as_ptr()) }
142    }
143
144    /// Copies the state to the specified destination address.
145    ///
146    /// Returns the number of bytes copied
147    ///
148    /// # Safety
149    ///
150    /// Destination needs to have allocated enough memory.
151    pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
152        unsafe { llama_cpp_sys_2::llama_copy_state_data(self.context.as_ptr(), dest) }
153    }
154
155    /// Set the state reading from the specified address
156    /// Returns the number of bytes read
157    ///
158    /// # Safety
159    ///
160    /// help wanted: not entirely sure what the safety requirements are here.
161    pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
162        unsafe { llama_cpp_sys_2::llama_set_state_data(self.context.as_ptr(), src.as_ptr()) }
163    }
164}