1pub type Qfloat = f32;
12
13const NONE: usize = usize::MAX;
15
16struct LruNode {
18 prev: usize,
19 next: usize,
20}
21
22pub struct Cache {
32 l: usize,
34 size: usize,
36 data: Vec<Option<Vec<Qfloat>>>,
38 len: Vec<usize>,
40 nodes: Vec<LruNode>,
43}
44
45impl Cache {
46 pub fn new(l: usize, size_bytes: usize) -> Self {
48 let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50 let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52 size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55 let mut nodes: Vec<LruNode> = (0..l)
57 .map(|_| LruNode { prev: NONE, next: NONE })
58 .collect();
59 nodes.push(LruNode { prev: l, next: l });
61
62 Cache {
63 l,
64 size,
65 data: (0..l).map(|_| None).collect(),
66 len: vec![0; l],
67 nodes,
68 }
69 }
70
71 #[inline]
73 fn lru_delete(&mut self, i: usize) {
74 let prev = self.nodes[i].prev;
75 let next = self.nodes[i].next;
76 self.nodes[prev].next = next;
77 self.nodes[next].prev = prev;
78 self.nodes[i].prev = NONE;
79 self.nodes[i].next = NONE;
80 }
81
82 #[inline]
84 fn lru_insert(&mut self, i: usize) {
85 let head = self.l; let tail = self.nodes[head].prev;
87 self.nodes[i].next = head;
88 self.nodes[i].prev = tail;
89 self.nodes[tail].next = i;
90 self.nodes[head].prev = i;
91 }
92
93 #[inline]
95 fn in_lru(&self, i: usize) -> bool {
96 self.nodes[i].prev != NONE
97 }
98
99 pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
107 assert!(index < self.l);
108
109 if self.in_lru(index) {
111 self.lru_delete(index);
112 }
113
114 let old_len = self.len[index];
115 let more = request_len.saturating_sub(old_len);
116
117 if more > 0 {
118 let head = self.l;
120 while self.size < more {
121 let victim = self.nodes[head].next;
122 if victim == head {
123 break; }
125 self.lru_delete(victim);
126 self.size += self.len[victim];
127 self.data[victim] = None;
128 self.len[victim] = 0;
129 }
130
131 let entry = self.data[index].get_or_insert_with(Vec::new);
133 entry.resize(request_len, 0.0);
134 self.size -= more;
135 self.len[index] = request_len;
136 }
137
138 self.lru_insert(index);
140
141 let start = old_len;
142 (self.data[index].as_mut().unwrap().as_mut_slice(), start)
143 }
144
145 pub fn swap_index(&mut self, i: usize, j: usize) {
149 if i == j {
150 return;
151 }
152
153 let i_in = self.in_lru(i);
154 let j_in = self.in_lru(j);
155
156 if i_in {
157 self.lru_delete(i);
158 }
159 if j_in {
160 self.lru_delete(j);
161 }
162
163 self.data.swap(i, j);
165 self.len.swap(i, j);
166
167 if i_in {
169 self.lru_insert(j);
170 }
171 if j_in {
172 self.lru_insert(i);
173 }
174
175 let (lo, hi) = if i < j { (i, j) } else { (j, i) };
179 let head = self.l;
180 let mut h = self.nodes[head].next;
181 while h != head {
182 let next = self.nodes[h].next;
183 if self.len[h] > lo {
184 if self.len[h] > hi {
185 if let Some(ref mut row) = self.data[h] {
187 row.swap(lo, hi);
188 }
189 } else {
190 self.lru_delete(h);
192 self.size += self.len[h];
193 self.data[h] = None;
194 self.len[h] = 0;
195 }
196 }
197 h = next;
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205
206 #[test]
207 fn basic_get_and_fill() {
208 let mut cache = Cache::new(3, 100);
209 let (data, start) = cache.get_data(0, 3);
210 assert_eq!(start, 0);
211 assert_eq!(data.len(), 3);
212 data[0] = 1.0;
213 data[1] = 2.0;
214 data[2] = 3.0;
215
216 let (data, start) = cache.get_data(0, 3);
218 assert_eq!(start, 3);
219 assert_eq!(data[0], 1.0);
220 assert_eq!(data[1], 2.0);
221 assert_eq!(data[2], 3.0);
222 }
223
224 #[test]
225 fn extend_cached_row() {
226 let mut cache = Cache::new(3, 1000);
227 let (data, start) = cache.get_data(0, 2);
228 assert_eq!(start, 0);
229 data[0] = 10.0;
230 data[1] = 20.0;
231
232 let (data, start) = cache.get_data(0, 3);
233 assert_eq!(start, 2);
234 assert_eq!(data[0], 10.0);
235 assert_eq!(data[1], 20.0);
236 data[2] = 30.0;
237 }
238
239 #[test]
240 fn lru_eviction() {
241 let l = 10;
242 let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
243 let mut cache = Cache::new(l, bytes);
244
245 let (data, start) = cache.get_data(0, l);
246 assert_eq!(start, 0);
247 data[0] = 1.0;
248
249 let (data, start) = cache.get_data(1, l);
250 assert_eq!(start, 0);
251 data[0] = 3.0;
252
253 let (data, start) = cache.get_data(2, l);
255 assert_eq!(start, 0);
256 data[0] = 5.0;
257
258 let (_, start) = cache.get_data(0, l);
260 assert_eq!(start, 0);
261 }
262
263 #[test]
264 fn lru_order_respects_access() {
265 let l = 5;
268 let row_len = l;
269 let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
270 let budget = 3 * row_len + header;
271 let bytes = budget * std::mem::size_of::<Qfloat>();
272 let mut cache = Cache::new(l, bytes);
273
274 let (d, _) = cache.get_data(0, row_len);
276 d[0] = 10.0;
277 let (d, _) = cache.get_data(1, row_len);
278 d[0] = 20.0;
279 let (d, _) = cache.get_data(2, row_len);
280 d[0] = 30.0;
281
282 let (d, start) = cache.get_data(0, row_len);
284 assert_eq!(start, row_len); assert_eq!(d[0], 10.0);
286
287 let (d, start) = cache.get_data(3, row_len);
290 assert_eq!(start, 0); d[0] = 40.0;
292
293 assert!(cache.data[1].is_none());
295 assert!(cache.data[0].is_some());
297 assert!(cache.data[2].is_some());
298 }
299
300 #[test]
301 fn swap_index_row_swap() {
302 let mut cache = Cache::new(3, 1000);
305 let (data, _) = cache.get_data(0, 3);
306 data[0] = 10.0;
307 data[1] = 20.0;
308 data[2] = 30.0;
309
310 cache.swap_index(0, 2);
311
312 let (data, start) = cache.get_data(2, 3);
314 assert_eq!(start, 3); assert_eq!(data[0], 30.0); assert_eq!(data[1], 20.0);
317 assert_eq!(data[2], 10.0); let (_, start) = cache.get_data(0, 3);
321 assert_eq!(start, 0);
322 }
323
324 #[test]
325 fn swap_index_swaps_columns_in_other_rows() {
326 let mut cache = Cache::new(4, 10000);
328
329 let (data, _) = cache.get_data(0, 4);
331 data[0] = 1.0;
332 data[1] = 2.0;
333 data[2] = 3.0;
334 data[3] = 4.0;
335
336 let (data, _) = cache.get_data(1, 4);
338 data[0] = 10.0;
339 data[1] = 20.0;
340 data[2] = 30.0;
341 data[3] = 40.0;
342
343 cache.swap_index(1, 3);
345
346 let (data, start) = cache.get_data(0, 4);
348 assert_eq!(start, 4); assert_eq!(data[0], 1.0);
350 assert_eq!(data[1], 4.0); assert_eq!(data[2], 3.0);
352 assert_eq!(data[3], 2.0); let (data, start) = cache.get_data(3, 4);
357 assert_eq!(start, 4); assert_eq!(data[0], 10.0);
359 assert_eq!(data[1], 40.0); assert_eq!(data[2], 30.0);
361 assert_eq!(data[3], 20.0); }
363
364 #[test]
365 fn swap_index_evicts_partial_rows() {
366 let mut cache = Cache::new(4, 10000);
368
369 let (data, _) = cache.get_data(0, 4);
371 data[0] = 1.0;
372 data[1] = 2.0;
373 data[2] = 3.0;
374 data[3] = 4.0;
375
376 let (data, _) = cache.get_data(1, 2);
378 data[0] = 10.0;
379 data[1] = 20.0;
380
381 cache.swap_index(1, 3);
383
384 let (data, start) = cache.get_data(0, 4);
386 assert_eq!(start, 4);
387 assert_eq!(data[1], 4.0);
388 assert_eq!(data[3], 2.0);
389
390 let (_, start) = cache.get_data(1, 2);
392 assert_eq!(start, 0);
393 }
394}