1use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::registry::{AccessPattern, SchemaRegistry, TableSchema, WorkloadType};
9
10#[derive(Debug)]
12pub struct QueryAnalyzer {
13 schema: Arc<SchemaRegistry>,
15}
16
17impl QueryAnalyzer {
18 pub fn new(schema: Arc<SchemaRegistry>) -> Self {
20 Self { schema }
21 }
22
23 pub fn analyze(&self, query: &str) -> QueryAnalysis {
25 let normalized = self.normalize_query(query);
26 let tables = self.extract_tables(&normalized);
27 let access_patterns = self.detect_access_patterns(&normalized, &tables);
28 let shard_keys = self.extract_shard_keys(&normalized, &tables);
29 let workload_type = self.classify_workload(&normalized, &tables);
30
31 QueryAnalysis {
32 original_query: query.to_string(),
33 tables,
34 access_patterns,
35 shard_keys,
36 workload_type,
37 complexity: self.estimate_complexity(&normalized),
38 selectivity: self.estimate_selectivity(&normalized),
39 is_read_only: self.is_read_only(&normalized),
40 has_aggregations: self.has_aggregations(&normalized),
41 has_joins: self.has_joins(&normalized),
42 has_subqueries: self.has_subqueries(&normalized),
43 }
44 }
45
46 fn normalize_query(&self, query: &str) -> String {
48 query
49 .to_uppercase()
50 .replace(['\n', '\t'], " ")
51 .split_whitespace()
52 .collect::<Vec<_>>()
53 .join(" ")
54 }
55
56 pub fn extract_tables(&self, query: &str) -> Vec<TableRef> {
58 let mut tables = Vec::new();
59 let words: Vec<&str> = query.split_whitespace().collect();
60
61 let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
63
64 for (i, word) in words.iter().enumerate() {
65 if table_keywords.contains(word) {
66 if let Some(table_name) = words.get(i + 1) {
67 let name = table_name.trim_matches(|c| c == ',' || c == '(' || c == ')');
68 if !name.is_empty() && !is_keyword(name) {
69 let alias = self.find_alias(&words, i + 1);
70 tables.push(TableRef {
71 name: name.to_lowercase(),
72 alias,
73 schema: self.schema.get_table(&name.to_lowercase()),
74 });
75 }
76 }
77 }
78 }
79
80 tables
81 }
82
83 fn find_alias(&self, words: &[&str], table_idx: usize) -> Option<String> {
85 if let Some(next) = words.get(table_idx + 1) {
86 if next.eq_ignore_ascii_case("AS") {
87 return words.get(table_idx + 2).map(|s| s.to_lowercase());
88 } else if !is_keyword(next) && !next.starts_with('(') {
89 return Some(next.to_lowercase());
90 }
91 }
92 None
93 }
94
95 fn detect_access_patterns(&self, query: &str, tables: &[TableRef]) -> Vec<AccessPattern> {
97 let mut patterns = Vec::new();
98
99 for table in tables {
100 let pattern = self.detect_table_access_pattern(query, table);
101 patterns.push(pattern);
102 }
103
104 patterns
105 }
106
107 fn detect_table_access_pattern(&self, query: &str, table: &TableRef) -> AccessPattern {
109 if self.has_vector_operator(query) {
111 return AccessPattern::VectorSearch;
112 }
113
114 if let Some(schema) = &table.schema {
116 if self.has_equality_on_pk(query, schema) {
117 return AccessPattern::PointLookup;
118 }
119 }
120
121 if self.has_range_predicate(query) {
123 return AccessPattern::RangeScan;
124 }
125
126 if self.is_time_series_append(query) {
128 return AccessPattern::TimeSeriesAppend;
129 }
130
131 if !query.contains("WHERE") {
133 return AccessPattern::FullScan;
134 }
135
136 AccessPattern::Mixed
137 }
138
139 fn has_equality_on_pk(&self, query: &str, schema: &TableSchema) -> bool {
141 if schema.primary_key.is_empty() {
142 return false;
143 }
144
145 for pk_col in &schema.primary_key {
146 let pattern = format!("{} =", pk_col.to_uppercase());
147 if query.contains(&pattern) {
148 return true;
149 }
150 }
151
152 false
153 }
154
155 fn has_range_predicate(&self, query: &str) -> bool {
157 query.contains(" > ")
158 || query.contains(" < ")
159 || query.contains(" >= ")
160 || query.contains(" <= ")
161 || query.contains(" BETWEEN ")
162 }
163
164 fn has_vector_operator(&self, query: &str) -> bool {
166 query.contains("<->")
167 || query.contains("<#>")
168 || query.contains("<=>")
169 || query.contains("VECTOR")
170 || query.contains("EMBEDDING")
171 || query.contains("COSINE_DISTANCE")
172 || query.contains("L2_DISTANCE")
173 }
174
175 fn is_time_series_append(&self, query: &str) -> bool {
177 query.starts_with("INSERT")
178 && (query.contains("TIMESTAMP")
179 || query.contains("CREATED_AT")
180 || query.contains("EVENT_TIME"))
181 }
182
183 fn extract_shard_keys(
185 &self,
186 query: &str,
187 tables: &[TableRef],
188 ) -> HashMap<String, ShardKeyValue> {
189 let mut shard_keys = HashMap::new();
190
191 for table in tables {
192 if let Some(schema) = &table.schema {
193 if let Some(shard_key) = &schema.shard_key {
194 if let Some(value) = self.extract_shard_key_value(query, shard_key) {
195 shard_keys.insert(shard_key.clone(), value);
196 }
197 }
198 }
199 }
200
201 shard_keys
202 }
203
204 fn extract_shard_key_value(&self, query: &str, shard_key: &str) -> Option<ShardKeyValue> {
206 let pattern = format!("{} =", shard_key.to_uppercase());
208 if let Some(idx) = query.find(&pattern) {
209 let rest = &query[idx + pattern.len()..];
210 let value = rest.split_whitespace().next()?;
211 let clean_value = value.trim_matches(|c| c == '\'' || c == '"' || c == ',');
212 return Some(ShardKeyValue::Single(clean_value.to_string()));
213 }
214
215 let in_pattern = format!("{} IN", shard_key.to_uppercase());
217 if let Some(idx) = query.find(&in_pattern) {
218 let rest = &query[idx + in_pattern.len()..];
219 if let Some(start) = rest.find('(') {
220 if let Some(end) = rest.find(')') {
221 let values_str = &rest[start + 1..end];
222 let values: Vec<String> = values_str
223 .split(',')
224 .map(|v| v.trim().trim_matches(|c| c == '\'' || c == '"').to_string())
225 .collect();
226 return Some(ShardKeyValue::Multiple(values));
227 }
228 }
229 }
230
231 None
232 }
233
234 fn classify_workload(&self, query: &str, tables: &[TableRef]) -> WorkloadType {
236 if self.has_vector_operator(query) {
238 return WorkloadType::Vector;
239 }
240
241 if self.has_aggregations(query)
243 || self.has_group_by(query)
244 || self.has_window_functions(query)
245 {
246 return WorkloadType::OLAP;
247 }
248
249 if self.is_simple_crud(query) {
251 return WorkloadType::OLTP;
252 }
253
254 for table in tables {
256 if let Some(schema) = &table.schema {
257 if schema.workload != WorkloadType::Mixed {
258 return schema.workload;
259 }
260 }
261 }
262
263 WorkloadType::Mixed
264 }
265
266 pub fn has_aggregations(&self, query: &str) -> bool {
268 query.contains("COUNT(")
269 || query.contains("SUM(")
270 || query.contains("AVG(")
271 || query.contains("MIN(")
272 || query.contains("MAX(")
273 }
274
275 fn has_group_by(&self, query: &str) -> bool {
277 query.contains("GROUP BY")
278 }
279
280 fn has_window_functions(&self, query: &str) -> bool {
282 query.contains("OVER(")
283 || query.contains("OVER (")
284 || query.contains("ROW_NUMBER")
285 || query.contains("RANK()")
286 || query.contains("DENSE_RANK")
287 || query.contains("LAG(")
288 || query.contains("LEAD(")
289 }
290
291 fn is_simple_crud(&self, query: &str) -> bool {
293 let is_simple_select = query.starts_with("SELECT")
294 && !self.has_joins(query)
295 && !self.has_subqueries(query)
296 && !self.has_aggregations(query);
297
298 let is_simple_insert = query.starts_with("INSERT") && !query.contains("SELECT");
299
300 let is_simple_update = query.starts_with("UPDATE") && query.contains("WHERE");
301
302 let is_simple_delete = query.starts_with("DELETE") && query.contains("WHERE");
303
304 is_simple_select || is_simple_insert || is_simple_update || is_simple_delete
305 }
306
307 pub fn is_read_only(&self, query: &str) -> bool {
309 query.starts_with("SELECT")
310 || query.starts_with("WITH")
311 || query.starts_with("EXPLAIN")
312 || query.starts_with("SHOW")
313 }
314
315 pub fn has_joins(&self, query: &str) -> bool {
317 query.contains(" JOIN ")
318 }
319
320 pub fn has_subqueries(&self, query: &str) -> bool {
322 query.matches("SELECT").count() > 1
324 }
325
326 fn estimate_complexity(&self, query: &str) -> u32 {
328 let mut complexity: u32 = 10; complexity += (query.matches(" JOIN ").count() as u32) * 15;
332
333 let select_count = query.matches("SELECT").count() as u32;
335 if select_count > 1 {
336 complexity += (select_count - 1) * 20;
337 }
338
339 if self.has_aggregations(query) {
341 complexity += 10;
342 }
343
344 if self.has_group_by(query) {
346 complexity += 10;
347 }
348
349 if self.has_window_functions(query) {
351 complexity += 15;
352 }
353
354 if query.contains("ORDER BY") {
356 complexity += 5;
357 }
358
359 if query.contains("DISTINCT") {
361 complexity += 5;
362 }
363
364 complexity.min(100)
365 }
366
367 fn estimate_selectivity(&self, query: &str) -> f64 {
369 if !query.contains("WHERE") {
370 return 1.0; }
372
373 let mut selectivity = 0.5; let eq_count = query.matches(" = ").count();
377 selectivity *= 0.9_f64.powi(eq_count as i32);
378
379 if query.contains("LIMIT") {
381 selectivity *= 0.5;
382 }
383
384 selectivity.max(0.001) }
386
387 pub fn extract_columns(&self, query: &str) -> Vec<String> {
389 let mut columns = HashSet::new();
390 let words: Vec<&str> = query.split_whitespace().collect();
391
392 if let Some(select_idx) = words.iter().position(|w| *w == "SELECT") {
394 if let Some(from_idx) = words.iter().position(|w| *w == "FROM") {
395 for word in &words[select_idx + 1..from_idx] {
396 let col = word.trim_matches(|c| c == ',' || c == '(' || c == ')');
397 if !col.is_empty() && !is_keyword(col) && col != "*" {
398 if let Some(dot_idx) = col.find('.') {
400 columns.insert(col[dot_idx + 1..].to_lowercase());
401 } else {
402 columns.insert(col.to_lowercase());
403 }
404 }
405 }
406 }
407 }
408
409 columns.into_iter().collect()
410 }
411}
412
413fn is_keyword(word: &str) -> bool {
415 let keywords = [
416 "SELECT",
417 "FROM",
418 "WHERE",
419 "JOIN",
420 "ON",
421 "AND",
422 "OR",
423 "NOT",
424 "IN",
425 "IS",
426 "NULL",
427 "AS",
428 "ORDER",
429 "BY",
430 "GROUP",
431 "HAVING",
432 "LIMIT",
433 "OFFSET",
434 "INSERT",
435 "INTO",
436 "VALUES",
437 "UPDATE",
438 "SET",
439 "DELETE",
440 "CREATE",
441 "DROP",
442 "ALTER",
443 "INDEX",
444 "TABLE",
445 "LEFT",
446 "RIGHT",
447 "INNER",
448 "OUTER",
449 "FULL",
450 "CROSS",
451 "NATURAL",
452 "USING",
453 "DISTINCT",
454 "ALL",
455 "UNION",
456 "INTERSECT",
457 "EXCEPT",
458 "CASE",
459 "WHEN",
460 "THEN",
461 "ELSE",
462 "END",
463 "BETWEEN",
464 "LIKE",
465 "ILIKE",
466 "EXISTS",
467 "WITH",
468 "RECURSIVE",
469 "ASC",
470 "DESC",
471 "NULLS",
472 "FIRST",
473 "LAST",
474 ];
475 keywords.contains(&word.to_uppercase().as_str())
476}
477
478#[derive(Debug, Clone)]
480pub struct TableRef {
481 pub name: String,
483 pub alias: Option<String>,
485 pub schema: Option<TableSchema>,
487}
488
489#[derive(Debug, Clone)]
491pub enum ShardKeyValue {
492 Single(String),
494 Multiple(Vec<String>),
496}
497
498#[derive(Debug, Clone)]
500pub struct QueryAnalysis {
501 pub original_query: String,
503 pub tables: Vec<TableRef>,
505 pub access_patterns: Vec<AccessPattern>,
507 pub shard_keys: HashMap<String, ShardKeyValue>,
509 pub workload_type: WorkloadType,
511 pub complexity: u32,
513 pub selectivity: f64,
515 pub is_read_only: bool,
517 pub has_aggregations: bool,
519 pub has_joins: bool,
521 pub has_subqueries: bool,
523}
524
525impl QueryAnalysis {
526 pub fn is_vector_query(&self) -> bool {
528 self.access_patterns.contains(&AccessPattern::VectorSearch)
529 }
530
531 pub fn is_analytics(&self) -> bool {
533 self.workload_type == WorkloadType::OLAP
534 }
535
536 pub fn primary_table(&self) -> Option<&TableRef> {
538 self.tables.first()
539 }
540
541 pub fn has_shard_key(&self) -> bool {
543 !self.shard_keys.is_empty()
544 }
545}
546
547#[cfg(test)]
548mod tests {
549 use super::*;
550
551 fn create_test_registry() -> Arc<SchemaRegistry> {
552 let registry = SchemaRegistry::new();
553
554 let users = TableSchema::new("users")
555 .with_workload(WorkloadType::OLTP)
556 .with_access_pattern(AccessPattern::PointLookup)
557 .with_primary_key(vec!["id".to_string()])
558 .with_shard_key("id");
559
560 let events = TableSchema::new("events")
561 .with_workload(WorkloadType::OLAP)
562 .with_access_pattern(AccessPattern::FullScan);
563
564 let embeddings = TableSchema::new("embeddings")
565 .with_workload(WorkloadType::Vector)
566 .with_access_pattern(AccessPattern::VectorSearch);
567
568 registry.register_table(users);
569 registry.register_table(events);
570 registry.register_table(embeddings);
571
572 Arc::new(registry)
573 }
574
575 #[test]
576 fn test_extract_tables() {
577 let registry = create_test_registry();
578 let analyzer = QueryAnalyzer::new(registry);
579
580 let query = "SELECT * FROM users WHERE id = 1";
581 let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
582
583 assert_eq!(tables.len(), 1);
584 assert_eq!(tables[0].name, "users");
585 }
586
587 #[test]
588 fn test_extract_tables_with_join() {
589 let registry = create_test_registry();
590 let analyzer = QueryAnalyzer::new(registry);
591
592 let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id";
593 let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
594
595 assert_eq!(tables.len(), 2);
596 assert_eq!(tables[0].name, "users");
597 assert_eq!(tables[0].alias, Some("u".to_string()));
598 }
599
600 #[test]
601 fn test_classify_oltp() {
602 let registry = create_test_registry();
603 let analyzer = QueryAnalyzer::new(registry);
604
605 let query = "SELECT * FROM users WHERE id = 1";
606 let analysis = analyzer.analyze(query);
607
608 assert_eq!(analysis.workload_type, WorkloadType::OLTP);
609 assert!(analysis.is_read_only);
610 }
611
612 #[test]
613 fn test_classify_olap() {
614 let registry = create_test_registry();
615 let analyzer = QueryAnalyzer::new(registry);
616
617 let query = "SELECT COUNT(*), SUM(amount) FROM events GROUP BY date";
618 let analysis = analyzer.analyze(query);
619
620 assert_eq!(analysis.workload_type, WorkloadType::OLAP);
621 assert!(analysis.has_aggregations);
622 }
623
624 #[test]
625 fn test_classify_vector() {
626 let registry = create_test_registry();
627 let analyzer = QueryAnalyzer::new(registry);
628
629 let query = "SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10";
630 let analysis = analyzer.analyze(query);
631
632 assert_eq!(analysis.workload_type, WorkloadType::Vector);
633 assert!(analysis.is_vector_query());
634 }
635
636 #[test]
637 fn test_extract_shard_key() {
638 let registry = create_test_registry();
639 let analyzer = QueryAnalyzer::new(registry);
640
641 let query = "SELECT * FROM users WHERE id = 'user_123'";
642 let analysis = analyzer.analyze(query);
643
644 assert!(analysis.has_shard_key());
645 assert!(analysis.shard_keys.contains_key("id"));
646 }
647
648 #[test]
649 fn test_complexity_estimation() {
650 let registry = create_test_registry();
651 let analyzer = QueryAnalyzer::new(registry);
652
653 let simple = "SELECT * FROM users WHERE id = 1";
654 let complex = "SELECT u.*, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.id ORDER BY COUNT(o.id) DESC";
655
656 let simple_analysis = analyzer.analyze(simple);
657 let complex_analysis = analyzer.analyze(complex);
658
659 assert!(simple_analysis.complexity < complex_analysis.complexity);
660 }
661
662 #[test]
663 fn test_detect_point_lookup() {
664 let registry = create_test_registry();
665 let analyzer = QueryAnalyzer::new(registry);
666
667 let query = "SELECT * FROM users WHERE id = 1";
668 let analysis = analyzer.analyze(query);
669
670 assert!(analysis
671 .access_patterns
672 .contains(&AccessPattern::PointLookup));
673 }
674
675 #[test]
676 fn test_detect_full_scan() {
677 let registry = create_test_registry();
678 let analyzer = QueryAnalyzer::new(registry);
679
680 let query = "SELECT * FROM events";
681 let analysis = analyzer.analyze(query);
682
683 assert!(analysis.access_patterns.contains(&AccessPattern::FullScan));
684 }
685
686 #[test]
687 fn test_has_joins() {
688 let registry = create_test_registry();
689 let analyzer = QueryAnalyzer::new(registry);
690
691 let with_join = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
692 let without_join = "SELECT * FROM users";
693
694 assert!(analyzer.analyze(with_join).has_joins);
695 assert!(!analyzer.analyze(without_join).has_joins);
696 }
697
698 #[test]
699 fn test_extract_columns() {
700 let registry = create_test_registry();
701 let analyzer = QueryAnalyzer::new(registry);
702
703 let query = "SELECT id, name, email FROM users WHERE id = 1";
704 let normalized = analyzer.normalize_query(query);
705 let columns = analyzer.extract_columns(&normalized);
706
707 assert!(columns.contains(&"id".to_string()));
708 assert!(columns.contains(&"name".to_string()));
709 assert!(columns.contains(&"email".to_string()));
710 }
711}