forgellm_runtime/
kv_cache.rs1#[derive(Debug, Clone)]
11pub struct KVCache {
12 k: Vec<Vec<f32>>,
14 v: Vec<Vec<f32>>,
16 num_layers: usize,
18 num_kv_heads: usize,
20 head_dim: usize,
22 len: usize,
24}
25
26impl KVCache {
27 pub fn new(num_layers: usize, num_kv_heads: usize, head_dim: usize) -> Self {
29 Self {
30 k: (0..num_layers).map(|_| Vec::new()).collect(),
31 v: (0..num_layers).map(|_| Vec::new()).collect(),
32 num_layers,
33 num_kv_heads,
34 head_dim,
35 len: 0,
36 }
37 }
38
39 pub fn with_capacity(
41 num_layers: usize,
42 num_kv_heads: usize,
43 head_dim: usize,
44 max_seq_len: usize,
45 ) -> Self {
46 let entry_size = num_kv_heads * head_dim;
47 let capacity = max_seq_len * entry_size;
48 Self {
49 k: (0..num_layers)
50 .map(|_| Vec::with_capacity(capacity))
51 .collect(),
52 v: (0..num_layers)
53 .map(|_| Vec::with_capacity(capacity))
54 .collect(),
55 num_layers,
56 num_kv_heads,
57 head_dim,
58 len: 0,
59 }
60 }
61
62 pub fn append(&mut self, layer: usize, k_data: &[f32], v_data: &[f32]) {
66 debug_assert_eq!(k_data.len(), self.num_kv_heads * self.head_dim);
67 debug_assert_eq!(v_data.len(), self.num_kv_heads * self.head_dim);
68 self.k[layer].extend_from_slice(k_data);
69 self.v[layer].extend_from_slice(v_data);
70 }
71
72 pub fn advance(&mut self) {
75 self.len += 1;
76 }
77
78 pub fn k(&self, layer: usize) -> &[f32] {
80 &self.k[layer]
81 }
82
83 pub fn v(&self, layer: usize) -> &[f32] {
85 &self.v[layer]
86 }
87
88 pub fn len(&self) -> usize {
90 self.len
91 }
92
93 pub fn is_empty(&self) -> bool {
95 self.len == 0
96 }
97
98 pub fn num_layers(&self) -> usize {
100 self.num_layers
101 }
102
103 pub fn clear(&mut self) {
105 for layer_k in &mut self.k {
106 layer_k.clear();
107 }
108 for layer_v in &mut self.v {
109 layer_v.clear();
110 }
111 self.len = 0;
112 }
113
114 pub fn entry_size(&self) -> usize {
116 self.num_kv_heads * self.head_dim
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn create_and_append() {
126 let mut cache = KVCache::new(2, 4, 16);
127 assert!(cache.is_empty());
128 assert_eq!(cache.len(), 0);
129 assert_eq!(cache.entry_size(), 64);
130
131 let kv_data = vec![1.0f32; 64];
133 cache.append(0, &kv_data, &kv_data);
134 cache.append(1, &kv_data, &kv_data);
135 cache.advance();
136
137 assert_eq!(cache.len(), 1);
138 assert_eq!(cache.k(0).len(), 64);
139 assert_eq!(cache.v(1).len(), 64);
140 }
141
142 #[test]
143 fn append_multiple_tokens() {
144 let mut cache = KVCache::new(1, 2, 8);
145
146 for token_idx in 0..5 {
147 let data = vec![token_idx as f32; 16];
148 cache.append(0, &data, &data);
149 cache.advance();
150 }
151
152 assert_eq!(cache.len(), 5);
153 assert_eq!(cache.k(0).len(), 80); assert_eq!(cache.k(0)[0], 0.0);
156 assert_eq!(cache.k(0)[64], 4.0);
157 }
158
159 #[test]
160 fn clear_cache() {
161 let mut cache = KVCache::new(2, 4, 16);
162 let data = vec![1.0f32; 64];
163 cache.append(0, &data, &data);
164 cache.append(1, &data, &data);
165 cache.advance();
166 assert_eq!(cache.len(), 1);
167
168 cache.clear();
169 assert_eq!(cache.len(), 0);
170 assert!(cache.is_empty());
171 assert!(cache.k(0).is_empty());
172 }
173
174 #[test]
175 fn with_capacity() {
176 let cache = KVCache::with_capacity(16, 8, 64, 2048);
177 assert!(cache.is_empty());
178 assert_eq!(cache.num_layers(), 16);
179 assert_eq!(cache.entry_size(), 512);
180 }
181}