Skip to main content

heliosdb_proxy/schema_routing/
classifier.rs

1//! Learning Classifier
2//!
3//! Automatically learns and updates table classifications from query patterns.
4
5use dashmap::DashMap;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use super::registry::{DataTemperature, SchemaRegistry, WorkloadType};
12
13/// Learning-based table classifier
14#[derive(Debug)]
15pub struct LearningClassifier {
16    /// Query history per table
17    history: DashMap<String, QueryHistory>,
18    /// Classification model
19    model: Arc<RwLock<ClassificationModel>>,
20    /// Schema registry for updates
21    schema: Arc<SchemaRegistry>,
22    /// Configuration
23    config: ClassifierConfig,
24}
25
26impl LearningClassifier {
27    /// Create a new learning classifier
28    pub fn new(schema: Arc<SchemaRegistry>) -> Self {
29        Self {
30            history: DashMap::new(),
31            model: Arc::new(RwLock::new(ClassificationModel::new())),
32            schema,
33            config: ClassifierConfig::default(),
34        }
35    }
36
37    /// Create with custom configuration
38    pub fn with_config(schema: Arc<SchemaRegistry>, config: ClassifierConfig) -> Self {
39        Self {
40            history: DashMap::new(),
41            model: Arc::new(RwLock::new(ClassificationModel::new())),
42            schema,
43            config,
44        }
45    }
46
47    /// Record a query execution
48    pub fn record(&self, table: &str, query_type: QueryType, latency: Duration) {
49        let mut history = self.history.entry(table.to_string()).or_default();
50
51        history.record(query_type, latency);
52
53        // Check if reclassification needed
54        if history.count() % self.config.reclassification_threshold == 0 {
55            self.reclassify(table);
56        }
57    }
58
59    /// Manually trigger reclassification
60    pub fn reclassify(&self, table: &str) {
61        let history = match self.history.get(table) {
62            Some(h) => h.clone(),
63            None => return,
64        };
65
66        let model = self.model.read();
67
68        // Determine temperature based on access frequency
69        let temperature = model.classify_temperature(&history);
70
71        // Determine workload based on query types
72        let workload = model.classify_workload(&history);
73
74        // Update schema registry
75        self.schema
76            .update_classification(table, temperature, workload);
77    }
78
79    /// Get current classification for a table
80    pub fn get_classification(&self, table: &str) -> Option<TableClassification> {
81        let history = self.history.get(table)?;
82        let model = self.model.read();
83
84        Some(TableClassification {
85            table: table.to_string(),
86            temperature: model.classify_temperature(&history),
87            workload: model.classify_workload(&history),
88            confidence: model.classification_confidence(&history),
89            query_count: history.count(),
90            last_updated: history.last_updated(),
91        })
92    }
93
94    /// Get all classifications
95    pub fn all_classifications(&self) -> Vec<TableClassification> {
96        self.history
97            .iter()
98            .map(|entry| {
99                let table = entry.key();
100                let history = entry.value();
101                let model = self.model.read();
102
103                TableClassification {
104                    table: table.clone(),
105                    temperature: model.classify_temperature(history),
106                    workload: model.classify_workload(history),
107                    confidence: model.classification_confidence(history),
108                    query_count: history.count(),
109                    last_updated: history.last_updated(),
110                }
111            })
112            .collect()
113    }
114
115    /// Update model thresholds
116    pub fn update_thresholds(&self, thresholds: ModelThresholds) {
117        let mut model = self.model.write();
118        model.thresholds = thresholds;
119    }
120
121    /// Get query history for a table
122    pub fn get_history(&self, table: &str) -> Option<QueryHistory> {
123        self.history.get(table).map(|h| h.clone())
124    }
125
126    /// Clear history for a table
127    pub fn clear_history(&self, table: &str) {
128        self.history.remove(table);
129    }
130
131    /// Clear all history
132    pub fn clear_all(&self) {
133        self.history.clear();
134    }
135
136    /// Get query count for a table
137    pub fn query_count(&self) -> u64 {
138        self.history.iter().map(|h| h.value().count()).sum()
139    }
140
141    /// Suggest temperature classification for a table
142    pub fn suggest_temperature(&self, table: &str) -> Option<DataTemperature> {
143        let history = self.history.get(table)?;
144        let model = self.model.read();
145        Some(model.classify_temperature(&history))
146    }
147
148    /// Suggest workload classification for a table
149    pub fn suggest_workload(&self, table: &str) -> Option<WorkloadType> {
150        let history = self.history.get(table)?;
151        let model = self.model.read();
152        Some(model.classify_workload(&history))
153    }
154
155    /// Get confidence for a table classification
156    pub fn get_confidence(&self, table: &str) -> Option<f64> {
157        let history = self.history.get(table)?;
158        let model = self.model.read();
159        Some(model.classification_confidence(&history))
160    }
161
162    /// Classify a query's workload type
163    pub fn classify_query(&self, sql: &str) -> Option<WorkloadType> {
164        let query_type = QueryType::from_sql(sql);
165
166        Some(match query_type {
167            QueryType::VectorSearch => WorkloadType::Vector,
168            QueryType::AggregateSelect | QueryType::JoinSelect => WorkloadType::OLAP,
169            QueryType::SimpleSelect => WorkloadType::OLTP,
170            QueryType::Insert | QueryType::Update | QueryType::Delete => WorkloadType::OLTP,
171        })
172    }
173}
174
175/// Classifier configuration
176#[derive(Debug, Clone)]
177pub struct ClassifierConfig {
178    /// Queries before triggering reclassification
179    pub reclassification_threshold: u64,
180    /// Time window for rate calculations
181    pub rate_window: Duration,
182    /// Minimum queries before classification
183    pub min_queries: u64,
184}
185
186impl Default for ClassifierConfig {
187    fn default() -> Self {
188        Self {
189            reclassification_threshold: 1000,
190            rate_window: Duration::from_secs(60),
191            min_queries: 100,
192        }
193    }
194}
195
196/// Query history for a table
197#[derive(Debug, Clone)]
198pub struct QueryHistory {
199    /// Total query count
200    total_count: u64,
201    /// Read count
202    read_count: u64,
203    /// Write count
204    write_count: u64,
205    /// Query type counts
206    type_counts: HashMap<QueryType, u64>,
207    /// Latency samples (rolling window)
208    latencies: Vec<Duration>,
209    /// Recent queries per minute samples
210    qpm_samples: Vec<(Instant, u64)>,
211    /// Created time
212    #[allow(dead_code)]
213    created: Instant,
214    /// Last updated
215    last_updated: Instant,
216}
217
218impl QueryHistory {
219    /// Create new history
220    pub fn new() -> Self {
221        let now = Instant::now();
222        Self {
223            total_count: 0,
224            read_count: 0,
225            write_count: 0,
226            type_counts: HashMap::new(),
227            latencies: Vec::new(),
228            qpm_samples: Vec::new(),
229            created: now,
230            last_updated: now,
231        }
232    }
233
234    /// Record a query
235    pub fn record(&mut self, query_type: QueryType, latency: Duration) {
236        self.total_count += 1;
237        self.last_updated = Instant::now();
238
239        // Update type counts
240        *self.type_counts.entry(query_type).or_insert(0) += 1;
241
242        // Update read/write counts
243        if query_type.is_read() {
244            self.read_count += 1;
245        } else {
246            self.write_count += 1;
247        }
248
249        // Record latency (keep last 1000 samples)
250        if self.latencies.len() >= 1000 {
251            self.latencies.remove(0);
252        }
253        self.latencies.push(latency);
254
255        // Update QPM samples
256        self.update_qpm();
257    }
258
259    /// Update queries per minute samples
260    fn update_qpm(&mut self) {
261        let now = Instant::now();
262
263        // Remove old samples (older than 5 minutes)
264        self.qpm_samples
265            .retain(|(t, _)| now.duration_since(*t) < Duration::from_secs(300));
266
267        // Add current count
268        self.qpm_samples.push((now, self.total_count));
269    }
270
271    /// Get total query count
272    pub fn count(&self) -> u64 {
273        self.total_count
274    }
275
276    /// Get queries per minute
277    pub fn qpm(&self) -> f64 {
278        if self.qpm_samples.len() < 2 {
279            return 0.0;
280        }
281
282        let first = self.qpm_samples.first().expect("checked len");
283        let last = self.qpm_samples.last().expect("checked len");
284
285        let duration = last.0.duration_since(first.0);
286        if duration.as_secs() == 0 {
287            return 0.0;
288        }
289
290        let queries = last.1 - first.1;
291        (queries as f64 / duration.as_secs_f64()) * 60.0
292    }
293
294    /// Get read/write ratio
295    pub fn read_write_ratio(&self) -> f64 {
296        if self.write_count == 0 {
297            return f64::INFINITY;
298        }
299        self.read_count as f64 / self.write_count as f64
300    }
301
302    /// Get average latency
303    pub fn avg_latency(&self) -> Duration {
304        if self.latencies.is_empty() {
305            return Duration::ZERO;
306        }
307
308        let sum: Duration = self.latencies.iter().sum();
309        sum / self.latencies.len() as u32
310    }
311
312    /// Get P95 latency
313    pub fn p95_latency(&self) -> Duration {
314        if self.latencies.is_empty() {
315            return Duration::ZERO;
316        }
317
318        let mut sorted = self.latencies.clone();
319        sorted.sort();
320
321        let idx = (sorted.len() as f64 * 0.95) as usize;
322        sorted
323            .get(idx.min(sorted.len() - 1))
324            .copied()
325            .unwrap_or(Duration::ZERO)
326    }
327
328    /// Get last updated time
329    pub fn last_updated(&self) -> Instant {
330        self.last_updated
331    }
332
333    /// Get count for a specific query type
334    pub fn type_count(&self, query_type: QueryType) -> u64 {
335        self.type_counts.get(&query_type).copied().unwrap_or(0)
336    }
337
338    /// Get fraction of queries that are a specific type
339    pub fn type_fraction(&self, query_type: QueryType) -> f64 {
340        if self.total_count == 0 {
341            return 0.0;
342        }
343        self.type_count(query_type) as f64 / self.total_count as f64
344    }
345}
346
347impl Default for QueryHistory {
348    fn default() -> Self {
349        Self::new()
350    }
351}
352
353/// Query type for classification
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
355pub enum QueryType {
356    /// Simple SELECT
357    SimpleSelect,
358    /// SELECT with aggregations
359    AggregateSelect,
360    /// SELECT with JOINs
361    JoinSelect,
362    /// Vector search
363    VectorSearch,
364    /// INSERT
365    Insert,
366    /// UPDATE
367    Update,
368    /// DELETE
369    Delete,
370}
371
372impl QueryType {
373    /// Check if this is a read query
374    pub fn is_read(&self) -> bool {
375        matches!(
376            self,
377            QueryType::SimpleSelect
378                | QueryType::AggregateSelect
379                | QueryType::JoinSelect
380                | QueryType::VectorSearch
381        )
382    }
383
384    /// Check if this is a write query
385    pub fn is_write(&self) -> bool {
386        !self.is_read()
387    }
388
389    /// Check if this is an OLAP-style query
390    pub fn is_olap(&self) -> bool {
391        matches!(self, QueryType::AggregateSelect | QueryType::JoinSelect)
392    }
393
394    /// Detect query type from SQL
395    pub fn from_sql(sql: &str) -> Self {
396        let upper = sql.to_uppercase();
397
398        if upper.starts_with("INSERT") {
399            QueryType::Insert
400        } else if upper.starts_with("UPDATE") {
401            QueryType::Update
402        } else if upper.starts_with("DELETE") {
403            QueryType::Delete
404        } else if upper.contains("<->") || upper.contains("VECTOR") || upper.contains("EMBEDDING") {
405            QueryType::VectorSearch
406        } else if upper.contains("COUNT(") || upper.contains("SUM(") || upper.contains("AVG(") {
407            QueryType::AggregateSelect
408        } else if upper.contains(" JOIN ") {
409            QueryType::JoinSelect
410        } else {
411            QueryType::SimpleSelect
412        }
413    }
414}
415
416/// Classification model
417#[derive(Debug)]
418pub struct ClassificationModel {
419    /// Thresholds for classification
420    pub thresholds: ModelThresholds,
421}
422
423impl ClassificationModel {
424    /// Create a new model with default thresholds
425    pub fn new() -> Self {
426        Self {
427            thresholds: ModelThresholds::default(),
428        }
429    }
430
431    /// Classify temperature based on history
432    pub fn classify_temperature(&self, history: &QueryHistory) -> DataTemperature {
433        let qpm = history.qpm();
434
435        if qpm > self.thresholds.hot_qpm {
436            DataTemperature::Hot
437        } else if qpm > self.thresholds.warm_qpm {
438            DataTemperature::Warm
439        } else if qpm > self.thresholds.cold_qpm {
440            DataTemperature::Cold
441        } else {
442            DataTemperature::Frozen
443        }
444    }
445
446    /// Classify workload based on history
447    pub fn classify_workload(&self, history: &QueryHistory) -> WorkloadType {
448        // Check for vector workload
449        if history.type_fraction(QueryType::VectorSearch) > 0.3 {
450            return WorkloadType::Vector;
451        }
452
453        // Check read/write ratio for OLTP vs OLAP
454        let rw_ratio = history.read_write_ratio();
455
456        if rw_ratio > self.thresholds.olap_ratio {
457            // High read ratio - could be OLAP
458            if history.type_fraction(QueryType::AggregateSelect) > 0.2 {
459                return WorkloadType::OLAP;
460            }
461        }
462
463        if rw_ratio < self.thresholds.oltp_ratio {
464            // Lower read ratio - OLTP
465            return WorkloadType::OLTP;
466        }
467
468        // Check for HTAP (mixed heavy workload)
469        if history.qpm() > 100.0 && rw_ratio > 1.0 && rw_ratio < 10.0 {
470            return WorkloadType::HTAP;
471        }
472
473        WorkloadType::Mixed
474    }
475
476    /// Calculate classification confidence (0.0 - 1.0)
477    pub fn classification_confidence(&self, history: &QueryHistory) -> f64 {
478        // More queries = higher confidence
479        let query_factor = (history.count() as f64 / 1000.0).min(1.0);
480
481        // Clear patterns = higher confidence
482        let rw_ratio = history.read_write_ratio();
483        let pattern_factor = if !(2.0..=10.0).contains(&rw_ratio) {
484            0.8
485        } else {
486            0.5
487        };
488
489        query_factor * pattern_factor
490    }
491}
492
493impl Default for ClassificationModel {
494    fn default() -> Self {
495        Self::new()
496    }
497}
498
499/// Model thresholds for classification
500#[derive(Debug, Clone)]
501pub struct ModelThresholds {
502    /// QPM threshold for HOT classification
503    pub hot_qpm: f64,
504    /// QPM threshold for WARM classification
505    pub warm_qpm: f64,
506    /// QPM threshold for COLD classification
507    pub cold_qpm: f64,
508    /// Read/write ratio threshold for OLAP
509    pub olap_ratio: f64,
510    /// Read/write ratio threshold for OLTP
511    pub oltp_ratio: f64,
512}
513
514impl Default for ModelThresholds {
515    fn default() -> Self {
516        Self {
517            hot_qpm: 1000.0,
518            warm_qpm: 100.0,
519            cold_qpm: 10.0,
520            olap_ratio: 10.0,
521            oltp_ratio: 2.0,
522        }
523    }
524}
525
526/// Table classification result
527#[derive(Debug, Clone)]
528pub struct TableClassification {
529    /// Table name
530    pub table: String,
531    /// Temperature classification
532    pub temperature: DataTemperature,
533    /// Workload classification
534    pub workload: WorkloadType,
535    /// Classification confidence
536    pub confidence: f64,
537    /// Query count used for classification
538    pub query_count: u64,
539    /// Last updated time
540    pub last_updated: Instant,
541}
542
543#[cfg(test)]
544mod tests {
545    use super::*;
546
547    #[test]
548    fn test_query_history() {
549        let mut history = QueryHistory::new();
550
551        history.record(QueryType::SimpleSelect, Duration::from_millis(10));
552        history.record(QueryType::SimpleSelect, Duration::from_millis(20));
553        history.record(QueryType::Insert, Duration::from_millis(30));
554
555        assert_eq!(history.count(), 3);
556        assert_eq!(history.read_count, 2);
557        assert_eq!(history.write_count, 1);
558        assert_eq!(history.read_write_ratio(), 2.0);
559    }
560
561    #[test]
562    fn test_query_type_detection() {
563        assert_eq!(
564            QueryType::from_sql("INSERT INTO users VALUES (1)"),
565            QueryType::Insert
566        );
567        assert_eq!(
568            QueryType::from_sql("UPDATE users SET name = 'x'"),
569            QueryType::Update
570        );
571        assert_eq!(QueryType::from_sql("DELETE FROM users"), QueryType::Delete);
572        assert_eq!(
573            QueryType::from_sql("SELECT COUNT(*) FROM users"),
574            QueryType::AggregateSelect
575        );
576        assert_eq!(
577            QueryType::from_sql("SELECT * FROM users"),
578            QueryType::SimpleSelect
579        );
580        assert_eq!(
581            QueryType::from_sql("SELECT * FROM a JOIN b ON a.id = b.id"),
582            QueryType::JoinSelect
583        );
584    }
585
586    #[test]
587    fn test_classification_model() {
588        let model = ClassificationModel::new();
589        let mut history = QueryHistory::new();
590
591        // Record many reads
592        for _ in 0..1000 {
593            history.record(QueryType::SimpleSelect, Duration::from_millis(5));
594        }
595        // Record few writes
596        for _ in 0..50 {
597            history.record(QueryType::Insert, Duration::from_millis(10));
598        }
599
600        let workload = model.classify_workload(&history);
601        // High read ratio should indicate OLAP-ish workload
602        assert!(workload == WorkloadType::OLAP || workload == WorkloadType::Mixed);
603    }
604
605    #[test]
606    fn test_learning_classifier() {
607        let registry = Arc::new(SchemaRegistry::new());
608        let classifier = LearningClassifier::new(registry);
609
610        for _ in 0..100 {
611            classifier.record("users", QueryType::SimpleSelect, Duration::from_millis(5));
612        }
613
614        let classification = classifier.get_classification("users");
615        assert!(classification.is_some());
616        assert_eq!(classification.as_ref().map(|c| c.query_count), Some(100));
617    }
618
619    #[test]
620    fn test_temperature_classification() {
621        let model = ClassificationModel::new();
622        let mut history = QueryHistory::new();
623
624        // Simulate high QPM
625        for _ in 0..1000 {
626            history.record(QueryType::SimpleSelect, Duration::from_millis(1));
627        }
628        // Force QPM calculation by adding samples over time
629        // In real usage, this happens naturally over time
630
631        // QPM calculation needs a time window; in tests all queries are instantaneous
632        // so QPM may be 0, resulting in Frozen classification
633        let temp = model.classify_temperature(&history);
634        assert!(
635            temp == DataTemperature::Hot
636                || temp == DataTemperature::Warm
637                || temp == DataTemperature::Cold
638                || temp == DataTemperature::Frozen
639        );
640    }
641
642    #[test]
643    fn test_latency_tracking() {
644        let mut history = QueryHistory::new();
645
646        for i in 0..100 {
647            history.record(QueryType::SimpleSelect, Duration::from_millis(i));
648        }
649
650        let avg = history.avg_latency();
651        assert!(avg.as_millis() > 0);
652
653        let p95 = history.p95_latency();
654        assert!(p95 >= avg);
655    }
656}