1use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10pub struct CallGraph {
12 edges: HashMap<[u8; 32], Vec<CallEdge>>,
14}
15
16#[derive(Debug, Clone)]
17pub struct CallEdge {
18 pub callee_hash: [u8; 32],
19 pub static_weight: f32,
20 pub dynamic_weight: f32,
21}
22
23#[derive(Debug, Clone)]
25pub struct PrefetchConfig {
26 pub max_prefetch_depth: usize,
27 pub top_n_callees: usize,
28 pub min_probability: f32,
29 pub enabled: bool,
30}
31
32impl Default for PrefetchConfig {
33 fn default() -> Self {
34 Self {
35 max_prefetch_depth: 2,
36 top_n_callees: 4,
37 min_probability: 0.1,
38 enabled: true,
39 }
40 }
41}
42
43pub struct Prefetcher {
45 call_graph: CallGraph,
46 config: PrefetchConfig,
47 prefetch_queue: Arc<Mutex<Vec<[u8; 32]>>>,
48 stats: PrefetchStats,
49}
50
51#[derive(Debug, Default, Clone)]
52pub struct PrefetchStats {
53 pub prefetch_requests: u64,
54 pub cache_hits_from_prefetch: u64,
55 pub wasted_prefetches: u64,
56}
57
58impl CallGraph {
59 pub fn new() -> Self {
61 Self {
62 edges: HashMap::new(),
63 }
64 }
65
66 pub fn build_from_dependencies(blobs: &[([u8; 32], Vec<[u8; 32]>)]) -> Self {
72 let mut edges: HashMap<[u8; 32], Vec<CallEdge>> = HashMap::new();
73
74 for (caller, callees) in blobs {
75 let mut counts: HashMap<[u8; 32], u32> = HashMap::new();
77 for callee in callees {
78 *counts.entry(*callee).or_insert(0) += 1;
79 }
80 let total: f32 = counts.values().sum::<u32>() as f32;
81
82 let mut call_edges: Vec<CallEdge> = counts
83 .into_iter()
84 .map(|(callee_hash, count)| CallEdge {
85 callee_hash,
86 static_weight: count as f32 / total,
87 dynamic_weight: 0.0,
88 })
89 .collect();
90
91 call_edges.sort_by(|a, b| {
93 b.static_weight
94 .partial_cmp(&a.static_weight)
95 .unwrap_or(std::cmp::Ordering::Equal)
96 });
97
98 edges.insert(*caller, call_edges);
99 }
100
101 Self { edges }
102 }
103
104 pub fn likely_callees(&self, hash: &[u8; 32], top_n: usize) -> Vec<CallEdge> {
109 match self.edges.get(hash) {
110 Some(edges) => {
111 let mut ranked = edges.clone();
112 ranked.sort_by(|a, b| {
113 let wa = a.static_weight + a.dynamic_weight;
114 let wb = b.static_weight + b.dynamic_weight;
115 wb.partial_cmp(&wa).unwrap_or(std::cmp::Ordering::Equal)
116 });
117 ranked.truncate(top_n);
118 ranked
119 }
120 None => Vec::new(),
121 }
122 }
123
124 pub fn update_dynamic_weight(&mut self, caller: &[u8; 32], callee: &[u8; 32], count: u64) {
129 if let Some(edges) = self.edges.get_mut(caller) {
130 let total_dynamic: f64 = edges
132 .iter()
133 .map(|e| {
134 if &e.callee_hash == callee {
135 count as f64
136 } else {
137 e.dynamic_weight as f64
138 }
139 })
140 .sum();
141
142 for edge in edges.iter_mut() {
143 if &edge.callee_hash == callee {
144 edge.dynamic_weight = if total_dynamic > 0.0 {
145 count as f32 / total_dynamic as f32
146 } else {
147 0.0
148 };
149 }
150 }
151 }
152 }
153}
154
155impl Default for CallGraph {
156 fn default() -> Self {
157 Self::new()
158 }
159}
160
161impl Prefetcher {
162 pub fn new(config: PrefetchConfig) -> Self {
164 Self {
165 call_graph: CallGraph::new(),
166 config,
167 prefetch_queue: Arc::new(Mutex::new(Vec::new())),
168 stats: PrefetchStats::default(),
169 }
170 }
171
172 pub fn build_call_graph(&mut self, blobs: &[([u8; 32], Vec<[u8; 32]>)]) {
174 self.call_graph = CallGraph::build_from_dependencies(blobs);
175 }
176
177 pub fn prefetch(&mut self, function_hash: &[u8; 32]) {
183 if !self.config.enabled {
184 return;
185 }
186
187 self.stats.prefetch_requests += 1;
188
189 let mut to_visit = vec![(*function_hash, 0usize)];
190 let mut enqueued = std::collections::HashSet::new();
191
192 while let Some((hash, depth)) = to_visit.pop() {
193 if depth >= self.config.max_prefetch_depth {
194 continue;
195 }
196
197 let callees = self
198 .call_graph
199 .likely_callees(&hash, self.config.top_n_callees);
200
201 for edge in &callees {
202 let combined = edge.static_weight + edge.dynamic_weight;
203 if combined < self.config.min_probability {
204 continue;
205 }
206 if enqueued.insert(edge.callee_hash) {
207 to_visit.push((edge.callee_hash, depth + 1));
208 }
209 }
210 }
211
212 if !enqueued.is_empty() {
213 let mut queue = self.prefetch_queue.lock().unwrap();
214 for hash in enqueued {
215 queue.push(hash);
216 }
217 }
218 }
219
220 pub fn get_prefetch_queue(&self) -> Vec<[u8; 32]> {
222 let mut queue = self.prefetch_queue.lock().unwrap();
223 std::mem::take(&mut *queue)
224 }
225
226 pub fn record_call(&mut self, caller: &[u8; 32], callee: &[u8; 32], count: u64) {
228 self.call_graph.update_dynamic_weight(caller, callee, count);
229 }
230
231 pub fn stats(&self) -> &PrefetchStats {
233 &self.stats
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 fn make_hash(val: u8) -> [u8; 32] {
242 let mut h = [0u8; 32];
243 h[0] = val;
244 h
245 }
246
247 #[test]
248 fn test_call_graph_empty() {
249 let graph = CallGraph::new();
250 let hash = make_hash(1);
251 assert!(graph.likely_callees(&hash, 4).is_empty());
252 }
253
254 #[test]
255 fn test_call_graph_build_and_likely_callees() {
256 let a = make_hash(1);
257 let b = make_hash(2);
258 let c = make_hash(3);
259
260 let blobs = vec![(a, vec![b, b, c])];
262 let graph = CallGraph::build_from_dependencies(&blobs);
263
264 let top = graph.likely_callees(&a, 2);
265 assert_eq!(top.len(), 2);
266 assert_eq!(top[0].callee_hash, b);
268 assert!((top[0].static_weight - 2.0 / 3.0).abs() < 1e-5);
269 assert_eq!(top[1].callee_hash, c);
270 assert!((top[1].static_weight - 1.0 / 3.0).abs() < 1e-5);
271 }
272
273 #[test]
274 fn test_call_graph_top_n_truncation() {
275 let a = make_hash(1);
276 let b = make_hash(2);
277 let c = make_hash(3);
278 let d = make_hash(4);
279
280 let blobs = vec![(a, vec![b, c, d])];
281 let graph = CallGraph::build_from_dependencies(&blobs);
282
283 let top = graph.likely_callees(&a, 1);
284 assert_eq!(top.len(), 1);
285 }
286
287 #[test]
288 fn test_dynamic_weight_update() {
289 let a = make_hash(1);
290 let b = make_hash(2);
291 let c = make_hash(3);
292
293 let blobs = vec![(a, vec![b, c])];
294 let mut graph = CallGraph::build_from_dependencies(&blobs);
295
296 graph.update_dynamic_weight(&a, &b, 10);
298
299 let top = graph.likely_callees(&a, 2);
300 assert_eq!(top[0].callee_hash, b);
302 assert!(top[0].dynamic_weight > 0.0);
303 }
304
305 #[test]
306 fn test_prefetcher_basic() {
307 let a = make_hash(1);
308 let b = make_hash(2);
309 let c = make_hash(3);
310
311 let blobs = vec![(a, vec![b, c])];
312
313 let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
314 prefetcher.build_call_graph(&blobs);
315 prefetcher.prefetch(&a);
316
317 let queue = prefetcher.get_prefetch_queue();
318 assert!(!queue.is_empty());
319 assert!(queue.contains(&b));
320 assert!(queue.contains(&c));
321 assert_eq!(prefetcher.stats().prefetch_requests, 1);
322 }
323
324 #[test]
325 fn test_prefetcher_disabled() {
326 let a = make_hash(1);
327 let b = make_hash(2);
328
329 let blobs = vec![(a, vec![b])];
330
331 let config = PrefetchConfig {
332 enabled: false,
333 ..Default::default()
334 };
335 let mut prefetcher = Prefetcher::new(config);
336 prefetcher.build_call_graph(&blobs);
337 prefetcher.prefetch(&a);
338
339 let queue = prefetcher.get_prefetch_queue();
340 assert!(queue.is_empty());
341 assert_eq!(prefetcher.stats().prefetch_requests, 0);
342 }
343
344 #[test]
345 fn test_prefetcher_depth_limit() {
346 let a = make_hash(1);
347 let b = make_hash(2);
348 let c = make_hash(3);
349 let d = make_hash(4);
350
351 let blobs = vec![(a, vec![b]), (b, vec![c]), (c, vec![d])];
353
354 let config = PrefetchConfig {
355 max_prefetch_depth: 2,
356 top_n_callees: 4,
357 min_probability: 0.0,
358 enabled: true,
359 };
360 let mut prefetcher = Prefetcher::new(config);
361 prefetcher.build_call_graph(&blobs);
362 prefetcher.prefetch(&a);
363
364 let queue = prefetcher.get_prefetch_queue();
365 assert!(queue.contains(&b));
366 assert!(queue.contains(&c));
367 assert!(!queue.contains(&d));
368 }
369
370 #[test]
371 fn test_prefetcher_get_queue_drains() {
372 let a = make_hash(1);
373 let b = make_hash(2);
374
375 let blobs = vec![(a, vec![b])];
376
377 let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
378 prefetcher.build_call_graph(&blobs);
379 prefetcher.prefetch(&a);
380
381 let queue1 = prefetcher.get_prefetch_queue();
382 assert!(!queue1.is_empty());
383
384 let queue2 = prefetcher.get_prefetch_queue();
386 assert!(queue2.is_empty());
387 }
388
389 #[test]
390 fn test_prefetcher_record_call() {
391 let a = make_hash(1);
392 let b = make_hash(2);
393 let c = make_hash(3);
394
395 let blobs = vec![(a, vec![b, c])];
396
397 let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
398 prefetcher.build_call_graph(&blobs);
399
400 prefetcher.record_call(&a, &c, 100);
402
403 let top = prefetcher.call_graph.likely_callees(&a, 1);
404 assert_eq!(top[0].callee_hash, c);
405 }
406
407 #[test]
408 fn test_prefetcher_min_probability_filter() {
409 let a = make_hash(1);
410 let b = make_hash(2);
411
412 let blobs = vec![(a, vec![b])];
414
415 let config = PrefetchConfig {
416 min_probability: 2.0, top_n_callees: 4,
418 max_prefetch_depth: 2,
419 enabled: true,
420 };
421 let mut prefetcher = Prefetcher::new(config);
422 prefetcher.build_call_graph(&blobs);
423 prefetcher.prefetch(&a);
424
425 let queue = prefetcher.get_prefetch_queue();
426 assert!(queue.is_empty());
427 }
428}