use super::{
HintParser, ParsedHints, RouteTarget,
NodeFilter, NodeCriteria, NodeInfo, FilterResult,
RoutingConfig, RoutingError, RoutingMetrics, Result,
};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
pub struct QueryRouter {
parser: HintParser,
filter: NodeFilter,
nodes: Arc<RwLock<Vec<NodeInfo>>>,
metrics: Arc<RoutingMetrics>,
config: RoutingConfig,
rr_counter: std::sync::atomic::AtomicU64,
}
impl QueryRouter {
pub fn new(config: RoutingConfig) -> Self {
let filter = NodeFilter::new(config.clone());
Self {
parser: HintParser::new(),
filter,
nodes: Arc::new(RwLock::new(Vec::new())),
metrics: Arc::new(RoutingMetrics::new()),
config,
rr_counter: std::sync::atomic::AtomicU64::new(0),
}
}
pub async fn route(&self, query: &str) -> RoutingDecision {
let start = Instant::now();
let hints = self.parser.parse(query);
if let Err(e) = hints.validate() {
self.metrics.record_invalid_hints();
return RoutingDecision::error(e.to_string());
}
let is_write = self.is_write_query(query);
let mut criteria = if !hints.is_empty() {
NodeCriteria::from_hints(&hints)
} else if is_write {
self.filter.default_criteria_for_write()
} else {
self.filter.default_criteria_for_read()
};
if is_write && criteria.route.is_none() {
criteria.route = Some(RouteTarget::Primary);
}
let nodes = self.nodes.read().await;
let filter_result = self.filter.filter(&nodes, &criteria);
let decision = if filter_result.has_matches() {
let selected = self.select_node(&filter_result);
self.metrics.record_routing(
criteria.route,
!hints.is_empty(),
start.elapsed(),
);
RoutingDecision {
target_node: Some(selected.name.clone()),
hints: hints.clone(),
reason: RoutingReason::Routed {
target: criteria.route,
filters_applied: filter_result.reasons.clone(),
},
elapsed: start.elapsed(),
is_write,
}
} else {
let fallback = self.try_fallback(&nodes, is_write);
if let Some(node) = fallback {
self.metrics.record_fallback();
RoutingDecision {
target_node: Some(node.name.clone()),
hints: hints.clone(),
reason: RoutingReason::Fallback {
original_filters: filter_result.reasons.clone(),
},
elapsed: start.elapsed(),
is_write,
}
} else {
self.metrics.record_no_nodes();
RoutingDecision {
target_node: None,
hints: hints.clone(),
reason: RoutingReason::NoNodes {
filters: filter_result.reasons.clone(),
},
elapsed: start.elapsed(),
is_write,
}
}
};
decision
}
pub async fn route_with_criteria(&self, criteria: &NodeCriteria) -> Result<String> {
let nodes = self.nodes.read().await;
let filter_result = self.filter.filter(&nodes, criteria);
filter_result
.require_match("routing")
.map(|n| n.name.clone())
}
pub fn is_write_query(&self, query: &str) -> bool {
if !self.config.default.auto_detect_writes {
return false;
}
let upper = query.trim().to_uppercase();
let first_word = upper.split_whitespace().next().unwrap_or("");
matches!(
first_word,
"INSERT" | "UPDATE" | "DELETE" | "CREATE" | "ALTER" | "DROP" |
"TRUNCATE" | "GRANT" | "REVOKE" | "MERGE" | "UPSERT" |
"BEGIN" | "START" | "COMMIT" | "ROLLBACK" | "SAVEPOINT" |
"LOCK" | "PREPARE" | "EXECUTE" | "DEALLOCATE"
)
}
fn select_node<'a>(&self, result: &FilterResult<'a>) -> &'a NodeInfo {
if result.eligible.is_empty() {
panic!("select_node called with no eligible nodes");
}
if result.eligible.len() == 1 {
return result.eligible[0];
}
let idx = self.rr_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
let selected_idx = (idx as usize) % result.eligible.len();
result.eligible[selected_idx]
}
fn try_fallback<'a>(&self, nodes: &'a [NodeInfo], is_write: bool) -> Option<&'a NodeInfo> {
if is_write {
nodes.iter().find(|n| n.role == super::node_filter::NodeRole::Primary && n.healthy)
} else {
nodes.iter().find(|n| n.healthy && n.enabled)
}
}
pub fn strip_hints(&self, query: &str) -> String {
if self.config.hints.strip_hints {
self.parser.strip(query)
} else {
query.to_string()
}
}
pub fn parse_hints(&self, query: &str) -> ParsedHints {
self.parser.parse(query)
}
pub async fn add_node(&self, node: NodeInfo) {
self.nodes.write().await.push(node);
}
pub async fn remove_node(&self, name: &str) {
self.nodes.write().await.retain(|n| n.name != name);
}
pub async fn update_node<F>(&self, name: &str, f: F)
where
F: FnOnce(&mut NodeInfo),
{
let mut nodes = self.nodes.write().await;
if let Some(node) = nodes.iter_mut().find(|n| n.name == name) {
f(node);
}
}
pub fn metrics(&self) -> &RoutingMetrics {
&self.metrics
}
pub fn config(&self) -> &RoutingConfig {
&self.config
}
}
#[derive(Debug, Clone)]
pub struct RoutingDecision {
pub target_node: Option<String>,
pub hints: ParsedHints,
pub reason: RoutingReason,
pub elapsed: Duration,
pub is_write: bool,
}
impl RoutingDecision {
pub fn error(message: String) -> Self {
Self {
target_node: None,
hints: ParsedHints::default(),
reason: RoutingReason::Error { message },
elapsed: Duration::ZERO,
is_write: false,
}
}
pub fn is_success(&self) -> bool {
self.target_node.is_some()
}
pub fn require_target(&self) -> Result<&str> {
self.target_node
.as_deref()
.ok_or_else(|| RoutingError::NoMatchingNodes(self.reason.to_string()))
}
pub fn summary(&self) -> String {
match &self.reason {
RoutingReason::Routed { target, .. } => {
format!(
"Routed to {} ({:?}) in {:?}",
self.target_node.as_deref().unwrap_or("unknown"),
target,
self.elapsed
)
}
RoutingReason::Fallback { .. } => {
format!(
"Fallback to {} in {:?}",
self.target_node.as_deref().unwrap_or("unknown"),
self.elapsed
)
}
RoutingReason::NoNodes { filters } => {
format!("No nodes available (filters: {:?})", filters)
}
RoutingReason::Error { message } => {
format!("Error: {}", message)
}
}
}
}
#[derive(Debug, Clone)]
pub enum RoutingReason {
Routed {
target: Option<RouteTarget>,
filters_applied: Vec<String>,
},
Fallback {
original_filters: Vec<String>,
},
NoNodes {
filters: Vec<String>,
},
Error {
message: String,
},
}
impl std::fmt::Display for RoutingReason {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RoutingReason::Routed { target, .. } => {
write!(f, "routed to {:?}", target)
}
RoutingReason::Fallback { .. } => {
write!(f, "fallback")
}
RoutingReason::NoNodes { filters } => {
write!(f, "no nodes ({})", filters.join(", "))
}
RoutingReason::Error { message } => {
write!(f, "error: {}", message)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::node_filter::SyncMode;
async fn setup_router() -> QueryRouter {
let router = QueryRouter::new(RoutingConfig::default());
router.add_node(NodeInfo::primary("primary")).await;
router.add_node(NodeInfo::standby("standby-sync-1", SyncMode::Sync)).await;
router.add_node(NodeInfo::standby("standby-async-1", SyncMode::Async)
.with_lag(100)).await;
router.add_node(NodeInfo::standby("standby-async-2", SyncMode::Async)
.with_lag(200)).await;
router
}
#[tokio::test]
async fn test_route_read_query() {
let router = setup_router().await;
let decision = router.route("SELECT * FROM users").await;
assert!(decision.is_success());
assert!(!decision.is_write);
}
#[tokio::test]
async fn test_route_write_query() {
let router = setup_router().await;
let decision = router.route("INSERT INTO users (name) VALUES ('test')").await;
assert!(decision.is_success());
assert!(decision.is_write);
assert_eq!(decision.target_node.as_deref(), Some("primary"));
}
#[tokio::test]
async fn test_route_with_primary_hint() {
let router = setup_router().await;
let decision = router.route("/*helios:route=primary*/ SELECT * FROM users").await;
assert!(decision.is_success());
assert_eq!(decision.target_node.as_deref(), Some("primary"));
}
#[tokio::test]
async fn test_route_with_sync_hint() {
let router = setup_router().await;
let decision = router.route("/*helios:route=sync*/ SELECT * FROM users").await;
assert!(decision.is_success());
assert_eq!(decision.target_node.as_deref(), Some("standby-sync-1"));
}
#[tokio::test]
async fn test_route_with_node_hint() {
let router = setup_router().await;
let decision = router.route("/*helios:node=standby-async-1*/ SELECT * FROM users").await;
assert!(decision.is_success());
assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
}
#[tokio::test]
async fn test_route_with_lag_hint() {
let router = setup_router().await;
let decision = router.route("/*helios:route=async,lag=150ms*/ SELECT * FROM users").await;
assert!(decision.is_success());
assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
}
#[tokio::test]
async fn test_route_no_matching_nodes() {
let router = setup_router().await;
let decision = router.route("/*helios:node=nonexistent*/ SELECT * FROM users").await;
assert!(decision.is_success()); }
#[tokio::test]
async fn test_is_write_query() {
let router = QueryRouter::new(RoutingConfig::default());
assert!(router.is_write_query("INSERT INTO users VALUES (1)"));
assert!(router.is_write_query("UPDATE users SET name = 'test'"));
assert!(router.is_write_query("DELETE FROM users"));
assert!(router.is_write_query("CREATE TABLE test (id INT)"));
assert!(router.is_write_query("BEGIN"));
assert!(router.is_write_query("COMMIT"));
assert!(!router.is_write_query("SELECT * FROM users"));
assert!(!router.is_write_query("WITH cte AS (SELECT 1) SELECT * FROM cte"));
}
#[tokio::test]
async fn test_strip_hints() {
let router = QueryRouter::new(RoutingConfig::default());
let stripped = router.strip_hints("/*helios:route=primary*/ SELECT * FROM users");
assert_eq!(stripped, "SELECT * FROM users");
}
#[tokio::test]
async fn test_invalid_hint_combination() {
let router = setup_router().await;
let decision = router.route(
"/*helios:route=async,consistency=strong*/ SELECT * FROM users"
).await;
assert!(!decision.is_success());
}
#[tokio::test]
async fn test_metrics_tracking() {
let router = setup_router().await;
router.route("SELECT * FROM users").await;
router.route("/*helios:route=primary*/ SELECT * FROM accounts").await;
router.route("INSERT INTO users VALUES (1)").await;
let stats = router.metrics().snapshot();
assert!(stats.total_routed >= 3);
assert!(stats.with_hints >= 1);
}
}