kalosm_common/
kv_cache.rs

1use candle_core::Tensor;
2
3/// A growable kv cache. This cache wraps candles [`KvCache`] with exponentially larger allocations as the sequence length increases.
4#[derive(Debug, Clone)]
5pub struct KvCache {
6    cache: candle_nn::kv_cache::KvCache,
7    concat_dim: usize,
8    max_seq_len: usize,
9}
10
11impl KvCache {
12    /// Create a new cache with the given max sequence length.
13    pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
14        Self {
15            cache: candle_nn::kv_cache::KvCache::new(concat_dim, 8),
16            concat_dim,
17            max_seq_len,
18        }
19    }
20
21    /// Get the raw cache.
22    pub fn cache(&self) -> &candle_nn::kv_cache::KvCache {
23        &self.cache
24    }
25
26    /// Get the raw cache mutably.
27    pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::KvCache {
28        &mut self.cache
29    }
30
31    /// Reset the cache.
32    pub fn reset(&mut self) {
33        self.cache.reset()
34    }
35
36    /// Append a new key/value pair to the cache.
37    pub fn append(&mut self, k: &Tensor, v: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
38        let k = k.contiguous()?;
39        let v = v.contiguous()?;
40        let seq_len = k.dim(self.concat_dim)?;
41        // The key and value token length must be the same.
42        debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
43
44        let current_allocated_size = self.cache.k_cache().max_seq_len();
45        let size_required_for_append = self.cache.current_seq_len() + seq_len;
46
47        // If adding the new key/value pair would exceed the max sequence length, we need to allocate a new tensor with double the size or the max sequence length whichever is smaller.
48        if size_required_for_append > current_allocated_size {
49            // The new size of the cache is double the old size or the max sequence length of the model.
50            // We try to keep the new size a power of two to keep memory alignment nice.
51            let next_power_of_two = size_required_for_append.next_power_of_two();
52            let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
53
54            // Create a new cache with the new size.
55            let mut new_cache =
56                candle_nn::kv_cache::KvCache::new(self.concat_dim, new_cache_max_seq_len);
57            // Append the old cache to the new cache.
58            if let (Ok(Some(k)), Ok(Some(v))) = (self.cache.k(), self.cache.v()) {
59                new_cache.k_cache_mut().append(&k.contiguous()?)?;
60                new_cache.v_cache_mut().append(&v.contiguous()?)?;
61            }
62            // Replace the old cache with the new cache.
63            self.cache = new_cache;
64        }
65
66        self.cache.append(&k, &v)
67    }
68}
69
70/// A growable tensor cache. This cache wraps candles [`Cache`] with exponentially larger allocations as the sequence length increases.
71#[derive(Debug, Clone)]
72pub struct TensorCache {
73    cache: candle_nn::kv_cache::Cache,
74    concat_dim: usize,
75    max_seq_len: usize,
76}
77
78impl TensorCache {
79    /// Create a new cache with the given max sequence length.
80    pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
81        Self {
82            cache: candle_nn::kv_cache::Cache::new(concat_dim, 8),
83            concat_dim,
84            max_seq_len,
85        }
86    }
87
88    /// Get the raw cache.
89    pub fn cache(&self) -> &candle_nn::kv_cache::Cache {
90        &self.cache
91    }
92
93    /// Get the raw cache mutably.
94    pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::Cache {
95        &mut self.cache
96    }
97
98    /// Get the current tensor in the cache.
99    pub fn all_data(&self) -> &Option<Tensor> {
100        self.cache.all_data()
101    }
102
103    /// Reset the cache.
104    pub fn reset(&mut self) {
105        self.cache.reset()
106    }
107
108    /// Append a new value to the cache.
109    pub fn append(&mut self, v: &Tensor) -> candle_core::Result<()> {
110        let v = v.contiguous()?;
111        let seq_len = v.dim(self.concat_dim)?;
112        // The key and value token length must be the same.
113        debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
114
115        let current_allocated_size = self.cache.max_seq_len();
116        let size_required_for_append = self.cache.current_seq_len() + seq_len;
117
118        // If adding the new key/value pair would exceed the max sequence length, we need to allocate a new tensor with double the size or the max sequence length whichever is smaller.
119        if size_required_for_append > current_allocated_size {
120            // The new size of the cache is double the old size or the max sequence length of the model.
121            // We try to keep the new size a power of two to keep memory alignment nice.
122            let next_power_of_two = size_required_for_append.next_power_of_two();
123            let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
124
125            // Create a new cache with the new size.
126            let mut new_cache =
127                candle_nn::kv_cache::Cache::new(self.concat_dim, new_cache_max_seq_len);
128            // Append the old cache to the new cache.
129            if let Some(v) = self.cache.all_data() {
130                new_cache.append(&v.contiguous()?)?;
131            }
132            // Replace the old cache with the new cache.
133            self.cache = new_cache;
134        }
135
136        // self.cache.append(&k, &v)
137        self.cache.append(&v)
138    }
139}