Skip to main content

llama_cpp_bindings/context/
kv_cache.rs

1use std::ffi::c_int;
2use std::num::{NonZeroU8, TryFromIntError};
3use std::os::raw::c_char;
4use std::ptr;
5
6use crate::context::LlamaContext;
7use crate::error::{KvCacheSeqAddError, KvCacheSeqDivError};
8use crate::ffi_error_reader::read_and_free_cpp_error;
9
10#[derive(Debug, Eq, PartialEq, thiserror::Error)]
11pub enum KvCacheConversionError {
12    #[error("Provided sequence id is too large for a i32")]
13    SeqIdTooLarge(#[source] TryFromIntError),
14    #[error("Provided start position is too large for a i32")]
15    P0TooLarge(#[source] TryFromIntError),
16    #[error("Provided end position is too large for a i32")]
17    P1TooLarge(#[source] TryFromIntError),
18}
19
20impl LlamaContext<'_> {
21    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
22        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
23        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
24    }
25
26    /// # Errors
27    /// If either position exceeds [`i32::MAX`].
28    pub fn copy_kv_cache_seq(
29        &mut self,
30        src: i32,
31        dest: i32,
32        p0: Option<u32>,
33        p1: Option<u32>,
34    ) -> Result<(), KvCacheConversionError> {
35        let p0 = p0
36            .map_or(Ok(-1), i32::try_from)
37            .map_err(KvCacheConversionError::P0TooLarge)?;
38        let p1 = p1
39            .map_or(Ok(-1), i32::try_from)
40            .map_err(KvCacheConversionError::P1TooLarge)?;
41        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
42        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
43        Ok(())
44    }
45
46    /// # Errors
47    /// If the sequence id or either position exceeds [`i32::MAX`].
48    pub fn clear_kv_cache_seq(
49        &mut self,
50        src: Option<u32>,
51        p0: Option<u32>,
52        p1: Option<u32>,
53    ) -> Result<bool, KvCacheConversionError> {
54        let src = src
55            .map_or(Ok(-1), i32::try_from)
56            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
57        let p0 = p0
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P0TooLarge)?;
60        let p1 = p1
61            .map_or(Ok(-1), i32::try_from)
62            .map_err(KvCacheConversionError::P1TooLarge)?;
63        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64        Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
65    }
66
67    pub fn clear_kv_cache(&mut self) {
68        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
69        let clear_data_buffers = true;
70        unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, clear_data_buffers) }
71    }
72
73    pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
74        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
75        unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
76    }
77
78    /// # Errors
79    /// If either position exceeds [`i32::MAX`], or the underlying memory operation reports a failure.
80    pub fn kv_cache_seq_add(
81        &mut self,
82        seq_id: i32,
83        p0: Option<u32>,
84        p1: Option<u32>,
85        delta: i32,
86    ) -> Result<(), KvCacheSeqAddError> {
87        let p0 = p0
88            .map_or(Ok(-1), i32::try_from)
89            .map_err(KvCacheSeqAddError::P0TooLarge)?;
90        let p1 = p1
91            .map_or(Ok(-1), i32::try_from)
92            .map_err(KvCacheSeqAddError::P1TooLarge)?;
93        let mut out_error: *mut c_char = ptr::null_mut();
94        let status = unsafe {
95            llama_cpp_bindings_sys::llama_rs_memory_seq_add(
96                self.context.as_ptr(),
97                seq_id,
98                p0,
99                p1,
100                delta,
101                &raw mut out_error,
102            )
103        };
104        match status {
105            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_OK => Ok(()),
106            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_INCOMPATIBLE_ROPE_TYPE => {
107                Err(KvCacheSeqAddError::IncompatibleRopeType)
108            }
109            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_NULL_MEM => {
110                Err(KvCacheSeqAddError::MemoryHandleUnavailable)
111            }
112            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_ERROR_STRING_ALLOCATION_FAILED => {
113                Err(KvCacheSeqAddError::NotEnoughMemory)
114            }
115            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_ADD_VENDORED_THREW_CXX_EXCEPTION => {
116                let message = unsafe { read_and_free_cpp_error(out_error) };
117                Err(KvCacheSeqAddError::Reported { message })
118            }
119            other => unreachable!("llama_rs_memory_seq_add returned unrecognized status {other}"),
120        }
121    }
122
123    /// # Errors
124    /// If either position exceeds [`i32::MAX`], or the underlying memory operation reports a failure.
125    pub fn kv_cache_seq_div(
126        &mut self,
127        seq_id: i32,
128        p0: Option<u32>,
129        p1: Option<u32>,
130        d: NonZeroU8,
131    ) -> Result<(), KvCacheSeqDivError> {
132        let p0 = p0
133            .map_or(Ok(-1), i32::try_from)
134            .map_err(KvCacheSeqDivError::P0TooLarge)?;
135        let p1 = p1
136            .map_or(Ok(-1), i32::try_from)
137            .map_err(KvCacheSeqDivError::P1TooLarge)?;
138        let d = c_int::from(d.get());
139        let mut out_error: *mut c_char = ptr::null_mut();
140        let status = unsafe {
141            llama_cpp_bindings_sys::llama_rs_memory_seq_div(
142                self.context.as_ptr(),
143                seq_id,
144                p0,
145                p1,
146                d,
147                &raw mut out_error,
148            )
149        };
150        match status {
151            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_OK => Ok(()),
152            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_INCOMPATIBLE_ROPE_TYPE => {
153                Err(KvCacheSeqDivError::IncompatibleRopeType)
154            }
155            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_NULL_MEM => {
156                Err(KvCacheSeqDivError::MemoryHandleUnavailable)
157            }
158            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_ERROR_STRING_ALLOCATION_FAILED => {
159                Err(KvCacheSeqDivError::NotEnoughMemory)
160            }
161            llama_cpp_bindings_sys::LLAMA_RS_MEMORY_SEQ_DIV_VENDORED_THREW_CXX_EXCEPTION => {
162                let message = unsafe { read_and_free_cpp_error(out_error) };
163                Err(KvCacheSeqDivError::Reported { message })
164            }
165            other => unreachable!("llama_rs_memory_seq_div returned unrecognized status {other}"),
166        }
167    }
168
169    #[must_use]
170    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
171        unsafe {
172            llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
173        }
174    }
175}