Skip to main content

llama_cpp_bindings/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10    /// Sequence id conversion to i32 failed
11    #[error("Provided sequence id is too large for a i32")]
12    SeqIdTooLarge(#[source] TryFromIntError),
13    /// Position 0 conversion to i32 failed
14    #[error("Provided start position is too large for a i32")]
15    P0TooLarge(#[source] TryFromIntError),
16    /// Position 1 conversion to i32 failed
17    #[error("Provided end position is too large for a i32")]
18    P1TooLarge(#[source] TryFromIntError),
19}
20
21impl LlamaContext<'_> {
22    /// Copy the cache from one sequence to another.
23    ///
24    /// # Parameters
25    ///
26    /// * `src` - The sequence id to copy the cache from.
27    /// * `dest` - The sequence id to copy the cache to.
28    /// * `size` - The size of the cache to copy.
29    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
30        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
31        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
32    }
33
34    /// Copy the cache from one sequence to another.
35    ///
36    /// # Returns
37    /// A `Result` indicating whether the operation was successful.
38    ///
39    /// # Parameters
40    /// * `src` - The sequence id to copy the cache from.
41    /// * `dest` - The sequence id to copy the cache to.
42    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
43    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
44    ///
45    /// # Errors
46    /// If either position exceeds [`i32::MAX`].
47    pub fn copy_kv_cache_seq(
48        &mut self,
49        src: i32,
50        dest: i32,
51        p0: Option<u32>,
52        p1: Option<u32>,
53    ) -> Result<(), KvCacheConversionError> {
54        let p0 = p0
55            .map_or(Ok(-1), i32::try_from)
56            .map_err(KvCacheConversionError::P0TooLarge)?;
57        let p1 = p1
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P1TooLarge)?;
60        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
61        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
62        Ok(())
63    }
64
65    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
66    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
67    ///
68    /// # Returns
69    /// A `Result` indicating whether the operation was successful. If the sequence id or
70    /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
71    ///
72    /// # Parameters
73    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76    ///
77    /// # Errors
78    /// If the sequence id or either position exceeds [`i32::MAX`].
79    pub fn clear_kv_cache_seq(
80        &mut self,
81        src: Option<u32>,
82        p0: Option<u32>,
83        p1: Option<u32>,
84    ) -> Result<bool, KvCacheConversionError> {
85        let src = src
86            .map_or(Ok(-1), i32::try_from)
87            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
88        let p0 = p0
89            .map_or(Ok(-1), i32::try_from)
90            .map_err(KvCacheConversionError::P0TooLarge)?;
91        let p1 = p1
92            .map_or(Ok(-1), i32::try_from)
93            .map_err(KvCacheConversionError::P1TooLarge)?;
94        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
95        Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
96    }
97
98    /// Clear the KV cache
99    pub fn clear_kv_cache(&mut self) {
100        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
101        // clear both metadata and data buffers to match previous semantics
102        unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, true) }
103    }
104
105    /// Removes all tokens that do not belong to the specified sequence
106    ///
107    /// # Parameters
108    ///
109    /// * `seq_id` - The sequence id to keep
110    pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
111        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
112        unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
113    }
114
115    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
116    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
117    ///   - lazily on next [`LlamaContext::decode`]
118    ///   - explicitly with [`Self::kv_cache_update`]
119    ///
120    /// # Returns
121    /// A `Result` indicating whether the operation was successful.
122    ///
123    /// # Parameters
124    ///
125    /// * `seq_id` - The sequence id to update
126    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
127    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
128    /// * `delta` - The relative position to add to the tokens
129    ///
130    /// # Errors
131    /// If either position exceeds [`i32::MAX`].
132    pub fn kv_cache_seq_add(
133        &mut self,
134        seq_id: i32,
135        p0: Option<u32>,
136        p1: Option<u32>,
137        delta: i32,
138    ) -> Result<(), KvCacheConversionError> {
139        let p0 = p0
140            .map_or(Ok(-1), i32::try_from)
141            .map_err(KvCacheConversionError::P0TooLarge)?;
142        let p1 = p1
143            .map_or(Ok(-1), i32::try_from)
144            .map_err(KvCacheConversionError::P1TooLarge)?;
145        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
146        unsafe { llama_cpp_bindings_sys::llama_memory_seq_add(mem, seq_id, p0, p1, delta) };
147        Ok(())
148    }
149
150    /// Integer division of the positions by factor of `d > 1`
151    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
152    ///   - lazily on next [`LlamaContext::decode`]
153    ///   - explicitly with [`Self::kv_cache_update`]
154    ///
155    /// # Returns
156    /// A `Result` indicating whether the operation was successful.
157    ///
158    /// # Parameters
159    ///
160    /// * `seq_id` - The sequence id to update
161    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
162    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
163    /// * `d` - The factor to divide the positions by
164    ///
165    /// # Errors
166    /// If either position exceeds [`i32::MAX`].
167    pub fn kv_cache_seq_div(
168        &mut self,
169        seq_id: i32,
170        p0: Option<u32>,
171        p1: Option<u32>,
172        d: NonZeroU8,
173    ) -> Result<(), KvCacheConversionError> {
174        let p0 = p0
175            .map_or(Ok(-1), i32::try_from)
176            .map_err(KvCacheConversionError::P0TooLarge)?;
177        let p1 = p1
178            .map_or(Ok(-1), i32::try_from)
179            .map_err(KvCacheConversionError::P1TooLarge)?;
180        let d = c_int::from(d.get());
181        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
182        unsafe { llama_cpp_bindings_sys::llama_memory_seq_div(mem, seq_id, p0, p1, d) }
183        Ok(())
184    }
185
186    /// Returns the largest position present in the KV cache for the specified sequence
187    ///
188    /// # Parameters
189    ///
190    /// * `seq_id` - The sequence id to get the max position for
191    #[must_use]
192    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
193        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
194        unsafe { llama_cpp_bindings_sys::llama_memory_seq_pos_max(mem, seq_id) }
195    }
196}