use std::sync::Arc;
use super::{
analyzer::{QueryAnalysis, QueryAnalyzer},
registry::{AccessPattern, DataTemperature, NodeCapabilities, SchemaRegistry, WorkloadType},
NodeInfo, SchemaRoutingConfig, SyncMode,
};
#[derive(Debug)]
pub struct SchemaAwareRouter {
config: SchemaRoutingConfig,
schema: Arc<SchemaRegistry>,
analyzer: QueryAnalyzer,
nodes: Vec<NodeInfo>,
ai_detector: AIWorkloadDetector,
rag_router: RAGRouter,
}
impl SchemaAwareRouter {
pub fn new(config: SchemaRoutingConfig, schema: Arc<SchemaRegistry>) -> Self {
Self {
analyzer: QueryAnalyzer::new(schema.clone()),
schema,
config,
nodes: Vec::new(),
ai_detector: AIWorkloadDetector::new(),
rag_router: RAGRouter::new(),
}
}
pub fn add_node(&mut self, node: NodeInfo) {
self.nodes.push(node);
}
pub fn remove_node(&mut self, node_id: &str) {
self.nodes.retain(|n| n.id != node_id);
}
pub fn update_node(&mut self, node_id: &str, load: f64, latency_ms: u64) {
if let Some(node) = self.nodes.iter_mut().find(|n| n.id == node_id) {
node.current_load = load;
node.current_latency_ms = latency_ms;
}
}
pub fn route(&self, query: &str) -> RoutingDecision {
if !self.config.enabled {
return RoutingDecision::default_routing();
}
let analysis = self.analyzer.analyze(query);
if let Some(ai_workload) = self.ai_detector.detect(query) {
let preference = self.ai_detector.get_optimal_routing(ai_workload);
return self.apply_preference(preference, &analysis);
}
let required_caps = self.get_required_capabilities(&analysis);
let eligible = self.filter_by_capabilities(&required_caps);
if let Some(shard_routing) = self.try_shard_routing(&analysis) {
return shard_routing;
}
match analysis.workload_type {
WorkloadType::OLTP => self.route_oltp(&eligible, &analysis),
WorkloadType::OLAP => self.route_olap(&eligible, &analysis),
WorkloadType::Vector => self.route_vector(&eligible, &analysis),
WorkloadType::HTAP | WorkloadType::Mixed => self.route_mixed(&eligible, &analysis),
}
}
pub fn route_with_branch(&self, query: &str, branch: &str) -> RoutingDecision {
let analysis = self.analyzer.analyze(query);
let branch_nodes = self.schema.get_branch_locations(branch);
let required_caps = self.get_required_capabilities(&analysis);
let eligible = self.filter_by_capabilities(&required_caps);
let available: Vec<_> = eligible
.iter()
.filter(|n| branch_nodes.contains(&n.id))
.cloned()
.collect();
if available.is_empty() {
return RoutingDecision {
target: RouteTarget::Primary,
reason: RoutingReason::BranchNotAvailable,
branch: Some(branch.to_string()),
..Default::default()
};
}
self.select_best(&available, &analysis)
}
pub fn route_time_travel(&self, query: &str, age_days: i64) -> RoutingDecision {
let analysis = self.analyzer.analyze(query);
if age_days < 7 {
return self.route_to_temperature_nodes(DataTemperature::Hot, &analysis);
}
if age_days < 30 {
return self.route_to_temperature_nodes(DataTemperature::Warm, &analysis);
}
self.route_to_temperature_nodes(DataTemperature::Cold, &analysis)
}
pub fn route_rag(&self, stage: RAGStage, query: &str) -> RoutingDecision {
let analysis = self.analyzer.analyze(query);
self.rag_router
.route_rag_query(stage, &analysis, &self.nodes)
}
fn get_required_capabilities(&self, analysis: &QueryAnalysis) -> NodeCapabilities {
let mut caps = NodeCapabilities::default();
if analysis
.access_patterns
.contains(&AccessPattern::VectorSearch)
{
caps.vector_search = true;
caps.gpu_acceleration = true; }
if analysis.workload_type == WorkloadType::OLAP {
caps.columnar_storage = true;
}
for table in &analysis.tables {
if let Some(schema) = &table.schema {
if schema.temperature == DataTemperature::Hot {
caps.in_memory = true;
}
}
}
caps
}
fn filter_by_capabilities(&self, required: &NodeCapabilities) -> Vec<NodeInfo> {
self.nodes
.iter()
.filter(|n| n.capabilities.satisfies(required) || !required.has_requirements())
.cloned()
.collect()
}
fn try_shard_routing(&self, analysis: &QueryAnalysis) -> Option<RoutingDecision> {
for table in &analysis.tables {
if let Some(schema) = &table.schema {
if let Some(shard_key) = &schema.shard_key {
if let Some(shard_value) = analysis.shard_keys.get(shard_key) {
let value = match shard_value {
super::analyzer::ShardKeyValue::Single(v) => v.clone(),
super::analyzer::ShardKeyValue::Multiple(v) => {
return Some(RoutingDecision {
target: RouteTarget::ScatterGather,
shards: v
.iter()
.filter_map(|val| self.schema.get_shard(shard_key, val))
.collect(),
reason: RoutingReason::ShardKey,
..Default::default()
});
}
};
if let Some(shard) = self.schema.get_shard(shard_key, &value) {
return Some(RoutingDecision {
target: RouteTarget::Shard(shard),
reason: RoutingReason::ShardKey,
..Default::default()
});
}
}
}
}
}
None
}
fn route_oltp(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
if !analysis.is_read_only {
return RoutingDecision {
target: RouteTarget::Primary,
reason: RoutingReason::WriteQuery,
..Default::default()
};
}
let mut preferred: Vec<_> = nodes
.iter()
.filter(|n| n.sync_mode == SyncMode::Sync || n.is_primary)
.cloned()
.collect();
preferred.sort_by_key(|n| n.current_latency_ms);
if let Some(node) = preferred.first() {
RoutingDecision {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::LowLatency,
node_info: Some(node.clone()),
..Default::default()
}
} else {
RoutingDecision::default_routing()
}
}
fn route_olap(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
let mut preferred: Vec<_> = nodes
.iter()
.filter(|n| n.capabilities.columnar_storage)
.cloned()
.collect();
if preferred.is_empty() {
preferred = nodes
.iter()
.filter(|n| n.sync_mode == SyncMode::Async)
.cloned()
.collect();
}
preferred.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
if let Some(node) = preferred.first() {
RoutingDecision {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::ColumnarStorage,
node_info: Some(node.clone()),
..Default::default()
}
} else {
RoutingDecision::default_routing()
}
}
fn route_vector(&self, nodes: &[NodeInfo], _analysis: &QueryAnalysis) -> RoutingDecision {
let mut vector_nodes: Vec<_> = nodes
.iter()
.filter(|n| n.capabilities.vector_search)
.cloned()
.collect();
vector_nodes.sort_by(|a, b| {
b.capabilities
.gpu_acceleration
.cmp(&a.capabilities.gpu_acceleration)
.then_with(|| a.current_load.partial_cmp(&b.current_load).unwrap())
});
if let Some(node) = vector_nodes.first() {
RoutingDecision {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::VectorCapable,
node_info: Some(node.clone()),
..Default::default()
}
} else {
RoutingDecision {
target: RouteTarget::Primary,
reason: RoutingReason::NoVectorNodes,
..Default::default()
}
}
}
fn route_mixed(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
if !analysis.is_read_only {
return RoutingDecision {
target: RouteTarget::Primary,
reason: RoutingReason::WriteQuery,
..Default::default()
};
}
let mut scored: Vec<_> = nodes
.iter()
.map(|n| {
let score = (n.current_latency_ms as f64) + (n.current_load * 100.0);
(n, score)
})
.collect();
scored.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
if let Some((node, _)) = scored.first() {
RoutingDecision {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::LowestScore,
node_info: Some((*node).clone()),
..Default::default()
}
} else {
RoutingDecision::default_routing()
}
}
fn route_to_temperature_nodes(
&self,
temp: DataTemperature,
analysis: &QueryAnalysis,
) -> RoutingDecision {
let matching_nodes: Vec<_> = self
.nodes
.iter()
.filter(|n| match temp {
DataTemperature::Hot => n.capabilities.in_memory,
DataTemperature::Warm => !n.capabilities.in_memory && !self.is_cold_storage(n),
DataTemperature::Cold | DataTemperature::Frozen => self.is_cold_storage(n),
})
.cloned()
.collect();
if matching_nodes.is_empty() {
return self.route_mixed(&self.nodes, analysis);
}
self.select_best(&matching_nodes, analysis)
}
fn is_cold_storage(&self, node: &NodeInfo) -> bool {
!node.capabilities.in_memory && node.sync_mode == SyncMode::Async && !node.is_primary
}
fn select_best(&self, nodes: &[NodeInfo], analysis: &QueryAnalysis) -> RoutingDecision {
if nodes.is_empty() {
return RoutingDecision::default_routing();
}
let mut sorted = nodes.to_vec();
if analysis.workload_type == WorkloadType::OLAP {
sorted.sort_by(|a, b| a.current_load.partial_cmp(&b.current_load).unwrap());
} else {
sorted.sort_by_key(|n| n.current_latency_ms);
}
let node = &sorted[0];
RoutingDecision {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::BestCandidate,
node_info: Some(node.clone()),
..Default::default()
}
}
fn apply_preference(
&self,
preference: RoutingPreference,
analysis: &QueryAnalysis,
) -> RoutingDecision {
match preference {
RoutingPreference::VectorNodes { prefer_gpu } => {
let nodes: Vec<_> = self
.nodes
.iter()
.filter(|n| n.capabilities.vector_search)
.filter(|n| !prefer_gpu || n.capabilities.gpu_acceleration)
.cloned()
.collect();
self.select_best(&nodes, analysis)
}
RoutingPreference::LowLatency { max_lag_ms } => {
let nodes: Vec<_> = self
.nodes
.iter()
.filter(|n| n.current_latency_ms <= max_lag_ms)
.cloned()
.collect();
self.select_best(&nodes, analysis)
}
RoutingPreference::HighThroughput => {
let nodes: Vec<_> = self
.nodes
.iter()
.filter(|n| n.sync_mode == SyncMode::Async)
.cloned()
.collect();
self.select_best(&nodes, analysis)
}
RoutingPreference::Primary => RoutingDecision {
target: RouteTarget::Primary,
reason: RoutingReason::AIWorkload,
..Default::default()
},
}
}
}
impl NodeCapabilities {
fn has_requirements(&self) -> bool {
self.vector_search
|| self.gpu_acceleration
|| self.columnar_storage
|| self.in_memory
|| self.content_addressed
}
}
#[derive(Debug, Clone, Default)]
pub struct RoutingDecision {
pub target: RouteTarget,
pub reason: RoutingReason,
pub shards: Vec<u32>,
pub branch: Option<String>,
pub node_info: Option<NodeInfo>,
}
impl RoutingDecision {
pub fn shard(shard_id: u32) -> Self {
Self {
target: RouteTarget::Shard(shard_id),
reason: RoutingReason::ShardKey,
..Default::default()
}
}
pub fn single(node: NodeInfo) -> Self {
Self {
target: RouteTarget::Node(node.id.clone()),
reason: RoutingReason::BestCandidate,
node_info: Some(node),
..Default::default()
}
}
pub fn default_routing() -> Self {
Self {
target: RouteTarget::Primary,
reason: RoutingReason::Default,
..Default::default()
}
}
pub fn is_primary(&self) -> bool {
matches!(self.target, RouteTarget::Primary)
}
pub fn is_scatter_gather(&self) -> bool {
matches!(self.target, RouteTarget::ScatterGather)
}
}
#[derive(Debug, Clone, Default)]
pub enum RouteTarget {
#[default]
Primary,
Node(String),
Shard(u32),
ScatterGather,
}
#[derive(Debug, Clone, Default)]
pub enum RoutingReason {
#[default]
Default,
WriteQuery,
ShardKey,
LowLatency,
ColumnarStorage,
VectorCapable,
NoVectorNodes,
BranchNotAvailable,
BestCandidate,
LowestScore,
AIWorkload,
}
#[derive(Debug, Clone)]
pub enum RoutingPreference {
VectorNodes { prefer_gpu: bool },
LowLatency { max_lag_ms: u64 },
HighThroughput,
Primary,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AIWorkloadType {
EmbeddingRetrieval,
ContextLookup,
KnowledgeBase,
ToolExecution,
}
#[derive(Debug, Default)]
pub struct AIWorkloadDetector {
patterns: Vec<AIPattern>,
}
#[derive(Debug)]
struct AIPattern {
keyword: String,
workload_type: AIWorkloadType,
}
impl AIWorkloadDetector {
pub fn new() -> Self {
Self {
patterns: vec![
AIPattern {
keyword: "<->".to_string(),
workload_type: AIWorkloadType::EmbeddingRetrieval,
},
AIPattern {
keyword: "VECTOR".to_string(),
workload_type: AIWorkloadType::EmbeddingRetrieval,
},
AIPattern {
keyword: "EMBEDDING".to_string(),
workload_type: AIWorkloadType::EmbeddingRetrieval,
},
AIPattern {
keyword: "CONVERSATION".to_string(),
workload_type: AIWorkloadType::ContextLookup,
},
AIPattern {
keyword: "TURNS".to_string(),
workload_type: AIWorkloadType::ContextLookup,
},
AIPattern {
keyword: "DOCUMENTS".to_string(),
workload_type: AIWorkloadType::KnowledgeBase,
},
AIPattern {
keyword: "CHUNKS".to_string(),
workload_type: AIWorkloadType::KnowledgeBase,
},
AIPattern {
keyword: "TOOL_RESULTS".to_string(),
workload_type: AIWorkloadType::ToolExecution,
},
AIPattern {
keyword: "ACTIONS".to_string(),
workload_type: AIWorkloadType::ToolExecution,
},
],
}
}
pub fn detect(&self, query: &str) -> Option<AIWorkloadType> {
let upper = query.to_uppercase();
for pattern in &self.patterns {
if upper.contains(&pattern.keyword) {
return Some(pattern.workload_type);
}
}
None
}
pub fn get_optimal_routing(&self, workload: AIWorkloadType) -> RoutingPreference {
match workload {
AIWorkloadType::EmbeddingRetrieval => {
RoutingPreference::VectorNodes { prefer_gpu: true }
}
AIWorkloadType::ContextLookup => RoutingPreference::LowLatency { max_lag_ms: 100 },
AIWorkloadType::KnowledgeBase => RoutingPreference::HighThroughput,
AIWorkloadType::ToolExecution => RoutingPreference::Primary,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RAGStage {
Retrieval,
Fetch,
Rerank,
Generate,
}
#[derive(Debug, Default)]
pub struct RAGRouter {}
impl RAGRouter {
pub fn new() -> Self {
Self {}
}
pub fn route_rag_query(
&self,
stage: RAGStage,
analysis: &QueryAnalysis,
nodes: &[NodeInfo],
) -> RoutingDecision {
match stage {
RAGStage::Retrieval => {
let vector_nodes: Vec<_> = nodes
.iter()
.filter(|n| n.capabilities.vector_search)
.cloned()
.collect();
if let Some(node) = vector_nodes.first() {
RoutingDecision::single(node.clone())
} else {
RoutingDecision::default_routing()
}
}
RAGStage::Fetch => {
let throughput_nodes: Vec<_> = nodes
.iter()
.filter(|n| n.sync_mode == SyncMode::Async)
.cloned()
.collect();
if let Some(node) = throughput_nodes.first() {
RoutingDecision::single(node.clone())
} else {
RoutingDecision::default_routing()
}
}
RAGStage::Rerank => {
let mut sorted = nodes.to_vec();
sorted.sort_by_key(|n| n.current_latency_ms);
if let Some(node) = sorted.first() {
RoutingDecision::single(node.clone())
} else {
RoutingDecision::default_routing()
}
}
RAGStage::Generate => {
if !analysis.is_read_only {
RoutingDecision::default_routing()
} else {
let mut sorted = nodes.to_vec();
sorted.sort_by_key(|n| n.current_latency_ms);
if let Some(node) = sorted.first() {
RoutingDecision::single(node.clone())
} else {
RoutingDecision::default_routing()
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema_routing::registry::TableSchema;
fn create_test_setup() -> SchemaAwareRouter {
let registry = Arc::new(SchemaRegistry::new());
registry.register_table(
TableSchema::new("users")
.with_workload(WorkloadType::OLTP)
.with_access_pattern(AccessPattern::PointLookup)
.with_primary_key(vec!["id".to_string()]),
);
registry.register_table(
TableSchema::new("events")
.with_workload(WorkloadType::OLAP)
.with_temperature(DataTemperature::Cold),
);
registry.register_table(TableSchema::new("embeddings").with_workload(WorkloadType::Vector));
let config = SchemaRoutingConfig::default();
let mut router = SchemaAwareRouter::new(config, registry);
router.add_node(NodeInfo::new("primary", "primary").as_primary());
router
.add_node(NodeInfo::new("standby-sync", "standby-sync").with_sync_mode(SyncMode::Sync));
router.add_node(
NodeInfo::new("standby-async", "standby-async")
.with_sync_mode(SyncMode::Async)
.with_capabilities(NodeCapabilities::analytics_node()),
);
router.add_node(
NodeInfo::new("vector-node", "vector-node")
.with_sync_mode(SyncMode::Async)
.with_capabilities(NodeCapabilities::vector_node()),
);
router
}
#[test]
fn test_route_oltp_read() {
let router = create_test_setup();
let decision = router.route("SELECT * FROM users WHERE id = 1");
assert!(!decision.is_primary() || matches!(decision.reason, RoutingReason::LowLatency));
}
#[test]
fn test_route_write_to_primary() {
let router = create_test_setup();
let decision = router.route("INSERT INTO users (name) VALUES ('test')");
assert!(decision.is_primary());
assert!(matches!(decision.reason, RoutingReason::WriteQuery));
}
#[test]
fn test_route_vector_query() {
let router = create_test_setup();
let decision =
router.route("SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10");
assert!(
matches!(
decision.reason,
RoutingReason::VectorCapable | RoutingReason::BestCandidate
) || decision.is_primary()
);
}
#[test]
fn test_route_olap_query() {
let router = create_test_setup();
let decision = router.route("SELECT COUNT(*), SUM(amount) FROM events GROUP BY date");
assert!(
!decision.is_primary()
|| matches!(
decision.reason,
RoutingReason::ColumnarStorage | RoutingReason::Default
)
);
}
#[test]
fn test_ai_workload_detection() {
let detector = AIWorkloadDetector::new();
let embedding = "SELECT * FROM embeddings ORDER BY vector <-> $1";
let context = "SELECT * FROM conversation WHERE session_id = $1";
let tool = "INSERT INTO tool_results (result) VALUES ($1)";
assert_eq!(
detector.detect(embedding),
Some(AIWorkloadType::EmbeddingRetrieval)
);
assert_eq!(
detector.detect(context),
Some(AIWorkloadType::ContextLookup)
);
assert_eq!(detector.detect(tool), Some(AIWorkloadType::ToolExecution));
}
#[test]
fn test_rag_routing() {
let router = create_test_setup();
let retrieval = router.route_rag(RAGStage::Retrieval, "SELECT embedding FROM docs");
let fetch = router.route_rag(
RAGStage::Fetch,
"SELECT content FROM docs WHERE id IN (1,2,3)",
);
assert!(retrieval.node_info.is_some() || retrieval.is_primary());
assert!(fetch.node_info.is_some() || fetch.is_primary());
}
#[test]
fn test_routing_decision_helpers() {
let decision = RoutingDecision::shard(3);
assert!(matches!(decision.target, RouteTarget::Shard(3)));
let default = RoutingDecision::default_routing();
assert!(default.is_primary());
}
}