Skip to main content

heliosdb_proxy/distribcache/
classifier.rs

1//! Workload classifier for intelligent caching decisions
2//!
3//! Classifies queries into workload types (OLTP, OLAP, Vector, AI Agent, RAG)
4//! to apply appropriate caching strategies.
5
6use dashmap::DashMap;
7use std::collections::VecDeque;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11use super::{DistribCacheConfig, QueryContext, SessionId, QueryFingerprint};
12
13/// Workload types for cache strategy selection
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum WorkloadType {
16    /// Online Transaction Processing
17    /// Characteristics: Point lookups, short transactions, low latency
18    OLTP,
19
20    /// Online Analytical Processing
21    /// Characteristics: Full scans, aggregations, high throughput
22    OLAP,
23
24    /// Vector/Embedding Operations
25    /// Characteristics: ANN search, similarity queries
26    Vector,
27
28    /// AI Agent Workloads
29    /// Characteristics: Context retrieval, tool calls, conversation
30    AIAgent,
31
32    /// RAG Pipeline
33    /// Characteristics: Embedding + retrieval + reranking
34    RAG,
35
36    /// Mixed/Unknown
37    Mixed,
38}
39
40/// Query history entry for session-based classification
41#[derive(Debug, Clone)]
42struct QueryHistoryEntry {
43    fingerprint: QueryFingerprint,
44    workload: WorkloadType,
45    timestamp: Instant,
46    latency_ms: u64,
47}
48
49/// Per-session query history
50#[derive(Debug)]
51struct SessionHistory {
52    /// Recent queries
53    queries: VecDeque<QueryHistoryEntry>,
54    /// Detected primary workload
55    primary_workload: Option<WorkloadType>,
56    /// Workload counts
57    oltp_count: u64,
58    olap_count: u64,
59    vector_count: u64,
60    ai_count: u64,
61    rag_count: u64,
62}
63
64impl SessionHistory {
65    fn new() -> Self {
66        Self {
67            queries: VecDeque::with_capacity(100),
68            primary_workload: None,
69            oltp_count: 0,
70            olap_count: 0,
71            vector_count: 0,
72            ai_count: 0,
73            rag_count: 0,
74        }
75    }
76
77    fn record(&mut self, entry: QueryHistoryEntry) {
78        // Update counts
79        match entry.workload {
80            WorkloadType::OLTP => self.oltp_count += 1,
81            WorkloadType::OLAP => self.olap_count += 1,
82            WorkloadType::Vector => self.vector_count += 1,
83            WorkloadType::AIAgent => self.ai_count += 1,
84            WorkloadType::RAG => self.rag_count += 1,
85            WorkloadType::Mixed => {}
86        }
87
88        // Add to history
89        self.queries.push_back(entry);
90        while self.queries.len() > 100 {
91            self.queries.pop_front();
92        }
93
94        // Update primary workload
95        self.primary_workload = self.determine_primary_workload();
96    }
97
98    fn determine_primary_workload(&self) -> Option<WorkloadType> {
99        let total = self.oltp_count + self.olap_count + self.vector_count +
100                    self.ai_count + self.rag_count;
101
102        if total < 10 {
103            return None; // Not enough data
104        }
105
106        let max = *[
107            self.oltp_count,
108            self.olap_count,
109            self.vector_count,
110            self.ai_count,
111            self.rag_count,
112        ].iter().max().unwrap();
113
114        // Need > 50% to be considered primary
115        if max as f64 / total as f64 > 0.5 {
116            if max == self.oltp_count {
117                Some(WorkloadType::OLTP)
118            } else if max == self.olap_count {
119                Some(WorkloadType::OLAP)
120            } else if max == self.vector_count {
121                Some(WorkloadType::Vector)
122            } else if max == self.ai_count {
123                Some(WorkloadType::AIAgent)
124            } else {
125                Some(WorkloadType::RAG)
126            }
127        } else {
128            Some(WorkloadType::Mixed)
129        }
130    }
131}
132
133/// Classification rule for pattern matching
134#[derive(Debug, Clone)]
135pub struct ClassificationRule {
136    /// Rule name
137    pub name: String,
138    /// Patterns to match (SQL keywords/fragments)
139    pub patterns: Vec<String>,
140    /// Target workload type
141    pub workload: WorkloadType,
142    /// Rule priority (higher = checked first)
143    pub priority: u32,
144}
145
146/// Workload classifier
147pub struct WorkloadClassifier {
148    /// Configuration
149    config: DistribCacheConfig,
150
151    /// Classification rules (priority-ordered)
152    rules: Vec<ClassificationRule>,
153
154    /// Per-session query history
155    session_history: DashMap<SessionId, SessionHistory>,
156
157    /// Global statistics
158    stats: ClassifierStats,
159}
160
161/// Classifier statistics
162#[derive(Debug, Default)]
163struct ClassifierStats {
164    total_classified: AtomicU64,
165    oltp_count: AtomicU64,
166    olap_count: AtomicU64,
167    vector_count: AtomicU64,
168    ai_count: AtomicU64,
169    rag_count: AtomicU64,
170    mixed_count: AtomicU64,
171    rule_hits: AtomicU64,
172    session_hits: AtomicU64,
173    default_hits: AtomicU64,
174}
175
176impl WorkloadClassifier {
177    /// Create a new workload classifier
178    pub fn new(config: DistribCacheConfig) -> Self {
179        let rules = Self::default_rules();
180
181        Self {
182            config,
183            rules,
184            session_history: DashMap::new(),
185            stats: ClassifierStats::default(),
186        }
187    }
188
189    /// Default classification rules
190    fn default_rules() -> Vec<ClassificationRule> {
191        vec![
192            // Vector operations (highest priority)
193            ClassificationRule {
194                name: "vector_similarity".to_string(),
195                patterns: vec![
196                    "<->".to_string(),
197                    "<#>".to_string(),
198                    "<=>".to_string(),
199                    "VECTOR".to_string(),
200                    "EMBEDDING".to_string(),
201                    "COSINE_SIMILARITY".to_string(),
202                    "L2_DISTANCE".to_string(),
203                    "INNER_PRODUCT".to_string(),
204                ],
205                workload: WorkloadType::Vector,
206                priority: 100,
207            },
208            // RAG patterns
209            ClassificationRule {
210                name: "rag_pipeline".to_string(),
211                patterns: vec![
212                    "CHUNKS".to_string(),
213                    "DOCUMENTS".to_string(),
214                    "RERANK".to_string(),
215                    "RETRIEVE".to_string(),
216                ],
217                workload: WorkloadType::RAG,
218                priority: 90,
219            },
220            // AI Agent patterns
221            ClassificationRule {
222                name: "ai_agent".to_string(),
223                patterns: vec![
224                    "CONVERSATION".to_string(),
225                    "AGENT_".to_string(),
226                    "TOOL_".to_string(),
227                    "CONTEXT".to_string(),
228                    "MEMORY".to_string(),
229                    "TURNS".to_string(),
230                ],
231                workload: WorkloadType::AIAgent,
232                priority: 85,
233            },
234            // OLAP patterns
235            ClassificationRule {
236                name: "olap_aggregation".to_string(),
237                patterns: vec![
238                    "GROUP BY".to_string(),
239                    "HAVING".to_string(),
240                    "COUNT(".to_string(),
241                    "SUM(".to_string(),
242                    "AVG(".to_string(),
243                    "MIN(".to_string(),
244                    "MAX(".to_string(),
245                    "STDDEV".to_string(),
246                    "VARIANCE".to_string(),
247                    "PERCENTILE".to_string(),
248                ],
249                workload: WorkloadType::OLAP,
250                priority: 70,
251            },
252            ClassificationRule {
253                name: "olap_analytics".to_string(),
254                patterns: vec![
255                    "WINDOW".to_string(),
256                    "OVER(".to_string(),
257                    "PARTITION BY".to_string(),
258                    "ROLLUP".to_string(),
259                    "CUBE".to_string(),
260                    "GROUPING".to_string(),
261                ],
262                workload: WorkloadType::OLAP,
263                priority: 70,
264            },
265            ClassificationRule {
266                name: "olap_large_scan".to_string(),
267                patterns: vec![
268                    "ANALYTICS".to_string(),
269                    "REPORT".to_string(),
270                    "DASHBOARD".to_string(),
271                    "METRIC".to_string(),
272                ],
273                workload: WorkloadType::OLAP,
274                priority: 60,
275            },
276            // OLTP patterns (lower priority, broader match)
277            ClassificationRule {
278                name: "oltp_point_lookup".to_string(),
279                patterns: vec![
280                    "WHERE ID =".to_string(),
281                    "WHERE ID=".to_string(),
282                    "BY ID".to_string(),
283                    "LIMIT 1".to_string(),
284                ],
285                workload: WorkloadType::OLTP,
286                priority: 50,
287            },
288        ]
289    }
290
291    /// Classify a query based on patterns and session history
292    pub fn classify(&self, query: &str, context: &QueryContext) -> WorkloadType {
293        self.stats.total_classified.fetch_add(1, Ordering::Relaxed);
294
295        // 1. Check explicit hint
296        if let Some(hint) = context.workload_hint {
297            return hint;
298        }
299
300        // 2. Pattern-based classification
301        if let Some(workload) = self.classify_by_pattern(query) {
302            self.stats.rule_hits.fetch_add(1, Ordering::Relaxed);
303            self.record_query(context, query, workload);
304            return workload;
305        }
306
307        // 3. Session history-based classification
308        if let Some(workload) = self.classify_by_session(&context.session_id) {
309            self.stats.session_hits.fetch_add(1, Ordering::Relaxed);
310            self.record_query(context, query, workload);
311            return workload;
312        }
313
314        // 4. Default classification based on query structure
315        let workload = self.classify_by_structure(query);
316        self.stats.default_hits.fetch_add(1, Ordering::Relaxed);
317        self.record_query(context, query, workload);
318        workload
319    }
320
321    /// Simplified classify method for string queries
322    pub fn classify_query(&self, query: &str, context: &QueryContext) -> WorkloadType {
323        self.classify(query, context)
324    }
325
326    /// Classify based on pattern rules
327    fn classify_by_pattern(&self, query: &str) -> Option<WorkloadType> {
328        let upper = query.to_uppercase();
329
330        // Check rules in priority order
331        let mut sorted_rules = self.rules.clone();
332        sorted_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
333
334        for rule in &sorted_rules {
335            for pattern in &rule.patterns {
336                if upper.contains(pattern) {
337                    return Some(rule.workload);
338                }
339            }
340        }
341
342        None
343    }
344
345    /// Classify based on session history
346    fn classify_by_session(&self, session_id: &SessionId) -> Option<WorkloadType> {
347        self.session_history
348            .get(session_id)
349            .and_then(|history| history.primary_workload)
350    }
351
352    /// Classify based on query structure (fallback)
353    fn classify_by_structure(&self, query: &str) -> WorkloadType {
354        let upper = query.to_uppercase();
355
356        // Simple heuristics
357        if upper.starts_with("INSERT") || upper.starts_with("UPDATE") ||
358           upper.starts_with("DELETE") {
359            return WorkloadType::OLTP;
360        }
361
362        // Full table scan likely OLAP
363        if upper.contains("SELECT") && !upper.contains("WHERE") && !upper.contains("LIMIT") {
364            return WorkloadType::OLAP;
365        }
366
367        // JOIN heavy likely OLAP
368        let join_count = upper.matches("JOIN").count();
369        if join_count >= 3 {
370            return WorkloadType::OLAP;
371        }
372
373        // Default to mixed
374        WorkloadType::Mixed
375    }
376
377    /// Record a query for history tracking
378    fn record_query(&self, context: &QueryContext, query: &str, workload: WorkloadType) {
379        // Update global stats
380        match workload {
381            WorkloadType::OLTP => self.stats.oltp_count.fetch_add(1, Ordering::Relaxed),
382            WorkloadType::OLAP => self.stats.olap_count.fetch_add(1, Ordering::Relaxed),
383            WorkloadType::Vector => self.stats.vector_count.fetch_add(1, Ordering::Relaxed),
384            WorkloadType::AIAgent => self.stats.ai_count.fetch_add(1, Ordering::Relaxed),
385            WorkloadType::RAG => self.stats.rag_count.fetch_add(1, Ordering::Relaxed),
386            WorkloadType::Mixed => self.stats.mixed_count.fetch_add(1, Ordering::Relaxed),
387        };
388
389        // Update session history
390        let entry = QueryHistoryEntry {
391            fingerprint: QueryFingerprint::from_query(query),
392            workload,
393            timestamp: Instant::now(),
394            latency_ms: 0, // Will be updated later
395        };
396
397        self.session_history
398            .entry(context.session_id.clone())
399            .or_insert_with(SessionHistory::new)
400            .record(entry);
401    }
402
403    /// Record query latency (call after execution)
404    pub fn record_latency(&self, session_id: &SessionId, latency_ms: u64) {
405        if let Some(mut history) = self.session_history.get_mut(session_id) {
406            if let Some(last) = history.queries.back_mut() {
407                last.latency_ms = latency_ms;
408            }
409        }
410    }
411
412    /// Add a custom classification rule
413    pub fn add_rule(&mut self, rule: ClassificationRule) {
414        self.rules.push(rule);
415    }
416
417    /// Get classifier statistics
418    pub fn stats(&self) -> ClassifierStatsSnapshot {
419        ClassifierStatsSnapshot {
420            total_classified: self.stats.total_classified.load(Ordering::Relaxed),
421            oltp_count: self.stats.oltp_count.load(Ordering::Relaxed),
422            olap_count: self.stats.olap_count.load(Ordering::Relaxed),
423            vector_count: self.stats.vector_count.load(Ordering::Relaxed),
424            ai_count: self.stats.ai_count.load(Ordering::Relaxed),
425            rag_count: self.stats.rag_count.load(Ordering::Relaxed),
426            mixed_count: self.stats.mixed_count.load(Ordering::Relaxed),
427            rule_hit_rate: self.stats.rule_hits.load(Ordering::Relaxed) as f64 /
428                          self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
429            session_hit_rate: self.stats.session_hits.load(Ordering::Relaxed) as f64 /
430                             self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
431        }
432    }
433
434    /// Clear session history older than threshold
435    pub fn cleanup_old_sessions(&self, max_age: Duration) {
436        let now = Instant::now();
437        self.session_history.retain(|_, history| {
438            if let Some(last) = history.queries.back() {
439                now.duration_since(last.timestamp) < max_age
440            } else {
441                false
442            }
443        });
444    }
445}
446
447/// Classifier statistics snapshot
448#[derive(Debug, Clone)]
449pub struct ClassifierStatsSnapshot {
450    pub total_classified: u64,
451    pub oltp_count: u64,
452    pub olap_count: u64,
453    pub vector_count: u64,
454    pub ai_count: u64,
455    pub rag_count: u64,
456    pub mixed_count: u64,
457    pub rule_hit_rate: f64,
458    pub session_hit_rate: f64,
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    fn make_context() -> QueryContext {
466        QueryContext::new("test-session")
467    }
468
469    #[test]
470    fn test_oltp_classification() {
471        let config = DistribCacheConfig::default();
472        let classifier = WorkloadClassifier::new(config);
473        let ctx = make_context();
474
475        let workload = classifier.classify("SELECT * FROM users WHERE id = 42", &ctx);
476        assert_eq!(workload, WorkloadType::OLTP);
477
478        let workload = classifier.classify("INSERT INTO users (name) VALUES ('Alice')", &ctx);
479        assert_eq!(workload, WorkloadType::OLTP);
480    }
481
482    #[test]
483    fn test_olap_classification() {
484        let config = DistribCacheConfig::default();
485        let classifier = WorkloadClassifier::new(config);
486        let ctx = make_context();
487
488        let workload = classifier.classify(
489            "SELECT region, COUNT(*) FROM orders GROUP BY region",
490            &ctx
491        );
492        assert_eq!(workload, WorkloadType::OLAP);
493
494        let workload = classifier.classify(
495            "SELECT AVG(amount), SUM(quantity) FROM sales",
496            &ctx
497        );
498        assert_eq!(workload, WorkloadType::OLAP);
499    }
500
501    #[test]
502    fn test_vector_classification() {
503        let config = DistribCacheConfig::default();
504        let classifier = WorkloadClassifier::new(config);
505        let ctx = make_context();
506
507        let workload = classifier.classify(
508            "SELECT * FROM embeddings ORDER BY vector <-> $1 LIMIT 10",
509            &ctx
510        );
511        assert_eq!(workload, WorkloadType::Vector);
512    }
513
514    #[test]
515    fn test_ai_agent_classification() {
516        let config = DistribCacheConfig::default();
517        let classifier = WorkloadClassifier::new(config);
518        let ctx = make_context();
519
520        let workload = classifier.classify(
521            "SELECT * FROM conversation_turns WHERE conversation_id = $1",
522            &ctx
523        );
524        assert_eq!(workload, WorkloadType::AIAgent);
525
526        let workload = classifier.classify(
527            "INSERT INTO agent_memory (key, value) VALUES ($1, $2)",
528            &ctx
529        );
530        assert_eq!(workload, WorkloadType::AIAgent);
531    }
532
533    #[test]
534    fn test_rag_classification() {
535        let config = DistribCacheConfig::default();
536        let classifier = WorkloadClassifier::new(config);
537        let ctx = make_context();
538
539        let workload = classifier.classify(
540            "SELECT content FROM documents WHERE id IN (SELECT doc_id FROM chunks WHERE ...)",
541            &ctx
542        );
543        assert_eq!(workload, WorkloadType::RAG);
544    }
545
546    #[test]
547    fn test_explicit_hint() {
548        let config = DistribCacheConfig::default();
549        let classifier = WorkloadClassifier::new(config);
550        let ctx = make_context().with_workload_hint(WorkloadType::OLAP);
551
552        // Even though query looks like OLTP, hint overrides
553        let workload = classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
554        assert_eq!(workload, WorkloadType::OLAP);
555    }
556
557    #[test]
558    fn test_session_based_classification() {
559        let config = DistribCacheConfig::default();
560        let classifier = WorkloadClassifier::new(config);
561        let ctx = make_context();
562
563        // Run many OLAP queries to establish session pattern
564        for _ in 0..20 {
565            classifier.classify("SELECT COUNT(*) FROM analytics GROUP BY region", &ctx.clone());
566        }
567
568        // Now an ambiguous query should be classified as OLAP based on session history
569        let history = classifier.session_history.get(&ctx.session_id).unwrap();
570        assert!(history.olap_count >= 20);
571    }
572
573    #[test]
574    fn test_stats() {
575        let config = DistribCacheConfig::default();
576        let classifier = WorkloadClassifier::new(config);
577        let ctx = make_context();
578
579        classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
580        classifier.classify("SELECT COUNT(*) FROM orders GROUP BY status", &ctx);
581        classifier.classify("SELECT * FROM embeddings ORDER BY vec <-> $1", &ctx);
582
583        let stats = classifier.stats();
584        assert_eq!(stats.total_classified, 3);
585    }
586}