Skip to main content

heliosdb_proxy/cache/
hints.rs

1//! Cache Hints Parser
2//!
3//! Parses SQL comments to extract cache control hints.
4//!
5//! # Supported Hints
6//!
7//! ```sql
8//! /* helios:cache=skip */          -- Skip caching entirely
9//! /* helios:cache_ttl=60 */        -- Override TTL (seconds)
10//! /* helios:cache=semantic */      -- Enable semantic caching
11//! /* helios:cache_tables=a,b */    -- Override table dependencies
12//! /* helios:cache_refresh */       -- Force cache refresh
13//! ```
14
15use once_cell::sync::Lazy;
16use regex::Regex;
17use std::time::Duration;
18
19/// Parsed cache hints from a SQL query
20#[derive(Debug, Clone, Default)]
21pub struct CacheHint {
22    /// Skip caching entirely
23    pub skip: bool,
24
25    /// Override TTL (None = use default)
26    pub ttl: Option<Duration>,
27
28    /// Enable semantic/L3 caching
29    pub semantic_cache: bool,
30
31    /// Override table dependencies
32    pub tables: Option<Vec<String>>,
33
34    /// Force cache refresh (bypass read, update cache)
35    pub refresh: bool,
36
37    /// Specific cache level to use
38    pub level: Option<CacheLevelHint>,
39}
40
41/// Hint for specific cache level
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum CacheLevelHint {
44    /// Only use L1 (connection-local)
45    L1Only,
46    /// Only use L2 (shared)
47    L2Only,
48    /// Only use L3 (semantic)
49    L3Only,
50    /// Use all levels
51    All,
52}
53
54// Regex patterns for hint parsing
55static HINT_PATTERN: Lazy<Regex> =
56    Lazy::new(|| Regex::new(r"/\*\s*helios:(\w+)(?:=([^*]+))?\s*\*/").unwrap());
57
58static HINT_PATTERN_DOUBLE_DASH: Lazy<Regex> =
59    Lazy::new(|| Regex::new(r"--\s*helios:(\w+)(?:=(\S+))?").unwrap());
60
61/// Parse cache hints from a SQL query
62pub fn parse_cache_hints(sql: &str) -> CacheHint {
63    let mut hint = CacheHint::default();
64
65    // Parse /* helios:key=value */ style hints
66    for cap in HINT_PATTERN.captures_iter(sql) {
67        let key = cap.get(1).map(|m| m.as_str()).unwrap_or("");
68        let value = cap.get(2).map(|m| m.as_str().trim());
69
70        apply_hint(&mut hint, key, value);
71    }
72
73    // Parse -- helios:key=value style hints
74    for cap in HINT_PATTERN_DOUBLE_DASH.captures_iter(sql) {
75        let key = cap.get(1).map(|m| m.as_str()).unwrap_or("");
76        let value = cap.get(2).map(|m| m.as_str().trim());
77
78        apply_hint(&mut hint, key, value);
79    }
80
81    hint
82}
83
84/// Apply a single hint to the CacheHint struct
85fn apply_hint(hint: &mut CacheHint, key: &str, value: Option<&str>) {
86    match key.to_lowercase().as_str() {
87        "cache" => {
88            if let Some(v) = value {
89                match v.to_lowercase().as_str() {
90                    "skip" | "no" | "off" | "false" | "disable" => {
91                        hint.skip = true;
92                    }
93                    "semantic" | "l3" | "vector" => {
94                        hint.semantic_cache = true;
95                    }
96                    "l1" | "hot" | "local" => {
97                        hint.level = Some(CacheLevelHint::L1Only);
98                    }
99                    "l2" | "warm" | "shared" => {
100                        hint.level = Some(CacheLevelHint::L2Only);
101                    }
102                    "all" | "yes" | "on" | "true" | "enable" => {
103                        hint.level = Some(CacheLevelHint::All);
104                    }
105                    _ => {}
106                }
107            }
108        }
109        "cache_ttl" | "ttl" => {
110            if let Some(v) = value {
111                if let Ok(secs) = v.parse::<u64>() {
112                    hint.ttl = Some(Duration::from_secs(secs));
113                } else if let Some(duration) = parse_duration(v) {
114                    hint.ttl = Some(duration);
115                }
116            }
117        }
118        "cache_tables" | "tables" => {
119            if let Some(v) = value {
120                let tables: Vec<String> = v
121                    .split(',')
122                    .map(|s| s.trim().to_string())
123                    .filter(|s| !s.is_empty())
124                    .collect();
125                if !tables.is_empty() {
126                    hint.tables = Some(tables);
127                }
128            }
129        }
130        "cache_refresh" | "refresh" | "nocache_read" => {
131            hint.refresh = true;
132        }
133        "semantic" | "semantic_cache" => {
134            hint.semantic_cache = true;
135        }
136        _ => {}
137    }
138}
139
140/// Parse duration strings like "5m", "1h", "30s"
141fn parse_duration(s: &str) -> Option<Duration> {
142    let s = s.trim().to_lowercase();
143
144    if s.is_empty() {
145        return None;
146    }
147
148    // Try to find the numeric part and unit
149    let mut num_end = 0;
150    for (i, c) in s.char_indices() {
151        if c.is_ascii_digit() || c == '.' {
152            num_end = i + c.len_utf8();
153        } else {
154            break;
155        }
156    }
157
158    if num_end == 0 {
159        return None;
160    }
161
162    let num: f64 = s[..num_end].parse().ok()?;
163    let unit = &s[num_end..];
164
165    let multiplier = match unit {
166        "" | "s" | "sec" | "secs" | "second" | "seconds" => 1.0,
167        "m" | "min" | "mins" | "minute" | "minutes" => 60.0,
168        "h" | "hr" | "hrs" | "hour" | "hours" => 3600.0,
169        "d" | "day" | "days" => 86400.0,
170        "ms" | "millis" | "milliseconds" => 0.001,
171        _ => return None,
172    };
173
174    Some(Duration::from_secs_f64(num * multiplier))
175}
176
177/// Strip cache hints from SQL query
178pub fn strip_hints(sql: &str) -> String {
179    let result = HINT_PATTERN.replace_all(sql, "");
180    let result = HINT_PATTERN_DOUBLE_DASH.replace_all(&result, "");
181    result.trim().to_string()
182}
183
184/// Check if a query is cacheable (SELECT, VALUES, etc.)
185pub fn is_cacheable_query(sql: &str) -> bool {
186    let trimmed = sql.trim().to_uppercase();
187
188    // Only cache read operations
189    if trimmed.starts_with("SELECT")
190        || trimmed.starts_with("VALUES")
191        || trimmed.starts_with("TABLE")
192        || trimmed.starts_with("WITH") && trimmed.contains("SELECT")
193    {
194        // Exclude queries with side effects
195        !trimmed.contains("FOR UPDATE")
196            && !trimmed.contains("FOR SHARE")
197            && !trimmed.contains("FOR NO KEY UPDATE")
198            && !trimmed.contains("FOR KEY SHARE")
199            && !trimmed.contains("NOWAIT")
200            && !trimmed.contains("SKIP LOCKED")
201    } else {
202        false
203    }
204}
205
206/// Check if SQL is a write operation (for cache invalidation)
207pub fn is_write_operation(sql: &str) -> bool {
208    let trimmed = sql.trim().to_uppercase();
209
210    trimmed.starts_with("INSERT")
211        || trimmed.starts_with("UPDATE")
212        || trimmed.starts_with("DELETE")
213        || trimmed.starts_with("TRUNCATE")
214        || trimmed.starts_with("DROP")
215        || trimmed.starts_with("ALTER")
216        || trimmed.starts_with("CREATE")
217        || trimmed.starts_with("MERGE")
218        || trimmed.starts_with("UPSERT")
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_parse_skip_hint() {
227        let sql = "/* helios:cache=skip */ SELECT * FROM users";
228        let hint = parse_cache_hints(sql);
229        assert!(hint.skip);
230        assert!(!hint.semantic_cache);
231    }
232
233    #[test]
234    fn test_parse_ttl_hint() {
235        let sql = "/* helios:cache_ttl=300 */ SELECT * FROM users";
236        let hint = parse_cache_hints(sql);
237        assert_eq!(hint.ttl, Some(Duration::from_secs(300)));
238    }
239
240    #[test]
241    fn test_parse_ttl_with_unit() {
242        let sql = "/* helios:ttl=5m */ SELECT * FROM users";
243        let hint = parse_cache_hints(sql);
244        assert_eq!(hint.ttl, Some(Duration::from_secs(300)));
245
246        let sql2 = "/* helios:ttl=1h */ SELECT * FROM users";
247        let hint2 = parse_cache_hints(sql2);
248        assert_eq!(hint2.ttl, Some(Duration::from_secs(3600)));
249    }
250
251    #[test]
252    fn test_parse_semantic_hint() {
253        let sql = "/* helios:cache=semantic */ SELECT * FROM documents WHERE topic = 'AI'";
254        let hint = parse_cache_hints(sql);
255        assert!(hint.semantic_cache);
256    }
257
258    #[test]
259    fn test_parse_tables_hint() {
260        let sql = "/* helios:cache_tables=users,sessions */ SELECT u.* FROM users u JOIN sessions s ON u.id = s.user_id";
261        let hint = parse_cache_hints(sql);
262        assert_eq!(
263            hint.tables,
264            Some(vec!["users".to_string(), "sessions".to_string()])
265        );
266    }
267
268    #[test]
269    fn test_parse_refresh_hint() {
270        let sql = "/* helios:cache_refresh */ SELECT * FROM users";
271        let hint = parse_cache_hints(sql);
272        assert!(hint.refresh);
273    }
274
275    #[test]
276    fn test_parse_multiple_hints() {
277        let sql = "/* helios:cache_ttl=60 */ /* helios:cache=semantic */ SELECT * FROM docs";
278        let hint = parse_cache_hints(sql);
279        assert_eq!(hint.ttl, Some(Duration::from_secs(60)));
280        assert!(hint.semantic_cache);
281    }
282
283    #[test]
284    fn test_parse_double_dash_hint() {
285        let sql = "-- helios:cache=skip\nSELECT * FROM users";
286        let hint = parse_cache_hints(sql);
287        assert!(hint.skip);
288    }
289
290    #[test]
291    fn test_strip_hints() {
292        let sql = "/* helios:cache=skip */ SELECT * FROM users";
293        let stripped = strip_hints(sql);
294        assert_eq!(stripped, "SELECT * FROM users");
295
296        let sql2 = "-- helios:ttl=60\nSELECT * FROM users";
297        let stripped2 = strip_hints(sql2);
298        assert_eq!(stripped2, "SELECT * FROM users");
299    }
300
301    #[test]
302    fn test_is_cacheable_query() {
303        assert!(is_cacheable_query("SELECT * FROM users"));
304        assert!(is_cacheable_query("  select id from users  "));
305        assert!(is_cacheable_query(
306            "WITH cte AS (SELECT 1) SELECT * FROM cte"
307        ));
308        assert!(is_cacheable_query("VALUES (1, 2), (3, 4)"));
309        assert!(is_cacheable_query("TABLE users"));
310
311        // Not cacheable
312        assert!(!is_cacheable_query("INSERT INTO users VALUES (1)"));
313        assert!(!is_cacheable_query("UPDATE users SET name = 'test'"));
314        assert!(!is_cacheable_query("DELETE FROM users"));
315        assert!(!is_cacheable_query("SELECT * FROM users FOR UPDATE"));
316        assert!(!is_cacheable_query("SELECT * FROM users FOR SHARE"));
317    }
318
319    #[test]
320    fn test_is_write_operation() {
321        assert!(is_write_operation("INSERT INTO users VALUES (1)"));
322        assert!(is_write_operation("UPDATE users SET name = 'test'"));
323        assert!(is_write_operation("DELETE FROM users"));
324        assert!(is_write_operation("TRUNCATE users"));
325        assert!(is_write_operation("DROP TABLE users"));
326        assert!(is_write_operation("ALTER TABLE users ADD COLUMN age INT"));
327        assert!(is_write_operation("CREATE TABLE test (id INT)"));
328
329        // Not write operations
330        assert!(!is_write_operation("SELECT * FROM users"));
331        assert!(!is_write_operation("EXPLAIN SELECT * FROM users"));
332    }
333
334    #[test]
335    fn test_parse_duration() {
336        assert_eq!(parse_duration("60"), Some(Duration::from_secs(60)));
337        assert_eq!(parse_duration("60s"), Some(Duration::from_secs(60)));
338        assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
339        assert_eq!(parse_duration("1h"), Some(Duration::from_secs(3600)));
340        assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86400)));
341        assert_eq!(parse_duration("500ms"), Some(Duration::from_millis(500)));
342        assert_eq!(parse_duration(""), None);
343        assert_eq!(parse_duration("invalid"), None);
344    }
345
346    #[test]
347    fn test_cache_level_hints() {
348        let sql = "/* helios:cache=l1 */ SELECT * FROM users";
349        let hint = parse_cache_hints(sql);
350        assert_eq!(hint.level, Some(CacheLevelHint::L1Only));
351
352        let sql2 = "/* helios:cache=l2 */ SELECT * FROM users";
353        let hint2 = parse_cache_hints(sql2);
354        assert_eq!(hint2.level, Some(CacheLevelHint::L2Only));
355
356        let sql3 = "/* helios:cache=l3 */ SELECT * FROM users";
357        let hint3 = parse_cache_hints(sql3);
358        assert!(hint3.semantic_cache);
359    }
360}