reddb_server/storage/query/planner/
cache_key.rs1use crate::storage::query::lexer::{Lexer, Token};
54use crate::storage::schema::Value;
55
56pub fn normalize_cache_key(sql: &str) -> String {
63 let mut out = String::with_capacity(sql.len());
64 let bytes = sql.as_bytes();
65 let mut i = 0;
66 let mut last_was_space = true; let mut preserve_numeric_literal = false;
68 while i < bytes.len() {
69 let b = bytes[i];
70
71 if b.is_ascii_whitespace() {
73 if !last_was_space {
74 out.push(' ');
75 last_was_space = true;
76 }
77 i += 1;
78 continue;
79 }
80
81 if b == b'\'' {
83 i += 1;
84 while i < bytes.len() {
85 if bytes[i] == b'\'' {
86 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
90 i += 2;
91 continue;
92 }
93 i += 1;
94 break;
95 }
96 i += 1;
97 }
98 out.push('?');
99 last_was_space = false;
100 continue;
101 }
102
103 if b == b'"' {
107 let start = i;
108 i += 1;
109 while i < bytes.len() && bytes[i] != b'"' {
110 i += 1;
111 }
112 if i < bytes.len() {
113 i += 1;
114 }
115 out.push_str(&sql[start..i]);
116 last_was_space = false;
117 continue;
118 }
119
120 if b.is_ascii_digit() {
125 let start = i;
126 while i < bytes.len()
127 && (bytes[i].is_ascii_digit()
128 || bytes[i] == b'.'
129 || bytes[i] == b'e'
130 || bytes[i] == b'E'
131 || bytes[i] == b'+'
132 || bytes[i] == b'-')
133 {
134 if bytes[i] == b'+' || bytes[i] == b'-' {
137 let prev = if i > 0 { bytes[i - 1] } else { 0 };
138 if prev != b'e' && prev != b'E' {
139 break;
140 }
141 }
142 i += 1;
143 }
144 if preserve_numeric_literal {
145 out.push_str(&sql[start..i]);
146 preserve_numeric_literal = false;
147 } else {
148 out.push('?');
149 }
150 last_was_space = false;
151 continue;
152 }
153
154 if b.is_ascii_alphabetic() || b == b'_' {
156 let start = i;
157 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
158 i += 1;
159 }
160 let word = &sql[start..i];
161 if word.eq_ignore_ascii_case("true")
164 || word.eq_ignore_ascii_case("false")
165 || word.eq_ignore_ascii_case("null")
166 {
167 out.push('?');
168 preserve_numeric_literal = false;
169 } else {
170 for c in word.chars() {
177 out.push(c.to_ascii_uppercase());
178 }
179 preserve_numeric_literal =
180 word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
181 }
182 last_was_space = false;
183 continue;
184 }
185
186 out.push(b as char);
189 preserve_numeric_literal = false;
190 last_was_space = false;
191 i += 1;
192 }
193
194 if out.ends_with(' ') {
197 out.pop();
198 }
199
200 out
201}
202
203pub fn same_cache_key(a: &str, b: &str) -> bool {
207 normalize_cache_key(a) == normalize_cache_key(b)
208}
209
210pub fn normalize_and_extract(sql: &str) -> (String, Vec<Value>) {
221 let mut out = String::with_capacity(sql.len());
222 let mut binds: Vec<Value> = Vec::new();
223 let bytes = sql.as_bytes();
224 let mut i = 0;
225 let mut last_was_space = true;
226 let mut preserve_numeric_literal = false;
227 while i < bytes.len() {
228 let b = bytes[i];
229
230 if b.is_ascii_whitespace() {
231 if !last_was_space {
232 out.push(' ');
233 last_was_space = true;
234 }
235 i += 1;
236 continue;
237 }
238
239 if b == b'\'' {
240 i += 1;
243 let body_start = i;
244 let mut literal: Option<String> = None;
245 while i < bytes.len() {
246 if bytes[i] == b'\'' {
247 if i + 1 < bytes.len() && bytes[i + 1] == b'\'' {
248 let acc = literal.get_or_insert_with(|| sql[body_start..i].to_string());
252 acc.push('\'');
253 i += 2;
254 continue;
255 }
256 break;
257 }
258 if let Some(ref mut acc) = literal {
259 acc.push(bytes[i] as char);
260 }
261 i += 1;
262 }
263 let value = match literal {
264 Some(s) => s,
265 None => sql[body_start..i].to_string(),
266 };
267 if i < bytes.len() && bytes[i] == b'\'' {
268 i += 1;
269 }
270 binds.push(Value::text(value));
271 out.push('?');
272 last_was_space = false;
273 continue;
274 }
275
276 if b == b'"' {
277 let start = i;
278 i += 1;
279 while i < bytes.len() && bytes[i] != b'"' {
280 i += 1;
281 }
282 if i < bytes.len() {
283 i += 1;
284 }
285 out.push_str(&sql[start..i]);
286 last_was_space = false;
287 continue;
288 }
289
290 if b.is_ascii_digit() {
291 let start = i;
292 while i < bytes.len()
293 && (bytes[i].is_ascii_digit()
294 || bytes[i] == b'.'
295 || bytes[i] == b'e'
296 || bytes[i] == b'E'
297 || bytes[i] == b'+'
298 || bytes[i] == b'-')
299 {
300 if bytes[i] == b'+' || bytes[i] == b'-' {
301 let prev = if i > 0 { bytes[i - 1] } else { 0 };
302 if prev != b'e' && prev != b'E' {
303 break;
304 }
305 }
306 i += 1;
307 }
308 let lit = &sql[start..i];
309 if preserve_numeric_literal {
310 out.push_str(lit);
311 preserve_numeric_literal = false;
312 } else {
313 out.push('?');
314 if lit.contains('.') || lit.contains('e') || lit.contains('E') {
315 if let Ok(v) = lit.parse::<f64>() {
316 binds.push(Value::Float(v));
317 }
318 } else if let Ok(v) = lit.parse::<i64>() {
319 binds.push(Value::Integer(v));
320 } else if let Ok(v) = lit.parse::<u64>() {
321 binds.push(Value::UnsignedInteger(v));
322 }
323 }
324 last_was_space = false;
325 continue;
326 }
327
328 if b.is_ascii_alphabetic() || b == b'_' {
329 let start = i;
330 while i < bytes.len() && (bytes[i].is_ascii_alphanumeric() || bytes[i] == b'_') {
331 i += 1;
332 }
333 let word = &sql[start..i];
334 if word.eq_ignore_ascii_case("true") {
335 out.push('?');
336 binds.push(Value::Boolean(true));
337 preserve_numeric_literal = false;
338 } else if word.eq_ignore_ascii_case("false") {
339 out.push('?');
340 binds.push(Value::Boolean(false));
341 preserve_numeric_literal = false;
342 } else if word.eq_ignore_ascii_case("null") {
343 out.push('?');
344 binds.push(Value::Null);
345 preserve_numeric_literal = false;
346 } else {
347 for c in word.chars() {
348 out.push(c.to_ascii_uppercase());
349 }
350 preserve_numeric_literal =
351 word.eq_ignore_ascii_case("limit") || word.eq_ignore_ascii_case("offset");
352 }
353 last_was_space = false;
354 continue;
355 }
356
357 out.push(b as char);
358 preserve_numeric_literal = false;
359 last_was_space = false;
360 i += 1;
361 }
362
363 if out.ends_with(' ') {
364 out.pop();
365 }
366
367 (out, binds)
368}
369
370pub fn extract_literal_bindings(sql: &str) -> Result<Vec<Value>, String> {
371 let mut lexer = Lexer::new(sql);
372 let mut binds = Vec::new();
373 let mut skip_next_numeric = false;
374
375 loop {
376 let spanned = lexer.next_token().map_err(|err| err.to_string())?;
377 match spanned.token {
378 Token::Eof => break,
379 Token::Limit | Token::Offset => {
380 skip_next_numeric = true;
381 }
382 Token::Integer(n) => {
383 if !skip_next_numeric {
384 binds.push(Value::Integer(n));
385 }
386 skip_next_numeric = false;
387 }
388 Token::Float(n) => {
389 if !skip_next_numeric {
390 binds.push(Value::Float(n));
391 }
392 skip_next_numeric = false;
393 }
394 Token::String(s) => {
395 binds.push(Value::text(s));
396 skip_next_numeric = false;
397 }
398 Token::True => {
399 binds.push(Value::Boolean(true));
400 skip_next_numeric = false;
401 }
402 Token::False => {
403 binds.push(Value::Boolean(false));
404 skip_next_numeric = false;
405 }
406 Token::Null => {
407 binds.push(Value::Null);
408 skip_next_numeric = false;
409 }
410 _ => {
411 skip_next_numeric = false;
412 }
413 }
414 }
415
416 Ok(binds)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 #[test]
424 fn integer_literals_collapse() {
425 assert_eq!(
426 normalize_cache_key("SELECT * FROM t WHERE id = 1"),
427 normalize_cache_key("SELECT * FROM t WHERE id = 2"),
428 );
429 }
430
431 #[test]
432 fn string_literals_collapse() {
433 assert_eq!(
434 normalize_cache_key("SELECT * FROM t WHERE name = 'alice'"),
435 normalize_cache_key("SELECT * FROM t WHERE name = 'bob'"),
436 );
437 }
438
439 #[test]
440 fn case_insensitive_keywords() {
441 assert_eq!(
442 normalize_cache_key("select * from t"),
443 normalize_cache_key("SELECT * FROM t"),
444 );
445 }
446
447 #[test]
448 fn whitespace_collapses() {
449 assert_eq!(
450 normalize_cache_key("SELECT * FROM t"),
451 normalize_cache_key("SELECT * FROM t"),
452 );
453 }
454
455 #[test]
456 fn different_shape_different_key() {
457 assert_ne!(
458 normalize_cache_key("SELECT * FROM a WHERE x = 1"),
459 normalize_cache_key("SELECT * FROM b WHERE x = 1"),
460 );
461 }
462
463 #[test]
464 fn float_and_scientific_collapse() {
465 assert_eq!(
466 normalize_cache_key("SELECT 1.5e10"),
467 normalize_cache_key("SELECT 3.14"),
468 );
469 }
470
471 #[test]
472 fn null_and_boolean_are_literals() {
473 assert_eq!(
474 normalize_cache_key("WHERE x IS NULL"),
475 normalize_cache_key("WHERE x IS TRUE"),
476 );
477 }
478
479 #[test]
480 fn quoted_identifiers_preserved() {
481 assert_ne!(
484 normalize_cache_key(r#"SELECT "col" FROM t"#),
485 normalize_cache_key(r#"SELECT "other" FROM t"#),
486 );
487 }
488
489 #[test]
490 fn limit_and_offset_literals_remain_in_shape() {
491 assert_ne!(
492 normalize_cache_key("SELECT * FROM t WHERE id = 1 LIMIT 10"),
493 normalize_cache_key("SELECT * FROM t WHERE id = 2 LIMIT 20"),
494 );
495 assert_ne!(
496 normalize_cache_key("SELECT * FROM t WHERE id = 1 OFFSET 10"),
497 normalize_cache_key("SELECT * FROM t WHERE id = 2 OFFSET 20"),
498 );
499 }
500
501 #[test]
502 fn normalize_and_extract_agrees_with_separate_paths() {
503 let queries = [
504 "SELECT * FROM users WHERE id = 42",
505 "UPDATE users SET score = 99.5 WHERE city = 'NYC' AND age > 30",
506 "DELETE FROM t WHERE name = 'al''ice' AND active = TRUE",
507 "SELECT 1, 'x', 2.5, NULL, FALSE FROM t",
508 "SELECT * FROM t LIMIT 10 OFFSET 5",
509 ];
510 for q in queries {
511 let (fk, fb) = normalize_and_extract(q);
512 assert_eq!(fk, normalize_cache_key(q), "cache_key mismatch for: {q}");
513 let sep = extract_literal_bindings(q).unwrap();
514 assert_eq!(
515 fb.len(),
516 sep.len(),
517 "bind count mismatch for {q}: fused={:?} sep={:?}",
518 fb,
519 sep
520 );
521 for (a, b) in fb.iter().zip(sep.iter()) {
523 assert_eq!(format!("{a:?}"), format!("{b:?}"), "bind mismatch for {q}");
524 }
525 }
526 }
527
528 #[test]
529 fn extract_literal_bindings_skips_limit_and_offset() {
530 let binds =
531 extract_literal_bindings("SELECT * FROM t WHERE age = 18 AND active = true LIMIT 10")
532 .unwrap();
533 assert_eq!(binds, vec![Value::Integer(18), Value::Boolean(true)]);
534 }
535}