llama_cpp_bindings/context/kv_cache.rs
1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10 /// Sequence id conversion to i32 failed
11 #[error("Provided sequence id is too large for a i32")]
12 SeqIdTooLarge(#[source] TryFromIntError),
13 /// Position 0 conversion to i32 failed
14 #[error("Provided start position is too large for a i32")]
15 P0TooLarge(#[source] TryFromIntError),
16 /// Position 1 conversion to i32 failed
17 #[error("Provided end position is too large for a i32")]
18 P1TooLarge(#[source] TryFromIntError),
19 /// The operation is not supported by the current model/context configuration.
20 #[error("operation not supported by this model: {0}")]
21 UnsupportedOperation(String),
22}
23
24impl LlamaContext<'_> {
25 /// Copy the cache from one sequence to another.
26 ///
27 /// # Parameters
28 ///
29 /// * `src` - The sequence id to copy the cache from.
30 /// * `dest` - The sequence id to copy the cache to.
31 /// * `size` - The size of the cache to copy.
32 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
33 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
34 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
35 }
36
37 /// Copy the cache from one sequence to another.
38 ///
39 /// # Returns
40 /// A `Result` indicating whether the operation was successful.
41 ///
42 /// # Parameters
43 /// * `src` - The sequence id to copy the cache from.
44 /// * `dest` - The sequence id to copy the cache to.
45 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
46 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
47 ///
48 /// # Errors
49 /// If either position exceeds [`i32::MAX`].
50 pub fn copy_kv_cache_seq(
51 &mut self,
52 src: i32,
53 dest: i32,
54 p0: Option<u32>,
55 p1: Option<u32>,
56 ) -> Result<(), KvCacheConversionError> {
57 let p0 = p0
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P0TooLarge)?;
60 let p1 = p1
61 .map_or(Ok(-1), i32::try_from)
62 .map_err(KvCacheConversionError::P1TooLarge)?;
63 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
65 Ok(())
66 }
67
68 /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
69 /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70 ///
71 /// # Returns
72 /// A `Result` indicating whether the operation was successful. If the sequence id or
73 /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
74 ///
75 /// # Parameters
76 /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
77 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
78 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
79 ///
80 /// # Errors
81 /// If the sequence id or either position exceeds [`i32::MAX`].
82 pub fn clear_kv_cache_seq(
83 &mut self,
84 src: Option<u32>,
85 p0: Option<u32>,
86 p1: Option<u32>,
87 ) -> Result<bool, KvCacheConversionError> {
88 let src = src
89 .map_or(Ok(-1), i32::try_from)
90 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
91 let p0 = p0
92 .map_or(Ok(-1), i32::try_from)
93 .map_err(KvCacheConversionError::P0TooLarge)?;
94 let p1 = p1
95 .map_or(Ok(-1), i32::try_from)
96 .map_err(KvCacheConversionError::P1TooLarge)?;
97 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
98 Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
99 }
100
101 /// Clear the KV cache, including both metadata and the underlying data buffers.
102 pub fn clear_kv_cache(&mut self) {
103 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
104 let clear_data_buffers = true;
105 unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
106 }
107
108 /// Removes all tokens that do not belong to the specified sequence
109 ///
110 /// # Parameters
111 ///
112 /// * `seq_id` - The sequence id to keep
113 pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
114 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
115 unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
116 }
117
118 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
119 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
120 /// - lazily on next [`LlamaContext::decode`]
121 /// - explicitly with [`Self::kv_cache_update`]
122 ///
123 /// # Returns
124 /// A `Result` indicating whether the operation was successful.
125 ///
126 /// # Parameters
127 ///
128 /// * `seq_id` - The sequence id to update
129 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
130 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
131 /// * `delta` - The relative position to add to the tokens
132 ///
133 /// # Errors
134 /// If either position exceeds [`i32::MAX`].
135 pub fn kv_cache_seq_add(
136 &mut self,
137 seq_id: i32,
138 p0: Option<u32>,
139 p1: Option<u32>,
140 delta: i32,
141 ) -> Result<(), KvCacheConversionError> {
142 let p0 = p0
143 .map_or(Ok(-1), i32::try_from)
144 .map_err(KvCacheConversionError::P0TooLarge)?;
145 let p1 = p1
146 .map_or(Ok(-1), i32::try_from)
147 .map_err(KvCacheConversionError::P1TooLarge)?;
148 let status = unsafe {
149 llama_cpp_bindings_sys::llama_rs_memory_seq_add(
150 self.context.as_ptr(),
151 seq_id,
152 p0,
153 p1,
154 delta,
155 )
156 };
157
158 if crate::status_is_ok(status) {
159 Ok(())
160 } else {
161 Err(KvCacheConversionError::UnsupportedOperation(format!(
162 "kv_cache_seq_add failed (status {})",
163 crate::status_to_i32(status)
164 )))
165 }
166 }
167
168 /// Integer division of the positions by factor of `d > 1`
169 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
170 /// - lazily on next [`LlamaContext::decode`]
171 /// - explicitly with [`Self::kv_cache_update`]
172 ///
173 /// # Returns
174 /// A `Result` indicating whether the operation was successful.
175 ///
176 /// # Parameters
177 ///
178 /// * `seq_id` - The sequence id to update
179 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
180 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
181 /// * `d` - The factor to divide the positions by
182 ///
183 /// # Errors
184 /// If either position exceeds [`i32::MAX`].
185 pub fn kv_cache_seq_div(
186 &mut self,
187 seq_id: i32,
188 p0: Option<u32>,
189 p1: Option<u32>,
190 d: NonZeroU8,
191 ) -> Result<(), KvCacheConversionError> {
192 let p0 = p0
193 .map_or(Ok(-1), i32::try_from)
194 .map_err(KvCacheConversionError::P0TooLarge)?;
195 let p1 = p1
196 .map_or(Ok(-1), i32::try_from)
197 .map_err(KvCacheConversionError::P1TooLarge)?;
198 let d = c_int::from(d.get());
199 let status = unsafe {
200 llama_cpp_bindings_sys::llama_rs_memory_seq_div(
201 self.context.as_ptr(),
202 seq_id,
203 p0,
204 p1,
205 d,
206 )
207 };
208
209 if crate::status_is_ok(status) {
210 Ok(())
211 } else {
212 Err(KvCacheConversionError::UnsupportedOperation(format!(
213 "kv_cache_seq_div failed (status {})",
214 crate::status_to_i32(status)
215 )))
216 }
217 }
218
219 /// Returns the largest position present in the KV cache for the specified sequence
220 ///
221 /// # Parameters
222 ///
223 /// * `seq_id` - The sequence id to get the max position for
224 #[must_use]
225 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
226 unsafe {
227 llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
228 }
229 }
230}