1use super::{ConsistencyLevel, ParsedHints, Result, RouteTarget, RoutingConfig, RoutingError};
6use std::time::Duration;
7
8#[derive(Debug, Clone)]
10pub struct NodeInfo {
11 pub name: String,
13 pub role: NodeRole,
15 pub sync_mode: SyncMode,
17 pub lag_ms: u64,
19 pub healthy: bool,
21 pub enabled: bool,
23 pub weight: u32,
25 pub tags: Vec<String>,
27 pub zone: Option<String>,
29}
30
31impl NodeInfo {
32 pub fn primary(name: &str) -> Self {
34 Self {
35 name: name.to_string(),
36 role: NodeRole::Primary,
37 sync_mode: SyncMode::Primary,
38 lag_ms: 0,
39 healthy: true,
40 enabled: true,
41 weight: 100,
42 tags: Vec::new(),
43 zone: None,
44 }
45 }
46
47 pub fn standby(name: &str, sync_mode: SyncMode) -> Self {
49 Self {
50 name: name.to_string(),
51 role: NodeRole::Standby,
52 sync_mode,
53 lag_ms: 0,
54 healthy: true,
55 enabled: true,
56 weight: 100,
57 tags: Vec::new(),
58 zone: None,
59 }
60 }
61
62 pub fn with_lag(mut self, lag_ms: u64) -> Self {
64 self.lag_ms = lag_ms;
65 self
66 }
67
68 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
70 self.tags = tags;
71 self
72 }
73
74 pub fn with_zone(mut self, zone: &str) -> Self {
76 self.zone = Some(zone.to_string());
77 self
78 }
79
80 pub fn has_tag(&self, tag: &str) -> bool {
82 self.tags.iter().any(|t| t == tag)
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum NodeRole {
89 Primary,
90 Standby,
91 ReadReplica,
92}
93
94#[derive(Debug, Clone, Copy, PartialEq, Eq)]
96pub enum SyncMode {
97 Primary,
99 Sync,
101 SemiSync,
103 Async,
105}
106
107impl SyncMode {
108 pub fn matches_target(&self, target: RouteTarget) -> bool {
110 match target {
111 RouteTarget::Primary => *self == SyncMode::Primary,
112 RouteTarget::Sync => *self == SyncMode::Sync,
113 RouteTarget::SemiSync => *self == SyncMode::SemiSync,
114 RouteTarget::Async => *self == SyncMode::Async,
115 RouteTarget::Standby => {
116 matches!(self, SyncMode::Sync | SyncMode::SemiSync | SyncMode::Async)
117 }
118 RouteTarget::Any => true,
119 RouteTarget::Local => true, RouteTarget::Vector => true, }
122 }
123}
124
125#[derive(Debug)]
127pub struct NodeFilter {
128 config: RoutingConfig,
130 local_zone: Option<String>,
132}
133
134impl NodeFilter {
135 pub fn new(config: RoutingConfig) -> Self {
137 Self {
138 config,
139 local_zone: None,
140 }
141 }
142
143 pub fn with_local_zone(mut self, zone: &str) -> Self {
145 self.local_zone = Some(zone.to_string());
146 self
147 }
148
149 pub fn filter<'a>(&self, nodes: &'a [NodeInfo], criteria: &NodeCriteria) -> FilterResult<'a> {
151 let mut eligible: Vec<&NodeInfo> =
152 nodes.iter().filter(|n| n.healthy && n.enabled).collect();
153
154 let mut reasons = Vec::new();
155
156 if let Some(ref name) = criteria.node_name {
158 let count_before = eligible.len();
159 eligible.retain(|n| n.name == *name);
160 if eligible.len() < count_before {
161 reasons.push(format!("Filtered to node: {}", name));
162 }
163 }
164
165 if let Some(target) = criteria.route {
167 let count_before = eligible.len();
168 eligible.retain(|n| self.matches_route_target(n, target));
169 if eligible.len() < count_before {
170 reasons.push(format!("Filtered by route target: {:?}", target));
171 }
172 }
173
174 if let Some(consistency) = criteria.consistency {
176 let count_before = eligible.len();
177 eligible.retain(|n| self.meets_consistency(n, consistency, criteria.max_lag));
178 if eligible.len() < count_before {
179 reasons.push(format!("Filtered by consistency: {:?}", consistency));
180 }
181 }
182
183 if let Some(max_lag) = criteria.max_lag {
185 let count_before = eligible.len();
186 let max_lag_ms = max_lag.as_millis() as u64;
187 eligible.retain(|n| n.lag_ms <= max_lag_ms);
188 if eligible.len() < count_before {
189 reasons.push(format!("Filtered by max lag: {}ms", max_lag_ms));
190 }
191 }
192
193 if !criteria.required_tags.is_empty() {
195 let count_before = eligible.len();
196 eligible.retain(|n| criteria.required_tags.iter().all(|tag| n.has_tag(tag)));
197 if eligible.len() < count_before {
198 reasons.push(format!("Filtered by tags: {:?}", criteria.required_tags));
199 }
200 }
201
202 if criteria.route == Some(RouteTarget::Local) {
204 if let Some(ref local_zone) = self.local_zone {
205 let local_nodes: Vec<_> = eligible
206 .iter()
207 .filter(|n| n.zone.as_ref() == Some(local_zone))
208 .copied()
209 .collect();
210
211 if !local_nodes.is_empty() {
212 eligible = local_nodes;
213 reasons.push(format!("Preferred local zone: {}", local_zone));
214 }
215 }
216 }
217
218 if criteria.route == Some(RouteTarget::Vector) {
220 let vector_nodes: Vec<_> = eligible
221 .iter()
222 .filter(|n| n.has_tag("vector"))
223 .copied()
224 .collect();
225
226 if !vector_nodes.is_empty() {
227 eligible = vector_nodes;
228 reasons.push("Filtered to vector-capable nodes".to_string());
229 }
230 }
231
232 if let Some(ref alias) = criteria.alias {
234 if let Some(alias_nodes) = self.config.resolve_alias(alias) {
235 let count_before = eligible.len();
236 eligible.retain(|n| alias_nodes.contains(&n.name));
237 if eligible.len() < count_before {
238 reasons.push(format!("Resolved alias: {}", alias));
239 }
240 }
241 }
242
243 FilterResult {
244 eligible,
245 reasons,
246 fallback_used: false,
247 }
248 }
249
250 fn matches_route_target(&self, node: &NodeInfo, target: RouteTarget) -> bool {
252 match target {
253 RouteTarget::Primary => node.role == NodeRole::Primary,
254 RouteTarget::Standby => node.role == NodeRole::Standby,
255 RouteTarget::Sync => node.sync_mode == SyncMode::Sync,
256 RouteTarget::SemiSync => node.sync_mode == SyncMode::SemiSync,
257 RouteTarget::Async => node.sync_mode == SyncMode::Async,
258 RouteTarget::Any => true,
259 RouteTarget::Local => true, RouteTarget::Vector => node.has_tag("vector"),
261 }
262 }
263
264 fn meets_consistency(
266 &self,
267 node: &NodeInfo,
268 level: ConsistencyLevel,
269 max_lag: Option<Duration>,
270 ) -> bool {
271 let config = match self.config.get_consistency_config(level) {
272 Some(c) => c,
273 None => return true, };
275
276 if !config.allows_node(&node.name)
278 && !config.allows_node(&format!("{:?}", node.role).to_lowercase())
279 {
280 return false;
281 }
282
283 let max_lag_ms = max_lag
285 .map(|d| d.as_millis() as u64)
286 .unwrap_or(config.max_lag_ms);
287
288 if max_lag_ms < u64::MAX && node.lag_ms > max_lag_ms {
289 return false;
290 }
291
292 true
293 }
294
295 pub fn default_criteria_for_read(&self) -> NodeCriteria {
297 NodeCriteria {
298 route: Some(self.config.default.read_target),
299 consistency: Some(self.config.default.consistency),
300 ..Default::default()
301 }
302 }
303
304 pub fn default_criteria_for_write(&self) -> NodeCriteria {
306 NodeCriteria {
307 route: Some(self.config.default.write_target),
308 consistency: Some(ConsistencyLevel::Strong),
309 ..Default::default()
310 }
311 }
312}
313
314#[derive(Debug, Clone, Default)]
316pub struct NodeCriteria {
317 pub node_name: Option<String>,
319 pub route: Option<RouteTarget>,
321 pub consistency: Option<ConsistencyLevel>,
323 pub max_lag: Option<Duration>,
325 pub required_tags: Vec<String>,
327 pub alias: Option<String>,
329 pub branch: Option<String>,
331}
332
333impl NodeCriteria {
334 pub fn from_hints(hints: &ParsedHints) -> Self {
336 Self {
337 node_name: hints.node.clone(),
338 route: hints.route,
339 consistency: hints.consistency,
340 max_lag: hints.max_lag,
341 required_tags: Vec::new(),
342 alias: None,
343 branch: hints.branch.clone(),
344 }
345 }
346
347 pub fn with_tag(mut self, tag: &str) -> Self {
349 self.required_tags.push(tag.to_string());
350 self
351 }
352
353 pub fn with_alias(mut self, alias: &str) -> Self {
355 self.alias = Some(alias.to_string());
356 self
357 }
358}
359
360#[derive(Debug)]
362pub struct FilterResult<'a> {
363 pub eligible: Vec<&'a NodeInfo>,
365 pub reasons: Vec<String>,
367 pub fallback_used: bool,
369}
370
371impl<'a> FilterResult<'a> {
372 pub fn has_matches(&self) -> bool {
374 !self.eligible.is_empty()
375 }
376
377 pub fn count(&self) -> usize {
379 self.eligible.len()
380 }
381
382 pub fn first(&self) -> Option<&'a NodeInfo> {
384 self.eligible.first().copied()
385 }
386
387 pub fn require_match(&self, context: &str) -> Result<&'a NodeInfo> {
389 self.first().ok_or_else(|| {
390 RoutingError::NoMatchingNodes(format!(
391 "{}: reasons: {}",
392 context,
393 self.reasons.join(", ")
394 ))
395 })
396 }
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402
403 fn test_nodes() -> Vec<NodeInfo> {
404 vec![
405 NodeInfo::primary("primary"),
406 NodeInfo::standby("standby-sync-1", SyncMode::Sync),
407 NodeInfo::standby("standby-async-1", SyncMode::Async).with_lag(500),
408 NodeInfo::standby("standby-async-2", SyncMode::Async).with_lag(5000),
409 NodeInfo::standby("standby-vector-1", SyncMode::Async)
410 .with_tags(vec!["vector".to_string()]),
411 ]
412 }
413
414 #[test]
415 fn test_filter_by_route_target() {
416 let filter = NodeFilter::new(RoutingConfig::default());
417 let nodes = test_nodes();
418
419 let criteria = NodeCriteria {
421 route: Some(RouteTarget::Primary),
422 ..Default::default()
423 };
424 let result = filter.filter(&nodes, &criteria);
425 assert_eq!(result.count(), 1);
426 assert_eq!(result.first().unwrap().name, "primary");
427
428 let criteria = NodeCriteria {
430 route: Some(RouteTarget::Standby),
431 ..Default::default()
432 };
433 let result = filter.filter(&nodes, &criteria);
434 assert_eq!(result.count(), 4);
435 }
436
437 #[test]
438 fn test_filter_by_sync_mode() {
439 let filter = NodeFilter::new(RoutingConfig::default());
440 let nodes = test_nodes();
441
442 let criteria = NodeCriteria {
443 route: Some(RouteTarget::Sync),
444 ..Default::default()
445 };
446 let result = filter.filter(&nodes, &criteria);
447 assert_eq!(result.count(), 1);
448 assert_eq!(result.first().unwrap().name, "standby-sync-1");
449 }
450
451 #[test]
452 fn test_filter_by_max_lag() {
453 let filter = NodeFilter::new(RoutingConfig::default());
454 let nodes = test_nodes();
455
456 let criteria = NodeCriteria {
457 max_lag: Some(Duration::from_millis(1000)),
458 ..Default::default()
459 };
460 let result = filter.filter(&nodes, &criteria);
461
462 assert!(result.eligible.iter().all(|n| n.lag_ms <= 1000));
464 }
465
466 #[test]
467 fn test_filter_by_node_name() {
468 let filter = NodeFilter::new(RoutingConfig::default());
469 let nodes = test_nodes();
470
471 let criteria = NodeCriteria {
472 node_name: Some("standby-sync-1".to_string()),
473 ..Default::default()
474 };
475 let result = filter.filter(&nodes, &criteria);
476 assert_eq!(result.count(), 1);
477 assert_eq!(result.first().unwrap().name, "standby-sync-1");
478 }
479
480 #[test]
481 fn test_filter_by_tag() {
482 let filter = NodeFilter::new(RoutingConfig::default());
483 let nodes = test_nodes();
484
485 let criteria = NodeCriteria {
486 route: Some(RouteTarget::Vector),
487 ..Default::default()
488 };
489 let result = filter.filter(&nodes, &criteria);
490 assert_eq!(result.count(), 1);
491 assert_eq!(result.first().unwrap().name, "standby-vector-1");
492 }
493
494 #[test]
495 fn test_filter_with_alias() {
496 let mut config = RoutingConfig::default();
497 config.add_alias(
498 "analytics",
499 vec!["standby-async-1".to_string(), "standby-async-2".to_string()],
500 );
501
502 let filter = NodeFilter::new(config);
503 let nodes = test_nodes();
504
505 let criteria = NodeCriteria {
506 alias: Some("analytics".to_string()),
507 ..Default::default()
508 };
509 let result = filter.filter(&nodes, &criteria);
510 assert_eq!(result.count(), 2);
511 }
512
513 #[test]
514 fn test_local_zone_preference() {
515 let filter = NodeFilter::new(RoutingConfig::default()).with_local_zone("us-west-1");
516
517 let nodes = vec![
518 NodeInfo::standby("standby-1", SyncMode::Async).with_zone("us-east-1"),
519 NodeInfo::standby("standby-2", SyncMode::Async).with_zone("us-west-1"),
520 ];
521
522 let criteria = NodeCriteria {
523 route: Some(RouteTarget::Local),
524 ..Default::default()
525 };
526 let result = filter.filter(&nodes, &criteria);
527 assert_eq!(result.count(), 1);
528 assert_eq!(result.first().unwrap().name, "standby-2");
529 }
530
531 #[test]
532 fn test_no_match_error() {
533 let filter = NodeFilter::new(RoutingConfig::default());
534 let nodes = test_nodes();
535
536 let criteria = NodeCriteria {
537 node_name: Some("nonexistent".to_string()),
538 ..Default::default()
539 };
540 let result = filter.filter(&nodes, &criteria);
541 assert!(!result.has_matches());
542
543 let err = result.require_match("test context");
544 assert!(err.is_err());
545 }
546
547 #[test]
548 fn test_from_hints() {
549 let parser = super::super::HintParser::new();
550 let hints = parser.parse("/*helios:route=sync,lag=100ms*/ SELECT 1");
551
552 let criteria = NodeCriteria::from_hints(&hints);
553 assert_eq!(criteria.route, Some(RouteTarget::Sync));
554 assert_eq!(criteria.max_lag, Some(Duration::from_millis(100)));
555 }
556}