Skip to main content

heliosdb_proxy/distribcache/
classifier.rs

1//! Workload classifier for intelligent caching decisions
2//!
3//! Classifies queries into workload types (OLTP, OLAP, Vector, AI Agent, RAG)
4//! to apply appropriate caching strategies.
5
6use dashmap::DashMap;
7use std::collections::VecDeque;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11use super::{DistribCacheConfig, QueryContext, QueryFingerprint, SessionId};
12
13/// Workload types for cache strategy selection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum WorkloadType {
16    /// Online Transaction Processing
17    /// Characteristics: Point lookups, short transactions, low latency
18    OLTP,
19
20    /// Online Analytical Processing
21    /// Characteristics: Full scans, aggregations, high throughput
22    OLAP,
23
24    /// Vector/Embedding Operations
25    /// Characteristics: ANN search, similarity queries
26    Vector,
27
28    /// AI Agent Workloads
29    /// Characteristics: Context retrieval, tool calls, conversation
30    AIAgent,
31
32    /// RAG Pipeline
33    /// Characteristics: Embedding + retrieval + reranking
34    RAG,
35
36    /// Mixed/Unknown
37    Mixed,
38}
39
40/// Query history entry for session-based classification
41#[derive(Debug, Clone)]
42struct QueryHistoryEntry {
43    #[allow(dead_code)]
44    fingerprint: QueryFingerprint,
45    workload: WorkloadType,
46    timestamp: Instant,
47    latency_ms: u64,
48}
49
50/// Per-session query history
51#[derive(Debug)]
52struct SessionHistory {
53    /// Recent queries
54    queries: VecDeque<QueryHistoryEntry>,
55    /// Detected primary workload
56    primary_workload: Option<WorkloadType>,
57    /// Workload counts
58    oltp_count: u64,
59    olap_count: u64,
60    vector_count: u64,
61    ai_count: u64,
62    rag_count: u64,
63}
64
65impl SessionHistory {
66    fn new() -> Self {
67        Self {
68            queries: VecDeque::with_capacity(100),
69            primary_workload: None,
70            oltp_count: 0,
71            olap_count: 0,
72            vector_count: 0,
73            ai_count: 0,
74            rag_count: 0,
75        }
76    }
77
78    fn record(&mut self, entry: QueryHistoryEntry) {
79        // Update counts
80        match entry.workload {
81            WorkloadType::OLTP => self.oltp_count += 1,
82            WorkloadType::OLAP => self.olap_count += 1,
83            WorkloadType::Vector => self.vector_count += 1,
84            WorkloadType::AIAgent => self.ai_count += 1,
85            WorkloadType::RAG => self.rag_count += 1,
86            WorkloadType::Mixed => {}
87        }
88
89        // Add to history
90        self.queries.push_back(entry);
91        while self.queries.len() > 100 {
92            self.queries.pop_front();
93        }
94
95        // Update primary workload
96        self.primary_workload = self.determine_primary_workload();
97    }
98
99    fn determine_primary_workload(&self) -> Option<WorkloadType> {
100        let total =
101            self.oltp_count + self.olap_count + self.vector_count + self.ai_count + self.rag_count;
102
103        if total < 10 {
104            return None; // Not enough data
105        }
106
107        let max = *[
108            self.oltp_count,
109            self.olap_count,
110            self.vector_count,
111            self.ai_count,
112            self.rag_count,
113        ]
114        .iter()
115        .max()
116        .unwrap();
117
118        // Need > 50% to be considered primary
119        if max as f64 / total as f64 > 0.5 {
120            if max == self.oltp_count {
121                Some(WorkloadType::OLTP)
122            } else if max == self.olap_count {
123                Some(WorkloadType::OLAP)
124            } else if max == self.vector_count {
125                Some(WorkloadType::Vector)
126            } else if max == self.ai_count {
127                Some(WorkloadType::AIAgent)
128            } else {
129                Some(WorkloadType::RAG)
130            }
131        } else {
132            Some(WorkloadType::Mixed)
133        }
134    }
135}
136
137/// Classification rule for pattern matching
138#[derive(Debug, Clone)]
139pub struct ClassificationRule {
140    /// Rule name
141    pub name: String,
142    /// Patterns to match (SQL keywords/fragments)
143    pub patterns: Vec<String>,
144    /// Target workload type
145    pub workload: WorkloadType,
146    /// Rule priority (higher = checked first)
147    pub priority: u32,
148}
149
150/// Workload classifier
151pub struct WorkloadClassifier {
152    /// Configuration
153    #[allow(dead_code)]
154    config: DistribCacheConfig,
155
156    /// Classification rules (priority-ordered)
157    rules: Vec<ClassificationRule>,
158
159    /// Per-session query history
160    session_history: DashMap<SessionId, SessionHistory>,
161
162    /// Global statistics
163    stats: ClassifierStats,
164}
165
166/// Classifier statistics
167#[derive(Debug, Default)]
168struct ClassifierStats {
169    total_classified: AtomicU64,
170    oltp_count: AtomicU64,
171    olap_count: AtomicU64,
172    vector_count: AtomicU64,
173    ai_count: AtomicU64,
174    rag_count: AtomicU64,
175    mixed_count: AtomicU64,
176    rule_hits: AtomicU64,
177    session_hits: AtomicU64,
178    default_hits: AtomicU64,
179}
180
181impl WorkloadClassifier {
182    /// Create a new workload classifier
183    pub fn new(config: DistribCacheConfig) -> Self {
184        let rules = Self::default_rules();
185
186        Self {
187            config,
188            rules,
189            session_history: DashMap::new(),
190            stats: ClassifierStats::default(),
191        }
192    }
193
194    /// Default classification rules
195    fn default_rules() -> Vec<ClassificationRule> {
196        vec![
197            // Vector operations (highest priority)
198            ClassificationRule {
199                name: "vector_similarity".to_string(),
200                patterns: vec![
201                    "<->".to_string(),
202                    "<#>".to_string(),
203                    "<=>".to_string(),
204                    "VECTOR".to_string(),
205                    "EMBEDDING".to_string(),
206                    "COSINE_SIMILARITY".to_string(),
207                    "L2_DISTANCE".to_string(),
208                    "INNER_PRODUCT".to_string(),
209                ],
210                workload: WorkloadType::Vector,
211                priority: 100,
212            },
213            // RAG patterns
214            ClassificationRule {
215                name: "rag_pipeline".to_string(),
216                patterns: vec![
217                    "CHUNKS".to_string(),
218                    "DOCUMENTS".to_string(),
219                    "RERANK".to_string(),
220                    "RETRIEVE".to_string(),
221                ],
222                workload: WorkloadType::RAG,
223                priority: 90,
224            },
225            // AI Agent patterns
226            ClassificationRule {
227                name: "ai_agent".to_string(),
228                patterns: vec![
229                    "CONVERSATION".to_string(),
230                    "AGENT_".to_string(),
231                    "TOOL_".to_string(),
232                    "CONTEXT".to_string(),
233                    "MEMORY".to_string(),
234                    "TURNS".to_string(),
235                ],
236                workload: WorkloadType::AIAgent,
237                priority: 85,
238            },
239            // OLAP patterns
240            ClassificationRule {
241                name: "olap_aggregation".to_string(),
242                patterns: vec![
243                    "GROUP BY".to_string(),
244                    "HAVING".to_string(),
245                    "COUNT(".to_string(),
246                    "SUM(".to_string(),
247                    "AVG(".to_string(),
248                    "MIN(".to_string(),
249                    "MAX(".to_string(),
250                    "STDDEV".to_string(),
251                    "VARIANCE".to_string(),
252                    "PERCENTILE".to_string(),
253                ],
254                workload: WorkloadType::OLAP,
255                priority: 70,
256            },
257            ClassificationRule {
258                name: "olap_analytics".to_string(),
259                patterns: vec![
260                    "WINDOW".to_string(),
261                    "OVER(".to_string(),
262                    "PARTITION BY".to_string(),
263                    "ROLLUP".to_string(),
264                    "CUBE".to_string(),
265                    "GROUPING".to_string(),
266                ],
267                workload: WorkloadType::OLAP,
268                priority: 70,
269            },
270            ClassificationRule {
271                name: "olap_large_scan".to_string(),
272                patterns: vec![
273                    "ANALYTICS".to_string(),
274                    "REPORT".to_string(),
275                    "DASHBOARD".to_string(),
276                    "METRIC".to_string(),
277                ],
278                workload: WorkloadType::OLAP,
279                priority: 60,
280            },
281            // OLTP patterns (lower priority, broader match)
282            ClassificationRule {
283                name: "oltp_point_lookup".to_string(),
284                patterns: vec![
285                    "WHERE ID =".to_string(),
286                    "WHERE ID=".to_string(),
287                    "BY ID".to_string(),
288                    "LIMIT 1".to_string(),
289                ],
290                workload: WorkloadType::OLTP,
291                priority: 50,
292            },
293        ]
294    }
295
296    /// Classify a query based on patterns and session history
297    pub fn classify(&self, query: &str, context: &QueryContext) -> WorkloadType {
298        self.stats.total_classified.fetch_add(1, Ordering::Relaxed);
299
300        // 1. Check explicit hint
301        if let Some(hint) = context.workload_hint {
302            return hint;
303        }
304
305        // 2. Pattern-based classification
306        if let Some(workload) = self.classify_by_pattern(query) {
307            self.stats.rule_hits.fetch_add(1, Ordering::Relaxed);
308            self.record_query(context, query, workload);
309            return workload;
310        }
311
312        // 3. Session history-based classification
313        if let Some(workload) = self.classify_by_session(&context.session_id) {
314            self.stats.session_hits.fetch_add(1, Ordering::Relaxed);
315            self.record_query(context, query, workload);
316            return workload;
317        }
318
319        // 4. Default classification based on query structure
320        let workload = self.classify_by_structure(query);
321        self.stats.default_hits.fetch_add(1, Ordering::Relaxed);
322        self.record_query(context, query, workload);
323        workload
324    }
325
326    /// Simplified classify method for string queries
327    pub fn classify_query(&self, query: &str, context: &QueryContext) -> WorkloadType {
328        self.classify(query, context)
329    }
330
331    /// Classify based on pattern rules
332    fn classify_by_pattern(&self, query: &str) -> Option<WorkloadType> {
333        let upper = query.to_uppercase();
334
335        // Check rules in priority order
336        let mut sorted_rules = self.rules.clone();
337        sorted_rules.sort_by_key(|b| std::cmp::Reverse(b.priority));
338
339        for rule in &sorted_rules {
340            for pattern in &rule.patterns {
341                if upper.contains(pattern) {
342                    return Some(rule.workload);
343                }
344            }
345        }
346
347        None
348    }
349
350    /// Classify based on session history
351    fn classify_by_session(&self, session_id: &SessionId) -> Option<WorkloadType> {
352        self.session_history
353            .get(session_id)
354            .and_then(|history| history.primary_workload)
355    }
356
357    /// Classify based on query structure (fallback)
358    fn classify_by_structure(&self, query: &str) -> WorkloadType {
359        let upper = query.to_uppercase();
360
361        // Simple heuristics
362        if upper.starts_with("INSERT") || upper.starts_with("UPDATE") || upper.starts_with("DELETE")
363        {
364            return WorkloadType::OLTP;
365        }
366
367        // Full table scan likely OLAP
368        if upper.contains("SELECT") && !upper.contains("WHERE") && !upper.contains("LIMIT") {
369            return WorkloadType::OLAP;
370        }
371
372        // JOIN heavy likely OLAP
373        let join_count = upper.matches("JOIN").count();
374        if join_count >= 3 {
375            return WorkloadType::OLAP;
376        }
377
378        // Default to mixed
379        WorkloadType::Mixed
380    }
381
382    /// Record a query for history tracking
383    fn record_query(&self, context: &QueryContext, query: &str, workload: WorkloadType) {
384        // Update global stats
385        match workload {
386            WorkloadType::OLTP => self.stats.oltp_count.fetch_add(1, Ordering::Relaxed),
387            WorkloadType::OLAP => self.stats.olap_count.fetch_add(1, Ordering::Relaxed),
388            WorkloadType::Vector => self.stats.vector_count.fetch_add(1, Ordering::Relaxed),
389            WorkloadType::AIAgent => self.stats.ai_count.fetch_add(1, Ordering::Relaxed),
390            WorkloadType::RAG => self.stats.rag_count.fetch_add(1, Ordering::Relaxed),
391            WorkloadType::Mixed => self.stats.mixed_count.fetch_add(1, Ordering::Relaxed),
392        };
393
394        // Update session history
395        let entry = QueryHistoryEntry {
396            fingerprint: QueryFingerprint::from_query(query),
397            workload,
398            timestamp: Instant::now(),
399            latency_ms: 0, // Will be updated later
400        };
401
402        self.session_history
403            .entry(context.session_id.clone())
404            .or_insert_with(SessionHistory::new)
405            .record(entry);
406    }
407
408    /// Record query latency (call after execution)
409    pub fn record_latency(&self, session_id: &SessionId, latency_ms: u64) {
410        if let Some(mut history) = self.session_history.get_mut(session_id) {
411            if let Some(last) = history.queries.back_mut() {
412                last.latency_ms = latency_ms;
413            }
414        }
415    }
416
417    /// Add a custom classification rule
418    pub fn add_rule(&mut self, rule: ClassificationRule) {
419        self.rules.push(rule);
420    }
421
422    /// Get classifier statistics
423    pub fn stats(&self) -> ClassifierStatsSnapshot {
424        ClassifierStatsSnapshot {
425            total_classified: self.stats.total_classified.load(Ordering::Relaxed),
426            oltp_count: self.stats.oltp_count.load(Ordering::Relaxed),
427            olap_count: self.stats.olap_count.load(Ordering::Relaxed),
428            vector_count: self.stats.vector_count.load(Ordering::Relaxed),
429            ai_count: self.stats.ai_count.load(Ordering::Relaxed),
430            rag_count: self.stats.rag_count.load(Ordering::Relaxed),
431            mixed_count: self.stats.mixed_count.load(Ordering::Relaxed),
432            rule_hit_rate: self.stats.rule_hits.load(Ordering::Relaxed) as f64
433                / self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
434            session_hit_rate: self.stats.session_hits.load(Ordering::Relaxed) as f64
435                / self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
436        }
437    }
438
439    /// Clear session history older than threshold
440    pub fn cleanup_old_sessions(&self, max_age: Duration) {
441        let now = Instant::now();
442        self.session_history.retain(|_, history| {
443            if let Some(last) = history.queries.back() {
444                now.duration_since(last.timestamp) < max_age
445            } else {
446                false
447            }
448        });
449    }
450}
451
452/// Classifier statistics snapshot
453#[derive(Debug, Clone)]
454pub struct ClassifierStatsSnapshot {
455    pub total_classified: u64,
456    pub oltp_count: u64,
457    pub olap_count: u64,
458    pub vector_count: u64,
459    pub ai_count: u64,
460    pub rag_count: u64,
461    pub mixed_count: u64,
462    pub rule_hit_rate: f64,
463    pub session_hit_rate: f64,
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    fn make_context() -> QueryContext {
471        QueryContext::new("test-session")
472    }
473
474    #[test]
475    fn test_oltp_classification() {
476        let config = DistribCacheConfig::default();
477        let classifier = WorkloadClassifier::new(config);
478        let ctx = make_context();
479
480        let workload = classifier.classify("SELECT * FROM users WHERE id = 42", &ctx);
481        assert_eq!(workload, WorkloadType::OLTP);
482
483        let workload = classifier.classify("INSERT INTO users (name) VALUES ('Alice')", &ctx);
484        assert_eq!(workload, WorkloadType::OLTP);
485    }
486
487    #[test]
488    fn test_olap_classification() {
489        let config = DistribCacheConfig::default();
490        let classifier = WorkloadClassifier::new(config);
491        let ctx = make_context();
492
493        let workload =
494            classifier.classify("SELECT region, COUNT(*) FROM orders GROUP BY region", &ctx);
495        assert_eq!(workload, WorkloadType::OLAP);
496
497        let workload = classifier.classify("SELECT AVG(amount), SUM(quantity) FROM sales", &ctx);
498        assert_eq!(workload, WorkloadType::OLAP);
499    }
500
501    #[test]
502    fn test_vector_classification() {
503        let config = DistribCacheConfig::default();
504        let classifier = WorkloadClassifier::new(config);
505        let ctx = make_context();
506
507        let workload = classifier.classify(
508            "SELECT * FROM embeddings ORDER BY vector <-> $1 LIMIT 10",
509            &ctx,
510        );
511        assert_eq!(workload, WorkloadType::Vector);
512    }
513
514    #[test]
515    fn test_ai_agent_classification() {
516        let config = DistribCacheConfig::default();
517        let classifier = WorkloadClassifier::new(config);
518        let ctx = make_context();
519
520        let workload = classifier.classify(
521            "SELECT * FROM conversation_turns WHERE conversation_id = $1",
522            &ctx,
523        );
524        assert_eq!(workload, WorkloadType::AIAgent);
525
526        let workload = classifier.classify(
527            "INSERT INTO agent_memory (key, value) VALUES ($1, $2)",
528            &ctx,
529        );
530        assert_eq!(workload, WorkloadType::AIAgent);
531    }
532
533    #[test]
534    fn test_rag_classification() {
535        let config = DistribCacheConfig::default();
536        let classifier = WorkloadClassifier::new(config);
537        let ctx = make_context();
538
539        let workload = classifier.classify(
540            "SELECT content FROM documents WHERE id IN (SELECT doc_id FROM chunks WHERE ...)",
541            &ctx,
542        );
543        assert_eq!(workload, WorkloadType::RAG);
544    }
545
546    #[test]
547    fn test_explicit_hint() {
548        let config = DistribCacheConfig::default();
549        let classifier = WorkloadClassifier::new(config);
550        let ctx = make_context().with_workload_hint(WorkloadType::OLAP);
551
552        // Even though query looks like OLTP, hint overrides
553        let workload = classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
554        assert_eq!(workload, WorkloadType::OLAP);
555    }
556
557    #[test]
558    fn test_session_based_classification() {
559        let config = DistribCacheConfig::default();
560        let classifier = WorkloadClassifier::new(config);
561        let ctx = make_context();
562
563        // Run many OLAP queries to establish session pattern
564        for _ in 0..20 {
565            classifier.classify(
566                "SELECT COUNT(*) FROM analytics GROUP BY region",
567                &ctx.clone(),
568            );
569        }
570
571        // Now an ambiguous query should be classified as OLAP based on session history
572        let history = classifier.session_history.get(&ctx.session_id).unwrap();
573        assert!(history.olap_count >= 20);
574    }
575
576    #[test]
577    fn test_stats() {
578        let config = DistribCacheConfig::default();
579        let classifier = WorkloadClassifier::new(config);
580        let ctx = make_context();
581
582        classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
583        classifier.classify("SELECT COUNT(*) FROM orders GROUP BY status", &ctx);
584        classifier.classify("SELECT * FROM embeddings ORDER BY vec <-> $1", &ctx);
585
586        let stats = classifier.stats();
587        assert_eq!(stats.total_classified, 3);
588    }
589}