Skip to main content

llama_cpp_bindings/context/
kv_cache.rs

1//! utilities for working with the kv cache
2
3use crate::context::LlamaContext;
4use std::ffi::c_int;
5use std::num::{NonZeroU8, TryFromIntError};
6
7/// Errors that can occur when attempting to prepare values for the kv cache
8#[derive(Debug, Eq, PartialEq, thiserror::Error)]
9pub enum KvCacheConversionError {
10    /// Sequence id conversion to i32 failed
11    #[error("Provided sequence id is too large for a i32")]
12    SeqIdTooLarge(#[source] TryFromIntError),
13    /// Position 0 conversion to i32 failed
14    #[error("Provided start position is too large for a i32")]
15    P0TooLarge(#[source] TryFromIntError),
16    /// Position 1 conversion to i32 failed
17    #[error("Provided end position is too large for a i32")]
18    P1TooLarge(#[source] TryFromIntError),
19    /// The operation is not supported by the current model/context configuration.
20    #[error("operation not supported by this model: {0}")]
21    UnsupportedOperation(String),
22}
23
24impl LlamaContext<'_> {
25    /// Copy the cache from one sequence to another.
26    ///
27    /// # Parameters
28    ///
29    /// * `src` - The sequence id to copy the cache from.
30    /// * `dest` - The sequence id to copy the cache to.
31    /// * `size` - The size of the cache to copy.
32    pub fn copy_cache(&mut self, src: i32, dest: i32, size: i32) {
33        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
34        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, 0, size) }
35    }
36
37    /// Copy the cache from one sequence to another.
38    ///
39    /// # Returns
40    /// A `Result` indicating whether the operation was successful.
41    ///
42    /// # Parameters
43    /// * `src` - The sequence id to copy the cache from.
44    /// * `dest` - The sequence id to copy the cache to.
45    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is copied up to `p1`.
46    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is copied starting from `p0`.
47    ///
48    /// # Errors
49    /// If either position exceeds [`i32::MAX`].
50    pub fn copy_kv_cache_seq(
51        &mut self,
52        src: i32,
53        dest: i32,
54        p0: Option<u32>,
55        p1: Option<u32>,
56    ) -> Result<(), KvCacheConversionError> {
57        let p0 = p0
58            .map_or(Ok(-1), i32::try_from)
59            .map_err(KvCacheConversionError::P0TooLarge)?;
60        let p1 = p1
61            .map_or(Ok(-1), i32::try_from)
62            .map_err(KvCacheConversionError::P1TooLarge)?;
63        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
64        unsafe { llama_cpp_bindings_sys::llama_memory_seq_cp(mem, src, dest, p0, p1) };
65        Ok(())
66    }
67
68    /// Clear the kv cache for the given sequence within the specified range `[p0, p1)`
69    /// Returns `false` only when partial sequence removals fail. Full sequence removals always succeed.
70    ///
71    /// # Returns
72    /// A `Result` indicating whether the operation was successful. If the sequence id or
73    /// either position exceeds the maximum i32 value, no removal is attempted and an `Err` is returned.
74    ///
75    /// # Parameters
76    /// * `src` - The sequence id to clear the cache for. If `None`, matches all sequences
77    /// * `p0` - The start position of the cache to clear. If `None`, the entire cache is cleared up to `p1`.
78    /// * `p1` - The end position of the cache to clear. If `None`, the entire cache is cleared from `p0`.
79    ///
80    /// # Errors
81    /// If the sequence id or either position exceeds [`i32::MAX`].
82    pub fn clear_kv_cache_seq(
83        &mut self,
84        src: Option<u32>,
85        p0: Option<u32>,
86        p1: Option<u32>,
87    ) -> Result<bool, KvCacheConversionError> {
88        let src = src
89            .map_or(Ok(-1), i32::try_from)
90            .map_err(KvCacheConversionError::SeqIdTooLarge)?;
91        let p0 = p0
92            .map_or(Ok(-1), i32::try_from)
93            .map_err(KvCacheConversionError::P0TooLarge)?;
94        let p1 = p1
95            .map_or(Ok(-1), i32::try_from)
96            .map_err(KvCacheConversionError::P1TooLarge)?;
97        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
98        Ok(unsafe { llama_cpp_bindings_sys::llama_memory_seq_rm(mem, src, p0, p1) })
99    }
100
101    /// Clear the KV cache
102    pub fn clear_kv_cache(&mut self) {
103        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
104        // clear both metadata and data buffers to match previous semantics
105        unsafe { llama_cpp_bindings_sys::llama_memory_clear(mem, true) }
106    }
107
108    /// Removes all tokens that do not belong to the specified sequence
109    ///
110    /// # Parameters
111    ///
112    /// * `seq_id` - The sequence id to keep
113    pub fn kv_cache_seq_keep(&mut self, seq_id: i32) {
114        let mem = unsafe { llama_cpp_bindings_sys::llama_get_memory(self.context.as_ptr()) };
115        unsafe { llama_cpp_bindings_sys::llama_memory_seq_keep(mem, seq_id) }
116    }
117
118    /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in `[p0, p1)`
119    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
120    ///   - lazily on next [`LlamaContext::decode`]
121    ///   - explicitly with [`Self::kv_cache_update`]
122    ///
123    /// # Returns
124    /// A `Result` indicating whether the operation was successful.
125    ///
126    /// # Parameters
127    ///
128    /// * `seq_id` - The sequence id to update
129    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
130    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
131    /// * `delta` - The relative position to add to the tokens
132    ///
133    /// # Errors
134    /// If either position exceeds [`i32::MAX`].
135    pub fn kv_cache_seq_add(
136        &mut self,
137        seq_id: i32,
138        p0: Option<u32>,
139        p1: Option<u32>,
140        delta: i32,
141    ) -> Result<(), KvCacheConversionError> {
142        let p0 = p0
143            .map_or(Ok(-1), i32::try_from)
144            .map_err(KvCacheConversionError::P0TooLarge)?;
145        let p1 = p1
146            .map_or(Ok(-1), i32::try_from)
147            .map_err(KvCacheConversionError::P1TooLarge)?;
148        let status = unsafe {
149            llama_cpp_bindings_sys::llama_rs_memory_seq_add(
150                self.context.as_ptr(),
151                seq_id,
152                p0,
153                p1,
154                delta,
155            )
156        };
157
158        if crate::status_is_ok(status) {
159            Ok(())
160        } else {
161            Err(KvCacheConversionError::UnsupportedOperation(format!(
162                "kv_cache_seq_add failed (status {})",
163                crate::status_to_i32(status)
164            )))
165        }
166    }
167
168    /// Integer division of the positions by factor of `d > 1`
169    /// If the KV cache is `RoPEd`, the KV data is updated accordingly:
170    ///   - lazily on next [`LlamaContext::decode`]
171    ///   - explicitly with [`Self::kv_cache_update`]
172    ///
173    /// # Returns
174    /// A `Result` indicating whether the operation was successful.
175    ///
176    /// # Parameters
177    ///
178    /// * `seq_id` - The sequence id to update
179    /// * `p0` - The start position of the cache to update. If `None`, the entire cache is updated up to `p1`.
180    /// * `p1` - The end position of the cache to update. If `None`, the entire cache is updated starting from `p0`.
181    /// * `d` - The factor to divide the positions by
182    ///
183    /// # Errors
184    /// If either position exceeds [`i32::MAX`].
185    pub fn kv_cache_seq_div(
186        &mut self,
187        seq_id: i32,
188        p0: Option<u32>,
189        p1: Option<u32>,
190        d: NonZeroU8,
191    ) -> Result<(), KvCacheConversionError> {
192        let p0 = p0
193            .map_or(Ok(-1), i32::try_from)
194            .map_err(KvCacheConversionError::P0TooLarge)?;
195        let p1 = p1
196            .map_or(Ok(-1), i32::try_from)
197            .map_err(KvCacheConversionError::P1TooLarge)?;
198        let d = c_int::from(d.get());
199        let status = unsafe {
200            llama_cpp_bindings_sys::llama_rs_memory_seq_div(
201                self.context.as_ptr(),
202                seq_id,
203                p0,
204                p1,
205                d,
206            )
207        };
208
209        if crate::status_is_ok(status) {
210            Ok(())
211        } else {
212            Err(KvCacheConversionError::UnsupportedOperation(format!(
213                "kv_cache_seq_div failed (status {})",
214                crate::status_to_i32(status)
215            )))
216        }
217    }
218
219    /// Returns the largest position present in the KV cache for the specified sequence
220    ///
221    /// # Parameters
222    ///
223    /// * `seq_id` - The sequence id to get the max position for
224    #[must_use]
225    pub fn kv_cache_seq_pos_max(&self, seq_id: i32) -> i32 {
226        unsafe {
227            llama_cpp_bindings_sys::llama_rs_memory_seq_pos_max(self.context.as_ptr(), seq_id)
228        }
229    }
230}
231
232#[cfg(test)]
233#[cfg(feature = "tests_that_use_llms")]
234mod tests {
235    use std::num::NonZeroU32;
236
237    use serial_test::serial;
238
239    use crate::context::params::LlamaContextParams;
240    use crate::llama_batch::LlamaBatch;
241    use crate::model::AddBos;
242    use crate::test_model;
243
244    #[test]
245    #[serial]
246    fn clear_kv_cache_resets_positions() {
247        let (backend, model) = test_model::load_default_model().unwrap();
248        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
249        let mut context = model.new_context(&backend, ctx_params).unwrap();
250
251        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
252        let mut batch = LlamaBatch::new(512, 1).unwrap();
253        batch.add_sequence(&tokens, 0, false).unwrap();
254        context.decode(&mut batch).unwrap();
255
256        context.clear_kv_cache();
257        assert_eq!(context.kv_cache_seq_pos_max(0), -1);
258    }
259
260    #[test]
261    #[serial]
262    fn kv_cache_seq_pos_max_is_non_negative_after_decode() {
263        let (backend, model) = test_model::load_default_model().unwrap();
264        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
265        let mut context = model.new_context(&backend, ctx_params).unwrap();
266
267        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
268        let mut batch = LlamaBatch::new(512, 1).unwrap();
269        batch.add_sequence(&tokens, 0, false).unwrap();
270        context.decode(&mut batch).unwrap();
271
272        assert!(context.kv_cache_seq_pos_max(0) >= 0);
273    }
274
275    #[test]
276    #[serial]
277    fn clear_kv_cache_seq_with_range() {
278        let (backend, model) = test_model::load_default_model().unwrap();
279        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
280        let mut context = model.new_context(&backend, ctx_params).unwrap();
281
282        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
283        let mut batch = LlamaBatch::new(512, 1).unwrap();
284        batch.add_sequence(&tokens, 0, false).unwrap();
285        context.decode(&mut batch).unwrap();
286
287        let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(1));
288        assert!(result.is_ok());
289    }
290
291    #[test]
292    #[serial]
293    fn copy_kv_cache_seq_succeeds() {
294        let (backend, model) = test_model::load_default_model().unwrap();
295        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
296        let mut context = model.new_context(&backend, ctx_params).unwrap();
297
298        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
299        let mut batch = LlamaBatch::new(512, 1).unwrap();
300        batch.add_sequence(&tokens, 0, false).unwrap();
301        context.decode(&mut batch).unwrap();
302
303        let result = context.copy_kv_cache_seq(0, 1, None, None);
304        assert!(result.is_ok());
305    }
306
307    #[test]
308    #[serial]
309    fn copy_cache_executes_without_crash() {
310        let (backend, model) = test_model::load_default_model().unwrap();
311        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
312        let mut context = model.new_context(&backend, ctx_params).unwrap();
313
314        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
315        let mut batch = LlamaBatch::new(512, 1).unwrap();
316        batch.add_sequence(&tokens, 0, false).unwrap();
317        context.decode(&mut batch).unwrap();
318
319        let pos_max = context.kv_cache_seq_pos_max(0);
320        context.copy_cache(0, 1, pos_max + 1);
321    }
322
323    #[test]
324    #[serial]
325    fn kv_cache_seq_add_returns_error_for_mrope_model() {
326        let (backend, model) = test_model::load_default_model().unwrap();
327        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
328        let mut context = model.new_context(&backend, ctx_params).unwrap();
329
330        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
331        let mut batch = LlamaBatch::new(512, 1).unwrap();
332        batch.add_sequence(&tokens, 0, false).unwrap();
333        context.decode(&mut batch).unwrap();
334
335        let result = context.kv_cache_seq_add(0, Some(0), None, 1);
336
337        assert!(result.is_err());
338    }
339
340    #[test]
341    #[serial]
342    fn kv_cache_seq_div_returns_error_for_mrope_model() {
343        let (backend, model) = test_model::load_default_model().unwrap();
344        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
345        let mut context = model.new_context(&backend, ctx_params).unwrap();
346
347        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
348        let mut batch = LlamaBatch::new(512, 1).unwrap();
349        batch.add_sequence(&tokens, 0, false).unwrap();
350        context.decode(&mut batch).unwrap();
351
352        let divisor = std::num::NonZeroU8::new(2).unwrap();
353        let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
354
355        assert!(result.is_err());
356    }
357
358    #[test]
359    #[serial]
360    fn kv_cache_seq_keep_retains_specified_sequence() {
361        let (backend, model) = test_model::load_default_model().unwrap();
362        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
363        let mut context = model.new_context(&backend, ctx_params).unwrap();
364
365        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
366        let mut batch = LlamaBatch::new(512, 1).unwrap();
367        batch.add_sequence(&tokens, 0, false).unwrap();
368        context.decode(&mut batch).unwrap();
369
370        context.kv_cache_seq_keep(0);
371
372        assert!(context.kv_cache_seq_pos_max(0) >= 0);
373    }
374
375    #[test]
376    #[serial]
377    fn copy_kv_cache_seq_with_explicit_range() {
378        let (backend, model) = test_model::load_default_model().unwrap();
379        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
380        let mut context = model.new_context(&backend, ctx_params).unwrap();
381
382        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
383        let mut batch = LlamaBatch::new(512, 1).unwrap();
384        batch.add_sequence(&tokens, 0, false).unwrap();
385        context.decode(&mut batch).unwrap();
386
387        let result = context.copy_kv_cache_seq(0, 2, Some(0), Some(1));
388
389        assert!(result.is_ok());
390    }
391
392    #[test]
393    #[serial]
394    fn kv_cache_seq_add_succeeds_on_embedding_model() {
395        let (backend, model) = test_model::load_default_embedding_model().unwrap();
396        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
397        let mut context = model.new_context(&backend, ctx_params).unwrap();
398
399        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
400        let mut batch = LlamaBatch::new(512, 1).unwrap();
401        batch.add_sequence(&tokens, 0, false).unwrap();
402        context.decode(&mut batch).unwrap();
403
404        let result = context.kv_cache_seq_add(0, Some(0), None, 1);
405
406        assert!(result.is_ok());
407    }
408
409    #[test]
410    #[serial]
411    fn kv_cache_seq_div_succeeds_on_embedding_model() {
412        let (backend, model) = test_model::load_default_embedding_model().unwrap();
413        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
414        let mut context = model.new_context(&backend, ctx_params).unwrap();
415
416        let tokens = model.str_to_token("Hello world", AddBos::Always).unwrap();
417        let mut batch = LlamaBatch::new(512, 1).unwrap();
418        batch.add_sequence(&tokens, 0, false).unwrap();
419        context.decode(&mut batch).unwrap();
420
421        let divisor = std::num::NonZeroU8::new(2).unwrap();
422        let result = context.kv_cache_seq_div(0, Some(0), None, divisor);
423
424        assert!(result.is_ok());
425    }
426
427    #[test]
428    #[serial]
429    fn kv_cache_seq_pos_max_returns_negative_one_for_unused_seq() {
430        let (backend, model) = test_model::load_default_model().unwrap();
431        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
432        let context = model.new_context(&backend, ctx_params).unwrap();
433
434        let result = context.kv_cache_seq_pos_max(999);
435
436        assert_eq!(result, -1);
437    }
438
439    #[test]
440    #[serial]
441    fn copy_kv_cache_seq_rejects_p0_exceeding_i32_max() {
442        let (backend, model) = test_model::load_default_model().unwrap();
443        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
444        let mut context = model.new_context(&backend, ctx_params).unwrap();
445
446        let result = context.copy_kv_cache_seq(0, 1, Some(u32::MAX), None);
447
448        assert_eq!(
449            result.unwrap_err(),
450            super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
451        );
452    }
453
454    #[test]
455    #[serial]
456    fn copy_kv_cache_seq_rejects_p1_exceeding_i32_max() {
457        let (backend, model) = test_model::load_default_model().unwrap();
458        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
459        let mut context = model.new_context(&backend, ctx_params).unwrap();
460
461        let result = context.copy_kv_cache_seq(0, 1, Some(0), Some(u32::MAX));
462
463        assert_eq!(
464            result.unwrap_err(),
465            super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
466        );
467    }
468
469    #[test]
470    #[serial]
471    fn clear_kv_cache_seq_rejects_src_exceeding_i32_max() {
472        let (backend, model) = test_model::load_default_model().unwrap();
473        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
474        let mut context = model.new_context(&backend, ctx_params).unwrap();
475
476        let result = context.clear_kv_cache_seq(Some(u32::MAX), None, None);
477
478        assert_eq!(
479            result.unwrap_err(),
480            super::KvCacheConversionError::SeqIdTooLarge(i32::try_from(u32::MAX).unwrap_err()),
481        );
482    }
483
484    #[test]
485    #[serial]
486    fn clear_kv_cache_seq_rejects_p0_exceeding_i32_max() {
487        let (backend, model) = test_model::load_default_model().unwrap();
488        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
489        let mut context = model.new_context(&backend, ctx_params).unwrap();
490
491        let result = context.clear_kv_cache_seq(Some(0), Some(u32::MAX), None);
492
493        assert_eq!(
494            result.unwrap_err(),
495            super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
496        );
497    }
498
499    #[test]
500    #[serial]
501    fn clear_kv_cache_seq_rejects_p1_exceeding_i32_max() {
502        let (backend, model) = test_model::load_default_model().unwrap();
503        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
504        let mut context = model.new_context(&backend, ctx_params).unwrap();
505
506        let result = context.clear_kv_cache_seq(Some(0), Some(0), Some(u32::MAX));
507
508        assert_eq!(
509            result.unwrap_err(),
510            super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
511        );
512    }
513
514    #[test]
515    #[serial]
516    fn kv_cache_seq_add_rejects_p0_exceeding_i32_max() {
517        let (backend, model) = test_model::load_default_model().unwrap();
518        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
519        let mut context = model.new_context(&backend, ctx_params).unwrap();
520
521        let result = context.kv_cache_seq_add(0, Some(u32::MAX), None, 1);
522
523        assert_eq!(
524            result.unwrap_err(),
525            super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
526        );
527    }
528
529    #[test]
530    #[serial]
531    fn kv_cache_seq_add_rejects_p1_exceeding_i32_max() {
532        let (backend, model) = test_model::load_default_model().unwrap();
533        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
534        let mut context = model.new_context(&backend, ctx_params).unwrap();
535
536        let result = context.kv_cache_seq_add(0, Some(0), Some(u32::MAX), 1);
537
538        assert_eq!(
539            result.unwrap_err(),
540            super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
541        );
542    }
543
544    #[test]
545    #[serial]
546    fn kv_cache_seq_div_rejects_p0_exceeding_i32_max() {
547        let (backend, model) = test_model::load_default_model().unwrap();
548        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
549        let mut context = model.new_context(&backend, ctx_params).unwrap();
550
551        let divisor = std::num::NonZeroU8::new(2).unwrap();
552        let result = context.kv_cache_seq_div(0, Some(u32::MAX), None, divisor);
553
554        assert_eq!(
555            result.unwrap_err(),
556            super::KvCacheConversionError::P0TooLarge(i32::try_from(u32::MAX).unwrap_err()),
557        );
558    }
559
560    #[test]
561    #[serial]
562    fn kv_cache_seq_div_rejects_p1_exceeding_i32_max() {
563        let (backend, model) = test_model::load_default_model().unwrap();
564        let ctx_params = LlamaContextParams::default().with_n_ctx(NonZeroU32::new(512));
565        let mut context = model.new_context(&backend, ctx_params).unwrap();
566
567        let divisor = std::num::NonZeroU8::new(2).unwrap();
568        let result = context.kv_cache_seq_div(0, Some(0), Some(u32::MAX), divisor);
569
570        assert_eq!(
571            result.unwrap_err(),
572            super::KvCacheConversionError::P1TooLarge(i32::try_from(u32::MAX).unwrap_err()),
573        );
574    }
575}