llama_cpp_4/context/kv_cache.rs
1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::num::{NonZeroU8, TryFromIntError};
5
6/// Errors that can occur when attempting to prepare values for the kv cache
7#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8pub enum KvCacheConversionError {
9 /// Sequence id conversion to i32 failed
10 #[error("Provided sequence id is too large for a i32")]
11 SeqIdTooLarge(#[source] TryFromIntError),
12 /// Position 0 conversion to i32 failed
13 #[error("Provided start position is too large for a i32")]
14 P0TooLarge(#[source] TryFromIntError),
15 /// Position 1 conversion to i32 failed
16 #[error("Provided end position is too large for a i32")]
17 P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21 /// Copy the cache from one sequence to another.
22 ///
23 /// # Parameters
24 ///
25 /// * `src` - The sequence id to copy the cache from.
26 /// * `dest` - The sequence id to copy the cache to.
27 /// * `size` - The size of the cache to copy.
28 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
29 unsafe {
30 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
31 llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, 0, size);
32 }
33 }
34
35 /// Copy the cache from one sequence to another.
36 ///
37 /// # Parameters
38 ///
39 /// * `src` - The sequence id to copy the cache from.
40 /// * `dest` - The sequence id to copy the cache to.
41 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
42 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
43 ///
44 /// # Errors
45 ///
46 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
47 pub fn copy_kv_cache_seq(
48 &mut self,
49 src: i32,
50 dest: i32,
51 p0: Option<u32>,
52 p1: Option<u32>,
53 ) -> Result<(), KvCacheConversionError> {
54 let p0 = p0
55 .map_or(Ok(-1), i32::try_from)
56 .map_err(KvCacheConversionError::P0TooLarge)?;
57 let p1 = p1
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P1TooLarge)?;
60 unsafe {
61 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
62 llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, p0, p1);
63 }
64 Ok(())
65 }
66
67 /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`.
68 ///
69 /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70 ///
71 /// # Parameters
72 ///
73 /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76 ///
77 /// # Errors
78 ///
79 /// Returns [`KvCacheConversionError`] if the sequence id or either position exceeds the maximum `i32` value.
80 pub fn clear_kv_cache_seq(
81 &mut self,
82 src: Option<u32>,
83 p0: Option<u32>,
84 p1: Option<u32>,
85 ) -> Result<bool, KvCacheConversionError> {
86 let src = src
87 .map_or(Ok(-1), i32::try_from)
88 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89 let p0 = p0
90 .map_or(Ok(-1), i32::try_from)
91 .map_err(KvCacheConversionError::P0TooLarge)?;
92 let p1 = p1
93 .map_or(Ok(-1), i32::try_from)
94 .map_err(KvCacheConversionError::P1TooLarge)?;
95 #[cfg(feature = "mtp")]
96 {
97 Ok(
98 unsafe {
99 llama_cpp_sys_4::llama_context_seq_rm(self.context.as_ptr(), src, p0, p1)
100 },
101 )
102 }
103 #[cfg(not(feature = "mtp"))]
104 {
105 Ok(unsafe {
106 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
107 llama_cpp_sys_4::llama_memory_seq_rm(mem, src, p0, p1)
108 })
109 }
110 }
111
112 /// Clear the KV cache
113 pub fn clear_kv_cache(&mut self) {
114 unsafe {
115 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
116 llama_cpp_sys_4::llama_memory_clear(mem, true);
117 }
118 }
119
120 /// Removes all tokens that do not belong to the specified sequence
121 ///
122 /// # Parameters
123 ///
124 /// * `seq_id` - The sequence id to keep
125 pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
126 unsafe {
127 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
128 llama_cpp_sys_4::llama_memory_seq_keep(mem, seq_id);
129 }
130 }
131
132 #[allow(clippy::doc_markdown)]
133 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`.
134 ///
135 /// # Parameters
136 ///
137 /// * `seq_id` - The sequence id to update
138 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
139 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
140 /// * `delta` - The relative position to add to the tokens
141 ///
142 /// # Errors
143 ///
144 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
145 pub fn kv_cache_seq_add(
146 &mut self,
147 seq_id: i32,
148 p0: Option<u32>,
149 p1: Option<u32>,
150 delta: i32,
151 ) -> Result<(), KvCacheConversionError> {
152 let p0 = p0
153 .map_or(Ok(-1), i32::try_from)
154 .map_err(KvCacheConversionError::P0TooLarge)?;
155 let p1 = p1
156 .map_or(Ok(-1), i32::try_from)
157 .map_err(KvCacheConversionError::P1TooLarge)?;
158 unsafe {
159 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
160 llama_cpp_sys_4::llama_memory_seq_add(mem, seq_id, p0, p1, delta);
161 }
162 Ok(())
163 }
164
165 /// Integer division of the positions by factor of `d > 1`.
166 ///
167 /// # Parameters
168 ///
169 /// * `seq_id` - The sequence id to update
170 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
171 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
172 /// * `d` - The factor to divide the positions by
173 ///
174 /// # Errors
175 ///
176 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
177 pub fn kv_cache_seq_div(
178 &mut self,
179 seq_id: i32,
180 p0: Option<u32>,
181 p1: Option<u32>,
182 d: NonZeroU8,
183 ) -> Result<(), KvCacheConversionError> {
184 let p0 = p0
185 .map_or(Ok(-1), i32::try_from)
186 .map_err(KvCacheConversionError::P0TooLarge)?;
187 let p1 = p1
188 .map_or(Ok(-1), i32::try_from)
189 .map_err(KvCacheConversionError::P1TooLarge)?;
190 let d = i32::from(d.get());
191 unsafe {
192 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
193 llama_cpp_sys_4::llama_memory_seq_div(mem, seq_id, p0, p1, d);
194 }
195 Ok(())
196 }
197
198 /// Returns the largest position present in the KV cache for the specified sequence
199 ///
200 /// # Parameters
201 ///
202 /// * `seq_id` - The sequence id to get the max position for
203 #[must_use]
204 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
205 unsafe {
206 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
207 llama_cpp_sys_4::llama_memory_seq_pos_max(mem, seq_id)
208 }
209 }
210}