1use super::{parse_duration, Result, RoutingError};
14use once_cell::sync::Lazy;
15use regex::Regex;
16use std::str::FromStr;
17use std::time::Duration;
18
19#[cfg(feature = "pool-modes")]
20use crate::pool::PoolingMode;
21
22static HINT_REGEX: Lazy<Regex> =
24 Lazy::new(|| Regex::new(r"/\*\s*helios:([^*]+)\*/").expect("Invalid hint regex"));
25
26static KV_REGEX: Lazy<Regex> =
28 Lazy::new(|| Regex::new(r"(\w+)\s*=\s*([^,\s]+)").expect("Invalid key-value regex"));
29
30#[derive(Debug, Clone, Default)]
32pub struct HintParser {
33 pub strip_hints: bool,
35}
36
37impl HintParser {
38 pub fn new() -> Self {
40 Self { strip_hints: true }
41 }
42
43 pub fn without_stripping() -> Self {
45 Self { strip_hints: false }
46 }
47
48 pub fn parse(&self, query: &str) -> ParsedHints {
50 let mut hints = ParsedHints::default();
51
52 for cap in HINT_REGEX.captures_iter(query) {
53 let hint_content = cap.get(1).map(|m| m.as_str()).unwrap_or("");
54
55 for kv in KV_REGEX.captures_iter(hint_content) {
57 let key = kv.get(1).map(|m| m.as_str()).unwrap_or("");
58 let value = kv.get(2).map(|m| m.as_str()).unwrap_or("");
59
60 if let Some(hint) = self.parse_hint(key, value) {
61 hints.add(hint);
62 }
63 }
64 }
65
66 hints
67 }
68
69 fn parse_hint(&self, key: &str, value: &str) -> Option<RoutingHint> {
71 match key.to_lowercase().as_str() {
72 "route" => RouteTarget::from_str(value).ok().map(RoutingHint::Route),
73 "node" => Some(RoutingHint::Node(value.to_string())),
74 "consistency" => ConsistencyLevel::from_str(value)
75 .ok()
76 .map(RoutingHint::Consistency),
77 "pool" => PoolingModeHint::from_str(value).ok().map(RoutingHint::Pool),
78 "cache" => CacheBehavior::from_str(value).ok().map(RoutingHint::Cache),
79 "timeout" => parse_duration(value).map(RoutingHint::Timeout),
80 "priority" => QueryPriority::from_str(value)
81 .ok()
82 .map(RoutingHint::Priority),
83 "lag" => parse_duration(value).map(RoutingHint::MaxLag),
84 "retry" => self.parse_retry(value).map(RoutingHint::Retry),
85 "branch" => Some(RoutingHint::Branch(value.to_string())),
86 "twr" => value
87 .parse::<bool>()
88 .ok()
89 .map(RoutingHint::TransparentWriteRouting),
90 "tool" => Some(RoutingHint::AgentTool(value.to_string())),
91 "workflow" => Some(RoutingHint::WorkflowStep(value.to_string())),
92 "prefetch" => value.parse::<bool>().ok().map(RoutingHint::Prefetch),
93 "cache_ttl" => value
94 .parse::<u64>()
95 .ok()
96 .map(|s| RoutingHint::CacheTtl(Duration::from_secs(s))),
97 _ => None,
98 }
99 }
100
101 fn parse_retry(&self, value: &str) -> Option<RetryBehavior> {
103 match value.to_lowercase().as_str() {
104 "true" | "yes" => Some(RetryBehavior::Auto),
105 "false" | "no" => Some(RetryBehavior::None),
106 _ => value.parse::<u32>().ok().map(RetryBehavior::Count),
107 }
108 }
109
110 pub fn strip(&self, query: &str) -> String {
112 HINT_REGEX.replace_all(query, "").trim().to_string()
113 }
114
115 pub fn extract_raw(&self, query: &str) -> Vec<String> {
117 HINT_REGEX
118 .captures_iter(query)
119 .filter_map(|cap| cap.get(0).map(|m| m.as_str().to_string()))
120 .collect()
121 }
122}
123
124#[derive(Debug, Clone, Default)]
126pub struct ParsedHints {
127 hints: Vec<RoutingHint>,
129 pub route: Option<RouteTarget>,
131 pub node: Option<String>,
133 pub consistency: Option<ConsistencyLevel>,
135 pub pool: Option<PoolingModeHint>,
137 pub cache: Option<CacheBehavior>,
139 pub timeout: Option<Duration>,
141 pub priority: Option<QueryPriority>,
143 pub max_lag: Option<Duration>,
145 pub retry: Option<RetryBehavior>,
147 pub branch: Option<String>,
149 pub twr: Option<bool>,
151 pub cache_ttl: Option<Duration>,
153}
154
155impl ParsedHints {
156 pub fn add(&mut self, hint: RoutingHint) {
158 match &hint {
159 RoutingHint::Route(target) => self.route = Some(*target),
160 RoutingHint::Node(name) => self.node = Some(name.clone()),
161 RoutingHint::Consistency(level) => self.consistency = Some(*level),
162 RoutingHint::Pool(mode) => self.pool = Some(*mode),
163 RoutingHint::Cache(behavior) => self.cache = Some(*behavior),
164 RoutingHint::Timeout(dur) => self.timeout = Some(*dur),
165 RoutingHint::Priority(pri) => self.priority = Some(*pri),
166 RoutingHint::MaxLag(dur) => self.max_lag = Some(*dur),
167 RoutingHint::Retry(retry) => self.retry = Some(retry.clone()),
168 RoutingHint::Branch(name) => self.branch = Some(name.clone()),
169 RoutingHint::TransparentWriteRouting(enabled) => self.twr = Some(*enabled),
170 RoutingHint::CacheTtl(dur) => self.cache_ttl = Some(*dur),
171 _ => {}
172 }
173 self.hints.push(hint);
174 }
175
176 pub fn is_empty(&self) -> bool {
178 self.hints.is_empty()
179 }
180
181 pub fn len(&self) -> usize {
183 self.hints.len()
184 }
185
186 pub fn hints(&self) -> &[RoutingHint] {
188 &self.hints
189 }
190
191 pub fn is_primary_route(&self) -> bool {
193 matches!(self.route, Some(RouteTarget::Primary))
194 }
195
196 pub fn is_standby_route(&self) -> bool {
198 matches!(
199 self.route,
200 Some(RouteTarget::Standby)
201 | Some(RouteTarget::Sync)
202 | Some(RouteTarget::SemiSync)
203 | Some(RouteTarget::Async)
204 )
205 }
206
207 pub fn validate(&self) -> Result<()> {
209 if let (Some(RouteTarget::Async), Some(ConsistencyLevel::Strong)) =
211 (self.route, self.consistency)
212 {
213 return Err(RoutingError::InvalidHintCombination(
214 "route=async and consistency=strong are incompatible".to_string(),
215 ));
216 }
217
218 if self.consistency == Some(ConsistencyLevel::Bounded) && self.max_lag.is_none() {
220 }
222
223 Ok(())
224 }
225}
226
227#[derive(Debug, Clone, PartialEq)]
229pub enum RoutingHint {
230 Route(RouteTarget),
232
233 Node(String),
235
236 Consistency(ConsistencyLevel),
238
239 Pool(PoolingModeHint),
241
242 Cache(CacheBehavior),
244
245 Timeout(Duration),
247
248 Priority(QueryPriority),
250
251 MaxLag(Duration),
253
254 Retry(RetryBehavior),
256
257 Branch(String),
259
260 TransparentWriteRouting(bool),
262
263 AgentTool(String),
265
266 WorkflowStep(String),
268
269 Prefetch(bool),
271
272 CacheTtl(Duration),
274}
275
276#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
278pub enum RouteTarget {
279 Primary,
281 Standby,
283 Sync,
285 SemiSync,
287 Async,
289 Any,
291 Local,
293 Vector,
295}
296
297impl FromStr for RouteTarget {
298 type Err = RoutingError;
299
300 fn from_str(s: &str) -> Result<Self> {
301 match s.to_lowercase().as_str() {
302 "primary" | "master" | "leader" => Ok(RouteTarget::Primary),
303 "standby" | "replica" | "secondary" => Ok(RouteTarget::Standby),
304 "sync" | "synchronous" => Ok(RouteTarget::Sync),
305 "semisync" | "semi-sync" | "semi_sync" => Ok(RouteTarget::SemiSync),
306 "async" | "asynchronous" => Ok(RouteTarget::Async),
307 "any" | "all" => Ok(RouteTarget::Any),
308 "local" | "nearest" => Ok(RouteTarget::Local),
309 "vector" => Ok(RouteTarget::Vector),
310 _ => Err(RoutingError::ParseError(format!(
311 "Unknown route target: {}",
312 s
313 ))),
314 }
315 }
316}
317
318impl std::fmt::Display for RouteTarget {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 match self {
321 RouteTarget::Primary => write!(f, "primary"),
322 RouteTarget::Standby => write!(f, "standby"),
323 RouteTarget::Sync => write!(f, "sync"),
324 RouteTarget::SemiSync => write!(f, "semisync"),
325 RouteTarget::Async => write!(f, "async"),
326 RouteTarget::Any => write!(f, "any"),
327 RouteTarget::Local => write!(f, "local"),
328 RouteTarget::Vector => write!(f, "vector"),
329 }
330 }
331}
332
333#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
335pub enum ConsistencyLevel {
336 Strong,
338 Bounded,
340 Eventual,
342}
343
344impl FromStr for ConsistencyLevel {
345 type Err = RoutingError;
346
347 fn from_str(s: &str) -> Result<Self> {
348 match s.to_lowercase().as_str() {
349 "strong" | "strict" | "linearizable" => Ok(ConsistencyLevel::Strong),
350 "bounded" | "session" | "read-your-writes" => Ok(ConsistencyLevel::Bounded),
351 "eventual" | "weak" => Ok(ConsistencyLevel::Eventual),
352 _ => Err(RoutingError::ParseError(format!(
353 "Unknown consistency level: {}",
354 s
355 ))),
356 }
357 }
358}
359
360impl std::fmt::Display for ConsistencyLevel {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 match self {
363 ConsistencyLevel::Strong => write!(f, "strong"),
364 ConsistencyLevel::Bounded => write!(f, "bounded"),
365 ConsistencyLevel::Eventual => write!(f, "eventual"),
366 }
367 }
368}
369
370#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
372pub enum PoolingModeHint {
373 Session,
374 Transaction,
375 Statement,
376}
377
378impl FromStr for PoolingModeHint {
379 type Err = RoutingError;
380
381 fn from_str(s: &str) -> Result<Self> {
382 match s.to_lowercase().as_str() {
383 "session" => Ok(PoolingModeHint::Session),
384 "transaction" | "tx" => Ok(PoolingModeHint::Transaction),
385 "statement" | "stmt" | "query" => Ok(PoolingModeHint::Statement),
386 _ => Err(RoutingError::ParseError(format!(
387 "Unknown pool mode: {}",
388 s
389 ))),
390 }
391 }
392}
393
394#[cfg(feature = "pool-modes")]
395impl From<PoolingModeHint> for PoolingMode {
396 fn from(hint: PoolingModeHint) -> Self {
397 match hint {
398 PoolingModeHint::Session => PoolingMode::Session,
399 PoolingModeHint::Transaction => PoolingMode::Transaction,
400 PoolingModeHint::Statement => PoolingMode::Statement,
401 }
402 }
403}
404
405#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
407pub enum CacheBehavior {
408 Normal,
410 Skip,
412 Refresh,
414 Semantic,
416 L1Only,
418 L2Only,
420}
421
422impl FromStr for CacheBehavior {
423 type Err = RoutingError;
424
425 fn from_str(s: &str) -> Result<Self> {
426 match s.to_lowercase().as_str() {
427 "normal" | "default" => Ok(CacheBehavior::Normal),
428 "skip" | "bypass" | "none" => Ok(CacheBehavior::Skip),
429 "refresh" | "force" | "update" => Ok(CacheBehavior::Refresh),
430 "semantic" | "l3" | "vector" => Ok(CacheBehavior::Semantic),
431 "l1" | "hot" => Ok(CacheBehavior::L1Only),
432 "l2" | "warm" => Ok(CacheBehavior::L2Only),
433 _ => Err(RoutingError::ParseError(format!(
434 "Unknown cache behavior: {}",
435 s
436 ))),
437 }
438 }
439}
440
441#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
443pub enum QueryPriority {
444 Low = 0,
445 #[default]
446 Normal = 1,
447 High = 2,
448 Critical = 3,
449}
450
451impl FromStr for QueryPriority {
452 type Err = RoutingError;
453
454 fn from_str(s: &str) -> Result<Self> {
455 match s.to_lowercase().as_str() {
456 "low" | "background" => Ok(QueryPriority::Low),
457 "normal" | "default" => Ok(QueryPriority::Normal),
458 "high" | "elevated" => Ok(QueryPriority::High),
459 "critical" | "urgent" | "realtime" => Ok(QueryPriority::Critical),
460 _ => Err(RoutingError::ParseError(format!("Unknown priority: {}", s))),
461 }
462 }
463}
464
465#[derive(Debug, Clone, PartialEq, Default)]
467pub enum RetryBehavior {
468 None,
470 #[default]
472 Auto,
473 Count(u32),
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_parse_single_hint() {
483 let parser = HintParser::new();
484 let hints = parser.parse("/*helios:route=primary*/ SELECT * FROM users");
485
486 assert!(!hints.is_empty());
487 assert_eq!(hints.route, Some(RouteTarget::Primary));
488 }
489
490 #[test]
491 fn test_parse_multiple_hints() {
492 let parser = HintParser::new();
493 let hints = parser.parse(
494 "/*helios:route=standby,consistency=eventual,timeout=5s*/ SELECT * FROM products",
495 );
496
497 assert_eq!(hints.len(), 3);
498 assert_eq!(hints.route, Some(RouteTarget::Standby));
499 assert_eq!(hints.consistency, Some(ConsistencyLevel::Eventual));
500 assert_eq!(hints.timeout, Some(Duration::from_secs(5)));
501 }
502
503 #[test]
504 fn test_parse_node_hint() {
505 let parser = HintParser::new();
506 let hints = parser.parse("/*helios:node=standby-sync-1*/ SELECT * FROM logs");
507
508 assert_eq!(hints.node, Some("standby-sync-1".to_string()));
509 }
510
511 #[test]
512 fn test_parse_lag_hint() {
513 let parser = HintParser::new();
514 let hints = parser.parse("/*helios:route=async,lag=10s*/ SELECT COUNT(*) FROM events");
515
516 assert_eq!(hints.route, Some(RouteTarget::Async));
517 assert_eq!(hints.max_lag, Some(Duration::from_secs(10)));
518 }
519
520 #[test]
521 fn test_parse_priority_hint() {
522 let parser = HintParser::new();
523 let hints = parser.parse("/*helios:priority=critical*/ SELECT balance FROM accounts");
524
525 assert_eq!(hints.priority, Some(QueryPriority::Critical));
526 }
527
528 #[test]
529 fn test_parse_cache_hint() {
530 let parser = HintParser::new();
531 let hints = parser.parse("/*helios:cache=skip*/ SELECT now()");
532
533 assert_eq!(hints.cache, Some(CacheBehavior::Skip));
534 }
535
536 #[test]
537 fn test_parse_pool_hint() {
538 let parser = HintParser::new();
539 let hints = parser.parse("/*helios:pool=transaction*/ BEGIN");
540
541 assert_eq!(hints.pool, Some(PoolingModeHint::Transaction));
542 }
543
544 #[test]
545 fn test_strip_hints() {
546 let parser = HintParser::new();
547 let query = "/*helios:route=primary*/ SELECT * FROM users WHERE id = 1";
548 let stripped = parser.strip(query);
549
550 assert_eq!(stripped, "SELECT * FROM users WHERE id = 1");
551 }
552
553 #[test]
554 fn test_strip_multiple_hints() {
555 let parser = HintParser::new();
556 let query = "/*helios:route=standby*/ SELECT * /*helios:cache=skip*/ FROM users";
557 let stripped = parser.strip(query);
558
559 assert_eq!(stripped, "SELECT * FROM users");
560 }
561
562 #[test]
563 fn test_validate_conflicting_hints() {
564 let parser = HintParser::new();
565 let hints = parser.parse("/*helios:route=async,consistency=strong*/ SELECT * FROM users");
566
567 let result = hints.validate();
568 assert!(result.is_err());
569 }
570
571 #[test]
572 fn test_route_target_parsing() {
573 assert_eq!(
574 RouteTarget::from_str("primary").unwrap(),
575 RouteTarget::Primary
576 );
577 assert_eq!(
578 RouteTarget::from_str("master").unwrap(),
579 RouteTarget::Primary
580 );
581 assert_eq!(
582 RouteTarget::from_str("standby").unwrap(),
583 RouteTarget::Standby
584 );
585 assert_eq!(
586 RouteTarget::from_str("replica").unwrap(),
587 RouteTarget::Standby
588 );
589 assert_eq!(RouteTarget::from_str("sync").unwrap(), RouteTarget::Sync);
590 assert_eq!(RouteTarget::from_str("async").unwrap(), RouteTarget::Async);
591 assert_eq!(RouteTarget::from_str("local").unwrap(), RouteTarget::Local);
592 }
593
594 #[test]
595 fn test_consistency_level_parsing() {
596 assert_eq!(
597 ConsistencyLevel::from_str("strong").unwrap(),
598 ConsistencyLevel::Strong
599 );
600 assert_eq!(
601 ConsistencyLevel::from_str("bounded").unwrap(),
602 ConsistencyLevel::Bounded
603 );
604 assert_eq!(
605 ConsistencyLevel::from_str("eventual").unwrap(),
606 ConsistencyLevel::Eventual
607 );
608 }
609
610 #[test]
611 fn test_query_priority_ordering() {
612 assert!(QueryPriority::Critical > QueryPriority::High);
613 assert!(QueryPriority::High > QueryPriority::Normal);
614 assert!(QueryPriority::Normal > QueryPriority::Low);
615 }
616
617 #[test]
618 fn test_ai_workflow_hints() {
619 let parser = HintParser::new();
620 let hints = parser.parse(
621 "/*helios:route=async,tool=knowledge_search,workflow=planning*/ SELECT content FROM docs"
622 );
623
624 assert!(!hints.is_empty());
625 assert_eq!(hints.route, Some(RouteTarget::Async));
626
627 let has_tool = hints
629 .hints()
630 .iter()
631 .any(|h| matches!(h, RoutingHint::AgentTool(t) if t == "knowledge_search"));
632 let has_workflow = hints
633 .hints()
634 .iter()
635 .any(|h| matches!(h, RoutingHint::WorkflowStep(w) if w == "planning"));
636
637 assert!(has_tool);
638 assert!(has_workflow);
639 }
640
641 #[test]
642 fn test_branch_hint() {
643 let parser = HintParser::new();
644 let hints = parser.parse("/*helios:branch=analytics,route=local*/ SELECT * FROM reports");
645
646 assert_eq!(hints.branch, Some("analytics".to_string()));
647 assert_eq!(hints.route, Some(RouteTarget::Local));
648 }
649
650 #[test]
651 fn test_twr_hint() {
652 let parser = HintParser::new();
653 let hints = parser.parse("/*helios:route=sync,twr=true*/ INSERT INTO logs VALUES (1)");
654
655 assert_eq!(hints.route, Some(RouteTarget::Sync));
656 assert_eq!(hints.twr, Some(true));
657 }
658
659 #[test]
660 fn test_empty_query() {
661 let parser = HintParser::new();
662 let hints = parser.parse("SELECT * FROM users");
663
664 assert!(hints.is_empty());
665 }
666
667 #[test]
668 fn test_extract_raw() {
669 let parser = HintParser::new();
670 let raw = parser.extract_raw("/*helios:route=primary*/ SELECT /*helios:cache=skip*/ 1");
671
672 assert_eq!(raw.len(), 2);
673 assert!(raw[0].contains("route=primary"));
674 assert!(raw[1].contains("cache=skip"));
675 }
676}