Skip to main content

heliosdb_proxy/routing/
node_filter.rs

1//! Node Filter
2//!
3//! Filters nodes based on routing hints and consistency requirements.
4
5use super::{
6    ConsistencyLevel, ParsedHints, RouteTarget, RoutingConfig, RoutingError, Result,
7};
8use std::time::Duration;
9
10/// Node information for filtering
11#[derive(Debug, Clone)]
12pub struct NodeInfo {
13    /// Node name/identifier
14    pub name: String,
15    /// Node role (primary/standby)
16    pub role: NodeRole,
17    /// Sync mode
18    pub sync_mode: SyncMode,
19    /// Current replication lag
20    pub lag_ms: u64,
21    /// Is node healthy
22    pub healthy: bool,
23    /// Is node enabled for routing
24    pub enabled: bool,
25    /// Node weight for load balancing
26    pub weight: u32,
27    /// Tags for custom routing (e.g., "vector", "analytics")
28    pub tags: Vec<String>,
29    /// Zone/region for locality routing
30    pub zone: Option<String>,
31}
32
33impl NodeInfo {
34    /// Create a new primary node
35    pub fn primary(name: &str) -> Self {
36        Self {
37            name: name.to_string(),
38            role: NodeRole::Primary,
39            sync_mode: SyncMode::Primary,
40            lag_ms: 0,
41            healthy: true,
42            enabled: true,
43            weight: 100,
44            tags: Vec::new(),
45            zone: None,
46        }
47    }
48
49    /// Create a new standby node
50    pub fn standby(name: &str, sync_mode: SyncMode) -> Self {
51        Self {
52            name: name.to_string(),
53            role: NodeRole::Standby,
54            sync_mode,
55            lag_ms: 0,
56            healthy: true,
57            enabled: true,
58            weight: 100,
59            tags: Vec::new(),
60            zone: None,
61        }
62    }
63
64    /// Set lag
65    pub fn with_lag(mut self, lag_ms: u64) -> Self {
66        self.lag_ms = lag_ms;
67        self
68    }
69
70    /// Set tags
71    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
72        self.tags = tags;
73        self
74    }
75
76    /// Set zone
77    pub fn with_zone(mut self, zone: &str) -> Self {
78        self.zone = Some(zone.to_string());
79        self
80    }
81
82    /// Check if node has a specific tag
83    pub fn has_tag(&self, tag: &str) -> bool {
84        self.tags.iter().any(|t| t == tag)
85    }
86}
87
88/// Node role
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90pub enum NodeRole {
91    Primary,
92    Standby,
93    ReadReplica,
94}
95
96/// Sync mode for standby nodes
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum SyncMode {
99    /// Primary node
100    Primary,
101    /// Fully synchronous replication
102    Sync,
103    /// Semi-synchronous replication
104    SemiSync,
105    /// Asynchronous replication
106    Async,
107}
108
109impl SyncMode {
110    /// Check if this mode matches a route target
111    pub fn matches_target(&self, target: RouteTarget) -> bool {
112        match target {
113            RouteTarget::Primary => *self == SyncMode::Primary,
114            RouteTarget::Sync => *self == SyncMode::Sync,
115            RouteTarget::SemiSync => *self == SyncMode::SemiSync,
116            RouteTarget::Async => *self == SyncMode::Async,
117            RouteTarget::Standby => matches!(self, SyncMode::Sync | SyncMode::SemiSync | SyncMode::Async),
118            RouteTarget::Any => true,
119            RouteTarget::Local => true, // Handled separately
120            RouteTarget::Vector => true, // Handled via tags
121        }
122    }
123}
124
125/// Node filter for routing decisions
126#[derive(Debug)]
127pub struct NodeFilter {
128    /// Routing configuration
129    config: RoutingConfig,
130    /// Local zone (for local routing)
131    local_zone: Option<String>,
132}
133
134impl NodeFilter {
135    /// Create a new node filter
136    pub fn new(config: RoutingConfig) -> Self {
137        Self {
138            config,
139            local_zone: None,
140        }
141    }
142
143    /// Set local zone
144    pub fn with_local_zone(mut self, zone: &str) -> Self {
145        self.local_zone = Some(zone.to_string());
146        self
147    }
148
149    /// Filter nodes based on criteria
150    pub fn filter<'a>(
151        &self,
152        nodes: &'a [NodeInfo],
153        criteria: &NodeCriteria,
154    ) -> FilterResult<'a> {
155        let mut eligible: Vec<&NodeInfo> = nodes.iter()
156            .filter(|n| n.healthy && n.enabled)
157            .collect();
158
159        let mut reasons = Vec::new();
160
161        // Filter by specific node name
162        if let Some(ref name) = criteria.node_name {
163            let count_before = eligible.len();
164            eligible.retain(|n| n.name == *name);
165            if eligible.len() < count_before {
166                reasons.push(format!("Filtered to node: {}", name));
167            }
168        }
169
170        // Filter by route target
171        if let Some(target) = criteria.route {
172            let count_before = eligible.len();
173            eligible.retain(|n| self.matches_route_target(n, target));
174            if eligible.len() < count_before {
175                reasons.push(format!("Filtered by route target: {:?}", target));
176            }
177        }
178
179        // Filter by consistency level
180        if let Some(consistency) = criteria.consistency {
181            let count_before = eligible.len();
182            eligible.retain(|n| self.meets_consistency(n, consistency, criteria.max_lag));
183            if eligible.len() < count_before {
184                reasons.push(format!("Filtered by consistency: {:?}", consistency));
185            }
186        }
187
188        // Filter by max lag
189        if let Some(max_lag) = criteria.max_lag {
190            let count_before = eligible.len();
191            let max_lag_ms = max_lag.as_millis() as u64;
192            eligible.retain(|n| n.lag_ms <= max_lag_ms);
193            if eligible.len() < count_before {
194                reasons.push(format!("Filtered by max lag: {}ms", max_lag_ms));
195            }
196        }
197
198        // Filter by tags
199        if !criteria.required_tags.is_empty() {
200            let count_before = eligible.len();
201            eligible.retain(|n| criteria.required_tags.iter().all(|tag| n.has_tag(tag)));
202            if eligible.len() < count_before {
203                reasons.push(format!("Filtered by tags: {:?}", criteria.required_tags));
204            }
205        }
206
207        // Handle local routing
208        if criteria.route == Some(RouteTarget::Local) {
209            if let Some(ref local_zone) = self.local_zone {
210                let local_nodes: Vec<_> = eligible.iter()
211                    .filter(|n| n.zone.as_ref() == Some(local_zone))
212                    .copied()
213                    .collect();
214
215                if !local_nodes.is_empty() {
216                    eligible = local_nodes;
217                    reasons.push(format!("Preferred local zone: {}", local_zone));
218                }
219            }
220        }
221
222        // Handle vector routing
223        if criteria.route == Some(RouteTarget::Vector) {
224            let vector_nodes: Vec<_> = eligible.iter()
225                .filter(|n| n.has_tag("vector"))
226                .copied()
227                .collect();
228
229            if !vector_nodes.is_empty() {
230                eligible = vector_nodes;
231                reasons.push("Filtered to vector-capable nodes".to_string());
232            }
233        }
234
235        // Resolve aliases
236        if let Some(ref alias) = criteria.alias {
237            if let Some(alias_nodes) = self.config.resolve_alias(alias) {
238                let count_before = eligible.len();
239                eligible.retain(|n| alias_nodes.contains(&n.name));
240                if eligible.len() < count_before {
241                    reasons.push(format!("Resolved alias: {}", alias));
242                }
243            }
244        }
245
246        FilterResult {
247            eligible,
248            reasons,
249            fallback_used: false,
250        }
251    }
252
253    /// Check if node matches route target
254    fn matches_route_target(&self, node: &NodeInfo, target: RouteTarget) -> bool {
255        match target {
256            RouteTarget::Primary => node.role == NodeRole::Primary,
257            RouteTarget::Standby => node.role == NodeRole::Standby,
258            RouteTarget::Sync => node.sync_mode == SyncMode::Sync,
259            RouteTarget::SemiSync => node.sync_mode == SyncMode::SemiSync,
260            RouteTarget::Async => node.sync_mode == SyncMode::Async,
261            RouteTarget::Any => true,
262            RouteTarget::Local => true, // Handled in filter()
263            RouteTarget::Vector => node.has_tag("vector"),
264        }
265    }
266
267    /// Check if node meets consistency requirements
268    fn meets_consistency(&self, node: &NodeInfo, level: ConsistencyLevel, max_lag: Option<Duration>) -> bool {
269        let config = match self.config.get_consistency_config(level) {
270            Some(c) => c,
271            None => return true, // No config = allow all
272        };
273
274        // Check if node name matches allowed patterns
275        if !config.allows_node(&node.name) && !config.allows_node(&format!("{:?}", node.role).to_lowercase()) {
276            return false;
277        }
278
279        // Check lag constraint
280        let max_lag_ms = max_lag
281            .map(|d| d.as_millis() as u64)
282            .unwrap_or(config.max_lag_ms);
283
284        if max_lag_ms < u64::MAX && node.lag_ms > max_lag_ms {
285            return false;
286        }
287
288        true
289    }
290
291    /// Get default criteria for a query type
292    pub fn default_criteria_for_read(&self) -> NodeCriteria {
293        NodeCriteria {
294            route: Some(self.config.default.read_target),
295            consistency: Some(self.config.default.consistency),
296            ..Default::default()
297        }
298    }
299
300    /// Get default criteria for a write
301    pub fn default_criteria_for_write(&self) -> NodeCriteria {
302        NodeCriteria {
303            route: Some(self.config.default.write_target),
304            consistency: Some(ConsistencyLevel::Strong),
305            ..Default::default()
306        }
307    }
308}
309
310/// Criteria for node filtering
311#[derive(Debug, Clone, Default)]
312pub struct NodeCriteria {
313    /// Specific node name
314    pub node_name: Option<String>,
315    /// Route target
316    pub route: Option<RouteTarget>,
317    /// Consistency level
318    pub consistency: Option<ConsistencyLevel>,
319    /// Maximum acceptable lag
320    pub max_lag: Option<Duration>,
321    /// Required tags
322    pub required_tags: Vec<String>,
323    /// Alias to resolve
324    pub alias: Option<String>,
325    /// Branch name (for branch-aware routing)
326    pub branch: Option<String>,
327}
328
329impl NodeCriteria {
330    /// Create criteria from parsed hints
331    pub fn from_hints(hints: &ParsedHints) -> Self {
332        Self {
333            node_name: hints.node.clone(),
334            route: hints.route,
335            consistency: hints.consistency,
336            max_lag: hints.max_lag,
337            required_tags: Vec::new(),
338            alias: None,
339            branch: hints.branch.clone(),
340        }
341    }
342
343    /// Add a required tag
344    pub fn with_tag(mut self, tag: &str) -> Self {
345        self.required_tags.push(tag.to_string());
346        self
347    }
348
349    /// Set alias
350    pub fn with_alias(mut self, alias: &str) -> Self {
351        self.alias = Some(alias.to_string());
352        self
353    }
354}
355
356/// Result of node filtering
357#[derive(Debug)]
358pub struct FilterResult<'a> {
359    /// Eligible nodes after filtering
360    pub eligible: Vec<&'a NodeInfo>,
361    /// Reasons for filtering decisions
362    pub reasons: Vec<String>,
363    /// Whether fallback was used
364    pub fallback_used: bool,
365}
366
367impl<'a> FilterResult<'a> {
368    /// Check if any nodes match
369    pub fn has_matches(&self) -> bool {
370        !self.eligible.is_empty()
371    }
372
373    /// Get number of matches
374    pub fn count(&self) -> usize {
375        self.eligible.len()
376    }
377
378    /// Get first match
379    pub fn first(&self) -> Option<&'a NodeInfo> {
380        self.eligible.first().copied()
381    }
382
383    /// Convert to error if no matches
384    pub fn require_match(&self, context: &str) -> Result<&'a NodeInfo> {
385        self.first().ok_or_else(|| {
386            RoutingError::NoMatchingNodes(format!(
387                "{}: reasons: {}",
388                context,
389                self.reasons.join(", ")
390            ))
391        })
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    fn test_nodes() -> Vec<NodeInfo> {
400        vec![
401            NodeInfo::primary("primary"),
402            NodeInfo::standby("standby-sync-1", SyncMode::Sync),
403            NodeInfo::standby("standby-async-1", SyncMode::Async).with_lag(500),
404            NodeInfo::standby("standby-async-2", SyncMode::Async).with_lag(5000),
405            NodeInfo::standby("standby-vector-1", SyncMode::Async)
406                .with_tags(vec!["vector".to_string()]),
407        ]
408    }
409
410    #[test]
411    fn test_filter_by_route_target() {
412        let filter = NodeFilter::new(RoutingConfig::default());
413        let nodes = test_nodes();
414
415        // Filter for primary
416        let criteria = NodeCriteria {
417            route: Some(RouteTarget::Primary),
418            ..Default::default()
419        };
420        let result = filter.filter(&nodes, &criteria);
421        assert_eq!(result.count(), 1);
422        assert_eq!(result.first().unwrap().name, "primary");
423
424        // Filter for any standby
425        let criteria = NodeCriteria {
426            route: Some(RouteTarget::Standby),
427            ..Default::default()
428        };
429        let result = filter.filter(&nodes, &criteria);
430        assert_eq!(result.count(), 4);
431    }
432
433    #[test]
434    fn test_filter_by_sync_mode() {
435        let filter = NodeFilter::new(RoutingConfig::default());
436        let nodes = test_nodes();
437
438        let criteria = NodeCriteria {
439            route: Some(RouteTarget::Sync),
440            ..Default::default()
441        };
442        let result = filter.filter(&nodes, &criteria);
443        assert_eq!(result.count(), 1);
444        assert_eq!(result.first().unwrap().name, "standby-sync-1");
445    }
446
447    #[test]
448    fn test_filter_by_max_lag() {
449        let filter = NodeFilter::new(RoutingConfig::default());
450        let nodes = test_nodes();
451
452        let criteria = NodeCriteria {
453            max_lag: Some(Duration::from_millis(1000)),
454            ..Default::default()
455        };
456        let result = filter.filter(&nodes, &criteria);
457
458        // Should exclude standby-async-2 (5000ms lag)
459        assert!(result.eligible.iter().all(|n| n.lag_ms <= 1000));
460    }
461
462    #[test]
463    fn test_filter_by_node_name() {
464        let filter = NodeFilter::new(RoutingConfig::default());
465        let nodes = test_nodes();
466
467        let criteria = NodeCriteria {
468            node_name: Some("standby-sync-1".to_string()),
469            ..Default::default()
470        };
471        let result = filter.filter(&nodes, &criteria);
472        assert_eq!(result.count(), 1);
473        assert_eq!(result.first().unwrap().name, "standby-sync-1");
474    }
475
476    #[test]
477    fn test_filter_by_tag() {
478        let filter = NodeFilter::new(RoutingConfig::default());
479        let nodes = test_nodes();
480
481        let criteria = NodeCriteria {
482            route: Some(RouteTarget::Vector),
483            ..Default::default()
484        };
485        let result = filter.filter(&nodes, &criteria);
486        assert_eq!(result.count(), 1);
487        assert_eq!(result.first().unwrap().name, "standby-vector-1");
488    }
489
490    #[test]
491    fn test_filter_with_alias() {
492        let mut config = RoutingConfig::default();
493        config.add_alias("analytics", vec![
494            "standby-async-1".to_string(),
495            "standby-async-2".to_string(),
496        ]);
497
498        let filter = NodeFilter::new(config);
499        let nodes = test_nodes();
500
501        let criteria = NodeCriteria {
502            alias: Some("analytics".to_string()),
503            ..Default::default()
504        };
505        let result = filter.filter(&nodes, &criteria);
506        assert_eq!(result.count(), 2);
507    }
508
509    #[test]
510    fn test_local_zone_preference() {
511        let filter = NodeFilter::new(RoutingConfig::default())
512            .with_local_zone("us-west-1");
513
514        let nodes = vec![
515            NodeInfo::standby("standby-1", SyncMode::Async).with_zone("us-east-1"),
516            NodeInfo::standby("standby-2", SyncMode::Async).with_zone("us-west-1"),
517        ];
518
519        let criteria = NodeCriteria {
520            route: Some(RouteTarget::Local),
521            ..Default::default()
522        };
523        let result = filter.filter(&nodes, &criteria);
524        assert_eq!(result.count(), 1);
525        assert_eq!(result.first().unwrap().name, "standby-2");
526    }
527
528    #[test]
529    fn test_no_match_error() {
530        let filter = NodeFilter::new(RoutingConfig::default());
531        let nodes = test_nodes();
532
533        let criteria = NodeCriteria {
534            node_name: Some("nonexistent".to_string()),
535            ..Default::default()
536        };
537        let result = filter.filter(&nodes, &criteria);
538        assert!(!result.has_matches());
539
540        let err = result.require_match("test context");
541        assert!(err.is_err());
542    }
543
544    #[test]
545    fn test_from_hints() {
546        let parser = super::super::HintParser::new();
547        let hints = parser.parse("/*helios:route=sync,lag=100ms*/ SELECT 1");
548
549        let criteria = NodeCriteria::from_hints(&hints);
550        assert_eq!(criteria.route, Some(RouteTarget::Sync));
551        assert_eq!(criteria.max_lag, Some(Duration::from_millis(100)));
552    }
553}