1pub type Qfloat = f32;
12
13const NONE: usize = usize::MAX;
15
16struct LruNode {
18 prev: usize,
19 next: usize,
20}
21
22pub struct Cache {
32 l: usize,
34 size: usize,
36 data: Vec<Option<Vec<Qfloat>>>,
38 len: Vec<usize>,
40 nodes: Vec<LruNode>,
43}
44
45impl Cache {
46 pub fn new(l: usize, size_bytes: usize) -> Self {
48 let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50 let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52 size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55 let mut nodes: Vec<LruNode> = (0..l)
57 .map(|_| LruNode {
58 prev: NONE,
59 next: NONE,
60 })
61 .collect();
62 nodes.push(LruNode { prev: l, next: l });
64
65 Cache {
66 l,
67 size,
68 data: (0..l).map(|_| None).collect(),
69 len: vec![0; l],
70 nodes,
71 }
72 }
73
74 #[inline]
76 fn lru_delete(&mut self, i: usize) {
77 let prev = self.nodes[i].prev;
78 let next = self.nodes[i].next;
79 self.nodes[prev].next = next;
80 self.nodes[next].prev = prev;
81 self.nodes[i].prev = NONE;
82 self.nodes[i].next = NONE;
83 }
84
85 #[inline]
87 fn lru_insert(&mut self, i: usize) {
88 let head = self.l; let tail = self.nodes[head].prev;
90 self.nodes[i].next = head;
91 self.nodes[i].prev = tail;
92 self.nodes[tail].next = i;
93 self.nodes[head].prev = i;
94 }
95
96 #[inline]
98 fn in_lru(&self, i: usize) -> bool {
99 self.nodes[i].prev != NONE
100 }
101
102 pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
110 assert!(index < self.l);
111
112 if self.in_lru(index) {
114 self.lru_delete(index);
115 }
116
117 let old_len = self.len[index];
118 let more = request_len.saturating_sub(old_len);
119
120 if more > 0 {
121 let head = self.l;
123 while self.size < more {
124 let victim = self.nodes[head].next;
125 if victim == head {
126 break; }
128 self.lru_delete(victim);
129 self.size += self.len[victim];
130 self.data[victim] = None;
131 self.len[victim] = 0;
132 }
133
134 let entry = self.data[index].get_or_insert_with(Vec::new);
136 entry.resize(request_len, 0.0);
137 self.size -= more;
138 self.len[index] = request_len;
139 }
140
141 self.lru_insert(index);
143
144 let start = old_len;
145 let data = self.data[index].get_or_insert_with(Vec::new);
146 (data.as_mut_slice(), start)
147 }
148
149 pub fn swap_index(&mut self, i: usize, j: usize) {
153 if i == j {
154 return;
155 }
156
157 let i_in = self.in_lru(i);
158 let j_in = self.in_lru(j);
159
160 if i_in {
161 self.lru_delete(i);
162 }
163 if j_in {
164 self.lru_delete(j);
165 }
166
167 self.data.swap(i, j);
169 self.len.swap(i, j);
170
171 if i_in {
173 self.lru_insert(j);
174 }
175 if j_in {
176 self.lru_insert(i);
177 }
178
179 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
183 let head = self.l;
184 let mut h = self.nodes[head].next;
185 while h != head {
186 let next = self.nodes[h].next;
187 if self.len[h] > lo {
188 if self.len[h] > hi {
189 if let Some(ref mut row) = self.data[h] {
191 row.swap(lo, hi);
192 }
193 } else {
194 self.lru_delete(h);
196 self.size += self.len[h];
197 self.data[h] = None;
198 self.len[h] = 0;
199 }
200 }
201 h = next;
202 }
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[test]
211 fn basic_get_and_fill() {
212 let mut cache = Cache::new(3, 100);
213 let (data, start) = cache.get_data(0, 3);
214 assert_eq!(start, 0);
215 assert_eq!(data.len(), 3);
216 data[0] = 1.0;
217 data[1] = 2.0;
218 data[2] = 3.0;
219
220 let (data, start) = cache.get_data(0, 3);
222 assert_eq!(start, 3);
223 assert_eq!(data[0], 1.0);
224 assert_eq!(data[1], 2.0);
225 assert_eq!(data[2], 3.0);
226 }
227
228 #[test]
229 fn extend_cached_row() {
230 let mut cache = Cache::new(3, 1000);
231 let (data, start) = cache.get_data(0, 2);
232 assert_eq!(start, 0);
233 data[0] = 10.0;
234 data[1] = 20.0;
235
236 let (data, start) = cache.get_data(0, 3);
237 assert_eq!(start, 2);
238 assert_eq!(data[0], 10.0);
239 assert_eq!(data[1], 20.0);
240 data[2] = 30.0;
241 }
242
243 #[test]
244 fn lru_eviction() {
245 let l = 10;
246 let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
247 let mut cache = Cache::new(l, bytes);
248
249 let (data, start) = cache.get_data(0, l);
250 assert_eq!(start, 0);
251 data[0] = 1.0;
252
253 let (data, start) = cache.get_data(1, l);
254 assert_eq!(start, 0);
255 data[0] = 3.0;
256
257 let (data, start) = cache.get_data(2, l);
259 assert_eq!(start, 0);
260 data[0] = 5.0;
261
262 let (_, start) = cache.get_data(0, l);
264 assert_eq!(start, 0);
265 }
266
267 #[test]
268 fn lru_order_respects_access() {
269 let l = 5;
272 let row_len = l;
273 let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
274 let budget = 3 * row_len + header;
275 let bytes = budget * std::mem::size_of::<Qfloat>();
276 let mut cache = Cache::new(l, bytes);
277
278 let (d, _) = cache.get_data(0, row_len);
280 d[0] = 10.0;
281 let (d, _) = cache.get_data(1, row_len);
282 d[0] = 20.0;
283 let (d, _) = cache.get_data(2, row_len);
284 d[0] = 30.0;
285
286 let (d, start) = cache.get_data(0, row_len);
288 assert_eq!(start, row_len); assert_eq!(d[0], 10.0);
290
291 let (d, start) = cache.get_data(3, row_len);
294 assert_eq!(start, 0); d[0] = 40.0;
296
297 assert!(cache.data[1].is_none());
299 assert!(cache.data[0].is_some());
301 assert!(cache.data[2].is_some());
302 }
303
304 #[test]
305 fn swap_index_row_swap() {
306 let mut cache = Cache::new(3, 1000);
309 let (data, _) = cache.get_data(0, 3);
310 data[0] = 10.0;
311 data[1] = 20.0;
312 data[2] = 30.0;
313
314 cache.swap_index(0, 2);
315
316 let (data, start) = cache.get_data(2, 3);
318 assert_eq!(start, 3); assert_eq!(data[0], 30.0); assert_eq!(data[1], 20.0);
321 assert_eq!(data[2], 10.0); let (_, start) = cache.get_data(0, 3);
325 assert_eq!(start, 0);
326 }
327
328 #[test]
329 fn swap_index_swaps_columns_in_other_rows() {
330 let mut cache = Cache::new(4, 10000);
332
333 let (data, _) = cache.get_data(0, 4);
335 data[0] = 1.0;
336 data[1] = 2.0;
337 data[2] = 3.0;
338 data[3] = 4.0;
339
340 let (data, _) = cache.get_data(1, 4);
342 data[0] = 10.0;
343 data[1] = 20.0;
344 data[2] = 30.0;
345 data[3] = 40.0;
346
347 cache.swap_index(1, 3);
349
350 let (data, start) = cache.get_data(0, 4);
352 assert_eq!(start, 4); assert_eq!(data[0], 1.0);
354 assert_eq!(data[1], 4.0); assert_eq!(data[2], 3.0);
356 assert_eq!(data[3], 2.0); let (data, start) = cache.get_data(3, 4);
361 assert_eq!(start, 4); assert_eq!(data[0], 10.0);
363 assert_eq!(data[1], 40.0); assert_eq!(data[2], 30.0);
365 assert_eq!(data[3], 20.0); }
367
368 #[test]
369 fn swap_index_evicts_partial_rows() {
370 let mut cache = Cache::new(4, 10000);
372
373 let (data, _) = cache.get_data(0, 4);
375 data[0] = 1.0;
376 data[1] = 2.0;
377 data[2] = 3.0;
378 data[3] = 4.0;
379
380 let (data, _) = cache.get_data(1, 2);
382 data[0] = 10.0;
383 data[1] = 20.0;
384
385 cache.swap_index(1, 3);
387
388 let (data, start) = cache.get_data(0, 4);
390 assert_eq!(start, 4);
391 assert_eq!(data[1], 4.0);
392 assert_eq!(data[3], 2.0);
393
394 let (_, start) = cache.get_data(1, 2);
396 assert_eq!(start, 0);
397 }
398}