Skip to main content

heliosdb_proxy/distribcache/ai/
integration.rs

1//! Cross-feature AI/Agent integration module
2//!
3//! Integrates semantic caching with other HeliosProxy features:
4//! - TWR (Transaction Write Replay) session tracking
5//! - Sync mode / lag routing awareness
6//! - Workload scheduler coordination
7//! - Branch-aware time-travel queries
8
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use dashmap::DashMap;
14
15use super::semantic::{AIWorkloadContext, BranchContext, BranchId, SemanticQueryCache};
16use crate::distribcache::classifier::WorkloadType;
17use crate::distribcache::scheduler::{ScheduleResult, ScheduledQuery, WorkloadScheduler};
18use crate::distribcache::SessionId;
19
20/// AI workload detection result
21#[derive(Debug, Clone)]
22pub struct AIWorkloadDetection {
23    /// Detected workload type
24    pub workload_type: WorkloadType,
25    /// AI-specific context
26    pub ai_context: AIWorkloadContext,
27    /// Confidence score (0.0 - 1.0)
28    pub confidence: f32,
29    /// Detected patterns
30    pub patterns: Vec<String>,
31}
32
33/// Session tracking for TWR integration
34#[derive(Debug, Clone)]
35pub struct SessionTrackingInfo {
36    /// Session identifier
37    pub session_id: SessionId,
38    /// Active branch context
39    pub branch: Option<BranchContext>,
40    /// Last query time
41    pub last_activity: Instant,
42    /// Transaction depth (0 = not in transaction)
43    pub transaction_depth: u32,
44    /// AI workload context
45    pub ai_context: AIWorkloadContext,
46    /// Total queries in session
47    pub query_count: u64,
48    /// Cache hit count
49    pub cache_hits: u64,
50}
51
52impl SessionTrackingInfo {
53    /// Create a new session tracking info
54    pub fn new(session_id: impl Into<String>) -> Self {
55        Self {
56            session_id: SessionId::new(session_id),
57            branch: None,
58            last_activity: Instant::now(),
59            transaction_depth: 0,
60            ai_context: AIWorkloadContext::General,
61            query_count: 0,
62            cache_hits: 0,
63        }
64    }
65
66    /// Set branch context
67    pub fn with_branch(mut self, branch: BranchContext) -> Self {
68        self.branch = Some(branch);
69        self
70    }
71
72    /// Set AI context
73    pub fn with_ai_context(mut self, context: AIWorkloadContext) -> Self {
74        self.ai_context = context;
75        self
76    }
77
78    /// Record a query
79    pub fn record_query(&mut self) {
80        self.query_count += 1;
81        self.last_activity = Instant::now();
82    }
83
84    /// Record a cache hit
85    pub fn record_cache_hit(&mut self) {
86        self.cache_hits += 1;
87    }
88
89    /// Get cache hit rate
90    pub fn cache_hit_rate(&self) -> f64 {
91        if self.query_count == 0 {
92            0.0
93        } else {
94            self.cache_hits as f64 / self.query_count as f64
95        }
96    }
97
98    /// Check if session is idle
99    pub fn is_idle(&self, timeout: Duration) -> bool {
100        self.last_activity.elapsed() > timeout
101    }
102}
103
104/// Cross-feature AI integration coordinator
105pub struct AIIntegrationCoordinator {
106    /// Semantic cache reference
107    semantic_cache: Arc<SemanticQueryCache>,
108
109    /// Session tracking
110    sessions: DashMap<SessionId, SessionTrackingInfo>,
111
112    /// Workload scheduler reference (optional)
113    scheduler: Option<Arc<WorkloadScheduler>>,
114
115    /// Configuration
116    config: AIIntegrationConfig,
117
118    /// Statistics
119    stats: AIIntegrationStats,
120}
121
122/// Configuration for AI integration
123#[derive(Debug, Clone)]
124pub struct AIIntegrationConfig {
125    /// Enable TWR session tracking
126    pub twr_tracking: bool,
127    /// Session idle timeout
128    pub session_idle_timeout: Duration,
129    /// Maximum sessions to track
130    pub max_sessions: usize,
131    /// Enable workload detection
132    pub workload_detection: bool,
133    /// RAG pattern detection threshold
134    pub rag_detection_threshold: f32,
135    /// Agent conversation detection threshold
136    pub agent_detection_threshold: f32,
137}
138
139impl Default for AIIntegrationConfig {
140    fn default() -> Self {
141        Self {
142            twr_tracking: true,
143            session_idle_timeout: Duration::from_secs(3600), // 1 hour
144            max_sessions: 10000,
145            workload_detection: true,
146            rag_detection_threshold: 0.7,
147            agent_detection_threshold: 0.8,
148        }
149    }
150}
151
152/// Statistics for AI integration
153#[derive(Debug, Default)]
154struct AIIntegrationStats {
155    /// Total workload detections
156    detections: AtomicU64,
157    /// RAG workloads detected
158    rag_detected: AtomicU64,
159    /// Agent workloads detected
160    agent_detected: AtomicU64,
161    /// Tool workloads detected
162    tool_detected: AtomicU64,
163    /// Sessions tracked
164    sessions_tracked: AtomicU64,
165    /// Cross-feature cache hits
166    cross_feature_hits: AtomicU64,
167}
168
169impl AIIntegrationCoordinator {
170    /// Create a new integration coordinator
171    pub fn new(semantic_cache: Arc<SemanticQueryCache>, config: AIIntegrationConfig) -> Self {
172        Self {
173            semantic_cache,
174            sessions: DashMap::new(),
175            scheduler: None,
176            config,
177            stats: AIIntegrationStats::default(),
178        }
179    }
180
181    /// Set the workload scheduler reference
182    pub fn with_scheduler(mut self, scheduler: Arc<WorkloadScheduler>) -> Self {
183        self.scheduler = Some(scheduler);
184        self
185    }
186
187    /// Detect AI workload from query patterns
188    pub fn detect_workload(&self, query: &str, session: Option<&SessionId>) -> AIWorkloadDetection {
189        self.stats.detections.fetch_add(1, Ordering::Relaxed);
190
191        let mut patterns = Vec::new();
192        let mut confidence = 0.0f32;
193        let mut ai_context = AIWorkloadContext::General;
194        let mut workload_type = WorkloadType::Mixed;
195
196        // Pattern matching for workload detection
197        let query_lower = query.to_lowercase();
198
199        // RAG retrieval patterns
200        if self.is_rag_pattern(&query_lower) {
201            patterns.push("RAG retrieval".to_string());
202            confidence = 0.85;
203            ai_context = AIWorkloadContext::RAGRetrieval;
204            workload_type = WorkloadType::RAG;
205            self.stats.rag_detected.fetch_add(1, Ordering::Relaxed);
206        }
207
208        // Agent conversation patterns
209        if self.is_agent_pattern(&query_lower, session) {
210            patterns.push("Agent conversation".to_string());
211            confidence = confidence.max(0.80);
212            ai_context = AIWorkloadContext::AgentConversation;
213            workload_type = WorkloadType::AIAgent;
214            self.stats.agent_detected.fetch_add(1, Ordering::Relaxed);
215        }
216
217        // Tool result patterns
218        if self.is_tool_pattern(&query_lower) {
219            patterns.push("Tool result".to_string());
220            confidence = confidence.max(0.90);
221            ai_context = AIWorkloadContext::ToolResult;
222            workload_type = WorkloadType::AIAgent;
223            self.stats.tool_detected.fetch_add(1, Ordering::Relaxed);
224        }
225
226        // Vector search patterns (only if not already classified as RAG)
227        if workload_type != WorkloadType::RAG {
228            if query_lower.contains("embedding") || query_lower.contains("vector") || query_lower.contains("similarity") {
229                patterns.push("Vector search".to_string());
230                confidence = confidence.max(0.75);
231                workload_type = WorkloadType::Vector;
232            }
233        }
234
235        // OLAP patterns
236        if self.is_olap_pattern(&query_lower) {
237            patterns.push("OLAP analytics".to_string());
238            confidence = confidence.max(0.70);
239            workload_type = WorkloadType::OLAP;
240        }
241
242        AIWorkloadDetection {
243            workload_type,
244            ai_context,
245            confidence,
246            patterns,
247        }
248    }
249
250    /// Check if query matches RAG retrieval patterns
251    fn is_rag_pattern(&self, query: &str) -> bool {
252        // RAG typically involves:
253        // - Semantic search / similarity queries
254        // - Chunk retrieval
255        // - Document lookups
256        let rag_patterns = [
257            "chunk", "retrieve", "context", "passage", "document",
258            "semantic", "similarity", "cosine", "embedding",
259        ];
260
261        rag_patterns.iter().any(|p| query.contains(p))
262    }
263
264    /// Check if query matches agent conversation patterns
265    fn is_agent_pattern(&self, query: &str, session: Option<&SessionId>) -> bool {
266        // Check session history for conversation patterns
267        if let Some(sid) = session {
268            if let Some(info) = self.sessions.get(sid) {
269                if info.ai_context == AIWorkloadContext::AgentConversation {
270                    return true;
271                }
272                // Conversation pattern: multiple sequential queries
273                if info.query_count > 5 && info.cache_hit_rate() > 0.3 {
274                    return true;
275                }
276            }
277        }
278
279        // Query pattern matching
280        let agent_patterns = ["conversation", "history", "context", "message", "response"];
281        agent_patterns.iter().any(|p| query.contains(p))
282    }
283
284    /// Check if query matches tool result patterns
285    fn is_tool_pattern(&self, query: &str) -> bool {
286        let tool_patterns = [
287            "tool_", "function_", "api_result", "calculate",
288            "format_", "convert_", "lookup_",
289        ];
290
291        tool_patterns.iter().any(|p| query.contains(p))
292    }
293
294    /// Check if query matches OLAP patterns
295    fn is_olap_pattern(&self, query: &str) -> bool {
296        let olap_patterns = [
297            "group by", "having", "aggregate", "sum(", "count(",
298            "avg(", "window", "partition by", "rollup", "cube",
299        ];
300
301        olap_patterns.iter().any(|p| query.contains(p))
302    }
303
304    /// Track session for TWR integration
305    pub fn track_session(
306        &self,
307        session_id: impl Into<String>,
308        branch: Option<BranchContext>,
309        ai_context: AIWorkloadContext,
310    ) {
311        let sid = SessionId::new(session_id);
312
313        // Limit session count
314        if self.sessions.len() >= self.config.max_sessions {
315            self.cleanup_idle_sessions();
316        }
317
318        let mut info = SessionTrackingInfo::new(sid.0.clone())
319            .with_ai_context(ai_context);
320
321        if let Some(b) = branch {
322            info = info.with_branch(b);
323        }
324
325        self.sessions.insert(sid, info);
326        self.stats.sessions_tracked.fetch_add(1, Ordering::Relaxed);
327    }
328
329    /// Update session activity
330    pub fn update_session(&self, session_id: &SessionId, cache_hit: bool) {
331        if let Some(mut info) = self.sessions.get_mut(session_id) {
332            info.record_query();
333            if cache_hit {
334                info.record_cache_hit();
335                self.stats.cross_feature_hits.fetch_add(1, Ordering::Relaxed);
336            }
337        }
338    }
339
340    /// Get session info
341    pub fn get_session(&self, session_id: &SessionId) -> Option<SessionTrackingInfo> {
342        self.sessions.get(session_id).map(|r| r.clone())
343    }
344
345    /// Begin transaction for session
346    pub fn begin_transaction(&self, session_id: &SessionId) {
347        if let Some(mut info) = self.sessions.get_mut(session_id) {
348            info.transaction_depth += 1;
349        }
350    }
351
352    /// End transaction for session
353    pub fn end_transaction(&self, session_id: &SessionId) {
354        if let Some(mut info) = self.sessions.get_mut(session_id) {
355            if info.transaction_depth > 0 {
356                info.transaction_depth -= 1;
357            }
358        }
359    }
360
361    /// Check if session is in transaction
362    pub fn is_in_transaction(&self, session_id: &SessionId) -> bool {
363        self.sessions
364            .get(session_id)
365            .map(|info| info.transaction_depth > 0)
366            .unwrap_or(false)
367    }
368
369    /// Get recommended cache behavior based on workload
370    pub fn get_cache_recommendation(&self, detection: &AIWorkloadDetection) -> CacheRecommendation {
371        match detection.ai_context {
372            AIWorkloadContext::RAGRetrieval => CacheRecommendation {
373                should_cache: true,
374                ttl: Duration::from_secs(300),
375                priority: CachePriority::High,
376                tier: RecommendedTier::L1,
377            },
378            AIWorkloadContext::RAGGeneration => CacheRecommendation {
379                should_cache: true,
380                ttl: Duration::from_secs(1800),
381                priority: CachePriority::Medium,
382                tier: RecommendedTier::L2,
383            },
384            AIWorkloadContext::AgentConversation => CacheRecommendation {
385                should_cache: true,
386                ttl: Duration::from_secs(3600),
387                priority: CachePriority::High,
388                tier: RecommendedTier::L1,
389            },
390            AIWorkloadContext::ToolResult => CacheRecommendation {
391                should_cache: true,
392                ttl: Duration::from_secs(86400),
393                priority: CachePriority::Low,
394                tier: RecommendedTier::L2,
395            },
396            AIWorkloadContext::General => {
397                match detection.workload_type {
398                    WorkloadType::OLTP => CacheRecommendation {
399                        should_cache: true,
400                        ttl: Duration::from_secs(60),
401                        priority: CachePriority::High,
402                        tier: RecommendedTier::L1,
403                    },
404                    WorkloadType::OLAP => CacheRecommendation {
405                        should_cache: true,
406                        ttl: Duration::from_secs(3600),
407                        priority: CachePriority::Low,
408                        tier: RecommendedTier::L3,
409                    },
410                    WorkloadType::Vector => CacheRecommendation {
411                        should_cache: true,
412                        ttl: Duration::from_secs(600),
413                        priority: CachePriority::Medium,
414                        tier: RecommendedTier::L2,
415                    },
416                    _ => CacheRecommendation::default(),
417                }
418            }
419        }
420    }
421
422    /// Schedule query with AI-aware priority
423    pub fn schedule_with_ai_priority(
424        &self,
425        query_id: u64,
426        detection: &AIWorkloadDetection,
427    ) -> Option<ScheduleResult> {
428        let scheduler = self.scheduler.as_ref()?;
429
430        let query = ScheduledQuery {
431            id: query_id,
432            workload_type: detection.workload_type,
433            timestamp: std::time::Instant::now(),
434        };
435
436        Some(scheduler.schedule(query))
437    }
438
439    /// Cleanup idle sessions
440    pub fn cleanup_idle_sessions(&self) {
441        let timeout = self.config.session_idle_timeout;
442        let to_remove: Vec<_> = self.sessions
443            .iter()
444            .filter(|e| e.is_idle(timeout))
445            .map(|e| e.key().clone())
446            .collect();
447
448        for sid in to_remove {
449            self.sessions.remove(&sid);
450        }
451    }
452
453    /// Invalidate cache entries for branch
454    pub fn invalidate_branch(&self, branch: &BranchId) -> usize {
455        self.semantic_cache.invalidate_branch(branch)
456    }
457
458    /// Invalidate cache entries for table
459    pub fn invalidate_table(&self, table: &str) -> usize {
460        self.semantic_cache.invalidate_by_table(table)
461    }
462
463    /// Get integration statistics
464    pub fn stats(&self) -> AIIntegrationStatsSnapshot {
465        AIIntegrationStatsSnapshot {
466            total_detections: self.stats.detections.load(Ordering::Relaxed),
467            rag_detected: self.stats.rag_detected.load(Ordering::Relaxed),
468            agent_detected: self.stats.agent_detected.load(Ordering::Relaxed),
469            tool_detected: self.stats.tool_detected.load(Ordering::Relaxed),
470            active_sessions: self.sessions.len(),
471            cross_feature_hits: self.stats.cross_feature_hits.load(Ordering::Relaxed),
472        }
473    }
474}
475
476/// Cache recommendation based on workload
477#[derive(Debug, Clone)]
478pub struct CacheRecommendation {
479    /// Whether to cache this query
480    pub should_cache: bool,
481    /// Recommended TTL
482    pub ttl: Duration,
483    /// Cache priority
484    pub priority: CachePriority,
485    /// Recommended cache tier
486    pub tier: RecommendedTier,
487}
488
489impl Default for CacheRecommendation {
490    fn default() -> Self {
491        Self {
492            should_cache: true,
493            ttl: Duration::from_secs(300),
494            priority: CachePriority::Medium,
495            tier: RecommendedTier::L1,
496        }
497    }
498}
499
500/// Cache priority levels
501#[derive(Debug, Clone, Copy, PartialEq, Eq)]
502pub enum CachePriority {
503    High,
504    Medium,
505    Low,
506}
507
508/// Recommended cache tier
509#[derive(Debug, Clone, Copy, PartialEq, Eq)]
510pub enum RecommendedTier {
511    L1,
512    L2,
513    L3,
514}
515
516/// Statistics snapshot
517#[derive(Debug, Clone)]
518pub struct AIIntegrationStatsSnapshot {
519    pub total_detections: u64,
520    pub rag_detected: u64,
521    pub agent_detected: u64,
522    pub tool_detected: u64,
523    pub active_sessions: usize,
524    pub cross_feature_hits: u64,
525}
526
527#[cfg(test)]
528mod tests {
529    use super::*;
530
531    #[test]
532    fn test_workload_detection_rag() {
533        let cache = Arc::new(SemanticQueryCache::new(0.9));
534        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
535
536        let detection = coordinator.detect_workload(
537            "SELECT * FROM chunks WHERE document_id = 1 AND similarity > 0.8",
538            None,
539        );
540
541        assert_eq!(detection.workload_type, WorkloadType::RAG);
542        assert_eq!(detection.ai_context, AIWorkloadContext::RAGRetrieval);
543        assert!(detection.confidence > 0.7);
544    }
545
546    #[test]
547    fn test_workload_detection_agent() {
548        let cache = Arc::new(SemanticQueryCache::new(0.9));
549        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
550
551        let detection = coordinator.detect_workload(
552            "SELECT * FROM conversation_history WHERE session_id = 'abc'",
553            None,
554        );
555
556        assert_eq!(detection.ai_context, AIWorkloadContext::AgentConversation);
557        assert!(detection.patterns.contains(&"Agent conversation".to_string()));
558    }
559
560    #[test]
561    fn test_workload_detection_tool() {
562        let cache = Arc::new(SemanticQueryCache::new(0.9));
563        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
564
565        let detection = coordinator.detect_workload(
566            "SELECT tool_calculate_result FROM api_result WHERE id = 1",
567            None,
568        );
569
570        assert_eq!(detection.ai_context, AIWorkloadContext::ToolResult);
571    }
572
573    #[test]
574    fn test_workload_detection_olap() {
575        let cache = Arc::new(SemanticQueryCache::new(0.9));
576        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
577
578        let detection = coordinator.detect_workload(
579            "SELECT category, SUM(amount) FROM orders GROUP BY category HAVING COUNT(*) > 10",
580            None,
581        );
582
583        assert_eq!(detection.workload_type, WorkloadType::OLAP);
584    }
585
586    #[test]
587    fn test_session_tracking() {
588        let cache = Arc::new(SemanticQueryCache::new(0.9));
589        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
590
591        let sid = SessionId::new("session-1");
592
593        // Track session
594        coordinator.track_session(
595            "session-1",
596            Some(BranchContext::main()),
597            AIWorkloadContext::AgentConversation,
598        );
599
600        // Update session
601        coordinator.update_session(&sid, true);
602        coordinator.update_session(&sid, false);
603
604        let info = coordinator.get_session(&sid).unwrap();
605        assert_eq!(info.query_count, 2);
606        assert_eq!(info.cache_hits, 1);
607        assert!((info.cache_hit_rate() - 0.5).abs() < 0.01);
608    }
609
610    #[test]
611    fn test_transaction_tracking() {
612        let cache = Arc::new(SemanticQueryCache::new(0.9));
613        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
614
615        let sid = SessionId::new("session-tx");
616        coordinator.track_session("session-tx", None, AIWorkloadContext::General);
617
618        assert!(!coordinator.is_in_transaction(&sid));
619
620        coordinator.begin_transaction(&sid);
621        assert!(coordinator.is_in_transaction(&sid));
622
623        coordinator.end_transaction(&sid);
624        assert!(!coordinator.is_in_transaction(&sid));
625    }
626
627    #[test]
628    fn test_cache_recommendation() {
629        let cache = Arc::new(SemanticQueryCache::new(0.9));
630        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
631
632        // RAG retrieval should have short TTL and high priority
633        let rag_detection = AIWorkloadDetection {
634            workload_type: WorkloadType::RAG,
635            ai_context: AIWorkloadContext::RAGRetrieval,
636            confidence: 0.9,
637            patterns: vec![],
638        };
639        let rag_rec = coordinator.get_cache_recommendation(&rag_detection);
640        assert_eq!(rag_rec.priority, CachePriority::High);
641        assert_eq!(rag_rec.tier, RecommendedTier::L1);
642
643        // Tool result should have long TTL
644        let tool_detection = AIWorkloadDetection {
645            workload_type: WorkloadType::AIAgent,
646            ai_context: AIWorkloadContext::ToolResult,
647            confidence: 0.9,
648            patterns: vec![],
649        };
650        let tool_rec = coordinator.get_cache_recommendation(&tool_detection);
651        assert_eq!(tool_rec.ttl, Duration::from_secs(86400));
652    }
653
654    #[test]
655    fn test_stats() {
656        let cache = Arc::new(SemanticQueryCache::new(0.9));
657        let coordinator = AIIntegrationCoordinator::new(cache, AIIntegrationConfig::default());
658
659        // Detect various workloads
660        coordinator.detect_workload("SELECT * FROM chunks", None);
661        coordinator.detect_workload("SELECT conversation_history", None);
662        coordinator.detect_workload("SELECT tool_result", None);
663
664        let stats = coordinator.stats();
665        assert_eq!(stats.total_detections, 3);
666    }
667
668    #[test]
669    fn test_invalidation() {
670        let cache = Arc::new(SemanticQueryCache::new(0.9));
671
672        // Insert some entries
673        cache.insert_with_context(
674            "query1",
675            vec![1.0, 0.0],
676            serde_json::json!(1),
677            Some(BranchContext::main()),
678            None,
679            AIWorkloadContext::General,
680            vec!["users".to_string()],
681        );
682
683        let coordinator = AIIntegrationCoordinator::new(cache.clone(), AIIntegrationConfig::default());
684
685        // Invalidate by table
686        let removed = coordinator.invalidate_table("users");
687        assert_eq!(removed, 1);
688        assert_eq!(cache.len(), 0);
689    }
690}