1use super::{
6 HintParser, ParsedHints, RouteTarget,
7 NodeFilter, NodeCriteria, NodeInfo, FilterResult,
8 RoutingConfig, RoutingError, RoutingMetrics, Result,
9};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13
14pub struct QueryRouter {
16 parser: HintParser,
18 filter: NodeFilter,
20 nodes: Arc<RwLock<Vec<NodeInfo>>>,
22 metrics: Arc<RoutingMetrics>,
24 config: RoutingConfig,
26 rr_counter: std::sync::atomic::AtomicU64,
28}
29
30impl QueryRouter {
31 pub fn new(config: RoutingConfig) -> Self {
33 let filter = NodeFilter::new(config.clone());
34
35 Self {
36 parser: HintParser::new(),
37 filter,
38 nodes: Arc::new(RwLock::new(Vec::new())),
39 metrics: Arc::new(RoutingMetrics::new()),
40 config,
41 rr_counter: std::sync::atomic::AtomicU64::new(0),
42 }
43 }
44
45 pub async fn route(&self, query: &str) -> RoutingDecision {
47 let start = Instant::now();
48
49 let hints = self.parser.parse(query);
51
52 if let Err(e) = hints.validate() {
54 self.metrics.record_invalid_hints();
55 return RoutingDecision::error(e.to_string());
56 }
57
58 let is_write = self.is_write_query(query);
60
61 let mut criteria = if !hints.is_empty() {
63 NodeCriteria::from_hints(&hints)
64 } else if is_write {
65 self.filter.default_criteria_for_write()
66 } else {
67 self.filter.default_criteria_for_read()
68 };
69
70 if is_write && criteria.route.is_none() {
72 criteria.route = Some(RouteTarget::Primary);
73 }
74
75 let nodes = self.nodes.read().await;
77 let filter_result = self.filter.filter(&nodes, &criteria);
78
79 let decision = if filter_result.has_matches() {
81 let selected = self.select_node(&filter_result);
82 self.metrics.record_routing(
83 criteria.route,
84 !hints.is_empty(),
85 start.elapsed(),
86 );
87
88 RoutingDecision {
89 target_node: Some(selected.name.clone()),
90 hints: hints.clone(),
91 reason: RoutingReason::Routed {
92 target: criteria.route,
93 filters_applied: filter_result.reasons.clone(),
94 },
95 elapsed: start.elapsed(),
96 is_write,
97 }
98 } else {
99 let fallback = self.try_fallback(&nodes, is_write);
101
102 if let Some(node) = fallback {
103 self.metrics.record_fallback();
104 RoutingDecision {
105 target_node: Some(node.name.clone()),
106 hints: hints.clone(),
107 reason: RoutingReason::Fallback {
108 original_filters: filter_result.reasons.clone(),
109 },
110 elapsed: start.elapsed(),
111 is_write,
112 }
113 } else {
114 self.metrics.record_no_nodes();
115 RoutingDecision {
116 target_node: None,
117 hints: hints.clone(),
118 reason: RoutingReason::NoNodes {
119 filters: filter_result.reasons.clone(),
120 },
121 elapsed: start.elapsed(),
122 is_write,
123 }
124 }
125 };
126
127 decision
128 }
129
130 pub async fn route_with_criteria(&self, criteria: &NodeCriteria) -> Result<String> {
132 let nodes = self.nodes.read().await;
133 let filter_result = self.filter.filter(&nodes, criteria);
134
135 filter_result
136 .require_match("routing")
137 .map(|n| n.name.clone())
138 }
139
140 pub fn is_write_query(&self, query: &str) -> bool {
142 if !self.config.default.auto_detect_writes {
143 return false;
144 }
145
146 let upper = query.trim().to_uppercase();
147 let first_word = upper.split_whitespace().next().unwrap_or("");
148
149 matches!(
150 first_word,
151 "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "ALTER" | "DROP" |
152 "TRUNCATE" | "GRANT" | "REVOKE" | "MERGE" | "UPSERT" |
153 "BEGIN" | "START" | "COMMIT" | "ROLLBACK" | "SAVEPOINT" |
154 "LOCK" | "PREPARE" | "EXECUTE" | "DEALLOCATE"
155 )
156 }
157
158 fn select_node<'a>(&self, result: &FilterResult<'a>) -> &'a NodeInfo {
160 if result.eligible.is_empty() {
161 panic!("select_node called with no eligible nodes");
162 }
163
164 if result.eligible.len() == 1 {
165 return result.eligible[0];
166 }
167
168 let idx = self.rr_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
170 let selected_idx = (idx as usize) % result.eligible.len();
171 result.eligible[selected_idx]
172 }
173
174 fn try_fallback<'a>(&self, nodes: &'a [NodeInfo], is_write: bool) -> Option<&'a NodeInfo> {
176 if is_write {
177 nodes.iter().find(|n| n.role == super::node_filter::NodeRole::Primary && n.healthy)
179 } else {
180 nodes.iter().find(|n| n.healthy && n.enabled)
182 }
183 }
184
185 pub fn strip_hints(&self, query: &str) -> String {
187 if self.config.hints.strip_hints {
188 self.parser.strip(query)
189 } else {
190 query.to_string()
191 }
192 }
193
194 pub fn parse_hints(&self, query: &str) -> ParsedHints {
196 self.parser.parse(query)
197 }
198
199 pub async fn add_node(&self, node: NodeInfo) {
201 self.nodes.write().await.push(node);
202 }
203
204 pub async fn remove_node(&self, name: &str) {
206 self.nodes.write().await.retain(|n| n.name != name);
207 }
208
209 pub async fn update_node<F>(&self, name: &str, f: F)
211 where
212 F: FnOnce(&mut NodeInfo),
213 {
214 let mut nodes = self.nodes.write().await;
215 if let Some(node) = nodes.iter_mut().find(|n| n.name == name) {
216 f(node);
217 }
218 }
219
220 pub fn metrics(&self) -> &RoutingMetrics {
222 &self.metrics
223 }
224
225 pub fn config(&self) -> &RoutingConfig {
227 &self.config
228 }
229}
230
231#[derive(Debug, Clone)]
233pub struct RoutingDecision {
234 pub target_node: Option<String>,
236 pub hints: ParsedHints,
238 pub reason: RoutingReason,
240 pub elapsed: Duration,
242 pub is_write: bool,
244}
245
246impl RoutingDecision {
247 pub fn error(message: String) -> Self {
249 Self {
250 target_node: None,
251 hints: ParsedHints::default(),
252 reason: RoutingReason::Error { message },
253 elapsed: Duration::ZERO,
254 is_write: false,
255 }
256 }
257
258 pub fn is_success(&self) -> bool {
260 self.target_node.is_some()
261 }
262
263 pub fn require_target(&self) -> Result<&str> {
265 self.target_node
266 .as_deref()
267 .ok_or_else(|| RoutingError::NoMatchingNodes(self.reason.to_string()))
268 }
269
270 pub fn summary(&self) -> String {
272 match &self.reason {
273 RoutingReason::Routed { target, .. } => {
274 format!(
275 "Routed to {} ({:?}) in {:?}",
276 self.target_node.as_deref().unwrap_or("unknown"),
277 target,
278 self.elapsed
279 )
280 }
281 RoutingReason::Fallback { .. } => {
282 format!(
283 "Fallback to {} in {:?}",
284 self.target_node.as_deref().unwrap_or("unknown"),
285 self.elapsed
286 )
287 }
288 RoutingReason::NoNodes { filters } => {
289 format!("No nodes available (filters: {:?})", filters)
290 }
291 RoutingReason::Error { message } => {
292 format!("Error: {}", message)
293 }
294 }
295 }
296}
297
298#[derive(Debug, Clone)]
300pub enum RoutingReason {
301 Routed {
303 target: Option<RouteTarget>,
304 filters_applied: Vec<String>,
305 },
306 Fallback {
308 original_filters: Vec<String>,
309 },
310 NoNodes {
312 filters: Vec<String>,
313 },
314 Error {
316 message: String,
317 },
318}
319
320impl std::fmt::Display for RoutingReason {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 match self {
323 RoutingReason::Routed { target, .. } => {
324 write!(f, "routed to {:?}", target)
325 }
326 RoutingReason::Fallback { .. } => {
327 write!(f, "fallback")
328 }
329 RoutingReason::NoNodes { filters } => {
330 write!(f, "no nodes ({})", filters.join(", "))
331 }
332 RoutingReason::Error { message } => {
333 write!(f, "error: {}", message)
334 }
335 }
336 }
337}
338
339#[cfg(test)]
340mod tests {
341 use super::*;
342 use super::super::node_filter::SyncMode;
343
344 async fn setup_router() -> QueryRouter {
345 let router = QueryRouter::new(RoutingConfig::default());
346
347 router.add_node(NodeInfo::primary("primary")).await;
349 router.add_node(NodeInfo::standby("standby-sync-1", SyncMode::Sync)).await;
350 router.add_node(NodeInfo::standby("standby-async-1", SyncMode::Async)
351 .with_lag(100)).await;
352 router.add_node(NodeInfo::standby("standby-async-2", SyncMode::Async)
353 .with_lag(200)).await;
354
355 router
356 }
357
358 #[tokio::test]
359 async fn test_route_read_query() {
360 let router = setup_router().await;
361
362 let decision = router.route("SELECT * FROM users").await;
363
364 assert!(decision.is_success());
365 assert!(!decision.is_write);
366 }
367
368 #[tokio::test]
369 async fn test_route_write_query() {
370 let router = setup_router().await;
371
372 let decision = router.route("INSERT INTO users (name) VALUES ('test')").await;
373
374 assert!(decision.is_success());
375 assert!(decision.is_write);
376 assert_eq!(decision.target_node.as_deref(), Some("primary"));
377 }
378
379 #[tokio::test]
380 async fn test_route_with_primary_hint() {
381 let router = setup_router().await;
382
383 let decision = router.route("/*helios:route=primary*/ SELECT * FROM users").await;
384
385 assert!(decision.is_success());
386 assert_eq!(decision.target_node.as_deref(), Some("primary"));
387 }
388
389 #[tokio::test]
390 async fn test_route_with_sync_hint() {
391 let router = setup_router().await;
392
393 let decision = router.route("/*helios:route=sync*/ SELECT * FROM users").await;
394
395 assert!(decision.is_success());
396 assert_eq!(decision.target_node.as_deref(), Some("standby-sync-1"));
397 }
398
399 #[tokio::test]
400 async fn test_route_with_node_hint() {
401 let router = setup_router().await;
402
403 let decision = router.route("/*helios:node=standby-async-1*/ SELECT * FROM users").await;
404
405 assert!(decision.is_success());
406 assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
407 }
408
409 #[tokio::test]
410 async fn test_route_with_lag_hint() {
411 let router = setup_router().await;
412
413 let decision = router.route("/*helios:route=async,lag=150ms*/ SELECT * FROM users").await;
414
415 assert!(decision.is_success());
416 assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
418 }
419
420 #[tokio::test]
421 async fn test_route_no_matching_nodes() {
422 let router = setup_router().await;
423
424 let decision = router.route("/*helios:node=nonexistent*/ SELECT * FROM users").await;
425
426 assert!(decision.is_success()); }
429
430 #[tokio::test]
431 async fn test_is_write_query() {
432 let router = QueryRouter::new(RoutingConfig::default());
433
434 assert!(router.is_write_query("INSERT INTO users VALUES (1)"));
435 assert!(router.is_write_query("UPDATE users SET name = 'test'"));
436 assert!(router.is_write_query("DELETE FROM users"));
437 assert!(router.is_write_query("CREATE TABLE test (id INT)"));
438 assert!(router.is_write_query("BEGIN"));
439 assert!(router.is_write_query("COMMIT"));
440
441 assert!(!router.is_write_query("SELECT * FROM users"));
442 assert!(!router.is_write_query("WITH cte AS (SELECT 1) SELECT * FROM cte"));
443 }
444
445 #[tokio::test]
446 async fn test_strip_hints() {
447 let router = QueryRouter::new(RoutingConfig::default());
448
449 let stripped = router.strip_hints("/*helios:route=primary*/ SELECT * FROM users");
450 assert_eq!(stripped, "SELECT * FROM users");
451 }
452
453 #[tokio::test]
454 async fn test_invalid_hint_combination() {
455 let router = setup_router().await;
456
457 let decision = router.route(
458 "/*helios:route=async,consistency=strong*/ SELECT * FROM users"
459 ).await;
460
461 assert!(!decision.is_success());
463 }
464
465 #[tokio::test]
466 async fn test_metrics_tracking() {
467 let router = setup_router().await;
468
469 router.route("SELECT * FROM users").await;
471 router.route("/*helios:route=primary*/ SELECT * FROM accounts").await;
472 router.route("INSERT INTO users VALUES (1)").await;
473
474 let stats = router.metrics().snapshot();
475 assert!(stats.total_routed >= 3);
476 assert!(stats.with_hints >= 1);
477 }
478}