1use dashmap::DashMap;
7use std::collections::VecDeque;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::{Duration, Instant};
10
11use super::{DistribCacheConfig, QueryContext, SessionId, QueryFingerprint};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
15pub enum WorkloadType {
16 OLTP,
19
20 OLAP,
23
24 Vector,
27
28 AIAgent,
31
32 RAG,
35
36 Mixed,
38}
39
40#[derive(Debug, Clone)]
42struct QueryHistoryEntry {
43 fingerprint: QueryFingerprint,
44 workload: WorkloadType,
45 timestamp: Instant,
46 latency_ms: u64,
47}
48
49#[derive(Debug)]
51struct SessionHistory {
52 queries: VecDeque<QueryHistoryEntry>,
54 primary_workload: Option<WorkloadType>,
56 oltp_count: u64,
58 olap_count: u64,
59 vector_count: u64,
60 ai_count: u64,
61 rag_count: u64,
62}
63
64impl SessionHistory {
65 fn new() -> Self {
66 Self {
67 queries: VecDeque::with_capacity(100),
68 primary_workload: None,
69 oltp_count: 0,
70 olap_count: 0,
71 vector_count: 0,
72 ai_count: 0,
73 rag_count: 0,
74 }
75 }
76
77 fn record(&mut self, entry: QueryHistoryEntry) {
78 match entry.workload {
80 WorkloadType::OLTP => self.oltp_count += 1,
81 WorkloadType::OLAP => self.olap_count += 1,
82 WorkloadType::Vector => self.vector_count += 1,
83 WorkloadType::AIAgent => self.ai_count += 1,
84 WorkloadType::RAG => self.rag_count += 1,
85 WorkloadType::Mixed => {}
86 }
87
88 self.queries.push_back(entry);
90 while self.queries.len() > 100 {
91 self.queries.pop_front();
92 }
93
94 self.primary_workload = self.determine_primary_workload();
96 }
97
98 fn determine_primary_workload(&self) -> Option<WorkloadType> {
99 let total = self.oltp_count + self.olap_count + self.vector_count +
100 self.ai_count + self.rag_count;
101
102 if total < 10 {
103 return None; }
105
106 let max = *[
107 self.oltp_count,
108 self.olap_count,
109 self.vector_count,
110 self.ai_count,
111 self.rag_count,
112 ].iter().max().unwrap();
113
114 if max as f64 / total as f64 > 0.5 {
116 if max == self.oltp_count {
117 Some(WorkloadType::OLTP)
118 } else if max == self.olap_count {
119 Some(WorkloadType::OLAP)
120 } else if max == self.vector_count {
121 Some(WorkloadType::Vector)
122 } else if max == self.ai_count {
123 Some(WorkloadType::AIAgent)
124 } else {
125 Some(WorkloadType::RAG)
126 }
127 } else {
128 Some(WorkloadType::Mixed)
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct ClassificationRule {
136 pub name: String,
138 pub patterns: Vec<String>,
140 pub workload: WorkloadType,
142 pub priority: u32,
144}
145
146pub struct WorkloadClassifier {
148 config: DistribCacheConfig,
150
151 rules: Vec<ClassificationRule>,
153
154 session_history: DashMap<SessionId, SessionHistory>,
156
157 stats: ClassifierStats,
159}
160
161#[derive(Debug, Default)]
163struct ClassifierStats {
164 total_classified: AtomicU64,
165 oltp_count: AtomicU64,
166 olap_count: AtomicU64,
167 vector_count: AtomicU64,
168 ai_count: AtomicU64,
169 rag_count: AtomicU64,
170 mixed_count: AtomicU64,
171 rule_hits: AtomicU64,
172 session_hits: AtomicU64,
173 default_hits: AtomicU64,
174}
175
176impl WorkloadClassifier {
177 pub fn new(config: DistribCacheConfig) -> Self {
179 let rules = Self::default_rules();
180
181 Self {
182 config,
183 rules,
184 session_history: DashMap::new(),
185 stats: ClassifierStats::default(),
186 }
187 }
188
189 fn default_rules() -> Vec<ClassificationRule> {
191 vec![
192 ClassificationRule {
194 name: "vector_similarity".to_string(),
195 patterns: vec![
196 "<->".to_string(),
197 "<#>".to_string(),
198 "<=>".to_string(),
199 "VECTOR".to_string(),
200 "EMBEDDING".to_string(),
201 "COSINE_SIMILARITY".to_string(),
202 "L2_DISTANCE".to_string(),
203 "INNER_PRODUCT".to_string(),
204 ],
205 workload: WorkloadType::Vector,
206 priority: 100,
207 },
208 ClassificationRule {
210 name: "rag_pipeline".to_string(),
211 patterns: vec![
212 "CHUNKS".to_string(),
213 "DOCUMENTS".to_string(),
214 "RERANK".to_string(),
215 "RETRIEVE".to_string(),
216 ],
217 workload: WorkloadType::RAG,
218 priority: 90,
219 },
220 ClassificationRule {
222 name: "ai_agent".to_string(),
223 patterns: vec![
224 "CONVERSATION".to_string(),
225 "AGENT_".to_string(),
226 "TOOL_".to_string(),
227 "CONTEXT".to_string(),
228 "MEMORY".to_string(),
229 "TURNS".to_string(),
230 ],
231 workload: WorkloadType::AIAgent,
232 priority: 85,
233 },
234 ClassificationRule {
236 name: "olap_aggregation".to_string(),
237 patterns: vec![
238 "GROUP BY".to_string(),
239 "HAVING".to_string(),
240 "COUNT(".to_string(),
241 "SUM(".to_string(),
242 "AVG(".to_string(),
243 "MIN(".to_string(),
244 "MAX(".to_string(),
245 "STDDEV".to_string(),
246 "VARIANCE".to_string(),
247 "PERCENTILE".to_string(),
248 ],
249 workload: WorkloadType::OLAP,
250 priority: 70,
251 },
252 ClassificationRule {
253 name: "olap_analytics".to_string(),
254 patterns: vec![
255 "WINDOW".to_string(),
256 "OVER(".to_string(),
257 "PARTITION BY".to_string(),
258 "ROLLUP".to_string(),
259 "CUBE".to_string(),
260 "GROUPING".to_string(),
261 ],
262 workload: WorkloadType::OLAP,
263 priority: 70,
264 },
265 ClassificationRule {
266 name: "olap_large_scan".to_string(),
267 patterns: vec![
268 "ANALYTICS".to_string(),
269 "REPORT".to_string(),
270 "DASHBOARD".to_string(),
271 "METRIC".to_string(),
272 ],
273 workload: WorkloadType::OLAP,
274 priority: 60,
275 },
276 ClassificationRule {
278 name: "oltp_point_lookup".to_string(),
279 patterns: vec![
280 "WHERE ID =".to_string(),
281 "WHERE ID=".to_string(),
282 "BY ID".to_string(),
283 "LIMIT 1".to_string(),
284 ],
285 workload: WorkloadType::OLTP,
286 priority: 50,
287 },
288 ]
289 }
290
291 pub fn classify(&self, query: &str, context: &QueryContext) -> WorkloadType {
293 self.stats.total_classified.fetch_add(1, Ordering::Relaxed);
294
295 if let Some(hint) = context.workload_hint {
297 return hint;
298 }
299
300 if let Some(workload) = self.classify_by_pattern(query) {
302 self.stats.rule_hits.fetch_add(1, Ordering::Relaxed);
303 self.record_query(context, query, workload);
304 return workload;
305 }
306
307 if let Some(workload) = self.classify_by_session(&context.session_id) {
309 self.stats.session_hits.fetch_add(1, Ordering::Relaxed);
310 self.record_query(context, query, workload);
311 return workload;
312 }
313
314 let workload = self.classify_by_structure(query);
316 self.stats.default_hits.fetch_add(1, Ordering::Relaxed);
317 self.record_query(context, query, workload);
318 workload
319 }
320
321 pub fn classify_query(&self, query: &str, context: &QueryContext) -> WorkloadType {
323 self.classify(query, context)
324 }
325
326 fn classify_by_pattern(&self, query: &str) -> Option<WorkloadType> {
328 let upper = query.to_uppercase();
329
330 let mut sorted_rules = self.rules.clone();
332 sorted_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
333
334 for rule in &sorted_rules {
335 for pattern in &rule.patterns {
336 if upper.contains(pattern) {
337 return Some(rule.workload);
338 }
339 }
340 }
341
342 None
343 }
344
345 fn classify_by_session(&self, session_id: &SessionId) -> Option<WorkloadType> {
347 self.session_history
348 .get(session_id)
349 .and_then(|history| history.primary_workload)
350 }
351
352 fn classify_by_structure(&self, query: &str) -> WorkloadType {
354 let upper = query.to_uppercase();
355
356 if upper.starts_with("INSERT") || upper.starts_with("UPDATE") ||
358 upper.starts_with("DELETE") {
359 return WorkloadType::OLTP;
360 }
361
362 if upper.contains("SELECT") && !upper.contains("WHERE") && !upper.contains("LIMIT") {
364 return WorkloadType::OLAP;
365 }
366
367 let join_count = upper.matches("JOIN").count();
369 if join_count >= 3 {
370 return WorkloadType::OLAP;
371 }
372
373 WorkloadType::Mixed
375 }
376
377 fn record_query(&self, context: &QueryContext, query: &str, workload: WorkloadType) {
379 match workload {
381 WorkloadType::OLTP => self.stats.oltp_count.fetch_add(1, Ordering::Relaxed),
382 WorkloadType::OLAP => self.stats.olap_count.fetch_add(1, Ordering::Relaxed),
383 WorkloadType::Vector => self.stats.vector_count.fetch_add(1, Ordering::Relaxed),
384 WorkloadType::AIAgent => self.stats.ai_count.fetch_add(1, Ordering::Relaxed),
385 WorkloadType::RAG => self.stats.rag_count.fetch_add(1, Ordering::Relaxed),
386 WorkloadType::Mixed => self.stats.mixed_count.fetch_add(1, Ordering::Relaxed),
387 };
388
389 let entry = QueryHistoryEntry {
391 fingerprint: QueryFingerprint::from_query(query),
392 workload,
393 timestamp: Instant::now(),
394 latency_ms: 0, };
396
397 self.session_history
398 .entry(context.session_id.clone())
399 .or_insert_with(SessionHistory::new)
400 .record(entry);
401 }
402
403 pub fn record_latency(&self, session_id: &SessionId, latency_ms: u64) {
405 if let Some(mut history) = self.session_history.get_mut(session_id) {
406 if let Some(last) = history.queries.back_mut() {
407 last.latency_ms = latency_ms;
408 }
409 }
410 }
411
412 pub fn add_rule(&mut self, rule: ClassificationRule) {
414 self.rules.push(rule);
415 }
416
417 pub fn stats(&self) -> ClassifierStatsSnapshot {
419 ClassifierStatsSnapshot {
420 total_classified: self.stats.total_classified.load(Ordering::Relaxed),
421 oltp_count: self.stats.oltp_count.load(Ordering::Relaxed),
422 olap_count: self.stats.olap_count.load(Ordering::Relaxed),
423 vector_count: self.stats.vector_count.load(Ordering::Relaxed),
424 ai_count: self.stats.ai_count.load(Ordering::Relaxed),
425 rag_count: self.stats.rag_count.load(Ordering::Relaxed),
426 mixed_count: self.stats.mixed_count.load(Ordering::Relaxed),
427 rule_hit_rate: self.stats.rule_hits.load(Ordering::Relaxed) as f64 /
428 self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
429 session_hit_rate: self.stats.session_hits.load(Ordering::Relaxed) as f64 /
430 self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
431 }
432 }
433
434 pub fn cleanup_old_sessions(&self, max_age: Duration) {
436 let now = Instant::now();
437 self.session_history.retain(|_, history| {
438 if let Some(last) = history.queries.back() {
439 now.duration_since(last.timestamp) < max_age
440 } else {
441 false
442 }
443 });
444 }
445}
446
447#[derive(Debug, Clone)]
449pub struct ClassifierStatsSnapshot {
450 pub total_classified: u64,
451 pub oltp_count: u64,
452 pub olap_count: u64,
453 pub vector_count: u64,
454 pub ai_count: u64,
455 pub rag_count: u64,
456 pub mixed_count: u64,
457 pub rule_hit_rate: f64,
458 pub session_hit_rate: f64,
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 fn make_context() -> QueryContext {
466 QueryContext::new("test-session")
467 }
468
469 #[test]
470 fn test_oltp_classification() {
471 let config = DistribCacheConfig::default();
472 let classifier = WorkloadClassifier::new(config);
473 let ctx = make_context();
474
475 let workload = classifier.classify("SELECT * FROM users WHERE id = 42", &ctx);
476 assert_eq!(workload, WorkloadType::OLTP);
477
478 let workload = classifier.classify("INSERT INTO users (name) VALUES ('Alice')", &ctx);
479 assert_eq!(workload, WorkloadType::OLTP);
480 }
481
482 #[test]
483 fn test_olap_classification() {
484 let config = DistribCacheConfig::default();
485 let classifier = WorkloadClassifier::new(config);
486 let ctx = make_context();
487
488 let workload = classifier.classify(
489 "SELECT region, COUNT(*) FROM orders GROUP BY region",
490 &ctx
491 );
492 assert_eq!(workload, WorkloadType::OLAP);
493
494 let workload = classifier.classify(
495 "SELECT AVG(amount), SUM(quantity) FROM sales",
496 &ctx
497 );
498 assert_eq!(workload, WorkloadType::OLAP);
499 }
500
501 #[test]
502 fn test_vector_classification() {
503 let config = DistribCacheConfig::default();
504 let classifier = WorkloadClassifier::new(config);
505 let ctx = make_context();
506
507 let workload = classifier.classify(
508 "SELECT * FROM embeddings ORDER BY vector <-> $1 LIMIT 10",
509 &ctx
510 );
511 assert_eq!(workload, WorkloadType::Vector);
512 }
513
514 #[test]
515 fn test_ai_agent_classification() {
516 let config = DistribCacheConfig::default();
517 let classifier = WorkloadClassifier::new(config);
518 let ctx = make_context();
519
520 let workload = classifier.classify(
521 "SELECT * FROM conversation_turns WHERE conversation_id = $1",
522 &ctx
523 );
524 assert_eq!(workload, WorkloadType::AIAgent);
525
526 let workload = classifier.classify(
527 "INSERT INTO agent_memory (key, value) VALUES ($1, $2)",
528 &ctx
529 );
530 assert_eq!(workload, WorkloadType::AIAgent);
531 }
532
533 #[test]
534 fn test_rag_classification() {
535 let config = DistribCacheConfig::default();
536 let classifier = WorkloadClassifier::new(config);
537 let ctx = make_context();
538
539 let workload = classifier.classify(
540 "SELECT content FROM documents WHERE id IN (SELECT doc_id FROM chunks WHERE ...)",
541 &ctx
542 );
543 assert_eq!(workload, WorkloadType::RAG);
544 }
545
546 #[test]
547 fn test_explicit_hint() {
548 let config = DistribCacheConfig::default();
549 let classifier = WorkloadClassifier::new(config);
550 let ctx = make_context().with_workload_hint(WorkloadType::OLAP);
551
552 let workload = classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
554 assert_eq!(workload, WorkloadType::OLAP);
555 }
556
557 #[test]
558 fn test_session_based_classification() {
559 let config = DistribCacheConfig::default();
560 let classifier = WorkloadClassifier::new(config);
561 let ctx = make_context();
562
563 for _ in 0..20 {
565 classifier.classify("SELECT COUNT(*) FROM analytics GROUP BY region", &ctx.clone());
566 }
567
568 let history = classifier.session_history.get(&ctx.session_id).unwrap();
570 assert!(history.olap_count >= 20);
571 }
572
573 #[test]
574 fn test_stats() {
575 let config = DistribCacheConfig::default();
576 let classifier = WorkloadClassifier::new(config);
577 let ctx = make_context();
578
579 classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
580 classifier.classify("SELECT COUNT(*) FROM orders GROUP BY status", &ctx);
581 classifier.classify("SELECT * FROM embeddings ORDER BY vec <-> $1", &ctx);
582
583 let stats = classifier.stats();
584 assert_eq!(stats.total_classified, 3);
585 }
586}