1pub mod cost_model;
8pub mod policy;
9pub mod reservoir;
10pub mod stats;
11
12use std::sync::mpsc::{self, Receiver, Sender};
13
14use crate::types::{CoordType, DataShape, Point, QueryMix};
15use reservoir::ReservoirSampler;
16use stats::OnlineStats;
17
18pub use cost_model::{CostEstimate, CostModel};
19pub use policy::{MigrationDecision, PolicyEngine};
20
21#[derive(Debug, Clone, Copy)]
23pub enum QueryKind {
24 Range,
25 Knn,
26 Join,
27}
28
29#[derive(Debug, Clone)]
31pub enum Observation<C: CoordType, const D: usize> {
32 Insert(Point<C, D>),
34 Query {
36 kind: QueryKind,
37 selectivity: f64,
39 hit: bool,
41 },
42}
43
44#[derive(Debug, Clone)]
46pub struct WorkloadHistory {
47 pub range_count: u64,
48 pub knn_count: u64,
49 pub join_count: u64,
50 pub range_hits: u64,
51 pub knn_hits: u64,
52 pub join_hits: u64,
53 pub total_selectivity: f64,
54 pub selectivity_count: u64,
55}
56
57impl WorkloadHistory {
58 fn new() -> Self {
59 Self {
60 range_count: 0,
61 knn_count: 0,
62 join_count: 0,
63 range_hits: 0,
64 knn_hits: 0,
65 join_hits: 0,
66 total_selectivity: 0.0,
67 selectivity_count: 0,
68 }
69 }
70
71 pub fn query_mix(&self) -> QueryMix {
73 let total = (self.range_count + self.knn_count + self.join_count) as f64;
74 if total < 1.0 {
75 return QueryMix::default();
76 }
77 let mean_selectivity = if self.selectivity_count > 0 {
78 self.total_selectivity / self.selectivity_count as f64
79 } else {
80 0.01
81 };
82 QueryMix {
83 range_frac: self.range_count as f64 / total,
84 knn_frac: self.knn_count as f64 / total,
85 join_frac: self.join_count as f64 / total,
86 mean_selectivity,
87 }
88 }
89
90 pub fn hit_rates(&self) -> (f64, f64, f64) {
92 let range_hr = if self.range_count > 0 {
93 self.range_hits as f64 / self.range_count as f64
94 } else {
95 0.0
96 };
97 let knn_hr = if self.knn_count > 0 {
98 self.knn_hits as f64 / self.knn_count as f64
99 } else {
100 0.0
101 };
102 let join_hr = if self.join_count > 0 {
103 self.join_hits as f64 / self.join_count as f64
104 } else {
105 0.0
106 };
107 (range_hr, knn_hr, join_hr)
108 }
109}
110
111pub struct Profiler<C: CoordType, const D: usize> {
118 sender: Sender<Observation<C, D>>,
119 receiver: Receiver<Observation<C, D>>,
120 sampler: ReservoirSampler<C, D>,
121 stats: OnlineStats<C, D>,
122 workload: WorkloadHistory,
123 last_shape: Option<DataShape<D>>,
124}
125
126const BATCH_SIZE: usize = 64;
128
129impl<C: CoordType, const D: usize> Profiler<C, D> {
130 pub fn new(reservoir_capacity: usize) -> Self {
132 let (sender, receiver) = mpsc::channel();
133 Self {
134 sender,
135 receiver,
136 sampler: ReservoirSampler::new(reservoir_capacity),
137 stats: OnlineStats::new(),
138 workload: WorkloadHistory::new(),
139 last_shape: None,
140 }
141 }
142
143 pub fn default_capacity() -> Self {
145 Self::new(4096)
146 }
147
148 pub fn sender(&self) -> Sender<Observation<C, D>> {
152 self.sender.clone()
153 }
154
155 pub fn observe(&mut self, obs: Observation<C, D>) {
160 self.process_observation(obs);
161 }
162
163 pub fn process_pending(&mut self) -> usize {
167 let mut total = 0;
168 loop {
169 let mut batch_count = 0;
170 while batch_count < BATCH_SIZE {
171 match self.receiver.try_recv() {
172 Ok(obs) => {
173 self.process_observation(obs);
174 batch_count += 1;
175 }
176 Err(_) => break,
177 }
178 }
179 total += batch_count;
180 if batch_count < BATCH_SIZE {
181 break;
182 }
183 }
184 if total > 0 {
185 self.recompute_stats();
186 }
187 total
188 }
189
190 pub fn flush(&mut self) {
192 self.process_pending();
193 self.recompute_stats();
194 }
195
196 pub fn data_shape(&self) -> Option<&DataShape<D>> {
198 self.last_shape.as_ref()
199 }
200
201 pub fn workload(&self) -> &WorkloadHistory {
203 &self.workload
204 }
205
206 pub fn reservoir_len(&self) -> usize {
208 self.sampler.len()
209 }
210
211 pub fn total_observed(&self) -> usize {
213 self.sampler.total_count()
214 }
215
216 fn process_observation(&mut self, obs: Observation<C, D>) {
219 match obs {
220 Observation::Insert(point) => {
221 self.sampler.update(point);
222 }
223 Observation::Query {
224 kind,
225 selectivity,
226 hit,
227 } => {
228 self.workload.total_selectivity += selectivity;
229 self.workload.selectivity_count += 1;
230 match kind {
231 QueryKind::Range => {
232 self.workload.range_count += 1;
233 if hit {
234 self.workload.range_hits += 1;
235 }
236 }
237 QueryKind::Knn => {
238 self.workload.knn_count += 1;
239 if hit {
240 self.workload.knn_hits += 1;
241 }
242 }
243 QueryKind::Join => {
244 self.workload.join_count += 1;
245 if hit {
246 self.workload.join_hits += 1;
247 }
248 }
249 }
250 }
251 }
252 }
253
254 fn recompute_stats(&mut self) {
255 let query_mix = self.workload.query_mix();
256 self.last_shape = self.stats.compute(self.sampler.samples(), query_mix);
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use crate::types::Point;
264
265 #[test]
266 fn profiler_processes_inserts() {
267 let mut profiler = Profiler::<f64, 2>::new(100);
268 for i in 0..50 {
269 profiler.observe(Observation::Insert(Point::new([i as f64, i as f64])));
270 }
271 assert_eq!(profiler.reservoir_len(), 50);
272 assert_eq!(profiler.total_observed(), 50);
273 }
274
275 #[test]
276 fn profiler_channel_processing() {
277 let mut profiler = Profiler::<f64, 2>::new(100);
278 let sender = profiler.sender();
279
280 for i in 0..200 {
282 sender
283 .send(Observation::Insert(Point::new([i as f64, 0.0])))
284 .unwrap();
285 }
286 drop(sender);
287
288 let processed = profiler.process_pending();
289 assert_eq!(processed, 200);
290 assert_eq!(profiler.reservoir_len(), 100); }
292
293 #[test]
294 fn profiler_tracks_query_workload() {
295 let mut profiler = Profiler::<f64, 2>::new(100);
296
297 for i in 0..50 {
299 profiler.observe(Observation::Insert(Point::new([i as f64, 0.0])));
300 }
301
302 for _ in 0..60 {
304 profiler.observe(Observation::Query {
305 kind: QueryKind::Range,
306 selectivity: 0.01,
307 hit: true,
308 });
309 }
310 for _ in 0..40 {
311 profiler.observe(Observation::Query {
312 kind: QueryKind::Knn,
313 selectivity: 0.001,
314 hit: false,
315 });
316 }
317
318 let mix = profiler.workload().query_mix();
319 assert!((mix.range_frac - 0.6).abs() < 0.01);
320 assert!((mix.knn_frac - 0.4).abs() < 0.01);
321 assert!((mix.join_frac).abs() < 0.01);
322
323 let (range_hr, knn_hr, _) = profiler.workload().hit_rates();
324 assert!((range_hr - 1.0).abs() < 0.01);
325 assert!((knn_hr - 0.0).abs() < 0.01);
326 }
327
328 #[test]
329 fn profiler_data_shape_computed_after_flush() {
330 let mut profiler = Profiler::<f64, 2>::new(100);
331 for i in 0..100 {
332 profiler.observe(Observation::Insert(Point::new([i as f64, i as f64])));
333 }
334 profiler.flush();
335 assert!(profiler.data_shape().is_some());
336 let shape = profiler.data_shape().unwrap();
337 assert_eq!(shape.point_count, 100);
338 }
339
340 #[test]
341 fn profiler_batch_size_is_64() {
342 let mut profiler = Profiler::<f64, 2>::new(1000);
344 let sender = profiler.sender();
345
346 for i in 0..128 {
347 sender
348 .send(Observation::Insert(Point::new([i as f64, 0.0])))
349 .unwrap();
350 }
351 drop(sender);
352
353 let processed = profiler.process_pending();
354 assert_eq!(processed, 128);
355 }
356}