Skip to main content

shape_runtime/
blob_prefetch.rs

1//! Speculative prefetch for content-addressed function blobs.
2//!
3//! Builds a call probability graph from FunctionBlob dependencies.
4//! On function entry, prefetches top-N likely callees in background
5//! to warm blob cache and JIT cache ahead of execution.
6
7use std::collections::HashMap;
8use std::sync::{Arc, Mutex};
9
10/// Call probability graph built from blob dependencies.
11pub struct CallGraph {
12    /// For each function hash, the set of functions it may call.
13    edges: HashMap<[u8; 32], Vec<CallEdge>>,
14}
15
16#[derive(Debug, Clone)]
17pub struct CallEdge {
18    pub callee_hash: [u8; 32],
19    pub static_weight: f32,
20    pub dynamic_weight: f32,
21}
22
23/// Prefetch configuration.
24#[derive(Debug, Clone)]
25pub struct PrefetchConfig {
26    pub max_prefetch_depth: usize,
27    pub top_n_callees: usize,
28    pub min_probability: f32,
29    pub enabled: bool,
30}
31
32impl Default for PrefetchConfig {
33    fn default() -> Self {
34        Self {
35            max_prefetch_depth: 2,
36            top_n_callees: 4,
37            min_probability: 0.1,
38            enabled: true,
39        }
40    }
41}
42
43/// Speculative prefetcher that warms caches ahead of execution.
44pub struct Prefetcher {
45    call_graph: CallGraph,
46    config: PrefetchConfig,
47    prefetch_queue: Arc<Mutex<Vec<[u8; 32]>>>,
48    stats: PrefetchStats,
49}
50
51#[derive(Debug, Default, Clone)]
52pub struct PrefetchStats {
53    pub prefetch_requests: u64,
54    pub cache_hits_from_prefetch: u64,
55    pub wasted_prefetches: u64,
56}
57
58impl CallGraph {
59    /// Create an empty call graph.
60    pub fn new() -> Self {
61        Self {
62            edges: HashMap::new(),
63        }
64    }
65
66    /// Build the call graph from blob dependency information.
67    ///
68    /// Each entry is a function hash paired with the list of function hashes
69    /// it may call. Static weights are assigned proportionally to the number
70    /// of call sites targeting each callee.
71    pub fn build_from_dependencies(blobs: &[([u8; 32], Vec<[u8; 32]>)]) -> Self {
72        let mut edges: HashMap<[u8; 32], Vec<CallEdge>> = HashMap::new();
73
74        for (caller, callees) in blobs {
75            // Count occurrences of each callee to derive static weights.
76            let mut counts: HashMap<[u8; 32], u32> = HashMap::new();
77            for callee in callees {
78                *counts.entry(*callee).or_insert(0) += 1;
79            }
80            let total: f32 = counts.values().sum::<u32>() as f32;
81
82            let mut call_edges: Vec<CallEdge> = counts
83                .into_iter()
84                .map(|(callee_hash, count)| CallEdge {
85                    callee_hash,
86                    static_weight: count as f32 / total,
87                    dynamic_weight: 0.0,
88                })
89                .collect();
90
91            // Sort by static weight descending for fast top-N access.
92            call_edges.sort_by(|a, b| {
93                b.static_weight
94                    .partial_cmp(&a.static_weight)
95                    .unwrap_or(std::cmp::Ordering::Equal)
96            });
97
98            edges.insert(*caller, call_edges);
99        }
100
101        Self { edges }
102    }
103
104    /// Return the top-N most likely callees for a given function hash.
105    ///
106    /// Edges are ranked by combined weight (static + dynamic). Returns an
107    /// empty vec if the function hash is not in the graph.
108    pub fn likely_callees(&self, hash: &[u8; 32], top_n: usize) -> Vec<CallEdge> {
109        match self.edges.get(hash) {
110            Some(edges) => {
111                let mut ranked = edges.clone();
112                ranked.sort_by(|a, b| {
113                    let wa = a.static_weight + a.dynamic_weight;
114                    let wb = b.static_weight + b.dynamic_weight;
115                    wb.partial_cmp(&wa).unwrap_or(std::cmp::Ordering::Equal)
116                });
117                ranked.truncate(top_n);
118                ranked
119            }
120            None => Vec::new(),
121        }
122    }
123
124    /// Update the dynamic weight for a specific caller->callee edge.
125    ///
126    /// The `count` is used as a raw signal that gets normalized into a
127    /// weight relative to total observed calls from this caller.
128    pub fn update_dynamic_weight(&mut self, caller: &[u8; 32], callee: &[u8; 32], count: u64) {
129        if let Some(edges) = self.edges.get_mut(caller) {
130            // Compute total dynamic counts for normalization.
131            let total_dynamic: f64 = edges
132                .iter()
133                .map(|e| {
134                    if &e.callee_hash == callee {
135                        count as f64
136                    } else {
137                        e.dynamic_weight as f64
138                    }
139                })
140                .sum();
141
142            for edge in edges.iter_mut() {
143                if &edge.callee_hash == callee {
144                    edge.dynamic_weight = if total_dynamic > 0.0 {
145                        count as f32 / total_dynamic as f32
146                    } else {
147                        0.0
148                    };
149                }
150            }
151        }
152    }
153}
154
155impl Default for CallGraph {
156    fn default() -> Self {
157        Self::new()
158    }
159}
160
161impl Prefetcher {
162    /// Create a new prefetcher with the given configuration and an empty call graph.
163    pub fn new(config: PrefetchConfig) -> Self {
164        Self {
165            call_graph: CallGraph::new(),
166            config,
167            prefetch_queue: Arc::new(Mutex::new(Vec::new())),
168            stats: PrefetchStats::default(),
169        }
170    }
171
172    /// Build (or replace) the internal call graph from blob dependency data.
173    pub fn build_call_graph(&mut self, blobs: &[([u8; 32], Vec<[u8; 32]>)]) {
174        self.call_graph = CallGraph::build_from_dependencies(blobs);
175    }
176
177    /// Enqueue the top-N likely callees for speculative prefetch.
178    ///
179    /// Walks up to `max_prefetch_depth` levels in the call graph, collecting
180    /// callees whose combined weight exceeds `min_probability`. Hashes are
181    /// appended to the internal prefetch queue for the cache layer to consume.
182    pub fn prefetch(&mut self, function_hash: &[u8; 32]) {
183        if !self.config.enabled {
184            return;
185        }
186
187        self.stats.prefetch_requests += 1;
188
189        let mut to_visit = vec![(*function_hash, 0usize)];
190        let mut enqueued = std::collections::HashSet::new();
191
192        while let Some((hash, depth)) = to_visit.pop() {
193            if depth >= self.config.max_prefetch_depth {
194                continue;
195            }
196
197            let callees = self
198                .call_graph
199                .likely_callees(&hash, self.config.top_n_callees);
200
201            for edge in &callees {
202                let combined = edge.static_weight + edge.dynamic_weight;
203                if combined < self.config.min_probability {
204                    continue;
205                }
206                if enqueued.insert(edge.callee_hash) {
207                    to_visit.push((edge.callee_hash, depth + 1));
208                }
209            }
210        }
211
212        if !enqueued.is_empty() {
213            let mut queue = self.prefetch_queue.lock().unwrap();
214            for hash in enqueued {
215                queue.push(hash);
216            }
217        }
218    }
219
220    /// Consume and return all hashes currently in the prefetch queue.
221    pub fn get_prefetch_queue(&self) -> Vec<[u8; 32]> {
222        let mut queue = self.prefetch_queue.lock().unwrap();
223        std::mem::take(&mut *queue)
224    }
225
226    /// Record an observed call from `caller` to `callee`, updating dynamic weights.
227    pub fn record_call(&mut self, caller: &[u8; 32], callee: &[u8; 32], count: u64) {
228        self.call_graph.update_dynamic_weight(caller, callee, count);
229    }
230
231    /// Return a reference to the current prefetch statistics.
232    pub fn stats(&self) -> &PrefetchStats {
233        &self.stats
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    fn make_hash(val: u8) -> [u8; 32] {
242        let mut h = [0u8; 32];
243        h[0] = val;
244        h
245    }
246
247    #[test]
248    fn test_call_graph_empty() {
249        let graph = CallGraph::new();
250        let hash = make_hash(1);
251        assert!(graph.likely_callees(&hash, 4).is_empty());
252    }
253
254    #[test]
255    fn test_call_graph_build_and_likely_callees() {
256        let a = make_hash(1);
257        let b = make_hash(2);
258        let c = make_hash(3);
259
260        // a calls b twice and c once => b has higher static weight
261        let blobs = vec![(a, vec![b, b, c])];
262        let graph = CallGraph::build_from_dependencies(&blobs);
263
264        let top = graph.likely_callees(&a, 2);
265        assert_eq!(top.len(), 2);
266        // b should be first (2/3 > 1/3)
267        assert_eq!(top[0].callee_hash, b);
268        assert!((top[0].static_weight - 2.0 / 3.0).abs() < 1e-5);
269        assert_eq!(top[1].callee_hash, c);
270        assert!((top[1].static_weight - 1.0 / 3.0).abs() < 1e-5);
271    }
272
273    #[test]
274    fn test_call_graph_top_n_truncation() {
275        let a = make_hash(1);
276        let b = make_hash(2);
277        let c = make_hash(3);
278        let d = make_hash(4);
279
280        let blobs = vec![(a, vec![b, c, d])];
281        let graph = CallGraph::build_from_dependencies(&blobs);
282
283        let top = graph.likely_callees(&a, 1);
284        assert_eq!(top.len(), 1);
285    }
286
287    #[test]
288    fn test_dynamic_weight_update() {
289        let a = make_hash(1);
290        let b = make_hash(2);
291        let c = make_hash(3);
292
293        let blobs = vec![(a, vec![b, c])];
294        let mut graph = CallGraph::build_from_dependencies(&blobs);
295
296        // Initially equal static weights (0.5 each). Boost b dynamically.
297        graph.update_dynamic_weight(&a, &b, 10);
298
299        let top = graph.likely_callees(&a, 2);
300        // b should now rank higher due to dynamic weight
301        assert_eq!(top[0].callee_hash, b);
302        assert!(top[0].dynamic_weight > 0.0);
303    }
304
305    #[test]
306    fn test_prefetcher_basic() {
307        let a = make_hash(1);
308        let b = make_hash(2);
309        let c = make_hash(3);
310
311        let blobs = vec![(a, vec![b, c])];
312
313        let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
314        prefetcher.build_call_graph(&blobs);
315        prefetcher.prefetch(&a);
316
317        let queue = prefetcher.get_prefetch_queue();
318        assert!(!queue.is_empty());
319        assert!(queue.contains(&b));
320        assert!(queue.contains(&c));
321        assert_eq!(prefetcher.stats().prefetch_requests, 1);
322    }
323
324    #[test]
325    fn test_prefetcher_disabled() {
326        let a = make_hash(1);
327        let b = make_hash(2);
328
329        let blobs = vec![(a, vec![b])];
330
331        let config = PrefetchConfig {
332            enabled: false,
333            ..Default::default()
334        };
335        let mut prefetcher = Prefetcher::new(config);
336        prefetcher.build_call_graph(&blobs);
337        prefetcher.prefetch(&a);
338
339        let queue = prefetcher.get_prefetch_queue();
340        assert!(queue.is_empty());
341        assert_eq!(prefetcher.stats().prefetch_requests, 0);
342    }
343
344    #[test]
345    fn test_prefetcher_depth_limit() {
346        let a = make_hash(1);
347        let b = make_hash(2);
348        let c = make_hash(3);
349        let d = make_hash(4);
350
351        // a -> b -> c -> d, depth limit = 2 should reach b and c but not d
352        let blobs = vec![(a, vec![b]), (b, vec![c]), (c, vec![d])];
353
354        let config = PrefetchConfig {
355            max_prefetch_depth: 2,
356            top_n_callees: 4,
357            min_probability: 0.0,
358            enabled: true,
359        };
360        let mut prefetcher = Prefetcher::new(config);
361        prefetcher.build_call_graph(&blobs);
362        prefetcher.prefetch(&a);
363
364        let queue = prefetcher.get_prefetch_queue();
365        assert!(queue.contains(&b));
366        assert!(queue.contains(&c));
367        assert!(!queue.contains(&d));
368    }
369
370    #[test]
371    fn test_prefetcher_get_queue_drains() {
372        let a = make_hash(1);
373        let b = make_hash(2);
374
375        let blobs = vec![(a, vec![b])];
376
377        let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
378        prefetcher.build_call_graph(&blobs);
379        prefetcher.prefetch(&a);
380
381        let queue1 = prefetcher.get_prefetch_queue();
382        assert!(!queue1.is_empty());
383
384        // Second call should return empty (queue was drained).
385        let queue2 = prefetcher.get_prefetch_queue();
386        assert!(queue2.is_empty());
387    }
388
389    #[test]
390    fn test_prefetcher_record_call() {
391        let a = make_hash(1);
392        let b = make_hash(2);
393        let c = make_hash(3);
394
395        let blobs = vec![(a, vec![b, c])];
396
397        let mut prefetcher = Prefetcher::new(PrefetchConfig::default());
398        prefetcher.build_call_graph(&blobs);
399
400        // Record many calls to c, boosting its dynamic weight.
401        prefetcher.record_call(&a, &c, 100);
402
403        let top = prefetcher.call_graph.likely_callees(&a, 1);
404        assert_eq!(top[0].callee_hash, c);
405    }
406
407    #[test]
408    fn test_prefetcher_min_probability_filter() {
409        let a = make_hash(1);
410        let b = make_hash(2);
411
412        // Single callee with static_weight 1.0
413        let blobs = vec![(a, vec![b])];
414
415        let config = PrefetchConfig {
416            min_probability: 2.0, // impossibly high threshold
417            top_n_callees: 4,
418            max_prefetch_depth: 2,
419            enabled: true,
420        };
421        let mut prefetcher = Prefetcher::new(config);
422        prefetcher.build_call_graph(&blobs);
423        prefetcher.prefetch(&a);
424
425        let queue = prefetcher.get_prefetch_queue();
426        assert!(queue.is_empty());
427    }
428}