Skip to main content

llama_cpp_bindings/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use std::ffi::c_int;
4use std::num::{NonZeroU8, TryFromIntError};
5use std::os::raw::c_char;
6use std::ptr;
7
8use crate::context::LlamaContext;
9use crate::error::{KvCacheSeqAddError, KvCacheSeqDivError};
10use crate::ffi_error_reader::read_and_free_cpp_error;
11
12/// Errors that can occur when attempting to prepare values for the kv cache
13#[derive(Debug, Eq, PartialEq, thiserror::Error)]
14pub enum KvCacheConversionError {
15    /// Sequence id conversion to i32 failed
16    #[error("Provided sequence id is too large for a i32")]
17    SeqIdTooLarge(#[source] TryFromIntError),
18    /// Position 0 conversion to i32 failed
19    #[error("Provided start position is too large for a i32")]
20    P0TooLarge(#[source] TryFromIntError),
21    /// Position 1 conversion to i32 failed
22    #[error("Provided end position is too large for a i32")]
23    P1TooLarge(#[source] TryFromIntError),
24}
25
26impl LlamaContext<'_> {
27    /// Copy the cache from one sequence to another.
28    ///
29    /// # Parameters
30    ///
31    /// * `src` - The sequence id to copy the cache from.
32    /// * `dest` - The sequence id to copy the cache to.
33    /// * `size` - The size of the cache to copy.
34    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
35        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
36        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
37    }
38
39    /// Copy the cache from one sequence to another.
40    ///
41    /// # Returns
42    /// A `Result` indicating whether the operation was successful.
43    ///
44    /// # Parameters
45    /// * `src` - The sequence id to copy the cache from.
46    /// * `dest` - The sequence id to copy the cache to.
47    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
48    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
49    ///
50    /// # Errors
51    /// If either position exceeds [`i32::MAX`].
52    pub fn copy_kv_cache_seq(
53        &mut self,
54        src: i32,
55        dest: i32,
56        p0: Option<u32>,
57        p1: Option<u32>,
58    ) -> Result<(), KvCacheConversionError> {
59        let p0 = p0
60            .map_or(Ok(-1), i32::try_from)
61            .map_err(KvCacheConversionError::P0TooLarge)?;
62        let p1 = p1
63            .map_or(Ok(-1), i32::try_from)
64            .map_err(KvCacheConversionError::P1TooLarge)?;
65        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
66        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
67        Ok(())
68    }
69
70    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
71    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
72    ///
73    /// # Returns
74    /// A `Result` indicating whether the operation was successful. If the sequence id or
75    /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
76    ///
77    /// # Parameters
78    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
79    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
80    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
81    ///
82    /// # Errors
83    /// If the sequence id or either position exceeds [`i32::MAX`].
84    pub fn clear_kv_cache_seq(
85        &mut self,
86        src: Option<u32>,
87        p0: Option<u32>,
88        p1: Option<u32>,
89    ) -> Result<bool, KvCacheConversionError> {
90        let src = src
91            .map_or(Ok(-1), i32::try_from)
92            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
93        let p0 = p0
94            .map_or(Ok(-1), i32::try_from)
95            .map_err(KvCacheConversionError::P0TooLarge)?;
96        let p1 = p1
97            .map_or(Ok(-1), i32::try_from)
98            .map_err(KvCacheConversionError::P1TooLarge)?;
99        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
100        Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
101    }
102
103    /// Clear the KV cache, including both metadata and the underlying data buffers.
104    pub fn clear_kv_cache(&mut self) {
105        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
106        let clear_data_buffers = true;
107        unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
108    }
109
110    /// Removes all tokens that do not belong to the specified sequence
111    ///
112    /// # Parameters
113    ///
114    /// * `seq_id` - The sequence id to keep
115    pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
116        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
117        unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
118    }
119
120    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
121    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
122    ///   - lazily on next [`LlamaContext::decode`]
123    ///   - explicitly with [`Self::kv_cache_update`]
124    ///
125    /// # Returns
126    /// A `Result` indicating whether the operation was successful.
127    ///
128    /// # Parameters
129    ///
130    /// * `seq_id` - The sequence id to update
131    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
132    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
133    /// * `delta` - The relative position to add to the tokens
134    ///
135    /// # Errors
136    /// If either position exceeds [`i32::MAX`], or the underlying memory operation reports a failure.
137    pub fn kv_cache_seq_add(
138        &mut self,
139        seq_id: i32,
140        p0: Option<u32>,
141        p1: Option<u32>,
142        delta: i32,
143    ) -> Result<(), KvCacheSeqAddError> {
144        let p0 = p0
145            .map_or(Ok(-1), i32::try_from)
146            .map_err(KvCacheSeqAddError::P0TooLarge)?;
147        let p1 = p1
148            .map_or(Ok(-1), i32::try_from)
149            .map_err(KvCacheSeqAddError::P1TooLarge)?;
150        let mut out_error: *mut c_char = ptr::null_mut();
151        let status = unsafe {
152            llama_cpp_bindings_sys::llama_rs_memory_seq_add(
153                self.context.as_ptr(),
154                seq_id,
155                p0,
156                p1,
157                delta,
158                &raw mut out_error,
159            )
160        };
161        match status {
162            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_OK => Ok(()),
163            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_INCOMPATIBLE_ROPE_TYPE => {
164                Err(KvCacheSeqAddError::IncompatibleRopeType)
165            }
166            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_NULL_MEM => {
167                Err(KvCacheSeqAddError::MemoryHandleUnavailable)
168            }
169            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_ERROR_STRING_ALLOCATION_FAILED => {
170                Err(KvCacheSeqAddError::NotEnoughMemory)
171            }
172            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_VENDORED_THREW_CXX_EXCEPTION => {
173                let message = unsafe { read_and_free_cpp_error(out_error) };
174                Err(KvCacheSeqAddError::Reported { message })
175            }
176            other => unreachable!("llama_rs_memory_seq_add returned unrecognized status {other}"),
177        }
178    }
179
180    /// Integer division of the positions by factor of `d > 1`
181    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
182    ///   - lazily on next [`LlamaContext::decode`]
183    ///   - explicitly with [`Self::kv_cache_update`]
184    ///
185    /// # Returns
186    /// A `Result` indicating whether the operation was successful.
187    ///
188    /// # Parameters
189    ///
190    /// * `seq_id` - The sequence id to update
191    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
192    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
193    /// * `d` - The factor to divide the positions by
194    ///
195    /// # Errors
196    /// If either position exceeds [`i32::MAX`], or the underlying memory operation reports a failure.
197    pub fn kv_cache_seq_div(
198        &mut self,
199        seq_id: i32,
200        p0: Option<u32>,
201        p1: Option<u32>,
202        d: NonZeroU8,
203    ) -> Result<(), KvCacheSeqDivError> {
204        let p0 = p0
205            .map_or(Ok(-1), i32::try_from)
206            .map_err(KvCacheSeqDivError::P0TooLarge)?;
207        let p1 = p1
208            .map_or(Ok(-1), i32::try_from)
209            .map_err(KvCacheSeqDivError::P1TooLarge)?;
210        let d = c_int::from(d.get());
211        let mut out_error: *mut c_char = ptr::null_mut();
212        let status = unsafe {
213            llama_cpp_bindings_sys::llama_rs_memory_seq_div(
214                self.context.as_ptr(),
215                seq_id,
216                p0,
217                p1,
218                d,
219                &raw mut out_error,
220            )
221        };
222        match status {
223            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_OK => Ok(()),
224            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_INCOMPATIBLE_ROPE_TYPE => {
225                Err(KvCacheSeqDivError::IncompatibleRopeType)
226            }
227            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_NULL_MEM => {
228                Err(KvCacheSeqDivError::MemoryHandleUnavailable)
229            }
230            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_ERROR_STRING_ALLOCATION_FAILED => {
231                Err(KvCacheSeqDivError::NotEnoughMemory)
232            }
233            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_VENDORED_THREW_CXX_EXCEPTION => {
234                let message = unsafe { read_and_free_cpp_error(out_error) };
235                Err(KvCacheSeqDivError::Reported { message })
236            }
237            other => unreachable!("llama_rs_memory_seq_div returned unrecognized status {other}"),
238        }
239    }
240
241    /// Returns the largest position present in the KV cache for the specified sequence
242    ///
243    /// # Parameters
244    ///
245    /// * `seq_id` - The sequence id to get the max position for
246    #[must_use]
247    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
248        unsafe {
249            llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
250        }
251    }
252}