Skip to main content

libsvm_rs/
cache.rs

1//! LRU kernel cache matching the original LIBSVM.
2//!
3//! The cache stores rows of the kernel matrix Q as `Qfloat` (`f32`) slices.
4//! When memory is exhausted, the least-recently-used row is evicted.
5//!
6//! The C++ original uses a doubly-linked circular list with raw pointers.
7//! This Rust version uses an index-based circular doubly-linked list for
8//! O(1) LRU operations, avoiding unsafe code while matching the semantics.
9
10/// Element type for cached kernel matrix rows. Matches LIBSVM's `Qfloat = float`.
11pub type Qfloat = f32;
12
13/// Sentinel index representing "no link" in the LRU list.
14const NONE: usize = usize::MAX;
15
16/// Per-row LRU node: stores prev/next indices for a circular doubly-linked list.
17struct LruNode {
18    prev: usize,
19    next: usize,
20}
21
22/// LRU cache for kernel matrix rows.
23///
24/// Each of the `l` data items may have a cached row of length up to `l`.
25/// The cache tracks how much memory (in Qfloat units) is in use and evicts
26/// LRU entries when the budget is exceeded.
27///
28/// The LRU list is a circular doubly-linked list using array indices.
29/// Index `l` is the sentinel head node. All operations (insert, remove,
30/// evict) are O(1).
31pub struct Cache {
32    /// Number of data items (rows in the Q matrix).
33    l: usize,
34    /// Available budget in Qfloat units.
35    size: usize,
36    /// Per-row cached data. `None` means not cached.
37    data: Vec<Option<Vec<Qfloat>>>,
38    /// Per-row cached length (how many elements are valid).
39    len: Vec<usize>,
40    /// LRU doubly-linked list nodes. Index `l` is the sentinel head.
41    /// Nodes with `prev == NONE` are not in the LRU list.
42    nodes: Vec<LruNode>,
43}
44
45impl Cache {
46    /// Create a new cache for `l` data items with `size_bytes` of memory.
47    pub fn new(l: usize, size_bytes: usize) -> Self {
48        // Convert bytes to Qfloat units
49        let mut size = size_bytes / std::mem::size_of::<Qfloat>();
50        // Subtract header overhead (metadata per row)
51        let header_size = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
52        // Cache must be large enough for at least two columns
53        size = size.max(2 * l + header_size).saturating_sub(header_size);
54
55        // Create l+1 nodes: 0..l for data rows, l for sentinel head
56        let mut nodes: Vec<LruNode> = (0..l)
57            .map(|_| LruNode {
58                prev: NONE,
59                next: NONE,
60            })
61            .collect();
62        // Sentinel head points to itself (empty list)
63        nodes.push(LruNode { prev: l, next: l });
64
65        Cache {
66            l,
67            size,
68            data: (0..l).map(|_| None).collect(),
69            len: vec![0; l],
70            nodes,
71        }
72    }
73
74    /// Remove node `i` from the LRU list. O(1).
75    #[inline]
76    fn lru_delete(&mut self, i: usize) {
77        let prev = self.nodes[i].prev;
78        let next = self.nodes[i].next;
79        self.nodes[prev].next = next;
80        self.nodes[next].prev = prev;
81        self.nodes[i].prev = NONE;
82        self.nodes[i].next = NONE;
83    }
84
85    /// Insert node `i` at the back of the LRU list (most recently used). O(1).
86    #[inline]
87    fn lru_insert(&mut self, i: usize) {
88        let head = self.l; // sentinel
89        let tail = self.nodes[head].prev;
90        self.nodes[i].next = head;
91        self.nodes[i].prev = tail;
92        self.nodes[tail].next = i;
93        self.nodes[head].prev = i;
94    }
95
96    /// Check if node `i` is in the LRU list.
97    #[inline]
98    fn in_lru(&self, i: usize) -> bool {
99        self.nodes[i].prev != NONE
100    }
101
102    /// Request data for row `index` of length `request_len`.
103    ///
104    /// Returns `(data, start)` where `data` is the cached row slice and
105    /// `start` is the position from which data needs to be filled.
106    /// If `start >= request_len`, the entire row was already cached.
107    ///
108    /// The caller must fill `data[start..request_len]` with kernel values.
109    pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
110        assert!(index < self.l);
111
112        // Remove from LRU if present (will re-insert at tail)
113        if self.in_lru(index) {
114            self.lru_delete(index);
115        }
116
117        let old_len = self.len[index];
118        let more = request_len.saturating_sub(old_len);
119
120        if more > 0 {
121            // Evict LRU entries until we have enough space
122            let head = self.l;
123            while self.size < more {
124                let victim = self.nodes[head].next;
125                if victim == head {
126                    break; // list empty
127                }
128                self.lru_delete(victim);
129                self.size += self.len[victim];
130                self.data[victim] = None;
131                self.len[victim] = 0;
132            }
133
134            // Allocate or extend
135            let entry = self.data[index].get_or_insert_with(Vec::new);
136            entry.resize(request_len, 0.0);
137            self.size -= more;
138            self.len[index] = request_len;
139        }
140
141        // Insert at back of LRU (most recently used)
142        self.lru_insert(index);
143
144        let start = old_len;
145        (self.data[index].as_mut().unwrap().as_mut_slice(), start)
146    }
147
148    /// Swap indices `i` and `j` in the cache.
149    ///
150    /// Used by the solver when rearranging the working set.
151    pub fn swap_index(&mut self, i: usize, j: usize) {
152        if i == j {
153            return;
154        }
155
156        let i_in = self.in_lru(i);
157        let j_in = self.in_lru(j);
158
159        if i_in {
160            self.lru_delete(i);
161        }
162        if j_in {
163            self.lru_delete(j);
164        }
165
166        // Swap data and lengths
167        self.data.swap(i, j);
168        self.len.swap(i, j);
169
170        // Re-insert with swapped identities
171        if i_in {
172            self.lru_insert(j);
173        }
174        if j_in {
175            self.lru_insert(i);
176        }
177
178        // Column swap: iterate over all cached rows and swap columns i,j.
179        // If a row covers the lower index but not the higher, evict it
180        // (matching the C++ "give up" behavior).
181        let (lo, hi) = if i < j { (i, j) } else { (j, i) };
182        let head = self.l;
183        let mut h = self.nodes[head].next;
184        while h != head {
185            let next = self.nodes[h].next;
186            if self.len[h] > lo {
187                if self.len[h] > hi {
188                    // Row covers both positions — swap the column entries
189                    if let Some(ref mut row) = self.data[h] {
190                        row.swap(lo, hi);
191                    }
192                } else {
193                    // Row covers lo but not hi — evict ("give up")
194                    self.lru_delete(h);
195                    self.size += self.len[h];
196                    self.data[h] = None;
197                    self.len[h] = 0;
198                }
199            }
200            h = next;
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn basic_get_and_fill() {
211        let mut cache = Cache::new(3, 100);
212        let (data, start) = cache.get_data(0, 3);
213        assert_eq!(start, 0);
214        assert_eq!(data.len(), 3);
215        data[0] = 1.0;
216        data[1] = 2.0;
217        data[2] = 3.0;
218
219        // Second access should return start=3 (fully cached)
220        let (data, start) = cache.get_data(0, 3);
221        assert_eq!(start, 3);
222        assert_eq!(data[0], 1.0);
223        assert_eq!(data[1], 2.0);
224        assert_eq!(data[2], 3.0);
225    }
226
227    #[test]
228    fn extend_cached_row() {
229        let mut cache = Cache::new(3, 1000);
230        let (data, start) = cache.get_data(0, 2);
231        assert_eq!(start, 0);
232        data[0] = 10.0;
233        data[1] = 20.0;
234
235        let (data, start) = cache.get_data(0, 3);
236        assert_eq!(start, 2);
237        assert_eq!(data[0], 10.0);
238        assert_eq!(data[1], 20.0);
239        data[2] = 30.0;
240    }
241
242    #[test]
243    fn lru_eviction() {
244        let l = 10;
245        let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
246        let mut cache = Cache::new(l, bytes);
247
248        let (data, start) = cache.get_data(0, l);
249        assert_eq!(start, 0);
250        data[0] = 1.0;
251
252        let (data, start) = cache.get_data(1, l);
253        assert_eq!(start, 0);
254        data[0] = 3.0;
255
256        // Should evict row 0 (LRU)
257        let (data, start) = cache.get_data(2, l);
258        assert_eq!(start, 0);
259        data[0] = 5.0;
260
261        // Row 0 evicted
262        let (_, start) = cache.get_data(0, l);
263        assert_eq!(start, 0);
264    }
265
266    #[test]
267    fn lru_order_respects_access() {
268        // Verify that re-accessing a row moves it to MRU position.
269        // l=5, row_len=5: min cache = 2*5=10 Qfloats. Budget = 3 rows = 15.
270        let l = 5;
271        let row_len = l;
272        let header = l * std::mem::size_of::<LruNode>() / std::mem::size_of::<Qfloat>();
273        let budget = 3 * row_len + header;
274        let bytes = budget * std::mem::size_of::<Qfloat>();
275        let mut cache = Cache::new(l, bytes);
276
277        // Fill rows 0, 1, 2. LRU order: 0(oldest), 1, 2(newest)
278        let (d, _) = cache.get_data(0, row_len);
279        d[0] = 10.0;
280        let (d, _) = cache.get_data(1, row_len);
281        d[0] = 20.0;
282        let (d, _) = cache.get_data(2, row_len);
283        d[0] = 30.0;
284
285        // Touch row 0 → LRU order: 1(oldest), 2, 0(newest)
286        let (d, start) = cache.get_data(0, row_len);
287        assert_eq!(start, row_len); // already cached, no fill needed
288        assert_eq!(d[0], 10.0);
289
290        // Insert row 3 → must evict. Row 1 is LRU, so it gets evicted.
291        // LRU order: 2, 0, 3(newest)
292        let (d, start) = cache.get_data(3, row_len);
293        assert_eq!(start, 0); // new row, needs fill
294        d[0] = 40.0;
295
296        // Row 1 was evicted
297        assert!(cache.data[1].is_none());
298        // Row 0 and 2 are still cached
299        assert!(cache.data[0].is_some());
300        assert!(cache.data[2].is_some());
301    }
302
303    #[test]
304    fn swap_index_row_swap() {
305        // Test that rows are swapped correctly.
306        // Use full-length rows to avoid column-swap eviction.
307        let mut cache = Cache::new(3, 1000);
308        let (data, _) = cache.get_data(0, 3);
309        data[0] = 10.0;
310        data[1] = 20.0;
311        data[2] = 30.0;
312
313        cache.swap_index(0, 2);
314
315        // Row data moved from position 0 to position 2; columns 0,2 also swapped
316        let (data, start) = cache.get_data(2, 3);
317        assert_eq!(start, 3); // fully cached
318        assert_eq!(data[0], 30.0); // was data[2], swapped to col 0
319        assert_eq!(data[1], 20.0);
320        assert_eq!(data[2], 10.0); // was data[0], swapped to col 2
321
322        // Position 0 was empty (row 2 had no data), needs fill
323        let (_, start) = cache.get_data(0, 3);
324        assert_eq!(start, 0);
325    }
326
327    #[test]
328    fn swap_index_swaps_columns_in_other_rows() {
329        // 4 items, generous cache so nothing is evicted by budget
330        let mut cache = Cache::new(4, 10000);
331
332        // Fill row 0 with [1, 2, 3, 4] (covers all 4 columns)
333        let (data, _) = cache.get_data(0, 4);
334        data[0] = 1.0;
335        data[1] = 2.0;
336        data[2] = 3.0;
337        data[3] = 4.0;
338
339        // Fill row 1 with [10, 20, 30, 40]
340        let (data, _) = cache.get_data(1, 4);
341        data[0] = 10.0;
342        data[1] = 20.0;
343        data[2] = 30.0;
344        data[3] = 40.0;
345
346        // Swap indices 1 and 3
347        cache.swap_index(1, 3);
348
349        // Row 0 (now at index 0, not swapped) should have columns 1,3 swapped
350        let (data, start) = cache.get_data(0, 4);
351        assert_eq!(start, 4); // still cached
352        assert_eq!(data[0], 1.0);
353        assert_eq!(data[1], 4.0); // was data[3]
354        assert_eq!(data[2], 3.0);
355        assert_eq!(data[3], 2.0); // was data[1]
356
357        // Row 1 was swapped to position 3, row 3 was swapped to position 1
358        // Row at position 3 (formerly row 1) should have columns 1,3 swapped
359        let (data, start) = cache.get_data(3, 4);
360        assert_eq!(start, 4); // still cached
361        assert_eq!(data[0], 10.0);
362        assert_eq!(data[1], 40.0); // was data[3]
363        assert_eq!(data[2], 30.0);
364        assert_eq!(data[3], 20.0); // was data[1]
365    }
366
367    #[test]
368    fn swap_index_evicts_partial_rows() {
369        // 4 items
370        let mut cache = Cache::new(4, 10000);
371
372        // Fill row 0 with full length 4
373        let (data, _) = cache.get_data(0, 4);
374        data[0] = 1.0;
375        data[1] = 2.0;
376        data[2] = 3.0;
377        data[3] = 4.0;
378
379        // Fill row 1 with length 2 only (covers columns 0,1 but NOT 2,3)
380        let (data, _) = cache.get_data(1, 2);
381        data[0] = 10.0;
382        data[1] = 20.0;
383
384        // Swap indices 1 and 3: row 1 covers col 1 but NOT col 3 → evict
385        cache.swap_index(1, 3);
386
387        // Row 0 has full coverage, columns swapped
388        let (data, start) = cache.get_data(0, 4);
389        assert_eq!(start, 4);
390        assert_eq!(data[1], 4.0);
391        assert_eq!(data[3], 2.0);
392
393        // Row at position 1 (formerly row 3, which was empty) should need fill
394        let (_, start) = cache.get_data(1, 2);
395        assert_eq!(start, 0);
396    }
397}