1use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use dashmap::DashMap;
14
15use super::semantic::{AIWorkloadContext, BranchContext, BranchId, SemanticQueryCache};
16use crate::distribcache::classifier::WorkloadType;
17use crate::distribcache::scheduler::{ScheduleResult, ScheduledQuery, WorkloadScheduler};
18use crate::distribcache::SessionId;
19
20#[derive(Debug, Clone)]
22pub struct AIWorkloadDetection {
23 pub workload_type: WorkloadType,
25 pub ai_context: AIWorkloadContext,
27 pub confidence: f32,
29 pub patterns: Vec<String>,
31}
32
33#[derive(Debug, Clone)]
35pub struct SessionTrackingInfo {
36 pub session_id: SessionId,
38 pub branch: Option<BranchContext>,
40 pub last_activity: Instant,
42 pub transaction_depth: u32,
44 pub ai_context: AIWorkloadContext,
46 pub query_count: u64,
48 pub cache_hits: u64,
50}
51
52impl SessionTrackingInfo {
53 pub fn new(session_id: impl Into<String>) -> Self {
55 Self {
56 session_id: SessionId::new(session_id),
57 branch: None,
58 last_activity: Instant::now(),
59 transaction_depth: 0,
60 ai_context: AIWorkloadContext::General,
61 query_count: 0,
62 cache_hits: 0,
63 }
64 }
65
66 pub fn with_branch(mut self, branch: BranchContext) -> Self {
68 self.branch = Some(branch);
69 self
70 }
71
72 pub fn with_ai_context(mut self, context: AIWorkloadContext) -> Self {
74 self.ai_context = context;
75 self
76 }
77
78 pub fn record_query(&mut self) {
80 self.query_count += 1;
81 self.last_activity = Instant::now();
82 }
83
84 pub fn record_cache_hit(&mut self) {
86 self.cache_hits += 1;
87 }
88
89 pub fn cache_hit_rate(&self) -> f64 {
91 if self.query_count == 0 {
92 0.0
93 } else {
94 self.cache_hits as f64 / self.query_count as f64
95 }
96 }
97
98 pub fn is_idle(&self, timeout: Duration) -> bool {
100 self.last_activity.elapsed() > timeout
101 }
102}
103
104pub struct AIIntegrationCoordinator {
106 semantic_cache: Arc<SemanticQueryCache>,
108
109 sessions: DashMap<SessionId, SessionTrackingInfo>,
111
112 scheduler: Option<Arc<WorkloadScheduler>>,
114
115 config: AIIntegrationConfig,
117
118 stats: AIIntegrationStats,
120}
121
122#[derive(Debug, Clone)]
124pub struct AIIntegrationConfig {
125 pub twr_tracking: bool,
127 pub session_idle_timeout: Duration,
129 pub max_sessions: usize,
131 pub workload_detection: bool,
133 pub rag_detection_threshold: f32,
135 pub agent_detection_threshold: f32,
137}
138
139impl Default for AIIntegrationConfig {
140 fn default() -> Self {
141 Self {
142 twr_tracking: true,
143 session_idle_timeout: Duration::from_secs(3600), max_sessions: 10000,
145 workload_detection: true,
146 rag_detection_threshold: 0.7,
147 agent_detection_threshold: 0.8,
148 }
149 }
150}
151
152#[derive(Debug, Default)]
154struct AIIntegrationStats {
155 detections: AtomicU64,
157 rag_detected: AtomicU64,
159 agent_detected: AtomicU64,
161 tool_detected: AtomicU64,
163 sessions_tracked: AtomicU64,
165 cross_feature_hits: AtomicU64,
167}
168
169impl AIIntegrationCoordinator {
170 pub fn new(semantic_cache: Arc<SemanticQueryCache>, config: AIIntegrationConfig) -> Self {
172 Self {
173 semantic_cache,
174 sessions: DashMap::new(),
175 scheduler: None,
176 config,
177 stats: AIIntegrationStats::default(),
178 }
179 }
180
181 pub fn with_scheduler(mut self, scheduler: Arc<WorkloadScheduler>) -> Self {
183 self.scheduler = Some(scheduler);
184 self
185 }
186
187 pub fn detect_workload(&self, query: &str, session: Option<&SessionId>) -> AIWorkloadDetection {
189 self.stats.detections.fetch_add(1, Ordering::Relaxed);
190
191 let mut patterns = Vec::new();
192 let mut confidence = 0.0f32;
193 let mut ai_context = AIWorkloadContext::General;
194 let mut workload_type = WorkloadType::Mixed;
195
196 let query_lower = query.to_lowercase();
198
199 if self.is_rag_pattern(&query_lower) {
201 patterns.push("RAG retrieval".to_string());
202 confidence = 0.85;
203 ai_context = AIWorkloadContext::RAGRetrieval;
204 workload_type = WorkloadType::RAG;
205 self.stats.rag_detected.fetch_add(1, Ordering::Relaxed);
206 }
207
208 if self.is_agent_pattern(&query_lower, session) {
210 patterns.push("Agent conversation".to_string());
211 confidence = confidence.max(0.80);
212 ai_context = AIWorkloadContext::AgentConversation;
213 workload_type = WorkloadType::AIAgent;
214 self.stats.agent_detected.fetch_add(1, Ordering::Relaxed);
215 }
216
217 if self.is_tool_pattern(&query_lower) {
219 patterns.push("Tool result".to_string());
220 confidence = confidence.max(0.90);
221 ai_context = AIWorkloadContext::ToolResult;
222 workload_type = WorkloadType::AIAgent;
223 self.stats.tool_detected.fetch_add(1, Ordering::Relaxed);
224 }
225
226 if workload_type != WorkloadType::RAG
228 && (query_lower.contains("embedding")
229 || query_lower.contains("vector")
230 || query_lower.contains("similarity"))
231 {
232 patterns.push("Vector search".to_string());
233 confidence = confidence.max(0.75);
234 workload_type = WorkloadType::Vector;
235 }
236
237 if self.is_olap_pattern(&query_lower) {
239 patterns.push("OLAP analytics".to_string());
240 confidence = confidence.max(0.70);
241 workload_type = WorkloadType::OLAP;
242 }
243
244 AIWorkloadDetection {
245 workload_type,
246 ai_context,
247 confidence,
248 patterns,
249 }
250 }
251
252 fn is_rag_pattern(&self, query: &str) -> bool {
254 let rag_patterns = [
259 "chunk",
260 "retrieve",
261 "context",
262 "passage",
263 "document",
264 "semantic",
265 "similarity",
266 "cosine",
267 "embedding",
268 ];
269
270 rag_patterns.iter().any(|p| query.contains(p))
271 }
272
273 fn is_agent_pattern(&self, query: &str, session: Option<&SessionId>) -> bool {
275 if let Some(sid) = session {
277 if let Some(info) = self.sessions.get(sid) {
278 if info.ai_context == AIWorkloadContext::AgentConversation {
279 return true;
280 }
281 if info.query_count > 5 && info.cache_hit_rate() > 0.3 {
283 return true;
284 }
285 }
286 }
287
288 let agent_patterns = ["conversation", "history", "context", "message", "response"];
290 agent_patterns.iter().any(|p| query.contains(p))
291 }
292
293 fn is_tool_pattern(&self, query: &str) -> bool {
295 let tool_patterns = [
296 "tool_",
297 "function_",
298 "api_result",
299 "calculate",
300 "format_",
301 "convert_",
302 "lookup_",
303 ];
304
305 tool_patterns.iter().any(|p| query.contains(p))
306 }
307
308 fn is_olap_pattern(&self, query: &str) -> bool {
310 let olap_patterns = [
311 "group by",
312 "having",
313 "aggregate",
314 "sum(",
315 "count(",
316 "avg(",
317 "window",
318 "partition by",
319 "rollup",
320 "cube",
321 ];
322
323 olap_patterns.iter().any(|p| query.contains(p))
324 }
325
326 pub fn track_session(
328 &self,
329 session_id: impl Into<String>,
330 branch: Option<BranchContext>,
331 ai_context: AIWorkloadContext,
332 ) {
333 let sid = SessionId::new(session_id);
334
335 if self.sessions.len() >= self.config.max_sessions {
337 self.cleanup_idle_sessions();
338 }
339
340 let mut info = SessionTrackingInfo::new(sid.0.clone()).with_ai_context(ai_context);
341
342 if let Some(b) = branch {
343 info = info.with_branch(b);
344 }
345
346 self.sessions.insert(sid, info);
347 self.stats.sessions_tracked.fetch_add(1, Ordering::Relaxed);
348 }
349
350 pub fn update_session(&self, session_id: &SessionId, cache_hit: bool) {
352 if let Some(mut info) = self.sessions.get_mut(session_id) {
353 info.record_query();
354 if cache_hit {
355 info.record_cache_hit();
356 self.stats
357 .cross_feature_hits
358 .fetch_add(1, Ordering::Relaxed);
359 }
360 }
361 }
362
363 pub fn get_session(&self, session_id: &SessionId) -> Option<SessionTrackingInfo> {
365 self.sessions.get(session_id).map(|r| r.clone())
366 }
367
368 pub fn begin_transaction(&self, session_id: &SessionId) {
370 if let Some(mut info) = self.sessions.get_mut(session_id) {
371 info.transaction_depth += 1;
372 }
373 }
374
375 pub fn end_transaction(&self, session_id: &SessionId) {
377 if let Some(mut info) = self.sessions.get_mut(session_id) {
378 if info.transaction_depth > 0 {
379 info.transaction_depth -= 1;
380 }
381 }
382 }
383
384 pub fn is_in_transaction(&self, session_id: &SessionId) -> bool {
386 self.sessions
387 .get(session_id)
388 .map(|info| info.transaction_depth > 0)
389 .unwrap_or(false)
390 }
391
392 pub fn get_cache_recommendation(&self, detection: &AIWorkloadDetection) -> CacheRecommendation {
394 match detection.ai_context {
395 AIWorkloadContext::RAGRetrieval => CacheRecommendation {
396 should_cache: true,
397 ttl: Duration::from_secs(300),
398 priority: CachePriority::High,
399 tier: RecommendedTier::L1,
400 },
401 AIWorkloadContext::RAGGeneration => CacheRecommendation {
402 should_cache: true,
403 ttl: Duration::from_secs(1800),
404 priority: CachePriority::Medium,
405 tier: RecommendedTier::L2,
406 },
407 AIWorkloadContext::AgentConversation => CacheRecommendation {
408 should_cache: true,
409 ttl: Duration::from_secs(3600),
410 priority: CachePriority::High,
411 tier: RecommendedTier::L1,
412 },
413 AIWorkloadContext::ToolResult => CacheRecommendation {
414 should_cache: true,
415 ttl: Duration::from_secs(86400),
416 priority: CachePriority::Low,
417 tier: RecommendedTier::L2,
418 },
419 AIWorkloadContext::General => match detection.workload_type {
420 WorkloadType::OLTP => CacheRecommendation {
421 should_cache: true,
422 ttl: Duration::from_secs(60),
423 priority: CachePriority::High,
424 tier: RecommendedTier::L1,
425 },
426 WorkloadType::OLAP => CacheRecommendation {
427 should_cache: true,
428 ttl: Duration::from_secs(3600),
429 priority: CachePriority::Low,
430 tier: RecommendedTier::L3,
431 },
432 WorkloadType::Vector => CacheRecommendation {
433 should_cache: true,
434 ttl: Duration::from_secs(600),
435 priority: CachePriority::Medium,
436 tier: RecommendedTier::L2,
437 },
438 _ => CacheRecommendation::default(),
439 },
440 }
441 }
442
443 pub fn schedule_with_ai_priority(
445 &self,
446 query_id: u64,
447 detection: &AIWorkloadDetection,
448 ) -> Option<ScheduleResult> {
449 let scheduler = self.scheduler.as_ref()?;
450
451 let query = ScheduledQuery {
452 id: query_id,
453 workload_type: detection.workload_type,
454 timestamp: std::time::Instant::now(),
455 };
456
457 Some(scheduler.schedule(query))
458 }
459
460 pub fn cleanup_idle_sessions(&self) {
462 let timeout = self.config.session_idle_timeout;
463 let to_remove: Vec<_> = self
464 .sessions
465 .iter()
466 .filter(|e| e.is_idle(timeout))
467 .map(|e| e.key().clone())
468 .collect();
469
470 for sid in to_remove {
471 self.sessions.remove(&sid);
472 }
473 }
474
475 pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
477 self.semantic_cache.invalidate_branch(branch)
478 }
479
480 pub fn invalidate_table(&self, table: &str) -> usize {
482 self.semantic_cache.invalidate_by_table(table)
483 }
484
485 pub fn stats(&self) -> AIIntegrationStatsSnapshot {
487 AIIntegrationStatsSnapshot {
488 total_detections: self.stats.detections.load(Ordering::Relaxed),
489 rag_detected: self.stats.rag_detected.load(Ordering::Relaxed),
490 agent_detected: self.stats.agent_detected.load(Ordering::Relaxed),
491 tool_detected: self.stats.tool_detected.load(Ordering::Relaxed),
492 active_sessions: self.sessions.len(),
493 cross_feature_hits: self.stats.cross_feature_hits.load(Ordering::Relaxed),
494 }
495 }
496}
497
498#[derive(Debug, Clone)]
500pub struct CacheRecommendation {
501 pub should_cache: bool,
503 pub ttl: Duration,
505 pub priority: CachePriority,
507 pub tier: RecommendedTier,
509}
510
511impl Default for CacheRecommendation {
512 fn default() -> Self {
513 Self {
514 should_cache: true,
515 ttl: Duration::from_secs(300),
516 priority: CachePriority::Medium,
517 tier: RecommendedTier::L1,
518 }
519 }
520}
521
522#[derive(Debug, Clone, Copy, PartialEq, Eq)]
524pub enum CachePriority {
525 High,
526 Medium,
527 Low,
528}
529
530#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532pub enum RecommendedTier {
533 L1,
534 L2,
535 L3,
536}
537
538#[derive(Debug, Clone)]
540pub struct AIIntegrationStatsSnapshot {
541 pub total_detections: u64,
542 pub rag_detected: u64,
543 pub agent_detected: u64,
544 pub tool_detected: u64,
545 pub active_sessions: usize,
546 pub cross_feature_hits: u64,
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552
553 #[test]
554 fn test_workload_detection_rag() {
555 let cache = Arc::new(SemanticQueryCache::new(0.9));
556 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
557
558 let detection = coordinator.detect_workload(
559 "SELECT * FROM chunks WHERE document_id = 1 AND similarity > 0.8",
560 None,
561 );
562
563 assert_eq!(detection.workload_type, WorkloadType::RAG);
564 assert_eq!(detection.ai_context, AIWorkloadContext::RAGRetrieval);
565 assert!(detection.confidence > 0.7);
566 }
567
568 #[test]
569 fn test_workload_detection_agent() {
570 let cache = Arc::new(SemanticQueryCache::new(0.9));
571 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
572
573 let detection = coordinator.detect_workload(
574 "SELECT * FROM conversation_history WHERE session_id = 'abc'",
575 None,
576 );
577
578 assert_eq!(detection.ai_context, AIWorkloadContext::AgentConversation);
579 assert!(detection
580 .patterns
581 .contains(&"Agent conversation".to_string()));
582 }
583
584 #[test]
585 fn test_workload_detection_tool() {
586 let cache = Arc::new(SemanticQueryCache::new(0.9));
587 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
588
589 let detection = coordinator.detect_workload(
590 "SELECT tool_calculate_result FROM api_result WHERE id = 1",
591 None,
592 );
593
594 assert_eq!(detection.ai_context, AIWorkloadContext::ToolResult);
595 }
596
597 #[test]
598 fn test_workload_detection_olap() {
599 let cache = Arc::new(SemanticQueryCache::new(0.9));
600 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
601
602 let detection = coordinator.detect_workload(
603 "SELECT category, SUM(amount) FROM orders GROUP BY category HAVING COUNT(*) > 10",
604 None,
605 );
606
607 assert_eq!(detection.workload_type, WorkloadType::OLAP);
608 }
609
610 #[test]
611 fn test_session_tracking() {
612 let cache = Arc::new(SemanticQueryCache::new(0.9));
613 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
614
615 let sid = SessionId::new("session-1");
616
617 coordinator.track_session(
619 "session-1",
620 Some(BranchContext::main()),
621 AIWorkloadContext::AgentConversation,
622 );
623
624 coordinator.update_session(&sid, true);
626 coordinator.update_session(&sid, false);
627
628 let info = coordinator.get_session(&sid).unwrap();
629 assert_eq!(info.query_count, 2);
630 assert_eq!(info.cache_hits, 1);
631 assert!((info.cache_hit_rate() - 0.5).abs() < 0.01);
632 }
633
634 #[test]
635 fn test_transaction_tracking() {
636 let cache = Arc::new(SemanticQueryCache::new(0.9));
637 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
638
639 let sid = SessionId::new("session-tx");
640 coordinator.track_session("session-tx", None, AIWorkloadContext::General);
641
642 assert!(!coordinator.is_in_transaction(&sid));
643
644 coordinator.begin_transaction(&sid);
645 assert!(coordinator.is_in_transaction(&sid));
646
647 coordinator.end_transaction(&sid);
648 assert!(!coordinator.is_in_transaction(&sid));
649 }
650
651 #[test]
652 fn test_cache_recommendation() {
653 let cache = Arc::new(SemanticQueryCache::new(0.9));
654 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
655
656 let rag_detection = AIWorkloadDetection {
658 workload_type: WorkloadType::RAG,
659 ai_context: AIWorkloadContext::RAGRetrieval,
660 confidence: 0.9,
661 patterns: vec![],
662 };
663 let rag_rec = coordinator.get_cache_recommendation(&rag_detection);
664 assert_eq!(rag_rec.priority, CachePriority::High);
665 assert_eq!(rag_rec.tier, RecommendedTier::L1);
666
667 let tool_detection = AIWorkloadDetection {
669 workload_type: WorkloadType::AIAgent,
670 ai_context: AIWorkloadContext::ToolResult,
671 confidence: 0.9,
672 patterns: vec![],
673 };
674 let tool_rec = coordinator.get_cache_recommendation(&tool_detection);
675 assert_eq!(tool_rec.ttl, Duration::from_secs(86400));
676 }
677
678 #[test]
679 fn test_stats() {
680 let cache = Arc::new(SemanticQueryCache::new(0.9));
681 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
682
683 coordinator.detect_workload("SELECT * FROM chunks", None);
685 coordinator.detect_workload("SELECT conversation_history", None);
686 coordinator.detect_workload("SELECT tool_result", None);
687
688 let stats = coordinator.stats();
689 assert_eq!(stats.total_detections, 3);
690 }
691
692 #[test]
693 fn test_invalidation() {
694 let cache = Arc::new(SemanticQueryCache::new(0.9));
695
696 cache.insert_with_context(
698 "query1",
699 vec![1.0, 0.0],
700 serde_json::json!(1),
701 Some(BranchContext::main()),
702 None,
703 AIWorkloadContext::General,
704 vec!["users".to_string()],
705 );
706
707 let coordinator =
708 AIIntegrationCoordinator::new(cache.clone(), AIIntegrationConfig::default());
709
710 let removed = coordinator.invalidate_table("users");
712 assert_eq!(removed, 1);
713 assert_eq!(cache.len(), 0);
714 }
715}