Skip to main content

heliosdb_proxy/distribcache/ai/
integration.rs

1//! Cross-feature AI/Agent integration module
2//!
3//! Integrates semantic caching with other HeliosProxy features:
4//! - TWR (Transaction Write Replay) session tracking
5//! - Sync mode / lag routing awareness
6//! - Workload scheduler coordination
7//! - Branch-aware time-travel queries
8
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use dashmap::DashMap;
14
15use super::semantic::{AIWorkloadContext, BranchContext, BranchId, SemanticQueryCache};
16use crate::distribcache::classifier::WorkloadType;
17use crate::distribcache::scheduler::{ScheduleResult, ScheduledQuery, WorkloadScheduler};
18use crate::distribcache::SessionId;
19
20/// AI workload detection result
21#[derive(Debug, Clone)]
22pub struct AIWorkloadDetection {
23    /// Detected workload type
24    pub workload_type: WorkloadType,
25    /// AI-specific context
26    pub ai_context: AIWorkloadContext,
27    /// Confidence score (0.0 - 1.0)
28    pub confidence: f32,
29    /// Detected patterns
30    pub patterns: Vec<String>,
31}
32
33/// Session tracking for TWR integration
34#[derive(Debug, Clone)]
35pub struct SessionTrackingInfo {
36    /// Session identifier
37    pub session_id: SessionId,
38    /// Active branch context
39    pub branch: Option<BranchContext>,
40    /// Last query time
41    pub last_activity: Instant,
42    /// Transaction depth (0 = not in transaction)
43    pub transaction_depth: u32,
44    /// AI workload context
45    pub ai_context: AIWorkloadContext,
46    /// Total queries in session
47    pub query_count: u64,
48    /// Cache hit count
49    pub cache_hits: u64,
50}
51
52impl SessionTrackingInfo {
53    /// Create a new session tracking info
54    pub fn new(session_id: impl Into<String>) -> Self {
55        Self {
56            session_id: SessionId::new(session_id),
57            branch: None,
58            last_activity: Instant::now(),
59            transaction_depth: 0,
60            ai_context: AIWorkloadContext::General,
61            query_count: 0,
62            cache_hits: 0,
63        }
64    }
65
66    /// Set branch context
67    pub fn with_branch(mut self, branch: BranchContext) -> Self {
68        self.branch = Some(branch);
69        self
70    }
71
72    /// Set AI context
73    pub fn with_ai_context(mut self, context: AIWorkloadContext) -> Self {
74        self.ai_context = context;
75        self
76    }
77
78    /// Record a query
79    pub fn record_query(&mut self) {
80        self.query_count += 1;
81        self.last_activity = Instant::now();
82    }
83
84    /// Record a cache hit
85    pub fn record_cache_hit(&mut self) {
86        self.cache_hits += 1;
87    }
88
89    /// Get cache hit rate
90    pub fn cache_hit_rate(&self) -> f64 {
91        if self.query_count == 0 {
92            0.0
93        } else {
94            self.cache_hits as f64 / self.query_count as f64
95        }
96    }
97
98    /// Check if session is idle
99    pub fn is_idle(&self, timeout: Duration) -> bool {
100        self.last_activity.elapsed() > timeout
101    }
102}
103
104/// Cross-feature AI integration coordinator
105pub struct AIIntegrationCoordinator {
106    /// Semantic cache reference
107    semantic_cache: Arc<SemanticQueryCache>,
108
109    /// Session tracking
110    sessions: DashMap<SessionId, SessionTrackingInfo>,
111
112    /// Workload scheduler reference (optional)
113    scheduler: Option<Arc<WorkloadScheduler>>,
114
115    /// Configuration
116    config: AIIntegrationConfig,
117
118    /// Statistics
119    stats: AIIntegrationStats,
120}
121
122/// Configuration for AI integration
123#[derive(Debug, Clone)]
124pub struct AIIntegrationConfig {
125    /// Enable TWR session tracking
126    pub twr_tracking: bool,
127    /// Session idle timeout
128    pub session_idle_timeout: Duration,
129    /// Maximum sessions to track
130    pub max_sessions: usize,
131    /// Enable workload detection
132    pub workload_detection: bool,
133    /// RAG pattern detection threshold
134    pub rag_detection_threshold: f32,
135    /// Agent conversation detection threshold
136    pub agent_detection_threshold: f32,
137}
138
139impl Default for AIIntegrationConfig {
140    fn default() -> Self {
141        Self {
142            twr_tracking: true,
143            session_idle_timeout: Duration::from_secs(3600), // 1 hour
144            max_sessions: 10000,
145            workload_detection: true,
146            rag_detection_threshold: 0.7,
147            agent_detection_threshold: 0.8,
148        }
149    }
150}
151
152/// Statistics for AI integration
153#[derive(Debug, Default)]
154struct AIIntegrationStats {
155    /// Total workload detections
156    detections: AtomicU64,
157    /// RAG workloads detected
158    rag_detected: AtomicU64,
159    /// Agent workloads detected
160    agent_detected: AtomicU64,
161    /// Tool workloads detected
162    tool_detected: AtomicU64,
163    /// Sessions tracked
164    sessions_tracked: AtomicU64,
165    /// Cross-feature cache hits
166    cross_feature_hits: AtomicU64,
167}
168
169impl AIIntegrationCoordinator {
170    /// Create a new integration coordinator
171    pub fn new(semantic_cache: Arc<SemanticQueryCache>, config: AIIntegrationConfig) -> Self {
172        Self {
173            semantic_cache,
174            sessions: DashMap::new(),
175            scheduler: None,
176            config,
177            stats: AIIntegrationStats::default(),
178        }
179    }
180
181    /// Set the workload scheduler reference
182    pub fn with_scheduler(mut self, scheduler: Arc<WorkloadScheduler>) -> Self {
183        self.scheduler = Some(scheduler);
184        self
185    }
186
187    /// Detect AI workload from query patterns
188    pub fn detect_workload(&self, query: &str, session: Option<&SessionId>) -> AIWorkloadDetection {
189        self.stats.detections.fetch_add(1, Ordering::Relaxed);
190
191        let mut patterns = Vec::new();
192        let mut confidence = 0.0f32;
193        let mut ai_context = AIWorkloadContext::General;
194        let mut workload_type = WorkloadType::Mixed;
195
196        // Pattern matching for workload detection
197        let query_lower = query.to_lowercase();
198
199        // RAG retrieval patterns
200        if self.is_rag_pattern(&query_lower) {
201            patterns.push("RAG retrieval".to_string());
202            confidence = 0.85;
203            ai_context = AIWorkloadContext::RAGRetrieval;
204            workload_type = WorkloadType::RAG;
205            self.stats.rag_detected.fetch_add(1, Ordering::Relaxed);
206        }
207
208        // Agent conversation patterns
209        if self.is_agent_pattern(&query_lower, session) {
210            patterns.push("Agent conversation".to_string());
211            confidence = confidence.max(0.80);
212            ai_context = AIWorkloadContext::AgentConversation;
213            workload_type = WorkloadType::AIAgent;
214            self.stats.agent_detected.fetch_add(1, Ordering::Relaxed);
215        }
216
217        // Tool result patterns
218        if self.is_tool_pattern(&query_lower) {
219            patterns.push("Tool result".to_string());
220            confidence = confidence.max(0.90);
221            ai_context = AIWorkloadContext::ToolResult;
222            workload_type = WorkloadType::AIAgent;
223            self.stats.tool_detected.fetch_add(1, Ordering::Relaxed);
224        }
225
226        // Vector search patterns (only if not already classified as RAG)
227        if workload_type != WorkloadType::RAG
228            && (query_lower.contains("embedding")
229                || query_lower.contains("vector")
230                || query_lower.contains("similarity"))
231        {
232            patterns.push("Vector search".to_string());
233            confidence = confidence.max(0.75);
234            workload_type = WorkloadType::Vector;
235        }
236
237        // OLAP patterns
238        if self.is_olap_pattern(&query_lower) {
239            patterns.push("OLAP analytics".to_string());
240            confidence = confidence.max(0.70);
241            workload_type = WorkloadType::OLAP;
242        }
243
244        AIWorkloadDetection {
245            workload_type,
246            ai_context,
247            confidence,
248            patterns,
249        }
250    }
251
252    /// Check if query matches RAG retrieval patterns
253    fn is_rag_pattern(&self, query: &str) -> bool {
254        // RAG typically involves:
255        // - Semantic search / similarity queries
256        // - Chunk retrieval
257        // - Document lookups
258        let rag_patterns = [
259            "chunk",
260            "retrieve",
261            "context",
262            "passage",
263            "document",
264            "semantic",
265            "similarity",
266            "cosine",
267            "embedding",
268        ];
269
270        rag_patterns.iter().any(|p| query.contains(p))
271    }
272
273    /// Check if query matches agent conversation patterns
274    fn is_agent_pattern(&self, query: &str, session: Option<&SessionId>) -> bool {
275        // Check session history for conversation patterns
276        if let Some(sid) = session {
277            if let Some(info) = self.sessions.get(sid) {
278                if info.ai_context == AIWorkloadContext::AgentConversation {
279                    return true;
280                }
281                // Conversation pattern: multiple sequential queries
282                if info.query_count > 5 && info.cache_hit_rate() > 0.3 {
283                    return true;
284                }
285            }
286        }
287
288        // Query pattern matching
289        let agent_patterns = ["conversation", "history", "context", "message", "response"];
290        agent_patterns.iter().any(|p| query.contains(p))
291    }
292
293    /// Check if query matches tool result patterns
294    fn is_tool_pattern(&self, query: &str) -> bool {
295        let tool_patterns = [
296            "tool_",
297            "function_",
298            "api_result",
299            "calculate",
300            "format_",
301            "convert_",
302            "lookup_",
303        ];
304
305        tool_patterns.iter().any(|p| query.contains(p))
306    }
307
308    /// Check if query matches OLAP patterns
309    fn is_olap_pattern(&self, query: &str) -> bool {
310        let olap_patterns = [
311            "group by",
312            "having",
313            "aggregate",
314            "sum(",
315            "count(",
316            "avg(",
317            "window",
318            "partition by",
319            "rollup",
320            "cube",
321        ];
322
323        olap_patterns.iter().any(|p| query.contains(p))
324    }
325
326    /// Track session for TWR integration
327    pub fn track_session(
328        &self,
329        session_id: impl Into<String>,
330        branch: Option<BranchContext>,
331        ai_context: AIWorkloadContext,
332    ) {
333        let sid = SessionId::new(session_id);
334
335        // Limit session count
336        if self.sessions.len() >= self.config.max_sessions {
337            self.cleanup_idle_sessions();
338        }
339
340        let mut info = SessionTrackingInfo::new(sid.0.clone()).with_ai_context(ai_context);
341
342        if let Some(b) = branch {
343            info = info.with_branch(b);
344        }
345
346        self.sessions.insert(sid, info);
347        self.stats.sessions_tracked.fetch_add(1, Ordering::Relaxed);
348    }
349
350    /// Update session activity
351    pub fn update_session(&self, session_id: &SessionId, cache_hit: bool) {
352        if let Some(mut info) = self.sessions.get_mut(session_id) {
353            info.record_query();
354            if cache_hit {
355                info.record_cache_hit();
356                self.stats
357                    .cross_feature_hits
358                    .fetch_add(1, Ordering::Relaxed);
359            }
360        }
361    }
362
363    /// Get session info
364    pub fn get_session(&self, session_id: &SessionId) -> Option<SessionTrackingInfo> {
365        self.sessions.get(session_id).map(|r| r.clone())
366    }
367
368    /// Begin transaction for session
369    pub fn begin_transaction(&self, session_id: &SessionId) {
370        if let Some(mut info) = self.sessions.get_mut(session_id) {
371            info.transaction_depth += 1;
372        }
373    }
374
375    /// End transaction for session
376    pub fn end_transaction(&self, session_id: &SessionId) {
377        if let Some(mut info) = self.sessions.get_mut(session_id) {
378            if info.transaction_depth > 0 {
379                info.transaction_depth -= 1;
380            }
381        }
382    }
383
384    /// Check if session is in transaction
385    pub fn is_in_transaction(&self, session_id: &SessionId) -> bool {
386        self.sessions
387            .get(session_id)
388            .map(|info| info.transaction_depth > 0)
389            .unwrap_or(false)
390    }
391
392    /// Get recommended cache behavior based on workload
393    pub fn get_cache_recommendation(&self, detection: &AIWorkloadDetection) -> CacheRecommendation {
394        match detection.ai_context {
395            AIWorkloadContext::RAGRetrieval => CacheRecommendation {
396                should_cache: true,
397                ttl: Duration::from_secs(300),
398                priority: CachePriority::High,
399                tier: RecommendedTier::L1,
400            },
401            AIWorkloadContext::RAGGeneration => CacheRecommendation {
402                should_cache: true,
403                ttl: Duration::from_secs(1800),
404                priority: CachePriority::Medium,
405                tier: RecommendedTier::L2,
406            },
407            AIWorkloadContext::AgentConversation => CacheRecommendation {
408                should_cache: true,
409                ttl: Duration::from_secs(3600),
410                priority: CachePriority::High,
411                tier: RecommendedTier::L1,
412            },
413            AIWorkloadContext::ToolResult => CacheRecommendation {
414                should_cache: true,
415                ttl: Duration::from_secs(86400),
416                priority: CachePriority::Low,
417                tier: RecommendedTier::L2,
418            },
419            AIWorkloadContext::General => match detection.workload_type {
420                WorkloadType::OLTP => CacheRecommendation {
421                    should_cache: true,
422                    ttl: Duration::from_secs(60),
423                    priority: CachePriority::High,
424                    tier: RecommendedTier::L1,
425                },
426                WorkloadType::OLAP => CacheRecommendation {
427                    should_cache: true,
428                    ttl: Duration::from_secs(3600),
429                    priority: CachePriority::Low,
430                    tier: RecommendedTier::L3,
431                },
432                WorkloadType::Vector => CacheRecommendation {
433                    should_cache: true,
434                    ttl: Duration::from_secs(600),
435                    priority: CachePriority::Medium,
436                    tier: RecommendedTier::L2,
437                },
438                _ => CacheRecommendation::default(),
439            },
440        }
441    }
442
443    /// Schedule query with AI-aware priority
444    pub fn schedule_with_ai_priority(
445        &self,
446        query_id: u64,
447        detection: &AIWorkloadDetection,
448    ) -> Option<ScheduleResult> {
449        let scheduler = self.scheduler.as_ref()?;
450
451        let query = ScheduledQuery {
452            id: query_id,
453            workload_type: detection.workload_type,
454            timestamp: std::time::Instant::now(),
455        };
456
457        Some(scheduler.schedule(query))
458    }
459
460    /// Cleanup idle sessions
461    pub fn cleanup_idle_sessions(&self) {
462        let timeout = self.config.session_idle_timeout;
463        let to_remove: Vec<_> = self
464            .sessions
465            .iter()
466            .filter(|e| e.is_idle(timeout))
467            .map(|e| e.key().clone())
468            .collect();
469
470        for sid in to_remove {
471            self.sessions.remove(&sid);
472        }
473    }
474
475    /// Invalidate cache entries for branch
476    pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
477        self.semantic_cache.invalidate_branch(branch)
478    }
479
480    /// Invalidate cache entries for table
481    pub fn invalidate_table(&self, table: &str) -> usize {
482        self.semantic_cache.invalidate_by_table(table)
483    }
484
485    /// Get integration statistics
486    pub fn stats(&self) -> AIIntegrationStatsSnapshot {
487        AIIntegrationStatsSnapshot {
488            total_detections: self.stats.detections.load(Ordering::Relaxed),
489            rag_detected: self.stats.rag_detected.load(Ordering::Relaxed),
490            agent_detected: self.stats.agent_detected.load(Ordering::Relaxed),
491            tool_detected: self.stats.tool_detected.load(Ordering::Relaxed),
492            active_sessions: self.sessions.len(),
493            cross_feature_hits: self.stats.cross_feature_hits.load(Ordering::Relaxed),
494        }
495    }
496}
497
498/// Cache recommendation based on workload
499#[derive(Debug, Clone)]
500pub struct CacheRecommendation {
501    /// Whether to cache this query
502    pub should_cache: bool,
503    /// Recommended TTL
504    pub ttl: Duration,
505    /// Cache priority
506    pub priority: CachePriority,
507    /// Recommended cache tier
508    pub tier: RecommendedTier,
509}
510
511impl Default for CacheRecommendation {
512    fn default() -> Self {
513        Self {
514            should_cache: true,
515            ttl: Duration::from_secs(300),
516            priority: CachePriority::Medium,
517            tier: RecommendedTier::L1,
518        }
519    }
520}
521
522/// Cache priority levels
523#[derive(Debug, Clone, Copy, PartialEq, Eq)]
524pub enum CachePriority {
525    High,
526    Medium,
527    Low,
528}
529
530/// Recommended cache tier
531#[derive(Debug, Clone, Copy, PartialEq, Eq)]
532pub enum RecommendedTier {
533    L1,
534    L2,
535    L3,
536}
537
538/// Statistics snapshot
539#[derive(Debug, Clone)]
540pub struct AIIntegrationStatsSnapshot {
541    pub total_detections: u64,
542    pub rag_detected: u64,
543    pub agent_detected: u64,
544    pub tool_detected: u64,
545    pub active_sessions: usize,
546    pub cross_feature_hits: u64,
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552
553    #[test]
554    fn test_workload_detection_rag() {
555        let cache = Arc::new(SemanticQueryCache::new(0.9));
556        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
557
558        let detection = coordinator.detect_workload(
559            "SELECT * FROM chunks WHERE document_id = 1 AND similarity > 0.8",
560            None,
561        );
562
563        assert_eq!(detection.workload_type, WorkloadType::RAG);
564        assert_eq!(detection.ai_context, AIWorkloadContext::RAGRetrieval);
565        assert!(detection.confidence > 0.7);
566    }
567
568    #[test]
569    fn test_workload_detection_agent() {
570        let cache = Arc::new(SemanticQueryCache::new(0.9));
571        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
572
573        let detection = coordinator.detect_workload(
574            "SELECT * FROM conversation_history WHERE session_id = 'abc'",
575            None,
576        );
577
578        assert_eq!(detection.ai_context, AIWorkloadContext::AgentConversation);
579        assert!(detection
580            .patterns
581            .contains(&"Agent conversation".to_string()));
582    }
583
584    #[test]
585    fn test_workload_detection_tool() {
586        let cache = Arc::new(SemanticQueryCache::new(0.9));
587        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
588
589        let detection = coordinator.detect_workload(
590            "SELECT tool_calculate_result FROM api_result WHERE id = 1",
591            None,
592        );
593
594        assert_eq!(detection.ai_context, AIWorkloadContext::ToolResult);
595    }
596
597    #[test]
598    fn test_workload_detection_olap() {
599        let cache = Arc::new(SemanticQueryCache::new(0.9));
600        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
601
602        let detection = coordinator.detect_workload(
603            "SELECT category, SUM(amount) FROM orders GROUP BY category HAVING COUNT(*) > 10",
604            None,
605        );
606
607        assert_eq!(detection.workload_type, WorkloadType::OLAP);
608    }
609
610    #[test]
611    fn test_session_tracking() {
612        let cache = Arc::new(SemanticQueryCache::new(0.9));
613        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
614
615        let sid = SessionId::new("session-1");
616
617        // Track session
618        coordinator.track_session(
619            "session-1",
620            Some(BranchContext::main()),
621            AIWorkloadContext::AgentConversation,
622        );
623
624        // Update session
625        coordinator.update_session(&sid, true);
626        coordinator.update_session(&sid, false);
627
628        let info = coordinator.get_session(&sid).unwrap();
629        assert_eq!(info.query_count, 2);
630        assert_eq!(info.cache_hits, 1);
631        assert!((info.cache_hit_rate() - 0.5).abs() < 0.01);
632    }
633
634    #[test]
635    fn test_transaction_tracking() {
636        let cache = Arc::new(SemanticQueryCache::new(0.9));
637        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
638
639        let sid = SessionId::new("session-tx");
640        coordinator.track_session("session-tx", None, AIWorkloadContext::General);
641
642        assert!(!coordinator.is_in_transaction(&sid));
643
644        coordinator.begin_transaction(&sid);
645        assert!(coordinator.is_in_transaction(&sid));
646
647        coordinator.end_transaction(&sid);
648        assert!(!coordinator.is_in_transaction(&sid));
649    }
650
651    #[test]
652    fn test_cache_recommendation() {
653        let cache = Arc::new(SemanticQueryCache::new(0.9));
654        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
655
656        // RAG retrieval should have short TTL and high priority
657        let rag_detection = AIWorkloadDetection {
658            workload_type: WorkloadType::RAG,
659            ai_context: AIWorkloadContext::RAGRetrieval,
660            confidence: 0.9,
661            patterns: vec![],
662        };
663        let rag_rec = coordinator.get_cache_recommendation(&rag_detection);
664        assert_eq!(rag_rec.priority, CachePriority::High);
665        assert_eq!(rag_rec.tier, RecommendedTier::L1);
666
667        // Tool result should have long TTL
668        let tool_detection = AIWorkloadDetection {
669            workload_type: WorkloadType::AIAgent,
670            ai_context: AIWorkloadContext::ToolResult,
671            confidence: 0.9,
672            patterns: vec![],
673        };
674        let tool_rec = coordinator.get_cache_recommendation(&tool_detection);
675        assert_eq!(tool_rec.ttl, Duration::from_secs(86400));
676    }
677
678    #[test]
679    fn test_stats() {
680        let cache = Arc::new(SemanticQueryCache::new(0.9));
681        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
682
683        // Detect various workloads
684        coordinator.detect_workload("SELECT * FROM chunks", None);
685        coordinator.detect_workload("SELECT conversation_history", None);
686        coordinator.detect_workload("SELECT tool_result", None);
687
688        let stats = coordinator.stats();
689        assert_eq!(stats.total_detections, 3);
690    }
691
692    #[test]
693    fn test_invalidation() {
694        let cache = Arc::new(SemanticQueryCache::new(0.9));
695
696        // Insert some entries
697        cache.insert_with_context(
698            "query1",
699            vec![1.0, 0.0],
700            serde_json::json!(1),
701            Some(BranchContext::main()),
702            None,
703            AIWorkloadContext::General,
704            vec!["users".to_string()],
705        );
706
707        let coordinator =
708            AIIntegrationCoordinator::new(cache.clone(), AIIntegrationConfig::default());
709
710        // Invalidate by table
711        let removed = coordinator.invalidate_table("users");
712        assert_eq!(removed, 1);
713        assert_eq!(cache.len(), 0);
714    }
715}