1use std::sync::Arc;
6
7use super::{
8 NodeInfo, SyncMode, SchemaRoutingConfig,
9 registry::{SchemaRegistry, NodeCapabilities, DataTemperature, WorkloadType, AccessPattern},
10 analyzer::{QueryAnalyzer, QueryAnalysis},
11};
12
13#[derive(Debug)]
15pub struct SchemaAwareRouter {
16 config: SchemaRoutingConfig,
18 schema: Arc<SchemaRegistry>,
20 analyzer: QueryAnalyzer,
22 nodes: Vec<NodeInfo>,
24 ai_detector: AIWorkloadDetector,
26 rag_router: RAGRouter,
28}
29
30impl SchemaAwareRouter {
31 pub fn new(config: SchemaRoutingConfig, schema: Arc<SchemaRegistry>) -> Self {
33 Self {
34 analyzer: QueryAnalyzer::new(schema.clone()),
35 schema,
36 config,
37 nodes: Vec::new(),
38 ai_detector: AIWorkloadDetector::new(),
39 rag_router: RAGRouter::new(),
40 }
41 }
42
43 pub fn add_node(&mut self, node: NodeInfo) {
45 self.nodes.push(node);
46 }
47
48 pub fn remove_node(&mut self, node_id: &str) {
50 self.nodes.retain(|n| n.id != node_id);
51 }
52
53 pub fn update_node(&mut self, node_id: &str, load: f64, latency_ms: u64) {
55 if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
56 node.current_load = load;
57 node.current_latency_ms = latency_ms;
58 }
59 }
60
61 pub fn route(&self, query: &str) -> RoutingDecision {
63 if !self.config.enabled {
64 return RoutingDecision::default_routing();
65 }
66
67 let analysis = self.analyzer.analyze(query);
68
69 if let Some(ai_workload) = self.ai_detector.detect(query) {
71 let preference = self.ai_detector.get_optimal_routing(ai_workload);
72 return self.apply_preference(preference, &analysis);
73 }
74
75 let required_caps = self.get_required_capabilities(&analysis);
77
78 let eligible = self.filter_by_capabilities(&required_caps);
80
81 if let Some(shard_routing) = self.try_shard_routing(&analysis) {
83 return shard_routing;
84 }
85
86 match analysis.workload_type {
88 WorkloadType::OLTP => self.route_oltp(&eligible, &analysis),
89 WorkloadType::OLAP => self.route_olap(&eligible, &analysis),
90 WorkloadType::Vector => self.route_vector(&eligible, &analysis),
91 WorkloadType::HTAP | WorkloadType::Mixed => self.route_mixed(&eligible, &analysis),
92 }
93 }
94
95 pub fn route_with_branch(&self, query: &str, branch: &str) -> RoutingDecision {
97 let analysis = self.analyzer.analyze(query);
98
99 let branch_nodes = self.schema.get_branch_locations(branch);
101
102 let required_caps = self.get_required_capabilities(&analysis);
104 let eligible = self.filter_by_capabilities(&required_caps);
105
106 let available: Vec<_> = eligible
108 .iter()
109 .filter(|n| branch_nodes.contains(&n.id))
110 .cloned()
111 .collect();
112
113 if available.is_empty() {
114 return RoutingDecision {
116 target: RouteTarget::Primary,
117 reason: RoutingReason::BranchNotAvailable,
118 branch: Some(branch.to_string()),
119 ..Default::default()
120 };
121 }
122
123 self.select_best(&available, &analysis)
124 }
125
126 pub fn route_time_travel(&self, query: &str, age_days: i64) -> RoutingDecision {
128 let analysis = self.analyzer.analyze(query);
129
130 if age_days < 7 {
132 return self.route_to_temperature_nodes(DataTemperature::Hot, &analysis);
133 }
134
135 if age_days < 30 {
137 return self.route_to_temperature_nodes(DataTemperature::Warm, &analysis);
138 }
139
140 self.route_to_temperature_nodes(DataTemperature::Cold, &analysis)
142 }
143
144 pub fn route_rag(&self, stage: RAGStage, query: &str) -> RoutingDecision {
146 let analysis = self.analyzer.analyze(query);
147 self.rag_router.route_rag_query(stage, &analysis, &self.nodes)
148 }
149
150 fn get_required_capabilities(&self, analysis: &QueryAnalysis) -> NodeCapabilities {
152 let mut caps = NodeCapabilities::default();
153
154 if analysis.access_patterns.contains(&AccessPattern::VectorSearch) {
156 caps.vector_search = true;
157 caps.gpu_acceleration = true; }
159
160 if analysis.workload_type == WorkloadType::OLAP {
162 caps.columnar_storage = true;
163 }
164
165 for table in &analysis.tables {
167 if let Some(schema) = &table.schema {
168 if schema.temperature == DataTemperature::Hot {
169 caps.in_memory = true;
170 }
171 }
172 }
173
174 caps
175 }
176
177 fn filter_by_capabilities(&self, required: &NodeCapabilities) -> Vec<NodeInfo> {
179 self.nodes
180 .iter()
181 .filter(|n| n.capabilities.satisfies(required) || !required.has_requirements())
182 .cloned()
183 .collect()
184 }
185
186 fn try_shard_routing(&self, analysis: &QueryAnalysis) -> Option<RoutingDecision> {
188 for table in &analysis.tables {
189 if let Some(schema) = &table.schema {
190 if let Some(shard_key) = &schema.shard_key {
191 if let Some(shard_value) = analysis.shard_keys.get(shard_key) {
192 let value = match shard_value {
193 super::analyzer::ShardKeyValue::Single(v) => v.clone(),
194 super::analyzer::ShardKeyValue::Multiple(v) => {
195 return Some(RoutingDecision {
197 target: RouteTarget::ScatterGather,
198 shards: v.iter().filter_map(|val| {
199 self.schema.get_shard(shard_key, val)
200 }).collect(),
201 reason: RoutingReason::ShardKey,
202 ..Default::default()
203 });
204 }
205 };
206
207 if let Some(shard) = self.schema.get_shard(shard_key, &value) {
208 return Some(RoutingDecision {
209 target: RouteTarget::Shard(shard),
210 reason: RoutingReason::ShardKey,
211 ..Default::default()
212 });
213 }
214 }
215 }
216 }
217 }
218 None
219 }
220
221 fn route_oltp(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
223 if !analysis.is_read_only {
225 return RoutingDecision {
226 target: RouteTarget::Primary,
227 reason: RoutingReason::WriteQuery,
228 ..Default::default()
229 };
230 }
231
232 let mut preferred: Vec<_> = nodes
234 .iter()
235 .filter(|n| n.sync_mode == SyncMode::Sync || n.is_primary)
236 .cloned()
237 .collect();
238
239 preferred.sort_by_key(|n| n.current_latency_ms);
240
241 if let Some(node) = preferred.first() {
242 RoutingDecision {
243 target: RouteTarget::Node(node.id.clone()),
244 reason: RoutingReason::LowLatency,
245 node_info: Some(node.clone()),
246 ..Default::default()
247 }
248 } else {
249 RoutingDecision::default_routing()
250 }
251 }
252
253 fn route_olap(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
255 let mut preferred: Vec<_> = nodes
257 .iter()
258 .filter(|n| n.capabilities.columnar_storage)
259 .cloned()
260 .collect();
261
262 if preferred.is_empty() {
263 preferred = nodes
265 .iter()
266 .filter(|n| n.sync_mode == SyncMode::Async)
267 .cloned()
268 .collect();
269 }
270
271 preferred.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
272
273 if let Some(node) = preferred.first() {
274 RoutingDecision {
275 target: RouteTarget::Node(node.id.clone()),
276 reason: RoutingReason::ColumnarStorage,
277 node_info: Some(node.clone()),
278 ..Default::default()
279 }
280 } else {
281 RoutingDecision::default_routing()
282 }
283 }
284
285 fn route_vector(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
287 let mut vector_nodes: Vec<_> = nodes
289 .iter()
290 .filter(|n| n.capabilities.vector_search)
291 .cloned()
292 .collect();
293
294 vector_nodes.sort_by(|a, b| {
296 b.capabilities.gpu_acceleration
297 .cmp(&a.capabilities.gpu_acceleration)
298 .then_with(|| a.current_load.partial_cmp(&b.current_load).unwrap())
299 });
300
301 if let Some(node) = vector_nodes.first() {
302 RoutingDecision {
303 target: RouteTarget::Node(node.id.clone()),
304 reason: RoutingReason::VectorCapable,
305 node_info: Some(node.clone()),
306 ..Default::default()
307 }
308 } else {
309 RoutingDecision {
311 target: RouteTarget::Primary,
312 reason: RoutingReason::NoVectorNodes,
313 ..Default::default()
314 }
315 }
316 }
317
318 fn route_mixed(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
320 if !analysis.is_read_only {
322 return RoutingDecision {
323 target: RouteTarget::Primary,
324 reason: RoutingReason::WriteQuery,
325 ..Default::default()
326 };
327 }
328
329 let mut scored: Vec<_> = nodes
331 .iter()
332 .map(|n| {
333 let score = (n.current_latency_ms as f64) + (n.current_load * 100.0);
334 (n, score)
335 })
336 .collect();
337
338 scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
339
340 if let Some((node, _)) = scored.first() {
341 RoutingDecision {
342 target: RouteTarget::Node(node.id.clone()),
343 reason: RoutingReason::LowestScore,
344 node_info: Some((*node).clone()),
345 ..Default::default()
346 }
347 } else {
348 RoutingDecision::default_routing()
349 }
350 }
351
352 fn route_to_temperature_nodes(&self, temp: DataTemperature, analysis: &QueryAnalysis) -> RoutingDecision {
354 let matching_nodes: Vec<_> = self.nodes
356 .iter()
357 .filter(|n| {
358 match temp {
359 DataTemperature::Hot => n.capabilities.in_memory,
360 DataTemperature::Warm => !n.capabilities.in_memory && !self.is_cold_storage(n),
361 DataTemperature::Cold | DataTemperature::Frozen => self.is_cold_storage(n),
362 }
363 })
364 .cloned()
365 .collect();
366
367 if matching_nodes.is_empty() {
368 return self.route_mixed(&self.nodes, analysis);
369 }
370
371 self.select_best(&matching_nodes, analysis)
372 }
373
374 fn is_cold_storage(&self, node: &NodeInfo) -> bool {
376 !node.capabilities.in_memory && node.sync_mode == SyncMode::Async && !node.is_primary
379 }
380
381 fn select_best(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
383 if nodes.is_empty() {
384 return RoutingDecision::default_routing();
385 }
386
387 let mut sorted = nodes.to_vec();
389 if analysis.workload_type == WorkloadType::OLAP {
390 sorted.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
391 } else {
392 sorted.sort_by_key(|n| n.current_latency_ms);
393 }
394
395 let node = &sorted[0];
396 RoutingDecision {
397 target: RouteTarget::Node(node.id.clone()),
398 reason: RoutingReason::BestCandidate,
399 node_info: Some(node.clone()),
400 ..Default::default()
401 }
402 }
403
404 fn apply_preference(&self, preference: RoutingPreference, analysis: &QueryAnalysis) -> RoutingDecision {
406 match preference {
407 RoutingPreference::VectorNodes { prefer_gpu } => {
408 let nodes: Vec<_> = self.nodes
409 .iter()
410 .filter(|n| n.capabilities.vector_search)
411 .filter(|n| !prefer_gpu || n.capabilities.gpu_acceleration)
412 .cloned()
413 .collect();
414 self.select_best(&nodes, analysis)
415 }
416 RoutingPreference::LowLatency { max_lag_ms } => {
417 let nodes: Vec<_> = self.nodes
418 .iter()
419 .filter(|n| n.current_latency_ms <= max_lag_ms)
420 .cloned()
421 .collect();
422 self.select_best(&nodes, analysis)
423 }
424 RoutingPreference::HighThroughput => {
425 let nodes: Vec<_> = self.nodes
426 .iter()
427 .filter(|n| n.sync_mode == SyncMode::Async)
428 .cloned()
429 .collect();
430 self.select_best(&nodes, analysis)
431 }
432 RoutingPreference::Primary => {
433 RoutingDecision {
434 target: RouteTarget::Primary,
435 reason: RoutingReason::AIWorkload,
436 ..Default::default()
437 }
438 }
439 }
440 }
441}
442
443impl NodeCapabilities {
444 fn has_requirements(&self) -> bool {
446 self.vector_search || self.gpu_acceleration || self.columnar_storage
447 || self.in_memory || self.content_addressed
448 }
449}
450
451#[derive(Debug, Clone, Default)]
453pub struct RoutingDecision {
454 pub target: RouteTarget,
456 pub reason: RoutingReason,
458 pub shards: Vec<u32>,
460 pub branch: Option<String>,
462 pub node_info: Option<NodeInfo>,
464}
465
466impl RoutingDecision {
467 pub fn shard(shard_id: u32) -> Self {
469 Self {
470 target: RouteTarget::Shard(shard_id),
471 reason: RoutingReason::ShardKey,
472 ..Default::default()
473 }
474 }
475
476 pub fn single(node: NodeInfo) -> Self {
478 Self {
479 target: RouteTarget::Node(node.id.clone()),
480 reason: RoutingReason::BestCandidate,
481 node_info: Some(node),
482 ..Default::default()
483 }
484 }
485
486 pub fn default_routing() -> Self {
488 Self {
489 target: RouteTarget::Primary,
490 reason: RoutingReason::Default,
491 ..Default::default()
492 }
493 }
494
495 pub fn is_primary(&self) -> bool {
497 matches!(self.target, RouteTarget::Primary)
498 }
499
500 pub fn is_scatter_gather(&self) -> bool {
502 matches!(self.target, RouteTarget::ScatterGather)
503 }
504}
505
506#[derive(Debug, Clone, Default)]
508pub enum RouteTarget {
509 #[default]
511 Primary,
512 Node(String),
514 Shard(u32),
516 ScatterGather,
518}
519
520#[derive(Debug, Clone, Default)]
522pub enum RoutingReason {
523 #[default]
525 Default,
526 WriteQuery,
528 ShardKey,
530 LowLatency,
532 ColumnarStorage,
534 VectorCapable,
536 NoVectorNodes,
538 BranchNotAvailable,
540 BestCandidate,
542 LowestScore,
544 AIWorkload,
546}
547
548#[derive(Debug, Clone)]
550pub enum RoutingPreference {
551 VectorNodes { prefer_gpu: bool },
553 LowLatency { max_lag_ms: u64 },
555 HighThroughput,
557 Primary,
559}
560
561#[derive(Debug, Clone, Copy, PartialEq, Eq)]
563pub enum AIWorkloadType {
564 EmbeddingRetrieval,
566 ContextLookup,
568 KnowledgeBase,
570 ToolExecution,
572}
573
574#[derive(Debug, Default)]
576pub struct AIWorkloadDetector {
577 patterns: Vec<AIPattern>,
579}
580
581#[derive(Debug)]
582struct AIPattern {
583 keyword: String,
584 workload_type: AIWorkloadType,
585}
586
587impl AIWorkloadDetector {
588 pub fn new() -> Self {
590 Self {
591 patterns: vec![
592 AIPattern { keyword: "<->".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
593 AIPattern { keyword: "VECTOR".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
594 AIPattern { keyword: "EMBEDDING".to_string(), workload_type: AIWorkloadType::EmbeddingRetrieval },
595 AIPattern { keyword: "CONVERSATION".to_string(), workload_type: AIWorkloadType::ContextLookup },
596 AIPattern { keyword: "TURNS".to_string(), workload_type: AIWorkloadType::ContextLookup },
597 AIPattern { keyword: "DOCUMENTS".to_string(), workload_type: AIWorkloadType::KnowledgeBase },
598 AIPattern { keyword: "CHUNKS".to_string(), workload_type: AIWorkloadType::KnowledgeBase },
599 AIPattern { keyword: "TOOL_RESULTS".to_string(), workload_type: AIWorkloadType::ToolExecution },
600 AIPattern { keyword: "ACTIONS".to_string(), workload_type: AIWorkloadType::ToolExecution },
601 ],
602 }
603 }
604
605 pub fn detect(&self, query: &str) -> Option<AIWorkloadType> {
607 let upper = query.to_uppercase();
608
609 for pattern in &self.patterns {
610 if upper.contains(&pattern.keyword) {
611 return Some(pattern.workload_type);
612 }
613 }
614
615 None
616 }
617
618 pub fn get_optimal_routing(&self, workload: AIWorkloadType) -> RoutingPreference {
620 match workload {
621 AIWorkloadType::EmbeddingRetrieval => {
622 RoutingPreference::VectorNodes { prefer_gpu: true }
623 }
624 AIWorkloadType::ContextLookup => {
625 RoutingPreference::LowLatency { max_lag_ms: 100 }
626 }
627 AIWorkloadType::KnowledgeBase => {
628 RoutingPreference::HighThroughput
629 }
630 AIWorkloadType::ToolExecution => {
631 RoutingPreference::Primary
632 }
633 }
634 }
635}
636
637#[derive(Debug, Clone, Copy, PartialEq, Eq)]
639pub enum RAGStage {
640 Retrieval,
642 Fetch,
644 Rerank,
646 Generate,
648}
649
650#[derive(Debug, Default)]
652pub struct RAGRouter {}
653
654impl RAGRouter {
655 pub fn new() -> Self {
657 Self {}
658 }
659
660 pub fn route_rag_query(&self, stage: RAGStage, analysis: &QueryAnalysis, nodes: &[NodeInfo]) -> RoutingDecision {
662 match stage {
663 RAGStage::Retrieval => {
664 let vector_nodes: Vec<_> = nodes
666 .iter()
667 .filter(|n| n.capabilities.vector_search)
668 .cloned()
669 .collect();
670
671 if let Some(node) = vector_nodes.first() {
672 RoutingDecision::single(node.clone())
673 } else {
674 RoutingDecision::default_routing()
675 }
676 }
677 RAGStage::Fetch => {
678 let throughput_nodes: Vec<_> = nodes
680 .iter()
681 .filter(|n| n.sync_mode == SyncMode::Async)
682 .cloned()
683 .collect();
684
685 if let Some(node) = throughput_nodes.first() {
686 RoutingDecision::single(node.clone())
687 } else {
688 RoutingDecision::default_routing()
689 }
690 }
691 RAGStage::Rerank => {
692 let mut sorted = nodes.to_vec();
694 sorted.sort_by_key(|n| n.current_latency_ms);
695
696 if let Some(node) = sorted.first() {
697 RoutingDecision::single(node.clone())
698 } else {
699 RoutingDecision::default_routing()
700 }
701 }
702 RAGStage::Generate => {
703 if !analysis.is_read_only {
705 RoutingDecision::default_routing()
706 } else {
707 let mut sorted = nodes.to_vec();
708 sorted.sort_by_key(|n| n.current_latency_ms);
709
710 if let Some(node) = sorted.first() {
711 RoutingDecision::single(node.clone())
712 } else {
713 RoutingDecision::default_routing()
714 }
715 }
716 }
717 }
718 }
719}
720
721#[cfg(test)]
722mod tests {
723 use super::*;
724 use crate::schema_routing::registry::TableSchema;
725
726 fn create_test_setup() -> SchemaAwareRouter {
727 let registry = Arc::new(SchemaRegistry::new());
728
729 registry.register_table(
730 TableSchema::new("users")
731 .with_workload(WorkloadType::OLTP)
732 .with_access_pattern(AccessPattern::PointLookup)
733 .with_primary_key(vec!["id".to_string()])
734 );
735
736 registry.register_table(
737 TableSchema::new("events")
738 .with_workload(WorkloadType::OLAP)
739 .with_temperature(DataTemperature::Cold)
740 );
741
742 registry.register_table(
743 TableSchema::new("embeddings")
744 .with_workload(WorkloadType::Vector)
745 );
746
747 let config = SchemaRoutingConfig::default();
748 let mut router = SchemaAwareRouter::new(config, registry);
749
750 router.add_node(NodeInfo::new("primary", "primary").as_primary());
752 router.add_node(NodeInfo::new("standby-sync", "standby-sync")
753 .with_sync_mode(SyncMode::Sync));
754 router.add_node(NodeInfo::new("standby-async", "standby-async")
755 .with_sync_mode(SyncMode::Async)
756 .with_capabilities(NodeCapabilities::analytics_node()));
757 router.add_node(NodeInfo::new("vector-node", "vector-node")
758 .with_sync_mode(SyncMode::Async)
759 .with_capabilities(NodeCapabilities::vector_node()));
760
761 router
762 }
763
764 #[test]
765 fn test_route_oltp_read() {
766 let router = create_test_setup();
767 let decision = router.route("SELECT * FROM users WHERE id = 1");
768
769 assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::LowLatency));
770 }
771
772 #[test]
773 fn test_route_write_to_primary() {
774 let router = create_test_setup();
775 let decision = router.route("INSERT INTO users (name) VALUES ('test')");
776
777 assert!(decision.is_primary());
778 assert!(matches!(decision.reason, RoutingReason::WriteQuery));
779 }
780
781 #[test]
782 fn test_route_vector_query() {
783 let router = create_test_setup();
784 let decision = router.route("SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10");
785
786 assert!(matches!(decision.reason, RoutingReason::VectorCapable | RoutingReason::BestCandidate) || decision.is_primary());
787 }
788
789 #[test]
790 fn test_route_olap_query() {
791 let router = create_test_setup();
792 let decision = router.route("SELECT COUNT(*), SUM(amount) FROM events GROUP BY date");
793
794 assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::ColumnarStorage | RoutingReason::Default));
796 }
797
798 #[test]
799 fn test_ai_workload_detection() {
800 let detector = AIWorkloadDetector::new();
801
802 let embedding = "SELECT * FROM embeddings ORDER BY vector <-> $1";
803 let context = "SELECT * FROM conversation WHERE session_id = $1";
804 let tool = "INSERT INTO tool_results (result) VALUES ($1)";
805
806 assert_eq!(detector.detect(embedding), Some(AIWorkloadType::EmbeddingRetrieval));
807 assert_eq!(detector.detect(context), Some(AIWorkloadType::ContextLookup));
808 assert_eq!(detector.detect(tool), Some(AIWorkloadType::ToolExecution));
809 }
810
811 #[test]
812 fn test_rag_routing() {
813 let router = create_test_setup();
814
815 let retrieval = router.route_rag(RAGStage::Retrieval, "SELECT embedding FROM docs");
816 let fetch = router.route_rag(RAGStage::Fetch, "SELECT content FROM docs WHERE id IN (1,2,3)");
817
818 assert!(retrieval.node_info.is_some() || retrieval.is_primary());
821 assert!(fetch.node_info.is_some() || fetch.is_primary());
822 }
823
824 #[test]
825 fn test_routing_decision_helpers() {
826 let decision = RoutingDecision::shard(3);
827 assert!(matches!(decision.target, RouteTarget::Shard(3)));
828
829 let default = RoutingDecision::default_routing();
830 assert!(default.is_primary());
831 }
832}