Skip to main content

heliosdb_proxy/cache/
hints.rs

1//! Cache Hints Parser
2//!
3//! Parses SQL comments to extract cache control hints.
4//!
5//! # Supported Hints
6//!
7//! ```sql
8//! /* helios:cache=skip */          -- Skip caching entirely
9//! /* helios:cache_ttl=60 */        -- Override TTL (seconds)
10//! /* helios:cache=semantic */      -- Enable semantic caching
11//! /* helios:cache_tables=a,b */    -- Override table dependencies
12//! /* helios:cache_refresh */       -- Force cache refresh
13//! ```
14
15use std::time::Duration;
16use regex::Regex;
17use once_cell::sync::Lazy;
18
19/// Parsed cache hints from a SQL query
20#[derive(Debug, Clone, Default)]
21pub struct CacheHint {
22    /// Skip caching entirely
23    pub skip: bool,
24
25    /// Override TTL (None = use default)
26    pub ttl: Option<Duration>,
27
28    /// Enable semantic/L3 caching
29    pub semantic_cache: bool,
30
31    /// Override table dependencies
32    pub tables: Option<Vec<String>>,
33
34    /// Force cache refresh (bypass read, update cache)
35    pub refresh: bool,
36
37    /// Specific cache level to use
38    pub level: Option<CacheLevelHint>,
39}
40
41/// Hint for specific cache level
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum CacheLevelHint {
44    /// Only use L1 (connection-local)
45    L1Only,
46    /// Only use L2 (shared)
47    L2Only,
48    /// Only use L3 (semantic)
49    L3Only,
50    /// Use all levels
51    All,
52}
53
54// Regex patterns for hint parsing
55static HINT_PATTERN: Lazy<Regex> = Lazy::new(|| {
56    Regex::new(r"/\*\s*helios:(\w+)(?:=([^*]+))?\s*\*/").unwrap()
57});
58
59static HINT_PATTERN_DOUBLE_DASH: Lazy<Regex> = Lazy::new(|| {
60    Regex::new(r"--\s*helios:(\w+)(?:=(\S+))?").unwrap()
61});
62
63/// Parse cache hints from a SQL query
64pub fn parse_cache_hints(sql: &str) -> CacheHint {
65    let mut hint = CacheHint::default();
66
67    // Parse /* helios:key=value */ style hints
68    for cap in HINT_PATTERN.captures_iter(sql) {
69        let key = cap.get(1).map(|m| m.as_str()).unwrap_or("");
70        let value = cap.get(2).map(|m| m.as_str().trim());
71
72        apply_hint(&mut hint, key, value);
73    }
74
75    // Parse -- helios:key=value style hints
76    for cap in HINT_PATTERN_DOUBLE_DASH.captures_iter(sql) {
77        let key = cap.get(1).map(|m| m.as_str()).unwrap_or("");
78        let value = cap.get(2).map(|m| m.as_str().trim());
79
80        apply_hint(&mut hint, key, value);
81    }
82
83    hint
84}
85
86/// Apply a single hint to the CacheHint struct
87fn apply_hint(hint: &mut CacheHint, key: &str, value: Option<&str>) {
88    match key.to_lowercase().as_str() {
89        "cache" => {
90            if let Some(v) = value {
91                match v.to_lowercase().as_str() {
92                    "skip" | "no" | "off" | "false" | "disable" => {
93                        hint.skip = true;
94                    }
95                    "semantic" | "l3" | "vector" => {
96                        hint.semantic_cache = true;
97                    }
98                    "l1" | "hot" | "local" => {
99                        hint.level = Some(CacheLevelHint::L1Only);
100                    }
101                    "l2" | "warm" | "shared" => {
102                        hint.level = Some(CacheLevelHint::L2Only);
103                    }
104                    "all" | "yes" | "on" | "true" | "enable" => {
105                        hint.level = Some(CacheLevelHint::All);
106                    }
107                    _ => {}
108                }
109            }
110        }
111        "cache_ttl" | "ttl" => {
112            if let Some(v) = value {
113                if let Ok(secs) = v.parse::<u64>() {
114                    hint.ttl = Some(Duration::from_secs(secs));
115                } else if let Some(duration) = parse_duration(v) {
116                    hint.ttl = Some(duration);
117                }
118            }
119        }
120        "cache_tables" | "tables" => {
121            if let Some(v) = value {
122                let tables: Vec<String> = v
123                    .split(',')
124                    .map(|s| s.trim().to_string())
125                    .filter(|s| !s.is_empty())
126                    .collect();
127                if !tables.is_empty() {
128                    hint.tables = Some(tables);
129                }
130            }
131        }
132        "cache_refresh" | "refresh" | "nocache_read" => {
133            hint.refresh = true;
134        }
135        "semantic" | "semantic_cache" => {
136            hint.semantic_cache = true;
137        }
138        _ => {}
139    }
140}
141
142/// Parse duration strings like "5m", "1h", "30s"
143fn parse_duration(s: &str) -> Option<Duration> {
144    let s = s.trim().to_lowercase();
145
146    if s.is_empty() {
147        return None;
148    }
149
150    // Try to find the numeric part and unit
151    let mut num_end = 0;
152    for (i, c) in s.char_indices() {
153        if c.is_ascii_digit() || c == '.' {
154            num_end = i + c.len_utf8();
155        } else {
156            break;
157        }
158    }
159
160    if num_end == 0 {
161        return None;
162    }
163
164    let num: f64 = s[..num_end].parse().ok()?;
165    let unit = &s[num_end..];
166
167    let multiplier = match unit {
168        "" | "s" | "sec" | "secs" | "second" | "seconds" => 1.0,
169        "m" | "min" | "mins" | "minute" | "minutes" => 60.0,
170        "h" | "hr" | "hrs" | "hour" | "hours" => 3600.0,
171        "d" | "day" | "days" => 86400.0,
172        "ms" | "millis" | "milliseconds" => 0.001,
173        _ => return None,
174    };
175
176    Some(Duration::from_secs_f64(num * multiplier))
177}
178
179/// Strip cache hints from SQL query
180pub fn strip_hints(sql: &str) -> String {
181    let result = HINT_PATTERN.replace_all(sql, "");
182    let result = HINT_PATTERN_DOUBLE_DASH.replace_all(&result, "");
183    result.trim().to_string()
184}
185
186/// Check if a query is cacheable (SELECT, VALUES, etc.)
187pub fn is_cacheable_query(sql: &str) -> bool {
188    let trimmed = sql.trim().to_uppercase();
189
190    // Only cache read operations
191    if trimmed.starts_with("SELECT")
192        || trimmed.starts_with("VALUES")
193        || trimmed.starts_with("TABLE")
194        || trimmed.starts_with("WITH") && trimmed.contains("SELECT")
195    {
196        // Exclude queries with side effects
197        !trimmed.contains("FOR UPDATE")
198            && !trimmed.contains("FOR SHARE")
199            && !trimmed.contains("FOR NO KEY UPDATE")
200            && !trimmed.contains("FOR KEY SHARE")
201            && !trimmed.contains("NOWAIT")
202            && !trimmed.contains("SKIP LOCKED")
203    } else {
204        false
205    }
206}
207
208/// Check if SQL is a write operation (for cache invalidation)
209pub fn is_write_operation(sql: &str) -> bool {
210    let trimmed = sql.trim().to_uppercase();
211
212    trimmed.starts_with("INSERT")
213        || trimmed.starts_with("UPDATE")
214        || trimmed.starts_with("DELETE")
215        || trimmed.starts_with("TRUNCATE")
216        || trimmed.starts_with("DROP")
217        || trimmed.starts_with("ALTER")
218        || trimmed.starts_with("CREATE")
219        || trimmed.starts_with("MERGE")
220        || trimmed.starts_with("UPSERT")
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn test_parse_skip_hint() {
229        let sql = "/* helios:cache=skip */ SELECT * FROM users";
230        let hint = parse_cache_hints(sql);
231        assert!(hint.skip);
232        assert!(!hint.semantic_cache);
233    }
234
235    #[test]
236    fn test_parse_ttl_hint() {
237        let sql = "/* helios:cache_ttl=300 */ SELECT * FROM users";
238        let hint = parse_cache_hints(sql);
239        assert_eq!(hint.ttl, Some(Duration::from_secs(300)));
240    }
241
242    #[test]
243    fn test_parse_ttl_with_unit() {
244        let sql = "/* helios:ttl=5m */ SELECT * FROM users";
245        let hint = parse_cache_hints(sql);
246        assert_eq!(hint.ttl, Some(Duration::from_secs(300)));
247
248        let sql2 = "/* helios:ttl=1h */ SELECT * FROM users";
249        let hint2 = parse_cache_hints(sql2);
250        assert_eq!(hint2.ttl, Some(Duration::from_secs(3600)));
251    }
252
253    #[test]
254    fn test_parse_semantic_hint() {
255        let sql = "/* helios:cache=semantic */ SELECT * FROM documents WHERE topic = 'AI'";
256        let hint = parse_cache_hints(sql);
257        assert!(hint.semantic_cache);
258    }
259
260    #[test]
261    fn test_parse_tables_hint() {
262        let sql = "/* helios:cache_tables=users,sessions */ SELECT u.* FROM users u JOIN sessions s ON u.id = s.user_id";
263        let hint = parse_cache_hints(sql);
264        assert_eq!(hint.tables, Some(vec!["users".to_string(), "sessions".to_string()]));
265    }
266
267    #[test]
268    fn test_parse_refresh_hint() {
269        let sql = "/* helios:cache_refresh */ SELECT * FROM users";
270        let hint = parse_cache_hints(sql);
271        assert!(hint.refresh);
272    }
273
274    #[test]
275    fn test_parse_multiple_hints() {
276        let sql = "/* helios:cache_ttl=60 */ /* helios:cache=semantic */ SELECT * FROM docs";
277        let hint = parse_cache_hints(sql);
278        assert_eq!(hint.ttl, Some(Duration::from_secs(60)));
279        assert!(hint.semantic_cache);
280    }
281
282    #[test]
283    fn test_parse_double_dash_hint() {
284        let sql = "-- helios:cache=skip\nSELECT * FROM users";
285        let hint = parse_cache_hints(sql);
286        assert!(hint.skip);
287    }
288
289    #[test]
290    fn test_strip_hints() {
291        let sql = "/* helios:cache=skip */ SELECT * FROM users";
292        let stripped = strip_hints(sql);
293        assert_eq!(stripped, "SELECT * FROM users");
294
295        let sql2 = "-- helios:ttl=60\nSELECT * FROM users";
296        let stripped2 = strip_hints(sql2);
297        assert_eq!(stripped2, "SELECT * FROM users");
298    }
299
300    #[test]
301    fn test_is_cacheable_query() {
302        assert!(is_cacheable_query("SELECT * FROM users"));
303        assert!(is_cacheable_query("  select id from users  "));
304        assert!(is_cacheable_query("WITH cte AS (SELECT 1) SELECT * FROM cte"));
305        assert!(is_cacheable_query("VALUES (1, 2), (3, 4)"));
306        assert!(is_cacheable_query("TABLE users"));
307
308        // Not cacheable
309        assert!(!is_cacheable_query("INSERT INTO users VALUES (1)"));
310        assert!(!is_cacheable_query("UPDATE users SET name = 'test'"));
311        assert!(!is_cacheable_query("DELETE FROM users"));
312        assert!(!is_cacheable_query("SELECT * FROM users FOR UPDATE"));
313        assert!(!is_cacheable_query("SELECT * FROM users FOR SHARE"));
314    }
315
316    #[test]
317    fn test_is_write_operation() {
318        assert!(is_write_operation("INSERT INTO users VALUES (1)"));
319        assert!(is_write_operation("UPDATE users SET name = 'test'"));
320        assert!(is_write_operation("DELETE FROM users"));
321        assert!(is_write_operation("TRUNCATE users"));
322        assert!(is_write_operation("DROP TABLE users"));
323        assert!(is_write_operation("ALTER TABLE users ADD COLUMN age INT"));
324        assert!(is_write_operation("CREATE TABLE test (id INT)"));
325
326        // Not write operations
327        assert!(!is_write_operation("SELECT * FROM users"));
328        assert!(!is_write_operation("EXPLAIN SELECT * FROM users"));
329    }
330
331    #[test]
332    fn test_parse_duration() {
333        assert_eq!(parse_duration("60"), Some(Duration::from_secs(60)));
334        assert_eq!(parse_duration("60s"), Some(Duration::from_secs(60)));
335        assert_eq!(parse_duration("5m"), Some(Duration::from_secs(300)));
336        assert_eq!(parse_duration("1h"), Some(Duration::from_secs(3600)));
337        assert_eq!(parse_duration("1d"), Some(Duration::from_secs(86400)));
338        assert_eq!(parse_duration("500ms"), Some(Duration::from_millis(500)));
339        assert_eq!(parse_duration(""), None);
340        assert_eq!(parse_duration("invalid"), None);
341    }
342
343    #[test]
344    fn test_cache_level_hints() {
345        let sql = "/* helios:cache=l1 */ SELECT * FROM users";
346        let hint = parse_cache_hints(sql);
347        assert_eq!(hint.level, Some(CacheLevelHint::L1Only));
348
349        let sql2 = "/* helios:cache=l2 */ SELECT * FROM users";
350        let hint2 = parse_cache_hints(sql2);
351        assert_eq!(hint2.level, Some(CacheLevelHint::L2Only));
352
353        let sql3 = "/* helios:cache=l3 */ SELECT * FROM users";
354        let hint3 = parse_cache_hints(sql3);
355        assert!(hint3.semantic_cache);
356    }
357}