Skip to main content

llama_cpp_4/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::num::{NonZeroU8, TryFromIntError};
5
6/// Errors that can occur when attempting to prepare values for the kv cache
7#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8pub enum KvCacheConversionError {
9    /// Sequence id conversion to i32 failed
10    #[error("Provided sequence id is too large for a i32")]
11    SeqIdTooLarge(#[source] TryFromIntError),
12    /// Position 0 conversion to i32 failed
13    #[error("Provided start position is too large for a i32")]
14    P0TooLarge(#[source] TryFromIntError),
15    /// Position 1 conversion to i32 failed
16    #[error("Provided end position is too large for a i32")]
17    P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21    /// Copy the cache from one sequence to another.
22    ///
23    /// # Parameters
24    ///
25    /// * `src` - The sequence id to copy the cache from.
26    /// * `dest` - The sequence id to copy the cache to.
27    /// * `size` - The size of the cache to copy.
28    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
29        unsafe {
30            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
31            llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, 0, size);
32        }
33    }
34
35    /// Copy the cache from one sequence to another.
36    ///
37    /// # Parameters
38    ///
39    /// * `src` - The sequence id to copy the cache from.
40    /// * `dest` - The sequence id to copy the cache to.
41    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
42    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
43    ///
44    /// # Errors
45    ///
46    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
47    pub fn copy_kv_cache_seq(
48        &mut self,
49        src: i32,
50        dest: i32,
51        p0: Option<u32>,
52        p1: Option<u32>,
53    ) -> Result<(), KvCacheConversionError> {
54        let p0 = p0
55            .map_or(Ok(-1), i32::try_from)
56            .map_err(KvCacheConversionError::P0TooLarge)?;
57        let p1 = p1
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P1TooLarge)?;
60        unsafe {
61            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
62            llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, p0, p1);
63        }
64        Ok(())
65    }
66
67    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`.
68    ///
69    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70    ///
71    /// # Parameters
72    ///
73    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`KvCacheConversionError`] if the sequence id or either position exceeds the maximum `i32` value.
80    pub fn clear_kv_cache_seq(
81        &mut self,
82        src: Option<u32>,
83        p0: Option<u32>,
84        p1: Option<u32>,
85    ) -> Result<bool, KvCacheConversionError> {
86        let src = src
87            .map_or(Ok(-1), i32::try_from)
88            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89        let p0 = p0
90            .map_or(Ok(-1), i32::try_from)
91            .map_err(KvCacheConversionError::P0TooLarge)?;
92        let p1 = p1
93            .map_or(Ok(-1), i32::try_from)
94            .map_err(KvCacheConversionError::P1TooLarge)?;
95        #[cfg(feature = "mtp")]
96        {
97            Ok(
98                unsafe {
99                    llama_cpp_sys_4::llama_context_seq_rm(self.context.as_ptr(), src, p0, p1)
100                },
101            )
102        }
103        #[cfg(not(feature = "mtp"))]
104        {
105            Ok(unsafe {
106                let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
107                llama_cpp_sys_4::llama_memory_seq_rm(mem, src, p0, p1)
108            })
109        }
110    }
111
112    /// Clear the KV cache
113    pub fn clear_kv_cache(&mut self) {
114        unsafe {
115            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
116            llama_cpp_sys_4::llama_memory_clear(mem, true);
117        }
118    }
119
120    /// Removes all tokens that do not belong to the specified sequence
121    ///
122    /// # Parameters
123    ///
124    /// * `seq_id` - The sequence id to keep
125    pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
126        unsafe {
127            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
128            llama_cpp_sys_4::llama_memory_seq_keep(mem, seq_id);
129        }
130    }
131
132    #[allow(clippy::doc_markdown)]
133    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`.
134    ///
135    /// # Parameters
136    ///
137    /// * `seq_id` - The sequence id to update
138    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
139    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
140    /// * `delta` - The relative position to add to the tokens
141    ///
142    /// # Errors
143    ///
144    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
145    pub fn kv_cache_seq_add(
146        &mut self,
147        seq_id: i32,
148        p0: Option<u32>,
149        p1: Option<u32>,
150        delta: i32,
151    ) -> Result<(), KvCacheConversionError> {
152        let p0 = p0
153            .map_or(Ok(-1), i32::try_from)
154            .map_err(KvCacheConversionError::P0TooLarge)?;
155        let p1 = p1
156            .map_or(Ok(-1), i32::try_from)
157            .map_err(KvCacheConversionError::P1TooLarge)?;
158        unsafe {
159            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
160            llama_cpp_sys_4::llama_memory_seq_add(mem, seq_id, p0, p1, delta);
161        }
162        Ok(())
163    }
164
165    /// Integer division of the positions by factor of `d > 1`.
166    ///
167    /// # Parameters
168    ///
169    /// * `seq_id` - The sequence id to update
170    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
171    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
172    /// * `d` - The factor to divide the positions by
173    ///
174    /// # Errors
175    ///
176    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
177    pub fn kv_cache_seq_div(
178        &mut self,
179        seq_id: i32,
180        p0: Option<u32>,
181        p1: Option<u32>,
182        d: NonZeroU8,
183    ) -> Result<(), KvCacheConversionError> {
184        let p0 = p0
185            .map_or(Ok(-1), i32::try_from)
186            .map_err(KvCacheConversionError::P0TooLarge)?;
187        let p1 = p1
188            .map_or(Ok(-1), i32::try_from)
189            .map_err(KvCacheConversionError::P1TooLarge)?;
190        let d = i32::from(d.get());
191        unsafe {
192            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
193            llama_cpp_sys_4::llama_memory_seq_div(mem, seq_id, p0, p1, d);
194        }
195        Ok(())
196    }
197
198    /// Returns the largest position present in the KV cache for the specified sequence
199    ///
200    /// # Parameters
201    ///
202    /// * `seq_id` - The sequence id to get the max position for
203    #[must_use]
204    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
205        unsafe {
206            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
207            llama_cpp_sys_4::llama_memory_seq_pos_max(mem, seq_id)
208        }
209    }
210}