heliosdb_proxy/schema_routing/
classifier.rs1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10
11use super::registry::{SchemaRegistry, DataTemperature, WorkloadType};
12
13#[derive(Debug)]
15pub struct LearningClassifier {
16 history: DashMap<String, QueryHistory>,
18 model: Arc<RwLock<ClassificationModel>>,
20 schema: Arc<SchemaRegistry>,
22 config: ClassifierConfig,
24}
25
26impl LearningClassifier {
27 pub fn new(schema: Arc<SchemaRegistry>) -> Self {
29 Self {
30 history: DashMap::new(),
31 model: Arc::new(RwLock::new(ClassificationModel::new())),
32 schema,
33 config: ClassifierConfig::default(),
34 }
35 }
36
37 pub fn with_config(schema: Arc<SchemaRegistry>, config: ClassifierConfig) -> Self {
39 Self {
40 history: DashMap::new(),
41 model: Arc::new(RwLock::new(ClassificationModel::new())),
42 schema,
43 config,
44 }
45 }
46
47 pub fn record(&self, table: &str, query_type: QueryType, latency: Duration) {
49 let mut history = self.history
50 .entry(table.to_string())
51 .or_insert_with(QueryHistory::new);
52
53 history.record(query_type, latency);
54
55 if history.count() % self.config.reclassification_threshold == 0 {
57 self.reclassify(table);
58 }
59 }
60
61 pub fn reclassify(&self, table: &str) {
63 let history = match self.history.get(table) {
64 Some(h) => h.clone(),
65 None => return,
66 };
67
68 let model = self.model.read();
69
70 let temperature = model.classify_temperature(&history);
72
73 let workload = model.classify_workload(&history);
75
76 self.schema.update_classification(table, temperature, workload);
78 }
79
80 pub fn get_classification(&self, table: &str) -> Option<TableClassification> {
82 let history = self.history.get(table)?;
83 let model = self.model.read();
84
85 Some(TableClassification {
86 table: table.to_string(),
87 temperature: model.classify_temperature(&history),
88 workload: model.classify_workload(&history),
89 confidence: model.classification_confidence(&history),
90 query_count: history.count(),
91 last_updated: history.last_updated(),
92 })
93 }
94
95 pub fn all_classifications(&self) -> Vec<TableClassification> {
97 self.history
98 .iter()
99 .map(|entry| {
100 let table = entry.key();
101 let history = entry.value();
102 let model = self.model.read();
103
104 TableClassification {
105 table: table.clone(),
106 temperature: model.classify_temperature(history),
107 workload: model.classify_workload(history),
108 confidence: model.classification_confidence(history),
109 query_count: history.count(),
110 last_updated: history.last_updated(),
111 }
112 })
113 .collect()
114 }
115
116 pub fn update_thresholds(&self, thresholds: ModelThresholds) {
118 let mut model = self.model.write();
119 model.thresholds = thresholds;
120 }
121
122 pub fn get_history(&self, table: &str) -> Option<QueryHistory> {
124 self.history.get(table).map(|h| h.clone())
125 }
126
127 pub fn clear_history(&self, table: &str) {
129 self.history.remove(table);
130 }
131
132 pub fn clear_all(&self) {
134 self.history.clear();
135 }
136
137 pub fn query_count(&self) -> u64 {
139 self.history.iter().map(|h| h.value().count()).sum()
140 }
141
142 pub fn suggest_temperature(&self, table: &str) -> Option<DataTemperature> {
144 let history = self.history.get(table)?;
145 let model = self.model.read();
146 Some(model.classify_temperature(&history))
147 }
148
149 pub fn suggest_workload(&self, table: &str) -> Option<WorkloadType> {
151 let history = self.history.get(table)?;
152 let model = self.model.read();
153 Some(model.classify_workload(&history))
154 }
155
156 pub fn get_confidence(&self, table: &str) -> Option<f64> {
158 let history = self.history.get(table)?;
159 let model = self.model.read();
160 Some(model.classification_confidence(&history))
161 }
162
163 pub fn classify_query(&self, sql: &str) -> Option<WorkloadType> {
165 let query_type = QueryType::from_sql(sql);
166
167 Some(match query_type {
168 QueryType::VectorSearch => WorkloadType::Vector,
169 QueryType::AggregateSelect | QueryType::JoinSelect => WorkloadType::OLAP,
170 QueryType::SimpleSelect => WorkloadType::OLTP,
171 QueryType::Insert | QueryType::Update | QueryType::Delete => WorkloadType::OLTP,
172 })
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct ClassifierConfig {
179 pub reclassification_threshold: u64,
181 pub rate_window: Duration,
183 pub min_queries: u64,
185}
186
187impl Default for ClassifierConfig {
188 fn default() -> Self {
189 Self {
190 reclassification_threshold: 1000,
191 rate_window: Duration::from_secs(60),
192 min_queries: 100,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct QueryHistory {
200 total_count: u64,
202 read_count: u64,
204 write_count: u64,
206 type_counts: HashMap<QueryType, u64>,
208 latencies: Vec<Duration>,
210 qpm_samples: Vec<(Instant, u64)>,
212 created: Instant,
214 last_updated: Instant,
216}
217
218impl QueryHistory {
219 pub fn new() -> Self {
221 let now = Instant::now();
222 Self {
223 total_count: 0,
224 read_count: 0,
225 write_count: 0,
226 type_counts: HashMap::new(),
227 latencies: Vec::new(),
228 qpm_samples: Vec::new(),
229 created: now,
230 last_updated: now,
231 }
232 }
233
234 pub fn record(&mut self, query_type: QueryType, latency: Duration) {
236 self.total_count += 1;
237 self.last_updated = Instant::now();
238
239 *self.type_counts.entry(query_type).or_insert(0) += 1;
241
242 if query_type.is_read() {
244 self.read_count += 1;
245 } else {
246 self.write_count += 1;
247 }
248
249 if self.latencies.len() >= 1000 {
251 self.latencies.remove(0);
252 }
253 self.latencies.push(latency);
254
255 self.update_qpm();
257 }
258
259 fn update_qpm(&mut self) {
261 let now = Instant::now();
262
263 self.qpm_samples.retain(|(t, _)| now.duration_since(*t) < Duration::from_secs(300));
265
266 self.qpm_samples.push((now, self.total_count));
268 }
269
270 pub fn count(&self) -> u64 {
272 self.total_count
273 }
274
275 pub fn qpm(&self) -> f64 {
277 if self.qpm_samples.len() < 2 {
278 return 0.0;
279 }
280
281 let first = self.qpm_samples.first().expect("checked len");
282 let last = self.qpm_samples.last().expect("checked len");
283
284 let duration = last.0.duration_since(first.0);
285 if duration.as_secs() == 0 {
286 return 0.0;
287 }
288
289 let queries = last.1 - first.1;
290 (queries as f64 / duration.as_secs_f64()) * 60.0
291 }
292
293 pub fn read_write_ratio(&self) -> f64 {
295 if self.write_count == 0 {
296 return f64::INFINITY;
297 }
298 self.read_count as f64 / self.write_count as f64
299 }
300
301 pub fn avg_latency(&self) -> Duration {
303 if self.latencies.is_empty() {
304 return Duration::ZERO;
305 }
306
307 let sum: Duration = self.latencies.iter().sum();
308 sum / self.latencies.len() as u32
309 }
310
311 pub fn p95_latency(&self) -> Duration {
313 if self.latencies.is_empty() {
314 return Duration::ZERO;
315 }
316
317 let mut sorted = self.latencies.clone();
318 sorted.sort();
319
320 let idx = (sorted.len() as f64 * 0.95) as usize;
321 sorted.get(idx.min(sorted.len() - 1)).copied().unwrap_or(Duration::ZERO)
322 }
323
324 pub fn last_updated(&self) -> Instant {
326 self.last_updated
327 }
328
329 pub fn type_count(&self, query_type: QueryType) -> u64 {
331 self.type_counts.get(&query_type).copied().unwrap_or(0)
332 }
333
334 pub fn type_fraction(&self, query_type: QueryType) -> f64 {
336 if self.total_count == 0 {
337 return 0.0;
338 }
339 self.type_count(query_type) as f64 / self.total_count as f64
340 }
341}
342
343impl Default for QueryHistory {
344 fn default() -> Self {
345 Self::new()
346 }
347}
348
349#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
351pub enum QueryType {
352 SimpleSelect,
354 AggregateSelect,
356 JoinSelect,
358 VectorSearch,
360 Insert,
362 Update,
364 Delete,
366}
367
368impl QueryType {
369 pub fn is_read(&self) -> bool {
371 matches!(self,
372 QueryType::SimpleSelect | QueryType::AggregateSelect |
373 QueryType::JoinSelect | QueryType::VectorSearch)
374 }
375
376 pub fn is_write(&self) -> bool {
378 !self.is_read()
379 }
380
381 pub fn is_olap(&self) -> bool {
383 matches!(self, QueryType::AggregateSelect | QueryType::JoinSelect)
384 }
385
386 pub fn from_sql(sql: &str) -> Self {
388 let upper = sql.to_uppercase();
389
390 if upper.starts_with("INSERT") {
391 QueryType::Insert
392 } else if upper.starts_with("UPDATE") {
393 QueryType::Update
394 } else if upper.starts_with("DELETE") {
395 QueryType::Delete
396 } else if upper.contains("<->") || upper.contains("VECTOR") || upper.contains("EMBEDDING") {
397 QueryType::VectorSearch
398 } else if upper.contains("COUNT(") || upper.contains("SUM(") || upper.contains("AVG(") {
399 QueryType::AggregateSelect
400 } else if upper.contains(" JOIN ") {
401 QueryType::JoinSelect
402 } else {
403 QueryType::SimpleSelect
404 }
405 }
406}
407
408#[derive(Debug)]
410pub struct ClassificationModel {
411 pub thresholds: ModelThresholds,
413}
414
415impl ClassificationModel {
416 pub fn new() -> Self {
418 Self {
419 thresholds: ModelThresholds::default(),
420 }
421 }
422
423 pub fn classify_temperature(&self, history: &QueryHistory) -> DataTemperature {
425 let qpm = history.qpm();
426
427 if qpm > self.thresholds.hot_qpm {
428 DataTemperature::Hot
429 } else if qpm > self.thresholds.warm_qpm {
430 DataTemperature::Warm
431 } else if qpm > self.thresholds.cold_qpm {
432 DataTemperature::Cold
433 } else {
434 DataTemperature::Frozen
435 }
436 }
437
438 pub fn classify_workload(&self, history: &QueryHistory) -> WorkloadType {
440 if history.type_fraction(QueryType::VectorSearch) > 0.3 {
442 return WorkloadType::Vector;
443 }
444
445 let rw_ratio = history.read_write_ratio();
447
448 if rw_ratio > self.thresholds.olap_ratio {
449 if history.type_fraction(QueryType::AggregateSelect) > 0.2 {
451 return WorkloadType::OLAP;
452 }
453 }
454
455 if rw_ratio < self.thresholds.oltp_ratio {
456 return WorkloadType::OLTP;
458 }
459
460 if history.qpm() > 100.0 && rw_ratio > 1.0 && rw_ratio < 10.0 {
462 return WorkloadType::HTAP;
463 }
464
465 WorkloadType::Mixed
466 }
467
468 pub fn classification_confidence(&self, history: &QueryHistory) -> f64 {
470 let query_factor = (history.count() as f64 / 1000.0).min(1.0);
472
473 let rw_ratio = history.read_write_ratio();
475 let pattern_factor = if rw_ratio > 10.0 || rw_ratio < 2.0 {
476 0.8
477 } else {
478 0.5
479 };
480
481 query_factor * pattern_factor
482 }
483}
484
485impl Default for ClassificationModel {
486 fn default() -> Self {
487 Self::new()
488 }
489}
490
491#[derive(Debug, Clone)]
493pub struct ModelThresholds {
494 pub hot_qpm: f64,
496 pub warm_qpm: f64,
498 pub cold_qpm: f64,
500 pub olap_ratio: f64,
502 pub oltp_ratio: f64,
504}
505
506impl Default for ModelThresholds {
507 fn default() -> Self {
508 Self {
509 hot_qpm: 1000.0,
510 warm_qpm: 100.0,
511 cold_qpm: 10.0,
512 olap_ratio: 10.0,
513 oltp_ratio: 2.0,
514 }
515 }
516}
517
518#[derive(Debug, Clone)]
520pub struct TableClassification {
521 pub table: String,
523 pub temperature: DataTemperature,
525 pub workload: WorkloadType,
527 pub confidence: f64,
529 pub query_count: u64,
531 pub last_updated: Instant,
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538
539 #[test]
540 fn test_query_history() {
541 let mut history = QueryHistory::new();
542
543 history.record(QueryType::SimpleSelect, Duration::from_millis(10));
544 history.record(QueryType::SimpleSelect, Duration::from_millis(20));
545 history.record(QueryType::Insert, Duration::from_millis(30));
546
547 assert_eq!(history.count(), 3);
548 assert_eq!(history.read_count, 2);
549 assert_eq!(history.write_count, 1);
550 assert_eq!(history.read_write_ratio(), 2.0);
551 }
552
553 #[test]
554 fn test_query_type_detection() {
555 assert_eq!(QueryType::from_sql("INSERT INTO users VALUES (1)"), QueryType::Insert);
556 assert_eq!(QueryType::from_sql("UPDATE users SET name = 'x'"), QueryType::Update);
557 assert_eq!(QueryType::from_sql("DELETE FROM users"), QueryType::Delete);
558 assert_eq!(QueryType::from_sql("SELECT COUNT(*) FROM users"), QueryType::AggregateSelect);
559 assert_eq!(QueryType::from_sql("SELECT * FROM users"), QueryType::SimpleSelect);
560 assert_eq!(QueryType::from_sql("SELECT * FROM a JOIN b ON a.id = b.id"), QueryType::JoinSelect);
561 }
562
563 #[test]
564 fn test_classification_model() {
565 let model = ClassificationModel::new();
566 let mut history = QueryHistory::new();
567
568 for _ in 0..1000 {
570 history.record(QueryType::SimpleSelect, Duration::from_millis(5));
571 }
572 for _ in 0..50 {
574 history.record(QueryType::Insert, Duration::from_millis(10));
575 }
576
577 let workload = model.classify_workload(&history);
578 assert!(workload == WorkloadType::OLAP || workload == WorkloadType::Mixed);
580 }
581
582 #[test]
583 fn test_learning_classifier() {
584 let registry = Arc::new(SchemaRegistry::new());
585 let classifier = LearningClassifier::new(registry);
586
587 for _ in 0..100 {
588 classifier.record("users", QueryType::SimpleSelect, Duration::from_millis(5));
589 }
590
591 let classification = classifier.get_classification("users");
592 assert!(classification.is_some());
593 assert_eq!(classification.as_ref().map(|c| c.query_count), Some(100));
594 }
595
596 #[test]
597 fn test_temperature_classification() {
598 let model = ClassificationModel::new();
599 let mut history = QueryHistory::new();
600
601 for _ in 0..1000 {
603 history.record(QueryType::SimpleSelect, Duration::from_millis(1));
604 }
605 let temp = model.classify_temperature(&history);
611 assert!(temp == DataTemperature::Hot || temp == DataTemperature::Warm || temp == DataTemperature::Cold || temp == DataTemperature::Frozen);
612 }
613
614 #[test]
615 fn test_latency_tracking() {
616 let mut history = QueryHistory::new();
617
618 for i in 0..100 {
619 history.record(QueryType::SimpleSelect, Duration::from_millis(i));
620 }
621
622 let avg = history.avg_latency();
623 assert!(avg.as_millis() > 0);
624
625 let p95 = history.p95_latency();
626 assert!(p95 >= avg);
627 }
628}