Skip to main content

libsvm_rs/
cache.rs

1//! LRU kernel cache matching the original LIBSVM.
2//!
3//! The cache stores rows of the kernel matrix Q as `Qfloat` (`f32`) slices.
4//! When memory is exhausted, the least-recently-used row is evicted.
5//!
6//! The C++ original uses a doubly-linked circular list with raw pointers.
7//! This Rust version uses an index-based circular doubly-linked list for
8//! O(1) LRU operations, avoiding unsafe code while matching the semantics.
9
10/// Element type for cached kernel matrix rows. Matches LIBSVM's `Qfloat = float`.
11pub type Qfloat = f32;
12
13/// Sentinel index representing "no link" in the LRU list.
14const NONE: usize = usize::MAX;
15
16/// Per-row LRU node: stores prev/next indices for a circular doubly-linked list.
17struct LruNode {
18    prev: usize,
19    next: usize,
20}
21
22/// LRU cache for kernel matrix rows.
23///
24/// Each of the `l` data items may have a cached row of length up to `l`.
25/// The cache tracks how much memory (in Qfloat units) is in use and evicts
26/// LRU entries when the budget is exceeded.
27///
28/// The LRU list is a circular doubly-linked list using array indices.
29/// Index `l` is the sentinel head node. All operations (insert, remove,
30/// evict) are O(1).
31pub struct Cache {
32    /// Number of data items (rows in the Q matrix).
33    l: usize,
34    /// Available budget in Qfloat units.
35    size: usize,
36    /// Per-row cached data. `None` means not cached.
37    data: Vec<Option<Vec<Qfloat>>>,
38    /// Per-row cached length (how many elements are valid).
39    len: Vec<usize>,
40    /// LRU doubly-linked list nodes. Index `l` is the sentinel head.
41    /// Nodes with `prev == NONE` are not in the LRU list.
42    nodes: Vec<LruNode>,
43}
44
45impl Cache {
46    /// Create a new cache for `l` data items with `size_bytes` of memory.
47    pub fn new(l: usize, size_bytes: usize) -> Self {
48        // Convert bytes to Qfloat units
49        let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50        // Subtract header overhead (metadata per row)
51        let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52        // Cache must be large enough for at least two columns
53        size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55        // Create l+1 nodes: 0..l for data rows, l for sentinel head
56        let mut nodes: Vec<LruNode> = (0..l)
57            .map(|_| LruNode {
58                prev: NONE,
59                next: NONE,
60            })
61            .collect();
62        // Sentinel head points to itself (empty list)
63        nodes.push(LruNode { prev: l, next: l });
64
65        Cache {
66            l,
67            size,
68            data: (0..l).map(|_| None).collect(),
69            len: vec![0; l],
70            nodes,
71        }
72    }
73
74    /// Remove node `i` from the LRU list. O(1).
75    #[inline]
76    fn lru_delete(&mut self, i: usize) {
77        let prev = self.nodes[i].prev;
78        let next = self.nodes[i].next;
79        self.nodes[prev].next = next;
80        self.nodes[next].prev = prev;
81        self.nodes[i].prev = NONE;
82        self.nodes[i].next = NONE;
83    }
84
85    /// Insert node `i` at the back of the LRU list (most recently used). O(1).
86    #[inline]
87    fn lru_insert(&mut self, i: usize) {
88        let head = self.l; // sentinel
89        let tail = self.nodes[head].prev;
90        self.nodes[i].next = head;
91        self.nodes[i].prev = tail;
92        self.nodes[tail].next = i;
93        self.nodes[head].prev = i;
94    }
95
96    /// Check if node `i` is in the LRU list.
97    #[inline]
98    fn in_lru(&self, i: usize) -> bool {
99        self.nodes[i].prev != NONE
100    }
101
102    /// Request data for row `index` of length `request_len`.
103    ///
104    /// Returns `(data, start)` where `data` is the cached row slice and
105    /// `start` is the position from which data needs to be filled.
106    /// If `start >= request_len`, the entire row was already cached.
107    ///
108    /// The caller must fill `data[start..request_len]` with kernel values.
109    pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
110        assert!(index < self.l);
111
112        // Remove from LRU if present (will re-insert at tail)
113        if self.in_lru(index) {
114            self.lru_delete(index);
115        }
116
117        let old_len = self.len[index];
118        let more = request_len.saturating_sub(old_len);
119
120        if more > 0 {
121            // Evict LRU entries until we have enough space
122            let head = self.l;
123            while self.size < more {
124                let victim = self.nodes[head].next;
125                if victim == head {
126                    break; // list empty
127                }
128                self.lru_delete(victim);
129                self.size += self.len[victim];
130                self.data[victim] = None;
131                self.len[victim] = 0;
132            }
133
134            // Allocate or extend
135            let entry = self.data[index].get_or_insert_with(Vec::new);
136            entry.resize(request_len, 0.0);
137            self.size -= more;
138            self.len[index] = request_len;
139        }
140
141        // Insert at back of LRU (most recently used)
142        self.lru_insert(index);
143
144        let start = old_len;
145        let data = self.data[index].get_or_insert_with(Vec::new);
146        (data.as_mut_slice(), start)
147    }
148
149    /// Swap indices `i` and `j` in the cache.
150    ///
151    /// Used by the solver when rearranging the working set.
152    pub fn swap_index(&mut self, i: usize, j: usize) {
153        if i == j {
154            return;
155        }
156
157        let i_in = self.in_lru(i);
158        let j_in = self.in_lru(j);
159
160        if i_in {
161            self.lru_delete(i);
162        }
163        if j_in {
164            self.lru_delete(j);
165        }
166
167        // Swap data and lengths
168        self.data.swap(i, j);
169        self.len.swap(i, j);
170
171        // Re-insert with swapped identities
172        if i_in {
173            self.lru_insert(j);
174        }
175        if j_in {
176            self.lru_insert(i);
177        }
178
179        // Column swap: iterate over all cached rows and swap columns i,j.
180        // If a row covers the lower index but not the higher, evict it
181        // (matching the C++ "give up" behavior).
182        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
183        let head = self.l;
184        let mut h = self.nodes[head].next;
185        while h != head {
186            let next = self.nodes[h].next;
187            if self.len[h] > lo {
188                if self.len[h] > hi {
189                    // Row covers both positions — swap the column entries
190                    if let Some(ref mut row) = self.data[h] {
191                        row.swap(lo, hi);
192                    }
193                } else {
194                    // Row covers lo but not hi — evict ("give up")
195                    self.lru_delete(h);
196                    self.size += self.len[h];
197                    self.data[h] = None;
198                    self.len[h] = 0;
199                }
200            }
201            h = next;
202        }
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[test]
211    fn basic_get_and_fill() {
212        let mut cache = Cache::new(3, 100);
213        let (data, start) = cache.get_data(0, 3);
214        assert_eq!(start, 0);
215        assert_eq!(data.len(), 3);
216        data[0] = 1.0;
217        data[1] = 2.0;
218        data[2] = 3.0;
219
220        // Second access should return start=3 (fully cached)
221        let (data, start) = cache.get_data(0, 3);
222        assert_eq!(start, 3);
223        assert_eq!(data[0], 1.0);
224        assert_eq!(data[1], 2.0);
225        assert_eq!(data[2], 3.0);
226    }
227
228    #[test]
229    fn extend_cached_row() {
230        let mut cache = Cache::new(3, 1000);
231        let (data, start) = cache.get_data(0, 2);
232        assert_eq!(start, 0);
233        data[0] = 10.0;
234        data[1] = 20.0;
235
236        let (data, start) = cache.get_data(0, 3);
237        assert_eq!(start, 2);
238        assert_eq!(data[0], 10.0);
239        assert_eq!(data[1], 20.0);
240        data[2] = 30.0;
241    }
242
243    #[test]
244    fn lru_eviction() {
245        let l = 10;
246        let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
247        let mut cache = Cache::new(l, bytes);
248
249        let (data, start) = cache.get_data(0, l);
250        assert_eq!(start, 0);
251        data[0] = 1.0;
252
253        let (data, start) = cache.get_data(1, l);
254        assert_eq!(start, 0);
255        data[0] = 3.0;
256
257        // Should evict row 0 (LRU)
258        let (data, start) = cache.get_data(2, l);
259        assert_eq!(start, 0);
260        data[0] = 5.0;
261
262        // Row 0 evicted
263        let (_, start) = cache.get_data(0, l);
264        assert_eq!(start, 0);
265    }
266
267    #[test]
268    fn lru_order_respects_access() {
269        // Verify that re-accessing a row moves it to MRU position.
270        // l=5, row_len=5: min cache = 2*5=10 Qfloats. Budget = 3 rows = 15.
271        let l = 5;
272        let row_len = l;
273        let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
274        let budget = 3 * row_len + header;
275        let bytes = budget * std::mem::size_of::<Qfloat>();
276        let mut cache = Cache::new(l, bytes);
277
278        // Fill rows 0, 1, 2. LRU order: 0(oldest), 1, 2(newest)
279        let (d, _) = cache.get_data(0, row_len);
280        d[0] = 10.0;
281        let (d, _) = cache.get_data(1, row_len);
282        d[0] = 20.0;
283        let (d, _) = cache.get_data(2, row_len);
284        d[0] = 30.0;
285
286        // Touch row 0 → LRU order: 1(oldest), 2, 0(newest)
287        let (d, start) = cache.get_data(0, row_len);
288        assert_eq!(start, row_len); // already cached, no fill needed
289        assert_eq!(d[0], 10.0);
290
291        // Insert row 3 → must evict. Row 1 is LRU, so it gets evicted.
292        // LRU order: 2, 0, 3(newest)
293        let (d, start) = cache.get_data(3, row_len);
294        assert_eq!(start, 0); // new row, needs fill
295        d[0] = 40.0;
296
297        // Row 1 was evicted
298        assert!(cache.data[1].is_none());
299        // Row 0 and 2 are still cached
300        assert!(cache.data[0].is_some());
301        assert!(cache.data[2].is_some());
302    }
303
304    #[test]
305    fn swap_index_row_swap() {
306        // Test that rows are swapped correctly.
307        // Use full-length rows to avoid column-swap eviction.
308        let mut cache = Cache::new(3, 1000);
309        let (data, _) = cache.get_data(0, 3);
310        data[0] = 10.0;
311        data[1] = 20.0;
312        data[2] = 30.0;
313
314        cache.swap_index(0, 2);
315
316        // Row data moved from position 0 to position 2; columns 0,2 also swapped
317        let (data, start) = cache.get_data(2, 3);
318        assert_eq!(start, 3); // fully cached
319        assert_eq!(data[0], 30.0); // was data[2], swapped to col 0
320        assert_eq!(data[1], 20.0);
321        assert_eq!(data[2], 10.0); // was data[0], swapped to col 2
322
323        // Position 0 was empty (row 2 had no data), needs fill
324        let (_, start) = cache.get_data(0, 3);
325        assert_eq!(start, 0);
326    }
327
328    #[test]
329    fn swap_index_swaps_columns_in_other_rows() {
330        // 4 items, generous cache so nothing is evicted by budget
331        let mut cache = Cache::new(4, 10000);
332
333        // Fill row 0 with [1, 2, 3, 4] (covers all 4 columns)
334        let (data, _) = cache.get_data(0, 4);
335        data[0] = 1.0;
336        data[1] = 2.0;
337        data[2] = 3.0;
338        data[3] = 4.0;
339
340        // Fill row 1 with [10, 20, 30, 40]
341        let (data, _) = cache.get_data(1, 4);
342        data[0] = 10.0;
343        data[1] = 20.0;
344        data[2] = 30.0;
345        data[3] = 40.0;
346
347        // Swap indices 1 and 3
348        cache.swap_index(1, 3);
349
350        // Row 0 (now at index 0, not swapped) should have columns 1,3 swapped
351        let (data, start) = cache.get_data(0, 4);
352        assert_eq!(start, 4); // still cached
353        assert_eq!(data[0], 1.0);
354        assert_eq!(data[1], 4.0); // was data[3]
355        assert_eq!(data[2], 3.0);
356        assert_eq!(data[3], 2.0); // was data[1]
357
358        // Row 1 was swapped to position 3, row 3 was swapped to position 1
359        // Row at position 3 (formerly row 1) should have columns 1,3 swapped
360        let (data, start) = cache.get_data(3, 4);
361        assert_eq!(start, 4); // still cached
362        assert_eq!(data[0], 10.0);
363        assert_eq!(data[1], 40.0); // was data[3]
364        assert_eq!(data[2], 30.0);
365        assert_eq!(data[3], 20.0); // was data[1]
366    }
367
368    #[test]
369    fn swap_index_evicts_partial_rows() {
370        // 4 items
371        let mut cache = Cache::new(4, 10000);
372
373        // Fill row 0 with full length 4
374        let (data, _) = cache.get_data(0, 4);
375        data[0] = 1.0;
376        data[1] = 2.0;
377        data[2] = 3.0;
378        data[3] = 4.0;
379
380        // Fill row 1 with length 2 only (covers columns 0,1 but NOT 2,3)
381        let (data, _) = cache.get_data(1, 2);
382        data[0] = 10.0;
383        data[1] = 20.0;
384
385        // Swap indices 1 and 3: row 1 covers col 1 but NOT col 3 → evict
386        cache.swap_index(1, 3);
387
388        // Row 0 has full coverage, columns swapped
389        let (data, start) = cache.get_data(0, 4);
390        assert_eq!(start, 4);
391        assert_eq!(data[1], 4.0);
392        assert_eq!(data[3], 2.0);
393
394        // Row at position 1 (formerly row 3, which was empty) should need fill
395        let (_, start) = cache.get_data(1, 2);
396        assert_eq!(start, 0);
397    }
398}