Skip to main content

llama_cpp_bindings/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10    /// Sequence id conversion to i32 failed
11    #[error("Provided sequence id is too large for a i32")]
12    SeqIdTooLarge(#[source] TryFromIntError),
13    /// Position 0 conversion to i32 failed
14    #[error("Provided start position is too large for a i32")]
15    P0TooLarge(#[source] TryFromIntError),
16    /// Position 1 conversion to i32 failed
17    #[error("Provided end position is too large for a i32")]
18    P1TooLarge(#[source] TryFromIntError),
19    /// The operation is not supported by the current model/context configuration.
20    #[error("operation not supported by this model: {0}")]
21    UnsupportedOperation(String),
22}
23
24impl LlamaContext<'_> {
25    /// Copy the cache from one sequence to another.
26    ///
27    /// # Parameters
28    ///
29    /// * `src` - The sequence id to copy the cache from.
30    /// * `dest` - The sequence id to copy the cache to.
31    /// * `size` - The size of the cache to copy.
32    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
33        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
34        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
35    }
36
37    /// Copy the cache from one sequence to another.
38    ///
39    /// # Returns
40    /// A `Result` indicating whether the operation was successful.
41    ///
42    /// # Parameters
43    /// * `src` - The sequence id to copy the cache from.
44    /// * `dest` - The sequence id to copy the cache to.
45    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
46    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
47    ///
48    /// # Errors
49    /// If either position exceeds [`i32::MAX`].
50    pub fn copy_kv_cache_seq(
51        &mut self,
52        src: i32,
53        dest: i32,
54        p0: Option<u32>,
55        p1: Option<u32>,
56    ) -> Result<(), KvCacheConversionError> {
57        let p0 = p0
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P0TooLarge)?;
60        let p1 = p1
61            .map_or(Ok(-1), i32::try_from)
62            .map_err(KvCacheConversionError::P1TooLarge)?;
63        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
65        Ok(())
66    }
67
68    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
69    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70    ///
71    /// # Returns
72    /// A `Result` indicating whether the operation was successful. If the sequence id or
73    /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
74    ///
75    /// # Parameters
76    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
77    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
78    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
79    ///
80    /// # Errors
81    /// If the sequence id or either position exceeds [`i32::MAX`].
82    pub fn clear_kv_cache_seq(
83        &mut self,
84        src: Option<u32>,
85        p0: Option<u32>,
86        p1: Option<u32>,
87    ) -> Result<bool, KvCacheConversionError> {
88        let src = src
89            .map_or(Ok(-1), i32::try_from)
90            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
91        let p0 = p0
92            .map_or(Ok(-1), i32::try_from)
93            .map_err(KvCacheConversionError::P0TooLarge)?;
94        let p1 = p1
95            .map_or(Ok(-1), i32::try_from)
96            .map_err(KvCacheConversionError::P1TooLarge)?;
97        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
98        Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
99    }
100
101    /// Clear the KV cache, including both metadata and the underlying data buffers.
102    pub fn clear_kv_cache(&mut self) {
103        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
104        let clear_data_buffers = true;
105        unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
106    }
107
108    /// Removes all tokens that do not belong to the specified sequence
109    ///
110    /// # Parameters
111    ///
112    /// * `seq_id` - The sequence id to keep
113    pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
114        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
115        unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
116    }
117
118    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
119    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
120    ///   - lazily on next [`LlamaContext::decode`]
121    ///   - explicitly with [`Self::kv_cache_update`]
122    ///
123    /// # Returns
124    /// A `Result` indicating whether the operation was successful.
125    ///
126    /// # Parameters
127    ///
128    /// * `seq_id` - The sequence id to update
129    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
130    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
131    /// * `delta` - The relative position to add to the tokens
132    ///
133    /// # Errors
134    /// If either position exceeds [`i32::MAX`].
135    pub fn kv_cache_seq_add(
136        &mut self,
137        seq_id: i32,
138        p0: Option<u32>,
139        p1: Option<u32>,
140        delta: i32,
141    ) -> Result<(), KvCacheConversionError> {
142        let p0 = p0
143            .map_or(Ok(-1), i32::try_from)
144            .map_err(KvCacheConversionError::P0TooLarge)?;
145        let p1 = p1
146            .map_or(Ok(-1), i32::try_from)
147            .map_err(KvCacheConversionError::P1TooLarge)?;
148        let status = unsafe {
149            llama_cpp_bindings_sys::llama_rs_memory_seq_add(
150                self.context.as_ptr(),
151                seq_id,
152                p0,
153                p1,
154                delta,
155            )
156        };
157
158        if crate::status_is_ok(status) {
159            Ok(())
160        } else {
161            Err(KvCacheConversionError::UnsupportedOperation(format!(
162                "kv_cache_seq_add failed (status {})",
163                crate::status_to_i32(status)
164            )))
165        }
166    }
167
168    /// Integer division of the positions by factor of `d > 1`
169    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
170    ///   - lazily on next [`LlamaContext::decode`]
171    ///   - explicitly with [`Self::kv_cache_update`]
172    ///
173    /// # Returns
174    /// A `Result` indicating whether the operation was successful.
175    ///
176    /// # Parameters
177    ///
178    /// * `seq_id` - The sequence id to update
179    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
180    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
181    /// * `d` - The factor to divide the positions by
182    ///
183    /// # Errors
184    /// If either position exceeds [`i32::MAX`].
185    pub fn kv_cache_seq_div(
186        &mut self,
187        seq_id: i32,
188        p0: Option<u32>,
189        p1: Option<u32>,
190        d: NonZeroU8,
191    ) -> Result<(), KvCacheConversionError> {
192        let p0 = p0
193            .map_or(Ok(-1), i32::try_from)
194            .map_err(KvCacheConversionError::P0TooLarge)?;
195        let p1 = p1
196            .map_or(Ok(-1), i32::try_from)
197            .map_err(KvCacheConversionError::P1TooLarge)?;
198        let d = c_int::from(d.get());
199        let status = unsafe {
200            llama_cpp_bindings_sys::llama_rs_memory_seq_div(
201                self.context.as_ptr(),
202                seq_id,
203                p0,
204                p1,
205                d,
206            )
207        };
208
209        if crate::status_is_ok(status) {
210            Ok(())
211        } else {
212            Err(KvCacheConversionError::UnsupportedOperation(format!(
213                "kv_cache_seq_div failed (status {})",
214                crate::status_to_i32(status)
215            )))
216        }
217    }
218
219    /// Returns the largest position present in the KV cache for the specified sequence
220    ///
221    /// # Parameters
222    ///
223    /// * `seq_id` - The sequence id to get the max position for
224    #[must_use]
225    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
226        unsafe {
227            llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
228        }
229    }
230}