1use candle::{DType, Device, Result, Tensor};
6use candle_nn::kv_cache::RotatingKvCache;
7
8#[derive(Debug, Clone)]
9pub struct IndicesAndMask {
10 indices: Tensor,
11 mask: Tensor,
12}
13
14impl IndicesAndMask {
15 pub fn mask(&self) -> &Tensor {
16 &self.mask
17 }
18}
19
20#[derive(Debug, Clone)]
21pub struct ScatteredKvCache {
22 k: Tensor,
23 v: Tensor,
24 context: usize,
25}
26
27impl ScatteredKvCache {
28 pub fn append(
29 &mut self,
30 k: &Tensor,
31 v: &Tensor,
32 iam: &IndicesAndMask,
33 ) -> Result<(Tensor, Tensor)> {
34 if self.context <= k.dim(2)? {
35 return Ok((k.clone(), v.clone()));
36 }
37 let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
38 let indices = indices.broadcast_as(k.shape())?.contiguous()?;
39 self.k.scatter_set(&indices, k, 2)?;
40 self.v.scatter_set(&indices, v, 2)?;
41 Ok((self.k.clone(), self.v.clone()))
42 }
43
44 pub fn k(&self) -> &Tensor {
45 &self.k
46 }
47
48 pub fn v(&self) -> &Tensor {
49 &self.v
50 }
51}
52
53#[derive(Debug, Clone)]
54pub struct ScatteredCacheBuilder {
55 context: usize,
56 positions: Vec<usize>,
58 indices: Vec<usize>,
60 dtype: DType,
61 device: Device,
62}
63
64impl ScatteredCacheBuilder {
65 pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
66 let positions = vec![0; batch_size];
67 let indices = vec![0; batch_size];
68 Ok(Self { positions, indices, context, dtype, device: device.clone() })
69 }
70
71 pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
72 let batch_size = self.batch_size();
73 let shape = (batch_size, num_heads, self.context, head_dim);
74 let k = Tensor::zeros(shape, self.dtype, self.device())?;
75 let v = Tensor::zeros(shape, self.dtype, self.device())?;
76 Ok(ScatteredKvCache { k, v, context: self.context })
77 }
78
79 pub fn positions(&self) -> &[usize] {
80 &self.positions
81 }
82
83 pub fn reset(&mut self) {
84 self.positions.fill(0);
85 self.indices.fill(0);
86 }
87
88 pub fn batch_size(&self) -> usize {
89 self.positions.len()
90 }
91
92 pub fn reset_batch_index(&mut self, batch_index: usize) {
93 self.positions[batch_index] = 0;
94 self.indices[batch_index] = 0;
95 }
96
97 #[allow(clippy::needless_range_loop)]
98 pub fn indices_and_mask(
99 &mut self,
100 seq_len: usize,
101 batch_mask: &[bool],
102 ) -> Result<IndicesAndMask> {
103 let context = self.context;
105 if self.context <= seq_len {
106 return self.indices_and_mask_abs(seq_len, batch_mask);
107 }
108 let mut attention_masks = Vec::with_capacity(self.batch_size());
109 let mut cache_indices = Vec::with_capacity(self.batch_size());
110 for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
111 if !batch_mask {
112 let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
113 let indices = vec![self.indices[batch_i] as u32; seq_len];
114 attention_masks.push(masks);
115 cache_indices.push(indices);
116 } else {
117 let start_index = self.indices[batch_i];
118 let start_pos = self.positions[batch_i];
119 let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
120 let mut indices = Vec::with_capacity(seq_len);
121 let mut all_pos = vec![usize::MAX; context];
122 if start_pos < context {
123 for i in 0..start_pos {
124 all_pos[i] = i;
125 }
126 } else {
127 let offset = start_pos - start_index;
128 for i in 0..context {
129 all_pos[i] =
130 if i < start_index { i + offset } else { i + offset - context };
131 }
132 }
133 for seq_i in 0..seq_len {
134 let index = self.indices[batch_i];
135 all_pos[index] = seq_i + start_pos;
136 indices.push(index as u32);
137 self.indices[batch_i] += 1;
138 self.positions[batch_i] += 1;
139 if self.indices[batch_i] >= self.context {
140 self.indices[batch_i] = 0;
141 }
142 }
143
144 for seq_i in 0..seq_len {
145 let my_pos = seq_i + start_pos;
146 let mask = all_pos
147 .iter()
148 .map(|&pos| if pos <= my_pos { 0.0 } else { f32::NEG_INFINITY })
149 .collect::<Vec<f32>>();
150 masks.push(mask);
151 }
152
153 attention_masks.push(masks);
154 cache_indices.push(indices);
155 }
156 }
157 let attention_masks =
161 attention_masks.into_iter().flat_map(|m| m.into_iter().flatten()).collect::<Vec<f32>>();
162 let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
163 .to_dtype(self.dtype)?;
164 let indices = Tensor::new(cache_indices, self.device())?;
165 Ok(IndicesAndMask { indices, mask })
166 }
167
168 pub fn device(&self) -> &Device {
169 &self.device
170 }
171
172 #[allow(clippy::needless_range_loop)]
173 fn indices_and_mask_abs(
174 &mut self,
175 seq_len: usize,
176 batch_mask: &[bool],
177 ) -> Result<IndicesAndMask> {
178 let mask = self.get_mask_abs(seq_len, seq_len)?;
179 let mut cache_indices = Vec::with_capacity(self.batch_size());
180 for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
181 if !batch_mask {
182 let indices = vec![self.indices[batch_i] as u32; seq_len];
183 cache_indices.push(indices);
184 } else {
185 let mut indices = Vec::with_capacity(seq_len);
186 for _ in 0..seq_len {
187 let index = self.indices[batch_i];
188 indices.push(index as u32);
189 self.indices[batch_i] += 1;
190 self.positions[batch_i] += 1;
191 if self.indices[batch_i] >= self.context {
192 self.indices[batch_i] = 0;
193 }
194 }
195 cache_indices.push(indices);
196 }
197 }
198 let indices = Tensor::new(cache_indices, self.device())?;
199 Ok(IndicesAndMask { indices, mask })
200 }
201
202 fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
203 let context = self.context;
204 let mask: Vec<_> = (0..size1)
205 .flat_map(|i| {
206 (0..size2).map(move |j| {
207 if size1 + j > size2 + i || size1 + j + context < size2 + i {
208 f32::NEG_INFINITY
209 } else {
210 0.0
211 }
212 })
213 })
214 .collect();
215 Tensor::from_slice(&mask, (size1, size2), self.device())
216 }
217}
218
219#[derive(Debug, Clone)]
220pub enum KvCache {
221 Rotating(RotatingKvCache),
222}
223
224impl KvCache {
225 pub fn new(dim: usize, max_seq_len: usize) -> Self {
226 let cache = RotatingKvCache::new(dim, max_seq_len);
227 Self::Rotating(cache)
228 }
229
230 pub fn current_seq_len(&self) -> usize {
231 match self {
232 KvCache::Rotating(cache) => cache.current_seq_len(),
233 }
234 }
235
236 pub fn reset(&mut self) {
237 match self {
238 KvCache::Rotating(cache) => cache.reset(),
239 }
240 }
241
242 pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
243 match self {
244 KvCache::Rotating(cache) => cache.append(key, value),
245 }
246 }
247
248 pub fn positions(&self, seq_len: usize) -> Vec<usize> {
249 match self {
250 KvCache::Rotating(cache) => cache.positions(seq_len),
251 }
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use super::*;
258 use candle::IndexOp;
259
260 #[test]
261 fn test_scattered_kv_cache() -> Result<()> {
262 let device = Device::Cpu;
263 let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
264 let inf = f32::INFINITY;
265
266 let iam = cache.indices_and_mask(1, &[true, false])?;
267 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
268 assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
269 assert_eq!(mask, [[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
270
271 let iam = cache.indices_and_mask(1, &[true, false])?;
272 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
273 assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
274 assert_eq!(mask, [[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
275
276 let iam = cache.indices_and_mask(3, &[false, true])?;
277 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
278 assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
279 assert_eq!(
280 mask,
281 [
282 [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
283 [
284 [0.0, -inf, -inf, -inf, -inf],
285 [0.0, 0.0, -inf, -inf, -inf],
286 [0.0, 0.0, 0.0, -inf, -inf]
287 ]
288 ]
289 );
290
291 let iam = cache.indices_and_mask(3, &[true, true])?;
292 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
293 assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
294 assert_eq!(
295 mask,
296 [
297 [
298 [0.0, 0.0, 0.0, -inf, -inf],
299 [0.0, 0.0, 0.0, 0.0, -inf],
300 [0.0, 0.0, 0.0, 0.0, 0.0]
301 ],
302 [
303 [-inf, 0.0, 0.0, 0.0, -inf],
304 [-inf, 0.0, 0.0, 0.0, 0.0],
305 [0.0, 0.0, 0.0, 0.0, 0.0]
306 ]
307 ]
308 );
309
310 let iam = cache.indices_and_mask(1, &[true, false])?;
311 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
312 assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
313 assert_eq!(mask, [[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
314
315 let iam = cache.indices_and_mask(2, &[true, false])?;
316 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
317 assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
318 assert_eq!(
319 mask,
320 [
321 [[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
322 [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
323 ]
324 );
325
326 Ok(())
327 }
328}