1use super::{parse_duration, RoutingError, Result};
14use regex::Regex;
15use std::str::FromStr;
16use std::sync::LazyLock;
17use std::time::Duration;
18
19#[cfg(feature = "pool-modes")]
20use crate::pool::PoolingMode;
21
22static HINT_REGEX: LazyLock<Regex> = LazyLock::new(|| {
24 Regex::new(r"/\*\s*helios:([^*]+)\*/").expect("Invalid hint regex")
25});
26
27static KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
29 Regex::new(r"(\w+)\s*=\s*([^,\s]+)").expect("Invalid key-value regex")
30});
31
32#[derive(Debug, Clone, Default)]
34pub struct HintParser {
35 pub strip_hints: bool,
37}
38
39impl HintParser {
40 pub fn new() -> Self {
42 Self { strip_hints: true }
43 }
44
45 pub fn without_stripping() -> Self {
47 Self { strip_hints: false }
48 }
49
50 pub fn parse(&self, query: &str) -> ParsedHints {
52 let mut hints = ParsedHints::default();
53
54 for cap in HINT_REGEX.captures_iter(query) {
55 let hint_content = cap.get(1).map(|m| m.as_str()).unwrap_or("");
56
57 for kv in KV_REGEX.captures_iter(hint_content) {
59 let key = kv.get(1).map(|m| m.as_str()).unwrap_or("");
60 let value = kv.get(2).map(|m| m.as_str()).unwrap_or("");
61
62 if let Some(hint) = self.parse_hint(key, value) {
63 hints.add(hint);
64 }
65 }
66 }
67
68 hints
69 }
70
71 fn parse_hint(&self, key: &str, value: &str) -> Option<RoutingHint> {
73 match key.to_lowercase().as_str() {
74 "route" => RouteTarget::from_str(value).ok().map(RoutingHint::Route),
75 "node" => Some(RoutingHint::Node(value.to_string())),
76 "consistency" => ConsistencyLevel::from_str(value).ok().map(RoutingHint::Consistency),
77 "pool" => PoolingModeHint::from_str(value).ok().map(RoutingHint::Pool),
78 "cache" => CacheBehavior::from_str(value).ok().map(RoutingHint::Cache),
79 "timeout" => parse_duration(value).map(RoutingHint::Timeout),
80 "priority" => QueryPriority::from_str(value).ok().map(RoutingHint::Priority),
81 "lag" => parse_duration(value).map(RoutingHint::MaxLag),
82 "retry" => self.parse_retry(value).map(RoutingHint::Retry),
83 "branch" => Some(RoutingHint::Branch(value.to_string())),
84 "twr" => value.parse::<bool>().ok().map(RoutingHint::TransparentWriteRouting),
85 "tool" => Some(RoutingHint::AgentTool(value.to_string())),
86 "workflow" => Some(RoutingHint::WorkflowStep(value.to_string())),
87 "prefetch" => value.parse::<bool>().ok().map(RoutingHint::Prefetch),
88 "cache_ttl" => value.parse::<u64>().ok().map(|s| RoutingHint::CacheTtl(Duration::from_secs(s))),
89 _ => None,
90 }
91 }
92
93 fn parse_retry(&self, value: &str) -> Option<RetryBehavior> {
95 match value.to_lowercase().as_str() {
96 "true" | "yes" => Some(RetryBehavior::Auto),
97 "false" | "no" => Some(RetryBehavior::None),
98 _ => value.parse::<u32>().ok().map(RetryBehavior::Count),
99 }
100 }
101
102 pub fn strip(&self, query: &str) -> String {
104 HINT_REGEX.replace_all(query, "").trim().to_string()
105 }
106
107 pub fn extract_raw(&self, query: &str) -> Vec<String> {
109 HINT_REGEX
110 .captures_iter(query)
111 .filter_map(|cap| cap.get(0).map(|m| m.as_str().to_string()))
112 .collect()
113 }
114}
115
116#[derive(Debug, Clone, Default)]
118pub struct ParsedHints {
119 hints: Vec<RoutingHint>,
121 pub route: Option<RouteTarget>,
123 pub node: Option<String>,
125 pub consistency: Option<ConsistencyLevel>,
127 pub pool: Option<PoolingModeHint>,
129 pub cache: Option<CacheBehavior>,
131 pub timeout: Option<Duration>,
133 pub priority: Option<QueryPriority>,
135 pub max_lag: Option<Duration>,
137 pub retry: Option<RetryBehavior>,
139 pub branch: Option<String>,
141 pub twr: Option<bool>,
143 pub cache_ttl: Option<Duration>,
145}
146
147impl ParsedHints {
148 pub fn add(&mut self, hint: RoutingHint) {
150 match &hint {
151 RoutingHint::Route(target) => self.route = Some(*target),
152 RoutingHint::Node(name) => self.node = Some(name.clone()),
153 RoutingHint::Consistency(level) => self.consistency = Some(*level),
154 RoutingHint::Pool(mode) => self.pool = Some(*mode),
155 RoutingHint::Cache(behavior) => self.cache = Some(*behavior),
156 RoutingHint::Timeout(dur) => self.timeout = Some(*dur),
157 RoutingHint::Priority(pri) => self.priority = Some(*pri),
158 RoutingHint::MaxLag(dur) => self.max_lag = Some(*dur),
159 RoutingHint::Retry(retry) => self.retry = Some(retry.clone()),
160 RoutingHint::Branch(name) => self.branch = Some(name.clone()),
161 RoutingHint::TransparentWriteRouting(enabled) => self.twr = Some(*enabled),
162 RoutingHint::CacheTtl(dur) => self.cache_ttl = Some(*dur),
163 _ => {}
164 }
165 self.hints.push(hint);
166 }
167
168 pub fn is_empty(&self) -> bool {
170 self.hints.is_empty()
171 }
172
173 pub fn len(&self) -> usize {
175 self.hints.len()
176 }
177
178 pub fn hints(&self) -> &[RoutingHint] {
180 &self.hints
181 }
182
183 pub fn is_primary_route(&self) -> bool {
185 matches!(self.route, Some(RouteTarget::Primary))
186 }
187
188 pub fn is_standby_route(&self) -> bool {
190 matches!(
191 self.route,
192 Some(RouteTarget::Standby) | Some(RouteTarget::Sync) |
193 Some(RouteTarget::SemiSync) | Some(RouteTarget::Async)
194 )
195 }
196
197 pub fn validate(&self) -> Result<()> {
199 if let (Some(RouteTarget::Async), Some(ConsistencyLevel::Strong)) =
201 (self.route, self.consistency)
202 {
203 return Err(RoutingError::InvalidHintCombination(
204 "route=async and consistency=strong are incompatible".to_string(),
205 ));
206 }
207
208 if self.consistency == Some(ConsistencyLevel::Bounded) && self.max_lag.is_none() {
210 }
212
213 Ok(())
214 }
215}
216
217#[derive(Debug, Clone, PartialEq)]
219pub enum RoutingHint {
220 Route(RouteTarget),
222
223 Node(String),
225
226 Consistency(ConsistencyLevel),
228
229 Pool(PoolingModeHint),
231
232 Cache(CacheBehavior),
234
235 Timeout(Duration),
237
238 Priority(QueryPriority),
240
241 MaxLag(Duration),
243
244 Retry(RetryBehavior),
246
247 Branch(String),
249
250 TransparentWriteRouting(bool),
252
253 AgentTool(String),
255
256 WorkflowStep(String),
258
259 Prefetch(bool),
261
262 CacheTtl(Duration),
264}
265
266#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
268pub enum RouteTarget {
269 Primary,
271 Standby,
273 Sync,
275 SemiSync,
277 Async,
279 Any,
281 Local,
283 Vector,
285}
286
287impl FromStr for RouteTarget {
288 type Err = RoutingError;
289
290 fn from_str(s: &str) -> Result<Self> {
291 match s.to_lowercase().as_str() {
292 "primary" | "master" | "leader" => Ok(RouteTarget::Primary),
293 "standby" | "replica" | "secondary" => Ok(RouteTarget::Standby),
294 "sync" | "synchronous" => Ok(RouteTarget::Sync),
295 "semisync" | "semi-sync" | "semi_sync" => Ok(RouteTarget::SemiSync),
296 "async" | "asynchronous" => Ok(RouteTarget::Async),
297 "any" | "all" => Ok(RouteTarget::Any),
298 "local" | "nearest" => Ok(RouteTarget::Local),
299 "vector" => Ok(RouteTarget::Vector),
300 _ => Err(RoutingError::ParseError(format!("Unknown route target: {}", s))),
301 }
302 }
303}
304
305impl std::fmt::Display for RouteTarget {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 match self {
308 RouteTarget::Primary => write!(f, "primary"),
309 RouteTarget::Standby => write!(f, "standby"),
310 RouteTarget::Sync => write!(f, "sync"),
311 RouteTarget::SemiSync => write!(f, "semisync"),
312 RouteTarget::Async => write!(f, "async"),
313 RouteTarget::Any => write!(f, "any"),
314 RouteTarget::Local => write!(f, "local"),
315 RouteTarget::Vector => write!(f, "vector"),
316 }
317 }
318}
319
320#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
322pub enum ConsistencyLevel {
323 Strong,
325 Bounded,
327 Eventual,
329}
330
331impl FromStr for ConsistencyLevel {
332 type Err = RoutingError;
333
334 fn from_str(s: &str) -> Result<Self> {
335 match s.to_lowercase().as_str() {
336 "strong" | "strict" | "linearizable" => Ok(ConsistencyLevel::Strong),
337 "bounded" | "session" | "read-your-writes" => Ok(ConsistencyLevel::Bounded),
338 "eventual" | "weak" => Ok(ConsistencyLevel::Eventual),
339 _ => Err(RoutingError::ParseError(format!("Unknown consistency level: {}", s))),
340 }
341 }
342}
343
344impl std::fmt::Display for ConsistencyLevel {
345 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346 match self {
347 ConsistencyLevel::Strong => write!(f, "strong"),
348 ConsistencyLevel::Bounded => write!(f, "bounded"),
349 ConsistencyLevel::Eventual => write!(f, "eventual"),
350 }
351 }
352}
353
354#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
356pub enum PoolingModeHint {
357 Session,
358 Transaction,
359 Statement,
360}
361
362impl FromStr for PoolingModeHint {
363 type Err = RoutingError;
364
365 fn from_str(s: &str) -> Result<Self> {
366 match s.to_lowercase().as_str() {
367 "session" => Ok(PoolingModeHint::Session),
368 "transaction" | "tx" => Ok(PoolingModeHint::Transaction),
369 "statement" | "stmt" | "query" => Ok(PoolingModeHint::Statement),
370 _ => Err(RoutingError::ParseError(format!("Unknown pool mode: {}", s))),
371 }
372 }
373}
374
375#[cfg(feature = "pool-modes")]
376impl From<PoolingModeHint> for PoolingMode {
377 fn from(hint: PoolingModeHint) -> Self {
378 match hint {
379 PoolingModeHint::Session => PoolingMode::Session,
380 PoolingModeHint::Transaction => PoolingMode::Transaction,
381 PoolingModeHint::Statement => PoolingMode::Statement,
382 }
383 }
384}
385
386#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
388pub enum CacheBehavior {
389 Normal,
391 Skip,
393 Refresh,
395 Semantic,
397 L1Only,
399 L2Only,
401}
402
403impl FromStr for CacheBehavior {
404 type Err = RoutingError;
405
406 fn from_str(s: &str) -> Result<Self> {
407 match s.to_lowercase().as_str() {
408 "normal" | "default" => Ok(CacheBehavior::Normal),
409 "skip" | "bypass" | "none" => Ok(CacheBehavior::Skip),
410 "refresh" | "force" | "update" => Ok(CacheBehavior::Refresh),
411 "semantic" | "l3" | "vector" => Ok(CacheBehavior::Semantic),
412 "l1" | "hot" => Ok(CacheBehavior::L1Only),
413 "l2" | "warm" => Ok(CacheBehavior::L2Only),
414 _ => Err(RoutingError::ParseError(format!("Unknown cache behavior: {}", s))),
415 }
416 }
417}
418
419#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
421pub enum QueryPriority {
422 Low = 0,
423 Normal = 1,
424 High = 2,
425 Critical = 3,
426}
427
428impl FromStr for QueryPriority {
429 type Err = RoutingError;
430
431 fn from_str(s: &str) -> Result<Self> {
432 match s.to_lowercase().as_str() {
433 "low" | "background" => Ok(QueryPriority::Low),
434 "normal" | "default" => Ok(QueryPriority::Normal),
435 "high" | "elevated" => Ok(QueryPriority::High),
436 "critical" | "urgent" | "realtime" => Ok(QueryPriority::Critical),
437 _ => Err(RoutingError::ParseError(format!("Unknown priority: {}", s))),
438 }
439 }
440}
441
442impl Default for QueryPriority {
443 fn default() -> Self {
444 QueryPriority::Normal
445 }
446}
447
448#[derive(Debug, Clone, PartialEq)]
450pub enum RetryBehavior {
451 None,
453 Auto,
455 Count(u32),
457}
458
459impl Default for RetryBehavior {
460 fn default() -> Self {
461 RetryBehavior::Auto
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_parse_single_hint() {
471 let parser = HintParser::new();
472 let hints = parser.parse("/*helios:route=primary*/ SELECT * FROM users");
473
474 assert!(!hints.is_empty());
475 assert_eq!(hints.route, Some(RouteTarget::Primary));
476 }
477
478 #[test]
479 fn test_parse_multiple_hints() {
480 let parser = HintParser::new();
481 let hints = parser.parse(
482 "/*helios:route=standby,consistency=eventual,timeout=5s*/ SELECT * FROM products"
483 );
484
485 assert_eq!(hints.len(), 3);
486 assert_eq!(hints.route, Some(RouteTarget::Standby));
487 assert_eq!(hints.consistency, Some(ConsistencyLevel::Eventual));
488 assert_eq!(hints.timeout, Some(Duration::from_secs(5)));
489 }
490
491 #[test]
492 fn test_parse_node_hint() {
493 let parser = HintParser::new();
494 let hints = parser.parse("/*helios:node=standby-sync-1*/ SELECT * FROM logs");
495
496 assert_eq!(hints.node, Some("standby-sync-1".to_string()));
497 }
498
499 #[test]
500 fn test_parse_lag_hint() {
501 let parser = HintParser::new();
502 let hints = parser.parse("/*helios:route=async,lag=10s*/ SELECT COUNT(*) FROM events");
503
504 assert_eq!(hints.route, Some(RouteTarget::Async));
505 assert_eq!(hints.max_lag, Some(Duration::from_secs(10)));
506 }
507
508 #[test]
509 fn test_parse_priority_hint() {
510 let parser = HintParser::new();
511 let hints = parser.parse("/*helios:priority=critical*/ SELECT balance FROM accounts");
512
513 assert_eq!(hints.priority, Some(QueryPriority::Critical));
514 }
515
516 #[test]
517 fn test_parse_cache_hint() {
518 let parser = HintParser::new();
519 let hints = parser.parse("/*helios:cache=skip*/ SELECT now()");
520
521 assert_eq!(hints.cache, Some(CacheBehavior::Skip));
522 }
523
524 #[test]
525 fn test_parse_pool_hint() {
526 let parser = HintParser::new();
527 let hints = parser.parse("/*helios:pool=transaction*/ BEGIN");
528
529 assert_eq!(hints.pool, Some(PoolingModeHint::Transaction));
530 }
531
532 #[test]
533 fn test_strip_hints() {
534 let parser = HintParser::new();
535 let query = "/*helios:route=primary*/ SELECT * FROM users WHERE id = 1";
536 let stripped = parser.strip(query);
537
538 assert_eq!(stripped, "SELECT * FROM users WHERE id = 1");
539 }
540
541 #[test]
542 fn test_strip_multiple_hints() {
543 let parser = HintParser::new();
544 let query = "/*helios:route=standby*/ SELECT * /*helios:cache=skip*/ FROM users";
545 let stripped = parser.strip(query);
546
547 assert_eq!(stripped, "SELECT * FROM users");
548 }
549
550 #[test]
551 fn test_validate_conflicting_hints() {
552 let parser = HintParser::new();
553 let hints = parser.parse("/*helios:route=async,consistency=strong*/ SELECT * FROM users");
554
555 let result = hints.validate();
556 assert!(result.is_err());
557 }
558
559 #[test]
560 fn test_route_target_parsing() {
561 assert_eq!(RouteTarget::from_str("primary").unwrap(), RouteTarget::Primary);
562 assert_eq!(RouteTarget::from_str("master").unwrap(), RouteTarget::Primary);
563 assert_eq!(RouteTarget::from_str("standby").unwrap(), RouteTarget::Standby);
564 assert_eq!(RouteTarget::from_str("replica").unwrap(), RouteTarget::Standby);
565 assert_eq!(RouteTarget::from_str("sync").unwrap(), RouteTarget::Sync);
566 assert_eq!(RouteTarget::from_str("async").unwrap(), RouteTarget::Async);
567 assert_eq!(RouteTarget::from_str("local").unwrap(), RouteTarget::Local);
568 }
569
570 #[test]
571 fn test_consistency_level_parsing() {
572 assert_eq!(ConsistencyLevel::from_str("strong").unwrap(), ConsistencyLevel::Strong);
573 assert_eq!(ConsistencyLevel::from_str("bounded").unwrap(), ConsistencyLevel::Bounded);
574 assert_eq!(ConsistencyLevel::from_str("eventual").unwrap(), ConsistencyLevel::Eventual);
575 }
576
577 #[test]
578 fn test_query_priority_ordering() {
579 assert!(QueryPriority::Critical > QueryPriority::High);
580 assert!(QueryPriority::High > QueryPriority::Normal);
581 assert!(QueryPriority::Normal > QueryPriority::Low);
582 }
583
584 #[test]
585 fn test_ai_workflow_hints() {
586 let parser = HintParser::new();
587 let hints = parser.parse(
588 "/*helios:route=async,tool=knowledge_search,workflow=planning*/ SELECT content FROM docs"
589 );
590
591 assert!(!hints.is_empty());
592 assert_eq!(hints.route, Some(RouteTarget::Async));
593
594 let has_tool = hints.hints().iter().any(|h| matches!(h, RoutingHint::AgentTool(t) if t == "knowledge_search"));
596 let has_workflow = hints.hints().iter().any(|h| matches!(h, RoutingHint::WorkflowStep(w) if w == "planning"));
597
598 assert!(has_tool);
599 assert!(has_workflow);
600 }
601
602 #[test]
603 fn test_branch_hint() {
604 let parser = HintParser::new();
605 let hints = parser.parse("/*helios:branch=analytics,route=local*/ SELECT * FROM reports");
606
607 assert_eq!(hints.branch, Some("analytics".to_string()));
608 assert_eq!(hints.route, Some(RouteTarget::Local));
609 }
610
611 #[test]
612 fn test_twr_hint() {
613 let parser = HintParser::new();
614 let hints = parser.parse("/*helios:route=sync,twr=true*/ INSERT INTO logs VALUES (1)");
615
616 assert_eq!(hints.route, Some(RouteTarget::Sync));
617 assert_eq!(hints.twr, Some(true));
618 }
619
620 #[test]
621 fn test_empty_query() {
622 let parser = HintParser::new();
623 let hints = parser.parse("SELECT * FROM users");
624
625 assert!(hints.is_empty());
626 }
627
628 #[test]
629 fn test_extract_raw() {
630 let parser = HintParser::new();
631 let raw = parser.extract_raw("/*helios:route=primary*/ SELECT /*helios:cache=skip*/ 1");
632
633 assert_eq!(raw.len(), 2);
634 assert!(raw[0].contains("route=primary"));
635 assert!(raw[1].contains("cache=skip"));
636 }
637}