Skip to main content

heliosdb_proxy/routing/
node_filter.rs

1//! Node Filter
2//!
3//! Filters nodes based on routing hints and consistency requirements.
4
5use super::{ConsistencyLevel, ParsedHints, Result, RouteTarget, RoutingConfig, RoutingError};
6use std::time::Duration;
7
8/// Node information for filtering
9#[derive(Debug, Clone)]
10pub struct NodeInfo {
11    /// Node name/identifier
12    pub name: String,
13    /// Node role (primary/standby)
14    pub role: NodeRole,
15    /// Sync mode
16    pub sync_mode: SyncMode,
17    /// Current replication lag
18    pub lag_ms: u64,
19    /// Is node healthy
20    pub healthy: bool,
21    /// Is node enabled for routing
22    pub enabled: bool,
23    /// Node weight for load balancing
24    pub weight: u32,
25    /// Tags for custom routing (e.g., "vector", "analytics")
26    pub tags: Vec<String>,
27    /// Zone/region for locality routing
28    pub zone: Option<String>,
29}
30
31impl NodeInfo {
32    /// Create a new primary node
33    pub fn primary(name: &str) -> Self {
34        Self {
35            name: name.to_string(),
36            role: NodeRole::Primary,
37            sync_mode: SyncMode::Primary,
38            lag_ms: 0,
39            healthy: true,
40            enabled: true,
41            weight: 100,
42            tags: Vec::new(),
43            zone: None,
44        }
45    }
46
47    /// Create a new standby node
48    pub fn standby(name: &str, sync_mode: SyncMode) -> Self {
49        Self {
50            name: name.to_string(),
51            role: NodeRole::Standby,
52            sync_mode,
53            lag_ms: 0,
54            healthy: true,
55            enabled: true,
56            weight: 100,
57            tags: Vec::new(),
58            zone: None,
59        }
60    }
61
62    /// Set lag
63    pub fn with_lag(mut self, lag_ms: u64) -> Self {
64        self.lag_ms = lag_ms;
65        self
66    }
67
68    /// Set tags
69    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
70        self.tags = tags;
71        self
72    }
73
74    /// Set zone
75    pub fn with_zone(mut self, zone: &str) -> Self {
76        self.zone = Some(zone.to_string());
77        self
78    }
79
80    /// Check if node has a specific tag
81    pub fn has_tag(&self, tag: &str) -> bool {
82        self.tags.iter().any(|t| t == tag)
83    }
84}
85
86/// Node role
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum NodeRole {
89    Primary,
90    Standby,
91    ReadReplica,
92}
93
94/// Sync mode for standby nodes
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum SyncMode {
97    /// Primary node
98    Primary,
99    /// Fully synchronous replication
100    Sync,
101    /// Semi-synchronous replication
102    SemiSync,
103    /// Asynchronous replication
104    Async,
105}
106
107impl SyncMode {
108    /// Check if this mode matches a route target
109    pub fn matches_target(&self, target: RouteTarget) -> bool {
110        match target {
111            RouteTarget::Primary => *self == SyncMode::Primary,
112            RouteTarget::Sync => *self == SyncMode::Sync,
113            RouteTarget::SemiSync => *self == SyncMode::SemiSync,
114            RouteTarget::Async => *self == SyncMode::Async,
115            RouteTarget::Standby => {
116                matches!(self, SyncMode::Sync | SyncMode::SemiSync | SyncMode::Async)
117            }
118            RouteTarget::Any => true,
119            RouteTarget::Local => true,  // Handled separately
120            RouteTarget::Vector => true, // Handled via tags
121        }
122    }
123}
124
125/// Node filter for routing decisions
126#[derive(Debug)]
127pub struct NodeFilter {
128    /// Routing configuration
129    config: RoutingConfig,
130    /// Local zone (for local routing)
131    local_zone: Option<String>,
132}
133
134impl NodeFilter {
135    /// Create a new node filter
136    pub fn new(config: RoutingConfig) -> Self {
137        Self {
138            config,
139            local_zone: None,
140        }
141    }
142
143    /// Set local zone
144    pub fn with_local_zone(mut self, zone: &str) -> Self {
145        self.local_zone = Some(zone.to_string());
146        self
147    }
148
149    /// Filter nodes based on criteria
150    pub fn filter<'a>(&self, nodes: &'a [NodeInfo], criteria: &NodeCriteria) -> FilterResult<'a> {
151        let mut eligible: Vec<&NodeInfo> =
152            nodes.iter().filter(|n| n.healthy && n.enabled).collect();
153
154        let mut reasons = Vec::new();
155
156        // Filter by specific node name
157        if let Some(ref name) = criteria.node_name {
158            let count_before = eligible.len();
159            eligible.retain(|n| n.name == *name);
160            if eligible.len() < count_before {
161                reasons.push(format!("Filtered to node: {}", name));
162            }
163        }
164
165        // Filter by route target
166        if let Some(target) = criteria.route {
167            let count_before = eligible.len();
168            eligible.retain(|n| self.matches_route_target(n, target));
169            if eligible.len() < count_before {
170                reasons.push(format!("Filtered by route target: {:?}", target));
171            }
172        }
173
174        // Filter by consistency level
175        if let Some(consistency) = criteria.consistency {
176            let count_before = eligible.len();
177            eligible.retain(|n| self.meets_consistency(n, consistency, criteria.max_lag));
178            if eligible.len() < count_before {
179                reasons.push(format!("Filtered by consistency: {:?}", consistency));
180            }
181        }
182
183        // Filter by max lag
184        if let Some(max_lag) = criteria.max_lag {
185            let count_before = eligible.len();
186            let max_lag_ms = max_lag.as_millis() as u64;
187            eligible.retain(|n| n.lag_ms <= max_lag_ms);
188            if eligible.len() < count_before {
189                reasons.push(format!("Filtered by max lag: {}ms", max_lag_ms));
190            }
191        }
192
193        // Filter by tags
194        if !criteria.required_tags.is_empty() {
195            let count_before = eligible.len();
196            eligible.retain(|n| criteria.required_tags.iter().all(|tag| n.has_tag(tag)));
197            if eligible.len() < count_before {
198                reasons.push(format!("Filtered by tags: {:?}", criteria.required_tags));
199            }
200        }
201
202        // Handle local routing
203        if criteria.route == Some(RouteTarget::Local) {
204            if let Some(ref local_zone) = self.local_zone {
205                let local_nodes: Vec<_> = eligible
206                    .iter()
207                    .filter(|n| n.zone.as_ref() == Some(local_zone))
208                    .copied()
209                    .collect();
210
211                if !local_nodes.is_empty() {
212                    eligible = local_nodes;
213                    reasons.push(format!("Preferred local zone: {}", local_zone));
214                }
215            }
216        }
217
218        // Handle vector routing
219        if criteria.route == Some(RouteTarget::Vector) {
220            let vector_nodes: Vec<_> = eligible
221                .iter()
222                .filter(|n| n.has_tag("vector"))
223                .copied()
224                .collect();
225
226            if !vector_nodes.is_empty() {
227                eligible = vector_nodes;
228                reasons.push("Filtered to vector-capable nodes".to_string());
229            }
230        }
231
232        // Resolve aliases
233        if let Some(ref alias) = criteria.alias {
234            if let Some(alias_nodes) = self.config.resolve_alias(alias) {
235                let count_before = eligible.len();
236                eligible.retain(|n| alias_nodes.contains(&n.name));
237                if eligible.len() < count_before {
238                    reasons.push(format!("Resolved alias: {}", alias));
239                }
240            }
241        }
242
243        FilterResult {
244            eligible,
245            reasons,
246            fallback_used: false,
247        }
248    }
249
250    /// Check if node matches route target
251    fn matches_route_target(&self, node: &NodeInfo, target: RouteTarget) -> bool {
252        match target {
253            RouteTarget::Primary => node.role == NodeRole::Primary,
254            RouteTarget::Standby => node.role == NodeRole::Standby,
255            RouteTarget::Sync => node.sync_mode == SyncMode::Sync,
256            RouteTarget::SemiSync => node.sync_mode == SyncMode::SemiSync,
257            RouteTarget::Async => node.sync_mode == SyncMode::Async,
258            RouteTarget::Any => true,
259            RouteTarget::Local => true, // Handled in filter()
260            RouteTarget::Vector => node.has_tag("vector"),
261        }
262    }
263
264    /// Check if node meets consistency requirements
265    fn meets_consistency(
266        &self,
267        node: &NodeInfo,
268        level: ConsistencyLevel,
269        max_lag: Option<Duration>,
270    ) -> bool {
271        let config = match self.config.get_consistency_config(level) {
272            Some(c) => c,
273            None => return true, // No config = allow all
274        };
275
276        // Check if node name matches allowed patterns
277        if !config.allows_node(&node.name)
278            && !config.allows_node(&format!("{:?}", node.role).to_lowercase())
279        {
280            return false;
281        }
282
283        // Check lag constraint
284        let max_lag_ms = max_lag
285            .map(|d| d.as_millis() as u64)
286            .unwrap_or(config.max_lag_ms);
287
288        if max_lag_ms < u64::MAX && node.lag_ms > max_lag_ms {
289            return false;
290        }
291
292        true
293    }
294
295    /// Get default criteria for a query type
296    pub fn default_criteria_for_read(&self) -> NodeCriteria {
297        NodeCriteria {
298            route: Some(self.config.default.read_target),
299            consistency: Some(self.config.default.consistency),
300            ..Default::default()
301        }
302    }
303
304    /// Get default criteria for a write
305    pub fn default_criteria_for_write(&self) -> NodeCriteria {
306        NodeCriteria {
307            route: Some(self.config.default.write_target),
308            consistency: Some(ConsistencyLevel::Strong),
309            ..Default::default()
310        }
311    }
312}
313
314/// Criteria for node filtering
315#[derive(Debug, Clone, Default)]
316pub struct NodeCriteria {
317    /// Specific node name
318    pub node_name: Option<String>,
319    /// Route target
320    pub route: Option<RouteTarget>,
321    /// Consistency level
322    pub consistency: Option<ConsistencyLevel>,
323    /// Maximum acceptable lag
324    pub max_lag: Option<Duration>,
325    /// Required tags
326    pub required_tags: Vec<String>,
327    /// Alias to resolve
328    pub alias: Option<String>,
329    /// Branch name (for branch-aware routing)
330    pub branch: Option<String>,
331}
332
333impl NodeCriteria {
334    /// Create criteria from parsed hints
335    pub fn from_hints(hints: &ParsedHints) -> Self {
336        Self {
337            node_name: hints.node.clone(),
338            route: hints.route,
339            consistency: hints.consistency,
340            max_lag: hints.max_lag,
341            required_tags: Vec::new(),
342            alias: None,
343            branch: hints.branch.clone(),
344        }
345    }
346
347    /// Add a required tag
348    pub fn with_tag(mut self, tag: &str) -> Self {
349        self.required_tags.push(tag.to_string());
350        self
351    }
352
353    /// Set alias
354    pub fn with_alias(mut self, alias: &str) -> Self {
355        self.alias = Some(alias.to_string());
356        self
357    }
358}
359
360/// Result of node filtering
361#[derive(Debug)]
362pub struct FilterResult<'a> {
363    /// Eligible nodes after filtering
364    pub eligible: Vec<&'a NodeInfo>,
365    /// Reasons for filtering decisions
366    pub reasons: Vec<String>,
367    /// Whether fallback was used
368    pub fallback_used: bool,
369}
370
371impl<'a> FilterResult<'a> {
372    /// Check if any nodes match
373    pub fn has_matches(&self) -> bool {
374        !self.eligible.is_empty()
375    }
376
377    /// Get number of matches
378    pub fn count(&self) -> usize {
379        self.eligible.len()
380    }
381
382    /// Get first match
383    pub fn first(&self) -> Option<&'a NodeInfo> {
384        self.eligible.first().copied()
385    }
386
387    /// Convert to error if no matches
388    pub fn require_match(&self, context: &str) -> Result<&'a NodeInfo> {
389        self.first().ok_or_else(|| {
390            RoutingError::NoMatchingNodes(format!(
391                "{}: reasons: {}",
392                context,
393                self.reasons.join(", ")
394            ))
395        })
396    }
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    fn test_nodes() -> Vec<NodeInfo> {
404        vec![
405            NodeInfo::primary("primary"),
406            NodeInfo::standby("standby-sync-1", SyncMode::Sync),
407            NodeInfo::standby("standby-async-1", SyncMode::Async).with_lag(500),
408            NodeInfo::standby("standby-async-2", SyncMode::Async).with_lag(5000),
409            NodeInfo::standby("standby-vector-1", SyncMode::Async)
410                .with_tags(vec!["vector".to_string()]),
411        ]
412    }
413
414    #[test]
415    fn test_filter_by_route_target() {
416        let filter = NodeFilter::new(RoutingConfig::default());
417        let nodes = test_nodes();
418
419        // Filter for primary
420        let criteria = NodeCriteria {
421            route: Some(RouteTarget::Primary),
422            ..Default::default()
423        };
424        let result = filter.filter(&nodes, &criteria);
425        assert_eq!(result.count(), 1);
426        assert_eq!(result.first().unwrap().name, "primary");
427
428        // Filter for any standby
429        let criteria = NodeCriteria {
430            route: Some(RouteTarget::Standby),
431            ..Default::default()
432        };
433        let result = filter.filter(&nodes, &criteria);
434        assert_eq!(result.count(), 4);
435    }
436
437    #[test]
438    fn test_filter_by_sync_mode() {
439        let filter = NodeFilter::new(RoutingConfig::default());
440        let nodes = test_nodes();
441
442        let criteria = NodeCriteria {
443            route: Some(RouteTarget::Sync),
444            ..Default::default()
445        };
446        let result = filter.filter(&nodes, &criteria);
447        assert_eq!(result.count(), 1);
448        assert_eq!(result.first().unwrap().name, "standby-sync-1");
449    }
450
451    #[test]
452    fn test_filter_by_max_lag() {
453        let filter = NodeFilter::new(RoutingConfig::default());
454        let nodes = test_nodes();
455
456        let criteria = NodeCriteria {
457            max_lag: Some(Duration::from_millis(1000)),
458            ..Default::default()
459        };
460        let result = filter.filter(&nodes, &criteria);
461
462        // Should exclude standby-async-2 (5000ms lag)
463        assert!(result.eligible.iter().all(|n| n.lag_ms <= 1000));
464    }
465
466    #[test]
467    fn test_filter_by_node_name() {
468        let filter = NodeFilter::new(RoutingConfig::default());
469        let nodes = test_nodes();
470
471        let criteria = NodeCriteria {
472            node_name: Some("standby-sync-1".to_string()),
473            ..Default::default()
474        };
475        let result = filter.filter(&nodes, &criteria);
476        assert_eq!(result.count(), 1);
477        assert_eq!(result.first().unwrap().name, "standby-sync-1");
478    }
479
480    #[test]
481    fn test_filter_by_tag() {
482        let filter = NodeFilter::new(RoutingConfig::default());
483        let nodes = test_nodes();
484
485        let criteria = NodeCriteria {
486            route: Some(RouteTarget::Vector),
487            ..Default::default()
488        };
489        let result = filter.filter(&nodes, &criteria);
490        assert_eq!(result.count(), 1);
491        assert_eq!(result.first().unwrap().name, "standby-vector-1");
492    }
493
494    #[test]
495    fn test_filter_with_alias() {
496        let mut config = RoutingConfig::default();
497        config.add_alias(
498            "analytics",
499            vec!["standby-async-1".to_string(), "standby-async-2".to_string()],
500        );
501
502        let filter = NodeFilter::new(config);
503        let nodes = test_nodes();
504
505        let criteria = NodeCriteria {
506            alias: Some("analytics".to_string()),
507            ..Default::default()
508        };
509        let result = filter.filter(&nodes, &criteria);
510        assert_eq!(result.count(), 2);
511    }
512
513    #[test]
514    fn test_local_zone_preference() {
515        let filter = NodeFilter::new(RoutingConfig::default()).with_local_zone("us-west-1");
516
517        let nodes = vec![
518            NodeInfo::standby("standby-1", SyncMode::Async).with_zone("us-east-1"),
519            NodeInfo::standby("standby-2", SyncMode::Async).with_zone("us-west-1"),
520        ];
521
522        let criteria = NodeCriteria {
523            route: Some(RouteTarget::Local),
524            ..Default::default()
525        };
526        let result = filter.filter(&nodes, &criteria);
527        assert_eq!(result.count(), 1);
528        assert_eq!(result.first().unwrap().name, "standby-2");
529    }
530
531    #[test]
532    fn test_no_match_error() {
533        let filter = NodeFilter::new(RoutingConfig::default());
534        let nodes = test_nodes();
535
536        let criteria = NodeCriteria {
537            node_name: Some("nonexistent".to_string()),
538            ..Default::default()
539        };
540        let result = filter.filter(&nodes, &criteria);
541        assert!(!result.has_matches());
542
543        let err = result.require_match("test context");
544        assert!(err.is_err());
545    }
546
547    #[test]
548    fn test_from_hints() {
549        let parser = super::super::HintParser::new();
550        let hints = parser.parse("/*helios:route=sync,lag=100ms*/ SELECT 1");
551
552        let criteria = NodeCriteria::from_hints(&hints);
553        assert_eq!(criteria.route, Some(RouteTarget::Sync));
554        assert_eq!(criteria.max_lag, Some(Duration::from_millis(100)));
555    }
556}