llama_cpp_bindings/context/
kv_cache.rs1use std::ffi::c_int;
2use std::num::{NonZeroU8, TryFromIntError};
3use std::os::raw::c_char;
4use std::ptr;
5
6use crate::context::LlamaContext;
7use crate::error::{KvCacheSeqAddError, KvCacheSeqDivError};
8use crate::ffi_error_reader::read_and_free_cpp_error;
9
10#[derive(Debug, Eq, PartialEq, thiserror::Error)]
11pub enum KvCacheConversionError {
12 #[error("Provided sequence id is too large for a i32")]
13 SeqIdTooLarge(#[source] TryFromIntError),
14 #[error("Provided start position is too large for a i32")]
15 P0TooLarge(#[source] TryFromIntError),
16 #[error("Provided end position is too large for a i32")]
17 P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
22 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
23 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
24 }
25
26 pub fn copy_kv_cache_seq(
29 &mut self,
30 src: i32,
31 dest: i32,
32 p0: Option<u32>,
33 p1: Option<u32>,
34 ) -> Result<(), KvCacheConversionError> {
35 let p0 = p0
36 .map_or(Ok(-1), i32::try_from)
37 .map_err(KvCacheConversionError::P0TooLarge)?;
38 let p1 = p1
39 .map_or(Ok(-1), i32::try_from)
40 .map_err(KvCacheConversionError::P1TooLarge)?;
41 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
42 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
43 Ok(())
44 }
45
46 pub fn clear_kv_cache_seq(
49 &mut self,
50 src: Option<u32>,
51 p0: Option<u32>,
52 p1: Option<u32>,
53 ) -> Result<bool, KvCacheConversionError> {
54 let src = src
55 .map_or(Ok(-1), i32::try_from)
56 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
57 let p0 = p0
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P0TooLarge)?;
60 let p1 = p1
61 .map_or(Ok(-1), i32::try_from)
62 .map_err(KvCacheConversionError::P1TooLarge)?;
63 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64 Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
65 }
66
67 pub fn clear_kv_cache(&mut self) {
68 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
69 let clear_data_buffers = true;
70 unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
71 }
72
73 pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
74 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
75 unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
76 }
77
78 pub fn kv_cache_seq_add(
81 &mut self,
82 seq_id: i32,
83 p0: Option<u32>,
84 p1: Option<u32>,
85 delta: i32,
86 ) -> Result<(), KvCacheSeqAddError> {
87 let p0 = p0
88 .map_or(Ok(-1), i32::try_from)
89 .map_err(KvCacheSeqAddError::P0TooLarge)?;
90 let p1 = p1
91 .map_or(Ok(-1), i32::try_from)
92 .map_err(KvCacheSeqAddError::P1TooLarge)?;
93 let mut out_error: *mut c_char = ptr::null_mut();
94 let status = unsafe {
95 llama_cpp_bindings_sys::llama_rs_memory_seq_add(
96 self.context.as_ptr(),
97 seq_id,
98 p0,
99 p1,
100 delta,
101 &raw mut out_error,
102 )
103 };
104 match status {
105 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_OK => Ok(()),
106 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_INCOMPATIBLE_ROPE_TYPE => {
107 Err(KvCacheSeqAddError::IncompatibleRopeType)
108 }
109 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_NULL_MEM => {
110 Err(KvCacheSeqAddError::MemoryHandleUnavailable)
111 }
112 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_ERROR_STRING_ALLOCATION_FAILED => {
113 Err(KvCacheSeqAddError::NotEnoughMemory)
114 }
115 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_VENDORED_THREW_CXX_EXCEPTION => {
116 let message = unsafe { read_and_free_cpp_error(out_error) };
117 Err(KvCacheSeqAddError::Reported { message })
118 }
119 other => unreachable!("llama_rs_memory_seq_add returned unrecognized status {other}"),
120 }
121 }
122
123 pub fn kv_cache_seq_div(
126 &mut self,
127 seq_id: i32,
128 p0: Option<u32>,
129 p1: Option<u32>,
130 d: NonZeroU8,
131 ) -> Result<(), KvCacheSeqDivError> {
132 let p0 = p0
133 .map_or(Ok(-1), i32::try_from)
134 .map_err(KvCacheSeqDivError::P0TooLarge)?;
135 let p1 = p1
136 .map_or(Ok(-1), i32::try_from)
137 .map_err(KvCacheSeqDivError::P1TooLarge)?;
138 let d = c_int::from(d.get());
139 let mut out_error: *mut c_char = ptr::null_mut();
140 let status = unsafe {
141 llama_cpp_bindings_sys::llama_rs_memory_seq_div(
142 self.context.as_ptr(),
143 seq_id,
144 p0,
145 p1,
146 d,
147 &raw mut out_error,
148 )
149 };
150 match status {
151 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_OK => Ok(()),
152 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_INCOMPATIBLE_ROPE_TYPE => {
153 Err(KvCacheSeqDivError::IncompatibleRopeType)
154 }
155 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_NULL_MEM => {
156 Err(KvCacheSeqDivError::MemoryHandleUnavailable)
157 }
158 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_ERROR_STRING_ALLOCATION_FAILED => {
159 Err(KvCacheSeqDivError::NotEnoughMemory)
160 }
161 llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_VENDORED_THREW_CXX_EXCEPTION => {
162 let message = unsafe { read_and_free_cpp_error(out_error) };
163 Err(KvCacheSeqDivError::Reported { message })
164 }
165 other => unreachable!("llama_rs_memory_seq_div returned unrecognized status {other}"),
166 }
167 }
168
169 #[must_use]
170 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
171 unsafe {
172 llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
173 }
174 }
175}