use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use super::semantic::{AIWorkloadContext, BranchContext, BranchId, SemanticQueryCache};
use crate::distribcache::classifier::WorkloadType;
use crate::distribcache::scheduler::{ScheduleResult, ScheduledQuery, WorkloadScheduler};
use crate::distribcache::SessionId;
#[derive(Debug, Clone)]
pub struct AIWorkloadDetection {
pub workload_type: WorkloadType,
pub ai_context: AIWorkloadContext,
pub confidence: f32,
pub patterns: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct SessionTrackingInfo {
pub session_id: SessionId,
pub branch: Option<BranchContext>,
pub last_activity: Instant,
pub transaction_depth: u32,
pub ai_context: AIWorkloadContext,
pub query_count: u64,
pub cache_hits: u64,
}
impl SessionTrackingInfo {
pub fn new(session_id: impl Into<String>) -> Self {
Self {
session_id: SessionId::new(session_id),
branch: None,
last_activity: Instant::now(),
transaction_depth: 0,
ai_context: AIWorkloadContext::General,
query_count: 0,
cache_hits: 0,
}
}
pub fn with_branch(mut self, branch: BranchContext) -> Self {
self.branch = Some(branch);
self
}
pub fn with_ai_context(mut self, context: AIWorkloadContext) -> Self {
self.ai_context = context;
self
}
pub fn record_query(&mut self) {
self.query_count += 1;
self.last_activity = Instant::now();
}
pub fn record_cache_hit(&mut self) {
self.cache_hits += 1;
}
pub fn cache_hit_rate(&self) -> f64 {
if self.query_count == 0 {
0.0
} else {
self.cache_hits as f64 / self.query_count as f64
}
}
pub fn is_idle(&self, timeout: Duration) -> bool {
self.last_activity.elapsed() > timeout
}
}
pub struct AIIntegrationCoordinator {
semantic_cache: Arc<SemanticQueryCache>,
sessions: DashMap<SessionId, SessionTrackingInfo>,
scheduler: Option<Arc<WorkloadScheduler>>,
config: AIIntegrationConfig,
stats: AIIntegrationStats,
}
#[derive(Debug, Clone)]
pub struct AIIntegrationConfig {
pub twr_tracking: bool,
pub session_idle_timeout: Duration,
pub max_sessions: usize,
pub workload_detection: bool,
pub rag_detection_threshold: f32,
pub agent_detection_threshold: f32,
}
impl Default for AIIntegrationConfig {
fn default() -> Self {
Self {
twr_tracking: true,
session_idle_timeout: Duration::from_secs(3600), max_sessions: 10000,
workload_detection: true,
rag_detection_threshold: 0.7,
agent_detection_threshold: 0.8,
}
}
}
#[derive(Debug, Default)]
struct AIIntegrationStats {
detections: AtomicU64,
rag_detected: AtomicU64,
agent_detected: AtomicU64,
tool_detected: AtomicU64,
sessions_tracked: AtomicU64,
cross_feature_hits: AtomicU64,
}
impl AIIntegrationCoordinator {
pub fn new(semantic_cache: Arc<SemanticQueryCache>, config: AIIntegrationConfig) -> Self {
Self {
semantic_cache,
sessions: DashMap::new(),
scheduler: None,
config,
stats: AIIntegrationStats::default(),
}
}
pub fn with_scheduler(mut self, scheduler: Arc<WorkloadScheduler>) -> Self {
self.scheduler = Some(scheduler);
self
}
pub fn detect_workload(&self, query: &str, session: Option<&SessionId>) -> AIWorkloadDetection {
self.stats.detections.fetch_add(1, Ordering::Relaxed);
let mut patterns = Vec::new();
let mut confidence = 0.0f32;
let mut ai_context = AIWorkloadContext::General;
let mut workload_type = WorkloadType::Mixed;
let query_lower = query.to_lowercase();
if self.is_rag_pattern(&query_lower) {
patterns.push("RAG retrieval".to_string());
confidence = 0.85;
ai_context = AIWorkloadContext::RAGRetrieval;
workload_type = WorkloadType::RAG;
self.stats.rag_detected.fetch_add(1, Ordering::Relaxed);
}
if self.is_agent_pattern(&query_lower, session) {
patterns.push("Agent conversation".to_string());
confidence = confidence.max(0.80);
ai_context = AIWorkloadContext::AgentConversation;
workload_type = WorkloadType::AIAgent;
self.stats.agent_detected.fetch_add(1, Ordering::Relaxed);
}
if self.is_tool_pattern(&query_lower) {
patterns.push("Tool result".to_string());
confidence = confidence.max(0.90);
ai_context = AIWorkloadContext::ToolResult;
workload_type = WorkloadType::AIAgent;
self.stats.tool_detected.fetch_add(1, Ordering::Relaxed);
}
if workload_type != WorkloadType::RAG {
if query_lower.contains("embedding") || query_lower.contains("vector") || query_lower.contains("similarity") {
patterns.push("Vector search".to_string());
confidence = confidence.max(0.75);
workload_type = WorkloadType::Vector;
}
}
if self.is_olap_pattern(&query_lower) {
patterns.push("OLAP analytics".to_string());
confidence = confidence.max(0.70);
workload_type = WorkloadType::OLAP;
}
AIWorkloadDetection {
workload_type,
ai_context,
confidence,
patterns,
}
}
fn is_rag_pattern(&self, query: &str) -> bool {
let rag_patterns = [
"chunk", "retrieve", "context", "passage", "document",
"semantic", "similarity", "cosine", "embedding",
];
rag_patterns.iter().any(|p| query.contains(p))
}
fn is_agent_pattern(&self, query: &str, session: Option<&SessionId>) -> bool {
if let Some(sid) = session {
if let Some(info) = self.sessions.get(sid) {
if info.ai_context == AIWorkloadContext::AgentConversation {
return true;
}
if info.query_count > 5 && info.cache_hit_rate() > 0.3 {
return true;
}
}
}
let agent_patterns = ["conversation", "history", "context", "message", "response"];
agent_patterns.iter().any(|p| query.contains(p))
}
fn is_tool_pattern(&self, query: &str) -> bool {
let tool_patterns = [
"tool_", "function_", "api_result", "calculate",
"format_", "convert_", "lookup_",
];
tool_patterns.iter().any(|p| query.contains(p))
}
fn is_olap_pattern(&self, query: &str) -> bool {
let olap_patterns = [
"group by", "having", "aggregate", "sum(", "count(",
"avg(", "window", "partition by", "rollup", "cube",
];
olap_patterns.iter().any(|p| query.contains(p))
}
pub fn track_session(
&self,
session_id: impl Into<String>,
branch: Option<BranchContext>,
ai_context: AIWorkloadContext,
) {
let sid = SessionId::new(session_id);
if self.sessions.len() >= self.config.max_sessions {
self.cleanup_idle_sessions();
}
let mut info = SessionTrackingInfo::new(sid.0.clone())
.with_ai_context(ai_context);
if let Some(b) = branch {
info = info.with_branch(b);
}
self.sessions.insert(sid, info);
self.stats.sessions_tracked.fetch_add(1, Ordering::Relaxed);
}
pub fn update_session(&self, session_id: &SessionId, cache_hit: bool) {
if let Some(mut info) = self.sessions.get_mut(session_id) {
info.record_query();
if cache_hit {
info.record_cache_hit();
self.stats.cross_feature_hits.fetch_add(1, Ordering::Relaxed);
}
}
}
pub fn get_session(&self, session_id: &SessionId) -> Option<SessionTrackingInfo> {
self.sessions.get(session_id).map(|r| r.clone())
}
pub fn begin_transaction(&self, session_id: &SessionId) {
if let Some(mut info) = self.sessions.get_mut(session_id) {
info.transaction_depth += 1;
}
}
pub fn end_transaction(&self, session_id: &SessionId) {
if let Some(mut info) = self.sessions.get_mut(session_id) {
if info.transaction_depth > 0 {
info.transaction_depth -= 1;
}
}
}
pub fn is_in_transaction(&self, session_id: &SessionId) -> bool {
self.sessions
.get(session_id)
.map(|info| info.transaction_depth > 0)
.unwrap_or(false)
}
pub fn get_cache_recommendation(&self, detection: &AIWorkloadDetection) -> CacheRecommendation {
match detection.ai_context {
AIWorkloadContext::RAGRetrieval => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(300),
priority: CachePriority::High,
tier: RecommendedTier::L1,
},
AIWorkloadContext::RAGGeneration => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(1800),
priority: CachePriority::Medium,
tier: RecommendedTier::L2,
},
AIWorkloadContext::AgentConversation => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(3600),
priority: CachePriority::High,
tier: RecommendedTier::L1,
},
AIWorkloadContext::ToolResult => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(86400),
priority: CachePriority::Low,
tier: RecommendedTier::L2,
},
AIWorkloadContext::General => {
match detection.workload_type {
WorkloadType::OLTP => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(60),
priority: CachePriority::High,
tier: RecommendedTier::L1,
},
WorkloadType::OLAP => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(3600),
priority: CachePriority::Low,
tier: RecommendedTier::L3,
},
WorkloadType::Vector => CacheRecommendation {
should_cache: true,
ttl: Duration::from_secs(600),
priority: CachePriority::Medium,
tier: RecommendedTier::L2,
},
_ => CacheRecommendation::default(),
}
}
}
}
pub fn schedule_with_ai_priority(
&self,
query_id: u64,
detection: &AIWorkloadDetection,
) -> Option<ScheduleResult> {
let scheduler = self.scheduler.as_ref()?;
let query = ScheduledQuery {
id: query_id,
workload_type: detection.workload_type,
timestamp: std::time::Instant::now(),
};
Some(scheduler.schedule(query))
}
pub fn cleanup_idle_sessions(&self) {
let timeout = self.config.session_idle_timeout;
let to_remove: Vec<_> = self.sessions
.iter()
.filter(|e| e.is_idle(timeout))
.map(|e| e.key().clone())
.collect();
for sid in to_remove {
self.sessions.remove(&sid);
}
}
pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
self.semantic_cache.invalidate_branch(branch)
}
pub fn invalidate_table(&self, table: &str) -> usize {
self.semantic_cache.invalidate_by_table(table)
}
pub fn stats(&self) -> AIIntegrationStatsSnapshot {
AIIntegrationStatsSnapshot {
total_detections: self.stats.detections.load(Ordering::Relaxed),
rag_detected: self.stats.rag_detected.load(Ordering::Relaxed),
agent_detected: self.stats.agent_detected.load(Ordering::Relaxed),
tool_detected: self.stats.tool_detected.load(Ordering::Relaxed),
active_sessions: self.sessions.len(),
cross_feature_hits: self.stats.cross_feature_hits.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheRecommendation {
pub should_cache: bool,
pub ttl: Duration,
pub priority: CachePriority,
pub tier: RecommendedTier,
}
impl Default for CacheRecommendation {
fn default() -> Self {
Self {
should_cache: true,
ttl: Duration::from_secs(300),
priority: CachePriority::Medium,
tier: RecommendedTier::L1,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CachePriority {
High,
Medium,
Low,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecommendedTier {
L1,
L2,
L3,
}
#[derive(Debug, Clone)]
pub struct AIIntegrationStatsSnapshot {
pub total_detections: u64,
pub rag_detected: u64,
pub agent_detected: u64,
pub tool_detected: u64,
pub active_sessions: usize,
pub cross_feature_hits: u64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workload_detection_rag() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let detection = coordinator.detect_workload(
"SELECT * FROM chunks WHERE document_id = 1 AND similarity > 0.8",
None,
);
assert_eq!(detection.workload_type, WorkloadType::RAG);
assert_eq!(detection.ai_context, AIWorkloadContext::RAGRetrieval);
assert!(detection.confidence > 0.7);
}
#[test]
fn test_workload_detection_agent() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let detection = coordinator.detect_workload(
"SELECT * FROM conversation_history WHERE session_id = 'abc'",
None,
);
assert_eq!(detection.ai_context, AIWorkloadContext::AgentConversation);
assert!(detection.patterns.contains(&"Agent conversation".to_string()));
}
#[test]
fn test_workload_detection_tool() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let detection = coordinator.detect_workload(
"SELECT tool_calculate_result FROM api_result WHERE id = 1",
None,
);
assert_eq!(detection.ai_context, AIWorkloadContext::ToolResult);
}
#[test]
fn test_workload_detection_olap() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let detection = coordinator.detect_workload(
"SELECT category, SUM(amount) FROM orders GROUP BY category HAVING COUNT(*) > 10",
None,
);
assert_eq!(detection.workload_type, WorkloadType::OLAP);
}
#[test]
fn test_session_tracking() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let sid = SessionId::new("session-1");
coordinator.track_session(
"session-1",
Some(BranchContext::main()),
AIWorkloadContext::AgentConversation,
);
coordinator.update_session(&sid, true);
coordinator.update_session(&sid, false);
let info = coordinator.get_session(&sid).unwrap();
assert_eq!(info.query_count, 2);
assert_eq!(info.cache_hits, 1);
assert!((info.cache_hit_rate() - 0.5).abs() < 0.01);
}
#[test]
fn test_transaction_tracking() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let sid = SessionId::new("session-tx");
coordinator.track_session("session-tx", None, AIWorkloadContext::General);
assert!(!coordinator.is_in_transaction(&sid));
coordinator.begin_transaction(&sid);
assert!(coordinator.is_in_transaction(&sid));
coordinator.end_transaction(&sid);
assert!(!coordinator.is_in_transaction(&sid));
}
#[test]
fn test_cache_recommendation() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
let rag_detection = AIWorkloadDetection {
workload_type: WorkloadType::RAG,
ai_context: AIWorkloadContext::RAGRetrieval,
confidence: 0.9,
patterns: vec![],
};
let rag_rec = coordinator.get_cache_recommendation(&rag_detection);
assert_eq!(rag_rec.priority, CachePriority::High);
assert_eq!(rag_rec.tier, RecommendedTier::L1);
let tool_detection = AIWorkloadDetection {
workload_type: WorkloadType::AIAgent,
ai_context: AIWorkloadContext::ToolResult,
confidence: 0.9,
patterns: vec![],
};
let tool_rec = coordinator.get_cache_recommendation(&tool_detection);
assert_eq!(tool_rec.ttl, Duration::from_secs(86400));
}
#[test]
fn test_stats() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
coordinator.detect_workload("SELECT * FROM chunks", None);
coordinator.detect_workload("SELECT conversation_history", None);
coordinator.detect_workload("SELECT tool_result", None);
let stats = coordinator.stats();
assert_eq!(stats.total_detections, 3);
}
#[test]
fn test_invalidation() {
let cache = Arc::new(SemanticQueryCache::new(0.9));
cache.insert_with_context(
"query1",
vec![1.0, 0.0],
serde_json::json!(1),
Some(BranchContext::main()),
None,
AIWorkloadContext::General,
vec!["users".to_string()],
);
let coordinator = AIIntegrationCoordinator::new(cache.clone(), AIIntegrationConfig::default());
let removed = coordinator.invalidate_table("users");
assert_eq!(removed, 1);
assert_eq!(cache.len(), 0);
}
}