llama_cpp_2/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9#[allow(clippy::module_name_repetitions)]
10pub enum KvCacheConversionError {
11    /// Sequence id conversion to i32 failed
12    #[error("Provided sequence id is too large for a i32")]
13    SeqIdTooLarge(#[source] TryFromIntError),
14    /// Position 0 conversion to i32 failed
15    #[error("Provided start position is too large for a i32")]
16    P0TooLarge(#[source] TryFromIntError),
17    /// Position 1 conversion to i32 failed
18    #[error("Provided end position is too large for a i32")]
19    P1TooLarge(#[source] TryFromIntError),
20}
21
22impl LlamaContext<'_> {
23    /// Copy the cache from one sequence to another.
24    ///
25    /// # Parameters
26    ///
27    /// * `src` - The sequence id to copy the cache from.
28    /// * `dest` - The sequence id to copy the cache to.
29    /// * `size` - The size of the cache to copy.
30    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
31        unsafe { llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, 0, size) }
32    }
33
34    /// Copy the cache from one sequence to another.
35    ///
36    /// # Returns
37    /// A `Result` indicating whether the operation was successful.
38    ///
39    /// # Parameters
40    /// * `src` - The sequence id to copy the cache from.
41    /// * `dest` - The sequence id to copy the cache to.
42    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
43    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
44    ///
45    /// # Errors
46    /// If either position exceeds [`i32::MAX`].
47    pub fn copy_kv_cache_seq(
48        &mut self,
49        src: i32,
50        dest: i32,
51        p0: Option<u32>,
52        p1: Option<u32>,
53    ) -> Result<(), KvCacheConversionError> {
54        let p0 = p0
55            .map_or(Ok(-1), i32::try_from)
56            .map_err(KvCacheConversionError::P0TooLarge)?;
57        let p1 = p1
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P1TooLarge)?;
60        unsafe {
61            llama_cpp_sys_2::llama_kv_self_seq_cp(self.context.as_ptr(), src, dest, p0, p1);
62        }
63        Ok(())
64    }
65
66    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
67    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
68    ///
69    /// # Returns
70    /// A `Result` indicating whether the operation was successful. If the sequence id or
71    /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
72    ///
73    /// # Parameters
74    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
75    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
76    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
77    ///
78    /// # Errors
79    /// If the sequence id or either position exceeds [`i32::MAX`].
80    pub fn clear_kv_cache_seq(
81        &mut self,
82        src: Option<u32>,
83        p0: Option<u32>,
84        p1: Option<u32>,
85    ) -> Result<bool, KvCacheConversionError> {
86        let src = src
87            .map_or(Ok(-1), i32::try_from)
88            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89        let p0 = p0
90            .map_or(Ok(-1), i32::try_from)
91            .map_err(KvCacheConversionError::P0TooLarge)?;
92        let p1 = p1
93            .map_or(Ok(-1), i32::try_from)
94            .map_err(KvCacheConversionError::P1TooLarge)?;
95        Ok(unsafe { llama_cpp_sys_2::llama_kv_self_seq_rm(self.context.as_ptr(), src, p0, p1) })
96    }
97
98    /// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
99    #[must_use]
100    pub fn get_kv_cache_used_cells(&self) -> i32 {
101        unsafe { llama_cpp_sys_2::llama_kv_self_used_cells(self.context.as_ptr()) }
102    }
103
104    /// Clear the KV cache
105    pub fn clear_kv_cache(&mut self) {
106        unsafe { llama_cpp_sys_2::llama_kv_self_clear(self.context.as_ptr()) }
107    }
108
109    /// Removes all tokens that do not belong to the specified sequence
110    ///
111    /// # Parameters
112    ///
113    /// * `seq_id` - The sequence id to keep
114    pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
115        unsafe { llama_cpp_sys_2::llama_kv_self_seq_keep(self.context.as_ptr(), seq_id) }
116    }
117
118    #[allow(clippy::doc_markdown)]
119    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
120    /// If the KV cache is RoPEd, the KV data is updated accordingly:
121    ///   - lazily on next [`LlamaContext::decode`]
122    ///   - explicitly with [`Self::kv_cache_update`]
123    ///
124    /// # Returns
125    /// A `Result` indicating whether the operation was successful.
126    ///
127    /// # Parameters
128    ///
129    /// * `seq_id` - The sequence id to update
130    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
131    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
132    /// * `delta` - The relative position to add to the tokens
133    ///
134    /// # Errors
135    /// If either position exceeds [`i32::MAX`].
136    pub fn kv_cache_seq_add(
137        &mut self,
138        seq_id: i32,
139        p0: Option<u32>,
140        p1: Option<u32>,
141        delta: i32,
142    ) -> Result<(), KvCacheConversionError> {
143        let p0 = p0
144            .map_or(Ok(-1), i32::try_from)
145            .map_err(KvCacheConversionError::P0TooLarge)?;
146        let p1 = p1
147            .map_or(Ok(-1), i32::try_from)
148            .map_err(KvCacheConversionError::P1TooLarge)?;
149        unsafe {
150            llama_cpp_sys_2::llama_kv_self_seq_add(self.context.as_ptr(), seq_id, p0, p1, delta);
151        }
152        Ok(())
153    }
154
155    /// Integer division of the positions by factor of `d > 1`
156    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
157    ///   - lazily on next [`LlamaContext::decode`]
158    ///   - explicitly with [`Self::kv_cache_update`]
159    ///
160    /// # Returns
161    /// A `Result` indicating whether the operation was successful.
162    ///
163    /// # Parameters
164    ///
165    /// * `seq_id` - The sequence id to update
166    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
167    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
168    /// * `d` - The factor to divide the positions by
169    ///
170    /// # Errors
171    /// If either position exceeds [`i32::MAX`].
172    pub fn kv_cache_seq_div(
173        &mut self,
174        seq_id: i32,
175        p0: Option<u32>,
176        p1: Option<u32>,
177        d: NonZeroU8,
178    ) -> Result<(), KvCacheConversionError> {
179        let p0 = p0
180            .map_or(Ok(-1), i32::try_from)
181            .map_err(KvCacheConversionError::P0TooLarge)?;
182        let p1 = p1
183            .map_or(Ok(-1), i32::try_from)
184            .map_err(KvCacheConversionError::P1TooLarge)?;
185        let d = c_int::from(d.get());
186        unsafe { llama_cpp_sys_2::llama_kv_self_seq_div(self.context.as_ptr(), seq_id, p0, p1, d) }
187        Ok(())
188    }
189
190    /// Returns the largest position present in the KV cache for the specified sequence
191    ///
192    /// # Parameters
193    ///
194    /// * `seq_id` - The sequence id to get the max position for
195    #[must_use]
196    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
197        unsafe { llama_cpp_sys_2::llama_kv_self_seq_pos_max(self.context.as_ptr(), seq_id) }
198    }
199
200    /// Defragment the KV cache
201    /// This will be applied:
202    ///   - lazily on next [`LlamaContext::decode`]
203    ///   - explicitly with [`Self::kv_cache_update`]
204    pub fn kv_cache_defrag(&mut self) {
205        unsafe { llama_cpp_sys_2::llama_kv_self_defrag(self.context.as_ptr()) }
206    }
207
208    /// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
209    pub fn kv_cache_update(&mut self) {
210        unsafe { llama_cpp_sys_2::llama_kv_self_update(self.context.as_ptr()) }
211    }
212}