Skip to main content

heliosdb_proxy/routing/
query_router.rs

1//! Query Router
2//!
3//! Routes queries to appropriate nodes based on hints and policies.
4
5use super::{
6    HintParser, ParsedHints, RouteTarget,
7    NodeFilter, NodeCriteria, NodeInfo, FilterResult,
8    RoutingConfig, RoutingError, RoutingMetrics, Result,
9};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13
14/// Query router - routes queries to appropriate nodes
15pub struct QueryRouter {
16    /// Hint parser
17    parser: HintParser,
18    /// Node filter
19    filter: NodeFilter,
20    /// Available nodes
21    nodes: Arc<RwLock<Vec<NodeInfo>>>,
22    /// Routing metrics
23    metrics: Arc<RoutingMetrics>,
24    /// Configuration
25    config: RoutingConfig,
26    /// Round-robin counter for load balancing
27    rr_counter: std::sync::atomic::AtomicU64,
28}
29
30impl QueryRouter {
31    /// Create a new query router
32    pub fn new(config: RoutingConfig) -> Self {
33        let filter = NodeFilter::new(config.clone());
34
35        Self {
36            parser: HintParser::new(),
37            filter,
38            nodes: Arc::new(RwLock::new(Vec::new())),
39            metrics: Arc::new(RoutingMetrics::new()),
40            config,
41            rr_counter: std::sync::atomic::AtomicU64::new(0),
42        }
43    }
44
45    /// Route a query
46    pub async fn route(&self, query: &str) -> RoutingDecision {
47        let start = Instant::now();
48
49        // Parse hints
50        let hints = self.parser.parse(query);
51
52        // Validate hints
53        if let Err(e) = hints.validate() {
54            self.metrics.record_invalid_hints();
55            return RoutingDecision::error(e.to_string());
56        }
57
58        // Determine if this is a write query
59        let is_write = self.is_write_query(query);
60
61        // Build criteria from hints
62        let mut criteria = if !hints.is_empty() {
63            NodeCriteria::from_hints(&hints)
64        } else if is_write {
65            self.filter.default_criteria_for_write()
66        } else {
67            self.filter.default_criteria_for_read()
68        };
69
70        // For writes without explicit routing, force primary
71        if is_write && criteria.route.is_none() {
72            criteria.route = Some(RouteTarget::Primary);
73        }
74
75        // Get nodes and filter
76        let nodes = self.nodes.read().await;
77        let filter_result = self.filter.filter(&nodes, &criteria);
78
79        // Build decision
80        let decision = if filter_result.has_matches() {
81            let selected = self.select_node(&filter_result);
82            self.metrics.record_routing(
83                criteria.route,
84                !hints.is_empty(),
85                start.elapsed(),
86            );
87
88            RoutingDecision {
89                target_node: Some(selected.name.clone()),
90                hints: hints.clone(),
91                reason: RoutingReason::Routed {
92                    target: criteria.route,
93                    filters_applied: filter_result.reasons.clone(),
94                },
95                elapsed: start.elapsed(),
96                is_write,
97            }
98        } else {
99            // No matching nodes - try fallback
100            let fallback = self.try_fallback(&nodes, is_write);
101
102            if let Some(node) = fallback {
103                self.metrics.record_fallback();
104                RoutingDecision {
105                    target_node: Some(node.name.clone()),
106                    hints: hints.clone(),
107                    reason: RoutingReason::Fallback {
108                        original_filters: filter_result.reasons.clone(),
109                    },
110                    elapsed: start.elapsed(),
111                    is_write,
112                }
113            } else {
114                self.metrics.record_no_nodes();
115                RoutingDecision {
116                    target_node: None,
117                    hints: hints.clone(),
118                    reason: RoutingReason::NoNodes {
119                        filters: filter_result.reasons.clone(),
120                    },
121                    elapsed: start.elapsed(),
122                    is_write,
123                }
124            }
125        };
126
127        decision
128    }
129
130    /// Route with explicit hints (for use by other modules)
131    pub async fn route_with_criteria(&self, criteria: &NodeCriteria) -> Result<String> {
132        let nodes = self.nodes.read().await;
133        let filter_result = self.filter.filter(&nodes, criteria);
134
135        filter_result
136            .require_match("routing")
137            .map(|n| n.name.clone())
138    }
139
140    /// Check if query is a write operation
141    pub fn is_write_query(&self, query: &str) -> bool {
142        if !self.config.default.auto_detect_writes {
143            return false;
144        }
145
146        let upper = query.trim().to_uppercase();
147        let first_word = upper.split_whitespace().next().unwrap_or("");
148
149        matches!(
150            first_word,
151            "INSERT" | "UPDATE" | "DELETE" | "CREATE" | "ALTER" | "DROP" |
152            "TRUNCATE" | "GRANT" | "REVOKE" | "MERGE" | "UPSERT" |
153            "BEGIN" | "START" | "COMMIT" | "ROLLBACK" | "SAVEPOINT" |
154            "LOCK" | "PREPARE" | "EXECUTE" | "DEALLOCATE"
155        )
156    }
157
158    /// Select a node from eligible nodes using load balancing
159    fn select_node<'a>(&self, result: &FilterResult<'a>) -> &'a NodeInfo {
160        if result.eligible.is_empty() {
161            panic!("select_node called with no eligible nodes");
162        }
163
164        if result.eligible.len() == 1 {
165            return result.eligible[0];
166        }
167
168        // Simple round-robin for now
169        let idx = self.rr_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
170        let selected_idx = (idx as usize) % result.eligible.len();
171        result.eligible[selected_idx]
172    }
173
174    /// Try to find a fallback node
175    fn try_fallback<'a>(&self, nodes: &'a [NodeInfo], is_write: bool) -> Option<&'a NodeInfo> {
176        if is_write {
177            // For writes, only primary is acceptable
178            nodes.iter().find(|n| n.role == super::node_filter::NodeRole::Primary && n.healthy)
179        } else {
180            // For reads, try any healthy node
181            nodes.iter().find(|n| n.healthy && n.enabled)
182        }
183    }
184
185    /// Strip hints from query for backend execution
186    pub fn strip_hints(&self, query: &str) -> String {
187        if self.config.hints.strip_hints {
188            self.parser.strip(query)
189        } else {
190            query.to_string()
191        }
192    }
193
194    /// Parse hints from query (for external use)
195    pub fn parse_hints(&self, query: &str) -> ParsedHints {
196        self.parser.parse(query)
197    }
198
199    /// Add a node
200    pub async fn add_node(&self, node: NodeInfo) {
201        self.nodes.write().await.push(node);
202    }
203
204    /// Remove a node by name
205    pub async fn remove_node(&self, name: &str) {
206        self.nodes.write().await.retain(|n| n.name != name);
207    }
208
209    /// Update node state
210    pub async fn update_node<F>(&self, name: &str, f: F)
211    where
212        F: FnOnce(&mut NodeInfo),
213    {
214        let mut nodes = self.nodes.write().await;
215        if let Some(node) = nodes.iter_mut().find(|n| n.name == name) {
216            f(node);
217        }
218    }
219
220    /// Get metrics
221    pub fn metrics(&self) -> &RoutingMetrics {
222        &self.metrics
223    }
224
225    /// Get configuration
226    pub fn config(&self) -> &RoutingConfig {
227        &self.config
228    }
229}
230
231/// Routing decision result
232#[derive(Debug, Clone)]
233pub struct RoutingDecision {
234    /// Target node name (None if no node available)
235    pub target_node: Option<String>,
236    /// Parsed hints from query
237    pub hints: ParsedHints,
238    /// Reason for routing decision
239    pub reason: RoutingReason,
240    /// Time taken to make decision
241    pub elapsed: Duration,
242    /// Whether this is a write query
243    pub is_write: bool,
244}
245
246impl RoutingDecision {
247    /// Create an error decision
248    pub fn error(message: String) -> Self {
249        Self {
250            target_node: None,
251            hints: ParsedHints::default(),
252            reason: RoutingReason::Error { message },
253            elapsed: Duration::ZERO,
254            is_write: false,
255        }
256    }
257
258    /// Check if routing succeeded
259    pub fn is_success(&self) -> bool {
260        self.target_node.is_some()
261    }
262
263    /// Get the target node or error
264    pub fn require_target(&self) -> Result<&str> {
265        self.target_node
266            .as_deref()
267            .ok_or_else(|| RoutingError::NoMatchingNodes(self.reason.to_string()))
268    }
269
270    /// Get a summary string
271    pub fn summary(&self) -> String {
272        match &self.reason {
273            RoutingReason::Routed { target, .. } => {
274                format!(
275                    "Routed to {} ({:?}) in {:?}",
276                    self.target_node.as_deref().unwrap_or("unknown"),
277                    target,
278                    self.elapsed
279                )
280            }
281            RoutingReason::Fallback { .. } => {
282                format!(
283                    "Fallback to {} in {:?}",
284                    self.target_node.as_deref().unwrap_or("unknown"),
285                    self.elapsed
286                )
287            }
288            RoutingReason::NoNodes { filters } => {
289                format!("No nodes available (filters: {:?})", filters)
290            }
291            RoutingReason::Error { message } => {
292                format!("Error: {}", message)
293            }
294        }
295    }
296}
297
298/// Reason for routing decision
299#[derive(Debug, Clone)]
300pub enum RoutingReason {
301    /// Successfully routed
302    Routed {
303        target: Option<RouteTarget>,
304        filters_applied: Vec<String>,
305    },
306    /// Fallback used due to no matches
307    Fallback {
308        original_filters: Vec<String>,
309    },
310    /// No nodes available
311    NoNodes {
312        filters: Vec<String>,
313    },
314    /// Error occurred
315    Error {
316        message: String,
317    },
318}
319
320impl std::fmt::Display for RoutingReason {
321    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322        match self {
323            RoutingReason::Routed { target, .. } => {
324                write!(f, "routed to {:?}", target)
325            }
326            RoutingReason::Fallback { .. } => {
327                write!(f, "fallback")
328            }
329            RoutingReason::NoNodes { filters } => {
330                write!(f, "no nodes ({})", filters.join(", "))
331            }
332            RoutingReason::Error { message } => {
333                write!(f, "error: {}", message)
334            }
335        }
336    }
337}
338
339#[cfg(test)]
340mod tests {
341    use super::*;
342    use super::super::node_filter::SyncMode;
343
344    async fn setup_router() -> QueryRouter {
345        let router = QueryRouter::new(RoutingConfig::default());
346
347        // Add test nodes
348        router.add_node(NodeInfo::primary("primary")).await;
349        router.add_node(NodeInfo::standby("standby-sync-1", SyncMode::Sync)).await;
350        router.add_node(NodeInfo::standby("standby-async-1", SyncMode::Async)
351            .with_lag(100)).await;
352        router.add_node(NodeInfo::standby("standby-async-2", SyncMode::Async)
353            .with_lag(200)).await;
354
355        router
356    }
357
358    #[tokio::test]
359    async fn test_route_read_query() {
360        let router = setup_router().await;
361
362        let decision = router.route("SELECT * FROM users").await;
363
364        assert!(decision.is_success());
365        assert!(!decision.is_write);
366    }
367
368    #[tokio::test]
369    async fn test_route_write_query() {
370        let router = setup_router().await;
371
372        let decision = router.route("INSERT INTO users (name) VALUES ('test')").await;
373
374        assert!(decision.is_success());
375        assert!(decision.is_write);
376        assert_eq!(decision.target_node.as_deref(), Some("primary"));
377    }
378
379    #[tokio::test]
380    async fn test_route_with_primary_hint() {
381        let router = setup_router().await;
382
383        let decision = router.route("/*helios:route=primary*/ SELECT * FROM users").await;
384
385        assert!(decision.is_success());
386        assert_eq!(decision.target_node.as_deref(), Some("primary"));
387    }
388
389    #[tokio::test]
390    async fn test_route_with_sync_hint() {
391        let router = setup_router().await;
392
393        let decision = router.route("/*helios:route=sync*/ SELECT * FROM users").await;
394
395        assert!(decision.is_success());
396        assert_eq!(decision.target_node.as_deref(), Some("standby-sync-1"));
397    }
398
399    #[tokio::test]
400    async fn test_route_with_node_hint() {
401        let router = setup_router().await;
402
403        let decision = router.route("/*helios:node=standby-async-1*/ SELECT * FROM users").await;
404
405        assert!(decision.is_success());
406        assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
407    }
408
409    #[tokio::test]
410    async fn test_route_with_lag_hint() {
411        let router = setup_router().await;
412
413        let decision = router.route("/*helios:route=async,lag=150ms*/ SELECT * FROM users").await;
414
415        assert!(decision.is_success());
416        // Should only match standby-async-1 (100ms lag)
417        assert_eq!(decision.target_node.as_deref(), Some("standby-async-1"));
418    }
419
420    #[tokio::test]
421    async fn test_route_no_matching_nodes() {
422        let router = setup_router().await;
423
424        let decision = router.route("/*helios:node=nonexistent*/ SELECT * FROM users").await;
425
426        // Should fallback
427        assert!(decision.is_success()); // Fallback finds a node
428    }
429
430    #[tokio::test]
431    async fn test_is_write_query() {
432        let router = QueryRouter::new(RoutingConfig::default());
433
434        assert!(router.is_write_query("INSERT INTO users VALUES (1)"));
435        assert!(router.is_write_query("UPDATE users SET name = 'test'"));
436        assert!(router.is_write_query("DELETE FROM users"));
437        assert!(router.is_write_query("CREATE TABLE test (id INT)"));
438        assert!(router.is_write_query("BEGIN"));
439        assert!(router.is_write_query("COMMIT"));
440
441        assert!(!router.is_write_query("SELECT * FROM users"));
442        assert!(!router.is_write_query("WITH cte AS (SELECT 1) SELECT * FROM cte"));
443    }
444
445    #[tokio::test]
446    async fn test_strip_hints() {
447        let router = QueryRouter::new(RoutingConfig::default());
448
449        let stripped = router.strip_hints("/*helios:route=primary*/ SELECT * FROM users");
450        assert_eq!(stripped, "SELECT * FROM users");
451    }
452
453    #[tokio::test]
454    async fn test_invalid_hint_combination() {
455        let router = setup_router().await;
456
457        let decision = router.route(
458            "/*helios:route=async,consistency=strong*/ SELECT * FROM users"
459        ).await;
460
461        // Should return error due to invalid combination
462        assert!(!decision.is_success());
463    }
464
465    #[tokio::test]
466    async fn test_metrics_tracking() {
467        let router = setup_router().await;
468
469        // Make some routing decisions
470        router.route("SELECT * FROM users").await;
471        router.route("/*helios:route=primary*/ SELECT * FROM accounts").await;
472        router.route("INSERT INTO users VALUES (1)").await;
473
474        let stats = router.metrics().snapshot();
475        assert!(stats.total_routed >= 3);
476        assert!(stats.with_hints >= 1);
477    }
478}