1use super::{
6 FilterResult, HintParser, NodeCriteria, NodeFilter, NodeInfo, ParsedHints, Result, RouteTarget,
7 RoutingConfig, RoutingError, RoutingMetrics,
8};
9use std::sync::Arc;
10use std::time::{Duration, Instant};
11use tokio::sync::RwLock;
12
13pub struct QueryRouter {
15 parser: HintParser,
17 filter: NodeFilter,
19 nodes: Arc<RwLock<Vec<NodeInfo>>>,
21 metrics: Arc<RoutingMetrics>,
23 config: RoutingConfig,
25 rr_counter: std::sync::atomic::AtomicU64,
27}
28
29impl QueryRouter {
30 pub fn new(config: RoutingConfig) -> Self {
32 let filter = NodeFilter::new(config.clone());
33
34 Self {
35 parser: HintParser::new(),
36 filter,
37 nodes: Arc::new(RwLock::new(Vec::new())),
38 metrics: Arc::new(RoutingMetrics::new()),
39 config,
40 rr_counter: std::sync::atomic::AtomicU64::new(0),
41 }
42 }
43
44 pub async fn route(&self, query: &str) -> RoutingDecision {
46 let start = Instant::now();
47
48 let hints = self.parser.parse(query);
50
51 if let Err(e) = hints.validate() {
53 self.metrics.record_invalid_hints();
54 return RoutingDecision::error(e.to_string());
55 }
56
57 let is_write = self.is_write_query(query);
59
60 let mut criteria = if !hints.is_empty() {
62 NodeCriteria::from_hints(&hints)
63 } else if is_write {
64 self.filter.default_criteria_for_write()
65 } else {
66 self.filter.default_criteria_for_read()
67 };
68
69 if is_write && criteria.route.is_none() {
71 criteria.route = Some(RouteTarget::Primary);
72 }
73
74 let nodes = self.nodes.read().await;
76 let filter_result = self.filter.filter(&nodes, &criteria);
77
78 let decision = if filter_result.has_matches() {
80 let selected = self.select_node(&filter_result);
81 self.metrics
82 .record_routing(criteria.route, !hints.is_empty(), start.elapsed());
83
84 RoutingDecision {
85 target_node: Some(selected.name.clone()),
86 hints: hints.clone(),
87 reason: RoutingReason::Routed {
88 target: criteria.route,
89 filters_applied: filter_result.reasons.clone(),
90 },
91 elapsed: start.elapsed(),
92 is_write,
93 }
94 } else {
95 let fallback = self.try_fallback(&nodes, is_write);
97
98 if let Some(node) = fallback {
99 self.metrics.record_fallback();
100 RoutingDecision {
101 target_node: Some(node.name.clone()),
102 hints: hints.clone(),
103 reason: RoutingReason::Fallback {
104 original_filters: filter_result.reasons.clone(),
105 },
106 elapsed: start.elapsed(),
107 is_write,
108 }
109 } else {
110 self.metrics.record_no_nodes();
111 RoutingDecision {
112 target_node: None,
113 hints: hints.clone(),
114 reason: RoutingReason::NoNodes {
115 filters: filter_result.reasons.clone(),
116 },
117 elapsed: start.elapsed(),
118 is_write,
119 }
120 }
121 };
122
123 decision
124 }
125
126 pub async fn route_with_criteria(&self, criteria: &NodeCriteria) -> Result<String> {
128 let nodes = self.nodes.read().await;
129 let filter_result = self.filter.filter(&nodes, criteria);
130
131 filter_result
132 .require_match("routing")
133 .map(|n| n.name.clone())
134 }
135
136 pub fn is_write_query(&self, query: &str) -> bool {
138 if !self.config.default.auto_detect_writes {
139 return false;
140 }
141
142 let upper = query.trim().to_uppercase();
143 let first_word = upper.split_whitespace().next().unwrap_or("");
144
145 matches!(
146 first_word,
147 "INSERT"
148 | "UPDATE"
149 | "DELETE"
150 | "CREATE"
151 | "ALTER"
152 | "DROP"
153 | "TRUNCATE"
154 | "GRANT"
155 | "REVOKE"
156 | "MERGE"
157 | "UPSERT"
158 | "BEGIN"
159 | "START"
160 | "COMMIT"
161 | "ROLLBACK"
162 | "SAVEPOINT"
163 | "LOCK"
164 | "PREPARE"
165 | "EXECUTE"
166 | "DEALLOCATE"
167 )
168 }
169
170 fn select_node<'a>(&self, result: &FilterResult<'a>) -> &'a NodeInfo {
172 if result.eligible.is_empty() {
173 panic!("select_node called with no eligible nodes");
174 }
175
176 if result.eligible.len() == 1 {
177 return result.eligible[0];
178 }
179
180 let idx = self
182 .rr_counter
183 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
184 let selected_idx = (idx as usize) % result.eligible.len();
185 result.eligible[selected_idx]
186 }
187
188 fn try_fallback<'a>(&self, nodes: &'a [NodeInfo], is_write: bool) -> Option<&'a NodeInfo> {
190 if is_write {
191 nodes
193 .iter()
194 .find(|n| n.role == super::node_filter::NodeRole::Primary && n.healthy)
195 } else {
196 nodes.iter().find(|n| n.healthy && n.enabled)
198 }
199 }
200
201 pub fn strip_hints(&self, query: &str) -> String {
203 if self.config.hints.strip_hints {
204 self.parser.strip(query)
205 } else {
206 query.to_string()
207 }
208 }
209
210 pub fn parse_hints(&self, query: &str) -> ParsedHints {
212 self.parser.parse(query)
213 }
214
215 pub async fn add_node(&self, node: NodeInfo) {
217 self.nodes.write().await.push(node);
218 }
219
220 pub async fn remove_node(&self, name: &str) {
222 self.nodes.write().await.retain(|n| n.name != name);
223 }
224
225 pub async fn update_node<F>(&self, name: &str, f: F)
227 where
228 F: FnOnce(&mut NodeInfo),
229 {
230 let mut nodes = self.nodes.write().await;
231 if let Some(node) = nodes.iter_mut().find(|n| n.name == name) {
232 f(node);
233 }
234 }
235
236 pub fn metrics(&self) -> &RoutingMetrics {
238 &self.metrics
239 }
240
241 pub fn config(&self) -> &RoutingConfig {
243 &self.config
244 }
245}
246
247#[derive(Debug, Clone)]
249pub struct RoutingDecision {
250 pub target_node: Option<String>,
252 pub hints: ParsedHints,
254 pub reason: RoutingReason,
256 pub elapsed: Duration,
258 pub is_write: bool,
260}
261
262impl RoutingDecision {
263 pub fn error(message: String) -> Self {
265 Self {
266 target_node: None,
267 hints: ParsedHints::default(),
268 reason: RoutingReason::Error { message },
269 elapsed: Duration::ZERO,
270 is_write: false,
271 }
272 }
273
274 pub fn is_success(&self) -> bool {
276 self.target_node.is_some()
277 }
278
279 pub fn require_target(&self) -> Result<&str> {
281 self.target_node
282 .as_deref()
283 .ok_or_else(|| RoutingError::NoMatchingNodes(self.reason.to_string()))
284 }
285
286 pub fn summary(&self) -> String {
288 match &self.reason {
289 RoutingReason::Routed { target, .. } => {
290 format!(
291 "Routed to {} ({:?}) in {:?}",
292 self.target_node.as_deref().unwrap_or("unknown"),
293 target,
294 self.elapsed
295 )
296 }
297 RoutingReason::Fallback { .. } => {
298 format!(
299 "Fallback to {} in {:?}",
300 self.target_node.as_deref().unwrap_or("unknown"),
301 self.elapsed
302 )
303 }
304 RoutingReason::NoNodes { filters } => {
305 format!("No nodes available (filters: {:?})", filters)
306 }
307 RoutingReason::Error { message } => {
308 format!("Error: {}", message)
309 }
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
316pub enum RoutingReason {
317 Routed {
319 target: Option<RouteTarget>,
320 filters_applied: Vec<String>,
321 },
322 Fallback { original_filters: Vec<String> },
324 NoNodes { filters: Vec<String> },
326 Error { message: String },
328}
329
330impl std::fmt::Display for RoutingReason {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 match self {
333 RoutingReason::Routed { target, .. } => {
334 write!(f, "routed to {:?}", target)
335 }
336 RoutingReason::Fallback { .. } => {
337 write!(f, "fallback")
338 }
339 RoutingReason::NoNodes { filters } => {
340 write!(f, "no nodes ({})", filters.join(", "))
341 }
342 RoutingReason::Error { message } => {
343 write!(f, "error: {}", message)
344 }
345 }
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::super::node_filter::SyncMode;
352 use super::*;
353
354 async fn setup_router() -> QueryRouter {
355 let router = QueryRouter::new(RoutingConfig::default());
356
357 router.add_node(NodeInfo::primary("primary")).await;
359 router
360 .add_node(NodeInfo::standby("standby-sync-1", SyncMode::Sync))
361 .await;
362 router
363 .add_node(NodeInfo::standby("standby-async-1", SyncMode::Async).with_lag(100))
364 .await;
365 router
366 .add_node(NodeInfo::standby("standby-async-2", SyncMode::Async).with_lag(200))
367 .await;
368
369 router
370 }
371
372 #[tokio::test]
373 async fn test_route_read_query() {
374 let router = setup_router().await;
375
376 let decision = router.route("SELECT * FROM users").await;
377
378 assert!(decision.is_success());
379 assert!(!decision.is_write);
380 }
381
382 #[tokio::test]
383 async fn test_route_write_query() {
384 let router = setup_router().await;
385
386 let decision = router
387 .route("INSERT INTO users (name) VALUES ('test')")
388 .await;
389
390 assert!(decision.is_success());
391 assert!(decision.is_write);
392 assert_eq!(decision.target_node.as_deref(), Some("primary"));
393 }
394
395 #[tokio::test]
396 async fn test_route_with_primary_hint() {
397 let router = setup_router().await;
398
399 let decision = router
400 .route("/*helios:route=primary*/ SELECT * FROM users")
401 .await;
402
403 assert!(decision.is_success());
404 assert_eq!(decision.target_node.as_deref(), Some("primary"));
405 }
406
407 #[tokio::test]
408 async fn test_route_with_sync_hint() {
409 let router = setup_router().await;
410
411 let decision = router
412 .route("/*helios:route=sync*/ SELECT * FROM users")
413 .await;
414
415 assert!(decision.is_success());
416 assert_eq!(decision.target_node.as_deref(), Some("standby-sync-1"));
417 }
418
419 #[tokio::test]
420 async fn test_route_with_node_hint() {
421 let router = setup_router().await;
422
423 let decision = router
424 .route("/*helios:node=standby-async-1*/ SELECT * FROM users")
425 .await;
426
427 assert!(decision.is_success());
428 assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
429 }
430
431 #[tokio::test]
432 async fn test_route_with_lag_hint() {
433 let router = setup_router().await;
434
435 let decision = router
436 .route("/*helios:route=async,lag=150ms*/ SELECT * FROM users")
437 .await;
438
439 assert!(decision.is_success());
440 assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
442 }
443
444 #[tokio::test]
445 async fn test_route_no_matching_nodes() {
446 let router = setup_router().await;
447
448 let decision = router
449 .route("/*helios:node=nonexistent*/ SELECT * FROM users")
450 .await;
451
452 assert!(decision.is_success()); }
455
456 #[tokio::test]
457 async fn test_is_write_query() {
458 let router = QueryRouter::new(RoutingConfig::default());
459
460 assert!(router.is_write_query("INSERT INTO users VALUES (1)"));
461 assert!(router.is_write_query("UPDATE users SET name = 'test'"));
462 assert!(router.is_write_query("DELETE FROM users"));
463 assert!(router.is_write_query("CREATE TABLE test (id INT)"));
464 assert!(router.is_write_query("BEGIN"));
465 assert!(router.is_write_query("COMMIT"));
466
467 assert!(!router.is_write_query("SELECT * FROM users"));
468 assert!(!router.is_write_query("WITH cte AS (SELECT 1) SELECT * FROM cte"));
469 }
470
471 #[tokio::test]
472 async fn test_strip_hints() {
473 let router = QueryRouter::new(RoutingConfig::default());
474
475 let stripped = router.strip_hints("/*helios:route=primary*/ SELECT * FROM users");
476 assert_eq!(stripped, "SELECT * FROM users");
477 }
478
479 #[tokio::test]
480 async fn test_invalid_hint_combination() {
481 let router = setup_router().await;
482
483 let decision = router
484 .route("/*helios:route=async,consistency=strong*/ SELECT * FROM users")
485 .await;
486
487 assert!(!decision.is_success());
489 }
490
491 #[tokio::test]
492 async fn test_metrics_tracking() {
493 let router = setup_router().await;
494
495 router.route("SELECT * FROM users").await;
497 router
498 .route("/*helios:route=primary*/ SELECT * FROM accounts")
499 .await;
500 router.route("INSERT INTO users VALUES (1)").await;
501
502 let stats = router.metrics().snapshot();
503 assert!(stats.total_routed >= 3);
504 assert!(stats.with_hints >= 1);
505 }
506}