1use std::sync::Arc;
6
7use super::{
8 analyzer::{QueryAnalysis, QueryAnalyzer},
9 registry::{AccessPattern, DataTemperature, NodeCapabilities, SchemaRegistry, WorkloadType},
10 NodeInfo, SchemaRoutingConfig, SyncMode,
11};
12
13#[derive(Debug)]
15pub struct SchemaAwareRouter {
16 config: SchemaRoutingConfig,
18 schema: Arc<SchemaRegistry>,
20 analyzer: QueryAnalyzer,
22 nodes: Vec<NodeInfo>,
24 ai_detector: AIWorkloadDetector,
26 rag_router: RAGRouter,
28}
29
30impl SchemaAwareRouter {
31 pub fn new(config: SchemaRoutingConfig, schema: Arc<SchemaRegistry>) -> Self {
33 Self {
34 analyzer: QueryAnalyzer::new(schema.clone()),
35 schema,
36 config,
37 nodes: Vec::new(),
38 ai_detector: AIWorkloadDetector::new(),
39 rag_router: RAGRouter::new(),
40 }
41 }
42
43 pub fn add_node(&mut self, node: NodeInfo) {
45 self.nodes.push(node);
46 }
47
48 pub fn remove_node(&mut self, node_id: &str) {
50 self.nodes.retain(|n| n.id != node_id);
51 }
52
53 pub fn update_node(&mut self, node_id: &str, load: f64, latency_ms: u64) {
55 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
56 node.current_load = load;
57 node.current_latency_ms = latency_ms;
58 }
59 }
60
61 pub fn route(&self, query: &str) -> RoutingDecision {
63 if !self.config.enabled {
64 return RoutingDecision::default_routing();
65 }
66
67 let analysis = self.analyzer.analyze(query);
68
69 if let Some(ai_workload) = self.ai_detector.detect(query) {
71 let preference = self.ai_detector.get_optimal_routing(ai_workload);
72 return self.apply_preference(preference, &analysis);
73 }
74
75 let required_caps = self.get_required_capabilities(&analysis);
77
78 let eligible = self.filter_by_capabilities(&required_caps);
80
81 if let Some(shard_routing) = self.try_shard_routing(&analysis) {
83 return shard_routing;
84 }
85
86 match analysis.workload_type {
88 WorkloadType::OLTP => self.route_oltp(&eligible, &analysis),
89 WorkloadType::OLAP => self.route_olap(&eligible, &analysis),
90 WorkloadType::Vector => self.route_vector(&eligible, &analysis),
91 WorkloadType::HTAP | WorkloadType::Mixed => self.route_mixed(&eligible, &analysis),
92 }
93 }
94
95 pub fn route_with_branch(&self, query: &str, branch: &str) -> RoutingDecision {
97 let analysis = self.analyzer.analyze(query);
98
99 let branch_nodes = self.schema.get_branch_locations(branch);
101
102 let required_caps = self.get_required_capabilities(&analysis);
104 let eligible = self.filter_by_capabilities(&required_caps);
105
106 let available: Vec<_> = eligible
108 .iter()
109 .filter(|n| branch_nodes.contains(&n.id))
110 .cloned()
111 .collect();
112
113 if available.is_empty() {
114 return RoutingDecision {
116 target: RouteTarget::Primary,
117 reason: RoutingReason::BranchNotAvailable,
118 branch: Some(branch.to_string()),
119 ..Default::default()
120 };
121 }
122
123 self.select_best(&available, &analysis)
124 }
125
126 pub fn route_time_travel(&self, query: &str, age_days: i64) -> RoutingDecision {
128 let analysis = self.analyzer.analyze(query);
129
130 if age_days < 7 {
132 return self.route_to_temperature_nodes(DataTemperature::Hot, &analysis);
133 }
134
135 if age_days < 30 {
137 return self.route_to_temperature_nodes(DataTemperature::Warm, &analysis);
138 }
139
140 self.route_to_temperature_nodes(DataTemperature::Cold, &analysis)
142 }
143
144 pub fn route_rag(&self, stage: RAGStage, query: &str) -> RoutingDecision {
146 let analysis = self.analyzer.analyze(query);
147 self.rag_router
148 .route_rag_query(stage, &analysis, &self.nodes)
149 }
150
151 fn get_required_capabilities(&self, analysis: &QueryAnalysis) -> NodeCapabilities {
153 let mut caps = NodeCapabilities::default();
154
155 if analysis
157 .access_patterns
158 .contains(&AccessPattern::VectorSearch)
159 {
160 caps.vector_search = true;
161 caps.gpu_acceleration = true; }
163
164 if analysis.workload_type == WorkloadType::OLAP {
166 caps.columnar_storage = true;
167 }
168
169 for table in &analysis.tables {
171 if let Some(schema) = &table.schema {
172 if schema.temperature == DataTemperature::Hot {
173 caps.in_memory = true;
174 }
175 }
176 }
177
178 caps
179 }
180
181 fn filter_by_capabilities(&self, required: &NodeCapabilities) -> Vec<NodeInfo> {
183 self.nodes
184 .iter()
185 .filter(|n| n.capabilities.satisfies(required) || !required.has_requirements())
186 .cloned()
187 .collect()
188 }
189
190 fn try_shard_routing(&self, analysis: &QueryAnalysis) -> Option<RoutingDecision> {
192 for table in &analysis.tables {
193 if let Some(schema) = &table.schema {
194 if let Some(shard_key) = &schema.shard_key {
195 if let Some(shard_value) = analysis.shard_keys.get(shard_key) {
196 let value = match shard_value {
197 super::analyzer::ShardKeyValue::Single(v) => v.clone(),
198 super::analyzer::ShardKeyValue::Multiple(v) => {
199 return Some(RoutingDecision {
201 target: RouteTarget::ScatterGather,
202 shards: v
203 .iter()
204 .filter_map(|val| self.schema.get_shard(shard_key, val))
205 .collect(),
206 reason: RoutingReason::ShardKey,
207 ..Default::default()
208 });
209 }
210 };
211
212 if let Some(shard) = self.schema.get_shard(shard_key, &value) {
213 return Some(RoutingDecision {
214 target: RouteTarget::Shard(shard),
215 reason: RoutingReason::ShardKey,
216 ..Default::default()
217 });
218 }
219 }
220 }
221 }
222 }
223 None
224 }
225
226 fn route_oltp(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
228 if !analysis.is_read_only {
230 return RoutingDecision {
231 target: RouteTarget::Primary,
232 reason: RoutingReason::WriteQuery,
233 ..Default::default()
234 };
235 }
236
237 let mut preferred: Vec<_> = nodes
239 .iter()
240 .filter(|n| n.sync_mode == SyncMode::Sync || n.is_primary)
241 .cloned()
242 .collect();
243
244 preferred.sort_by_key(|n| n.current_latency_ms);
245
246 if let Some(node) = preferred.first() {
247 RoutingDecision {
248 target: RouteTarget::Node(node.id.clone()),
249 reason: RoutingReason::LowLatency,
250 node_info: Some(node.clone()),
251 ..Default::default()
252 }
253 } else {
254 RoutingDecision::default_routing()
255 }
256 }
257
258 fn route_olap(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
260 let mut preferred: Vec<_> = nodes
262 .iter()
263 .filter(|n| n.capabilities.columnar_storage)
264 .cloned()
265 .collect();
266
267 if preferred.is_empty() {
268 preferred = nodes
270 .iter()
271 .filter(|n| n.sync_mode == SyncMode::Async)
272 .cloned()
273 .collect();
274 }
275
276 preferred.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
277
278 if let Some(node) = preferred.first() {
279 RoutingDecision {
280 target: RouteTarget::Node(node.id.clone()),
281 reason: RoutingReason::ColumnarStorage,
282 node_info: Some(node.clone()),
283 ..Default::default()
284 }
285 } else {
286 RoutingDecision::default_routing()
287 }
288 }
289
290 fn route_vector(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
292 let mut vector_nodes: Vec<_> = nodes
294 .iter()
295 .filter(|n| n.capabilities.vector_search)
296 .cloned()
297 .collect();
298
299 vector_nodes.sort_by(|a, b| {
301 b.capabilities
302 .gpu_acceleration
303 .cmp(&a.capabilities.gpu_acceleration)
304 .then_with(|| a.current_load.partial_cmp(&b.current_load).unwrap())
305 });
306
307 if let Some(node) = vector_nodes.first() {
308 RoutingDecision {
309 target: RouteTarget::Node(node.id.clone()),
310 reason: RoutingReason::VectorCapable,
311 node_info: Some(node.clone()),
312 ..Default::default()
313 }
314 } else {
315 RoutingDecision {
317 target: RouteTarget::Primary,
318 reason: RoutingReason::NoVectorNodes,
319 ..Default::default()
320 }
321 }
322 }
323
324 fn route_mixed(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
326 if !analysis.is_read_only {
328 return RoutingDecision {
329 target: RouteTarget::Primary,
330 reason: RoutingReason::WriteQuery,
331 ..Default::default()
332 };
333 }
334
335 let mut scored: Vec<_> = nodes
337 .iter()
338 .map(|n| {
339 let score = (n.current_latency_ms as f64) + (n.current_load * 100.0);
340 (n, score)
341 })
342 .collect();
343
344 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
345
346 if let Some((node, _)) = scored.first() {
347 RoutingDecision {
348 target: RouteTarget::Node(node.id.clone()),
349 reason: RoutingReason::LowestScore,
350 node_info: Some((*node).clone()),
351 ..Default::default()
352 }
353 } else {
354 RoutingDecision::default_routing()
355 }
356 }
357
358 fn route_to_temperature_nodes(
360 &self,
361 temp: DataTemperature,
362 analysis: &QueryAnalysis,
363 ) -> RoutingDecision {
364 let matching_nodes: Vec<_> = self
366 .nodes
367 .iter()
368 .filter(|n| match temp {
369 DataTemperature::Hot => n.capabilities.in_memory,
370 DataTemperature::Warm => !n.capabilities.in_memory && !self.is_cold_storage(n),
371 DataTemperature::Cold | DataTemperature::Frozen => self.is_cold_storage(n),
372 })
373 .cloned()
374 .collect();
375
376 if matching_nodes.is_empty() {
377 return self.route_mixed(&self.nodes, analysis);
378 }
379
380 self.select_best(&matching_nodes, analysis)
381 }
382
383 fn is_cold_storage(&self, node: &NodeInfo) -> bool {
385 !node.capabilities.in_memory && node.sync_mode == SyncMode::Async && !node.is_primary
388 }
389
390 fn select_best(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
392 if nodes.is_empty() {
393 return RoutingDecision::default_routing();
394 }
395
396 let mut sorted = nodes.to_vec();
398 if analysis.workload_type == WorkloadType::OLAP {
399 sorted.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
400 } else {
401 sorted.sort_by_key(|n| n.current_latency_ms);
402 }
403
404 let node = &sorted[0];
405 RoutingDecision {
406 target: RouteTarget::Node(node.id.clone()),
407 reason: RoutingReason::BestCandidate,
408 node_info: Some(node.clone()),
409 ..Default::default()
410 }
411 }
412
413 fn apply_preference(
415 &self,
416 preference: RoutingPreference,
417 analysis: &QueryAnalysis,
418 ) -> RoutingDecision {
419 match preference {
420 RoutingPreference::VectorNodes { prefer_gpu } => {
421 let nodes: Vec<_> = self
422 .nodes
423 .iter()
424 .filter(|n| n.capabilities.vector_search)
425 .filter(|n| !prefer_gpu || n.capabilities.gpu_acceleration)
426 .cloned()
427 .collect();
428 self.select_best(&nodes, analysis)
429 }
430 RoutingPreference::LowLatency { max_lag_ms } => {
431 let nodes: Vec<_> = self
432 .nodes
433 .iter()
434 .filter(|n| n.current_latency_ms <= max_lag_ms)
435 .cloned()
436 .collect();
437 self.select_best(&nodes, analysis)
438 }
439 RoutingPreference::HighThroughput => {
440 let nodes: Vec<_> = self
441 .nodes
442 .iter()
443 .filter(|n| n.sync_mode == SyncMode::Async)
444 .cloned()
445 .collect();
446 self.select_best(&nodes, analysis)
447 }
448 RoutingPreference::Primary => RoutingDecision {
449 target: RouteTarget::Primary,
450 reason: RoutingReason::AIWorkload,
451 ..Default::default()
452 },
453 }
454 }
455}
456
457impl NodeCapabilities {
458 fn has_requirements(&self) -> bool {
460 self.vector_search
461 || self.gpu_acceleration
462 || self.columnar_storage
463 || self.in_memory
464 || self.content_addressed
465 }
466}
467
468#[derive(Debug, Clone, Default)]
470pub struct RoutingDecision {
471 pub target: RouteTarget,
473 pub reason: RoutingReason,
475 pub shards: Vec<u32>,
477 pub branch: Option<String>,
479 pub node_info: Option<NodeInfo>,
481}
482
483impl RoutingDecision {
484 pub fn shard(shard_id: u32) -> Self {
486 Self {
487 target: RouteTarget::Shard(shard_id),
488 reason: RoutingReason::ShardKey,
489 ..Default::default()
490 }
491 }
492
493 pub fn single(node: NodeInfo) -> Self {
495 Self {
496 target: RouteTarget::Node(node.id.clone()),
497 reason: RoutingReason::BestCandidate,
498 node_info: Some(node),
499 ..Default::default()
500 }
501 }
502
503 pub fn default_routing() -> Self {
505 Self {
506 target: RouteTarget::Primary,
507 reason: RoutingReason::Default,
508 ..Default::default()
509 }
510 }
511
512 pub fn is_primary(&self) -> bool {
514 matches!(self.target, RouteTarget::Primary)
515 }
516
517 pub fn is_scatter_gather(&self) -> bool {
519 matches!(self.target, RouteTarget::ScatterGather)
520 }
521}
522
523#[derive(Debug, Clone, Default)]
525pub enum RouteTarget {
526 #[default]
528 Primary,
529 Node(String),
531 Shard(u32),
533 ScatterGather,
535}
536
537#[derive(Debug, Clone, Default)]
539pub enum RoutingReason {
540 #[default]
542 Default,
543 WriteQuery,
545 ShardKey,
547 LowLatency,
549 ColumnarStorage,
551 VectorCapable,
553 NoVectorNodes,
555 BranchNotAvailable,
557 BestCandidate,
559 LowestScore,
561 AIWorkload,
563}
564
565#[derive(Debug, Clone)]
567pub enum RoutingPreference {
568 VectorNodes { prefer_gpu: bool },
570 LowLatency { max_lag_ms: u64 },
572 HighThroughput,
574 Primary,
576}
577
578#[derive(Debug, Clone, Copy, PartialEq, Eq)]
580pub enum AIWorkloadType {
581 EmbeddingRetrieval,
583 ContextLookup,
585 KnowledgeBase,
587 ToolExecution,
589}
590
591#[derive(Debug, Default)]
593pub struct AIWorkloadDetector {
594 patterns: Vec<AIPattern>,
596}
597
598#[derive(Debug)]
599struct AIPattern {
600 keyword: String,
601 workload_type: AIWorkloadType,
602}
603
604impl AIWorkloadDetector {
605 pub fn new() -> Self {
607 Self {
608 patterns: vec![
609 AIPattern {
610 keyword: "<->".to_string(),
611 workload_type: AIWorkloadType::EmbeddingRetrieval,
612 },
613 AIPattern {
614 keyword: "VECTOR".to_string(),
615 workload_type: AIWorkloadType::EmbeddingRetrieval,
616 },
617 AIPattern {
618 keyword: "EMBEDDING".to_string(),
619 workload_type: AIWorkloadType::EmbeddingRetrieval,
620 },
621 AIPattern {
622 keyword: "CONVERSATION".to_string(),
623 workload_type: AIWorkloadType::ContextLookup,
624 },
625 AIPattern {
626 keyword: "TURNS".to_string(),
627 workload_type: AIWorkloadType::ContextLookup,
628 },
629 AIPattern {
630 keyword: "DOCUMENTS".to_string(),
631 workload_type: AIWorkloadType::KnowledgeBase,
632 },
633 AIPattern {
634 keyword: "CHUNKS".to_string(),
635 workload_type: AIWorkloadType::KnowledgeBase,
636 },
637 AIPattern {
638 keyword: "TOOL_RESULTS".to_string(),
639 workload_type: AIWorkloadType::ToolExecution,
640 },
641 AIPattern {
642 keyword: "ACTIONS".to_string(),
643 workload_type: AIWorkloadType::ToolExecution,
644 },
645 ],
646 }
647 }
648
649 pub fn detect(&self, query: &str) -> Option<AIWorkloadType> {
651 let upper = query.to_uppercase();
652
653 for pattern in &self.patterns {
654 if upper.contains(&pattern.keyword) {
655 return Some(pattern.workload_type);
656 }
657 }
658
659 None
660 }
661
662 pub fn get_optimal_routing(&self, workload: AIWorkloadType) -> RoutingPreference {
664 match workload {
665 AIWorkloadType::EmbeddingRetrieval => {
666 RoutingPreference::VectorNodes { prefer_gpu: true }
667 }
668 AIWorkloadType::ContextLookup => RoutingPreference::LowLatency { max_lag_ms: 100 },
669 AIWorkloadType::KnowledgeBase => RoutingPreference::HighThroughput,
670 AIWorkloadType::ToolExecution => RoutingPreference::Primary,
671 }
672 }
673}
674
675#[derive(Debug, Clone, Copy, PartialEq, Eq)]
677pub enum RAGStage {
678 Retrieval,
680 Fetch,
682 Rerank,
684 Generate,
686}
687
688#[derive(Debug, Default)]
690pub struct RAGRouter {}
691
692impl RAGRouter {
693 pub fn new() -> Self {
695 Self {}
696 }
697
698 pub fn route_rag_query(
700 &self,
701 stage: RAGStage,
702 analysis: &QueryAnalysis,
703 nodes: &[NodeInfo],
704 ) -> RoutingDecision {
705 match stage {
706 RAGStage::Retrieval => {
707 let vector_nodes: Vec<_> = nodes
709 .iter()
710 .filter(|n| n.capabilities.vector_search)
711 .cloned()
712 .collect();
713
714 if let Some(node) = vector_nodes.first() {
715 RoutingDecision::single(node.clone())
716 } else {
717 RoutingDecision::default_routing()
718 }
719 }
720 RAGStage::Fetch => {
721 let throughput_nodes: Vec<_> = nodes
723 .iter()
724 .filter(|n| n.sync_mode == SyncMode::Async)
725 .cloned()
726 .collect();
727
728 if let Some(node) = throughput_nodes.first() {
729 RoutingDecision::single(node.clone())
730 } else {
731 RoutingDecision::default_routing()
732 }
733 }
734 RAGStage::Rerank => {
735 let mut sorted = nodes.to_vec();
737 sorted.sort_by_key(|n| n.current_latency_ms);
738
739 if let Some(node) = sorted.first() {
740 RoutingDecision::single(node.clone())
741 } else {
742 RoutingDecision::default_routing()
743 }
744 }
745 RAGStage::Generate => {
746 if !analysis.is_read_only {
748 RoutingDecision::default_routing()
749 } else {
750 let mut sorted = nodes.to_vec();
751 sorted.sort_by_key(|n| n.current_latency_ms);
752
753 if let Some(node) = sorted.first() {
754 RoutingDecision::single(node.clone())
755 } else {
756 RoutingDecision::default_routing()
757 }
758 }
759 }
760 }
761 }
762}
763
764#[cfg(test)]
765mod tests {
766 use super::*;
767 use crate::schema_routing::registry::TableSchema;
768
769 fn create_test_setup() -> SchemaAwareRouter {
770 let registry = Arc::new(SchemaRegistry::new());
771
772 registry.register_table(
773 TableSchema::new("users")
774 .with_workload(WorkloadType::OLTP)
775 .with_access_pattern(AccessPattern::PointLookup)
776 .with_primary_key(vec!["id".to_string()]),
777 );
778
779 registry.register_table(
780 TableSchema::new("events")
781 .with_workload(WorkloadType::OLAP)
782 .with_temperature(DataTemperature::Cold),
783 );
784
785 registry.register_table(TableSchema::new("embeddings").with_workload(WorkloadType::Vector));
786
787 let config = SchemaRoutingConfig::default();
788 let mut router = SchemaAwareRouter::new(config, registry);
789
790 router.add_node(NodeInfo::new("primary", "primary").as_primary());
792 router
793 .add_node(NodeInfo::new("standby-sync", "standby-sync").with_sync_mode(SyncMode::Sync));
794 router.add_node(
795 NodeInfo::new("standby-async", "standby-async")
796 .with_sync_mode(SyncMode::Async)
797 .with_capabilities(NodeCapabilities::analytics_node()),
798 );
799 router.add_node(
800 NodeInfo::new("vector-node", "vector-node")
801 .with_sync_mode(SyncMode::Async)
802 .with_capabilities(NodeCapabilities::vector_node()),
803 );
804
805 router
806 }
807
808 #[test]
809 fn test_route_oltp_read() {
810 let router = create_test_setup();
811 let decision = router.route("SELECT * FROM users WHERE id = 1");
812
813 assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::LowLatency));
814 }
815
816 #[test]
817 fn test_route_write_to_primary() {
818 let router = create_test_setup();
819 let decision = router.route("INSERT INTO users (name) VALUES ('test')");
820
821 assert!(decision.is_primary());
822 assert!(matches!(decision.reason, RoutingReason::WriteQuery));
823 }
824
825 #[test]
826 fn test_route_vector_query() {
827 let router = create_test_setup();
828 let decision =
829 router.route("SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10");
830
831 assert!(
832 matches!(
833 decision.reason,
834 RoutingReason::VectorCapable | RoutingReason::BestCandidate
835 ) || decision.is_primary()
836 );
837 }
838
839 #[test]
840 fn test_route_olap_query() {
841 let router = create_test_setup();
842 let decision = router.route("SELECT COUNT(*), SUM(amount) FROM events GROUP BY date");
843
844 assert!(
846 !decision.is_primary()
847 || matches!(
848 decision.reason,
849 RoutingReason::ColumnarStorage | RoutingReason::Default
850 )
851 );
852 }
853
854 #[test]
855 fn test_ai_workload_detection() {
856 let detector = AIWorkloadDetector::new();
857
858 let embedding = "SELECT * FROM embeddings ORDER BY vector <-> $1";
859 let context = "SELECT * FROM conversation WHERE session_id = $1";
860 let tool = "INSERT INTO tool_results (result) VALUES ($1)";
861
862 assert_eq!(
863 detector.detect(embedding),
864 Some(AIWorkloadType::EmbeddingRetrieval)
865 );
866 assert_eq!(
867 detector.detect(context),
868 Some(AIWorkloadType::ContextLookup)
869 );
870 assert_eq!(detector.detect(tool), Some(AIWorkloadType::ToolExecution));
871 }
872
873 #[test]
874 fn test_rag_routing() {
875 let router = create_test_setup();
876
877 let retrieval = router.route_rag(RAGStage::Retrieval, "SELECT embedding FROM docs");
878 let fetch = router.route_rag(
879 RAGStage::Fetch,
880 "SELECT content FROM docs WHERE id IN (1,2,3)",
881 );
882
883 assert!(retrieval.node_info.is_some() || retrieval.is_primary());
886 assert!(fetch.node_info.is_some() || fetch.is_primary());
887 }
888
889 #[test]
890 fn test_routing_decision_helpers() {
891 let decision = RoutingDecision::shard(3);
892 assert!(matches!(decision.target, RouteTarget::Shard(3)));
893
894 let default = RoutingDecision::default_routing();
895 assert!(default.is_primary());
896 }
897}