heliosdb_proxy/schema_routing/
classifier.rs1use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use super::registry::{DataTemperature, SchemaRegistry, WorkloadType};
12
13#[derive(Debug)]
15pub struct LearningClassifier {
16 history: DashMap<String, QueryHistory>,
18 model: Arc<RwLock<ClassificationModel>>,
20 schema: Arc<SchemaRegistry>,
22 config: ClassifierConfig,
24}
25
26impl LearningClassifier {
27 pub fn new(schema: Arc<SchemaRegistry>) -> Self {
29 Self {
30 history: DashMap::new(),
31 model: Arc::new(RwLock::new(ClassificationModel::new())),
32 schema,
33 config: ClassifierConfig::default(),
34 }
35 }
36
37 pub fn with_config(schema: Arc<SchemaRegistry>, config: ClassifierConfig) -> Self {
39 Self {
40 history: DashMap::new(),
41 model: Arc::new(RwLock::new(ClassificationModel::new())),
42 schema,
43 config,
44 }
45 }
46
47 pub fn record(&self, table: &str, query_type: QueryType, latency: Duration) {
49 let mut history = self.history.entry(table.to_string()).or_default();
50
51 history.record(query_type, latency);
52
53 if history.count() % self.config.reclassification_threshold == 0 {
55 self.reclassify(table);
56 }
57 }
58
59 pub fn reclassify(&self, table: &str) {
61 let history = match self.history.get(table) {
62 Some(h) => h.clone(),
63 None => return,
64 };
65
66 let model = self.model.read();
67
68 let temperature = model.classify_temperature(&history);
70
71 let workload = model.classify_workload(&history);
73
74 self.schema
76 .update_classification(table, temperature, workload);
77 }
78
79 pub fn get_classification(&self, table: &str) -> Option<TableClassification> {
81 let history = self.history.get(table)?;
82 let model = self.model.read();
83
84 Some(TableClassification {
85 table: table.to_string(),
86 temperature: model.classify_temperature(&history),
87 workload: model.classify_workload(&history),
88 confidence: model.classification_confidence(&history),
89 query_count: history.count(),
90 last_updated: history.last_updated(),
91 })
92 }
93
94 pub fn all_classifications(&self) -> Vec<TableClassification> {
96 self.history
97 .iter()
98 .map(|entry| {
99 let table = entry.key();
100 let history = entry.value();
101 let model = self.model.read();
102
103 TableClassification {
104 table: table.clone(),
105 temperature: model.classify_temperature(history),
106 workload: model.classify_workload(history),
107 confidence: model.classification_confidence(history),
108 query_count: history.count(),
109 last_updated: history.last_updated(),
110 }
111 })
112 .collect()
113 }
114
115 pub fn update_thresholds(&self, thresholds: ModelThresholds) {
117 let mut model = self.model.write();
118 model.thresholds = thresholds;
119 }
120
121 pub fn get_history(&self, table: &str) -> Option<QueryHistory> {
123 self.history.get(table).map(|h| h.clone())
124 }
125
126 pub fn clear_history(&self, table: &str) {
128 self.history.remove(table);
129 }
130
131 pub fn clear_all(&self) {
133 self.history.clear();
134 }
135
136 pub fn query_count(&self) -> u64 {
138 self.history.iter().map(|h| h.value().count()).sum()
139 }
140
141 pub fn suggest_temperature(&self, table: &str) -> Option<DataTemperature> {
143 let history = self.history.get(table)?;
144 let model = self.model.read();
145 Some(model.classify_temperature(&history))
146 }
147
148 pub fn suggest_workload(&self, table: &str) -> Option<WorkloadType> {
150 let history = self.history.get(table)?;
151 let model = self.model.read();
152 Some(model.classify_workload(&history))
153 }
154
155 pub fn get_confidence(&self, table: &str) -> Option<f64> {
157 let history = self.history.get(table)?;
158 let model = self.model.read();
159 Some(model.classification_confidence(&history))
160 }
161
162 pub fn classify_query(&self, sql: &str) -> Option<WorkloadType> {
164 let query_type = QueryType::from_sql(sql);
165
166 Some(match query_type {
167 QueryType::VectorSearch => WorkloadType::Vector,
168 QueryType::AggregateSelect | QueryType::JoinSelect => WorkloadType::OLAP,
169 QueryType::SimpleSelect => WorkloadType::OLTP,
170 QueryType::Insert | QueryType::Update | QueryType::Delete => WorkloadType::OLTP,
171 })
172 }
173}
174
175#[derive(Debug, Clone)]
177pub struct ClassifierConfig {
178 pub reclassification_threshold: u64,
180 pub rate_window: Duration,
182 pub min_queries: u64,
184}
185
186impl Default for ClassifierConfig {
187 fn default() -> Self {
188 Self {
189 reclassification_threshold: 1000,
190 rate_window: Duration::from_secs(60),
191 min_queries: 100,
192 }
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct QueryHistory {
199 total_count: u64,
201 read_count: u64,
203 write_count: u64,
205 type_counts: HashMap<QueryType, u64>,
207 latencies: Vec<Duration>,
209 qpm_samples: Vec<(Instant, u64)>,
211 #[allow(dead_code)]
213 created: Instant,
214 last_updated: Instant,
216}
217
218impl QueryHistory {
219 pub fn new() -> Self {
221 let now = Instant::now();
222 Self {
223 total_count: 0,
224 read_count: 0,
225 write_count: 0,
226 type_counts: HashMap::new(),
227 latencies: Vec::new(),
228 qpm_samples: Vec::new(),
229 created: now,
230 last_updated: now,
231 }
232 }
233
234 pub fn record(&mut self, query_type: QueryType, latency: Duration) {
236 self.total_count += 1;
237 self.last_updated = Instant::now();
238
239 *self.type_counts.entry(query_type).or_insert(0) += 1;
241
242 if query_type.is_read() {
244 self.read_count += 1;
245 } else {
246 self.write_count += 1;
247 }
248
249 if self.latencies.len() >= 1000 {
251 self.latencies.remove(0);
252 }
253 self.latencies.push(latency);
254
255 self.update_qpm();
257 }
258
259 fn update_qpm(&mut self) {
261 let now = Instant::now();
262
263 self.qpm_samples
265 .retain(|(t, _)| now.duration_since(*t) < Duration::from_secs(300));
266
267 self.qpm_samples.push((now, self.total_count));
269 }
270
271 pub fn count(&self) -> u64 {
273 self.total_count
274 }
275
276 pub fn qpm(&self) -> f64 {
278 if self.qpm_samples.len() < 2 {
279 return 0.0;
280 }
281
282 let first = self.qpm_samples.first().expect("checked len");
283 let last = self.qpm_samples.last().expect("checked len");
284
285 let duration = last.0.duration_since(first.0);
286 if duration.as_secs() == 0 {
287 return 0.0;
288 }
289
290 let queries = last.1 - first.1;
291 (queries as f64 / duration.as_secs_f64()) * 60.0
292 }
293
294 pub fn read_write_ratio(&self) -> f64 {
296 if self.write_count == 0 {
297 return f64::INFINITY;
298 }
299 self.read_count as f64 / self.write_count as f64
300 }
301
302 pub fn avg_latency(&self) -> Duration {
304 if self.latencies.is_empty() {
305 return Duration::ZERO;
306 }
307
308 let sum: Duration = self.latencies.iter().sum();
309 sum / self.latencies.len() as u32
310 }
311
312 pub fn p95_latency(&self) -> Duration {
314 if self.latencies.is_empty() {
315 return Duration::ZERO;
316 }
317
318 let mut sorted = self.latencies.clone();
319 sorted.sort();
320
321 let idx = (sorted.len() as f64 * 0.95) as usize;
322 sorted
323 .get(idx.min(sorted.len() - 1))
324 .copied()
325 .unwrap_or(Duration::ZERO)
326 }
327
328 pub fn last_updated(&self) -> Instant {
330 self.last_updated
331 }
332
333 pub fn type_count(&self, query_type: QueryType) -> u64 {
335 self.type_counts.get(&query_type).copied().unwrap_or(0)
336 }
337
338 pub fn type_fraction(&self, query_type: QueryType) -> f64 {
340 if self.total_count == 0 {
341 return 0.0;
342 }
343 self.type_count(query_type) as f64 / self.total_count as f64
344 }
345}
346
347impl Default for QueryHistory {
348 fn default() -> Self {
349 Self::new()
350 }
351}
352
353#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
355pub enum QueryType {
356 SimpleSelect,
358 AggregateSelect,
360 JoinSelect,
362 VectorSearch,
364 Insert,
366 Update,
368 Delete,
370}
371
372impl QueryType {
373 pub fn is_read(&self) -> bool {
375 matches!(
376 self,
377 QueryType::SimpleSelect
378 | QueryType::AggregateSelect
379 | QueryType::JoinSelect
380 | QueryType::VectorSearch
381 )
382 }
383
384 pub fn is_write(&self) -> bool {
386 !self.is_read()
387 }
388
389 pub fn is_olap(&self) -> bool {
391 matches!(self, QueryType::AggregateSelect | QueryType::JoinSelect)
392 }
393
394 pub fn from_sql(sql: &str) -> Self {
396 let upper = sql.to_uppercase();
397
398 if upper.starts_with("INSERT") {
399 QueryType::Insert
400 } else if upper.starts_with("UPDATE") {
401 QueryType::Update
402 } else if upper.starts_with("DELETE") {
403 QueryType::Delete
404 } else if upper.contains("<->") || upper.contains("VECTOR") || upper.contains("EMBEDDING") {
405 QueryType::VectorSearch
406 } else if upper.contains("COUNT(") || upper.contains("SUM(") || upper.contains("AVG(") {
407 QueryType::AggregateSelect
408 } else if upper.contains(" JOIN ") {
409 QueryType::JoinSelect
410 } else {
411 QueryType::SimpleSelect
412 }
413 }
414}
415
416#[derive(Debug)]
418pub struct ClassificationModel {
419 pub thresholds: ModelThresholds,
421}
422
423impl ClassificationModel {
424 pub fn new() -> Self {
426 Self {
427 thresholds: ModelThresholds::default(),
428 }
429 }
430
431 pub fn classify_temperature(&self, history: &QueryHistory) -> DataTemperature {
433 let qpm = history.qpm();
434
435 if qpm > self.thresholds.hot_qpm {
436 DataTemperature::Hot
437 } else if qpm > self.thresholds.warm_qpm {
438 DataTemperature::Warm
439 } else if qpm > self.thresholds.cold_qpm {
440 DataTemperature::Cold
441 } else {
442 DataTemperature::Frozen
443 }
444 }
445
446 pub fn classify_workload(&self, history: &QueryHistory) -> WorkloadType {
448 if history.type_fraction(QueryType::VectorSearch) > 0.3 {
450 return WorkloadType::Vector;
451 }
452
453 let rw_ratio = history.read_write_ratio();
455
456 if rw_ratio > self.thresholds.olap_ratio {
457 if history.type_fraction(QueryType::AggregateSelect) > 0.2 {
459 return WorkloadType::OLAP;
460 }
461 }
462
463 if rw_ratio < self.thresholds.oltp_ratio {
464 return WorkloadType::OLTP;
466 }
467
468 if history.qpm() > 100.0 && rw_ratio > 1.0 && rw_ratio < 10.0 {
470 return WorkloadType::HTAP;
471 }
472
473 WorkloadType::Mixed
474 }
475
476 pub fn classification_confidence(&self, history: &QueryHistory) -> f64 {
478 let query_factor = (history.count() as f64 / 1000.0).min(1.0);
480
481 let rw_ratio = history.read_write_ratio();
483 let pattern_factor = if !(2.0..=10.0).contains(&rw_ratio) {
484 0.8
485 } else {
486 0.5
487 };
488
489 query_factor * pattern_factor
490 }
491}
492
493impl Default for ClassificationModel {
494 fn default() -> Self {
495 Self::new()
496 }
497}
498
499#[derive(Debug, Clone)]
501pub struct ModelThresholds {
502 pub hot_qpm: f64,
504 pub warm_qpm: f64,
506 pub cold_qpm: f64,
508 pub olap_ratio: f64,
510 pub oltp_ratio: f64,
512}
513
514impl Default for ModelThresholds {
515 fn default() -> Self {
516 Self {
517 hot_qpm: 1000.0,
518 warm_qpm: 100.0,
519 cold_qpm: 10.0,
520 olap_ratio: 10.0,
521 oltp_ratio: 2.0,
522 }
523 }
524}
525
526#[derive(Debug, Clone)]
528pub struct TableClassification {
529 pub table: String,
531 pub temperature: DataTemperature,
533 pub workload: WorkloadType,
535 pub confidence: f64,
537 pub query_count: u64,
539 pub last_updated: Instant,
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_query_history() {
549 let mut history = QueryHistory::new();
550
551 history.record(QueryType::SimpleSelect, Duration::from_millis(10));
552 history.record(QueryType::SimpleSelect, Duration::from_millis(20));
553 history.record(QueryType::Insert, Duration::from_millis(30));
554
555 assert_eq!(history.count(), 3);
556 assert_eq!(history.read_count, 2);
557 assert_eq!(history.write_count, 1);
558 assert_eq!(history.read_write_ratio(), 2.0);
559 }
560
561 #[test]
562 fn test_query_type_detection() {
563 assert_eq!(
564 QueryType::from_sql("INSERT INTO users VALUES (1)"),
565 QueryType::Insert
566 );
567 assert_eq!(
568 QueryType::from_sql("UPDATE users SET name = 'x'"),
569 QueryType::Update
570 );
571 assert_eq!(QueryType::from_sql("DELETE FROM users"), QueryType::Delete);
572 assert_eq!(
573 QueryType::from_sql("SELECT COUNT(*) FROM users"),
574 QueryType::AggregateSelect
575 );
576 assert_eq!(
577 QueryType::from_sql("SELECT * FROM users"),
578 QueryType::SimpleSelect
579 );
580 assert_eq!(
581 QueryType::from_sql("SELECT * FROM a JOIN b ON a.id = b.id"),
582 QueryType::JoinSelect
583 );
584 }
585
586 #[test]
587 fn test_classification_model() {
588 let model = ClassificationModel::new();
589 let mut history = QueryHistory::new();
590
591 for _ in 0..1000 {
593 history.record(QueryType::SimpleSelect, Duration::from_millis(5));
594 }
595 for _ in 0..50 {
597 history.record(QueryType::Insert, Duration::from_millis(10));
598 }
599
600 let workload = model.classify_workload(&history);
601 assert!(workload == WorkloadType::OLAP || workload == WorkloadType::Mixed);
603 }
604
605 #[test]
606 fn test_learning_classifier() {
607 let registry = Arc::new(SchemaRegistry::new());
608 let classifier = LearningClassifier::new(registry);
609
610 for _ in 0..100 {
611 classifier.record("users", QueryType::SimpleSelect, Duration::from_millis(5));
612 }
613
614 let classification = classifier.get_classification("users");
615 assert!(classification.is_some());
616 assert_eq!(classification.as_ref().map(|c| c.query_count), Some(100));
617 }
618
619 #[test]
620 fn test_temperature_classification() {
621 let model = ClassificationModel::new();
622 let mut history = QueryHistory::new();
623
624 for _ in 0..1000 {
626 history.record(QueryType::SimpleSelect, Duration::from_millis(1));
627 }
628 let temp = model.classify_temperature(&history);
634 assert!(
635 temp == DataTemperature::Hot
636 || temp == DataTemperature::Warm
637 || temp == DataTemperature::Cold
638 || temp == DataTemperature::Frozen
639 );
640 }
641
642 #[test]
643 fn test_latency_tracking() {
644 let mut history = QueryHistory::new();
645
646 for i in 0..100 {
647 history.record(QueryType::SimpleSelect, Duration::from_millis(i));
648 }
649
650 let avg = history.avg_latency();
651 assert!(avg.as_millis() > 0);
652
653 let p95 = history.p95_latency();
654 assert!(p95 >= avg);
655 }
656}