Skip to main content

heliosdb_proxy/routing/
hint_parser.rs

1//! SQL Hint Parser
2//!
3//! Parses routing hints from SQL comments.
4//! Supports both single hints and comma-separated multiple hints.
5//!
6//! # Format
7//!
8//! ```text
9//! /*helios:key=value*/
10//! /*helios:key1=value1,key2=value2*/
11//! ```
12
13use super::{parse_duration, Result, RoutingError};
14use once_cell::sync::Lazy;
15use regex::Regex;
16use std::str::FromStr;
17use std::time::Duration;
18
19#[cfg(feature = "pool-modes")]
20use crate::pool::PoolingMode;
21
22/// Compiled regex for hint parsing
23static HINT_REGEX: Lazy<Regex> =
24    Lazy::new(|| Regex::new(r"/\*\s*helios:([^*]+)\*/").expect("Invalid hint regex"));
25
26/// Key-value pair regex
27static KV_REGEX: Lazy<Regex> =
28    Lazy::new(|| Regex::new(r"(\w+)\s*=\s*([^,\s]+)").expect("Invalid key-value regex"));
29
30/// Hint parser for SQL routing hints
31#[derive(Debug, Clone, Default)]
32pub struct HintParser {
33    /// Whether to strip hints from query before sending to backend
34    pub strip_hints: bool,
35}
36
37impl HintParser {
38    /// Create a new hint parser
39    pub fn new() -> Self {
40        Self { strip_hints: true }
41    }
42
43    /// Create parser with hint stripping disabled
44    pub fn without_stripping() -> Self {
45        Self { strip_hints: false }
46    }
47
48    /// Parse all routing hints from a SQL query
49    pub fn parse(&self, query: &str) -> ParsedHints {
50        let mut hints = ParsedHints::default();
51
52        for cap in HINT_REGEX.captures_iter(query) {
53            let hint_content = cap.get(1).map(|m| m.as_str()).unwrap_or("");
54
55            // Parse key-value pairs
56            for kv in KV_REGEX.captures_iter(hint_content) {
57                let key = kv.get(1).map(|m| m.as_str()).unwrap_or("");
58                let value = kv.get(2).map(|m| m.as_str()).unwrap_or("");
59
60                if let Some(hint) = self.parse_hint(key, value) {
61                    hints.add(hint);
62                }
63            }
64        }
65
66        hints
67    }
68
69    /// Parse a single hint from key-value pair
70    fn parse_hint(&self, key: &str, value: &str) -> Option<RoutingHint> {
71        match key.to_lowercase().as_str() {
72            "route" => RouteTarget::from_str(value).ok().map(RoutingHint::Route),
73            "node" => Some(RoutingHint::Node(value.to_string())),
74            "consistency" => ConsistencyLevel::from_str(value)
75                .ok()
76                .map(RoutingHint::Consistency),
77            "pool" => PoolingModeHint::from_str(value).ok().map(RoutingHint::Pool),
78            "cache" => CacheBehavior::from_str(value).ok().map(RoutingHint::Cache),
79            "timeout" => parse_duration(value).map(RoutingHint::Timeout),
80            "priority" => QueryPriority::from_str(value)
81                .ok()
82                .map(RoutingHint::Priority),
83            "lag" => parse_duration(value).map(RoutingHint::MaxLag),
84            "retry" => self.parse_retry(value).map(RoutingHint::Retry),
85            "branch" => Some(RoutingHint::Branch(value.to_string())),
86            "twr" => value
87                .parse::<bool>()
88                .ok()
89                .map(RoutingHint::TransparentWriteRouting),
90            "tool" => Some(RoutingHint::AgentTool(value.to_string())),
91            "workflow" => Some(RoutingHint::WorkflowStep(value.to_string())),
92            "prefetch" => value.parse::<bool>().ok().map(RoutingHint::Prefetch),
93            "cache_ttl" => value
94                .parse::<u64>()
95                .ok()
96                .map(|s| RoutingHint::CacheTtl(Duration::from_secs(s))),
97            _ => None,
98        }
99    }
100
101    /// Parse retry hint value
102    fn parse_retry(&self, value: &str) -> Option<RetryBehavior> {
103        match value.to_lowercase().as_str() {
104            "true" | "yes" => Some(RetryBehavior::Auto),
105            "false" | "no" => Some(RetryBehavior::None),
106            _ => value.parse::<u32>().ok().map(RetryBehavior::Count),
107        }
108    }
109
110    /// Strip hints from query for backend execution
111    pub fn strip(&self, query: &str) -> String {
112        HINT_REGEX.replace_all(query, "").trim().to_string()
113    }
114
115    /// Extract raw hint string from query (for logging)
116    pub fn extract_raw(&self, query: &str) -> Vec<String> {
117        HINT_REGEX
118            .captures_iter(query)
119            .filter_map(|cap| cap.get(0).map(|m| m.as_str().to_string()))
120            .collect()
121    }
122}
123
124/// Parsed hints collection
125#[derive(Debug, Clone, Default)]
126pub struct ParsedHints {
127    /// All parsed hints
128    hints: Vec<RoutingHint>,
129    /// Route target (if specified)
130    pub route: Option<RouteTarget>,
131    /// Specific node (if specified)
132    pub node: Option<String>,
133    /// Consistency level (if specified)
134    pub consistency: Option<ConsistencyLevel>,
135    /// Pool mode (if specified)
136    pub pool: Option<PoolingModeHint>,
137    /// Cache behavior (if specified)
138    pub cache: Option<CacheBehavior>,
139    /// Query timeout (if specified)
140    pub timeout: Option<Duration>,
141    /// Query priority (if specified)
142    pub priority: Option<QueryPriority>,
143    /// Maximum acceptable lag (if specified)
144    pub max_lag: Option<Duration>,
145    /// Retry behavior (if specified)
146    pub retry: Option<RetryBehavior>,
147    /// Branch name (if specified)
148    pub branch: Option<String>,
149    /// Transparent Write Routing (if specified)
150    pub twr: Option<bool>,
151    /// Cache TTL override (if specified)
152    pub cache_ttl: Option<Duration>,
153}
154
155impl ParsedHints {
156    /// Add a hint to the collection
157    pub fn add(&mut self, hint: RoutingHint) {
158        match &hint {
159            RoutingHint::Route(target) => self.route = Some(*target),
160            RoutingHint::Node(name) => self.node = Some(name.clone()),
161            RoutingHint::Consistency(level) => self.consistency = Some(*level),
162            RoutingHint::Pool(mode) => self.pool = Some(*mode),
163            RoutingHint::Cache(behavior) => self.cache = Some(*behavior),
164            RoutingHint::Timeout(dur) => self.timeout = Some(*dur),
165            RoutingHint::Priority(pri) => self.priority = Some(*pri),
166            RoutingHint::MaxLag(dur) => self.max_lag = Some(*dur),
167            RoutingHint::Retry(retry) => self.retry = Some(retry.clone()),
168            RoutingHint::Branch(name) => self.branch = Some(name.clone()),
169            RoutingHint::TransparentWriteRouting(enabled) => self.twr = Some(*enabled),
170            RoutingHint::CacheTtl(dur) => self.cache_ttl = Some(*dur),
171            _ => {}
172        }
173        self.hints.push(hint);
174    }
175
176    /// Check if any hints were parsed
177    pub fn is_empty(&self) -> bool {
178        self.hints.is_empty()
179    }
180
181    /// Get number of hints
182    pub fn len(&self) -> usize {
183        self.hints.len()
184    }
185
186    /// Get all hints
187    pub fn hints(&self) -> &[RoutingHint] {
188        &self.hints
189    }
190
191    /// Check if route=primary is specified
192    pub fn is_primary_route(&self) -> bool {
193        matches!(self.route, Some(RouteTarget::Primary))
194    }
195
196    /// Check if any standby route is specified
197    pub fn is_standby_route(&self) -> bool {
198        matches!(
199            self.route,
200            Some(RouteTarget::Standby)
201                | Some(RouteTarget::Sync)
202                | Some(RouteTarget::SemiSync)
203                | Some(RouteTarget::Async)
204        )
205    }
206
207    /// Validate hint combinations
208    pub fn validate(&self) -> Result<()> {
209        // Check for conflicting hints
210        if let (Some(RouteTarget::Async), Some(ConsistencyLevel::Strong)) =
211            (self.route, self.consistency)
212        {
213            return Err(RoutingError::InvalidHintCombination(
214                "route=async and consistency=strong are incompatible".to_string(),
215            ));
216        }
217
218        // Bounded consistency requires lag specification for proper enforcement
219        if self.consistency == Some(ConsistencyLevel::Bounded) && self.max_lag.is_none() {
220            // Not an error, just a warning - use default lag
221        }
222
223        Ok(())
224    }
225}
226
227/// Individual routing hint
228#[derive(Debug, Clone, PartialEq)]
229pub enum RoutingHint {
230    /// Target node type
231    Route(RouteTarget),
232
233    /// Specific node by name
234    Node(String),
235
236    /// Consistency level requirement
237    Consistency(ConsistencyLevel),
238
239    /// Connection pool mode
240    Pool(PoolingModeHint),
241
242    /// Cache behavior
243    Cache(CacheBehavior),
244
245    /// Query timeout override
246    Timeout(Duration),
247
248    /// Query priority for scheduling
249    Priority(QueryPriority),
250
251    /// Maximum acceptable replication lag
252    MaxLag(Duration),
253
254    /// Retry behavior on failure
255    Retry(RetryBehavior),
256
257    /// Branch name for branch-aware routing
258    Branch(String),
259
260    /// Enable Transparent Write Routing
261    TransparentWriteRouting(bool),
262
263    /// Agent tool identifier
264    AgentTool(String),
265
266    /// Workflow step identifier
267    WorkflowStep(String),
268
269    /// Prefetch hint for context retrieval
270    Prefetch(bool),
271
272    /// Cache TTL override
273    CacheTtl(Duration),
274}
275
276/// Route target types
277#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
278pub enum RouteTarget {
279    /// Primary node (for writes or critical reads)
280    Primary,
281    /// Any standby (read scaling)
282    Standby,
283    /// Synchronous standby only
284    Sync,
285    /// Semi-synchronous standby
286    SemiSync,
287    /// Asynchronous standby (eventual consistency)
288    Async,
289    /// Any available node
290    Any,
291    /// Prefer local/closest node
292    Local,
293    /// Vector-optimized node
294    Vector,
295}
296
297impl FromStr for RouteTarget {
298    type Err = RoutingError;
299
300    fn from_str(s: &str) -> Result<Self> {
301        match s.to_lowercase().as_str() {
302            "primary" | "master" | "leader" => Ok(RouteTarget::Primary),
303            "standby" | "replica" | "secondary" => Ok(RouteTarget::Standby),
304            "sync" | "synchronous" => Ok(RouteTarget::Sync),
305            "semisync" | "semi-sync" | "semi_sync" => Ok(RouteTarget::SemiSync),
306            "async" | "asynchronous" => Ok(RouteTarget::Async),
307            "any" | "all" => Ok(RouteTarget::Any),
308            "local" | "nearest" => Ok(RouteTarget::Local),
309            "vector" => Ok(RouteTarget::Vector),
310            _ => Err(RoutingError::ParseError(format!(
311                "Unknown route target: {}",
312                s
313            ))),
314        }
315    }
316}
317
318impl std::fmt::Display for RouteTarget {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        match self {
321            RouteTarget::Primary => write!(f, "primary"),
322            RouteTarget::Standby => write!(f, "standby"),
323            RouteTarget::Sync => write!(f, "sync"),
324            RouteTarget::SemiSync => write!(f, "semisync"),
325            RouteTarget::Async => write!(f, "async"),
326            RouteTarget::Any => write!(f, "any"),
327            RouteTarget::Local => write!(f, "local"),
328            RouteTarget::Vector => write!(f, "vector"),
329        }
330    }
331}
332
333/// Consistency levels
334#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
335pub enum ConsistencyLevel {
336    /// Must read from primary or sync standby
337    Strong,
338    /// Allow semi-sync with bounded lag
339    Bounded,
340    /// Allow any replica (eventual consistency)
341    Eventual,
342}
343
344impl FromStr for ConsistencyLevel {
345    type Err = RoutingError;
346
347    fn from_str(s: &str) -> Result<Self> {
348        match s.to_lowercase().as_str() {
349            "strong" | "strict" | "linearizable" => Ok(ConsistencyLevel::Strong),
350            "bounded" | "session" | "read-your-writes" => Ok(ConsistencyLevel::Bounded),
351            "eventual" | "weak" => Ok(ConsistencyLevel::Eventual),
352            _ => Err(RoutingError::ParseError(format!(
353                "Unknown consistency level: {}",
354                s
355            ))),
356        }
357    }
358}
359
360impl std::fmt::Display for ConsistencyLevel {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        match self {
363            ConsistencyLevel::Strong => write!(f, "strong"),
364            ConsistencyLevel::Bounded => write!(f, "bounded"),
365            ConsistencyLevel::Eventual => write!(f, "eventual"),
366        }
367    }
368}
369
370/// Pooling mode hint (mirrors pool::PoolingMode)
371#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
372pub enum PoolingModeHint {
373    Session,
374    Transaction,
375    Statement,
376}
377
378impl FromStr for PoolingModeHint {
379    type Err = RoutingError;
380
381    fn from_str(s: &str) -> Result<Self> {
382        match s.to_lowercase().as_str() {
383            "session" => Ok(PoolingModeHint::Session),
384            "transaction" | "tx" => Ok(PoolingModeHint::Transaction),
385            "statement" | "stmt" | "query" => Ok(PoolingModeHint::Statement),
386            _ => Err(RoutingError::ParseError(format!(
387                "Unknown pool mode: {}",
388                s
389            ))),
390        }
391    }
392}
393
394#[cfg(feature = "pool-modes")]
395impl From<PoolingModeHint> for PoolingMode {
396    fn from(hint: PoolingModeHint) -> Self {
397        match hint {
398            PoolingModeHint::Session => PoolingMode::Session,
399            PoolingModeHint::Transaction => PoolingMode::Transaction,
400            PoolingModeHint::Statement => PoolingMode::Statement,
401        }
402    }
403}
404
405/// Cache behavior hints
406#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
407pub enum CacheBehavior {
408    /// Use normal caching
409    Normal,
410    /// Skip cache entirely
411    Skip,
412    /// Refresh cache (bypass read, update on response)
413    Refresh,
414    /// Use semantic (L3) cache
415    Semantic,
416    /// Only use L1 cache
417    L1Only,
418    /// Only use L2 cache
419    L2Only,
420}
421
422impl FromStr for CacheBehavior {
423    type Err = RoutingError;
424
425    fn from_str(s: &str) -> Result<Self> {
426        match s.to_lowercase().as_str() {
427            "normal" | "default" => Ok(CacheBehavior::Normal),
428            "skip" | "bypass" | "none" => Ok(CacheBehavior::Skip),
429            "refresh" | "force" | "update" => Ok(CacheBehavior::Refresh),
430            "semantic" | "l3" | "vector" => Ok(CacheBehavior::Semantic),
431            "l1" | "hot" => Ok(CacheBehavior::L1Only),
432            "l2" | "warm" => Ok(CacheBehavior::L2Only),
433            _ => Err(RoutingError::ParseError(format!(
434                "Unknown cache behavior: {}",
435                s
436            ))),
437        }
438    }
439}
440
441/// Query priority levels
442#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
443pub enum QueryPriority {
444    Low = 0,
445    #[default]
446    Normal = 1,
447    High = 2,
448    Critical = 3,
449}
450
451impl FromStr for QueryPriority {
452    type Err = RoutingError;
453
454    fn from_str(s: &str) -> Result<Self> {
455        match s.to_lowercase().as_str() {
456            "low" | "background" => Ok(QueryPriority::Low),
457            "normal" | "default" => Ok(QueryPriority::Normal),
458            "high" | "elevated" => Ok(QueryPriority::High),
459            "critical" | "urgent" | "realtime" => Ok(QueryPriority::Critical),
460            _ => Err(RoutingError::ParseError(format!("Unknown priority: {}", s))),
461        }
462    }
463}
464
465/// Retry behavior
466#[derive(Debug, Clone, PartialEq, Default)]
467pub enum RetryBehavior {
468    /// No retry
469    None,
470    /// Automatic retry with default count
471    #[default]
472    Auto,
473    /// Retry specific number of times
474    Count(u32),
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_parse_single_hint() {
483        let parser = HintParser::new();
484        let hints = parser.parse("/*helios:route=primary*/ SELECT * FROM users");
485
486        assert!(!hints.is_empty());
487        assert_eq!(hints.route, Some(RouteTarget::Primary));
488    }
489
490    #[test]
491    fn test_parse_multiple_hints() {
492        let parser = HintParser::new();
493        let hints = parser.parse(
494            "/*helios:route=standby,consistency=eventual,timeout=5s*/ SELECT * FROM products",
495        );
496
497        assert_eq!(hints.len(), 3);
498        assert_eq!(hints.route, Some(RouteTarget::Standby));
499        assert_eq!(hints.consistency, Some(ConsistencyLevel::Eventual));
500        assert_eq!(hints.timeout, Some(Duration::from_secs(5)));
501    }
502
503    #[test]
504    fn test_parse_node_hint() {
505        let parser = HintParser::new();
506        let hints = parser.parse("/*helios:node=standby-sync-1*/ SELECT * FROM logs");
507
508        assert_eq!(hints.node, Some("standby-sync-1".to_string()));
509    }
510
511    #[test]
512    fn test_parse_lag_hint() {
513        let parser = HintParser::new();
514        let hints = parser.parse("/*helios:route=async,lag=10s*/ SELECT COUNT(*) FROM events");
515
516        assert_eq!(hints.route, Some(RouteTarget::Async));
517        assert_eq!(hints.max_lag, Some(Duration::from_secs(10)));
518    }
519
520    #[test]
521    fn test_parse_priority_hint() {
522        let parser = HintParser::new();
523        let hints = parser.parse("/*helios:priority=critical*/ SELECT balance FROM accounts");
524
525        assert_eq!(hints.priority, Some(QueryPriority::Critical));
526    }
527
528    #[test]
529    fn test_parse_cache_hint() {
530        let parser = HintParser::new();
531        let hints = parser.parse("/*helios:cache=skip*/ SELECT now()");
532
533        assert_eq!(hints.cache, Some(CacheBehavior::Skip));
534    }
535
536    #[test]
537    fn test_parse_pool_hint() {
538        let parser = HintParser::new();
539        let hints = parser.parse("/*helios:pool=transaction*/ BEGIN");
540
541        assert_eq!(hints.pool, Some(PoolingModeHint::Transaction));
542    }
543
544    #[test]
545    fn test_strip_hints() {
546        let parser = HintParser::new();
547        let query = "/*helios:route=primary*/ SELECT * FROM users WHERE id = 1";
548        let stripped = parser.strip(query);
549
550        assert_eq!(stripped, "SELECT * FROM users WHERE id = 1");
551    }
552
553    #[test]
554    fn test_strip_multiple_hints() {
555        let parser = HintParser::new();
556        let query = "/*helios:route=standby*/ SELECT * /*helios:cache=skip*/ FROM users";
557        let stripped = parser.strip(query);
558
559        assert_eq!(stripped, "SELECT *  FROM users");
560    }
561
562    #[test]
563    fn test_validate_conflicting_hints() {
564        let parser = HintParser::new();
565        let hints = parser.parse("/*helios:route=async,consistency=strong*/ SELECT * FROM users");
566
567        let result = hints.validate();
568        assert!(result.is_err());
569    }
570
571    #[test]
572    fn test_route_target_parsing() {
573        assert_eq!(
574            RouteTarget::from_str("primary").unwrap(),
575            RouteTarget::Primary
576        );
577        assert_eq!(
578            RouteTarget::from_str("master").unwrap(),
579            RouteTarget::Primary
580        );
581        assert_eq!(
582            RouteTarget::from_str("standby").unwrap(),
583            RouteTarget::Standby
584        );
585        assert_eq!(
586            RouteTarget::from_str("replica").unwrap(),
587            RouteTarget::Standby
588        );
589        assert_eq!(RouteTarget::from_str("sync").unwrap(), RouteTarget::Sync);
590        assert_eq!(RouteTarget::from_str("async").unwrap(), RouteTarget::Async);
591        assert_eq!(RouteTarget::from_str("local").unwrap(), RouteTarget::Local);
592    }
593
594    #[test]
595    fn test_consistency_level_parsing() {
596        assert_eq!(
597            ConsistencyLevel::from_str("strong").unwrap(),
598            ConsistencyLevel::Strong
599        );
600        assert_eq!(
601            ConsistencyLevel::from_str("bounded").unwrap(),
602            ConsistencyLevel::Bounded
603        );
604        assert_eq!(
605            ConsistencyLevel::from_str("eventual").unwrap(),
606            ConsistencyLevel::Eventual
607        );
608    }
609
610    #[test]
611    fn test_query_priority_ordering() {
612        assert!(QueryPriority::Critical > QueryPriority::High);
613        assert!(QueryPriority::High > QueryPriority::Normal);
614        assert!(QueryPriority::Normal > QueryPriority::Low);
615    }
616
617    #[test]
618    fn test_ai_workflow_hints() {
619        let parser = HintParser::new();
620        let hints = parser.parse(
621            "/*helios:route=async,tool=knowledge_search,workflow=planning*/ SELECT content FROM docs"
622        );
623
624        assert!(!hints.is_empty());
625        assert_eq!(hints.route, Some(RouteTarget::Async));
626
627        // Check for tool and workflow hints in the list
628        let has_tool = hints
629            .hints()
630            .iter()
631            .any(|h| matches!(h, RoutingHint::AgentTool(t) if t == "knowledge_search"));
632        let has_workflow = hints
633            .hints()
634            .iter()
635            .any(|h| matches!(h, RoutingHint::WorkflowStep(w) if w == "planning"));
636
637        assert!(has_tool);
638        assert!(has_workflow);
639    }
640
641    #[test]
642    fn test_branch_hint() {
643        let parser = HintParser::new();
644        let hints = parser.parse("/*helios:branch=analytics,route=local*/ SELECT * FROM reports");
645
646        assert_eq!(hints.branch, Some("analytics".to_string()));
647        assert_eq!(hints.route, Some(RouteTarget::Local));
648    }
649
650    #[test]
651    fn test_twr_hint() {
652        let parser = HintParser::new();
653        let hints = parser.parse("/*helios:route=sync,twr=true*/ INSERT INTO logs VALUES (1)");
654
655        assert_eq!(hints.route, Some(RouteTarget::Sync));
656        assert_eq!(hints.twr, Some(true));
657    }
658
659    #[test]
660    fn test_empty_query() {
661        let parser = HintParser::new();
662        let hints = parser.parse("SELECT * FROM users");
663
664        assert!(hints.is_empty());
665    }
666
667    #[test]
668    fn test_extract_raw() {
669        let parser = HintParser::new();
670        let raw = parser.extract_raw("/*helios:route=primary*/ SELECT /*helios:cache=skip*/ 1");
671
672        assert_eq!(raw.len(), 2);
673        assert!(raw[0].contains("route=primary"));
674        assert!(raw[1].contains("cache=skip"));
675    }
676}