llama_cpp_bindings/context/kv_cache.rs
1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10 /// Sequence id conversion to i32 failed
11 #[error("Provided sequence id is too large for a i32")]
12 SeqIdTooLarge(#[source] TryFromIntError),
13 /// Position 0 conversion to i32 failed
14 #[error("Provided start position is too large for a i32")]
15 P0TooLarge(#[source] TryFromIntError),
16 /// Position 1 conversion to i32 failed
17 #[error("Provided end position is too large for a i32")]
18 P1TooLarge(#[source] TryFromIntError),
19}
20
21impl LlamaContext<'_> {
22 /// Copy the cache from one sequence to another.
23 ///
24 /// # Parameters
25 ///
26 /// * `src` - The sequence id to copy the cache from.
27 /// * `dest` - The sequence id to copy the cache to.
28 /// * `size` - The size of the cache to copy.
29 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
30 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
31 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
32 }
33
34 /// Copy the cache from one sequence to another.
35 ///
36 /// # Returns
37 /// A `Result` indicating whether the operation was successful.
38 ///
39 /// # Parameters
40 /// * `src` - The sequence id to copy the cache from.
41 /// * `dest` - The sequence id to copy the cache to.
42 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
43 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
44 ///
45 /// # Errors
46 /// If either position exceeds [`i32::MAX`].
47 pub fn copy_kv_cache_seq(
48 &mut self,
49 src: i32,
50 dest: i32,
51 p0: Option<u32>,
52 p1: Option<u32>,
53 ) -> Result<(), KvCacheConversionError> {
54 let p0 = p0
55 .map_or(Ok(-1), i32::try_from)
56 .map_err(KvCacheConversionError::P0TooLarge)?;
57 let p1 = p1
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P1TooLarge)?;
60 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
61 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
62 Ok(())
63 }
64
65 /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
66 /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
67 ///
68 /// # Returns
69 /// A `Result` indicating whether the operation was successful. If the sequence id or
70 /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
71 ///
72 /// # Parameters
73 /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76 ///
77 /// # Errors
78 /// If the sequence id or either position exceeds [`i32::MAX`].
79 pub fn clear_kv_cache_seq(
80 &mut self,
81 src: Option<u32>,
82 p0: Option<u32>,
83 p1: Option<u32>,
84 ) -> Result<bool, KvCacheConversionError> {
85 let src = src
86 .map_or(Ok(-1), i32::try_from)
87 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
88 let p0 = p0
89 .map_or(Ok(-1), i32::try_from)
90 .map_err(KvCacheConversionError::P0TooLarge)?;
91 let p1 = p1
92 .map_or(Ok(-1), i32::try_from)
93 .map_err(KvCacheConversionError::P1TooLarge)?;
94 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
95 Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
96 }
97
98 /// Clear the KV cache
99 pub fn clear_kv_cache(&mut self) {
100 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
101 // clear both metadata and data buffers to match previous semantics
102 unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, true) }
103 }
104
105 /// Removes all tokens that do not belong to the specified sequence
106 ///
107 /// # Parameters
108 ///
109 /// * `seq_id` - The sequence id to keep
110 pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
111 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
112 unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
113 }
114
115 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
116 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
117 /// - lazily on next [`LlamaContext::decode`]
118 /// - explicitly with [`Self::kv_cache_update`]
119 ///
120 /// # Returns
121 /// A `Result` indicating whether the operation was successful.
122 ///
123 /// # Parameters
124 ///
125 /// * `seq_id` - The sequence id to update
126 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
127 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
128 /// * `delta` - The relative position to add to the tokens
129 ///
130 /// # Errors
131 /// If either position exceeds [`i32::MAX`].
132 pub fn kv_cache_seq_add(
133 &mut self,
134 seq_id: i32,
135 p0: Option<u32>,
136 p1: Option<u32>,
137 delta: i32,
138 ) -> Result<(), KvCacheConversionError> {
139 let p0 = p0
140 .map_or(Ok(-1), i32::try_from)
141 .map_err(KvCacheConversionError::P0TooLarge)?;
142 let p1 = p1
143 .map_or(Ok(-1), i32::try_from)
144 .map_err(KvCacheConversionError::P1TooLarge)?;
145 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
146 unsafe { llama_cpp_bindings_sys::llama_memory_seq_add(mem, seq_id, p0, p1, delta) };
147 Ok(())
148 }
149
150 /// Integer division of the positions by factor of `d > 1`
151 /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
152 /// - lazily on next [`LlamaContext::decode`]
153 /// - explicitly with [`Self::kv_cache_update`]
154 ///
155 /// # Returns
156 /// A `Result` indicating whether the operation was successful.
157 ///
158 /// # Parameters
159 ///
160 /// * `seq_id` - The sequence id to update
161 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
162 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
163 /// * `d` - The factor to divide the positions by
164 ///
165 /// # Errors
166 /// If either position exceeds [`i32::MAX`].
167 pub fn kv_cache_seq_div(
168 &mut self,
169 seq_id: i32,
170 p0: Option<u32>,
171 p1: Option<u32>,
172 d: NonZeroU8,
173 ) -> Result<(), KvCacheConversionError> {
174 let p0 = p0
175 .map_or(Ok(-1), i32::try_from)
176 .map_err(KvCacheConversionError::P0TooLarge)?;
177 let p1 = p1
178 .map_or(Ok(-1), i32::try_from)
179 .map_err(KvCacheConversionError::P1TooLarge)?;
180 let d = c_int::from(d.get());
181 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
182 unsafe { llama_cpp_bindings_sys::llama_memory_seq_div(mem, seq_id, p0, p1, d) }
183 Ok(())
184 }
185
186 /// Returns the largest position present in the KV cache for the specified sequence
187 ///
188 /// # Parameters
189 ///
190 /// * `seq_id` - The sequence id to get the max position for
191 #[must_use]
192 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
193 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
194 unsafe { llama_cpp_bindings_sys::llama_memory_seq_pos_max(mem, seq_id) }
195 }
196}