1pub type Qfloat = f32;
12
13const NONE: usize = usize::MAX;
15
16struct LruNode {
18 prev: usize,
19 next: usize,
20}
21
22pub struct Cache {
32 l: usize,
34 size: usize,
36 data: Vec<Option<Vec<Qfloat>>>,
38 len: Vec<usize>,
40 nodes: Vec<LruNode>,
43}
44
45impl Cache {
46 pub fn new(l: usize, size_bytes: usize) -> Self {
48 let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50 let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52 size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55 let mut nodes: Vec<LruNode> = (0..l)
57 .map(|_| LruNode {
58 prev: NONE,
59 next: NONE,
60 })
61 .collect();
62 nodes.push(LruNode { prev: l, next: l });
64
65 Cache {
66 l,
67 size,
68 data: (0..l).map(|_| None).collect(),
69 len: vec![0; l],
70 nodes,
71 }
72 }
73
74 #[inline]
76 fn lru_delete(&mut self, i: usize) {
77 let prev = self.nodes[i].prev;
78 let next = self.nodes[i].next;
79 self.nodes[prev].next = next;
80 self.nodes[next].prev = prev;
81 self.nodes[i].prev = NONE;
82 self.nodes[i].next = NONE;
83 }
84
85 #[inline]
87 fn lru_insert(&mut self, i: usize) {
88 let head = self.l; let tail = self.nodes[head].prev;
90 self.nodes[i].next = head;
91 self.nodes[i].prev = tail;
92 self.nodes[tail].next = i;
93 self.nodes[head].prev = i;
94 }
95
96 #[inline]
98 fn in_lru(&self, i: usize) -> bool {
99 self.nodes[i].prev != NONE
100 }
101
102 pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
110 assert!(index < self.l);
111
112 if self.in_lru(index) {
114 self.lru_delete(index);
115 }
116
117 let old_len = self.len[index];
118 let more = request_len.saturating_sub(old_len);
119
120 if more > 0 {
121 let head = self.l;
123 while self.size < more {
124 let victim = self.nodes[head].next;
125 if victim == head {
126 break; }
128 self.lru_delete(victim);
129 self.size += self.len[victim];
130 self.data[victim] = None;
131 self.len[victim] = 0;
132 }
133
134 let entry = self.data[index].get_or_insert_with(Vec::new);
136 entry.resize(request_len, 0.0);
137 self.size -= more;
138 self.len[index] = request_len;
139 }
140
141 self.lru_insert(index);
143
144 let start = old_len;
145 (self.data[index].as_mut().unwrap().as_mut_slice(), start)
146 }
147
148 pub fn swap_index(&mut self, i: usize, j: usize) {
152 if i == j {
153 return;
154 }
155
156 let i_in = self.in_lru(i);
157 let j_in = self.in_lru(j);
158
159 if i_in {
160 self.lru_delete(i);
161 }
162 if j_in {
163 self.lru_delete(j);
164 }
165
166 self.data.swap(i, j);
168 self.len.swap(i, j);
169
170 if i_in {
172 self.lru_insert(j);
173 }
174 if j_in {
175 self.lru_insert(i);
176 }
177
178 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
182 let head = self.l;
183 let mut h = self.nodes[head].next;
184 while h != head {
185 let next = self.nodes[h].next;
186 if self.len[h] > lo {
187 if self.len[h] > hi {
188 if let Some(ref mut row) = self.data[h] {
190 row.swap(lo, hi);
191 }
192 } else {
193 self.lru_delete(h);
195 self.size += self.len[h];
196 self.data[h] = None;
197 self.len[h] = 0;
198 }
199 }
200 h = next;
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208
209 #[test]
210 fn basic_get_and_fill() {
211 let mut cache = Cache::new(3, 100);
212 let (data, start) = cache.get_data(0, 3);
213 assert_eq!(start, 0);
214 assert_eq!(data.len(), 3);
215 data[0] = 1.0;
216 data[1] = 2.0;
217 data[2] = 3.0;
218
219 let (data, start) = cache.get_data(0, 3);
221 assert_eq!(start, 3);
222 assert_eq!(data[0], 1.0);
223 assert_eq!(data[1], 2.0);
224 assert_eq!(data[2], 3.0);
225 }
226
227 #[test]
228 fn extend_cached_row() {
229 let mut cache = Cache::new(3, 1000);
230 let (data, start) = cache.get_data(0, 2);
231 assert_eq!(start, 0);
232 data[0] = 10.0;
233 data[1] = 20.0;
234
235 let (data, start) = cache.get_data(0, 3);
236 assert_eq!(start, 2);
237 assert_eq!(data[0], 10.0);
238 assert_eq!(data[1], 20.0);
239 data[2] = 30.0;
240 }
241
242 #[test]
243 fn lru_eviction() {
244 let l = 10;
245 let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
246 let mut cache = Cache::new(l, bytes);
247
248 let (data, start) = cache.get_data(0, l);
249 assert_eq!(start, 0);
250 data[0] = 1.0;
251
252 let (data, start) = cache.get_data(1, l);
253 assert_eq!(start, 0);
254 data[0] = 3.0;
255
256 let (data, start) = cache.get_data(2, l);
258 assert_eq!(start, 0);
259 data[0] = 5.0;
260
261 let (_, start) = cache.get_data(0, l);
263 assert_eq!(start, 0);
264 }
265
266 #[test]
267 fn lru_order_respects_access() {
268 let l = 5;
271 let row_len = l;
272 let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
273 let budget = 3 * row_len + header;
274 let bytes = budget * std::mem::size_of::<Qfloat>();
275 let mut cache = Cache::new(l, bytes);
276
277 let (d, _) = cache.get_data(0, row_len);
279 d[0] = 10.0;
280 let (d, _) = cache.get_data(1, row_len);
281 d[0] = 20.0;
282 let (d, _) = cache.get_data(2, row_len);
283 d[0] = 30.0;
284
285 let (d, start) = cache.get_data(0, row_len);
287 assert_eq!(start, row_len); assert_eq!(d[0], 10.0);
289
290 let (d, start) = cache.get_data(3, row_len);
293 assert_eq!(start, 0); d[0] = 40.0;
295
296 assert!(cache.data[1].is_none());
298 assert!(cache.data[0].is_some());
300 assert!(cache.data[2].is_some());
301 }
302
303 #[test]
304 fn swap_index_row_swap() {
305 let mut cache = Cache::new(3, 1000);
308 let (data, _) = cache.get_data(0, 3);
309 data[0] = 10.0;
310 data[1] = 20.0;
311 data[2] = 30.0;
312
313 cache.swap_index(0, 2);
314
315 let (data, start) = cache.get_data(2, 3);
317 assert_eq!(start, 3); assert_eq!(data[0], 30.0); assert_eq!(data[1], 20.0);
320 assert_eq!(data[2], 10.0); let (_, start) = cache.get_data(0, 3);
324 assert_eq!(start, 0);
325 }
326
327 #[test]
328 fn swap_index_swaps_columns_in_other_rows() {
329 let mut cache = Cache::new(4, 10000);
331
332 let (data, _) = cache.get_data(0, 4);
334 data[0] = 1.0;
335 data[1] = 2.0;
336 data[2] = 3.0;
337 data[3] = 4.0;
338
339 let (data, _) = cache.get_data(1, 4);
341 data[0] = 10.0;
342 data[1] = 20.0;
343 data[2] = 30.0;
344 data[3] = 40.0;
345
346 cache.swap_index(1, 3);
348
349 let (data, start) = cache.get_data(0, 4);
351 assert_eq!(start, 4); assert_eq!(data[0], 1.0);
353 assert_eq!(data[1], 4.0); assert_eq!(data[2], 3.0);
355 assert_eq!(data[3], 2.0); let (data, start) = cache.get_data(3, 4);
360 assert_eq!(start, 4); assert_eq!(data[0], 10.0);
362 assert_eq!(data[1], 40.0); assert_eq!(data[2], 30.0);
364 assert_eq!(data[3], 20.0); }
366
367 #[test]
368 fn swap_index_evicts_partial_rows() {
369 let mut cache = Cache::new(4, 10000);
371
372 let (data, _) = cache.get_data(0, 4);
374 data[0] = 1.0;
375 data[1] = 2.0;
376 data[2] = 3.0;
377 data[3] = 4.0;
378
379 let (data, _) = cache.get_data(1, 2);
381 data[0] = 10.0;
382 data[1] = 20.0;
383
384 cache.swap_index(1, 3);
386
387 let (data, start) = cache.get_data(0, 4);
389 assert_eq!(start, 4);
390 assert_eq!(data[1], 4.0);
391 assert_eq!(data[3], 2.0);
392
393 let (_, start) = cache.get_data(1, 2);
395 assert_eq!(start, 0);
396 }
397}