1use dashmap::DashMap;
7use std::collections::VecDeque;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11use super::{DistribCacheConfig, QueryContext, QueryFingerprint, SessionId};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum WorkloadType {
16 OLTP,
19
20 OLAP,
23
24 Vector,
27
28 AIAgent,
31
32 RAG,
35
36 Mixed,
38}
39
40#[derive(Debug, Clone)]
42struct QueryHistoryEntry {
43 #[allow(dead_code)]
44 fingerprint: QueryFingerprint,
45 workload: WorkloadType,
46 timestamp: Instant,
47 latency_ms: u64,
48}
49
50#[derive(Debug)]
52struct SessionHistory {
53 queries: VecDeque<QueryHistoryEntry>,
55 primary_workload: Option<WorkloadType>,
57 oltp_count: u64,
59 olap_count: u64,
60 vector_count: u64,
61 ai_count: u64,
62 rag_count: u64,
63}
64
65impl SessionHistory {
66 fn new() -> Self {
67 Self {
68 queries: VecDeque::with_capacity(100),
69 primary_workload: None,
70 oltp_count: 0,
71 olap_count: 0,
72 vector_count: 0,
73 ai_count: 0,
74 rag_count: 0,
75 }
76 }
77
78 fn record(&mut self, entry: QueryHistoryEntry) {
79 match entry.workload {
81 WorkloadType::OLTP => self.oltp_count += 1,
82 WorkloadType::OLAP => self.olap_count += 1,
83 WorkloadType::Vector => self.vector_count += 1,
84 WorkloadType::AIAgent => self.ai_count += 1,
85 WorkloadType::RAG => self.rag_count += 1,
86 WorkloadType::Mixed => {}
87 }
88
89 self.queries.push_back(entry);
91 while self.queries.len() > 100 {
92 self.queries.pop_front();
93 }
94
95 self.primary_workload = self.determine_primary_workload();
97 }
98
99 fn determine_primary_workload(&self) -> Option<WorkloadType> {
100 let total =
101 self.oltp_count + self.olap_count + self.vector_count + self.ai_count + self.rag_count;
102
103 if total < 10 {
104 return None; }
106
107 let max = *[
108 self.oltp_count,
109 self.olap_count,
110 self.vector_count,
111 self.ai_count,
112 self.rag_count,
113 ]
114 .iter()
115 .max()
116 .unwrap();
117
118 if max as f64 / total as f64 > 0.5 {
120 if max == self.oltp_count {
121 Some(WorkloadType::OLTP)
122 } else if max == self.olap_count {
123 Some(WorkloadType::OLAP)
124 } else if max == self.vector_count {
125 Some(WorkloadType::Vector)
126 } else if max == self.ai_count {
127 Some(WorkloadType::AIAgent)
128 } else {
129 Some(WorkloadType::RAG)
130 }
131 } else {
132 Some(WorkloadType::Mixed)
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct ClassificationRule {
140 pub name: String,
142 pub patterns: Vec<String>,
144 pub workload: WorkloadType,
146 pub priority: u32,
148}
149
150pub struct WorkloadClassifier {
152 #[allow(dead_code)]
154 config: DistribCacheConfig,
155
156 rules: Vec<ClassificationRule>,
158
159 session_history: DashMap<SessionId, SessionHistory>,
161
162 stats: ClassifierStats,
164}
165
166#[derive(Debug, Default)]
168struct ClassifierStats {
169 total_classified: AtomicU64,
170 oltp_count: AtomicU64,
171 olap_count: AtomicU64,
172 vector_count: AtomicU64,
173 ai_count: AtomicU64,
174 rag_count: AtomicU64,
175 mixed_count: AtomicU64,
176 rule_hits: AtomicU64,
177 session_hits: AtomicU64,
178 default_hits: AtomicU64,
179}
180
181impl WorkloadClassifier {
182 pub fn new(config: DistribCacheConfig) -> Self {
184 let rules = Self::default_rules();
185
186 Self {
187 config,
188 rules,
189 session_history: DashMap::new(),
190 stats: ClassifierStats::default(),
191 }
192 }
193
194 fn default_rules() -> Vec<ClassificationRule> {
196 vec![
197 ClassificationRule {
199 name: "vector_similarity".to_string(),
200 patterns: vec![
201 "<->".to_string(),
202 "<#>".to_string(),
203 "<=>".to_string(),
204 "VECTOR".to_string(),
205 "EMBEDDING".to_string(),
206 "COSINE_SIMILARITY".to_string(),
207 "L2_DISTANCE".to_string(),
208 "INNER_PRODUCT".to_string(),
209 ],
210 workload: WorkloadType::Vector,
211 priority: 100,
212 },
213 ClassificationRule {
215 name: "rag_pipeline".to_string(),
216 patterns: vec![
217 "CHUNKS".to_string(),
218 "DOCUMENTS".to_string(),
219 "RERANK".to_string(),
220 "RETRIEVE".to_string(),
221 ],
222 workload: WorkloadType::RAG,
223 priority: 90,
224 },
225 ClassificationRule {
227 name: "ai_agent".to_string(),
228 patterns: vec![
229 "CONVERSATION".to_string(),
230 "AGENT_".to_string(),
231 "TOOL_".to_string(),
232 "CONTEXT".to_string(),
233 "MEMORY".to_string(),
234 "TURNS".to_string(),
235 ],
236 workload: WorkloadType::AIAgent,
237 priority: 85,
238 },
239 ClassificationRule {
241 name: "olap_aggregation".to_string(),
242 patterns: vec![
243 "GROUP BY".to_string(),
244 "HAVING".to_string(),
245 "COUNT(".to_string(),
246 "SUM(".to_string(),
247 "AVG(".to_string(),
248 "MIN(".to_string(),
249 "MAX(".to_string(),
250 "STDDEV".to_string(),
251 "VARIANCE".to_string(),
252 "PERCENTILE".to_string(),
253 ],
254 workload: WorkloadType::OLAP,
255 priority: 70,
256 },
257 ClassificationRule {
258 name: "olap_analytics".to_string(),
259 patterns: vec![
260 "WINDOW".to_string(),
261 "OVER(".to_string(),
262 "PARTITION BY".to_string(),
263 "ROLLUP".to_string(),
264 "CUBE".to_string(),
265 "GROUPING".to_string(),
266 ],
267 workload: WorkloadType::OLAP,
268 priority: 70,
269 },
270 ClassificationRule {
271 name: "olap_large_scan".to_string(),
272 patterns: vec![
273 "ANALYTICS".to_string(),
274 "REPORT".to_string(),
275 "DASHBOARD".to_string(),
276 "METRIC".to_string(),
277 ],
278 workload: WorkloadType::OLAP,
279 priority: 60,
280 },
281 ClassificationRule {
283 name: "oltp_point_lookup".to_string(),
284 patterns: vec![
285 "WHERE ID =".to_string(),
286 "WHERE ID=".to_string(),
287 "BY ID".to_string(),
288 "LIMIT 1".to_string(),
289 ],
290 workload: WorkloadType::OLTP,
291 priority: 50,
292 },
293 ]
294 }
295
296 pub fn classify(&self, query: &str, context: &QueryContext) -> WorkloadType {
298 self.stats.total_classified.fetch_add(1, Ordering::Relaxed);
299
300 if let Some(hint) = context.workload_hint {
302 return hint;
303 }
304
305 if let Some(workload) = self.classify_by_pattern(query) {
307 self.stats.rule_hits.fetch_add(1, Ordering::Relaxed);
308 self.record_query(context, query, workload);
309 return workload;
310 }
311
312 if let Some(workload) = self.classify_by_session(&context.session_id) {
314 self.stats.session_hits.fetch_add(1, Ordering::Relaxed);
315 self.record_query(context, query, workload);
316 return workload;
317 }
318
319 let workload = self.classify_by_structure(query);
321 self.stats.default_hits.fetch_add(1, Ordering::Relaxed);
322 self.record_query(context, query, workload);
323 workload
324 }
325
326 pub fn classify_query(&self, query: &str, context: &QueryContext) -> WorkloadType {
328 self.classify(query, context)
329 }
330
331 fn classify_by_pattern(&self, query: &str) -> Option<WorkloadType> {
333 let upper = query.to_uppercase();
334
335 let mut sorted_rules = self.rules.clone();
337 sorted_rules.sort_by_key(|b| std::cmp::Reverse(b.priority));
338
339 for rule in &sorted_rules {
340 for pattern in &rule.patterns {
341 if upper.contains(pattern) {
342 return Some(rule.workload);
343 }
344 }
345 }
346
347 None
348 }
349
350 fn classify_by_session(&self, session_id: &SessionId) -> Option<WorkloadType> {
352 self.session_history
353 .get(session_id)
354 .and_then(|history| history.primary_workload)
355 }
356
357 fn classify_by_structure(&self, query: &str) -> WorkloadType {
359 let upper = query.to_uppercase();
360
361 if upper.starts_with("INSERT") || upper.starts_with("UPDATE") || upper.starts_with("DELETE")
363 {
364 return WorkloadType::OLTP;
365 }
366
367 if upper.contains("SELECT") && !upper.contains("WHERE") && !upper.contains("LIMIT") {
369 return WorkloadType::OLAP;
370 }
371
372 let join_count = upper.matches("JOIN").count();
374 if join_count >= 3 {
375 return WorkloadType::OLAP;
376 }
377
378 WorkloadType::Mixed
380 }
381
382 fn record_query(&self, context: &QueryContext, query: &str, workload: WorkloadType) {
384 match workload {
386 WorkloadType::OLTP => self.stats.oltp_count.fetch_add(1, Ordering::Relaxed),
387 WorkloadType::OLAP => self.stats.olap_count.fetch_add(1, Ordering::Relaxed),
388 WorkloadType::Vector => self.stats.vector_count.fetch_add(1, Ordering::Relaxed),
389 WorkloadType::AIAgent => self.stats.ai_count.fetch_add(1, Ordering::Relaxed),
390 WorkloadType::RAG => self.stats.rag_count.fetch_add(1, Ordering::Relaxed),
391 WorkloadType::Mixed => self.stats.mixed_count.fetch_add(1, Ordering::Relaxed),
392 };
393
394 let entry = QueryHistoryEntry {
396 fingerprint: QueryFingerprint::from_query(query),
397 workload,
398 timestamp: Instant::now(),
399 latency_ms: 0, };
401
402 self.session_history
403 .entry(context.session_id.clone())
404 .or_insert_with(SessionHistory::new)
405 .record(entry);
406 }
407
408 pub fn record_latency(&self, session_id: &SessionId, latency_ms: u64) {
410 if let Some(mut history) = self.session_history.get_mut(session_id) {
411 if let Some(last) = history.queries.back_mut() {
412 last.latency_ms = latency_ms;
413 }
414 }
415 }
416
417 pub fn add_rule(&mut self, rule: ClassificationRule) {
419 self.rules.push(rule);
420 }
421
422 pub fn stats(&self) -> ClassifierStatsSnapshot {
424 ClassifierStatsSnapshot {
425 total_classified: self.stats.total_classified.load(Ordering::Relaxed),
426 oltp_count: self.stats.oltp_count.load(Ordering::Relaxed),
427 olap_count: self.stats.olap_count.load(Ordering::Relaxed),
428 vector_count: self.stats.vector_count.load(Ordering::Relaxed),
429 ai_count: self.stats.ai_count.load(Ordering::Relaxed),
430 rag_count: self.stats.rag_count.load(Ordering::Relaxed),
431 mixed_count: self.stats.mixed_count.load(Ordering::Relaxed),
432 rule_hit_rate: self.stats.rule_hits.load(Ordering::Relaxed) as f64
433 / self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
434 session_hit_rate: self.stats.session_hits.load(Ordering::Relaxed) as f64
435 / self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
436 }
437 }
438
439 pub fn cleanup_old_sessions(&self, max_age: Duration) {
441 let now = Instant::now();
442 self.session_history.retain(|_, history| {
443 if let Some(last) = history.queries.back() {
444 now.duration_since(last.timestamp) < max_age
445 } else {
446 false
447 }
448 });
449 }
450}
451
452#[derive(Debug, Clone)]
454pub struct ClassifierStatsSnapshot {
455 pub total_classified: u64,
456 pub oltp_count: u64,
457 pub olap_count: u64,
458 pub vector_count: u64,
459 pub ai_count: u64,
460 pub rag_count: u64,
461 pub mixed_count: u64,
462 pub rule_hit_rate: f64,
463 pub session_hit_rate: f64,
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 fn make_context() -> QueryContext {
471 QueryContext::new("test-session")
472 }
473
474 #[test]
475 fn test_oltp_classification() {
476 let config = DistribCacheConfig::default();
477 let classifier = WorkloadClassifier::new(config);
478 let ctx = make_context();
479
480 let workload = classifier.classify("SELECT * FROM users WHERE id = 42", &ctx);
481 assert_eq!(workload, WorkloadType::OLTP);
482
483 let workload = classifier.classify("INSERT INTO users (name) VALUES ('Alice')", &ctx);
484 assert_eq!(workload, WorkloadType::OLTP);
485 }
486
487 #[test]
488 fn test_olap_classification() {
489 let config = DistribCacheConfig::default();
490 let classifier = WorkloadClassifier::new(config);
491 let ctx = make_context();
492
493 let workload =
494 classifier.classify("SELECT region, COUNT(*) FROM orders GROUP BY region", &ctx);
495 assert_eq!(workload, WorkloadType::OLAP);
496
497 let workload = classifier.classify("SELECT AVG(amount), SUM(quantity) FROM sales", &ctx);
498 assert_eq!(workload, WorkloadType::OLAP);
499 }
500
501 #[test]
502 fn test_vector_classification() {
503 let config = DistribCacheConfig::default();
504 let classifier = WorkloadClassifier::new(config);
505 let ctx = make_context();
506
507 let workload = classifier.classify(
508 "SELECT * FROM embeddings ORDER BY vector <-> $1 LIMIT 10",
509 &ctx,
510 );
511 assert_eq!(workload, WorkloadType::Vector);
512 }
513
514 #[test]
515 fn test_ai_agent_classification() {
516 let config = DistribCacheConfig::default();
517 let classifier = WorkloadClassifier::new(config);
518 let ctx = make_context();
519
520 let workload = classifier.classify(
521 "SELECT * FROM conversation_turns WHERE conversation_id = $1",
522 &ctx,
523 );
524 assert_eq!(workload, WorkloadType::AIAgent);
525
526 let workload = classifier.classify(
527 "INSERT INTO agent_memory (key, value) VALUES ($1, $2)",
528 &ctx,
529 );
530 assert_eq!(workload, WorkloadType::AIAgent);
531 }
532
533 #[test]
534 fn test_rag_classification() {
535 let config = DistribCacheConfig::default();
536 let classifier = WorkloadClassifier::new(config);
537 let ctx = make_context();
538
539 let workload = classifier.classify(
540 "SELECT content FROM documents WHERE id IN (SELECT doc_id FROM chunks WHERE ...)",
541 &ctx,
542 );
543 assert_eq!(workload, WorkloadType::RAG);
544 }
545
546 #[test]
547 fn test_explicit_hint() {
548 let config = DistribCacheConfig::default();
549 let classifier = WorkloadClassifier::new(config);
550 let ctx = make_context().with_workload_hint(WorkloadType::OLAP);
551
552 let workload = classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
554 assert_eq!(workload, WorkloadType::OLAP);
555 }
556
557 #[test]
558 fn test_session_based_classification() {
559 let config = DistribCacheConfig::default();
560 let classifier = WorkloadClassifier::new(config);
561 let ctx = make_context();
562
563 for _ in 0..20 {
565 classifier.classify(
566 "SELECT COUNT(*) FROM analytics GROUP BY region",
567 &ctx.clone(),
568 );
569 }
570
571 let history = classifier.session_history.get(&ctx.session_id).unwrap();
573 assert!(history.olap_count >= 20);
574 }
575
576 #[test]
577 fn test_stats() {
578 let config = DistribCacheConfig::default();
579 let classifier = WorkloadClassifier::new(config);
580 let ctx = make_context();
581
582 classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
583 classifier.classify("SELECT COUNT(*) FROM orders GROUP BY status", &ctx);
584 classifier.classify("SELECT * FROM embeddings ORDER BY vec <-> $1", &ctx);
585
586 let stats = classifier.stats();
587 assert_eq!(stats.total_classified, 3);
588 }
589}