Skip to main content

libsvm_rs/
cache.rs

1//! LRU kernel cache matching the original LIBSVM.
2//!
3//! The cache stores rows of the kernel matrix Q as `Qfloat` (`f32`) slices.
4//! When memory is exhausted, the least-recently-used row is evicted.
5//!
6//! The C++ original uses a doubly-linked circular list with raw pointers.
7//! This Rust version uses a `VecDeque` for the LRU order and `Vec<Option<Vec<Qfloat>>>`
8//! for per-row data, avoiding unsafe code while matching the semantics.
9
10/// Element type for cached kernel matrix rows. Matches LIBSVM's `Qfloat = float`.
11pub type Qfloat = f32;
12
13/// LRU cache for kernel matrix rows.
14///
15/// Each of the `l` data items may have a cached row of length up to `l`.
16/// The cache tracks how much memory (in Qfloat units) is in use and evicts
17/// LRU entries when the budget is exceeded.
18pub struct Cache {
19    /// Number of data items (rows in the Q matrix).
20    l: usize,
21    /// Available budget in Qfloat units.
22    size: usize,
23    /// Per-row cached data. `None` means not cached.
24    data: Vec<Option<Vec<Qfloat>>>,
25    /// Per-row cached length (how many elements are valid).
26    len: Vec<usize>,
27    /// LRU order: front = least recently used, back = most recently used.
28    /// Contains indices of rows currently in the cache.
29    lru: Vec<usize>,
30}
31
32impl Cache {
33    /// Create a new cache for `l` data items with `size_bytes` of memory.
34    pub fn new(l: usize, size_bytes: usize) -> Self {
35        // Convert bytes to Qfloat units
36        let mut size = size_bytes / std::mem::size_of::<Qfloat>();
37        // Subtract header overhead (metadata per row)
38        let header_size = l * 3 * std::mem::size_of::<usize>() / std::mem::size_of::<Qfloat>();
39        // Cache must be large enough for at least two columns
40        size = size.max(2 * l + header_size).saturating_sub(header_size);
41
42        Cache {
43            l,
44            size,
45            data: (0..l).map(|_| None).collect(),
46            len: vec![0; l],
47            lru: Vec::new(),
48        }
49    }
50
51    /// Request data for row `index` of length `request_len`.
52    ///
53    /// Returns `(data, start)` where `data` is the cached row slice and
54    /// `start` is the position from which data needs to be filled.
55    /// If `start >= request_len`, the entire row was already cached.
56    ///
57    /// The caller must fill `data[start..request_len]` with kernel values.
58    pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
59        assert!(index < self.l);
60
61        // Remove from LRU if present
62        if self.len[index] > 0 {
63            self.lru_remove(index);
64        }
65
66        let old_len = self.len[index];
67        let more = request_len.saturating_sub(old_len);
68
69        if more > 0 {
70            // Evict until we have enough space
71            while self.size < more {
72                if let Some(victim) = self.lru.first().copied() {
73                    self.lru.remove(0);
74                    let victim_len = self.len[victim];
75                    self.size += victim_len;
76                    self.data[victim] = None;
77                    self.len[victim] = 0;
78                } else {
79                    break;
80                }
81            }
82
83            // Allocate or extend
84            let entry = self.data[index].get_or_insert_with(Vec::new);
85            entry.resize(request_len, 0.0);
86            self.size -= more;
87            self.len[index] = request_len;
88        }
89
90        // Insert at back of LRU (most recently used)
91        self.lru.push(index);
92
93        let start = old_len;
94        (self.data[index].as_mut().unwrap().as_mut_slice(), start)
95    }
96
97    /// Swap indices `i` and `j` in the cache.
98    ///
99    /// Used by the solver when rearranging the working set.
100    pub fn swap_index(&mut self, i: usize, j: usize) {
101        if i == j {
102            return;
103        }
104
105        // Remove both from LRU if present
106        if self.len[i] > 0 {
107            self.lru_remove(i);
108        }
109        if self.len[j] > 0 {
110            self.lru_remove(j);
111        }
112
113        // Swap data and lengths
114        self.data.swap(i, j);
115        self.len.swap(i, j);
116
117        // Re-insert into LRU if they have data
118        if self.len[i] > 0 {
119            self.lru.push(i);
120        }
121        if self.len[j] > 0 {
122            self.lru.push(j);
123        }
124    }
125
126    fn lru_remove(&mut self, index: usize) {
127        if let Some(pos) = self.lru.iter().position(|&x| x == index) {
128            self.lru.remove(pos);
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn basic_get_and_fill() {
139        // Cache for 3 items, 100 bytes
140        let mut cache = Cache::new(3, 100);
141        let (data, start) = cache.get_data(0, 3);
142        assert_eq!(start, 0); // nothing cached yet
143        assert_eq!(data.len(), 3);
144        data[0] = 1.0;
145        data[1] = 2.0;
146        data[2] = 3.0;
147
148        // Second access should return start=3 (fully cached)
149        let (data, start) = cache.get_data(0, 3);
150        assert_eq!(start, 3);
151        assert_eq!(data[0], 1.0);
152        assert_eq!(data[1], 2.0);
153        assert_eq!(data[2], 3.0);
154    }
155
156    #[test]
157    fn extend_cached_row() {
158        let mut cache = Cache::new(3, 1000);
159        let (data, start) = cache.get_data(0, 2);
160        assert_eq!(start, 0);
161        data[0] = 10.0;
162        data[1] = 20.0;
163
164        // Request longer row
165        let (data, start) = cache.get_data(0, 3);
166        assert_eq!(start, 2); // only need to fill [2..3)
167        assert_eq!(data[0], 10.0);
168        assert_eq!(data[1], 20.0);
169        data[2] = 30.0;
170    }
171
172    #[test]
173    fn lru_eviction() {
174        // Use 10 items to make the minimum cache size meaningful.
175        // Min size = 2*l = 20 Qfloats. Request rows of 10, so only 2 fit.
176        let l = 10;
177        // Give just enough bytes for ~20 Qfloats + header overhead
178        let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
179        let mut cache = Cache::new(l, bytes);
180
181        // Fill row 0 with l elements
182        let (data, start) = cache.get_data(0, l);
183        assert_eq!(start, 0);
184        data[0] = 1.0;
185
186        // Fill row 1 with l elements
187        let (data, start) = cache.get_data(1, l);
188        assert_eq!(start, 0);
189        data[0] = 3.0;
190
191        // Fill row 2 — should evict row 0 (LRU)
192        let (data, start) = cache.get_data(2, l);
193        assert_eq!(start, 0);
194        data[0] = 5.0;
195
196        // Row 0 should have been evicted
197        let (_, start) = cache.get_data(0, l);
198        assert_eq!(start, 0); // needs to be re-filled
199    }
200
201    #[test]
202    fn swap_index_works() {
203        let mut cache = Cache::new(3, 1000);
204        let (data, _) = cache.get_data(0, 2);
205        data[0] = 10.0;
206        data[1] = 20.0;
207
208        cache.swap_index(0, 2);
209
210        // Row 2 should now have the data from row 0
211        let (data, start) = cache.get_data(2, 2);
212        assert_eq!(start, 2); // already cached
213        assert_eq!(data[0], 10.0);
214        assert_eq!(data[1], 20.0);
215
216        // Row 0 should be empty
217        let (_, start) = cache.get_data(0, 2);
218        assert_eq!(start, 0);
219    }
220}