llama_cpp_2/context/kv_cache.rs
1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9#[allow(clippy::module_name_repetitions)]
10pub enum KvCacheConversionError {
11 /// Sequence id conversion to i32 failed
12 #[error("Provided sequence id is too large for a i32")]
13 SeqIdTooLarge(#[source] TryFromIntError),
14 /// Position 0 conversion to i32 failed
15 #[error("Provided start position is too large for a i32")]
16 P0TooLarge(#[source] TryFromIntError),
17 /// Position 1 conversion to i32 failed
18 #[error("Provided end position is too large for a i32")]
19 P1TooLarge(#[source] TryFromIntError),
20}
21
22impl LlamaContext<'_> {
23 /// Copy the cache from one sequence to another.
24 ///
25 /// # Parameters
26 ///
27 /// * `src` - The sequence id to copy the cache from.
28 /// * `dest` - The sequence id to copy the cache to.
29 /// * `size` - The size of the cache to copy.
30 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
31 unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) }
32 }
33
34 /// Copy the cache from one sequence to another.
35 ///
36 /// # Returns
37 /// A `Result` indicating whether the operation was successful.
38 ///
39 /// # Parameters
40 /// * `src` - The sequence id to copy the cache from.
41 /// * `dest` - The sequence id to copy the cache to.
42 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
43 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
44 ///
45 /// # Errors
46 /// If either position exceeds [`i32::MAX`].
47 pub fn copy_kv_cache_seq(
48 &mut self,
49 src: i32,
50 dest: i32,
51 p0: Option<u32>,
52 p1: Option<u32>,
53 ) -> Result<(), KvCacheConversionError> {
54 let p0 = p0
55 .map_or(Ok(-1), i32::try_from)
56 .map_err(KvCacheConversionError::P0TooLarge)?;
57 let p1 = p1
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P1TooLarge)?;
60 unsafe {
61 llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
62 }
63 Ok(())
64 }
65
66 /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
67 /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
68 ///
69 /// # Returns
70 /// A `Result` indicating whether the operation was successful. If the sequence id or
71 /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
72 ///
73 /// # Parameters
74 /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
75 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
76 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
77 ///
78 /// # Errors
79 /// If the sequence id or either position exceeds [`i32::MAX`].
80 pub fn clear_kv_cache_seq(
81 &mut self,
82 src: Option<u32>,
83 p0: Option<u32>,
84 p1: Option<u32>,
85 ) -> Result<bool, KvCacheConversionError> {
86 let src = src
87 .map_or(Ok(-1), i32::try_from)
88 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89 let p0 = p0
90 .map_or(Ok(-1), i32::try_from)
91 .map_err(KvCacheConversionError::P0TooLarge)?;
92 let p1 = p1
93 .map_or(Ok(-1), i32::try_from)
94 .map_err(KvCacheConversionError::P1TooLarge)?;
95 Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) })
96 }
97
98 /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
99 #[must_use]
100 pub fn get_kv_cache_used_cells(&self) -> i32 {
101 unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) }
102 }
103
104 /// Clear the KV cache
105 pub fn clear_kv_cache(&mut self) {
106 unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) }
107 }
108
109 /// Removes all tokens that do not belong to the specified sequence
110 ///
111 /// # Parameters
112 ///
113 /// * `seq_id` - The sequence id to keep
114 pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
115 unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) }
116 }
117
118 #[allow(clippy::doc_markdown)]
119 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
120 /// If the KV cache is RoPEd, the KV data is updated accordingly:
121 /// - lazily on next [`LlamaContext::decode`]
122 /// - explicitly with [`Self::kv_cache_update`]
123 ///
124 /// # Returns
125 /// A `Result` indicating whether the operation was successful.
126 ///
127 /// # Parameters
128 ///
129 /// * `seq_id` - The sequence id to update
130 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
131 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
132 /// * `delta` - The relative position to add to the tokens
133 ///
134 /// # Errors
135 /// If either position exceeds [`i32::MAX`].
136 pub fn kv_cache_seq_add(
137 &mut self,
138 seq_id: i32,
139 p0: Option<u32>,
140 p1: Option<u32>,
141 delta: i32,
142 ) -> Result<(), KvCacheConversionError> {
143 let p0 = p0
144 .map_or(Ok(-1), i32::try_from)
145 .map_err(KvCacheConversionError::P0TooLarge)?;
146 let p1 = p1
147 .map_or(Ok(-1), i32::try_from)
148 .map_err(KvCacheConversionError::P1TooLarge)?;
149 unsafe {
150 llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
151 }
152 Ok(())
153 }
154
155 /// Integer division of the positions by factor of `d > 1`
156 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
157 /// - lazily on next [`LlamaContext::decode`]
158 /// - explicitly with [`Self::kv_cache_update`]
159 ///
160 /// # Returns
161 /// A `Result` indicating whether the operation was successful.
162 ///
163 /// # Parameters
164 ///
165 /// * `seq_id` - The sequence id to update
166 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
167 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
168 /// * `d` - The factor to divide the positions by
169 ///
170 /// # Errors
171 /// If either position exceeds [`i32::MAX`].
172 pub fn kv_cache_seq_div(
173 &mut self,
174 seq_id: i32,
175 p0: Option<u32>,
176 p1: Option<u32>,
177 d: NonZeroU8,
178 ) -> Result<(), KvCacheConversionError> {
179 let p0 = p0
180 .map_or(Ok(-1), i32::try_from)
181 .map_err(KvCacheConversionError::P0TooLarge)?;
182 let p1 = p1
183 .map_or(Ok(-1), i32::try_from)
184 .map_err(KvCacheConversionError::P1TooLarge)?;
185 let d = c_int::from(d.get());
186 unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
187 Ok(())
188 }
189
190 /// Returns the largest position present in the KV cache for the specified sequence
191 ///
192 /// # Parameters
193 ///
194 /// * `seq_id` - The sequence id to get the max position for
195 #[must_use]
196 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
197 unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) }
198 }
199
200 /// Defragment the KV cache
201 /// This will be applied:
202 /// - lazily on next [`LlamaContext::decode`]
203 /// - explicitly with [`Self::kv_cache_update`]
204 pub fn kv_cache_defrag(&mut self) {
205 unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) }
206 }
207
208 /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
209 pub fn kv_cache_update(&mut self) {
210 unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) }
211 }
212}