Skip to main content

reddb_server/storage/query/planner/
cache_key.rs

1//! Plan cache key normalisation — Fase 4 P1 building block.
2//!
3//! Normalises a raw SQL query string into a canonical cache key
4//! by replacing literal tokens (integers, floats, strings,
5//! booleans, null) with a single `?` placeholder. Two queries
6//! that differ only in their literal values collapse to the
7//! same key.
8//!
9//! ## Why here
10//!
11//! The full parameter-binding story — `Expr::Parameter(n)` in
12//! the AST, a bind phase that substitutes concrete values
13//! before execution, cache-hit reuse of the parsed expression
14//! — requires invasive changes to every path that holds a
15//! `QueryExpr`. That's Fase 4 W3+ scope.
16//!
17//! This module is the smallest immediately-shippable piece:
18//! the normalised cache key. Today's `impl_core::execute_query`
19//! keys the plan cache by raw SQL text, so `WHERE id = 1` and
20//! `WHERE id = 2` produce different entries. Normalising the
21//! key first means both queries hit a shared entry.
22//!
23//! BUT the cached entry still contains the *old* literal
24//! values baked into its `QueryExpr`, so cache hits must
25//! re-parse the new query and discard the cached plan's
26//! AST if the literals matter for execution. The follow-up
27//! commit does exactly that — `execute_query` will compare the
28//! normalised form on lookup and re-parse when the cached
29//! plan's literals don't match the fresh query.
30//!
31//! Until that follow-up, this module is the fast-path
32//! building block: cheap tokenisation + literal stripping,
33//! producing a stable `String` the cache can use.
34//!
35//! ## Algorithm
36//!
37//! Single-pass tokenizer-lite that walks the query character
38//! by character and emits a canonical form:
39//!
40//! - Integers / floats: emit `?`
41//! - Quoted strings (single + double): emit `?`
42//! - `TRUE` / `FALSE` / `NULL` keywords (case-insensitive,
43//!   word-bounded): emit `?`
44//! - Everything else: copy verbatim.
45//! - Whitespace runs collapse to a single space so `SELECT  a`
46//!   and `SELECT a` produce the same key.
47//! - Keywords are uppercased so `select` and `SELECT` match.
48//!
49//! The output is a best-effort canonical form. It's not a
50//! formal parse — we only care about stable equivalence
51//! classes, not strict correctness.
52
53use crate::storage::query::lexer::{Lexer, Token};
54use crate::storage::schema::Value;
55
56/// Normalise a raw SQL query into a cache-friendly canonical
57/// form. Stable across whitespace, case, and literal values;
58/// identical AST shapes collapse to the same output.
59///
60/// Worst case O(n) where n = input length, O(1) state. No
61/// allocation beyond the output string.
62pub fn normalize_cache_key(sql: &str) -> String {
63    let mut out = String::with_capacity(sql.len());
64    let bytes = sql.as_bytes();
65    let mut i = 0;
66    let mut last_was_space = true; // suppress leading space
67    let mut preserve_numeric_literal = false;
68    while i < bytes.len() {
69        let b = bytes[i];
70
71        // Whitespace collapse.
72        if b.is_ascii_whitespace() {
73            if !last_was_space {
74                out.push(' ');
75                last_was_space = true;
76            }
77            i += 1;
78            continue;
79        }
80
81        // Single-quoted string: scan to matching quote, emit `?`.
82        if b == b'\'' {
83            i += 1;
84            while i < bytes.len() {
85                if bytes[i] == b'\'' {
86                    // SQL escape: two consecutive quotes is a
87                    // literal quote inside the string. Skip both
88                    // and continue scanning.
89                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
90                        i += 2;
91                        continue;
92                    }
93                    i += 1;
94                    break;
95                }
96                i += 1;
97            }
98            out.push('?');
99            last_was_space = false;
100            continue;
101        }
102
103        // Double-quoted string (identifier in SQL-92; still
104        // handled as opaque here — quoted identifiers are
105        // case-sensitive so we emit them verbatim).
106        if b == b'"' {
107            let start = i;
108            i += 1;
109            while i < bytes.len() && bytes[i] != b'"' {
110                i += 1;
111            }
112            if i < bytes.len() {
113                i += 1;
114            }
115            out.push_str(&sql[start..i]);
116            last_was_space = false;
117            continue;
118        }
119
120        // Numeric literal: integer, float, or scientific.
121        // Optional leading sign is NOT consumed here because it
122        // could be a binary operator; we only canonicalise
123        // digit-led runs.
124        if b.is_ascii_digit() {
125            let start = i;
126            while i < bytes.len()
127                && (bytes[i].is_ascii_digit()
128                    || bytes[i] == b'.'
129                    || bytes[i] == b'e'
130                    || bytes[i] == b'E'
131                    || bytes[i] == b'+'
132                    || bytes[i] == b'-')
133            {
134                // Only consume + / - when immediately following
135                // e / E (scientific notation exponent sign).
136                if bytes[i] == b'+' || bytes[i] == b'-' {
137                    let prev = if i > 0 { bytes[i - 1] } else { 0 };
138                    if prev != b'e' && prev != b'E' {
139                        break;
140                    }
141                }
142                i += 1;
143            }
144            if preserve_numeric_literal {
145                out.push_str(&sql[start..i]);
146                preserve_numeric_literal = false;
147            } else {
148                out.push('?');
149            }
150            last_was_space = false;
151            continue;
152        }
153
154        // Identifier / keyword run.
155        if b.is_ascii_alphabetic() || b == b'_' {
156            let start = i;
157            while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
158                i += 1;
159            }
160            let word = &sql[start..i];
161            // Case-insensitive keyword canonicalisation for the
162            // three literal keywords TRUE / FALSE / NULL.
163            if word.eq_ignore_ascii_case("true")
164                || word.eq_ignore_ascii_case("false")
165                || word.eq_ignore_ascii_case("null")
166            {
167                out.push('?');
168                preserve_numeric_literal = false;
169            } else {
170                // Uppercase the word so `select` and `SELECT`
171                // collapse. This over-normalises — it also
172                // uppercases column names — but plan cache
173                // equivalence still holds because the column
174                // names are part of the normalised form and
175                // retain their identity within the query.
176                for c in word.chars() {
177                    out.push(c.to_ascii_uppercase());
178                }
179                preserve_numeric_literal =
180                    word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
181            }
182            last_was_space = false;
183            continue;
184        }
185
186        // Everything else (punctuation, operators, parens).
187        // Emit verbatim.
188        out.push(b as char);
189        preserve_numeric_literal = false;
190        last_was_space = false;
191        i += 1;
192    }
193
194    // Trim a single trailing space so `SELECT 1 ` and
195    // `SELECT 1` collapse.
196    if out.ends_with(' ') {
197        out.pop();
198    }
199
200    out
201}
202
203/// Returns true when two raw SQL strings would hit the same
204/// plan cache slot. Used by diagnostic tools to verify the
205/// normalisation is doing its job.
206pub fn same_cache_key(a: &str, b: &str) -> bool {
207    normalize_cache_key(a) == normalize_cache_key(b)
208}
209
210/// Fused single-pass of `normalize_cache_key` + `extract_literal_bindings`.
211///
212/// The normalize pass already identifies every literal token
213/// (byte-scan state machine — single quotes, numeric runs,
214/// TRUE/FALSE/NULL keywords). Extracting the bound `Value`
215/// alongside is strictly cheaper than running a separate `Lexer`
216/// pass, which is what `extract_literal_bindings` does today.
217///
218/// On the plan-cache HIT path (every UPDATE / repeat SELECT in a
219/// hot loop) this saves one full lex of the query text per hit.
220pub fn normalize_and_extract(sql: &str) -> (String, Vec<Value>) {
221    let mut out = String::with_capacity(sql.len());
222    let mut binds: Vec<Value> = Vec::new();
223    let bytes = sql.as_bytes();
224    let mut i = 0;
225    let mut last_was_space = true;
226    let mut preserve_numeric_literal = false;
227    while i < bytes.len() {
228        let b = bytes[i];
229
230        if b.is_ascii_whitespace() {
231            if !last_was_space {
232                out.push(' ');
233                last_was_space = true;
234            }
235            i += 1;
236            continue;
237        }
238
239        if b == b'\'' {
240            // Walk the string body, honouring the SQL '' escape.
241            // Handle the escape-free fast path as a single span copy.
242            i += 1;
243            let body_start = i;
244            let mut literal: Option<String> = None;
245            while i < bytes.len() {
246                if bytes[i] == b'\'' {
247                    if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
248                        // Escaped single quote — switch to the
249                        // owned-accumulator mode if we haven't
250                        // already, copying what we've seen so far.
251                        let acc = literal.get_or_insert_with(|| sql[body_start..i].to_string());
252                        acc.push('\'');
253                        i += 2;
254                        continue;
255                    }
256                    break;
257                }
258                if let Some(ref mut acc) = literal {
259                    acc.push(bytes[i] as char);
260                }
261                i += 1;
262            }
263            let value = match literal {
264                Some(s) => s,
265                None => sql[body_start..i].to_string(),
266            };
267            if i < bytes.len() && bytes[i] == b'\'' {
268                i += 1;
269            }
270            binds.push(Value::text(value));
271            out.push('?');
272            last_was_space = false;
273            continue;
274        }
275
276        if b == b'"' {
277            let start = i;
278            i += 1;
279            while i < bytes.len() && bytes[i] != b'"' {
280                i += 1;
281            }
282            if i < bytes.len() {
283                i += 1;
284            }
285            out.push_str(&sql[start..i]);
286            last_was_space = false;
287            continue;
288        }
289
290        if b.is_ascii_digit() {
291            let start = i;
292            while i < bytes.len()
293                && (bytes[i].is_ascii_digit()
294                    || bytes[i] == b'.'
295                    || bytes[i] == b'e'
296                    || bytes[i] == b'E'
297                    || bytes[i] == b'+'
298                    || bytes[i] == b'-')
299            {
300                if bytes[i] == b'+' || bytes[i] == b'-' {
301                    let prev = if i > 0 { bytes[i - 1] } else { 0 };
302                    if prev != b'e' && prev != b'E' {
303                        break;
304                    }
305                }
306                i += 1;
307            }
308            let lit = &sql[start..i];
309            if preserve_numeric_literal {
310                out.push_str(lit);
311                preserve_numeric_literal = false;
312            } else {
313                out.push('?');
314                if lit.contains('.') || lit.contains('e') || lit.contains('E') {
315                    if let Ok(v) = lit.parse::<f64>() {
316                        binds.push(Value::Float(v));
317                    }
318                } else if let Ok(v) = lit.parse::<i64>() {
319                    binds.push(Value::Integer(v));
320                } else if let Ok(v) = lit.parse::<u64>() {
321                    binds.push(Value::UnsignedInteger(v));
322                }
323            }
324            last_was_space = false;
325            continue;
326        }
327
328        if b.is_ascii_alphabetic() || b == b'_' {
329            let start = i;
330            while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
331                i += 1;
332            }
333            let word = &sql[start..i];
334            if word.eq_ignore_ascii_case("true") {
335                out.push('?');
336                binds.push(Value::Boolean(true));
337                preserve_numeric_literal = false;
338            } else if word.eq_ignore_ascii_case("false") {
339                out.push('?');
340                binds.push(Value::Boolean(false));
341                preserve_numeric_literal = false;
342            } else if word.eq_ignore_ascii_case("null") {
343                out.push('?');
344                binds.push(Value::Null);
345                preserve_numeric_literal = false;
346            } else {
347                for c in word.chars() {
348                    out.push(c.to_ascii_uppercase());
349                }
350                preserve_numeric_literal =
351                    word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
352            }
353            last_was_space = false;
354            continue;
355        }
356
357        out.push(b as char);
358        preserve_numeric_literal = false;
359        last_was_space = false;
360        i += 1;
361    }
362
363    if out.ends_with(' ') {
364        out.pop();
365    }
366
367    (out, binds)
368}
369
370pub fn extract_literal_bindings(sql: &str) -> Result<Vec<Value>, String> {
371    let mut lexer = Lexer::new(sql);
372    let mut binds = Vec::new();
373    let mut skip_next_numeric = false;
374
375    loop {
376        let spanned = lexer.next_token().map_err(|err| err.to_string())?;
377        match spanned.token {
378            Token::Eof => break,
379            Token::Limit | Token::Offset => {
380                skip_next_numeric = true;
381            }
382            Token::Integer(n) => {
383                if !skip_next_numeric {
384                    binds.push(Value::Integer(n));
385                }
386                skip_next_numeric = false;
387            }
388            Token::Float(n) => {
389                if !skip_next_numeric {
390                    binds.push(Value::Float(n));
391                }
392                skip_next_numeric = false;
393            }
394            Token::String(s) => {
395                binds.push(Value::text(s));
396                skip_next_numeric = false;
397            }
398            Token::True => {
399                binds.push(Value::Boolean(true));
400                skip_next_numeric = false;
401            }
402            Token::False => {
403                binds.push(Value::Boolean(false));
404                skip_next_numeric = false;
405            }
406            Token::Null => {
407                binds.push(Value::Null);
408                skip_next_numeric = false;
409            }
410            _ => {
411                skip_next_numeric = false;
412            }
413        }
414    }
415
416    Ok(binds)
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    #[test]
424    fn integer_literals_collapse() {
425        assert_eq!(
426            normalize_cache_key("SELECT * FROM t WHERE id = 1"),
427            normalize_cache_key("SELECT * FROM t WHERE id = 2"),
428        );
429    }
430
431    #[test]
432    fn string_literals_collapse() {
433        assert_eq!(
434            normalize_cache_key("SELECT * FROM t WHERE name = 'alice'"),
435            normalize_cache_key("SELECT * FROM t WHERE name = 'bob'"),
436        );
437    }
438
439    #[test]
440    fn case_insensitive_keywords() {
441        assert_eq!(
442            normalize_cache_key("select * from t"),
443            normalize_cache_key("SELECT * FROM t"),
444        );
445    }
446
447    #[test]
448    fn whitespace_collapses() {
449        assert_eq!(
450            normalize_cache_key("SELECT   *  FROM  t"),
451            normalize_cache_key("SELECT * FROM t"),
452        );
453    }
454
455    #[test]
456    fn different_shape_different_key() {
457        assert_ne!(
458            normalize_cache_key("SELECT * FROM a WHERE x = 1"),
459            normalize_cache_key("SELECT * FROM b WHERE x = 1"),
460        );
461    }
462
463    #[test]
464    fn float_and_scientific_collapse() {
465        assert_eq!(
466            normalize_cache_key("SELECT 1.5e10"),
467            normalize_cache_key("SELECT 3.14"),
468        );
469    }
470
471    #[test]
472    fn null_and_boolean_are_literals() {
473        assert_eq!(
474            normalize_cache_key("WHERE x IS NULL"),
475            normalize_cache_key("WHERE x IS TRUE"),
476        );
477    }
478
479    #[test]
480    fn quoted_identifiers_preserved() {
481        // Double-quoted identifiers stay verbatim so
482        // "col" and "other" don't collapse.
483        assert_ne!(
484            normalize_cache_key(r#"SELECT "col" FROM t"#),
485            normalize_cache_key(r#"SELECT "other" FROM t"#),
486        );
487    }
488
489    #[test]
490    fn limit_and_offset_literals_remain_in_shape() {
491        assert_ne!(
492            normalize_cache_key("SELECT * FROM t WHERE id = 1 LIMIT 10"),
493            normalize_cache_key("SELECT * FROM t WHERE id = 2 LIMIT 20"),
494        );
495        assert_ne!(
496            normalize_cache_key("SELECT * FROM t WHERE id = 1 OFFSET 10"),
497            normalize_cache_key("SELECT * FROM t WHERE id = 2 OFFSET 20"),
498        );
499    }
500
501    #[test]
502    fn normalize_and_extract_agrees_with_separate_paths() {
503        let queries = [
504            "SELECT * FROM users WHERE id = 42",
505            "UPDATE users SET score = 99.5 WHERE city = 'NYC' AND age > 30",
506            "DELETE FROM t WHERE name = 'al''ice' AND active = TRUE",
507            "SELECT 1, 'x', 2.5, NULL, FALSE FROM t",
508            "SELECT * FROM t LIMIT 10 OFFSET 5",
509        ];
510        for q in queries {
511            let (fk, fb) = normalize_and_extract(q);
512            assert_eq!(fk, normalize_cache_key(q), "cache_key mismatch for: {q}");
513            let sep = extract_literal_bindings(q).unwrap();
514            assert_eq!(
515                fb.len(),
516                sep.len(),
517                "bind count mismatch for {q}: fused={:?} sep={:?}",
518                fb,
519                sep
520            );
521            // Compare by string repr (Value doesn't derive PartialEq uniformly).
522            for (a, b) in fb.iter().zip(sep.iter()) {
523                assert_eq!(format!("{a:?}"), format!("{b:?}"), "bind mismatch for {q}");
524            }
525        }
526    }
527
528    #[test]
529    fn extract_literal_bindings_skips_limit_and_offset() {
530        let binds =
531            extract_literal_bindings("SELECT * FROM t WHERE age = 18 AND active = true LIMIT 10")
532                .unwrap();
533        assert_eq!(binds, vec![Value::Integer(18), Value::Boolean(true)]);
534    }
535}