llama_cpp_4/context/kv_cache.rs
1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::num::{NonZeroU8, TryFromIntError};
5
6/// Errors that can occur when attempting to prepare values for the kv cache
7#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8pub enum KvCacheConversionError {
9 /// Sequence id conversion to i32 failed
10 #[error("Provided sequence id is too large for a i32")]
11 SeqIdTooLarge(#[source] TryFromIntError),
12 /// Position 0 conversion to i32 failed
13 #[error("Provided start position is too large for a i32")]
14 P0TooLarge(#[source] TryFromIntError),
15 /// Position 1 conversion to i32 failed
16 #[error("Provided end position is too large for a i32")]
17 P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21 /// Copy the cache from one sequence to another.
22 ///
23 /// # Parameters
24 ///
25 /// * `src` - The sequence id to copy the cache from.
26 /// * `dest` - The sequence id to copy the cache to.
27 /// * `size` - The size of the cache to copy.
28 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
29 unsafe {
30 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
31 llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, 0, size);
32 }
33 }
34
35 /// Copy the cache from one sequence to another.
36 ///
37 /// # Parameters
38 ///
39 /// * `src` - The sequence id to copy the cache from.
40 /// * `dest` - The sequence id to copy the cache to.
41 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
42 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
43 ///
44 /// # Errors
45 ///
46 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
47 pub fn copy_kv_cache_seq(
48 &mut self,
49 src: i32,
50 dest: i32,
51 p0: Option<u32>,
52 p1: Option<u32>,
53 ) -> Result<(), KvCacheConversionError> {
54 let p0 = p0
55 .map_or(Ok(-1), i32::try_from)
56 .map_err(KvCacheConversionError::P0TooLarge)?;
57 let p1 = p1
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P1TooLarge)?;
60 unsafe {
61 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
62 llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, p0, p1);
63 }
64 Ok(())
65 }
66
67 /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`.
68 ///
69 /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70 ///
71 /// # Parameters
72 ///
73 /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74 /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75 /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76 ///
77 /// # Errors
78 ///
79 /// Returns [`KvCacheConversionError`] if the sequence id or either position exceeds the maximum `i32` value.
80 pub fn clear_kv_cache_seq(
81 &mut self,
82 src: Option<u32>,
83 p0: Option<u32>,
84 p1: Option<u32>,
85 ) -> Result<bool, KvCacheConversionError> {
86 let src = src
87 .map_or(Ok(-1), i32::try_from)
88 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89 let p0 = p0
90 .map_or(Ok(-1), i32::try_from)
91 .map_err(KvCacheConversionError::P0TooLarge)?;
92 let p1 = p1
93 .map_or(Ok(-1), i32::try_from)
94 .map_err(KvCacheConversionError::P1TooLarge)?;
95 Ok(unsafe {
96 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
97 llama_cpp_sys_4::llama_memory_seq_rm(mem, src, p0, p1)
98 })
99 }
100
101 /// Clear the KV cache
102 pub fn clear_kv_cache(&mut self) {
103 unsafe {
104 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
105 llama_cpp_sys_4::llama_memory_clear(mem, true);
106 }
107 }
108
109 /// Removes all tokens that do not belong to the specified sequence
110 ///
111 /// # Parameters
112 ///
113 /// * `seq_id` - The sequence id to keep
114 pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
115 unsafe {
116 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
117 llama_cpp_sys_4::llama_memory_seq_keep(mem, seq_id);
118 }
119 }
120
121 #[allow(clippy::doc_markdown)]
122 /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`.
123 ///
124 /// # Parameters
125 ///
126 /// * `seq_id` - The sequence id to update
127 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
128 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
129 /// * `delta` - The relative position to add to the tokens
130 ///
131 /// # Errors
132 ///
133 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
134 pub fn kv_cache_seq_add(
135 &mut self,
136 seq_id: i32,
137 p0: Option<u32>,
138 p1: Option<u32>,
139 delta: i32,
140 ) -> Result<(), KvCacheConversionError> {
141 let p0 = p0
142 .map_or(Ok(-1), i32::try_from)
143 .map_err(KvCacheConversionError::P0TooLarge)?;
144 let p1 = p1
145 .map_or(Ok(-1), i32::try_from)
146 .map_err(KvCacheConversionError::P1TooLarge)?;
147 unsafe {
148 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
149 llama_cpp_sys_4::llama_memory_seq_add(mem, seq_id, p0, p1, delta);
150 }
151 Ok(())
152 }
153
154 /// Integer division of the positions by factor of `d > 1`.
155 ///
156 /// # Parameters
157 ///
158 /// * `seq_id` - The sequence id to update
159 /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
160 /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
161 /// * `d` - The factor to divide the positions by
162 ///
163 /// # Errors
164 ///
165 /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
166 pub fn kv_cache_seq_div(
167 &mut self,
168 seq_id: i32,
169 p0: Option<u32>,
170 p1: Option<u32>,
171 d: NonZeroU8,
172 ) -> Result<(), KvCacheConversionError> {
173 let p0 = p0
174 .map_or(Ok(-1), i32::try_from)
175 .map_err(KvCacheConversionError::P0TooLarge)?;
176 let p1 = p1
177 .map_or(Ok(-1), i32::try_from)
178 .map_err(KvCacheConversionError::P1TooLarge)?;
179 let d = i32::from(d.get());
180 unsafe {
181 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
182 llama_cpp_sys_4::llama_memory_seq_div(mem, seq_id, p0, p1, d);
183 }
184 Ok(())
185 }
186
187 /// Returns the largest position present in the KV cache for the specified sequence
188 ///
189 /// # Parameters
190 ///
191 /// * `seq_id` - The sequence id to get the max position for
192 #[must_use]
193 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
194 unsafe {
195 let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
196 llama_cpp_sys_4::llama_memory_seq_pos_max(mem, seq_id)
197 }
198 }
199}