Skip to main content

llama_cpp_4/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::num::{NonZeroU8, TryFromIntError};
5
6/// Errors that can occur when attempting to prepare values for the kv cache
7#[derive(Debug, Eq, PartialEq, thiserror::Error)]
8pub enum KvCacheConversionError {
9    /// Sequence id conversion to i32 failed
10    #[error("Provided sequence id is too large for a i32")]
11    SeqIdTooLarge(#[source] TryFromIntError),
12    /// Position 0 conversion to i32 failed
13    #[error("Provided start position is too large for a i32")]
14    P0TooLarge(#[source] TryFromIntError),
15    /// Position 1 conversion to i32 failed
16    #[error("Provided end position is too large for a i32")]
17    P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21    /// Copy the cache from one sequence to another.
22    ///
23    /// # Parameters
24    ///
25    /// * `src` - The sequence id to copy the cache from.
26    /// * `dest` - The sequence id to copy the cache to.
27    /// * `size` - The size of the cache to copy.
28    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
29        unsafe {
30            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
31            llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, 0, size);
32        }
33    }
34
35    /// Copy the cache from one sequence to another.
36    ///
37    /// # Parameters
38    ///
39    /// * `src` - The sequence id to copy the cache from.
40    /// * `dest` - The sequence id to copy the cache to.
41    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
42    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
43    ///
44    /// # Errors
45    ///
46    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
47    pub fn copy_kv_cache_seq(
48        &mut self,
49        src: i32,
50        dest: i32,
51        p0: Option<u32>,
52        p1: Option<u32>,
53    ) -> Result<(), KvCacheConversionError> {
54        let p0 = p0
55            .map_or(Ok(-1), i32::try_from)
56            .map_err(KvCacheConversionError::P0TooLarge)?;
57        let p1 = p1
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P1TooLarge)?;
60        unsafe {
61            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
62            llama_cpp_sys_4::llama_memory_seq_cp(mem, src, dest, p0, p1);
63        }
64        Ok(())
65    }
66
67    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`.
68    ///
69    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70    ///
71    /// # Parameters
72    ///
73    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
74    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
75    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
76    ///
77    /// # Errors
78    ///
79    /// Returns [`KvCacheConversionError`] if the sequence id or either position exceeds the maximum `i32` value.
80    pub fn clear_kv_cache_seq(
81        &mut self,
82        src: Option<u32>,
83        p0: Option<u32>,
84        p1: Option<u32>,
85    ) -> Result<bool, KvCacheConversionError> {
86        let src = src
87            .map_or(Ok(-1), i32::try_from)
88            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
89        let p0 = p0
90            .map_or(Ok(-1), i32::try_from)
91            .map_err(KvCacheConversionError::P0TooLarge)?;
92        let p1 = p1
93            .map_or(Ok(-1), i32::try_from)
94            .map_err(KvCacheConversionError::P1TooLarge)?;
95        Ok(unsafe {
96            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
97            llama_cpp_sys_4::llama_memory_seq_rm(mem, src, p0, p1)
98        })
99    }
100
101    /// Clear the KV cache
102    pub fn clear_kv_cache(&mut self) {
103        unsafe {
104            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
105            llama_cpp_sys_4::llama_memory_clear(mem, true);
106        }
107    }
108
109    /// Removes all tokens that do not belong to the specified sequence
110    ///
111    /// # Parameters
112    ///
113    /// * `seq_id` - The sequence id to keep
114    pub fn llama_kv_cache_seq_keep(&mut self, seq_id: i32) {
115        unsafe {
116            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
117            llama_cpp_sys_4::llama_memory_seq_keep(mem, seq_id);
118        }
119    }
120
121    #[allow(clippy::doc_markdown)]
122    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`.
123    ///
124    /// # Parameters
125    ///
126    /// * `seq_id` - The sequence id to update
127    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
128    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
129    /// * `delta` - The relative position to add to the tokens
130    ///
131    /// # Errors
132    ///
133    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
134    pub fn kv_cache_seq_add(
135        &mut self,
136        seq_id: i32,
137        p0: Option<u32>,
138        p1: Option<u32>,
139        delta: i32,
140    ) -> Result<(), KvCacheConversionError> {
141        let p0 = p0
142            .map_or(Ok(-1), i32::try_from)
143            .map_err(KvCacheConversionError::P0TooLarge)?;
144        let p1 = p1
145            .map_or(Ok(-1), i32::try_from)
146            .map_err(KvCacheConversionError::P1TooLarge)?;
147        unsafe {
148            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
149            llama_cpp_sys_4::llama_memory_seq_add(mem, seq_id, p0, p1, delta);
150        }
151        Ok(())
152    }
153
154    /// Integer division of the positions by factor of `d > 1`.
155    ///
156    /// # Parameters
157    ///
158    /// * `seq_id` - The sequence id to update
159    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
160    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
161    /// * `d` - The factor to divide the positions by
162    ///
163    /// # Errors
164    ///
165    /// Returns [`KvCacheConversionError`] if either position exceeds the maximum `i32` value.
166    pub fn kv_cache_seq_div(
167        &mut self,
168        seq_id: i32,
169        p0: Option<u32>,
170        p1: Option<u32>,
171        d: NonZeroU8,
172    ) -> Result<(), KvCacheConversionError> {
173        let p0 = p0
174            .map_or(Ok(-1), i32::try_from)
175            .map_err(KvCacheConversionError::P0TooLarge)?;
176        let p1 = p1
177            .map_or(Ok(-1), i32::try_from)
178            .map_err(KvCacheConversionError::P1TooLarge)?;
179        let d = i32::from(d.get());
180        unsafe {
181            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
182            llama_cpp_sys_4::llama_memory_seq_div(mem, seq_id, p0, p1, d);
183        }
184        Ok(())
185    }
186
187    /// Returns the largest position present in the KV cache for the specified sequence
188    ///
189    /// # Parameters
190    ///
191    /// * `seq_id` - The sequence id to get the max position for
192    #[must_use]
193    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
194        unsafe {
195            let mem = llama_cpp_sys_4::llama_get_memory(self.context.as_ptr());
196            llama_cpp_sys_4::llama_memory_seq_pos_max(mem, seq_id)
197        }
198    }
199}