use dashmap::DashMap;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use super::{DistribCacheConfig, QueryContext, SessionId, QueryFingerprint};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum WorkloadType {
OLTP,
OLAP,
Vector,
AIAgent,
RAG,
Mixed,
}
#[derive(Debug, Clone)]
struct QueryHistoryEntry {
fingerprint: QueryFingerprint,
workload: WorkloadType,
timestamp: Instant,
latency_ms: u64,
}
#[derive(Debug)]
struct SessionHistory {
queries: VecDeque<QueryHistoryEntry>,
primary_workload: Option<WorkloadType>,
oltp_count: u64,
olap_count: u64,
vector_count: u64,
ai_count: u64,
rag_count: u64,
}
impl SessionHistory {
fn new() -> Self {
Self {
queries: VecDeque::with_capacity(100),
primary_workload: None,
oltp_count: 0,
olap_count: 0,
vector_count: 0,
ai_count: 0,
rag_count: 0,
}
}
fn record(&mut self, entry: QueryHistoryEntry) {
match entry.workload {
WorkloadType::OLTP => self.oltp_count += 1,
WorkloadType::OLAP => self.olap_count += 1,
WorkloadType::Vector => self.vector_count += 1,
WorkloadType::AIAgent => self.ai_count += 1,
WorkloadType::RAG => self.rag_count += 1,
WorkloadType::Mixed => {}
}
self.queries.push_back(entry);
while self.queries.len() > 100 {
self.queries.pop_front();
}
self.primary_workload = self.determine_primary_workload();
}
fn determine_primary_workload(&self) -> Option<WorkloadType> {
let total = self.oltp_count + self.olap_count + self.vector_count +
self.ai_count + self.rag_count;
if total < 10 {
return None; }
let max = *[
self.oltp_count,
self.olap_count,
self.vector_count,
self.ai_count,
self.rag_count,
].iter().max().unwrap();
if max as f64 / total as f64 > 0.5 {
if max == self.oltp_count {
Some(WorkloadType::OLTP)
} else if max == self.olap_count {
Some(WorkloadType::OLAP)
} else if max == self.vector_count {
Some(WorkloadType::Vector)
} else if max == self.ai_count {
Some(WorkloadType::AIAgent)
} else {
Some(WorkloadType::RAG)
}
} else {
Some(WorkloadType::Mixed)
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationRule {
pub name: String,
pub patterns: Vec<String>,
pub workload: WorkloadType,
pub priority: u32,
}
pub struct WorkloadClassifier {
config: DistribCacheConfig,
rules: Vec<ClassificationRule>,
session_history: DashMap<SessionId, SessionHistory>,
stats: ClassifierStats,
}
#[derive(Debug, Default)]
struct ClassifierStats {
total_classified: AtomicU64,
oltp_count: AtomicU64,
olap_count: AtomicU64,
vector_count: AtomicU64,
ai_count: AtomicU64,
rag_count: AtomicU64,
mixed_count: AtomicU64,
rule_hits: AtomicU64,
session_hits: AtomicU64,
default_hits: AtomicU64,
}
impl WorkloadClassifier {
pub fn new(config: DistribCacheConfig) -> Self {
let rules = Self::default_rules();
Self {
config,
rules,
session_history: DashMap::new(),
stats: ClassifierStats::default(),
}
}
fn default_rules() -> Vec<ClassificationRule> {
vec![
ClassificationRule {
name: "vector_similarity".to_string(),
patterns: vec![
"<->".to_string(),
"<#>".to_string(),
"<=>".to_string(),
"VECTOR".to_string(),
"EMBEDDING".to_string(),
"COSINE_SIMILARITY".to_string(),
"L2_DISTANCE".to_string(),
"INNER_PRODUCT".to_string(),
],
workload: WorkloadType::Vector,
priority: 100,
},
ClassificationRule {
name: "rag_pipeline".to_string(),
patterns: vec![
"CHUNKS".to_string(),
"DOCUMENTS".to_string(),
"RERANK".to_string(),
"RETRIEVE".to_string(),
],
workload: WorkloadType::RAG,
priority: 90,
},
ClassificationRule {
name: "ai_agent".to_string(),
patterns: vec![
"CONVERSATION".to_string(),
"AGENT_".to_string(),
"TOOL_".to_string(),
"CONTEXT".to_string(),
"MEMORY".to_string(),
"TURNS".to_string(),
],
workload: WorkloadType::AIAgent,
priority: 85,
},
ClassificationRule {
name: "olap_aggregation".to_string(),
patterns: vec![
"GROUP BY".to_string(),
"HAVING".to_string(),
"COUNT(".to_string(),
"SUM(".to_string(),
"AVG(".to_string(),
"MIN(".to_string(),
"MAX(".to_string(),
"STDDEV".to_string(),
"VARIANCE".to_string(),
"PERCENTILE".to_string(),
],
workload: WorkloadType::OLAP,
priority: 70,
},
ClassificationRule {
name: "olap_analytics".to_string(),
patterns: vec![
"WINDOW".to_string(),
"OVER(".to_string(),
"PARTITION BY".to_string(),
"ROLLUP".to_string(),
"CUBE".to_string(),
"GROUPING".to_string(),
],
workload: WorkloadType::OLAP,
priority: 70,
},
ClassificationRule {
name: "olap_large_scan".to_string(),
patterns: vec![
"ANALYTICS".to_string(),
"REPORT".to_string(),
"DASHBOARD".to_string(),
"METRIC".to_string(),
],
workload: WorkloadType::OLAP,
priority: 60,
},
ClassificationRule {
name: "oltp_point_lookup".to_string(),
patterns: vec![
"WHERE ID =".to_string(),
"WHERE ID=".to_string(),
"BY ID".to_string(),
"LIMIT 1".to_string(),
],
workload: WorkloadType::OLTP,
priority: 50,
},
]
}
pub fn classify(&self, query: &str, context: &QueryContext) -> WorkloadType {
self.stats.total_classified.fetch_add(1, Ordering::Relaxed);
if let Some(hint) = context.workload_hint {
return hint;
}
if let Some(workload) = self.classify_by_pattern(query) {
self.stats.rule_hits.fetch_add(1, Ordering::Relaxed);
self.record_query(context, query, workload);
return workload;
}
if let Some(workload) = self.classify_by_session(&context.session_id) {
self.stats.session_hits.fetch_add(1, Ordering::Relaxed);
self.record_query(context, query, workload);
return workload;
}
let workload = self.classify_by_structure(query);
self.stats.default_hits.fetch_add(1, Ordering::Relaxed);
self.record_query(context, query, workload);
workload
}
pub fn classify_query(&self, query: &str, context: &QueryContext) -> WorkloadType {
self.classify(query, context)
}
fn classify_by_pattern(&self, query: &str) -> Option<WorkloadType> {
let upper = query.to_uppercase();
let mut sorted_rules = self.rules.clone();
sorted_rules.sort_by(|a, b| b.priority.cmp(&a.priority));
for rule in &sorted_rules {
for pattern in &rule.patterns {
if upper.contains(pattern) {
return Some(rule.workload);
}
}
}
None
}
fn classify_by_session(&self, session_id: &SessionId) -> Option<WorkloadType> {
self.session_history
.get(session_id)
.and_then(|history| history.primary_workload)
}
fn classify_by_structure(&self, query: &str) -> WorkloadType {
let upper = query.to_uppercase();
if upper.starts_with("INSERT") || upper.starts_with("UPDATE") ||
upper.starts_with("DELETE") {
return WorkloadType::OLTP;
}
if upper.contains("SELECT") && !upper.contains("WHERE") && !upper.contains("LIMIT") {
return WorkloadType::OLAP;
}
let join_count = upper.matches("JOIN").count();
if join_count >= 3 {
return WorkloadType::OLAP;
}
WorkloadType::Mixed
}
fn record_query(&self, context: &QueryContext, query: &str, workload: WorkloadType) {
match workload {
WorkloadType::OLTP => self.stats.oltp_count.fetch_add(1, Ordering::Relaxed),
WorkloadType::OLAP => self.stats.olap_count.fetch_add(1, Ordering::Relaxed),
WorkloadType::Vector => self.stats.vector_count.fetch_add(1, Ordering::Relaxed),
WorkloadType::AIAgent => self.stats.ai_count.fetch_add(1, Ordering::Relaxed),
WorkloadType::RAG => self.stats.rag_count.fetch_add(1, Ordering::Relaxed),
WorkloadType::Mixed => self.stats.mixed_count.fetch_add(1, Ordering::Relaxed),
};
let entry = QueryHistoryEntry {
fingerprint: QueryFingerprint::from_query(query),
workload,
timestamp: Instant::now(),
latency_ms: 0, };
self.session_history
.entry(context.session_id.clone())
.or_insert_with(SessionHistory::new)
.record(entry);
}
pub fn record_latency(&self, session_id: &SessionId, latency_ms: u64) {
if let Some(mut history) = self.session_history.get_mut(session_id) {
if let Some(last) = history.queries.back_mut() {
last.latency_ms = latency_ms;
}
}
}
pub fn add_rule(&mut self, rule: ClassificationRule) {
self.rules.push(rule);
}
pub fn stats(&self) -> ClassifierStatsSnapshot {
ClassifierStatsSnapshot {
total_classified: self.stats.total_classified.load(Ordering::Relaxed),
oltp_count: self.stats.oltp_count.load(Ordering::Relaxed),
olap_count: self.stats.olap_count.load(Ordering::Relaxed),
vector_count: self.stats.vector_count.load(Ordering::Relaxed),
ai_count: self.stats.ai_count.load(Ordering::Relaxed),
rag_count: self.stats.rag_count.load(Ordering::Relaxed),
mixed_count: self.stats.mixed_count.load(Ordering::Relaxed),
rule_hit_rate: self.stats.rule_hits.load(Ordering::Relaxed) as f64 /
self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
session_hit_rate: self.stats.session_hits.load(Ordering::Relaxed) as f64 /
self.stats.total_classified.load(Ordering::Relaxed).max(1) as f64,
}
}
pub fn cleanup_old_sessions(&self, max_age: Duration) {
let now = Instant::now();
self.session_history.retain(|_, history| {
if let Some(last) = history.queries.back() {
now.duration_since(last.timestamp) < max_age
} else {
false
}
});
}
}
#[derive(Debug, Clone)]
pub struct ClassifierStatsSnapshot {
pub total_classified: u64,
pub oltp_count: u64,
pub olap_count: u64,
pub vector_count: u64,
pub ai_count: u64,
pub rag_count: u64,
pub mixed_count: u64,
pub rule_hit_rate: f64,
pub session_hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
fn make_context() -> QueryContext {
QueryContext::new("test-session")
}
#[test]
fn test_oltp_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
let workload = classifier.classify("SELECT * FROM users WHERE id = 42", &ctx);
assert_eq!(workload, WorkloadType::OLTP);
let workload = classifier.classify("INSERT INTO users (name) VALUES ('Alice')", &ctx);
assert_eq!(workload, WorkloadType::OLTP);
}
#[test]
fn test_olap_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
let workload = classifier.classify(
"SELECT region, COUNT(*) FROM orders GROUP BY region",
&ctx
);
assert_eq!(workload, WorkloadType::OLAP);
let workload = classifier.classify(
"SELECT AVG(amount), SUM(quantity) FROM sales",
&ctx
);
assert_eq!(workload, WorkloadType::OLAP);
}
#[test]
fn test_vector_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
let workload = classifier.classify(
"SELECT * FROM embeddings ORDER BY vector <-> $1 LIMIT 10",
&ctx
);
assert_eq!(workload, WorkloadType::Vector);
}
#[test]
fn test_ai_agent_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
let workload = classifier.classify(
"SELECT * FROM conversation_turns WHERE conversation_id = $1",
&ctx
);
assert_eq!(workload, WorkloadType::AIAgent);
let workload = classifier.classify(
"INSERT INTO agent_memory (key, value) VALUES ($1, $2)",
&ctx
);
assert_eq!(workload, WorkloadType::AIAgent);
}
#[test]
fn test_rag_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
let workload = classifier.classify(
"SELECT content FROM documents WHERE id IN (SELECT doc_id FROM chunks WHERE ...)",
&ctx
);
assert_eq!(workload, WorkloadType::RAG);
}
#[test]
fn test_explicit_hint() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context().with_workload_hint(WorkloadType::OLAP);
let workload = classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
assert_eq!(workload, WorkloadType::OLAP);
}
#[test]
fn test_session_based_classification() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
for _ in 0..20 {
classifier.classify("SELECT COUNT(*) FROM analytics GROUP BY region", &ctx.clone());
}
let history = classifier.session_history.get(&ctx.session_id).unwrap();
assert!(history.olap_count >= 20);
}
#[test]
fn test_stats() {
let config = DistribCacheConfig::default();
let classifier = WorkloadClassifier::new(config);
let ctx = make_context();
classifier.classify("SELECT * FROM users WHERE id = 1", &ctx);
classifier.classify("SELECT COUNT(*) FROM orders GROUP BY status", &ctx);
classifier.classify("SELECT * FROM embeddings ORDER BY vec <-> $1", &ctx);
let stats = classifier.stats();
assert_eq!(stats.total_classified, 3);
}
}