libsvm_rs/cache.rs
1//! LRU kernel cache matching the original LIBSVM.
2//!
3//! The cache stores rows of the kernel matrix Q as `Qfloat` (`f32`) slices.
4//! When memory is exhausted, the least-recently-used row is evicted.
5//!
6//! The C++ original uses a doubly-linked circular list with raw pointers.
7//! This Rust version uses a `VecDeque` for the LRU order and `Vec<Option<Vec<Qfloat>>>`
8//! for per-row data, avoiding unsafe code while matching the semantics.
9
10/// Element type for cached kernel matrix rows. Matches LIBSVM's `Qfloat = float`.
11pub type Qfloat = f32;
12
13/// LRU cache for kernel matrix rows.
14///
15/// Each of the `l` data items may have a cached row of length up to `l`.
16/// The cache tracks how much memory (in Qfloat units) is in use and evicts
17/// LRU entries when the budget is exceeded.
18pub struct Cache {
19 /// Number of data items (rows in the Q matrix).
20 l: usize,
21 /// Available budget in Qfloat units.
22 size: usize,
23 /// Per-row cached data. `None` means not cached.
24 data: Vec<Option<Vec<Qfloat>>>,
25 /// Per-row cached length (how many elements are valid).
26 len: Vec<usize>,
27 /// LRU order: front = least recently used, back = most recently used.
28 /// Contains indices of rows currently in the cache.
29 lru: Vec<usize>,
30}
31
32impl Cache {
33 /// Create a new cache for `l` data items with `size_bytes` of memory.
34 pub fn new(l: usize, size_bytes: usize) -> Self {
35 // Convert bytes to Qfloat units
36 let mut size = size_bytes / std::mem::size_of::<Qfloat>();
37 // Subtract header overhead (metadata per row)
38 let header_size = l * 3 * std::mem::size_of::<usize>() / std::mem::size_of::<Qfloat>();
39 // Cache must be large enough for at least two columns
40 size = size.max(2 * l + header_size).saturating_sub(header_size);
41
42 Cache {
43 l,
44 size,
45 data: (0..l).map(|_| None).collect(),
46 len: vec![0; l],
47 lru: Vec::new(),
48 }
49 }
50
51 /// Request data for row `index` of length `request_len`.
52 ///
53 /// Returns `(data, start)` where `data` is the cached row slice and
54 /// `start` is the position from which data needs to be filled.
55 /// If `start >= request_len`, the entire row was already cached.
56 ///
57 /// The caller must fill `data[start..request_len]` with kernel values.
58 pub fn get_data(&mut self, index: usize, request_len: usize) -> (&mut [Qfloat], usize) {
59 assert!(index < self.l);
60
61 // Remove from LRU if present
62 if self.len[index] > 0 {
63 self.lru_remove(index);
64 }
65
66 let old_len = self.len[index];
67 let more = request_len.saturating_sub(old_len);
68
69 if more > 0 {
70 // Evict until we have enough space
71 while self.size < more {
72 if let Some(victim) = self.lru.first().copied() {
73 self.lru.remove(0);
74 let victim_len = self.len[victim];
75 self.size += victim_len;
76 self.data[victim] = None;
77 self.len[victim] = 0;
78 } else {
79 break;
80 }
81 }
82
83 // Allocate or extend
84 let entry = self.data[index].get_or_insert_with(Vec::new);
85 entry.resize(request_len, 0.0);
86 self.size -= more;
87 self.len[index] = request_len;
88 }
89
90 // Insert at back of LRU (most recently used)
91 self.lru.push(index);
92
93 let start = old_len;
94 (self.data[index].as_mut().unwrap().as_mut_slice(), start)
95 }
96
97 /// Swap indices `i` and `j` in the cache.
98 ///
99 /// Used by the solver when rearranging the working set.
100 pub fn swap_index(&mut self, i: usize, j: usize) {
101 if i == j {
102 return;
103 }
104
105 // Remove both from LRU if present
106 if self.len[i] > 0 {
107 self.lru_remove(i);
108 }
109 if self.len[j] > 0 {
110 self.lru_remove(j);
111 }
112
113 // Swap data and lengths
114 self.data.swap(i, j);
115 self.len.swap(i, j);
116
117 // Re-insert into LRU if they have data
118 if self.len[i] > 0 {
119 self.lru.push(i);
120 }
121 if self.len[j] > 0 {
122 self.lru.push(j);
123 }
124 }
125
126 fn lru_remove(&mut self, index: usize) {
127 if let Some(pos) = self.lru.iter().position(|&x| x == index) {
128 self.lru.remove(pos);
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn basic_get_and_fill() {
139 // Cache for 3 items, 100 bytes
140 let mut cache = Cache::new(3, 100);
141 let (data, start) = cache.get_data(0, 3);
142 assert_eq!(start, 0); // nothing cached yet
143 assert_eq!(data.len(), 3);
144 data[0] = 1.0;
145 data[1] = 2.0;
146 data[2] = 3.0;
147
148 // Second access should return start=3 (fully cached)
149 let (data, start) = cache.get_data(0, 3);
150 assert_eq!(start, 3);
151 assert_eq!(data[0], 1.0);
152 assert_eq!(data[1], 2.0);
153 assert_eq!(data[2], 3.0);
154 }
155
156 #[test]
157 fn extend_cached_row() {
158 let mut cache = Cache::new(3, 1000);
159 let (data, start) = cache.get_data(0, 2);
160 assert_eq!(start, 0);
161 data[0] = 10.0;
162 data[1] = 20.0;
163
164 // Request longer row
165 let (data, start) = cache.get_data(0, 3);
166 assert_eq!(start, 2); // only need to fill [2..3)
167 assert_eq!(data[0], 10.0);
168 assert_eq!(data[1], 20.0);
169 data[2] = 30.0;
170 }
171
172 #[test]
173 fn lru_eviction() {
174 // Use 10 items to make the minimum cache size meaningful.
175 // Min size = 2*l = 20 Qfloats. Request rows of 10, so only 2 fit.
176 let l = 10;
177 // Give just enough bytes for ~20 Qfloats + header overhead
178 let bytes = (2 * l + l * 3) * std::mem::size_of::<Qfloat>();
179 let mut cache = Cache::new(l, bytes);
180
181 // Fill row 0 with l elements
182 let (data, start) = cache.get_data(0, l);
183 assert_eq!(start, 0);
184 data[0] = 1.0;
185
186 // Fill row 1 with l elements
187 let (data, start) = cache.get_data(1, l);
188 assert_eq!(start, 0);
189 data[0] = 3.0;
190
191 // Fill row 2 — should evict row 0 (LRU)
192 let (data, start) = cache.get_data(2, l);
193 assert_eq!(start, 0);
194 data[0] = 5.0;
195
196 // Row 0 should have been evicted
197 let (_, start) = cache.get_data(0, l);
198 assert_eq!(start, 0); // needs to be re-filled
199 }
200
201 #[test]
202 fn swap_index_works() {
203 let mut cache = Cache::new(3, 1000);
204 let (data, _) = cache.get_data(0, 2);
205 data[0] = 10.0;
206 data[1] = 20.0;
207
208 cache.swap_index(0, 2);
209
210 // Row 2 should now have the data from row 0
211 let (data, start) = cache.get_data(2, 2);
212 assert_eq!(start, 2); // already cached
213 assert_eq!(data[0], 10.0);
214 assert_eq!(data[1], 20.0);
215
216 // Row 0 should be empty
217 let (_, start) = cache.get_data(0, 2);
218 assert_eq!(start, 0);
219 }
220}