1use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use dashmap::DashMap;
14
15use super::semantic::{AIWorkloadContext, BranchContext, BranchId, SemanticQueryCache};
16use crate::distribcache::classifier::WorkloadType;
17use crate::distribcache::scheduler::{ScheduleResult, ScheduledQuery, WorkloadScheduler};
18use crate::distribcache::SessionId;
19
20#[derive(Debug, Clone)]
22pub struct AIWorkloadDetection {
23 pub workload_type: WorkloadType,
25 pub ai_context: AIWorkloadContext,
27 pub confidence: f32,
29 pub patterns: Vec<String>,
31}
32
33#[derive(Debug, Clone)]
35pub struct SessionTrackingInfo {
36 pub session_id: SessionId,
38 pub branch: Option<BranchContext>,
40 pub last_activity: Instant,
42 pub transaction_depth: u32,
44 pub ai_context: AIWorkloadContext,
46 pub query_count: u64,
48 pub cache_hits: u64,
50}
51
52impl SessionTrackingInfo {
53 pub fn new(session_id: impl Into<String>) -> Self {
55 Self {
56 session_id: SessionId::new(session_id),
57 branch: None,
58 last_activity: Instant::now(),
59 transaction_depth: 0,
60 ai_context: AIWorkloadContext::General,
61 query_count: 0,
62 cache_hits: 0,
63 }
64 }
65
66 pub fn with_branch(mut self, branch: BranchContext) -> Self {
68 self.branch = Some(branch);
69 self
70 }
71
72 pub fn with_ai_context(mut self, context: AIWorkloadContext) -> Self {
74 self.ai_context = context;
75 self
76 }
77
78 pub fn record_query(&mut self) {
80 self.query_count += 1;
81 self.last_activity = Instant::now();
82 }
83
84 pub fn record_cache_hit(&mut self) {
86 self.cache_hits += 1;
87 }
88
89 pub fn cache_hit_rate(&self) -> f64 {
91 if self.query_count == 0 {
92 0.0
93 } else {
94 self.cache_hits as f64 / self.query_count as f64
95 }
96 }
97
98 pub fn is_idle(&self, timeout: Duration) -> bool {
100 self.last_activity.elapsed() > timeout
101 }
102}
103
104pub struct AIIntegrationCoordinator {
106 semantic_cache: Arc<SemanticQueryCache>,
108
109 sessions: DashMap<SessionId, SessionTrackingInfo>,
111
112 scheduler: Option<Arc<WorkloadScheduler>>,
114
115 config: AIIntegrationConfig,
117
118 stats: AIIntegrationStats,
120}
121
122#[derive(Debug, Clone)]
124pub struct AIIntegrationConfig {
125 pub twr_tracking: bool,
127 pub session_idle_timeout: Duration,
129 pub max_sessions: usize,
131 pub workload_detection: bool,
133 pub rag_detection_threshold: f32,
135 pub agent_detection_threshold: f32,
137}
138
139impl Default for AIIntegrationConfig {
140 fn default() -> Self {
141 Self {
142 twr_tracking: true,
143 session_idle_timeout: Duration::from_secs(3600), max_sessions: 10000,
145 workload_detection: true,
146 rag_detection_threshold: 0.7,
147 agent_detection_threshold: 0.8,
148 }
149 }
150}
151
152#[derive(Debug, Default)]
154struct AIIntegrationStats {
155 detections: AtomicU64,
157 rag_detected: AtomicU64,
159 agent_detected: AtomicU64,
161 tool_detected: AtomicU64,
163 sessions_tracked: AtomicU64,
165 cross_feature_hits: AtomicU64,
167}
168
169impl AIIntegrationCoordinator {
170 pub fn new(semantic_cache: Arc<SemanticQueryCache>, config: AIIntegrationConfig) -> Self {
172 Self {
173 semantic_cache,
174 sessions: DashMap::new(),
175 scheduler: None,
176 config,
177 stats: AIIntegrationStats::default(),
178 }
179 }
180
181 pub fn with_scheduler(mut self, scheduler: Arc<WorkloadScheduler>) -> Self {
183 self.scheduler = Some(scheduler);
184 self
185 }
186
187 pub fn detect_workload(&self, query: &str, session: Option<&SessionId>) -> AIWorkloadDetection {
189 self.stats.detections.fetch_add(1, Ordering::Relaxed);
190
191 let mut patterns = Vec::new();
192 let mut confidence = 0.0f32;
193 let mut ai_context = AIWorkloadContext::General;
194 let mut workload_type = WorkloadType::Mixed;
195
196 let query_lower = query.to_lowercase();
198
199 if self.is_rag_pattern(&query_lower) {
201 patterns.push("RAG retrieval".to_string());
202 confidence = 0.85;
203 ai_context = AIWorkloadContext::RAGRetrieval;
204 workload_type = WorkloadType::RAG;
205 self.stats.rag_detected.fetch_add(1, Ordering::Relaxed);
206 }
207
208 if self.is_agent_pattern(&query_lower, session) {
210 patterns.push("Agent conversation".to_string());
211 confidence = confidence.max(0.80);
212 ai_context = AIWorkloadContext::AgentConversation;
213 workload_type = WorkloadType::AIAgent;
214 self.stats.agent_detected.fetch_add(1, Ordering::Relaxed);
215 }
216
217 if self.is_tool_pattern(&query_lower) {
219 patterns.push("Tool result".to_string());
220 confidence = confidence.max(0.90);
221 ai_context = AIWorkloadContext::ToolResult;
222 workload_type = WorkloadType::AIAgent;
223 self.stats.tool_detected.fetch_add(1, Ordering::Relaxed);
224 }
225
226 if workload_type != WorkloadType::RAG {
228 if query_lower.contains("embedding") || query_lower.contains("vector") || query_lower.contains("similarity") {
229 patterns.push("Vector search".to_string());
230 confidence = confidence.max(0.75);
231 workload_type = WorkloadType::Vector;
232 }
233 }
234
235 if self.is_olap_pattern(&query_lower) {
237 patterns.push("OLAP analytics".to_string());
238 confidence = confidence.max(0.70);
239 workload_type = WorkloadType::OLAP;
240 }
241
242 AIWorkloadDetection {
243 workload_type,
244 ai_context,
245 confidence,
246 patterns,
247 }
248 }
249
250 fn is_rag_pattern(&self, query: &str) -> bool {
252 let rag_patterns = [
257 "chunk", "retrieve", "context", "passage", "document",
258 "semantic", "similarity", "cosine", "embedding",
259 ];
260
261 rag_patterns.iter().any(|p| query.contains(p))
262 }
263
264 fn is_agent_pattern(&self, query: &str, session: Option<&SessionId>) -> bool {
266 if let Some(sid) = session {
268 if let Some(info) = self.sessions.get(sid) {
269 if info.ai_context == AIWorkloadContext::AgentConversation {
270 return true;
271 }
272 if info.query_count > 5 && info.cache_hit_rate() > 0.3 {
274 return true;
275 }
276 }
277 }
278
279 let agent_patterns = ["conversation", "history", "context", "message", "response"];
281 agent_patterns.iter().any(|p| query.contains(p))
282 }
283
284 fn is_tool_pattern(&self, query: &str) -> bool {
286 let tool_patterns = [
287 "tool_", "function_", "api_result", "calculate",
288 "format_", "convert_", "lookup_",
289 ];
290
291 tool_patterns.iter().any(|p| query.contains(p))
292 }
293
294 fn is_olap_pattern(&self, query: &str) -> bool {
296 let olap_patterns = [
297 "group by", "having", "aggregate", "sum(", "count(",
298 "avg(", "window", "partition by", "rollup", "cube",
299 ];
300
301 olap_patterns.iter().any(|p| query.contains(p))
302 }
303
304 pub fn track_session(
306 &self,
307 session_id: impl Into<String>,
308 branch: Option<BranchContext>,
309 ai_context: AIWorkloadContext,
310 ) {
311 let sid = SessionId::new(session_id);
312
313 if self.sessions.len() >= self.config.max_sessions {
315 self.cleanup_idle_sessions();
316 }
317
318 let mut info = SessionTrackingInfo::new(sid.0.clone())
319 .with_ai_context(ai_context);
320
321 if let Some(b) = branch {
322 info = info.with_branch(b);
323 }
324
325 self.sessions.insert(sid, info);
326 self.stats.sessions_tracked.fetch_add(1, Ordering::Relaxed);
327 }
328
329 pub fn update_session(&self, session_id: &SessionId, cache_hit: bool) {
331 if let Some(mut info) = self.sessions.get_mut(session_id) {
332 info.record_query();
333 if cache_hit {
334 info.record_cache_hit();
335 self.stats.cross_feature_hits.fetch_add(1, Ordering::Relaxed);
336 }
337 }
338 }
339
340 pub fn get_session(&self, session_id: &SessionId) -> Option<SessionTrackingInfo> {
342 self.sessions.get(session_id).map(|r| r.clone())
343 }
344
345 pub fn begin_transaction(&self, session_id: &SessionId) {
347 if let Some(mut info) = self.sessions.get_mut(session_id) {
348 info.transaction_depth += 1;
349 }
350 }
351
352 pub fn end_transaction(&self, session_id: &SessionId) {
354 if let Some(mut info) = self.sessions.get_mut(session_id) {
355 if info.transaction_depth > 0 {
356 info.transaction_depth -= 1;
357 }
358 }
359 }
360
361 pub fn is_in_transaction(&self, session_id: &SessionId) -> bool {
363 self.sessions
364 .get(session_id)
365 .map(|info| info.transaction_depth > 0)
366 .unwrap_or(false)
367 }
368
369 pub fn get_cache_recommendation(&self, detection: &AIWorkloadDetection) -> CacheRecommendation {
371 match detection.ai_context {
372 AIWorkloadContext::RAGRetrieval => CacheRecommendation {
373 should_cache: true,
374 ttl: Duration::from_secs(300),
375 priority: CachePriority::High,
376 tier: RecommendedTier::L1,
377 },
378 AIWorkloadContext::RAGGeneration => CacheRecommendation {
379 should_cache: true,
380 ttl: Duration::from_secs(1800),
381 priority: CachePriority::Medium,
382 tier: RecommendedTier::L2,
383 },
384 AIWorkloadContext::AgentConversation => CacheRecommendation {
385 should_cache: true,
386 ttl: Duration::from_secs(3600),
387 priority: CachePriority::High,
388 tier: RecommendedTier::L1,
389 },
390 AIWorkloadContext::ToolResult => CacheRecommendation {
391 should_cache: true,
392 ttl: Duration::from_secs(86400),
393 priority: CachePriority::Low,
394 tier: RecommendedTier::L2,
395 },
396 AIWorkloadContext::General => {
397 match detection.workload_type {
398 WorkloadType::OLTP => CacheRecommendation {
399 should_cache: true,
400 ttl: Duration::from_secs(60),
401 priority: CachePriority::High,
402 tier: RecommendedTier::L1,
403 },
404 WorkloadType::OLAP => CacheRecommendation {
405 should_cache: true,
406 ttl: Duration::from_secs(3600),
407 priority: CachePriority::Low,
408 tier: RecommendedTier::L3,
409 },
410 WorkloadType::Vector => CacheRecommendation {
411 should_cache: true,
412 ttl: Duration::from_secs(600),
413 priority: CachePriority::Medium,
414 tier: RecommendedTier::L2,
415 },
416 _ => CacheRecommendation::default(),
417 }
418 }
419 }
420 }
421
422 pub fn schedule_with_ai_priority(
424 &self,
425 query_id: u64,
426 detection: &AIWorkloadDetection,
427 ) -> Option<ScheduleResult> {
428 let scheduler = self.scheduler.as_ref()?;
429
430 let query = ScheduledQuery {
431 id: query_id,
432 workload_type: detection.workload_type,
433 timestamp: std::time::Instant::now(),
434 };
435
436 Some(scheduler.schedule(query))
437 }
438
439 pub fn cleanup_idle_sessions(&self) {
441 let timeout = self.config.session_idle_timeout;
442 let to_remove: Vec<_> = self.sessions
443 .iter()
444 .filter(|e| e.is_idle(timeout))
445 .map(|e| e.key().clone())
446 .collect();
447
448 for sid in to_remove {
449 self.sessions.remove(&sid);
450 }
451 }
452
453 pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
455 self.semantic_cache.invalidate_branch(branch)
456 }
457
458 pub fn invalidate_table(&self, table: &str) -> usize {
460 self.semantic_cache.invalidate_by_table(table)
461 }
462
463 pub fn stats(&self) -> AIIntegrationStatsSnapshot {
465 AIIntegrationStatsSnapshot {
466 total_detections: self.stats.detections.load(Ordering::Relaxed),
467 rag_detected: self.stats.rag_detected.load(Ordering::Relaxed),
468 agent_detected: self.stats.agent_detected.load(Ordering::Relaxed),
469 tool_detected: self.stats.tool_detected.load(Ordering::Relaxed),
470 active_sessions: self.sessions.len(),
471 cross_feature_hits: self.stats.cross_feature_hits.load(Ordering::Relaxed),
472 }
473 }
474}
475
476#[derive(Debug, Clone)]
478pub struct CacheRecommendation {
479 pub should_cache: bool,
481 pub ttl: Duration,
483 pub priority: CachePriority,
485 pub tier: RecommendedTier,
487}
488
489impl Default for CacheRecommendation {
490 fn default() -> Self {
491 Self {
492 should_cache: true,
493 ttl: Duration::from_secs(300),
494 priority: CachePriority::Medium,
495 tier: RecommendedTier::L1,
496 }
497 }
498}
499
500#[derive(Debug, Clone, Copy, PartialEq, Eq)]
502pub enum CachePriority {
503 High,
504 Medium,
505 Low,
506}
507
508#[derive(Debug, Clone, Copy, PartialEq, Eq)]
510pub enum RecommendedTier {
511 L1,
512 L2,
513 L3,
514}
515
516#[derive(Debug, Clone)]
518pub struct AIIntegrationStatsSnapshot {
519 pub total_detections: u64,
520 pub rag_detected: u64,
521 pub agent_detected: u64,
522 pub tool_detected: u64,
523 pub active_sessions: usize,
524 pub cross_feature_hits: u64,
525}
526
527#[cfg(test)]
528mod tests {
529 use super::*;
530
531 #[test]
532 fn test_workload_detection_rag() {
533 let cache = Arc::new(SemanticQueryCache::new(0.9));
534 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
535
536 let detection = coordinator.detect_workload(
537 "SELECT * FROM chunks WHERE document_id = 1 AND similarity > 0.8",
538 None,
539 );
540
541 assert_eq!(detection.workload_type, WorkloadType::RAG);
542 assert_eq!(detection.ai_context, AIWorkloadContext::RAGRetrieval);
543 assert!(detection.confidence > 0.7);
544 }
545
546 #[test]
547 fn test_workload_detection_agent() {
548 let cache = Arc::new(SemanticQueryCache::new(0.9));
549 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
550
551 let detection = coordinator.detect_workload(
552 "SELECT * FROM conversation_history WHERE session_id = 'abc'",
553 None,
554 );
555
556 assert_eq!(detection.ai_context, AIWorkloadContext::AgentConversation);
557 assert!(detection.patterns.contains(&"Agent conversation".to_string()));
558 }
559
560 #[test]
561 fn test_workload_detection_tool() {
562 let cache = Arc::new(SemanticQueryCache::new(0.9));
563 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
564
565 let detection = coordinator.detect_workload(
566 "SELECT tool_calculate_result FROM api_result WHERE id = 1",
567 None,
568 );
569
570 assert_eq!(detection.ai_context, AIWorkloadContext::ToolResult);
571 }
572
573 #[test]
574 fn test_workload_detection_olap() {
575 let cache = Arc::new(SemanticQueryCache::new(0.9));
576 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
577
578 let detection = coordinator.detect_workload(
579 "SELECT category, SUM(amount) FROM orders GROUP BY category HAVING COUNT(*) > 10",
580 None,
581 );
582
583 assert_eq!(detection.workload_type, WorkloadType::OLAP);
584 }
585
586 #[test]
587 fn test_session_tracking() {
588 let cache = Arc::new(SemanticQueryCache::new(0.9));
589 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
590
591 let sid = SessionId::new("session-1");
592
593 coordinator.track_session(
595 "session-1",
596 Some(BranchContext::main()),
597 AIWorkloadContext::AgentConversation,
598 );
599
600 coordinator.update_session(&sid, true);
602 coordinator.update_session(&sid, false);
603
604 let info = coordinator.get_session(&sid).unwrap();
605 assert_eq!(info.query_count, 2);
606 assert_eq!(info.cache_hits, 1);
607 assert!((info.cache_hit_rate() - 0.5).abs() < 0.01);
608 }
609
610 #[test]
611 fn test_transaction_tracking() {
612 let cache = Arc::new(SemanticQueryCache::new(0.9));
613 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
614
615 let sid = SessionId::new("session-tx");
616 coordinator.track_session("session-tx", None, AIWorkloadContext::General);
617
618 assert!(!coordinator.is_in_transaction(&sid));
619
620 coordinator.begin_transaction(&sid);
621 assert!(coordinator.is_in_transaction(&sid));
622
623 coordinator.end_transaction(&sid);
624 assert!(!coordinator.is_in_transaction(&sid));
625 }
626
627 #[test]
628 fn test_cache_recommendation() {
629 let cache = Arc::new(SemanticQueryCache::new(0.9));
630 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
631
632 let rag_detection = AIWorkloadDetection {
634 workload_type: WorkloadType::RAG,
635 ai_context: AIWorkloadContext::RAGRetrieval,
636 confidence: 0.9,
637 patterns: vec![],
638 };
639 let rag_rec = coordinator.get_cache_recommendation(&rag_detection);
640 assert_eq!(rag_rec.priority, CachePriority::High);
641 assert_eq!(rag_rec.tier, RecommendedTier::L1);
642
643 let tool_detection = AIWorkloadDetection {
645 workload_type: WorkloadType::AIAgent,
646 ai_context: AIWorkloadContext::ToolResult,
647 confidence: 0.9,
648 patterns: vec![],
649 };
650 let tool_rec = coordinator.get_cache_recommendation(&tool_detection);
651 assert_eq!(tool_rec.ttl, Duration::from_secs(86400));
652 }
653
654 #[test]
655 fn test_stats() {
656 let cache = Arc::new(SemanticQueryCache::new(0.9));
657 let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
658
659 coordinator.detect_workload("SELECT * FROM chunks", None);
661 coordinator.detect_workload("SELECT conversation_history", None);
662 coordinator.detect_workload("SELECT tool_result", None);
663
664 let stats = coordinator.stats();
665 assert_eq!(stats.total_detections, 3);
666 }
667
668 #[test]
669 fn test_invalidation() {
670 let cache = Arc::new(SemanticQueryCache::new(0.9));
671
672 cache.insert_with_context(
674 "query1",
675 vec![1.0, 0.0],
676 serde_json::json!(1),
677 Some(BranchContext::main()),
678 None,
679 AIWorkloadContext::General,
680 vec!["users".to_string()],
681 );
682
683 let coordinator = AIIntegrationCoordinator::new(cache.clone(), AIIntegrationConfig::default());
684
685 let removed = coordinator.invalidate_table("users");
687 assert_eq!(removed, 1);
688 assert_eq!(cache.len(), 0);
689 }
690}