Skip to main content

heliosdb_proxy/schema_routing/
router.rs

1//! Schema-Aware Router
2//!
3//! Routes queries based on schema semantics and workload characteristics.
4
5use std::sync::Arc;
6
7use super::{
8    NodeInfo, SyncMode, SchemaRoutingConfig,
9    registry::{SchemaRegistry, NodeCapabilities, DataTemperature, WorkloadType, AccessPattern},
10    analyzer::{QueryAnalyzer, QueryAnalysis},
11};
12
13/// Schema-aware query router
14#[derive(Debug)]
15pub struct SchemaAwareRouter {
16    /// Configuration
17    config: SchemaRoutingConfig,
18    /// Schema registry
19    schema: Arc<SchemaRegistry>,
20    /// Query analyzer
21    analyzer: QueryAnalyzer,
22    /// Available nodes
23    nodes: Vec<NodeInfo>,
24    /// AI workload detector
25    ai_detector: AIWorkloadDetector,
26    /// RAG router
27    rag_router: RAGRouter,
28}
29
30impl SchemaAwareRouter {
31    /// Create a new schema-aware router
32    pub fn new(config: SchemaRoutingConfig, schema: Arc<SchemaRegistry>) -> Self {
33        Self {
34            analyzer: QueryAnalyzer::new(schema.clone()),
35            schema,
36            config,
37            nodes: Vec::new(),
38            ai_detector: AIWorkloadDetector::new(),
39            rag_router: RAGRouter::new(),
40        }
41    }
42
43    /// Add a node to the router
44    pub fn add_node(&mut self, node: NodeInfo) {
45        self.nodes.push(node);
46    }
47
48    /// Remove a node from the router
49    pub fn remove_node(&mut self, node_id: &str) {
50        self.nodes.retain(|n| n.id != node_id);
51    }
52
53    /// Update node status
54    pub fn update_node(&mut self, node_id: &str, load: f64, latency_ms: u64) {
55        if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
56            node.current_load = load;
57            node.current_latency_ms = latency_ms;
58        }
59    }
60
61    /// Route a query
62    pub fn route(&self, query: &str) -> RoutingDecision {
63        if !self.config.enabled {
64            return RoutingDecision::default_routing();
65        }
66
67        let analysis = self.analyzer.analyze(query);
68
69        // 1. Check for AI workload patterns
70        if let Some(ai_workload) = self.ai_detector.detect(query) {
71            let preference = self.ai_detector.get_optimal_routing(ai_workload);
72            return self.apply_preference(preference, &analysis);
73        }
74
75        // 2. Determine required capabilities
76        let required_caps = self.get_required_capabilities(&analysis);
77
78        // 3. Filter eligible nodes
79        let eligible = self.filter_by_capabilities(&required_caps);
80
81        // 4. Check sharding
82        if let Some(shard_routing) = self.try_shard_routing(&analysis) {
83            return shard_routing;
84        }
85
86        // 5. Route based on workload type
87        match analysis.workload_type {
88            WorkloadType::OLTP => self.route_oltp(&eligible, &analysis),
89            WorkloadType::OLAP => self.route_olap(&eligible, &analysis),
90            WorkloadType::Vector => self.route_vector(&eligible, &analysis),
91            WorkloadType::HTAP | WorkloadType::Mixed => self.route_mixed(&eligible, &analysis),
92        }
93    }
94
95    /// Route with branch context
96    pub fn route_with_branch(&self, query: &str, branch: &str) -> RoutingDecision {
97        let analysis = self.analyzer.analyze(query);
98
99        // Get nodes that have the branch data
100        let branch_nodes = self.schema.get_branch_locations(branch);
101
102        // Filter by query requirements
103        let required_caps = self.get_required_capabilities(&analysis);
104        let eligible = self.filter_by_capabilities(&required_caps);
105
106        // Intersection with branch nodes
107        let available: Vec<_> = eligible
108            .iter()
109            .filter(|n| branch_nodes.contains(&n.id))
110            .cloned()
111            .collect();
112
113        if available.is_empty() {
114            // Branch not replicated to eligible nodes
115            return RoutingDecision {
116                target: RouteTarget::Primary,
117                reason: RoutingReason::BranchNotAvailable,
118                branch: Some(branch.to_string()),
119                ..Default::default()
120            };
121        }
122
123        self.select_best(&available, &analysis)
124    }
125
126    /// Route for time-travel query
127    pub fn route_time_travel(&self, query: &str, age_days: i64) -> RoutingDecision {
128        let analysis = self.analyzer.analyze(query);
129
130        // Recent data on hot nodes
131        if age_days < 7 {
132            return self.route_to_temperature_nodes(DataTemperature::Hot, &analysis);
133        }
134
135        // Older data on warm nodes
136        if age_days < 30 {
137            return self.route_to_temperature_nodes(DataTemperature::Warm, &analysis);
138        }
139
140        // Historical data on cold/archive nodes
141        self.route_to_temperature_nodes(DataTemperature::Cold, &analysis)
142    }
143
144    /// Route RAG query
145    pub fn route_rag(&self, stage: RAGStage, query: &str) -> RoutingDecision {
146        let analysis = self.analyzer.analyze(query);
147        self.rag_router.route_rag_query(stage, &analysis, &self.nodes)
148    }
149
150    /// Get required capabilities based on query analysis
151    fn get_required_capabilities(&self, analysis: &QueryAnalysis) -> NodeCapabilities {
152        let mut caps = NodeCapabilities::default();
153
154        // Vector queries need vector-capable nodes
155        if analysis.access_patterns.contains(&AccessPattern::VectorSearch) {
156            caps.vector_search = true;
157            caps.gpu_acceleration = true; // Prefer GPU nodes
158        }
159
160        // OLAP queries prefer columnar storage
161        if analysis.workload_type == WorkloadType::OLAP {
162            caps.columnar_storage = true;
163        }
164
165        // Hot tables need in-memory nodes
166        for table in &analysis.tables {
167            if let Some(schema) = &table.schema {
168                if schema.temperature == DataTemperature::Hot {
169                    caps.in_memory = true;
170                }
171            }
172        }
173
174        caps
175    }
176
177    /// Filter nodes by capabilities
178    fn filter_by_capabilities(&self, required: &NodeCapabilities) -> Vec<NodeInfo> {
179        self.nodes
180            .iter()
181            .filter(|n| n.capabilities.satisfies(required) || !required.has_requirements())
182            .cloned()
183            .collect()
184    }
185
186    /// Try to route to specific shard
187    fn try_shard_routing(&self, analysis: &QueryAnalysis) -> Option<RoutingDecision> {
188        for table in &analysis.tables {
189            if let Some(schema) = &table.schema {
190                if let Some(shard_key) = &schema.shard_key {
191                    if let Some(shard_value) = analysis.shard_keys.get(shard_key) {
192                        let value = match shard_value {
193                            super::analyzer::ShardKeyValue::Single(v) => v.clone(),
194                            super::analyzer::ShardKeyValue::Multiple(v) => {
195                                // Multiple values = scatter-gather
196                                return Some(RoutingDecision {
197                                    target: RouteTarget::ScatterGather,
198                                    shards: v.iter().filter_map(|val| {
199                                        self.schema.get_shard(shard_key, val)
200                                    }).collect(),
201                                    reason: RoutingReason::ShardKey,
202                                    ..Default::default()
203                                });
204                            }
205                        };
206
207                        if let Some(shard) = self.schema.get_shard(shard_key, &value) {
208                            return Some(RoutingDecision {
209                                target: RouteTarget::Shard(shard),
210                                reason: RoutingReason::ShardKey,
211                                ..Default::default()
212                            });
213                        }
214                    }
215                }
216            }
217        }
218        None
219    }
220
221    /// Route OLTP workload
222    fn route_oltp(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
223        // Write queries must go to primary
224        if !analysis.is_read_only {
225            return RoutingDecision {
226                target: RouteTarget::Primary,
227                reason: RoutingReason::WriteQuery,
228                ..Default::default()
229            };
230        }
231
232        // OLTP: Low latency, prefer primary or sync standbys
233        let mut preferred: Vec<_> = nodes
234            .iter()
235            .filter(|n| n.sync_mode == SyncMode::Sync || n.is_primary)
236            .cloned()
237            .collect();
238
239        preferred.sort_by_key(|n| n.current_latency_ms);
240
241        if let Some(node) = preferred.first() {
242            RoutingDecision {
243                target: RouteTarget::Node(node.id.clone()),
244                reason: RoutingReason::LowLatency,
245                node_info: Some(node.clone()),
246                ..Default::default()
247            }
248        } else {
249            RoutingDecision::default_routing()
250        }
251    }
252
253    /// Route OLAP workload
254    fn route_olap(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
255        // OLAP: Throughput over latency, prefer async standbys with columnar storage
256        let mut preferred: Vec<_> = nodes
257            .iter()
258            .filter(|n| n.capabilities.columnar_storage)
259            .cloned()
260            .collect();
261
262        if preferred.is_empty() {
263            // Fall back to any async standby
264            preferred = nodes
265                .iter()
266                .filter(|n| n.sync_mode == SyncMode::Async)
267                .cloned()
268                .collect();
269        }
270
271        preferred.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
272
273        if let Some(node) = preferred.first() {
274            RoutingDecision {
275                target: RouteTarget::Node(node.id.clone()),
276                reason: RoutingReason::ColumnarStorage,
277                node_info: Some(node.clone()),
278                ..Default::default()
279            }
280        } else {
281            RoutingDecision::default_routing()
282        }
283    }
284
285    /// Route vector workload
286    fn route_vector(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
287        // Vector: Need vector-capable nodes, prefer GPU
288        let mut vector_nodes: Vec<_> = nodes
289            .iter()
290            .filter(|n| n.capabilities.vector_search)
291            .cloned()
292            .collect();
293
294        // Sort by: GPU first, then lower load
295        vector_nodes.sort_by(|a, b| {
296            b.capabilities.gpu_acceleration
297                .cmp(&a.capabilities.gpu_acceleration)
298                .then_with(|| a.current_load.partial_cmp(&b.current_load).unwrap())
299        });
300
301        if let Some(node) = vector_nodes.first() {
302            RoutingDecision {
303                target: RouteTarget::Node(node.id.clone()),
304                reason: RoutingReason::VectorCapable,
305                node_info: Some(node.clone()),
306                ..Default::default()
307            }
308        } else {
309            // No vector-capable nodes, fall back to primary
310            RoutingDecision {
311                target: RouteTarget::Primary,
312                reason: RoutingReason::NoVectorNodes,
313                ..Default::default()
314            }
315        }
316    }
317
318    /// Route mixed workload
319    fn route_mixed(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
320        // Mixed: Balance between latency and throughput
321        if !analysis.is_read_only {
322            return RoutingDecision {
323                target: RouteTarget::Primary,
324                reason: RoutingReason::WriteQuery,
325                ..Default::default()
326            };
327        }
328
329        // Sort by weighted score: latency + load
330        let mut scored: Vec<_> = nodes
331            .iter()
332            .map(|n| {
333                let score = (n.current_latency_ms as f64) + (n.current_load * 100.0);
334                (n, score)
335            })
336            .collect();
337
338        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
339
340        if let Some((node, _)) = scored.first() {
341            RoutingDecision {
342                target: RouteTarget::Node(node.id.clone()),
343                reason: RoutingReason::LowestScore,
344                node_info: Some((*node).clone()),
345                ..Default::default()
346            }
347        } else {
348            RoutingDecision::default_routing()
349        }
350    }
351
352    /// Route to nodes with specific temperature
353    fn route_to_temperature_nodes(&self, temp: DataTemperature, analysis: &QueryAnalysis) -> RoutingDecision {
354        // Find nodes that host tables with matching temperature
355        let matching_nodes: Vec<_> = self.nodes
356            .iter()
357            .filter(|n| {
358                match temp {
359                    DataTemperature::Hot => n.capabilities.in_memory,
360                    DataTemperature::Warm => !n.capabilities.in_memory && !self.is_cold_storage(n),
361                    DataTemperature::Cold | DataTemperature::Frozen => self.is_cold_storage(n),
362                }
363            })
364            .cloned()
365            .collect();
366
367        if matching_nodes.is_empty() {
368            return self.route_mixed(&self.nodes, analysis);
369        }
370
371        self.select_best(&matching_nodes, analysis)
372    }
373
374    /// Check if node is cold storage
375    fn is_cold_storage(&self, node: &NodeInfo) -> bool {
376        // Heuristic: cold storage nodes have no in-memory capability
377        // and are not primary/sync
378        !node.capabilities.in_memory && node.sync_mode == SyncMode::Async && !node.is_primary
379    }
380
381    /// Select best node from candidates
382    fn select_best(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
383        if nodes.is_empty() {
384            return RoutingDecision::default_routing();
385        }
386
387        // Sort by latency for read queries, load for analytics
388        let mut sorted = nodes.to_vec();
389        if analysis.workload_type == WorkloadType::OLAP {
390            sorted.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
391        } else {
392            sorted.sort_by_key(|n| n.current_latency_ms);
393        }
394
395        let node = &sorted[0];
396        RoutingDecision {
397            target: RouteTarget::Node(node.id.clone()),
398            reason: RoutingReason::BestCandidate,
399            node_info: Some(node.clone()),
400            ..Default::default()
401        }
402    }
403
404    /// Apply routing preference
405    fn apply_preference(&self, preference: RoutingPreference, analysis: &QueryAnalysis) -> RoutingDecision {
406        match preference {
407            RoutingPreference::VectorNodes { prefer_gpu } => {
408                let nodes: Vec<_> = self.nodes
409                    .iter()
410                    .filter(|n| n.capabilities.vector_search)
411                    .filter(|n| !prefer_gpu || n.capabilities.gpu_acceleration)
412                    .cloned()
413                    .collect();
414                self.select_best(&nodes, analysis)
415            }
416            RoutingPreference::LowLatency { max_lag_ms } => {
417                let nodes: Vec<_> = self.nodes
418                    .iter()
419                    .filter(|n| n.current_latency_ms <= max_lag_ms)
420                    .cloned()
421                    .collect();
422                self.select_best(&nodes, analysis)
423            }
424            RoutingPreference::HighThroughput => {
425                let nodes: Vec<_> = self.nodes
426                    .iter()
427                    .filter(|n| n.sync_mode == SyncMode::Async)
428                    .cloned()
429                    .collect();
430                self.select_best(&nodes, analysis)
431            }
432            RoutingPreference::Primary => {
433                RoutingDecision {
434                    target: RouteTarget::Primary,
435                    reason: RoutingReason::AIWorkload,
436                    ..Default::default()
437                }
438            }
439        }
440    }
441}
442
443impl NodeCapabilities {
444    /// Check if there are any requirements
445    fn has_requirements(&self) -> bool {
446        self.vector_search || self.gpu_acceleration || self.columnar_storage
447            || self.in_memory || self.content_addressed
448    }
449}
450
451/// Routing decision
452#[derive(Debug, Clone, Default)]
453pub struct RoutingDecision {
454    /// Target for routing
455    pub target: RouteTarget,
456    /// Reason for decision
457    pub reason: RoutingReason,
458    /// Target shards (for scatter-gather)
459    pub shards: Vec<u32>,
460    /// Branch context
461    pub branch: Option<String>,
462    /// Selected node info
463    pub node_info: Option<NodeInfo>,
464}
465
466impl RoutingDecision {
467    /// Create a shard routing decision
468    pub fn shard(shard_id: u32) -> Self {
469        Self {
470            target: RouteTarget::Shard(shard_id),
471            reason: RoutingReason::ShardKey,
472            ..Default::default()
473        }
474    }
475
476    /// Create a single node routing decision
477    pub fn single(node: NodeInfo) -> Self {
478        Self {
479            target: RouteTarget::Node(node.id.clone()),
480            reason: RoutingReason::BestCandidate,
481            node_info: Some(node),
482            ..Default::default()
483        }
484    }
485
486    /// Create default routing (to primary)
487    pub fn default_routing() -> Self {
488        Self {
489            target: RouteTarget::Primary,
490            reason: RoutingReason::Default,
491            ..Default::default()
492        }
493    }
494
495    /// Check if routing to primary
496    pub fn is_primary(&self) -> bool {
497        matches!(self.target, RouteTarget::Primary)
498    }
499
500    /// Check if scatter-gather needed
501    pub fn is_scatter_gather(&self) -> bool {
502        matches!(self.target, RouteTarget::ScatterGather)
503    }
504}
505
506/// Route target
507#[derive(Debug, Clone, Default)]
508pub enum RouteTarget {
509    /// Route to primary
510    #[default]
511    Primary,
512    /// Route to specific node
513    Node(String),
514    /// Route to specific shard
515    Shard(u32),
516    /// Scatter-gather across shards
517    ScatterGather,
518}
519
520/// Routing reason
521#[derive(Debug, Clone, Default)]
522pub enum RoutingReason {
523    /// Default routing
524    #[default]
525    Default,
526    /// Write query must go to primary
527    WriteQuery,
528    /// Shard key present in query
529    ShardKey,
530    /// Lowest latency node
531    LowLatency,
532    /// Node with columnar storage
533    ColumnarStorage,
534    /// Vector-capable node
535    VectorCapable,
536    /// No vector-capable nodes available
537    NoVectorNodes,
538    /// Branch not available on eligible nodes
539    BranchNotAvailable,
540    /// Best candidate from scoring
541    BestCandidate,
542    /// Lowest combined score
543    LowestScore,
544    /// AI workload routing
545    AIWorkload,
546}
547
548/// Routing preference for AI workloads
549#[derive(Debug, Clone)]
550pub enum RoutingPreference {
551    /// Prefer vector-capable nodes
552    VectorNodes { prefer_gpu: bool },
553    /// Prefer low-latency nodes
554    LowLatency { max_lag_ms: u64 },
555    /// Prefer high-throughput nodes
556    HighThroughput,
557    /// Must route to primary
558    Primary,
559}
560
561/// AI workload type
562#[derive(Debug, Clone, Copy, PartialEq, Eq)]
563pub enum AIWorkloadType {
564    /// Embedding retrieval (vector search)
565    EmbeddingRetrieval,
566    /// Context/conversation lookup
567    ContextLookup,
568    /// Knowledge base query
569    KnowledgeBase,
570    /// Tool execution (writes)
571    ToolExecution,
572}
573
574/// AI workload detector
575#[derive(Debug, Default)]
576pub struct AIWorkloadDetector {
577    /// Patterns for detection
578    patterns: Vec<AIPattern>,
579}
580
581#[derive(Debug)]
582struct AIPattern {
583    keyword: String,
584    workload_type: AIWorkloadType,
585}
586
587impl AIWorkloadDetector {
588    /// Create a new detector
589    pub fn new() -> Self {
590        Self {
591            patterns: vec![
592                AIPattern { keyword: "<->".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
593                AIPattern { keyword: "VECTOR".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
594                AIPattern { keyword: "EMBEDDING".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
595                AIPattern { keyword: "CONVERSATION".to_string(), workload_type: AIWorkloadType::ContextLookup },
596                AIPattern { keyword: "TURNS".to_string(), workload_type: AIWorkloadType::ContextLookup },
597                AIPattern { keyword: "DOCUMENTS".to_string(), workload_type: AIWorkloadType::KnowledgeBase },
598                AIPattern { keyword: "CHUNKS".to_string(), workload_type: AIWorkloadType::KnowledgeBase },
599                AIPattern { keyword: "TOOL_RESULTS".to_string(), workload_type: AIWorkloadType::ToolExecution },
600                AIPattern { keyword: "ACTIONS".to_string(), workload_type: AIWorkloadType::ToolExecution },
601            ],
602        }
603    }
604
605    /// Detect AI workload type
606    pub fn detect(&self, query: &str) -> Option<AIWorkloadType> {
607        let upper = query.to_uppercase();
608
609        for pattern in &self.patterns {
610            if upper.contains(&pattern.keyword) {
611                return Some(pattern.workload_type);
612            }
613        }
614
615        None
616    }
617
618    /// Get optimal routing for AI workload
619    pub fn get_optimal_routing(&self, workload: AIWorkloadType) -> RoutingPreference {
620        match workload {
621            AIWorkloadType::EmbeddingRetrieval => {
622                RoutingPreference::VectorNodes { prefer_gpu: true }
623            }
624            AIWorkloadType::ContextLookup => {
625                RoutingPreference::LowLatency { max_lag_ms: 100 }
626            }
627            AIWorkloadType::KnowledgeBase => {
628                RoutingPreference::HighThroughput
629            }
630            AIWorkloadType::ToolExecution => {
631                RoutingPreference::Primary
632            }
633        }
634    }
635}
636
637/// RAG stage
638#[derive(Debug, Clone, Copy, PartialEq, Eq)]
639pub enum RAGStage {
640    /// Retrieval stage (vector search)
641    Retrieval,
642    /// Fetch stage (document lookup)
643    Fetch,
644    /// Rerank stage
645    Rerank,
646    /// Generation stage
647    Generate,
648}
649
650/// RAG router
651#[derive(Debug, Default)]
652pub struct RAGRouter {}
653
654impl RAGRouter {
655    /// Create a new RAG router
656    pub fn new() -> Self {
657        Self {}
658    }
659
660    /// Route RAG query based on stage
661    pub fn route_rag_query(&self, stage: RAGStage, analysis: &QueryAnalysis, nodes: &[NodeInfo]) -> RoutingDecision {
662        match stage {
663            RAGStage::Retrieval => {
664                // Vector search on embeddings
665                let vector_nodes: Vec<_> = nodes
666                    .iter()
667                    .filter(|n| n.capabilities.vector_search)
668                    .cloned()
669                    .collect();
670
671                if let Some(node) = vector_nodes.first() {
672                    RoutingDecision::single(node.clone())
673                } else {
674                    RoutingDecision::default_routing()
675                }
676            }
677            RAGStage::Fetch => {
678                // Bulk fetch - high throughput
679                let throughput_nodes: Vec<_> = nodes
680                    .iter()
681                    .filter(|n| n.sync_mode == SyncMode::Async)
682                    .cloned()
683                    .collect();
684
685                if let Some(node) = throughput_nodes.first() {
686                    RoutingDecision::single(node.clone())
687                } else {
688                    RoutingDecision::default_routing()
689                }
690            }
691            RAGStage::Rerank => {
692                // Light computation - lowest latency
693                let mut sorted = nodes.to_vec();
694                sorted.sort_by_key(|n| n.current_latency_ms);
695
696                if let Some(node) = sorted.first() {
697                    RoutingDecision::single(node.clone())
698                } else {
699                    RoutingDecision::default_routing()
700                }
701            }
702            RAGStage::Generate => {
703                // May write to cache - check if write
704                if !analysis.is_read_only {
705                    RoutingDecision::default_routing()
706                } else {
707                    let mut sorted = nodes.to_vec();
708                    sorted.sort_by_key(|n| n.current_latency_ms);
709
710                    if let Some(node) = sorted.first() {
711                        RoutingDecision::single(node.clone())
712                    } else {
713                        RoutingDecision::default_routing()
714                    }
715                }
716            }
717        }
718    }
719}
720
721#[cfg(test)]
722mod tests {
723    use super::*;
724    use crate::schema_routing::registry::TableSchema;
725
726    fn create_test_setup() -> SchemaAwareRouter {
727        let registry = Arc::new(SchemaRegistry::new());
728
729        registry.register_table(
730            TableSchema::new("users")
731                .with_workload(WorkloadType::OLTP)
732                .with_access_pattern(AccessPattern::PointLookup)
733                .with_primary_key(vec!["id".to_string()])
734        );
735
736        registry.register_table(
737            TableSchema::new("events")
738                .with_workload(WorkloadType::OLAP)
739                .with_temperature(DataTemperature::Cold)
740        );
741
742        registry.register_table(
743            TableSchema::new("embeddings")
744                .with_workload(WorkloadType::Vector)
745        );
746
747        let config = SchemaRoutingConfig::default();
748        let mut router = SchemaAwareRouter::new(config, registry);
749
750        // Add test nodes
751        router.add_node(NodeInfo::new("primary", "primary").as_primary());
752        router.add_node(NodeInfo::new("standby-sync", "standby-sync")
753            .with_sync_mode(SyncMode::Sync));
754        router.add_node(NodeInfo::new("standby-async", "standby-async")
755            .with_sync_mode(SyncMode::Async)
756            .with_capabilities(NodeCapabilities::analytics_node()));
757        router.add_node(NodeInfo::new("vector-node", "vector-node")
758            .with_sync_mode(SyncMode::Async)
759            .with_capabilities(NodeCapabilities::vector_node()));
760
761        router
762    }
763
764    #[test]
765    fn test_route_oltp_read() {
766        let router = create_test_setup();
767        let decision = router.route("SELECT * FROM users WHERE id = 1");
768
769        assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::LowLatency));
770    }
771
772    #[test]
773    fn test_route_write_to_primary() {
774        let router = create_test_setup();
775        let decision = router.route("INSERT INTO users (name) VALUES ('test')");
776
777        assert!(decision.is_primary());
778        assert!(matches!(decision.reason, RoutingReason::WriteQuery));
779    }
780
781    #[test]
782    fn test_route_vector_query() {
783        let router = create_test_setup();
784        let decision = router.route("SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10");
785
786        assert!(matches!(decision.reason, RoutingReason::VectorCapable | RoutingReason::BestCandidate) || decision.is_primary());
787    }
788
789    #[test]
790    fn test_route_olap_query() {
791        let router = create_test_setup();
792        let decision = router.route("SELECT COUNT(*), SUM(amount) FROM events GROUP BY date");
793
794        // Should prefer columnar storage or async nodes
795        assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::ColumnarStorage | RoutingReason::Default));
796    }
797
798    #[test]
799    fn test_ai_workload_detection() {
800        let detector = AIWorkloadDetector::new();
801
802        let embedding = "SELECT * FROM embeddings ORDER BY vector <-> $1";
803        let context = "SELECT * FROM conversation WHERE session_id = $1";
804        let tool = "INSERT INTO tool_results (result) VALUES ($1)";
805
806        assert_eq!(detector.detect(embedding), Some(AIWorkloadType::EmbeddingRetrieval));
807        assert_eq!(detector.detect(context), Some(AIWorkloadType::ContextLookup));
808        assert_eq!(detector.detect(tool), Some(AIWorkloadType::ToolExecution));
809    }
810
811    #[test]
812    fn test_rag_routing() {
813        let router = create_test_setup();
814
815        let retrieval = router.route_rag(RAGStage::Retrieval, "SELECT embedding FROM docs");
816        let fetch = router.route_rag(RAGStage::Fetch, "SELECT content FROM docs WHERE id IN (1,2,3)");
817
818        // Retrieval should prefer vector nodes
819        // Fetch should prefer high throughput
820        assert!(retrieval.node_info.is_some() || retrieval.is_primary());
821        assert!(fetch.node_info.is_some() || fetch.is_primary());
822    }
823
824    #[test]
825    fn test_routing_decision_helpers() {
826        let decision = RoutingDecision::shard(3);
827        assert!(matches!(decision.target, RouteTarget::Shard(3)));
828
829        let default = RoutingDecision::default_routing();
830        assert!(default.is_primary());
831    }
832}