llama_cpp_bindings/context/
kv_cache.rs1use std::ffi::c_int;
4use std::num::{NonZeroU8, TryFromIntError};
5use std::os::raw::c_char;
6use std::ptr;
7
8use crate::context::LlamaContext;
9use crate::error::{KvCacheSeqAddError, KvCacheSeqDivError};
10use crate::ffi_error_reader::read_and_free_cpp_error;
11
12#[derive(Debug, Eq, PartialEq, thiserror::Error)]
14pub enum KvCacheConversionError {
15 #[error("Provided sequence id is too large for a i32")]
17 SeqIdTooLarge(#[source] TryFromIntError),
18 #[error("Provided start position is too large for a i32")]
20 P0TooLarge(#[source] TryFromIntError),
21 #[error("Provided end position is too large for a i32")]
23 P1TooLarge(#[source] TryFromIntError),
24}
25
26impl LlamaContext<'_> {
27 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
35 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
36 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
37 }
38
39 pub fn copy_kv_cache_seq(
53 &mut self,
54 src: i32,
55 dest: i32,
56 p0: Option<u32>,
57 p1: Option<u32>,
58 ) -> Result<(), KvCacheConversionError> {
59 let p0 = p0
60 .map_or(Ok(-1), i32::try_from)
61 .map_err(KvCacheConversionError::P0TooLarge)?;
62 let p1 = p1
63 .map_or(Ok(-1), i32::try_from)
64 .map_err(KvCacheConversionError::P1TooLarge)?;
65 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
66 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
67 Ok(())
68 }
69
70 pub fn clear_kv_cache_seq(
85 &mut self,
86 src: Option<u32>,
87 p0: Option<u32>,
88 p1: Option<u32>,
89 ) -> Result<bool, KvCacheConversionError> {
90 let src = src
91 .map_or(Ok(-1), i32::try_from)
92 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
93 let p0 = p0
94 .map_or(Ok(-1), i32::try_from)
95 .map_err(KvCacheConversionError::P0TooLarge)?;
96 let p1 = p1
97 .map_or(Ok(-1), i32::try_from)
98 .map_err(KvCacheConversionError::P1TooLarge)?;
99 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
100 Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
101 }
102
103 pub fn clear_kv_cache(&mut self) {
105 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
106 let clear_data_buffers = true;
107 unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
108 }
109
110 pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
116 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
117 unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
118 }
119
120 pub fn kv_cache_seq_add(
138 &mut self,
139 seq_id: i32,
140 p0: Option<u32>,
141 p1: Option<u32>,
142 delta: i32,
143 ) -> Result<(), KvCacheSeqAddError> {
144 let p0 = p0
145 .map_or(Ok(-1), i32::try_from)
146 .map_err(KvCacheSeqAddError::P0TooLarge)?;
147 let p1 = p1
148 .map_or(Ok(-1), i32::try_from)
149 .map_err(KvCacheSeqAddError::P1TooLarge)?;
150 let mut out_error: *mut c_char = ptr::null_mut();
151 let status = unsafe {
152 llama_cpp_bindings_sys::llama_rs_memory_seq_add(
153 self.context.as_ptr(),
154 seq_id,
155 p0,
156 p1,
157 delta,
158 &raw mut out_error,
159 )
160 };
161 match status {
162 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_OK => Ok(()),
163 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_INCOMPATIBLE_ROPE_TYPE => {
164 Err(KvCacheSeqAddError::IncompatibleRopeType)
165 }
166 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_NULL_MEM => {
167 Err(KvCacheSeqAddError::MemoryHandleUnavailable)
168 }
169 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_ERROR_STRING_ALLOCATION_FAILED => {
170 Err(KvCacheSeqAddError::NotEnoughMemory)
171 }
172 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_VENDORED_THREW_CXX_EXCEPTION => {
173 let message = unsafe { read_and_free_cpp_error(out_error) };
174 Err(KvCacheSeqAddError::Reported { message })
175 }
176 other => unreachable!("llama_rs_memory_seq_add returned unrecognized status {other}"),
177 }
178 }
179
180 pub fn kv_cache_seq_div(
198 &mut self,
199 seq_id: i32,
200 p0: Option<u32>,
201 p1: Option<u32>,
202 d: NonZeroU8,
203 ) -> Result<(), KvCacheSeqDivError> {
204 let p0 = p0
205 .map_or(Ok(-1), i32::try_from)
206 .map_err(KvCacheSeqDivError::P0TooLarge)?;
207 let p1 = p1
208 .map_or(Ok(-1), i32::try_from)
209 .map_err(KvCacheSeqDivError::P1TooLarge)?;
210 let d = c_int::from(d.get());
211 let mut out_error: *mut c_char = ptr::null_mut();
212 let status = unsafe {
213 llama_cpp_bindings_sys::llama_rs_memory_seq_div(
214 self.context.as_ptr(),
215 seq_id,
216 p0,
217 p1,
218 d,
219 &raw mut out_error,
220 )
221 };
222 match status {
223 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_OK => Ok(()),
224 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_INCOMPATIBLE_ROPE_TYPE => {
225 Err(KvCacheSeqDivError::IncompatibleRopeType)
226 }
227 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_NULL_MEM => {
228 Err(KvCacheSeqDivError::MemoryHandleUnavailable)
229 }
230 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_ERROR_STRING_ALLOCATION_FAILED => {
231 Err(KvCacheSeqDivError::NotEnoughMemory)
232 }
233 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_VENDORED_THREW_CXX_EXCEPTION => {
234 let message = unsafe { read_and_free_cpp_error(out_error) };
235 Err(KvCacheSeqDivError::Reported { message })
236 }
237 other => unreachable!("llama_rs_memory_seq_div returned unrecognized status {other}"),
238 }
239 }
240
241 #[must_use]
247 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
248 unsafe {
249 llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
250 }
251 }
252}