1use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10 #[error("Provided sequence id is too large for a i32")]
12 SeqIdTooLarge(#[source] TryFromIntError),
13 #[error("Provided start position is too large for a i32")]
15 P0TooLarge(#[source] TryFromIntError),
16 #[error("Provided end position is too large for a i32")]
18 P1TooLarge(#[source] TryFromIntError),
19 #[error("operation not supported by this model: {0}")]
21 UnsupportedOperation(String),
22}
23
24impl LlamaContext<'_> {
25 pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
33 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
34 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
35 }
36
37 pub fn copy_kv_cache_seq(
51 &mut self,
52 src: i32,
53 dest: i32,
54 p0: Option<u32>,
55 p1: Option<u32>,
56 ) -> Result<(), KvCacheConversionError> {
57 let p0 = p0
58 .map_or(Ok(-1), i32::try_from)
59 .map_err(KvCacheConversionError::P0TooLarge)?;
60 let p1 = p1
61 .map_or(Ok(-1), i32::try_from)
62 .map_err(KvCacheConversionError::P1TooLarge)?;
63 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64 unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
65 Ok(())
66 }
67
68 pub fn clear_kv_cache_seq(
83 &mut self,
84 src: Option<u32>,
85 p0: Option<u32>,
86 p1: Option<u32>,
87 ) -> Result<bool, KvCacheConversionError> {
88 let src = src
89 .map_or(Ok(-1), i32::try_from)
90 .map_err(KvCacheConversionError::SeqIdTooLarge)?;
91 let p0 = p0
92 .map_or(Ok(-1), i32::try_from)
93 .map_err(KvCacheConversionError::P0TooLarge)?;
94 let p1 = p1
95 .map_or(Ok(-1), i32::try_from)
96 .map_err(KvCacheConversionError::P1TooLarge)?;
97 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
98 Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
99 }
100
101 pub fn clear_kv_cache(&mut self) {
103 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
104 unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, true) }
106 }
107
108 pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
114 let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
115 unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
116 }
117
118 pub fn kv_cache_seq_add(
136 &mut self,
137 seq_id: i32,
138 p0: Option<u32>,
139 p1: Option<u32>,
140 delta: i32,
141 ) -> Result<(), KvCacheConversionError> {
142 let p0 = p0
143 .map_or(Ok(-1), i32::try_from)
144 .map_err(KvCacheConversionError::P0TooLarge)?;
145 let p1 = p1
146 .map_or(Ok(-1), i32::try_from)
147 .map_err(KvCacheConversionError::P1TooLarge)?;
148 let status = unsafe {
149 llama_cpp_bindings_sys::llama_rs_memory_seq_add(
150 self.context.as_ptr(),
151 seq_id,
152 p0,
153 p1,
154 delta,
155 )
156 };
157
158 if crate::status_is_ok(status) {
159 Ok(())
160 } else {
161 Err(KvCacheConversionError::UnsupportedOperation(format!(
162 "kv_cache_seq_add failed (status {})",
163 crate::status_to_i32(status)
164 )))
165 }
166 }
167
168 pub fn kv_cache_seq_div(
186 &mut self,
187 seq_id: i32,
188 p0: Option<u32>,
189 p1: Option<u32>,
190 d: NonZeroU8,
191 ) -> Result<(), KvCacheConversionError> {
192 let p0 = p0
193 .map_or(Ok(-1), i32::try_from)
194 .map_err(KvCacheConversionError::P0TooLarge)?;
195 let p1 = p1
196 .map_or(Ok(-1), i32::try_from)
197 .map_err(KvCacheConversionError::P1TooLarge)?;
198 let d = c_int::from(d.get());
199 let status = unsafe {
200 llama_cpp_bindings_sys::llama_rs_memory_seq_div(
201 self.context.as_ptr(),
202 seq_id,
203 p0,
204 p1,
205 d,
206 )
207 };
208
209 if crate::status_is_ok(status) {
210 Ok(())
211 } else {
212 Err(KvCacheConversionError::UnsupportedOperation(format!(
213 "kv_cache_seq_div failed (status {})",
214 crate::status_to_i32(status)
215 )))
216 }
217 }
218
219 #[must_use]
225 pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
226 unsafe {
227 llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
228 }
229 }
230}
231
232#[cfg(test)]
233#[cfg(feature = "tests_that_use_llms")]
234mod tests {
235 use std::num::NonZeroU32;
236
237 use serial_test::serial;
238
239 use crate::context::params::LlamaContextParams;
240 use crate::llama_batch::LlamaBatch;
241 use crate::model::AddBos;
242 use crate::test_model;
243
244 #[test]
245 #[serial]
246 fn clear_kv_cache_resets_positions() {
247 let (backend, model) = test_model::load_default_model().unwrap();
248 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
249 let mut context = model.new_context(&backend, ctx_params).unwrap();
250
251 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
252 let mut batch = LlamaBatch::new(512, 1).unwrap();
253 batch.add_sequence(&tokens, 0, false).unwrap();
254 context.decode(&mut batch).unwrap();
255
256 context.clear_kv_cache();
257 assert_eq!(context.kv_cache_seq_pos_max(0), -1);
258 }
259
260 #[test]
261 #[serial]
262 fn kv_cache_seq_pos_max_is_non_negative_after_decode() {
263 let (backend, model) = test_model::load_default_model().unwrap();
264 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
265 let mut context = model.new_context(&backend, ctx_params).unwrap();
266
267 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
268 let mut batch = LlamaBatch::new(512, 1).unwrap();
269 batch.add_sequence(&tokens, 0, false).unwrap();
270 context.decode(&mut batch).unwrap();
271
272 assert!(context.kv_cache_seq_pos_max(0) >= 0);
273 }
274
275 #[test]
276 #[serial]
277 fn clear_kv_cache_seq_with_range() {
278 let (backend, model) = test_model::load_default_model().unwrap();
279 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
280 let mut context = model.new_context(&backend, ctx_params).unwrap();
281
282 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
283 let mut batch = LlamaBatch::new(512, 1).unwrap();
284 batch.add_sequence(&tokens, 0, false).unwrap();
285 context.decode(&mut batch).unwrap();
286
287 let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(1));
288 assert!(result.is_ok());
289 }
290
291 #[test]
292 #[serial]
293 fn copy_kv_cache_seq_succeeds() {
294 let (backend, model) = test_model::load_default_model().unwrap();
295 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
296 let mut context = model.new_context(&backend, ctx_params).unwrap();
297
298 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
299 let mut batch = LlamaBatch::new(512, 1).unwrap();
300 batch.add_sequence(&tokens, 0, false).unwrap();
301 context.decode(&mut batch).unwrap();
302
303 let result = context.copy_kv_cache_seq(0, 1, None, None);
304 assert!(result.is_ok());
305 }
306
307 #[test]
308 #[serial]
309 fn copy_cache_executes_without_crash() {
310 let (backend, model) = test_model::load_default_model().unwrap();
311 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
312 let mut context = model.new_context(&backend, ctx_params).unwrap();
313
314 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
315 let mut batch = LlamaBatch::new(512, 1).unwrap();
316 batch.add_sequence(&tokens, 0, false).unwrap();
317 context.decode(&mut batch).unwrap();
318
319 let pos_max = context.kv_cache_seq_pos_max(0);
320 context.copy_cache(0, 1, pos_max + 1);
321 }
322
323 #[test]
324 #[serial]
325 fn kv_cache_seq_add_returns_error_for_mrope_model() {
326 let (backend, model) = test_model::load_default_model().unwrap();
327 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
328 let mut context = model.new_context(&backend, ctx_params).unwrap();
329
330 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
331 let mut batch = LlamaBatch::new(512, 1).unwrap();
332 batch.add_sequence(&tokens, 0, false).unwrap();
333 context.decode(&mut batch).unwrap();
334
335 let result = context.kv_cache_seq_add(0, Some(0), None, 1);
336
337 assert!(result.is_err());
338 }
339
340 #[test]
341 #[serial]
342 fn kv_cache_seq_div_returns_error_for_mrope_model() {
343 let (backend, model) = test_model::load_default_model().unwrap();
344 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
345 let mut context = model.new_context(&backend, ctx_params).unwrap();
346
347 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
348 let mut batch = LlamaBatch::new(512, 1).unwrap();
349 batch.add_sequence(&tokens, 0, false).unwrap();
350 context.decode(&mut batch).unwrap();
351
352 let divisor = std::num::NonZeroU8::new(2).unwrap();
353 let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
354
355 assert!(result.is_err());
356 }
357
358 #[test]
359 #[serial]
360 fn kv_cache_seq_keep_retains_specified_sequence() {
361 let (backend, model) = test_model::load_default_model().unwrap();
362 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
363 let mut context = model.new_context(&backend, ctx_params).unwrap();
364
365 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
366 let mut batch = LlamaBatch::new(512, 1).unwrap();
367 batch.add_sequence(&tokens, 0, false).unwrap();
368 context.decode(&mut batch).unwrap();
369
370 context.kv_cache_seq_keep(0);
371
372 assert!(context.kv_cache_seq_pos_max(0) >= 0);
373 }
374
375 #[test]
376 #[serial]
377 fn copy_kv_cache_seq_with_explicit_range() {
378 let (backend, model) = test_model::load_default_model().unwrap();
379 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
380 let mut context = model.new_context(&backend, ctx_params).unwrap();
381
382 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
383 let mut batch = LlamaBatch::new(512, 1).unwrap();
384 batch.add_sequence(&tokens, 0, false).unwrap();
385 context.decode(&mut batch).unwrap();
386
387 let result = context.copy_kv_cache_seq(0, 2, Some(0), Some(1));
388
389 assert!(result.is_ok());
390 }
391
392 #[test]
393 #[serial]
394 fn kv_cache_seq_add_succeeds_on_embedding_model() {
395 let (backend, model) = test_model::load_default_embedding_model().unwrap();
396 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
397 let mut context = model.new_context(&backend, ctx_params).unwrap();
398
399 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
400 let mut batch = LlamaBatch::new(512, 1).unwrap();
401 batch.add_sequence(&tokens, 0, false).unwrap();
402 context.decode(&mut batch).unwrap();
403
404 let result = context.kv_cache_seq_add(0, Some(0), None, 1);
405
406 assert!(result.is_ok());
407 }
408
409 #[test]
410 #[serial]
411 fn kv_cache_seq_div_succeeds_on_embedding_model() {
412 let (backend, model) = test_model::load_default_embedding_model().unwrap();
413 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
414 let mut context = model.new_context(&backend, ctx_params).unwrap();
415
416 let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
417 let mut batch = LlamaBatch::new(512, 1).unwrap();
418 batch.add_sequence(&tokens, 0, false).unwrap();
419 context.decode(&mut batch).unwrap();
420
421 let divisor = std::num::NonZeroU8::new(2).unwrap();
422 let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
423
424 assert!(result.is_ok());
425 }
426
427 #[test]
428 #[serial]
429 fn kv_cache_seq_pos_max_returns_negative_one_for_unused_seq() {
430 let (backend, model) = test_model::load_default_model().unwrap();
431 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
432 let context = model.new_context(&backend, ctx_params).unwrap();
433
434 let result = context.kv_cache_seq_pos_max(999);
435
436 assert_eq!(result, -1);
437 }
438
439 #[test]
440 #[serial]
441 fn copy_kv_cache_seq_rejects_p0_exceeding_i32_max() {
442 let (backend, model) = test_model::load_default_model().unwrap();
443 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
444 let mut context = model.new_context(&backend, ctx_params).unwrap();
445
446 let result = context.copy_kv_cache_seq(0, 1, Some(u32::MAX), None);
447
448 assert_eq!(
449 result.unwrap_err(),
450 super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
451 );
452 }
453
454 #[test]
455 #[serial]
456 fn copy_kv_cache_seq_rejects_p1_exceeding_i32_max() {
457 let (backend, model) = test_model::load_default_model().unwrap();
458 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
459 let mut context = model.new_context(&backend, ctx_params).unwrap();
460
461 let result = context.copy_kv_cache_seq(0, 1, Some(0), Some(u32::MAX));
462
463 assert_eq!(
464 result.unwrap_err(),
465 super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
466 );
467 }
468
469 #[test]
470 #[serial]
471 fn clear_kv_cache_seq_rejects_src_exceeding_i32_max() {
472 let (backend, model) = test_model::load_default_model().unwrap();
473 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
474 let mut context = model.new_context(&backend, ctx_params).unwrap();
475
476 let result = context.clear_kv_cache_seq(Some(u32::MAX), None, None);
477
478 assert_eq!(
479 result.unwrap_err(),
480 super::KvCacheConversionError::SeqIdTooLarge(i32::try_from(u32::MAX).unwrap_err()),
481 );
482 }
483
484 #[test]
485 #[serial]
486 fn clear_kv_cache_seq_rejects_p0_exceeding_i32_max() {
487 let (backend, model) = test_model::load_default_model().unwrap();
488 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
489 let mut context = model.new_context(&backend, ctx_params).unwrap();
490
491 let result = context.clear_kv_cache_seq(Some(0), Some(u32::MAX), None);
492
493 assert_eq!(
494 result.unwrap_err(),
495 super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
496 );
497 }
498
499 #[test]
500 #[serial]
501 fn clear_kv_cache_seq_rejects_p1_exceeding_i32_max() {
502 let (backend, model) = test_model::load_default_model().unwrap();
503 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
504 let mut context = model.new_context(&backend, ctx_params).unwrap();
505
506 let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(u32::MAX));
507
508 assert_eq!(
509 result.unwrap_err(),
510 super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
511 );
512 }
513
514 #[test]
515 #[serial]
516 fn kv_cache_seq_add_rejects_p0_exceeding_i32_max() {
517 let (backend, model) = test_model::load_default_model().unwrap();
518 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
519 let mut context = model.new_context(&backend, ctx_params).unwrap();
520
521 let result = context.kv_cache_seq_add(0, Some(u32::MAX), None, 1);
522
523 assert_eq!(
524 result.unwrap_err(),
525 super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
526 );
527 }
528
529 #[test]
530 #[serial]
531 fn kv_cache_seq_add_rejects_p1_exceeding_i32_max() {
532 let (backend, model) = test_model::load_default_model().unwrap();
533 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
534 let mut context = model.new_context(&backend, ctx_params).unwrap();
535
536 let result = context.kv_cache_seq_add(0, Some(0), Some(u32::MAX), 1);
537
538 assert_eq!(
539 result.unwrap_err(),
540 super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
541 );
542 }
543
544 #[test]
545 #[serial]
546 fn kv_cache_seq_div_rejects_p0_exceeding_i32_max() {
547 let (backend, model) = test_model::load_default_model().unwrap();
548 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
549 let mut context = model.new_context(&backend, ctx_params).unwrap();
550
551 let divisor = std::num::NonZeroU8::new(2).unwrap();
552 let result = context.kv_cache_seq_div(0, Some(u32::MAX), None, divisor);
553
554 assert_eq!(
555 result.unwrap_err(),
556 super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
557 );
558 }
559
560 #[test]
561 #[serial]
562 fn kv_cache_seq_div_rejects_p1_exceeding_i32_max() {
563 let (backend, model) = test_model::load_default_model().unwrap();
564 let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
565 let mut context = model.new_context(&backend, ctx_params).unwrap();
566
567 let divisor = std::num::NonZeroU8::new(2).unwrap();
568 let result = context.kv_cache_seq_div(0, Some(0), Some(u32::MAX), divisor);
569
570 assert_eq!(
571 result.unwrap_err(),
572 super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
573 );
574 }
575}