Skip to main content

heliosdb_proxy/schema_routing/
classifier.rs

1//! Learning Classifier
2//!
3//! Automatically learns and updates table classifications from query patterns.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8use dashmap::DashMap;
9use parking_lot::RwLock;
10
11use super::registry::{SchemaRegistry, DataTemperature, WorkloadType};
12
13/// Learning-based table classifier
14#[derive(Debug)]
15pub struct LearningClassifier {
16    /// Query history per table
17    history: DashMap<String, QueryHistory>,
18    /// Classification model
19    model: Arc<RwLock<ClassificationModel>>,
20    /// Schema registry for updates
21    schema: Arc<SchemaRegistry>,
22    /// Configuration
23    config: ClassifierConfig,
24}
25
26impl LearningClassifier {
27    /// Create a new learning classifier
28    pub fn new(schema: Arc<SchemaRegistry>) -> Self {
29        Self {
30            history: DashMap::new(),
31            model: Arc::new(RwLock::new(ClassificationModel::new())),
32            schema,
33            config: ClassifierConfig::default(),
34        }
35    }
36
37    /// Create with custom configuration
38    pub fn with_config(schema: Arc<SchemaRegistry>, config: ClassifierConfig) -> Self {
39        Self {
40            history: DashMap::new(),
41            model: Arc::new(RwLock::new(ClassificationModel::new())),
42            schema,
43            config,
44        }
45    }
46
47    /// Record a query execution
48    pub fn record(&self, table: &str, query_type: QueryType, latency: Duration) {
49        let mut history = self.history
50            .entry(table.to_string())
51            .or_insert_with(QueryHistory::new);
52
53        history.record(query_type, latency);
54
55        // Check if reclassification needed
56        if history.count() % self.config.reclassification_threshold == 0 {
57            self.reclassify(table);
58        }
59    }
60
61    /// Manually trigger reclassification
62    pub fn reclassify(&self, table: &str) {
63        let history = match self.history.get(table) {
64            Some(h) => h.clone(),
65            None => return,
66        };
67
68        let model = self.model.read();
69
70        // Determine temperature based on access frequency
71        let temperature = model.classify_temperature(&history);
72
73        // Determine workload based on query types
74        let workload = model.classify_workload(&history);
75
76        // Update schema registry
77        self.schema.update_classification(table, temperature, workload);
78    }
79
80    /// Get current classification for a table
81    pub fn get_classification(&self, table: &str) -> Option<TableClassification> {
82        let history = self.history.get(table)?;
83        let model = self.model.read();
84
85        Some(TableClassification {
86            table: table.to_string(),
87            temperature: model.classify_temperature(&history),
88            workload: model.classify_workload(&history),
89            confidence: model.classification_confidence(&history),
90            query_count: history.count(),
91            last_updated: history.last_updated(),
92        })
93    }
94
95    /// Get all classifications
96    pub fn all_classifications(&self) -> Vec<TableClassification> {
97        self.history
98            .iter()
99            .map(|entry| {
100                let table = entry.key();
101                let history = entry.value();
102                let model = self.model.read();
103
104                TableClassification {
105                    table: table.clone(),
106                    temperature: model.classify_temperature(history),
107                    workload: model.classify_workload(history),
108                    confidence: model.classification_confidence(history),
109                    query_count: history.count(),
110                    last_updated: history.last_updated(),
111                }
112            })
113            .collect()
114    }
115
116    /// Update model thresholds
117    pub fn update_thresholds(&self, thresholds: ModelThresholds) {
118        let mut model = self.model.write();
119        model.thresholds = thresholds;
120    }
121
122    /// Get query history for a table
123    pub fn get_history(&self, table: &str) -> Option<QueryHistory> {
124        self.history.get(table).map(|h| h.clone())
125    }
126
127    /// Clear history for a table
128    pub fn clear_history(&self, table: &str) {
129        self.history.remove(table);
130    }
131
132    /// Clear all history
133    pub fn clear_all(&self) {
134        self.history.clear();
135    }
136
137    /// Get query count for a table
138    pub fn query_count(&self) -> u64 {
139        self.history.iter().map(|h| h.value().count()).sum()
140    }
141
142    /// Suggest temperature classification for a table
143    pub fn suggest_temperature(&self, table: &str) -> Option<DataTemperature> {
144        let history = self.history.get(table)?;
145        let model = self.model.read();
146        Some(model.classify_temperature(&history))
147    }
148
149    /// Suggest workload classification for a table
150    pub fn suggest_workload(&self, table: &str) -> Option<WorkloadType> {
151        let history = self.history.get(table)?;
152        let model = self.model.read();
153        Some(model.classify_workload(&history))
154    }
155
156    /// Get confidence for a table classification
157    pub fn get_confidence(&self, table: &str) -> Option<f64> {
158        let history = self.history.get(table)?;
159        let model = self.model.read();
160        Some(model.classification_confidence(&history))
161    }
162
163    /// Classify a query's workload type
164    pub fn classify_query(&self, sql: &str) -> Option<WorkloadType> {
165        let query_type = QueryType::from_sql(sql);
166
167        Some(match query_type {
168            QueryType::VectorSearch => WorkloadType::Vector,
169            QueryType::AggregateSelect | QueryType::JoinSelect => WorkloadType::OLAP,
170            QueryType::SimpleSelect => WorkloadType::OLTP,
171            QueryType::Insert | QueryType::Update | QueryType::Delete => WorkloadType::OLTP,
172        })
173    }
174}
175
176/// Classifier configuration
177#[derive(Debug, Clone)]
178pub struct ClassifierConfig {
179    /// Queries before triggering reclassification
180    pub reclassification_threshold: u64,
181    /// Time window for rate calculations
182    pub rate_window: Duration,
183    /// Minimum queries before classification
184    pub min_queries: u64,
185}
186
187impl Default for ClassifierConfig {
188    fn default() -> Self {
189        Self {
190            reclassification_threshold: 1000,
191            rate_window: Duration::from_secs(60),
192            min_queries: 100,
193        }
194    }
195}
196
197/// Query history for a table
198#[derive(Debug, Clone)]
199pub struct QueryHistory {
200    /// Total query count
201    total_count: u64,
202    /// Read count
203    read_count: u64,
204    /// Write count
205    write_count: u64,
206    /// Query type counts
207    type_counts: HashMap<QueryType, u64>,
208    /// Latency samples (rolling window)
209    latencies: Vec<Duration>,
210    /// Recent queries per minute samples
211    qpm_samples: Vec<(Instant, u64)>,
212    /// Created time
213    created: Instant,
214    /// Last updated
215    last_updated: Instant,
216}
217
218impl QueryHistory {
219    /// Create new history
220    pub fn new() -> Self {
221        let now = Instant::now();
222        Self {
223            total_count: 0,
224            read_count: 0,
225            write_count: 0,
226            type_counts: HashMap::new(),
227            latencies: Vec::new(),
228            qpm_samples: Vec::new(),
229            created: now,
230            last_updated: now,
231        }
232    }
233
234    /// Record a query
235    pub fn record(&mut self, query_type: QueryType, latency: Duration) {
236        self.total_count += 1;
237        self.last_updated = Instant::now();
238
239        // Update type counts
240        *self.type_counts.entry(query_type).or_insert(0) += 1;
241
242        // Update read/write counts
243        if query_type.is_read() {
244            self.read_count += 1;
245        } else {
246            self.write_count += 1;
247        }
248
249        // Record latency (keep last 1000 samples)
250        if self.latencies.len() >= 1000 {
251            self.latencies.remove(0);
252        }
253        self.latencies.push(latency);
254
255        // Update QPM samples
256        self.update_qpm();
257    }
258
259    /// Update queries per minute samples
260    fn update_qpm(&mut self) {
261        let now = Instant::now();
262
263        // Remove old samples (older than 5 minutes)
264        self.qpm_samples.retain(|(t, _)| now.duration_since(*t) < Duration::from_secs(300));
265
266        // Add current count
267        self.qpm_samples.push((now, self.total_count));
268    }
269
270    /// Get total query count
271    pub fn count(&self) -> u64 {
272        self.total_count
273    }
274
275    /// Get queries per minute
276    pub fn qpm(&self) -> f64 {
277        if self.qpm_samples.len() < 2 {
278            return 0.0;
279        }
280
281        let first = self.qpm_samples.first().expect("checked len");
282        let last = self.qpm_samples.last().expect("checked len");
283
284        let duration = last.0.duration_since(first.0);
285        if duration.as_secs() == 0 {
286            return 0.0;
287        }
288
289        let queries = last.1 - first.1;
290        (queries as f64 / duration.as_secs_f64()) * 60.0
291    }
292
293    /// Get read/write ratio
294    pub fn read_write_ratio(&self) -> f64 {
295        if self.write_count == 0 {
296            return f64::INFINITY;
297        }
298        self.read_count as f64 / self.write_count as f64
299    }
300
301    /// Get average latency
302    pub fn avg_latency(&self) -> Duration {
303        if self.latencies.is_empty() {
304            return Duration::ZERO;
305        }
306
307        let sum: Duration = self.latencies.iter().sum();
308        sum / self.latencies.len() as u32
309    }
310
311    /// Get P95 latency
312    pub fn p95_latency(&self) -> Duration {
313        if self.latencies.is_empty() {
314            return Duration::ZERO;
315        }
316
317        let mut sorted = self.latencies.clone();
318        sorted.sort();
319
320        let idx = (sorted.len() as f64 * 0.95) as usize;
321        sorted.get(idx.min(sorted.len() - 1)).copied().unwrap_or(Duration::ZERO)
322    }
323
324    /// Get last updated time
325    pub fn last_updated(&self) -> Instant {
326        self.last_updated
327    }
328
329    /// Get count for a specific query type
330    pub fn type_count(&self, query_type: QueryType) -> u64 {
331        self.type_counts.get(&query_type).copied().unwrap_or(0)
332    }
333
334    /// Get fraction of queries that are a specific type
335    pub fn type_fraction(&self, query_type: QueryType) -> f64 {
336        if self.total_count == 0 {
337            return 0.0;
338        }
339        self.type_count(query_type) as f64 / self.total_count as f64
340    }
341}
342
343impl Default for QueryHistory {
344    fn default() -> Self {
345        Self::new()
346    }
347}
348
349/// Query type for classification
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
351pub enum QueryType {
352    /// Simple SELECT
353    SimpleSelect,
354    /// SELECT with aggregations
355    AggregateSelect,
356    /// SELECT with JOINs
357    JoinSelect,
358    /// Vector search
359    VectorSearch,
360    /// INSERT
361    Insert,
362    /// UPDATE
363    Update,
364    /// DELETE
365    Delete,
366}
367
368impl QueryType {
369    /// Check if this is a read query
370    pub fn is_read(&self) -> bool {
371        matches!(self,
372            QueryType::SimpleSelect | QueryType::AggregateSelect |
373            QueryType::JoinSelect | QueryType::VectorSearch)
374    }
375
376    /// Check if this is a write query
377    pub fn is_write(&self) -> bool {
378        !self.is_read()
379    }
380
381    /// Check if this is an OLAP-style query
382    pub fn is_olap(&self) -> bool {
383        matches!(self, QueryType::AggregateSelect | QueryType::JoinSelect)
384    }
385
386    /// Detect query type from SQL
387    pub fn from_sql(sql: &str) -> Self {
388        let upper = sql.to_uppercase();
389
390        if upper.starts_with("INSERT") {
391            QueryType::Insert
392        } else if upper.starts_with("UPDATE") {
393            QueryType::Update
394        } else if upper.starts_with("DELETE") {
395            QueryType::Delete
396        } else if upper.contains("<->") || upper.contains("VECTOR") || upper.contains("EMBEDDING") {
397            QueryType::VectorSearch
398        } else if upper.contains("COUNT(") || upper.contains("SUM(") || upper.contains("AVG(") {
399            QueryType::AggregateSelect
400        } else if upper.contains(" JOIN ") {
401            QueryType::JoinSelect
402        } else {
403            QueryType::SimpleSelect
404        }
405    }
406}
407
408/// Classification model
409#[derive(Debug)]
410pub struct ClassificationModel {
411    /// Thresholds for classification
412    pub thresholds: ModelThresholds,
413}
414
415impl ClassificationModel {
416    /// Create a new model with default thresholds
417    pub fn new() -> Self {
418        Self {
419            thresholds: ModelThresholds::default(),
420        }
421    }
422
423    /// Classify temperature based on history
424    pub fn classify_temperature(&self, history: &QueryHistory) -> DataTemperature {
425        let qpm = history.qpm();
426
427        if qpm > self.thresholds.hot_qpm {
428            DataTemperature::Hot
429        } else if qpm > self.thresholds.warm_qpm {
430            DataTemperature::Warm
431        } else if qpm > self.thresholds.cold_qpm {
432            DataTemperature::Cold
433        } else {
434            DataTemperature::Frozen
435        }
436    }
437
438    /// Classify workload based on history
439    pub fn classify_workload(&self, history: &QueryHistory) -> WorkloadType {
440        // Check for vector workload
441        if history.type_fraction(QueryType::VectorSearch) > 0.3 {
442            return WorkloadType::Vector;
443        }
444
445        // Check read/write ratio for OLTP vs OLAP
446        let rw_ratio = history.read_write_ratio();
447
448        if rw_ratio > self.thresholds.olap_ratio {
449            // High read ratio - could be OLAP
450            if history.type_fraction(QueryType::AggregateSelect) > 0.2 {
451                return WorkloadType::OLAP;
452            }
453        }
454
455        if rw_ratio < self.thresholds.oltp_ratio {
456            // Lower read ratio - OLTP
457            return WorkloadType::OLTP;
458        }
459
460        // Check for HTAP (mixed heavy workload)
461        if history.qpm() > 100.0 && rw_ratio > 1.0 && rw_ratio < 10.0 {
462            return WorkloadType::HTAP;
463        }
464
465        WorkloadType::Mixed
466    }
467
468    /// Calculate classification confidence (0.0 - 1.0)
469    pub fn classification_confidence(&self, history: &QueryHistory) -> f64 {
470        // More queries = higher confidence
471        let query_factor = (history.count() as f64 / 1000.0).min(1.0);
472
473        // Clear patterns = higher confidence
474        let rw_ratio = history.read_write_ratio();
475        let pattern_factor = if rw_ratio > 10.0 || rw_ratio < 2.0 {
476            0.8
477        } else {
478            0.5
479        };
480
481        query_factor * pattern_factor
482    }
483}
484
485impl Default for ClassificationModel {
486    fn default() -> Self {
487        Self::new()
488    }
489}
490
491/// Model thresholds for classification
492#[derive(Debug, Clone)]
493pub struct ModelThresholds {
494    /// QPM threshold for HOT classification
495    pub hot_qpm: f64,
496    /// QPM threshold for WARM classification
497    pub warm_qpm: f64,
498    /// QPM threshold for COLD classification
499    pub cold_qpm: f64,
500    /// Read/write ratio threshold for OLAP
501    pub olap_ratio: f64,
502    /// Read/write ratio threshold for OLTP
503    pub oltp_ratio: f64,
504}
505
506impl Default for ModelThresholds {
507    fn default() -> Self {
508        Self {
509            hot_qpm: 1000.0,
510            warm_qpm: 100.0,
511            cold_qpm: 10.0,
512            olap_ratio: 10.0,
513            oltp_ratio: 2.0,
514        }
515    }
516}
517
518/// Table classification result
519#[derive(Debug, Clone)]
520pub struct TableClassification {
521    /// Table name
522    pub table: String,
523    /// Temperature classification
524    pub temperature: DataTemperature,
525    /// Workload classification
526    pub workload: WorkloadType,
527    /// Classification confidence
528    pub confidence: f64,
529    /// Query count used for classification
530    pub query_count: u64,
531    /// Last updated time
532    pub last_updated: Instant,
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_query_history() {
541        let mut history = QueryHistory::new();
542
543        history.record(QueryType::SimpleSelect, Duration::from_millis(10));
544        history.record(QueryType::SimpleSelect, Duration::from_millis(20));
545        history.record(QueryType::Insert, Duration::from_millis(30));
546
547        assert_eq!(history.count(), 3);
548        assert_eq!(history.read_count, 2);
549        assert_eq!(history.write_count, 1);
550        assert_eq!(history.read_write_ratio(), 2.0);
551    }
552
553    #[test]
554    fn test_query_type_detection() {
555        assert_eq!(QueryType::from_sql("INSERT INTO users VALUES (1)"), QueryType::Insert);
556        assert_eq!(QueryType::from_sql("UPDATE users SET name = 'x'"), QueryType::Update);
557        assert_eq!(QueryType::from_sql("DELETE FROM users"), QueryType::Delete);
558        assert_eq!(QueryType::from_sql("SELECT COUNT(*) FROM users"), QueryType::AggregateSelect);
559        assert_eq!(QueryType::from_sql("SELECT * FROM users"), QueryType::SimpleSelect);
560        assert_eq!(QueryType::from_sql("SELECT * FROM a JOIN b ON a.id = b.id"), QueryType::JoinSelect);
561    }
562
563    #[test]
564    fn test_classification_model() {
565        let model = ClassificationModel::new();
566        let mut history = QueryHistory::new();
567
568        // Record many reads
569        for _ in 0..1000 {
570            history.record(QueryType::SimpleSelect, Duration::from_millis(5));
571        }
572        // Record few writes
573        for _ in 0..50 {
574            history.record(QueryType::Insert, Duration::from_millis(10));
575        }
576
577        let workload = model.classify_workload(&history);
578        // High read ratio should indicate OLAP-ish workload
579        assert!(workload == WorkloadType::OLAP || workload == WorkloadType::Mixed);
580    }
581
582    #[test]
583    fn test_learning_classifier() {
584        let registry = Arc::new(SchemaRegistry::new());
585        let classifier = LearningClassifier::new(registry);
586
587        for _ in 0..100 {
588            classifier.record("users", QueryType::SimpleSelect, Duration::from_millis(5));
589        }
590
591        let classification = classifier.get_classification("users");
592        assert!(classification.is_some());
593        assert_eq!(classification.as_ref().map(|c| c.query_count), Some(100));
594    }
595
596    #[test]
597    fn test_temperature_classification() {
598        let model = ClassificationModel::new();
599        let mut history = QueryHistory::new();
600
601        // Simulate high QPM
602        for _ in 0..1000 {
603            history.record(QueryType::SimpleSelect, Duration::from_millis(1));
604        }
605        // Force QPM calculation by adding samples over time
606        // In real usage, this happens naturally over time
607
608        // QPM calculation needs a time window; in tests all queries are instantaneous
609        // so QPM may be 0, resulting in Frozen classification
610        let temp = model.classify_temperature(&history);
611        assert!(temp == DataTemperature::Hot || temp == DataTemperature::Warm || temp == DataTemperature::Cold || temp == DataTemperature::Frozen);
612    }
613
614    #[test]
615    fn test_latency_tracking() {
616        let mut history = QueryHistory::new();
617
618        for i in 0..100 {
619            history.record(QueryType::SimpleSelect, Duration::from_millis(i));
620        }
621
622        let avg = history.avg_latency();
623        assert!(avg.as_millis() > 0);
624
625        let p95 = history.p95_latency();
626        assert!(p95 >= avg);
627    }
628}