Skip to main content

heliosdb_proxy/routing/
hint_parser.rs

1//! SQL Hint Parser
2//!
3//! Parses routing hints from SQL comments.
4//! Supports both single hints and comma-separated multiple hints.
5//!
6//! # Format
7//!
8//! ```text
9//! /*helios:key=value*/
10//! /*helios:key1=value1,key2=value2*/
11//! ```
12
13use super::{parse_duration, RoutingError, Result};
14use regex::Regex;
15use std::str::FromStr;
16use std::sync::LazyLock;
17use std::time::Duration;
18
19#[cfg(feature = "pool-modes")]
20use crate::pool::PoolingMode;
21
22/// Compiled regex for hint parsing
23static HINT_REGEX: LazyLock<Regex> = LazyLock::new(|| {
24    Regex::new(r"/\*\s*helios:([^*]+)\*/").expect("Invalid hint regex")
25});
26
27/// Key-value pair regex
28static KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
29    Regex::new(r"(\w+)\s*=\s*([^,\s]+)").expect("Invalid key-value regex")
30});
31
32/// Hint parser for SQL routing hints
33#[derive(Debug, Clone, Default)]
34pub struct HintParser {
35    /// Whether to strip hints from query before sending to backend
36    pub strip_hints: bool,
37}
38
39impl HintParser {
40    /// Create a new hint parser
41    pub fn new() -> Self {
42        Self { strip_hints: true }
43    }
44
45    /// Create parser with hint stripping disabled
46    pub fn without_stripping() -> Self {
47        Self { strip_hints: false }
48    }
49
50    /// Parse all routing hints from a SQL query
51    pub fn parse(&self, query: &str) -> ParsedHints {
52        let mut hints = ParsedHints::default();
53
54        for cap in HINT_REGEX.captures_iter(query) {
55            let hint_content = cap.get(1).map(|m| m.as_str()).unwrap_or("");
56
57            // Parse key-value pairs
58            for kv in KV_REGEX.captures_iter(hint_content) {
59                let key = kv.get(1).map(|m| m.as_str()).unwrap_or("");
60                let value = kv.get(2).map(|m| m.as_str()).unwrap_or("");
61
62                if let Some(hint) = self.parse_hint(key, value) {
63                    hints.add(hint);
64                }
65            }
66        }
67
68        hints
69    }
70
71    /// Parse a single hint from key-value pair
72    fn parse_hint(&self, key: &str, value: &str) -> Option<RoutingHint> {
73        match key.to_lowercase().as_str() {
74            "route" => RouteTarget::from_str(value).ok().map(RoutingHint::Route),
75            "node" => Some(RoutingHint::Node(value.to_string())),
76            "consistency" => ConsistencyLevel::from_str(value).ok().map(RoutingHint::Consistency),
77            "pool" => PoolingModeHint::from_str(value).ok().map(RoutingHint::Pool),
78            "cache" => CacheBehavior::from_str(value).ok().map(RoutingHint::Cache),
79            "timeout" => parse_duration(value).map(RoutingHint::Timeout),
80            "priority" => QueryPriority::from_str(value).ok().map(RoutingHint::Priority),
81            "lag" => parse_duration(value).map(RoutingHint::MaxLag),
82            "retry" => self.parse_retry(value).map(RoutingHint::Retry),
83            "branch" => Some(RoutingHint::Branch(value.to_string())),
84            "twr" => value.parse::<bool>().ok().map(RoutingHint::TransparentWriteRouting),
85            "tool" => Some(RoutingHint::AgentTool(value.to_string())),
86            "workflow" => Some(RoutingHint::WorkflowStep(value.to_string())),
87            "prefetch" => value.parse::<bool>().ok().map(RoutingHint::Prefetch),
88            "cache_ttl" => value.parse::<u64>().ok().map(|s| RoutingHint::CacheTtl(Duration::from_secs(s))),
89            _ => None,
90        }
91    }
92
93    /// Parse retry hint value
94    fn parse_retry(&self, value: &str) -> Option<RetryBehavior> {
95        match value.to_lowercase().as_str() {
96            "true" | "yes" => Some(RetryBehavior::Auto),
97            "false" | "no" => Some(RetryBehavior::None),
98            _ => value.parse::<u32>().ok().map(RetryBehavior::Count),
99        }
100    }
101
102    /// Strip hints from query for backend execution
103    pub fn strip(&self, query: &str) -> String {
104        HINT_REGEX.replace_all(query, "").trim().to_string()
105    }
106
107    /// Extract raw hint string from query (for logging)
108    pub fn extract_raw(&self, query: &str) -> Vec<String> {
109        HINT_REGEX
110            .captures_iter(query)
111            .filter_map(|cap| cap.get(0).map(|m| m.as_str().to_string()))
112            .collect()
113    }
114}
115
116/// Parsed hints collection
117#[derive(Debug, Clone, Default)]
118pub struct ParsedHints {
119    /// All parsed hints
120    hints: Vec<RoutingHint>,
121    /// Route target (if specified)
122    pub route: Option<RouteTarget>,
123    /// Specific node (if specified)
124    pub node: Option<String>,
125    /// Consistency level (if specified)
126    pub consistency: Option<ConsistencyLevel>,
127    /// Pool mode (if specified)
128    pub pool: Option<PoolingModeHint>,
129    /// Cache behavior (if specified)
130    pub cache: Option<CacheBehavior>,
131    /// Query timeout (if specified)
132    pub timeout: Option<Duration>,
133    /// Query priority (if specified)
134    pub priority: Option<QueryPriority>,
135    /// Maximum acceptable lag (if specified)
136    pub max_lag: Option<Duration>,
137    /// Retry behavior (if specified)
138    pub retry: Option<RetryBehavior>,
139    /// Branch name (if specified)
140    pub branch: Option<String>,
141    /// Transparent Write Routing (if specified)
142    pub twr: Option<bool>,
143    /// Cache TTL override (if specified)
144    pub cache_ttl: Option<Duration>,
145}
146
147impl ParsedHints {
148    /// Add a hint to the collection
149    pub fn add(&mut self, hint: RoutingHint) {
150        match &hint {
151            RoutingHint::Route(target) => self.route = Some(*target),
152            RoutingHint::Node(name) => self.node = Some(name.clone()),
153            RoutingHint::Consistency(level) => self.consistency = Some(*level),
154            RoutingHint::Pool(mode) => self.pool = Some(*mode),
155            RoutingHint::Cache(behavior) => self.cache = Some(*behavior),
156            RoutingHint::Timeout(dur) => self.timeout = Some(*dur),
157            RoutingHint::Priority(pri) => self.priority = Some(*pri),
158            RoutingHint::MaxLag(dur) => self.max_lag = Some(*dur),
159            RoutingHint::Retry(retry) => self.retry = Some(retry.clone()),
160            RoutingHint::Branch(name) => self.branch = Some(name.clone()),
161            RoutingHint::TransparentWriteRouting(enabled) => self.twr = Some(*enabled),
162            RoutingHint::CacheTtl(dur) => self.cache_ttl = Some(*dur),
163            _ => {}
164        }
165        self.hints.push(hint);
166    }
167
168    /// Check if any hints were parsed
169    pub fn is_empty(&self) -> bool {
170        self.hints.is_empty()
171    }
172
173    /// Get number of hints
174    pub fn len(&self) -> usize {
175        self.hints.len()
176    }
177
178    /// Get all hints
179    pub fn hints(&self) -> &[RoutingHint] {
180        &self.hints
181    }
182
183    /// Check if route=primary is specified
184    pub fn is_primary_route(&self) -> bool {
185        matches!(self.route, Some(RouteTarget::Primary))
186    }
187
188    /// Check if any standby route is specified
189    pub fn is_standby_route(&self) -> bool {
190        matches!(
191            self.route,
192            Some(RouteTarget::Standby) | Some(RouteTarget::Sync) |
193            Some(RouteTarget::SemiSync) | Some(RouteTarget::Async)
194        )
195    }
196
197    /// Validate hint combinations
198    pub fn validate(&self) -> Result<()> {
199        // Check for conflicting hints
200        if let (Some(RouteTarget::Async), Some(ConsistencyLevel::Strong)) =
201            (self.route, self.consistency)
202        {
203            return Err(RoutingError::InvalidHintCombination(
204                "route=async and consistency=strong are incompatible".to_string(),
205            ));
206        }
207
208        // Bounded consistency requires lag specification for proper enforcement
209        if self.consistency == Some(ConsistencyLevel::Bounded) && self.max_lag.is_none() {
210            // Not an error, just a warning - use default lag
211        }
212
213        Ok(())
214    }
215}
216
217/// Individual routing hint
218#[derive(Debug, Clone, PartialEq)]
219pub enum RoutingHint {
220    /// Target node type
221    Route(RouteTarget),
222
223    /// Specific node by name
224    Node(String),
225
226    /// Consistency level requirement
227    Consistency(ConsistencyLevel),
228
229    /// Connection pool mode
230    Pool(PoolingModeHint),
231
232    /// Cache behavior
233    Cache(CacheBehavior),
234
235    /// Query timeout override
236    Timeout(Duration),
237
238    /// Query priority for scheduling
239    Priority(QueryPriority),
240
241    /// Maximum acceptable replication lag
242    MaxLag(Duration),
243
244    /// Retry behavior on failure
245    Retry(RetryBehavior),
246
247    /// Branch name for branch-aware routing
248    Branch(String),
249
250    /// Enable Transparent Write Routing
251    TransparentWriteRouting(bool),
252
253    /// Agent tool identifier
254    AgentTool(String),
255
256    /// Workflow step identifier
257    WorkflowStep(String),
258
259    /// Prefetch hint for context retrieval
260    Prefetch(bool),
261
262    /// Cache TTL override
263    CacheTtl(Duration),
264}
265
266/// Route target types
267#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
268pub enum RouteTarget {
269    /// Primary node (for writes or critical reads)
270    Primary,
271    /// Any standby (read scaling)
272    Standby,
273    /// Synchronous standby only
274    Sync,
275    /// Semi-synchronous standby
276    SemiSync,
277    /// Asynchronous standby (eventual consistency)
278    Async,
279    /// Any available node
280    Any,
281    /// Prefer local/closest node
282    Local,
283    /// Vector-optimized node
284    Vector,
285}
286
287impl FromStr for RouteTarget {
288    type Err = RoutingError;
289
290    fn from_str(s: &str) -> Result<Self> {
291        match s.to_lowercase().as_str() {
292            "primary" | "master" | "leader" => Ok(RouteTarget::Primary),
293            "standby" | "replica" | "secondary" => Ok(RouteTarget::Standby),
294            "sync" | "synchronous" => Ok(RouteTarget::Sync),
295            "semisync" | "semi-sync" | "semi_sync" => Ok(RouteTarget::SemiSync),
296            "async" | "asynchronous" => Ok(RouteTarget::Async),
297            "any" | "all" => Ok(RouteTarget::Any),
298            "local" | "nearest" => Ok(RouteTarget::Local),
299            "vector" => Ok(RouteTarget::Vector),
300            _ => Err(RoutingError::ParseError(format!("Unknown route target: {}", s))),
301        }
302    }
303}
304
305impl std::fmt::Display for RouteTarget {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        match self {
308            RouteTarget::Primary => write!(f, "primary"),
309            RouteTarget::Standby => write!(f, "standby"),
310            RouteTarget::Sync => write!(f, "sync"),
311            RouteTarget::SemiSync => write!(f, "semisync"),
312            RouteTarget::Async => write!(f, "async"),
313            RouteTarget::Any => write!(f, "any"),
314            RouteTarget::Local => write!(f, "local"),
315            RouteTarget::Vector => write!(f, "vector"),
316        }
317    }
318}
319
320/// Consistency levels
321#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
322pub enum ConsistencyLevel {
323    /// Must read from primary or sync standby
324    Strong,
325    /// Allow semi-sync with bounded lag
326    Bounded,
327    /// Allow any replica (eventual consistency)
328    Eventual,
329}
330
331impl FromStr for ConsistencyLevel {
332    type Err = RoutingError;
333
334    fn from_str(s: &str) -> Result<Self> {
335        match s.to_lowercase().as_str() {
336            "strong" | "strict" | "linearizable" => Ok(ConsistencyLevel::Strong),
337            "bounded" | "session" | "read-your-writes" => Ok(ConsistencyLevel::Bounded),
338            "eventual" | "weak" => Ok(ConsistencyLevel::Eventual),
339            _ => Err(RoutingError::ParseError(format!("Unknown consistency level: {}", s))),
340        }
341    }
342}
343
344impl std::fmt::Display for ConsistencyLevel {
345    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
346        match self {
347            ConsistencyLevel::Strong => write!(f, "strong"),
348            ConsistencyLevel::Bounded => write!(f, "bounded"),
349            ConsistencyLevel::Eventual => write!(f, "eventual"),
350        }
351    }
352}
353
354/// Pooling mode hint (mirrors pool::PoolingMode)
355#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
356pub enum PoolingModeHint {
357    Session,
358    Transaction,
359    Statement,
360}
361
362impl FromStr for PoolingModeHint {
363    type Err = RoutingError;
364
365    fn from_str(s: &str) -> Result<Self> {
366        match s.to_lowercase().as_str() {
367            "session" => Ok(PoolingModeHint::Session),
368            "transaction" | "tx" => Ok(PoolingModeHint::Transaction),
369            "statement" | "stmt" | "query" => Ok(PoolingModeHint::Statement),
370            _ => Err(RoutingError::ParseError(format!("Unknown pool mode: {}", s))),
371        }
372    }
373}
374
375#[cfg(feature = "pool-modes")]
376impl From<PoolingModeHint> for PoolingMode {
377    fn from(hint: PoolingModeHint) -> Self {
378        match hint {
379            PoolingModeHint::Session => PoolingMode::Session,
380            PoolingModeHint::Transaction => PoolingMode::Transaction,
381            PoolingModeHint::Statement => PoolingMode::Statement,
382        }
383    }
384}
385
386/// Cache behavior hints
387#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
388pub enum CacheBehavior {
389    /// Use normal caching
390    Normal,
391    /// Skip cache entirely
392    Skip,
393    /// Refresh cache (bypass read, update on response)
394    Refresh,
395    /// Use semantic (L3) cache
396    Semantic,
397    /// Only use L1 cache
398    L1Only,
399    /// Only use L2 cache
400    L2Only,
401}
402
403impl FromStr for CacheBehavior {
404    type Err = RoutingError;
405
406    fn from_str(s: &str) -> Result<Self> {
407        match s.to_lowercase().as_str() {
408            "normal" | "default" => Ok(CacheBehavior::Normal),
409            "skip" | "bypass" | "none" => Ok(CacheBehavior::Skip),
410            "refresh" | "force" | "update" => Ok(CacheBehavior::Refresh),
411            "semantic" | "l3" | "vector" => Ok(CacheBehavior::Semantic),
412            "l1" | "hot" => Ok(CacheBehavior::L1Only),
413            "l2" | "warm" => Ok(CacheBehavior::L2Only),
414            _ => Err(RoutingError::ParseError(format!("Unknown cache behavior: {}", s))),
415        }
416    }
417}
418
419/// Query priority levels
420#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
421pub enum QueryPriority {
422    Low = 0,
423    Normal = 1,
424    High = 2,
425    Critical = 3,
426}
427
428impl FromStr for QueryPriority {
429    type Err = RoutingError;
430
431    fn from_str(s: &str) -> Result<Self> {
432        match s.to_lowercase().as_str() {
433            "low" | "background" => Ok(QueryPriority::Low),
434            "normal" | "default" => Ok(QueryPriority::Normal),
435            "high" | "elevated" => Ok(QueryPriority::High),
436            "critical" | "urgent" | "realtime" => Ok(QueryPriority::Critical),
437            _ => Err(RoutingError::ParseError(format!("Unknown priority: {}", s))),
438        }
439    }
440}
441
442impl Default for QueryPriority {
443    fn default() -> Self {
444        QueryPriority::Normal
445    }
446}
447
448/// Retry behavior
449#[derive(Debug, Clone, PartialEq)]
450pub enum RetryBehavior {
451    /// No retry
452    None,
453    /// Automatic retry with default count
454    Auto,
455    /// Retry specific number of times
456    Count(u32),
457}
458
459impl Default for RetryBehavior {
460    fn default() -> Self {
461        RetryBehavior::Auto
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn test_parse_single_hint() {
471        let parser = HintParser::new();
472        let hints = parser.parse("/*helios:route=primary*/ SELECT * FROM users");
473
474        assert!(!hints.is_empty());
475        assert_eq!(hints.route, Some(RouteTarget::Primary));
476    }
477
478    #[test]
479    fn test_parse_multiple_hints() {
480        let parser = HintParser::new();
481        let hints = parser.parse(
482            "/*helios:route=standby,consistency=eventual,timeout=5s*/ SELECT * FROM products"
483        );
484
485        assert_eq!(hints.len(), 3);
486        assert_eq!(hints.route, Some(RouteTarget::Standby));
487        assert_eq!(hints.consistency, Some(ConsistencyLevel::Eventual));
488        assert_eq!(hints.timeout, Some(Duration::from_secs(5)));
489    }
490
491    #[test]
492    fn test_parse_node_hint() {
493        let parser = HintParser::new();
494        let hints = parser.parse("/*helios:node=standby-sync-1*/ SELECT * FROM logs");
495
496        assert_eq!(hints.node, Some("standby-sync-1".to_string()));
497    }
498
499    #[test]
500    fn test_parse_lag_hint() {
501        let parser = HintParser::new();
502        let hints = parser.parse("/*helios:route=async,lag=10s*/ SELECT COUNT(*) FROM events");
503
504        assert_eq!(hints.route, Some(RouteTarget::Async));
505        assert_eq!(hints.max_lag, Some(Duration::from_secs(10)));
506    }
507
508    #[test]
509    fn test_parse_priority_hint() {
510        let parser = HintParser::new();
511        let hints = parser.parse("/*helios:priority=critical*/ SELECT balance FROM accounts");
512
513        assert_eq!(hints.priority, Some(QueryPriority::Critical));
514    }
515
516    #[test]
517    fn test_parse_cache_hint() {
518        let parser = HintParser::new();
519        let hints = parser.parse("/*helios:cache=skip*/ SELECT now()");
520
521        assert_eq!(hints.cache, Some(CacheBehavior::Skip));
522    }
523
524    #[test]
525    fn test_parse_pool_hint() {
526        let parser = HintParser::new();
527        let hints = parser.parse("/*helios:pool=transaction*/ BEGIN");
528
529        assert_eq!(hints.pool, Some(PoolingModeHint::Transaction));
530    }
531
532    #[test]
533    fn test_strip_hints() {
534        let parser = HintParser::new();
535        let query = "/*helios:route=primary*/ SELECT * FROM users WHERE id = 1";
536        let stripped = parser.strip(query);
537
538        assert_eq!(stripped, "SELECT * FROM users WHERE id = 1");
539    }
540
541    #[test]
542    fn test_strip_multiple_hints() {
543        let parser = HintParser::new();
544        let query = "/*helios:route=standby*/ SELECT * /*helios:cache=skip*/ FROM users";
545        let stripped = parser.strip(query);
546
547        assert_eq!(stripped, "SELECT *  FROM users");
548    }
549
550    #[test]
551    fn test_validate_conflicting_hints() {
552        let parser = HintParser::new();
553        let hints = parser.parse("/*helios:route=async,consistency=strong*/ SELECT * FROM users");
554
555        let result = hints.validate();
556        assert!(result.is_err());
557    }
558
559    #[test]
560    fn test_route_target_parsing() {
561        assert_eq!(RouteTarget::from_str("primary").unwrap(), RouteTarget::Primary);
562        assert_eq!(RouteTarget::from_str("master").unwrap(), RouteTarget::Primary);
563        assert_eq!(RouteTarget::from_str("standby").unwrap(), RouteTarget::Standby);
564        assert_eq!(RouteTarget::from_str("replica").unwrap(), RouteTarget::Standby);
565        assert_eq!(RouteTarget::from_str("sync").unwrap(), RouteTarget::Sync);
566        assert_eq!(RouteTarget::from_str("async").unwrap(), RouteTarget::Async);
567        assert_eq!(RouteTarget::from_str("local").unwrap(), RouteTarget::Local);
568    }
569
570    #[test]
571    fn test_consistency_level_parsing() {
572        assert_eq!(ConsistencyLevel::from_str("strong").unwrap(), ConsistencyLevel::Strong);
573        assert_eq!(ConsistencyLevel::from_str("bounded").unwrap(), ConsistencyLevel::Bounded);
574        assert_eq!(ConsistencyLevel::from_str("eventual").unwrap(), ConsistencyLevel::Eventual);
575    }
576
577    #[test]
578    fn test_query_priority_ordering() {
579        assert!(QueryPriority::Critical > QueryPriority::High);
580        assert!(QueryPriority::High > QueryPriority::Normal);
581        assert!(QueryPriority::Normal > QueryPriority::Low);
582    }
583
584    #[test]
585    fn test_ai_workflow_hints() {
586        let parser = HintParser::new();
587        let hints = parser.parse(
588            "/*helios:route=async,tool=knowledge_search,workflow=planning*/ SELECT content FROM docs"
589        );
590
591        assert!(!hints.is_empty());
592        assert_eq!(hints.route, Some(RouteTarget::Async));
593
594        // Check for tool and workflow hints in the list
595        let has_tool = hints.hints().iter().any(|h| matches!(h, RoutingHint::AgentTool(t) if t == "knowledge_search"));
596        let has_workflow = hints.hints().iter().any(|h| matches!(h, RoutingHint::WorkflowStep(w) if w == "planning"));
597
598        assert!(has_tool);
599        assert!(has_workflow);
600    }
601
602    #[test]
603    fn test_branch_hint() {
604        let parser = HintParser::new();
605        let hints = parser.parse("/*helios:branch=analytics,route=local*/ SELECT * FROM reports");
606
607        assert_eq!(hints.branch, Some("analytics".to_string()));
608        assert_eq!(hints.route, Some(RouteTarget::Local));
609    }
610
611    #[test]
612    fn test_twr_hint() {
613        let parser = HintParser::new();
614        let hints = parser.parse("/*helios:route=sync,twr=true*/ INSERT INTO logs VALUES (1)");
615
616        assert_eq!(hints.route, Some(RouteTarget::Sync));
617        assert_eq!(hints.twr, Some(true));
618    }
619
620    #[test]
621    fn test_empty_query() {
622        let parser = HintParser::new();
623        let hints = parser.parse("SELECT * FROM users");
624
625        assert!(hints.is_empty());
626    }
627
628    #[test]
629    fn test_extract_raw() {
630        let parser = HintParser::new();
631        let raw = parser.extract_raw("/*helios:route=primary*/ SELECT /*helios:cache=skip*/ 1");
632
633        assert_eq!(raw.len(), 2);
634        assert!(raw[0].contains("route=primary"));
635        assert!(raw[1].contains("cache=skip"));
636    }
637}