Skip to main content

heliosdb_proxy/schema_routing/
router.rs

1//! Schema-Aware Router
2//!
3//! Routes queries based on schema semantics and workload characteristics.
4
5use std::sync::Arc;
6
7use super::{
8    analyzer::{QueryAnalysis, QueryAnalyzer},
9    registry::{AccessPattern, DataTemperature, NodeCapabilities, SchemaRegistry, WorkloadType},
10    NodeInfo, SchemaRoutingConfig, SyncMode,
11};
12
13/// Schema-aware query router
14#[derive(Debug)]
15pub struct SchemaAwareRouter {
16    /// Configuration
17    config: SchemaRoutingConfig,
18    /// Schema registry
19    schema: Arc<SchemaRegistry>,
20    /// Query analyzer
21    analyzer: QueryAnalyzer,
22    /// Available nodes
23    nodes: Vec<NodeInfo>,
24    /// AI workload detector
25    ai_detector: AIWorkloadDetector,
26    /// RAG router
27    rag_router: RAGRouter,
28}
29
30impl SchemaAwareRouter {
31    /// Create a new schema-aware router
32    pub fn new(config: SchemaRoutingConfig, schema: Arc<SchemaRegistry>) -> Self {
33        Self {
34            analyzer: QueryAnalyzer::new(schema.clone()),
35            schema,
36            config,
37            nodes: Vec::new(),
38            ai_detector: AIWorkloadDetector::new(),
39            rag_router: RAGRouter::new(),
40        }
41    }
42
43    /// Add a node to the router
44    pub fn add_node(&mut self, node: NodeInfo) {
45        self.nodes.push(node);
46    }
47
48    /// Remove a node from the router
49    pub fn remove_node(&mut self, node_id: &str) {
50        self.nodes.retain(|n| n.id != node_id);
51    }
52
53    /// Update node status
54    pub fn update_node(&mut self, node_id: &str, load: f64, latency_ms: u64) {
55        if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
56            node.current_load = load;
57            node.current_latency_ms = latency_ms;
58        }
59    }
60
61    /// Route a query
62    pub fn route(&self, query: &str) -> RoutingDecision {
63        if !self.config.enabled {
64            return RoutingDecision::default_routing();
65        }
66
67        let analysis = self.analyzer.analyze(query);
68
69        // 1. Check for AI workload patterns
70        if let Some(ai_workload) = self.ai_detector.detect(query) {
71            let preference = self.ai_detector.get_optimal_routing(ai_workload);
72            return self.apply_preference(preference, &analysis);
73        }
74
75        // 2. Determine required capabilities
76        let required_caps = self.get_required_capabilities(&analysis);
77
78        // 3. Filter eligible nodes
79        let eligible = self.filter_by_capabilities(&required_caps);
80
81        // 4. Check sharding
82        if let Some(shard_routing) = self.try_shard_routing(&analysis) {
83            return shard_routing;
84        }
85
86        // 5. Route based on workload type
87        match analysis.workload_type {
88            WorkloadType::OLTP => self.route_oltp(&eligible, &analysis),
89            WorkloadType::OLAP => self.route_olap(&eligible, &analysis),
90            WorkloadType::Vector => self.route_vector(&eligible, &analysis),
91            WorkloadType::HTAP | WorkloadType::Mixed => self.route_mixed(&eligible, &analysis),
92        }
93    }
94
95    /// Route with branch context
96    pub fn route_with_branch(&self, query: &str, branch: &str) -> RoutingDecision {
97        let analysis = self.analyzer.analyze(query);
98
99        // Get nodes that have the branch data
100        let branch_nodes = self.schema.get_branch_locations(branch);
101
102        // Filter by query requirements
103        let required_caps = self.get_required_capabilities(&analysis);
104        let eligible = self.filter_by_capabilities(&required_caps);
105
106        // Intersection with branch nodes
107        let available: Vec<_> = eligible
108            .iter()
109            .filter(|n| branch_nodes.contains(&n.id))
110            .cloned()
111            .collect();
112
113        if available.is_empty() {
114            // Branch not replicated to eligible nodes
115            return RoutingDecision {
116                target: RouteTarget::Primary,
117                reason: RoutingReason::BranchNotAvailable,
118                branch: Some(branch.to_string()),
119                ..Default::default()
120            };
121        }
122
123        self.select_best(&available, &analysis)
124    }
125
126    /// Route for time-travel query
127    pub fn route_time_travel(&self, query: &str, age_days: i64) -> RoutingDecision {
128        let analysis = self.analyzer.analyze(query);
129
130        // Recent data on hot nodes
131        if age_days < 7 {
132            return self.route_to_temperature_nodes(DataTemperature::Hot, &analysis);
133        }
134
135        // Older data on warm nodes
136        if age_days < 30 {
137            return self.route_to_temperature_nodes(DataTemperature::Warm, &analysis);
138        }
139
140        // Historical data on cold/archive nodes
141        self.route_to_temperature_nodes(DataTemperature::Cold, &analysis)
142    }
143
144    /// Route RAG query
145    pub fn route_rag(&self, stage: RAGStage, query: &str) -> RoutingDecision {
146        let analysis = self.analyzer.analyze(query);
147        self.rag_router
148            .route_rag_query(stage, &analysis, &self.nodes)
149    }
150
151    /// Get required capabilities based on query analysis
152    fn get_required_capabilities(&self, analysis: &QueryAnalysis) -> NodeCapabilities {
153        let mut caps = NodeCapabilities::default();
154
155        // Vector queries need vector-capable nodes
156        if analysis
157            .access_patterns
158            .contains(&AccessPattern::VectorSearch)
159        {
160            caps.vector_search = true;
161            caps.gpu_acceleration = true; // Prefer GPU nodes
162        }
163
164        // OLAP queries prefer columnar storage
165        if analysis.workload_type == WorkloadType::OLAP {
166            caps.columnar_storage = true;
167        }
168
169        // Hot tables need in-memory nodes
170        for table in &analysis.tables {
171            if let Some(schema) = &table.schema {
172                if schema.temperature == DataTemperature::Hot {
173                    caps.in_memory = true;
174                }
175            }
176        }
177
178        caps
179    }
180
181    /// Filter nodes by capabilities
182    fn filter_by_capabilities(&self, required: &NodeCapabilities) -> Vec<NodeInfo> {
183        self.nodes
184            .iter()
185            .filter(|n| n.capabilities.satisfies(required) || !required.has_requirements())
186            .cloned()
187            .collect()
188    }
189
190    /// Try to route to specific shard
191    fn try_shard_routing(&self, analysis: &QueryAnalysis) -> Option<RoutingDecision> {
192        for table in &analysis.tables {
193            if let Some(schema) = &table.schema {
194                if let Some(shard_key) = &schema.shard_key {
195                    if let Some(shard_value) = analysis.shard_keys.get(shard_key) {
196                        let value = match shard_value {
197                            super::analyzer::ShardKeyValue::Single(v) => v.clone(),
198                            super::analyzer::ShardKeyValue::Multiple(v) => {
199                                // Multiple values = scatter-gather
200                                return Some(RoutingDecision {
201                                    target: RouteTarget::ScatterGather,
202                                    shards: v
203                                        .iter()
204                                        .filter_map(|val| self.schema.get_shard(shard_key, val))
205                                        .collect(),
206                                    reason: RoutingReason::ShardKey,
207                                    ..Default::default()
208                                });
209                            }
210                        };
211
212                        if let Some(shard) = self.schema.get_shard(shard_key, &value) {
213                            return Some(RoutingDecision {
214                                target: RouteTarget::Shard(shard),
215                                reason: RoutingReason::ShardKey,
216                                ..Default::default()
217                            });
218                        }
219                    }
220                }
221            }
222        }
223        None
224    }
225
226    /// Route OLTP workload
227    fn route_oltp(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
228        // Write queries must go to primary
229        if !analysis.is_read_only {
230            return RoutingDecision {
231                target: RouteTarget::Primary,
232                reason: RoutingReason::WriteQuery,
233                ..Default::default()
234            };
235        }
236
237        // OLTP: Low latency, prefer primary or sync standbys
238        let mut preferred: Vec<_> = nodes
239            .iter()
240            .filter(|n| n.sync_mode == SyncMode::Sync || n.is_primary)
241            .cloned()
242            .collect();
243
244        preferred.sort_by_key(|n| n.current_latency_ms);
245
246        if let Some(node) = preferred.first() {
247            RoutingDecision {
248                target: RouteTarget::Node(node.id.clone()),
249                reason: RoutingReason::LowLatency,
250                node_info: Some(node.clone()),
251                ..Default::default()
252            }
253        } else {
254            RoutingDecision::default_routing()
255        }
256    }
257
258    /// Route OLAP workload
259    fn route_olap(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
260        // OLAP: Throughput over latency, prefer async standbys with columnar storage
261        let mut preferred: Vec<_> = nodes
262            .iter()
263            .filter(|n| n.capabilities.columnar_storage)
264            .cloned()
265            .collect();
266
267        if preferred.is_empty() {
268            // Fall back to any async standby
269            preferred = nodes
270                .iter()
271                .filter(|n| n.sync_mode == SyncMode::Async)
272                .cloned()
273                .collect();
274        }
275
276        preferred.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
277
278        if let Some(node) = preferred.first() {
279            RoutingDecision {
280                target: RouteTarget::Node(node.id.clone()),
281                reason: RoutingReason::ColumnarStorage,
282                node_info: Some(node.clone()),
283                ..Default::default()
284            }
285        } else {
286            RoutingDecision::default_routing()
287        }
288    }
289
290    /// Route vector workload
291    fn route_vector(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
292        // Vector: Need vector-capable nodes, prefer GPU
293        let mut vector_nodes: Vec<_> = nodes
294            .iter()
295            .filter(|n| n.capabilities.vector_search)
296            .cloned()
297            .collect();
298
299        // Sort by: GPU first, then lower load
300        vector_nodes.sort_by(|a, b| {
301            b.capabilities
302                .gpu_acceleration
303                .cmp(&a.capabilities.gpu_acceleration)
304                .then_with(|| a.current_load.partial_cmp(&b.current_load).unwrap())
305        });
306
307        if let Some(node) = vector_nodes.first() {
308            RoutingDecision {
309                target: RouteTarget::Node(node.id.clone()),
310                reason: RoutingReason::VectorCapable,
311                node_info: Some(node.clone()),
312                ..Default::default()
313            }
314        } else {
315            // No vector-capable nodes, fall back to primary
316            RoutingDecision {
317                target: RouteTarget::Primary,
318                reason: RoutingReason::NoVectorNodes,
319                ..Default::default()
320            }
321        }
322    }
323
324    /// Route mixed workload
325    fn route_mixed(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
326        // Mixed: Balance between latency and throughput
327        if !analysis.is_read_only {
328            return RoutingDecision {
329                target: RouteTarget::Primary,
330                reason: RoutingReason::WriteQuery,
331                ..Default::default()
332            };
333        }
334
335        // Sort by weighted score: latency + load
336        let mut scored: Vec<_> = nodes
337            .iter()
338            .map(|n| {
339                let score = (n.current_latency_ms as f64) + (n.current_load * 100.0);
340                (n, score)
341            })
342            .collect();
343
344        scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
345
346        if let Some((node, _)) = scored.first() {
347            RoutingDecision {
348                target: RouteTarget::Node(node.id.clone()),
349                reason: RoutingReason::LowestScore,
350                node_info: Some((*node).clone()),
351                ..Default::default()
352            }
353        } else {
354            RoutingDecision::default_routing()
355        }
356    }
357
358    /// Route to nodes with specific temperature
359    fn route_to_temperature_nodes(
360        &self,
361        temp: DataTemperature,
362        analysis: &QueryAnalysis,
363    ) -> RoutingDecision {
364        // Find nodes that host tables with matching temperature
365        let matching_nodes: Vec<_> = self
366            .nodes
367            .iter()
368            .filter(|n| match temp {
369                DataTemperature::Hot => n.capabilities.in_memory,
370                DataTemperature::Warm => !n.capabilities.in_memory && !self.is_cold_storage(n),
371                DataTemperature::Cold | DataTemperature::Frozen => self.is_cold_storage(n),
372            })
373            .cloned()
374            .collect();
375
376        if matching_nodes.is_empty() {
377            return self.route_mixed(&self.nodes, analysis);
378        }
379
380        self.select_best(&matching_nodes, analysis)
381    }
382
383    /// Check if node is cold storage
384    fn is_cold_storage(&self, node: &NodeInfo) -> bool {
385        // Heuristic: cold storage nodes have no in-memory capability
386        // and are not primary/sync
387        !node.capabilities.in_memory && node.sync_mode == SyncMode::Async && !node.is_primary
388    }
389
390    /// Select best node from candidates
391    fn select_best(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
392        if nodes.is_empty() {
393            return RoutingDecision::default_routing();
394        }
395
396        // Sort by latency for read queries, load for analytics
397        let mut sorted = nodes.to_vec();
398        if analysis.workload_type == WorkloadType::OLAP {
399            sorted.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
400        } else {
401            sorted.sort_by_key(|n| n.current_latency_ms);
402        }
403
404        let node = &sorted[0];
405        RoutingDecision {
406            target: RouteTarget::Node(node.id.clone()),
407            reason: RoutingReason::BestCandidate,
408            node_info: Some(node.clone()),
409            ..Default::default()
410        }
411    }
412
413    /// Apply routing preference
414    fn apply_preference(
415        &self,
416        preference: RoutingPreference,
417        analysis: &QueryAnalysis,
418    ) -> RoutingDecision {
419        match preference {
420            RoutingPreference::VectorNodes { prefer_gpu } => {
421                let nodes: Vec<_> = self
422                    .nodes
423                    .iter()
424                    .filter(|n| n.capabilities.vector_search)
425                    .filter(|n| !prefer_gpu || n.capabilities.gpu_acceleration)
426                    .cloned()
427                    .collect();
428                self.select_best(&nodes, analysis)
429            }
430            RoutingPreference::LowLatency { max_lag_ms } => {
431                let nodes: Vec<_> = self
432                    .nodes
433                    .iter()
434                    .filter(|n| n.current_latency_ms <= max_lag_ms)
435                    .cloned()
436                    .collect();
437                self.select_best(&nodes, analysis)
438            }
439            RoutingPreference::HighThroughput => {
440                let nodes: Vec<_> = self
441                    .nodes
442                    .iter()
443                    .filter(|n| n.sync_mode == SyncMode::Async)
444                    .cloned()
445                    .collect();
446                self.select_best(&nodes, analysis)
447            }
448            RoutingPreference::Primary => RoutingDecision {
449                target: RouteTarget::Primary,
450                reason: RoutingReason::AIWorkload,
451                ..Default::default()
452            },
453        }
454    }
455}
456
457impl NodeCapabilities {
458    /// Check if there are any requirements
459    fn has_requirements(&self) -> bool {
460        self.vector_search
461            || self.gpu_acceleration
462            || self.columnar_storage
463            || self.in_memory
464            || self.content_addressed
465    }
466}
467
468/// Routing decision
469#[derive(Debug, Clone, Default)]
470pub struct RoutingDecision {
471    /// Target for routing
472    pub target: RouteTarget,
473    /// Reason for decision
474    pub reason: RoutingReason,
475    /// Target shards (for scatter-gather)
476    pub shards: Vec<u32>,
477    /// Branch context
478    pub branch: Option<String>,
479    /// Selected node info
480    pub node_info: Option<NodeInfo>,
481}
482
483impl RoutingDecision {
484    /// Create a shard routing decision
485    pub fn shard(shard_id: u32) -> Self {
486        Self {
487            target: RouteTarget::Shard(shard_id),
488            reason: RoutingReason::ShardKey,
489            ..Default::default()
490        }
491    }
492
493    /// Create a single node routing decision
494    pub fn single(node: NodeInfo) -> Self {
495        Self {
496            target: RouteTarget::Node(node.id.clone()),
497            reason: RoutingReason::BestCandidate,
498            node_info: Some(node),
499            ..Default::default()
500        }
501    }
502
503    /// Create default routing (to primary)
504    pub fn default_routing() -> Self {
505        Self {
506            target: RouteTarget::Primary,
507            reason: RoutingReason::Default,
508            ..Default::default()
509        }
510    }
511
512    /// Check if routing to primary
513    pub fn is_primary(&self) -> bool {
514        matches!(self.target, RouteTarget::Primary)
515    }
516
517    /// Check if scatter-gather needed
518    pub fn is_scatter_gather(&self) -> bool {
519        matches!(self.target, RouteTarget::ScatterGather)
520    }
521}
522
523/// Route target
524#[derive(Debug, Clone, Default)]
525pub enum RouteTarget {
526    /// Route to primary
527    #[default]
528    Primary,
529    /// Route to specific node
530    Node(String),
531    /// Route to specific shard
532    Shard(u32),
533    /// Scatter-gather across shards
534    ScatterGather,
535}
536
537/// Routing reason
538#[derive(Debug, Clone, Default)]
539pub enum RoutingReason {
540    /// Default routing
541    #[default]
542    Default,
543    /// Write query must go to primary
544    WriteQuery,
545    /// Shard key present in query
546    ShardKey,
547    /// Lowest latency node
548    LowLatency,
549    /// Node with columnar storage
550    ColumnarStorage,
551    /// Vector-capable node
552    VectorCapable,
553    /// No vector-capable nodes available
554    NoVectorNodes,
555    /// Branch not available on eligible nodes
556    BranchNotAvailable,
557    /// Best candidate from scoring
558    BestCandidate,
559    /// Lowest combined score
560    LowestScore,
561    /// AI workload routing
562    AIWorkload,
563}
564
565/// Routing preference for AI workloads
566#[derive(Debug, Clone)]
567pub enum RoutingPreference {
568    /// Prefer vector-capable nodes
569    VectorNodes { prefer_gpu: bool },
570    /// Prefer low-latency nodes
571    LowLatency { max_lag_ms: u64 },
572    /// Prefer high-throughput nodes
573    HighThroughput,
574    /// Must route to primary
575    Primary,
576}
577
578/// AI workload type
579#[derive(Debug, Clone, Copy, PartialEq, Eq)]
580pub enum AIWorkloadType {
581    /// Embedding retrieval (vector search)
582    EmbeddingRetrieval,
583    /// Context/conversation lookup
584    ContextLookup,
585    /// Knowledge base query
586    KnowledgeBase,
587    /// Tool execution (writes)
588    ToolExecution,
589}
590
591/// AI workload detector
592#[derive(Debug, Default)]
593pub struct AIWorkloadDetector {
594    /// Patterns for detection
595    patterns: Vec<AIPattern>,
596}
597
598#[derive(Debug)]
599struct AIPattern {
600    keyword: String,
601    workload_type: AIWorkloadType,
602}
603
604impl AIWorkloadDetector {
605    /// Create a new detector
606    pub fn new() -> Self {
607        Self {
608            patterns: vec![
609                AIPattern {
610                    keyword: "<->".to_string(),
611                    workload_type: AIWorkloadType::EmbeddingRetrieval,
612                },
613                AIPattern {
614                    keyword: "VECTOR".to_string(),
615                    workload_type: AIWorkloadType::EmbeddingRetrieval,
616                },
617                AIPattern {
618                    keyword: "EMBEDDING".to_string(),
619                    workload_type: AIWorkloadType::EmbeddingRetrieval,
620                },
621                AIPattern {
622                    keyword: "CONVERSATION".to_string(),
623                    workload_type: AIWorkloadType::ContextLookup,
624                },
625                AIPattern {
626                    keyword: "TURNS".to_string(),
627                    workload_type: AIWorkloadType::ContextLookup,
628                },
629                AIPattern {
630                    keyword: "DOCUMENTS".to_string(),
631                    workload_type: AIWorkloadType::KnowledgeBase,
632                },
633                AIPattern {
634                    keyword: "CHUNKS".to_string(),
635                    workload_type: AIWorkloadType::KnowledgeBase,
636                },
637                AIPattern {
638                    keyword: "TOOL_RESULTS".to_string(),
639                    workload_type: AIWorkloadType::ToolExecution,
640                },
641                AIPattern {
642                    keyword: "ACTIONS".to_string(),
643                    workload_type: AIWorkloadType::ToolExecution,
644                },
645            ],
646        }
647    }
648
649    /// Detect AI workload type
650    pub fn detect(&self, query: &str) -> Option<AIWorkloadType> {
651        let upper = query.to_uppercase();
652
653        for pattern in &self.patterns {
654            if upper.contains(&pattern.keyword) {
655                return Some(pattern.workload_type);
656            }
657        }
658
659        None
660    }
661
662    /// Get optimal routing for AI workload
663    pub fn get_optimal_routing(&self, workload: AIWorkloadType) -> RoutingPreference {
664        match workload {
665            AIWorkloadType::EmbeddingRetrieval => {
666                RoutingPreference::VectorNodes { prefer_gpu: true }
667            }
668            AIWorkloadType::ContextLookup => RoutingPreference::LowLatency { max_lag_ms: 100 },
669            AIWorkloadType::KnowledgeBase => RoutingPreference::HighThroughput,
670            AIWorkloadType::ToolExecution => RoutingPreference::Primary,
671        }
672    }
673}
674
675/// RAG stage
676#[derive(Debug, Clone, Copy, PartialEq, Eq)]
677pub enum RAGStage {
678    /// Retrieval stage (vector search)
679    Retrieval,
680    /// Fetch stage (document lookup)
681    Fetch,
682    /// Rerank stage
683    Rerank,
684    /// Generation stage
685    Generate,
686}
687
688/// RAG router
689#[derive(Debug, Default)]
690pub struct RAGRouter {}
691
692impl RAGRouter {
693    /// Create a new RAG router
694    pub fn new() -> Self {
695        Self {}
696    }
697
698    /// Route RAG query based on stage
699    pub fn route_rag_query(
700        &self,
701        stage: RAGStage,
702        analysis: &QueryAnalysis,
703        nodes: &[NodeInfo],
704    ) -> RoutingDecision {
705        match stage {
706            RAGStage::Retrieval => {
707                // Vector search on embeddings
708                let vector_nodes: Vec<_> = nodes
709                    .iter()
710                    .filter(|n| n.capabilities.vector_search)
711                    .cloned()
712                    .collect();
713
714                if let Some(node) = vector_nodes.first() {
715                    RoutingDecision::single(node.clone())
716                } else {
717                    RoutingDecision::default_routing()
718                }
719            }
720            RAGStage::Fetch => {
721                // Bulk fetch - high throughput
722                let throughput_nodes: Vec<_> = nodes
723                    .iter()
724                    .filter(|n| n.sync_mode == SyncMode::Async)
725                    .cloned()
726                    .collect();
727
728                if let Some(node) = throughput_nodes.first() {
729                    RoutingDecision::single(node.clone())
730                } else {
731                    RoutingDecision::default_routing()
732                }
733            }
734            RAGStage::Rerank => {
735                // Light computation - lowest latency
736                let mut sorted = nodes.to_vec();
737                sorted.sort_by_key(|n| n.current_latency_ms);
738
739                if let Some(node) = sorted.first() {
740                    RoutingDecision::single(node.clone())
741                } else {
742                    RoutingDecision::default_routing()
743                }
744            }
745            RAGStage::Generate => {
746                // May write to cache - check if write
747                if !analysis.is_read_only {
748                    RoutingDecision::default_routing()
749                } else {
750                    let mut sorted = nodes.to_vec();
751                    sorted.sort_by_key(|n| n.current_latency_ms);
752
753                    if let Some(node) = sorted.first() {
754                        RoutingDecision::single(node.clone())
755                    } else {
756                        RoutingDecision::default_routing()
757                    }
758                }
759            }
760        }
761    }
762}
763
764#[cfg(test)]
765mod tests {
766    use super::*;
767    use crate::schema_routing::registry::TableSchema;
768
769    fn create_test_setup() -> SchemaAwareRouter {
770        let registry = Arc::new(SchemaRegistry::new());
771
772        registry.register_table(
773            TableSchema::new("users")
774                .with_workload(WorkloadType::OLTP)
775                .with_access_pattern(AccessPattern::PointLookup)
776                .with_primary_key(vec!["id".to_string()]),
777        );
778
779        registry.register_table(
780            TableSchema::new("events")
781                .with_workload(WorkloadType::OLAP)
782                .with_temperature(DataTemperature::Cold),
783        );
784
785        registry.register_table(TableSchema::new("embeddings").with_workload(WorkloadType::Vector));
786
787        let config = SchemaRoutingConfig::default();
788        let mut router = SchemaAwareRouter::new(config, registry);
789
790        // Add test nodes
791        router.add_node(NodeInfo::new("primary", "primary").as_primary());
792        router
793            .add_node(NodeInfo::new("standby-sync", "standby-sync").with_sync_mode(SyncMode::Sync));
794        router.add_node(
795            NodeInfo::new("standby-async", "standby-async")
796                .with_sync_mode(SyncMode::Async)
797                .with_capabilities(NodeCapabilities::analytics_node()),
798        );
799        router.add_node(
800            NodeInfo::new("vector-node", "vector-node")
801                .with_sync_mode(SyncMode::Async)
802                .with_capabilities(NodeCapabilities::vector_node()),
803        );
804
805        router
806    }
807
808    #[test]
809    fn test_route_oltp_read() {
810        let router = create_test_setup();
811        let decision = router.route("SELECT * FROM users WHERE id = 1");
812
813        assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::LowLatency));
814    }
815
816    #[test]
817    fn test_route_write_to_primary() {
818        let router = create_test_setup();
819        let decision = router.route("INSERT INTO users (name) VALUES ('test')");
820
821        assert!(decision.is_primary());
822        assert!(matches!(decision.reason, RoutingReason::WriteQuery));
823    }
824
825    #[test]
826    fn test_route_vector_query() {
827        let router = create_test_setup();
828        let decision =
829            router.route("SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10");
830
831        assert!(
832            matches!(
833                decision.reason,
834                RoutingReason::VectorCapable | RoutingReason::BestCandidate
835            ) || decision.is_primary()
836        );
837    }
838
839    #[test]
840    fn test_route_olap_query() {
841        let router = create_test_setup();
842        let decision = router.route("SELECT COUNT(*), SUM(amount) FROM events GROUP BY date");
843
844        // Should prefer columnar storage or async nodes
845        assert!(
846            !decision.is_primary()
847                || matches!(
848                    decision.reason,
849                    RoutingReason::ColumnarStorage | RoutingReason::Default
850                )
851        );
852    }
853
854    #[test]
855    fn test_ai_workload_detection() {
856        let detector = AIWorkloadDetector::new();
857
858        let embedding = "SELECT * FROM embeddings ORDER BY vector <-> $1";
859        let context = "SELECT * FROM conversation WHERE session_id = $1";
860        let tool = "INSERT INTO tool_results (result) VALUES ($1)";
861
862        assert_eq!(
863            detector.detect(embedding),
864            Some(AIWorkloadType::EmbeddingRetrieval)
865        );
866        assert_eq!(
867            detector.detect(context),
868            Some(AIWorkloadType::ContextLookup)
869        );
870        assert_eq!(detector.detect(tool), Some(AIWorkloadType::ToolExecution));
871    }
872
873    #[test]
874    fn test_rag_routing() {
875        let router = create_test_setup();
876
877        let retrieval = router.route_rag(RAGStage::Retrieval, "SELECT embedding FROM docs");
878        let fetch = router.route_rag(
879            RAGStage::Fetch,
880            "SELECT content FROM docs WHERE id IN (1,2,3)",
881        );
882
883        // Retrieval should prefer vector nodes
884        // Fetch should prefer high throughput
885        assert!(retrieval.node_info.is_some() || retrieval.is_primary());
886        assert!(fetch.node_info.is_some() || fetch.is_primary());
887    }
888
889    #[test]
890    fn test_routing_decision_helpers() {
891        let decision = RoutingDecision::shard(3);
892        assert!(matches!(decision.target, RouteTarget::Shard(3)));
893
894        let default = RoutingDecision::default_routing();
895        assert!(default.is_primary());
896    }
897}