Skip to main content

bonsai/profiler/
mod.rs

1//! Profiler subsystem for the Bonsai adaptive spatial index.
2//!
3//! [`Profiler<C, D>`] orchestrates reservoir sampling, online statistics
4//! computation, and query workload tracking. Observations are received via a
5//! lock-free MPSC channel and processed in batches of 64.
6
7pub mod cost_model;
8pub mod policy;
9pub mod reservoir;
10pub mod stats;
11
12use std::sync::mpsc::{self, Receiver, Sender};
13
14use crate::types::{CoordType, DataShape, Point, QueryMix};
15use reservoir::ReservoirSampler;
16use stats::OnlineStats;
17
18pub use cost_model::{CostEstimate, CostModel};
19pub use policy::{MigrationDecision, PolicyEngine};
20
21/// The type of query observation sent to the profiler.
22#[derive(Debug, Clone, Copy)]
23pub enum QueryKind {
24    Range,
25    Knn,
26    Join,
27}
28
29/// An observation sent to the profiler via the MPSC channel.
30#[derive(Debug, Clone)]
31pub enum Observation<C: CoordType, const D: usize> {
32    /// A point was inserted into the index.
33    Insert(Point<C, D>),
34    /// A query was executed.
35    Query {
36        kind: QueryKind,
37        /// Fraction of total points returned (0.0–1.0).
38        selectivity: f64,
39        /// Whether the query returned any results.
40        hit: bool,
41    },
42}
43
44/// Snapshot of the query workload history.
45#[derive(Debug, Clone)]
46pub struct WorkloadHistory {
47    pub range_count: u64,
48    pub knn_count: u64,
49    pub join_count: u64,
50    pub range_hits: u64,
51    pub knn_hits: u64,
52    pub join_hits: u64,
53    pub total_selectivity: f64,
54    pub selectivity_count: u64,
55}
56
57impl WorkloadHistory {
58    fn new() -> Self {
59        Self {
60            range_count: 0,
61            knn_count: 0,
62            join_count: 0,
63            range_hits: 0,
64            knn_hits: 0,
65            join_hits: 0,
66            total_selectivity: 0.0,
67            selectivity_count: 0,
68        }
69    }
70
71    /// Compute a [`QueryMix`] from the accumulated history.
72    pub fn query_mix(&self) -> QueryMix {
73        let total = (self.range_count + self.knn_count + self.join_count) as f64;
74        if total < 1.0 {
75            return QueryMix::default();
76        }
77        let mean_selectivity = if self.selectivity_count > 0 {
78            self.total_selectivity / self.selectivity_count as f64
79        } else {
80            0.01
81        };
82        QueryMix {
83            range_frac: self.range_count as f64 / total,
84            knn_frac: self.knn_count as f64 / total,
85            join_frac: self.join_count as f64 / total,
86            mean_selectivity,
87        }
88    }
89
90    /// Per-type hit rates as `(range_hit_rate, knn_hit_rate, join_hit_rate)`.
91    pub fn hit_rates(&self) -> (f64, f64, f64) {
92        let range_hr = if self.range_count > 0 {
93            self.range_hits as f64 / self.range_count as f64
94        } else {
95            0.0
96        };
97        let knn_hr = if self.knn_count > 0 {
98            self.knn_hits as f64 / self.knn_count as f64
99        } else {
100            0.0
101        };
102        let join_hr = if self.join_count > 0 {
103            self.join_hits as f64 / self.join_count as f64
104        } else {
105            0.0
106        };
107        (range_hr, knn_hr, join_hr)
108    }
109}
110
111/// The Profiler orchestrates reservoir sampling, statistics computation, and
112/// query workload tracking.
113///
114/// Observations are sent via the [`Profiler::sender`] handle and processed
115/// in batches of 64 when [`Profiler::flush`] or [`Profiler::process_pending`]
116/// is called.
117pub struct Profiler<C: CoordType, const D: usize> {
118    sender: Sender<Observation<C, D>>,
119    receiver: Receiver<Observation<C, D>>,
120    sampler: ReservoirSampler<C, D>,
121    stats: OnlineStats<C, D>,
122    workload: WorkloadHistory,
123    last_shape: Option<DataShape<D>>,
124}
125
126/// Batch size for processing observations.
127const BATCH_SIZE: usize = 64;
128
129impl<C: CoordType, const D: usize> Profiler<C, D> {
130    /// Create a new `Profiler` with the given reservoir capacity.
131    pub fn new(reservoir_capacity: usize) -> Self {
132        let (sender, receiver) = mpsc::channel();
133        Self {
134            sender,
135            receiver,
136            sampler: ReservoirSampler::new(reservoir_capacity),
137            stats: OnlineStats::new(),
138            workload: WorkloadHistory::new(),
139            last_shape: None,
140        }
141    }
142
143    /// Create a `Profiler` with the default reservoir capacity of 4096.
144    pub fn default_capacity() -> Self {
145        Self::new(4096)
146    }
147
148    /// Return a clone of the sender handle for submitting observations.
149    ///
150    /// Multiple senders can be created from this handle for multi-producer use.
151    pub fn sender(&self) -> Sender<Observation<C, D>> {
152        self.sender.clone()
153    }
154
155    /// Submit an observation directly (bypasses the channel).
156    ///
157    /// Convenience method for single-threaded use. Call [`Profiler::flush`] to
158    /// force an immediate stats recompute.
159    pub fn observe(&mut self, obs: Observation<C, D>) {
160        self.process_observation(obs);
161    }
162
163    /// Process all pending observations from the channel in batches of 64.
164    ///
165    /// Returns the number of observations processed.
166    pub fn process_pending(&mut self) -> usize {
167        let mut total = 0;
168        loop {
169            let mut batch_count = 0;
170            while batch_count < BATCH_SIZE {
171                match self.receiver.try_recv() {
172                    Ok(obs) => {
173                        self.process_observation(obs);
174                        batch_count += 1;
175                    }
176                    Err(_) => break,
177                }
178            }
179            total += batch_count;
180            if batch_count < BATCH_SIZE {
181                break;
182            }
183        }
184        if total > 0 {
185            self.recompute_stats();
186        }
187        total
188    }
189
190    /// Flush all pending channel observations and recompute statistics.
191    pub fn flush(&mut self) {
192        self.process_pending();
193        self.recompute_stats();
194    }
195
196    /// Return the last computed [`DataShape`], or `None` if no data has been observed.
197    pub fn data_shape(&self) -> Option<&DataShape<D>> {
198        self.last_shape.as_ref()
199    }
200
201    /// Return a reference to the query workload history.
202    pub fn workload(&self) -> &WorkloadHistory {
203        &self.workload
204    }
205
206    /// Return the number of points in the reservoir.
207    pub fn reservoir_len(&self) -> usize {
208        self.sampler.len()
209    }
210
211    /// Return the total number of points observed.
212    pub fn total_observed(&self) -> usize {
213        self.sampler.total_count()
214    }
215
216    // ── Private helpers ───────────────────────────────────────────────────────
217
218    fn process_observation(&mut self, obs: Observation<C, D>) {
219        match obs {
220            Observation::Insert(point) => {
221                self.sampler.update(point);
222            }
223            Observation::Query {
224                kind,
225                selectivity,
226                hit,
227            } => {
228                self.workload.total_selectivity += selectivity;
229                self.workload.selectivity_count += 1;
230                match kind {
231                    QueryKind::Range => {
232                        self.workload.range_count += 1;
233                        if hit {
234                            self.workload.range_hits += 1;
235                        }
236                    }
237                    QueryKind::Knn => {
238                        self.workload.knn_count += 1;
239                        if hit {
240                            self.workload.knn_hits += 1;
241                        }
242                    }
243                    QueryKind::Join => {
244                        self.workload.join_count += 1;
245                        if hit {
246                            self.workload.join_hits += 1;
247                        }
248                    }
249                }
250            }
251        }
252    }
253
254    fn recompute_stats(&mut self) {
255        let query_mix = self.workload.query_mix();
256        self.last_shape = self.stats.compute(self.sampler.samples(), query_mix);
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use crate::types::Point;
264
265    #[test]
266    fn profiler_processes_inserts() {
267        let mut profiler = Profiler::<f64, 2>::new(100);
268        for i in 0..50 {
269            profiler.observe(Observation::Insert(Point::new([i as f64, i as f64])));
270        }
271        assert_eq!(profiler.reservoir_len(), 50);
272        assert_eq!(profiler.total_observed(), 50);
273    }
274
275    #[test]
276    fn profiler_channel_processing() {
277        let mut profiler = Profiler::<f64, 2>::new(100);
278        let sender = profiler.sender();
279
280        // Send observations via channel.
281        for i in 0..200 {
282            sender
283                .send(Observation::Insert(Point::new([i as f64, 0.0])))
284                .unwrap();
285        }
286        drop(sender);
287
288        let processed = profiler.process_pending();
289        assert_eq!(processed, 200);
290        assert_eq!(profiler.reservoir_len(), 100); // capped at capacity
291    }
292
293    #[test]
294    fn profiler_tracks_query_workload() {
295        let mut profiler = Profiler::<f64, 2>::new(100);
296
297        // Insert some points first.
298        for i in 0..50 {
299            profiler.observe(Observation::Insert(Point::new([i as f64, 0.0])));
300        }
301
302        // Record queries.
303        for _ in 0..60 {
304            profiler.observe(Observation::Query {
305                kind: QueryKind::Range,
306                selectivity: 0.01,
307                hit: true,
308            });
309        }
310        for _ in 0..40 {
311            profiler.observe(Observation::Query {
312                kind: QueryKind::Knn,
313                selectivity: 0.001,
314                hit: false,
315            });
316        }
317
318        let mix = profiler.workload().query_mix();
319        assert!((mix.range_frac - 0.6).abs() < 0.01);
320        assert!((mix.knn_frac - 0.4).abs() < 0.01);
321        assert!((mix.join_frac).abs() < 0.01);
322
323        let (range_hr, knn_hr, _) = profiler.workload().hit_rates();
324        assert!((range_hr - 1.0).abs() < 0.01);
325        assert!((knn_hr - 0.0).abs() < 0.01);
326    }
327
328    #[test]
329    fn profiler_data_shape_computed_after_flush() {
330        let mut profiler = Profiler::<f64, 2>::new(100);
331        for i in 0..100 {
332            profiler.observe(Observation::Insert(Point::new([i as f64, i as f64])));
333        }
334        profiler.flush();
335        assert!(profiler.data_shape().is_some());
336        let shape = profiler.data_shape().unwrap();
337        assert_eq!(shape.point_count, 100);
338    }
339
340    #[test]
341    fn profiler_batch_size_is_64() {
342        // Verify that process_pending processes in batches of 64.
343        let mut profiler = Profiler::<f64, 2>::new(1000);
344        let sender = profiler.sender();
345
346        for i in 0..128 {
347            sender
348                .send(Observation::Insert(Point::new([i as f64, 0.0])))
349                .unwrap();
350        }
351        drop(sender);
352
353        let processed = profiler.process_pending();
354        assert_eq!(processed, 128);
355    }
356}