Skip to main content

libsvm_rs/
cache.rs

1//! LRU kernel cache matching the original LIBSVM.
2//!
3//! The cache stores rows of the kernel matrix Q as `Qfloat` (`f32`) slices.
4//! When memory is exhausted, the least-recently-used row is evicted.
5//!
6//! The C++ original uses a doubly-linked circular list with raw pointers.
7//! This Rust version uses an index-based circular doubly-linked list for
8//! O(1) LRU operations, avoiding unsafe code while matching the semantics.
9
10/// Element type for cached kernel matrix rows. Matches LIBSVM's `Qfloat = float`.
11pub type Qfloat = f32;
12
13/// Sentinel index representing "no link" in the LRU list.
14const NONE: usize = usize::MAX;
15
16/// Per-row LRU node: stores prev/next indices for a circular doubly-linked list.
17struct LruNode {
18    prev: usize,
19    next: usize,
20}
21
22/// LRU cache for kernel matrix rows.
23///
24/// Each of the `l` data items may have a cached row of length up to `l`.
25/// The cache tracks how much memory (in Qfloat units) is in use and evicts
26/// LRU entries when the budget is exceeded.
27///
28/// The LRU list is a circular doubly-linked list using array indices.
29/// Index `l` is the sentinel head node. All operations (insert, remove,
30/// evict) are O(1).
31pub struct Cache {
32    /// Number of data items (rows in the Q matrix).
33    l: usize,
34    /// Available budget in Qfloat units.
35    size: usize,
36    /// Per-row cached data. `None` means not cached.
37    data: Vec<Option<Vec<Qfloat>>>,
38    /// Per-row cached length (how many elements are valid).
39    len: Vec<usize>,
40    /// LRU doubly-linked list nodes. Index `l` is the sentinel head.
41    /// Nodes with `prev == NONE` are not in the LRU list.
42    nodes: Vec<LruNode>,
43}
44
45impl Cache {
46    /// Create a new cache for `l` data items with `size_bytes` of memory.
47    pub fn new(l: usize, size_bytes: usize) -> Self {
48        // Convert bytes to Qfloat units
49        let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50        // Subtract header overhead (metadata per row)
51        let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52        // Cache must be large enough for at least two columns
53        size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55        // Create l+1 nodes: 0..l for data rows, l for sentinel head
56        let mut nodes: Vec<LruNode> = (0..l)
57            .map(|_| LruNode { prev: NONE, next: NONE })
58            .collect();
59        // Sentinel head points to itself (empty list)
60        nodes.push(LruNode { prev: l, next: l });
61
62        Cache {
63            l,
64            size,
65            data: (0..l).map(|_| None).collect(),
66            len: vec![0; l],
67            nodes,
68        }
69    }
70
71    /// Remove node `i` from the LRU list. O(1).
72    #[inline]
73    fn lru_delete(&mut self, i: usize) {
74        let prev = self.nodes[i].prev;
75        let next = self.nodes[i].next;
76        self.nodes[prev].next = next;
77        self.nodes[next].prev = prev;
78        self.nodes[i].prev = NONE;
79        self.nodes[i].next = NONE;
80    }
81
82    /// Insert node `i` at the back of the LRU list (most recently used). O(1).
83    #[inline]
84    fn lru_insert(&mut self, i: usize) {
85        let head = self.l; // sentinel
86        let tail = self.nodes[head].prev;
87        self.nodes[i].next = head;
88        self.nodes[i].prev = tail;
89        self.nodes[tail].next = i;
90        self.nodes[head].prev = i;
91    }
92
93    /// Check if node `i` is in the LRU list.
94    #[inline]
95    fn in_lru(&self, i: usize) -> bool {
96        self.nodes[i].prev != NONE
97    }
98
99    /// Request data for row `index` of length `request_len`.
100    ///
101    /// Returns `(data, start)` where `data` is the cached row slice and
102    /// `start` is the position from which data needs to be filled.
103    /// If `start >= request_len`, the entire row was already cached.
104    ///
105    /// The caller must fill `data[start..request_len]` with kernel values.
106    pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
107        assert!(index < self.l);
108
109        // Remove from LRU if present (will re-insert at tail)
110        if self.in_lru(index) {
111            self.lru_delete(index);
112        }
113
114        let old_len = self.len[index];
115        let more = request_len.saturating_sub(old_len);
116
117        if more > 0 {
118            // Evict LRU entries until we have enough space
119            let head = self.l;
120            while self.size < more {
121                let victim = self.nodes[head].next;
122                if victim == head {
123                    break; // list empty
124                }
125                self.lru_delete(victim);
126                self.size += self.len[victim];
127                self.data[victim] = None;
128                self.len[victim] = 0;
129            }
130
131            // Allocate or extend
132            let entry = self.data[index].get_or_insert_with(Vec::new);
133            entry.resize(request_len, 0.0);
134            self.size -= more;
135            self.len[index] = request_len;
136        }
137
138        // Insert at back of LRU (most recently used)
139        self.lru_insert(index);
140
141        let start = old_len;
142        (self.data[index].as_mut().unwrap().as_mut_slice(), start)
143    }
144
145    /// Swap indices `i` and `j` in the cache.
146    ///
147    /// Used by the solver when rearranging the working set.
148    pub fn swap_index(&mut self, i: usize, j: usize) {
149        if i == j {
150            return;
151        }
152
153        let i_in = self.in_lru(i);
154        let j_in = self.in_lru(j);
155
156        if i_in {
157            self.lru_delete(i);
158        }
159        if j_in {
160            self.lru_delete(j);
161        }
162
163        // Swap data and lengths
164        self.data.swap(i, j);
165        self.len.swap(i, j);
166
167        // Re-insert with swapped identities
168        if i_in {
169            self.lru_insert(j);
170        }
171        if j_in {
172            self.lru_insert(i);
173        }
174
175        // Column swap: iterate over all cached rows and swap columns i,j.
176        // If a row covers the lower index but not the higher, evict it
177        // (matching the C++ "give up" behavior).
178        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
179        let head = self.l;
180        let mut h = self.nodes[head].next;
181        while h != head {
182            let next = self.nodes[h].next;
183            if self.len[h] > lo {
184                if self.len[h] > hi {
185                    // Row covers both positions — swap the column entries
186                    if let Some(ref mut row) = self.data[h] {
187                        row.swap(lo, hi);
188                    }
189                } else {
190                    // Row covers lo but not hi — evict ("give up")
191                    self.lru_delete(h);
192                    self.size += self.len[h];
193                    self.data[h] = None;
194                    self.len[h] = 0;
195                }
196            }
197            h = next;
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205
206    #[test]
207    fn basic_get_and_fill() {
208        let mut cache = Cache::new(3, 100);
209        let (data, start) = cache.get_data(0, 3);
210        assert_eq!(start, 0);
211        assert_eq!(data.len(), 3);
212        data[0] = 1.0;
213        data[1] = 2.0;
214        data[2] = 3.0;
215
216        // Second access should return start=3 (fully cached)
217        let (data, start) = cache.get_data(0, 3);
218        assert_eq!(start, 3);
219        assert_eq!(data[0], 1.0);
220        assert_eq!(data[1], 2.0);
221        assert_eq!(data[2], 3.0);
222    }
223
224    #[test]
225    fn extend_cached_row() {
226        let mut cache = Cache::new(3, 1000);
227        let (data, start) = cache.get_data(0, 2);
228        assert_eq!(start, 0);
229        data[0] = 10.0;
230        data[1] = 20.0;
231
232        let (data, start) = cache.get_data(0, 3);
233        assert_eq!(start, 2);
234        assert_eq!(data[0], 10.0);
235        assert_eq!(data[1], 20.0);
236        data[2] = 30.0;
237    }
238
239    #[test]
240    fn lru_eviction() {
241        let l = 10;
242        let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
243        let mut cache = Cache::new(l, bytes);
244
245        let (data, start) = cache.get_data(0, l);
246        assert_eq!(start, 0);
247        data[0] = 1.0;
248
249        let (data, start) = cache.get_data(1, l);
250        assert_eq!(start, 0);
251        data[0] = 3.0;
252
253        // Should evict row 0 (LRU)
254        let (data, start) = cache.get_data(2, l);
255        assert_eq!(start, 0);
256        data[0] = 5.0;
257
258        // Row 0 evicted
259        let (_, start) = cache.get_data(0, l);
260        assert_eq!(start, 0);
261    }
262
263    #[test]
264    fn lru_order_respects_access() {
265        // Verify that re-accessing a row moves it to MRU position.
266        // l=5, row_len=5: min cache = 2*5=10 Qfloats. Budget = 3 rows = 15.
267        let l = 5;
268        let row_len = l;
269        let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
270        let budget = 3 * row_len + header;
271        let bytes = budget * std::mem::size_of::<Qfloat>();
272        let mut cache = Cache::new(l, bytes);
273
274        // Fill rows 0, 1, 2. LRU order: 0(oldest), 1, 2(newest)
275        let (d, _) = cache.get_data(0, row_len);
276        d[0] = 10.0;
277        let (d, _) = cache.get_data(1, row_len);
278        d[0] = 20.0;
279        let (d, _) = cache.get_data(2, row_len);
280        d[0] = 30.0;
281
282        // Touch row 0 → LRU order: 1(oldest), 2, 0(newest)
283        let (d, start) = cache.get_data(0, row_len);
284        assert_eq!(start, row_len); // already cached, no fill needed
285        assert_eq!(d[0], 10.0);
286
287        // Insert row 3 → must evict. Row 1 is LRU, so it gets evicted.
288        // LRU order: 2, 0, 3(newest)
289        let (d, start) = cache.get_data(3, row_len);
290        assert_eq!(start, 0); // new row, needs fill
291        d[0] = 40.0;
292
293        // Row 1 was evicted
294        assert!(cache.data[1].is_none());
295        // Row 0 and 2 are still cached
296        assert!(cache.data[0].is_some());
297        assert!(cache.data[2].is_some());
298    }
299
300    #[test]
301    fn swap_index_row_swap() {
302        // Test that rows are swapped correctly.
303        // Use full-length rows to avoid column-swap eviction.
304        let mut cache = Cache::new(3, 1000);
305        let (data, _) = cache.get_data(0, 3);
306        data[0] = 10.0;
307        data[1] = 20.0;
308        data[2] = 30.0;
309
310        cache.swap_index(0, 2);
311
312        // Row data moved from position 0 to position 2; columns 0,2 also swapped
313        let (data, start) = cache.get_data(2, 3);
314        assert_eq!(start, 3); // fully cached
315        assert_eq!(data[0], 30.0); // was data[2], swapped to col 0
316        assert_eq!(data[1], 20.0);
317        assert_eq!(data[2], 10.0); // was data[0], swapped to col 2
318
319        // Position 0 was empty (row 2 had no data), needs fill
320        let (_, start) = cache.get_data(0, 3);
321        assert_eq!(start, 0);
322    }
323
324    #[test]
325    fn swap_index_swaps_columns_in_other_rows() {
326        // 4 items, generous cache so nothing is evicted by budget
327        let mut cache = Cache::new(4, 10000);
328
329        // Fill row 0 with [1, 2, 3, 4] (covers all 4 columns)
330        let (data, _) = cache.get_data(0, 4);
331        data[0] = 1.0;
332        data[1] = 2.0;
333        data[2] = 3.0;
334        data[3] = 4.0;
335
336        // Fill row 1 with [10, 20, 30, 40]
337        let (data, _) = cache.get_data(1, 4);
338        data[0] = 10.0;
339        data[1] = 20.0;
340        data[2] = 30.0;
341        data[3] = 40.0;
342
343        // Swap indices 1 and 3
344        cache.swap_index(1, 3);
345
346        // Row 0 (now at index 0, not swapped) should have columns 1,3 swapped
347        let (data, start) = cache.get_data(0, 4);
348        assert_eq!(start, 4); // still cached
349        assert_eq!(data[0], 1.0);
350        assert_eq!(data[1], 4.0); // was data[3]
351        assert_eq!(data[2], 3.0);
352        assert_eq!(data[3], 2.0); // was data[1]
353
354        // Row 1 was swapped to position 3, row 3 was swapped to position 1
355        // Row at position 3 (formerly row 1) should have columns 1,3 swapped
356        let (data, start) = cache.get_data(3, 4);
357        assert_eq!(start, 4); // still cached
358        assert_eq!(data[0], 10.0);
359        assert_eq!(data[1], 40.0); // was data[3]
360        assert_eq!(data[2], 30.0);
361        assert_eq!(data[3], 20.0); // was data[1]
362    }
363
364    #[test]
365    fn swap_index_evicts_partial_rows() {
366        // 4 items
367        let mut cache = Cache::new(4, 10000);
368
369        // Fill row 0 with full length 4
370        let (data, _) = cache.get_data(0, 4);
371        data[0] = 1.0;
372        data[1] = 2.0;
373        data[2] = 3.0;
374        data[3] = 4.0;
375
376        // Fill row 1 with length 2 only (covers columns 0,1 but NOT 2,3)
377        let (data, _) = cache.get_data(1, 2);
378        data[0] = 10.0;
379        data[1] = 20.0;
380
381        // Swap indices 1 and 3: row 1 covers col 1 but NOT col 3 → evict
382        cache.swap_index(1, 3);
383
384        // Row 0 has full coverage, columns swapped
385        let (data, start) = cache.get_data(0, 4);
386        assert_eq!(start, 4);
387        assert_eq!(data[1], 4.0);
388        assert_eq!(data[3], 2.0);
389
390        // Row at position 1 (formerly row 3, which was empty) should need fill
391        let (_, start) = cache.get_data(1, 2);
392        assert_eq!(start, 0);
393    }
394}