1use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::registry::{
9 SchemaRegistry, TableSchema, AccessPattern, WorkloadType,
10};
11
12#[derive(Debug)]
14pub struct QueryAnalyzer {
15 schema: Arc<SchemaRegistry>,
17}
18
19impl QueryAnalyzer {
20 pub fn new(schema: Arc<SchemaRegistry>) -> Self {
22 Self { schema }
23 }
24
25 pub fn analyze(&self, query: &str) -> QueryAnalysis {
27 let normalized = self.normalize_query(query);
28 let tables = self.extract_tables(&normalized);
29 let access_patterns = self.detect_access_patterns(&normalized, &tables);
30 let shard_keys = self.extract_shard_keys(&normalized, &tables);
31 let workload_type = self.classify_workload(&normalized, &tables);
32
33 QueryAnalysis {
34 original_query: query.to_string(),
35 tables,
36 access_patterns,
37 shard_keys,
38 workload_type,
39 complexity: self.estimate_complexity(&normalized),
40 selectivity: self.estimate_selectivity(&normalized),
41 is_read_only: self.is_read_only(&normalized),
42 has_aggregations: self.has_aggregations(&normalized),
43 has_joins: self.has_joins(&normalized),
44 has_subqueries: self.has_subqueries(&normalized),
45 }
46 }
47
48 fn normalize_query(&self, query: &str) -> String {
50 query.to_uppercase()
51 .replace('\n', " ")
52 .replace('\t', " ")
53 .split_whitespace()
54 .collect::<Vec<_>>()
55 .join(" ")
56 }
57
58 pub fn extract_tables(&self, query: &str) -> Vec<TableRef> {
60 let mut tables = Vec::new();
61 let words: Vec<&str> = query.split_whitespace().collect();
62
63 let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
65
66 for (i, word) in words.iter().enumerate() {
67 if table_keywords.contains(word) {
68 if let Some(table_name) = words.get(i + 1) {
69 let name = table_name.trim_matches(|c| c == ',' || c == '(' || c == ')');
70 if !name.is_empty() && !is_keyword(name) {
71 let alias = self.find_alias(&words, i + 1);
72 tables.push(TableRef {
73 name: name.to_lowercase(),
74 alias,
75 schema: self.schema.get_table(&name.to_lowercase()),
76 });
77 }
78 }
79 }
80 }
81
82 tables
83 }
84
85 fn find_alias(&self, words: &[&str], table_idx: usize) -> Option<String> {
87 if let Some(next) = words.get(table_idx + 1) {
88 if next.eq_ignore_ascii_case("AS") {
89 return words.get(table_idx + 2).map(|s| s.to_lowercase());
90 } else if !is_keyword(next) && !next.starts_with('(') {
91 return Some(next.to_lowercase());
92 }
93 }
94 None
95 }
96
97 fn detect_access_patterns(&self, query: &str, tables: &[TableRef]) -> Vec<AccessPattern> {
99 let mut patterns = Vec::new();
100
101 for table in tables {
102 let pattern = self.detect_table_access_pattern(query, table);
103 patterns.push(pattern);
104 }
105
106 patterns
107 }
108
109 fn detect_table_access_pattern(&self, query: &str, table: &TableRef) -> AccessPattern {
111 if self.has_vector_operator(query) {
113 return AccessPattern::VectorSearch;
114 }
115
116 if let Some(schema) = &table.schema {
118 if self.has_equality_on_pk(query, schema) {
119 return AccessPattern::PointLookup;
120 }
121 }
122
123 if self.has_range_predicate(query) {
125 return AccessPattern::RangeScan;
126 }
127
128 if self.is_time_series_append(query) {
130 return AccessPattern::TimeSeriesAppend;
131 }
132
133 if !query.contains("WHERE") {
135 return AccessPattern::FullScan;
136 }
137
138 AccessPattern::Mixed
139 }
140
141 fn has_equality_on_pk(&self, query: &str, schema: &TableSchema) -> bool {
143 if schema.primary_key.is_empty() {
144 return false;
145 }
146
147 for pk_col in &schema.primary_key {
148 let pattern = format!("{} =", pk_col.to_uppercase());
149 if query.contains(&pattern) {
150 return true;
151 }
152 }
153
154 false
155 }
156
157 fn has_range_predicate(&self, query: &str) -> bool {
159 query.contains(" > ") || query.contains(" < ")
160 || query.contains(" >= ") || query.contains(" <= ")
161 || query.contains(" BETWEEN ")
162 }
163
164 fn has_vector_operator(&self, query: &str) -> bool {
166 query.contains("<->") || query.contains("<#>") || query.contains("<=>")
167 || query.contains("VECTOR") || query.contains("EMBEDDING")
168 || query.contains("COSINE_DISTANCE") || query.contains("L2_DISTANCE")
169 }
170
171 fn is_time_series_append(&self, query: &str) -> bool {
173 query.starts_with("INSERT") && (
174 query.contains("TIMESTAMP") || query.contains("CREATED_AT")
175 || query.contains("EVENT_TIME")
176 )
177 }
178
179 fn extract_shard_keys(&self, query: &str, tables: &[TableRef]) -> HashMap<String, ShardKeyValue> {
181 let mut shard_keys = HashMap::new();
182
183 for table in tables {
184 if let Some(schema) = &table.schema {
185 if let Some(shard_key) = &schema.shard_key {
186 if let Some(value) = self.extract_shard_key_value(query, shard_key) {
187 shard_keys.insert(shard_key.clone(), value);
188 }
189 }
190 }
191 }
192
193 shard_keys
194 }
195
196 fn extract_shard_key_value(&self, query: &str, shard_key: &str) -> Option<ShardKeyValue> {
198 let pattern = format!("{} =", shard_key.to_uppercase());
200 if let Some(idx) = query.find(&pattern) {
201 let rest = &query[idx + pattern.len()..];
202 let value = rest.split_whitespace().next()?;
203 let clean_value = value.trim_matches(|c| c == '\'' || c == '"' || c == ',');
204 return Some(ShardKeyValue::Single(clean_value.to_string()));
205 }
206
207 let in_pattern = format!("{} IN", shard_key.to_uppercase());
209 if let Some(idx) = query.find(&in_pattern) {
210 let rest = &query[idx + in_pattern.len()..];
211 if let Some(start) = rest.find('(') {
212 if let Some(end) = rest.find(')') {
213 let values_str = &rest[start + 1..end];
214 let values: Vec<String> = values_str
215 .split(',')
216 .map(|v| v.trim().trim_matches(|c| c == '\'' || c == '"').to_string())
217 .collect();
218 return Some(ShardKeyValue::Multiple(values));
219 }
220 }
221 }
222
223 None
224 }
225
226 fn classify_workload(&self, query: &str, tables: &[TableRef]) -> WorkloadType {
228 if self.has_vector_operator(query) {
230 return WorkloadType::Vector;
231 }
232
233 if self.has_aggregations(query) || self.has_group_by(query) || self.has_window_functions(query) {
235 return WorkloadType::OLAP;
236 }
237
238 if self.is_simple_crud(query) {
240 return WorkloadType::OLTP;
241 }
242
243 for table in tables {
245 if let Some(schema) = &table.schema {
246 if schema.workload != WorkloadType::Mixed {
247 return schema.workload;
248 }
249 }
250 }
251
252 WorkloadType::Mixed
253 }
254
255 pub fn has_aggregations(&self, query: &str) -> bool {
257 query.contains("COUNT(") || query.contains("SUM(")
258 || query.contains("AVG(") || query.contains("MIN(")
259 || query.contains("MAX(")
260 }
261
262 fn has_group_by(&self, query: &str) -> bool {
264 query.contains("GROUP BY")
265 }
266
267 fn has_window_functions(&self, query: &str) -> bool {
269 query.contains("OVER(") || query.contains("OVER (")
270 || query.contains("ROW_NUMBER") || query.contains("RANK()")
271 || query.contains("DENSE_RANK") || query.contains("LAG(")
272 || query.contains("LEAD(")
273 }
274
275 fn is_simple_crud(&self, query: &str) -> bool {
277 let is_simple_select = query.starts_with("SELECT")
278 && !self.has_joins(query)
279 && !self.has_subqueries(query)
280 && !self.has_aggregations(query);
281
282 let is_simple_insert = query.starts_with("INSERT")
283 && !query.contains("SELECT");
284
285 let is_simple_update = query.starts_with("UPDATE")
286 && query.contains("WHERE");
287
288 let is_simple_delete = query.starts_with("DELETE")
289 && query.contains("WHERE");
290
291 is_simple_select || is_simple_insert || is_simple_update || is_simple_delete
292 }
293
294 pub fn is_read_only(&self, query: &str) -> bool {
296 query.starts_with("SELECT") || query.starts_with("WITH")
297 || query.starts_with("EXPLAIN") || query.starts_with("SHOW")
298 }
299
300 pub fn has_joins(&self, query: &str) -> bool {
302 query.contains(" JOIN ")
303 }
304
305 pub fn has_subqueries(&self, query: &str) -> bool {
307 query.matches("SELECT").count() > 1
309 }
310
311 fn estimate_complexity(&self, query: &str) -> u32 {
313 let mut complexity: u32 = 10; complexity += (query.matches(" JOIN ").count() as u32) * 15;
317
318 let select_count = query.matches("SELECT").count() as u32;
320 if select_count > 1 {
321 complexity += (select_count - 1) * 20;
322 }
323
324 if self.has_aggregations(query) {
326 complexity += 10;
327 }
328
329 if self.has_group_by(query) {
331 complexity += 10;
332 }
333
334 if self.has_window_functions(query) {
336 complexity += 15;
337 }
338
339 if query.contains("ORDER BY") {
341 complexity += 5;
342 }
343
344 if query.contains("DISTINCT") {
346 complexity += 5;
347 }
348
349 complexity.min(100)
350 }
351
352 fn estimate_selectivity(&self, query: &str) -> f64 {
354 if !query.contains("WHERE") {
355 return 1.0; }
357
358 let mut selectivity = 0.5; let eq_count = query.matches(" = ").count();
362 selectivity *= 0.9_f64.powi(eq_count as i32);
363
364 if query.contains("LIMIT") {
366 selectivity *= 0.5;
367 }
368
369 selectivity.max(0.001) }
371
372 pub fn extract_columns(&self, query: &str) -> Vec<String> {
374 let mut columns = HashSet::new();
375 let words: Vec<&str> = query.split_whitespace().collect();
376
377 if let Some(select_idx) = words.iter().position(|w| *w == "SELECT") {
379 if let Some(from_idx) = words.iter().position(|w| *w == "FROM") {
380 for word in &words[select_idx + 1..from_idx] {
381 let col = word.trim_matches(|c| c == ',' || c == '(' || c == ')');
382 if !col.is_empty() && !is_keyword(col) && col != "*" {
383 if let Some(dot_idx) = col.find('.') {
385 columns.insert(col[dot_idx + 1..].to_lowercase());
386 } else {
387 columns.insert(col.to_lowercase());
388 }
389 }
390 }
391 }
392 }
393
394 columns.into_iter().collect()
395 }
396}
397
398fn is_keyword(word: &str) -> bool {
400 let keywords = [
401 "SELECT", "FROM", "WHERE", "JOIN", "ON", "AND", "OR", "NOT",
402 "IN", "IS", "NULL", "AS", "ORDER", "BY", "GROUP", "HAVING",
403 "LIMIT", "OFFSET", "INSERT", "INTO", "VALUES", "UPDATE", "SET",
404 "DELETE", "CREATE", "DROP", "ALTER", "INDEX", "TABLE", "LEFT",
405 "RIGHT", "INNER", "OUTER", "FULL", "CROSS", "NATURAL", "USING",
406 "DISTINCT", "ALL", "UNION", "INTERSECT", "EXCEPT", "CASE",
407 "WHEN", "THEN", "ELSE", "END", "BETWEEN", "LIKE", "ILIKE",
408 "EXISTS", "WITH", "RECURSIVE", "ASC", "DESC", "NULLS", "FIRST", "LAST",
409 ];
410 keywords.contains(&word.to_uppercase().as_str())
411}
412
413#[derive(Debug, Clone)]
415pub struct TableRef {
416 pub name: String,
418 pub alias: Option<String>,
420 pub schema: Option<TableSchema>,
422}
423
424#[derive(Debug, Clone)]
426pub enum ShardKeyValue {
427 Single(String),
429 Multiple(Vec<String>),
431}
432
433#[derive(Debug, Clone)]
435pub struct QueryAnalysis {
436 pub original_query: String,
438 pub tables: Vec<TableRef>,
440 pub access_patterns: Vec<AccessPattern>,
442 pub shard_keys: HashMap<String, ShardKeyValue>,
444 pub workload_type: WorkloadType,
446 pub complexity: u32,
448 pub selectivity: f64,
450 pub is_read_only: bool,
452 pub has_aggregations: bool,
454 pub has_joins: bool,
456 pub has_subqueries: bool,
458}
459
460impl QueryAnalysis {
461 pub fn is_vector_query(&self) -> bool {
463 self.access_patterns.contains(&AccessPattern::VectorSearch)
464 }
465
466 pub fn is_analytics(&self) -> bool {
468 self.workload_type == WorkloadType::OLAP
469 }
470
471 pub fn primary_table(&self) -> Option<&TableRef> {
473 self.tables.first()
474 }
475
476 pub fn has_shard_key(&self) -> bool {
478 !self.shard_keys.is_empty()
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 fn create_test_registry() -> Arc<SchemaRegistry> {
487 let registry = SchemaRegistry::new();
488
489 let users = TableSchema::new("users")
490 .with_workload(WorkloadType::OLTP)
491 .with_access_pattern(AccessPattern::PointLookup)
492 .with_primary_key(vec!["id".to_string()])
493 .with_shard_key("id");
494
495 let events = TableSchema::new("events")
496 .with_workload(WorkloadType::OLAP)
497 .with_access_pattern(AccessPattern::FullScan);
498
499 let embeddings = TableSchema::new("embeddings")
500 .with_workload(WorkloadType::Vector)
501 .with_access_pattern(AccessPattern::VectorSearch);
502
503 registry.register_table(users);
504 registry.register_table(events);
505 registry.register_table(embeddings);
506
507 Arc::new(registry)
508 }
509
510 #[test]
511 fn test_extract_tables() {
512 let registry = create_test_registry();
513 let analyzer = QueryAnalyzer::new(registry);
514
515 let query = "SELECT * FROM users WHERE id = 1";
516 let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
517
518 assert_eq!(tables.len(), 1);
519 assert_eq!(tables[0].name, "users");
520 }
521
522 #[test]
523 fn test_extract_tables_with_join() {
524 let registry = create_test_registry();
525 let analyzer = QueryAnalyzer::new(registry);
526
527 let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id";
528 let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
529
530 assert_eq!(tables.len(), 2);
531 assert_eq!(tables[0].name, "users");
532 assert_eq!(tables[0].alias, Some("u".to_string()));
533 }
534
535 #[test]
536 fn test_classify_oltp() {
537 let registry = create_test_registry();
538 let analyzer = QueryAnalyzer::new(registry);
539
540 let query = "SELECT * FROM users WHERE id = 1";
541 let analysis = analyzer.analyze(query);
542
543 assert_eq!(analysis.workload_type, WorkloadType::OLTP);
544 assert!(analysis.is_read_only);
545 }
546
547 #[test]
548 fn test_classify_olap() {
549 let registry = create_test_registry();
550 let analyzer = QueryAnalyzer::new(registry);
551
552 let query = "SELECT COUNT(*), SUM(amount) FROM events GROUP BY date";
553 let analysis = analyzer.analyze(query);
554
555 assert_eq!(analysis.workload_type, WorkloadType::OLAP);
556 assert!(analysis.has_aggregations);
557 }
558
559 #[test]
560 fn test_classify_vector() {
561 let registry = create_test_registry();
562 let analyzer = QueryAnalyzer::new(registry);
563
564 let query = "SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10";
565 let analysis = analyzer.analyze(query);
566
567 assert_eq!(analysis.workload_type, WorkloadType::Vector);
568 assert!(analysis.is_vector_query());
569 }
570
571 #[test]
572 fn test_extract_shard_key() {
573 let registry = create_test_registry();
574 let analyzer = QueryAnalyzer::new(registry);
575
576 let query = "SELECT * FROM users WHERE id = 'user_123'";
577 let analysis = analyzer.analyze(query);
578
579 assert!(analysis.has_shard_key());
580 assert!(analysis.shard_keys.contains_key("id"));
581 }
582
583 #[test]
584 fn test_complexity_estimation() {
585 let registry = create_test_registry();
586 let analyzer = QueryAnalyzer::new(registry);
587
588 let simple = "SELECT * FROM users WHERE id = 1";
589 let complex = "SELECT u.*, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.id ORDER BY COUNT(o.id) DESC";
590
591 let simple_analysis = analyzer.analyze(simple);
592 let complex_analysis = analyzer.analyze(complex);
593
594 assert!(simple_analysis.complexity < complex_analysis.complexity);
595 }
596
597 #[test]
598 fn test_detect_point_lookup() {
599 let registry = create_test_registry();
600 let analyzer = QueryAnalyzer::new(registry);
601
602 let query = "SELECT * FROM users WHERE id = 1";
603 let analysis = analyzer.analyze(query);
604
605 assert!(analysis.access_patterns.contains(&AccessPattern::PointLookup));
606 }
607
608 #[test]
609 fn test_detect_full_scan() {
610 let registry = create_test_registry();
611 let analyzer = QueryAnalyzer::new(registry);
612
613 let query = "SELECT * FROM events";
614 let analysis = analyzer.analyze(query);
615
616 assert!(analysis.access_patterns.contains(&AccessPattern::FullScan));
617 }
618
619 #[test]
620 fn test_has_joins() {
621 let registry = create_test_registry();
622 let analyzer = QueryAnalyzer::new(registry);
623
624 let with_join = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
625 let without_join = "SELECT * FROM users";
626
627 assert!(analyzer.analyze(with_join).has_joins);
628 assert!(!analyzer.analyze(without_join).has_joins);
629 }
630
631 #[test]
632 fn test_extract_columns() {
633 let registry = create_test_registry();
634 let analyzer = QueryAnalyzer::new(registry);
635
636 let query = "SELECT id, name, email FROM users WHERE id = 1";
637 let normalized = analyzer.normalize_query(query);
638 let columns = analyzer.extract_columns(&normalized);
639
640 assert!(columns.contains(&"id".to_string()));
641 assert!(columns.contains(&"name".to_string()));
642 assert!(columns.contains(&"email".to_string()));
643 }
644}