use crate::context::LlamaContext;
use std::ffi::c_int;
use std::num::{NonZeroU8, TryFromIntError};
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum KvCacheConversionError {
#[error("Provided sequence id is too large for a i32")]
SeqIdTooLarge(#[source] TryFromIntError),
#[error("Provided start position is too large for a i32")]
P0TooLarge(#[source] TryFromIntError),
#[error("Provided end position is too large for a i32")]
P1TooLarge(#[source] TryFromIntError),
#[error("operation not supported by this model: {0}")]
UnsupportedOperation(String),
}
impl LlamaContext<'_> {
pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
}
pub fn copy_kv_cache_seq(
&mut self,
src: i32,
dest: i32,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
Ok(())
}
pub fn clear_kv_cache_seq(
&mut self,
src: Option<u32>,
p0: Option<u32>,
p1: Option<u32>,
) -> Result<bool, KvCacheConversionError> {
let src = src
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::SeqIdTooLarge)?;
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
}
pub fn clear_kv_cache(&mut self) {
let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, true) }
}
pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
}
pub fn kv_cache_seq_add(
&mut self,
seq_id: i32,
p0: Option<u32>,
p1: Option<u32>,
delta: i32,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_memory_seq_add(
self.context.as_ptr(),
seq_id,
p0,
p1,
delta,
)
};
if crate::status_is_ok(status) {
Ok(())
} else {
Err(KvCacheConversionError::UnsupportedOperation(format!(
"kv_cache_seq_add failed (status {})",
crate::status_to_i32(status)
)))
}
}
pub fn kv_cache_seq_div(
&mut self,
seq_id: i32,
p0: Option<u32>,
p1: Option<u32>,
d: NonZeroU8,
) -> Result<(), KvCacheConversionError> {
let p0 = p0
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P0TooLarge)?;
let p1 = p1
.map_or(Ok(-1), i32::try_from)
.map_err(KvCacheConversionError::P1TooLarge)?;
let d = c_int::from(d.get());
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_memory_seq_div(
self.context.as_ptr(),
seq_id,
p0,
p1,
d,
)
};
if crate::status_is_ok(status) {
Ok(())
} else {
Err(KvCacheConversionError::UnsupportedOperation(format!(
"kv_cache_seq_div failed (status {})",
crate::status_to_i32(status)
)))
}
}
#[must_use]
pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
unsafe {
llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
}
}
}
#[cfg(test)]
#[cfg(feature = "tests_that_use_llms")]
mod tests {
use std::num::NonZeroU32;
use serial_test::serial;
use crate::context::params::LlamaContextParams;
use crate::llama_batch::LlamaBatch;
use crate::model::AddBos;
use crate::test_model;
#[test]
#[serial]
fn clear_kv_cache_resets_positions() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
context.clear_kv_cache();
assert_eq!(context.kv_cache_seq_pos_max(0), -1);
}
#[test]
#[serial]
fn kv_cache_seq_pos_max_is_non_negative_after_decode() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
assert!(context.kv_cache_seq_pos_max(0) >= 0);
}
#[test]
#[serial]
fn clear_kv_cache_seq_with_range() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(1));
assert!(result.is_ok());
}
#[test]
#[serial]
fn copy_kv_cache_seq_succeeds() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let result = context.copy_kv_cache_seq(0, 1, None, None);
assert!(result.is_ok());
}
#[test]
#[serial]
fn copy_cache_executes_without_crash() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let pos_max = context.kv_cache_seq_pos_max(0);
context.copy_cache(0, 1, pos_max + 1);
}
#[test]
#[serial]
fn kv_cache_seq_add_returns_error_for_mrope_model() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let result = context.kv_cache_seq_add(0, Some(0), None, 1);
assert!(result.is_err());
}
#[test]
#[serial]
fn kv_cache_seq_div_returns_error_for_mrope_model() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let divisor = std::num::NonZeroU8::new(2).unwrap();
let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
assert!(result.is_err());
}
#[test]
#[serial]
fn kv_cache_seq_keep_retains_specified_sequence() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
context.kv_cache_seq_keep(0);
assert!(context.kv_cache_seq_pos_max(0) >= 0);
}
#[test]
#[serial]
fn copy_kv_cache_seq_with_explicit_range() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let result = context.copy_kv_cache_seq(0, 2, Some(0), Some(1));
assert!(result.is_ok());
}
#[test]
#[serial]
fn kv_cache_seq_add_succeeds_on_embedding_model() {
let (backend, model) = test_model::load_default_embedding_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let result = context.kv_cache_seq_add(0, Some(0), None, 1);
assert!(result.is_ok());
}
#[test]
#[serial]
fn kv_cache_seq_div_succeeds_on_embedding_model() {
let (backend, model) = test_model::load_default_embedding_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
let mut batch = LlamaBatch::new(512, 1).unwrap();
batch.add_sequence(&tokens, 0, false).unwrap();
context.decode(&mut batch).unwrap();
let divisor = std::num::NonZeroU8::new(2).unwrap();
let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
assert!(result.is_ok());
}
#[test]
#[serial]
fn kv_cache_seq_pos_max_returns_negative_one_for_unused_seq() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let context = model.new_context(&backend, ctx_params).unwrap();
let result = context.kv_cache_seq_pos_max(999);
assert_eq!(result, -1);
}
#[test]
#[serial]
fn copy_kv_cache_seq_rejects_p0_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.copy_kv_cache_seq(0, 1, Some(u32::MAX), None);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn copy_kv_cache_seq_rejects_p1_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.copy_kv_cache_seq(0, 1, Some(0), Some(u32::MAX));
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn clear_kv_cache_seq_rejects_src_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.clear_kv_cache_seq(Some(u32::MAX), None, None);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::SeqIdTooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn clear_kv_cache_seq_rejects_p0_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.clear_kv_cache_seq(Some(0), Some(u32::MAX), None);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn clear_kv_cache_seq_rejects_p1_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(u32::MAX));
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn kv_cache_seq_add_rejects_p0_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.kv_cache_seq_add(0, Some(u32::MAX), None, 1);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn kv_cache_seq_add_rejects_p1_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let result = context.kv_cache_seq_add(0, Some(0), Some(u32::MAX), 1);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn kv_cache_seq_div_rejects_p0_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let divisor = std::num::NonZeroU8::new(2).unwrap();
let result = context.kv_cache_seq_div(0, Some(u32::MAX), None, divisor);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
#[test]
#[serial]
fn kv_cache_seq_div_rejects_p1_exceeding_i32_max() {
let (backend, model) = test_model::load_default_model().unwrap();
let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
let mut context = model.new_context(&backend, ctx_params).unwrap();
let divisor = std::num::NonZeroU8::new(2).unwrap();
let result = context.kv_cache_seq_div(0, Some(0), Some(u32::MAX), divisor);
assert_eq!(
result.unwrap_err(),
super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
);
}
}