Skip to main content

heliosdb_proxy/schema_routing/
analyzer.rs

1//! Query Analyzer
2//!
3//! Analyzes SQL queries to determine routing requirements.
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::registry::{AccessPattern, SchemaRegistry, TableSchema, WorkloadType};
9
10/// Query analyzer for schema-aware routing
11#[derive(Debug)]
12pub struct QueryAnalyzer {
13    /// Schema registry reference
14    schema: Arc<SchemaRegistry>,
15}
16
17impl QueryAnalyzer {
18    /// Create a new query analyzer
19    pub fn new(schema: Arc<SchemaRegistry>) -> Self {
20        Self { schema }
21    }
22
23    /// Analyze a query and determine routing requirements
24    pub fn analyze(&self, query: &str) -> QueryAnalysis {
25        let normalized = self.normalize_query(query);
26        let tables = self.extract_tables(&normalized);
27        let access_patterns = self.detect_access_patterns(&normalized, &tables);
28        let shard_keys = self.extract_shard_keys(&normalized, &tables);
29        let workload_type = self.classify_workload(&normalized, &tables);
30
31        QueryAnalysis {
32            original_query: query.to_string(),
33            tables,
34            access_patterns,
35            shard_keys,
36            workload_type,
37            complexity: self.estimate_complexity(&normalized),
38            selectivity: self.estimate_selectivity(&normalized),
39            is_read_only: self.is_read_only(&normalized),
40            has_aggregations: self.has_aggregations(&normalized),
41            has_joins: self.has_joins(&normalized),
42            has_subqueries: self.has_subqueries(&normalized),
43        }
44    }
45
46    /// Normalize query for analysis
47    fn normalize_query(&self, query: &str) -> String {
48        query
49            .to_uppercase()
50            .replace(['\n', '\t'], " ")
51            .split_whitespace()
52            .collect::<Vec<_>>()
53            .join(" ")
54    }
55
56    /// Extract tables from query
57    pub fn extract_tables(&self, query: &str) -> Vec<TableRef> {
58        let mut tables = Vec::new();
59        let words: Vec<&str> = query.split_whitespace().collect();
60
61        // Find tables after FROM, JOIN, INTO, UPDATE
62        let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
63
64        for (i, word) in words.iter().enumerate() {
65            if table_keywords.contains(word) {
66                if let Some(table_name) = words.get(i + 1) {
67                    let name = table_name.trim_matches(|c| c == ',' || c == '(' || c == ')');
68                    if !name.is_empty() && !is_keyword(name) {
69                        let alias = self.find_alias(&words, i + 1);
70                        tables.push(TableRef {
71                            name: name.to_lowercase(),
72                            alias,
73                            schema: self.schema.get_table(&name.to_lowercase()),
74                        });
75                    }
76                }
77            }
78        }
79
80        tables
81    }
82
83    /// Find alias for a table
84    fn find_alias(&self, words: &[&str], table_idx: usize) -> Option<String> {
85        if let Some(next) = words.get(table_idx + 1) {
86            if next.eq_ignore_ascii_case("AS") {
87                return words.get(table_idx + 2).map(|s| s.to_lowercase());
88            } else if !is_keyword(next) && !next.starts_with('(') {
89                return Some(next.to_lowercase());
90            }
91        }
92        None
93    }
94
95    /// Detect access patterns for each table
96    fn detect_access_patterns(&self, query: &str, tables: &[TableRef]) -> Vec<AccessPattern> {
97        let mut patterns = Vec::new();
98
99        for table in tables {
100            let pattern = self.detect_table_access_pattern(query, table);
101            patterns.push(pattern);
102        }
103
104        patterns
105    }
106
107    /// Detect access pattern for a specific table
108    fn detect_table_access_pattern(&self, query: &str, table: &TableRef) -> AccessPattern {
109        // Check for vector operations
110        if self.has_vector_operator(query) {
111            return AccessPattern::VectorSearch;
112        }
113
114        // Check for point lookup (equality on PK)
115        if let Some(schema) = &table.schema {
116            if self.has_equality_on_pk(query, schema) {
117                return AccessPattern::PointLookup;
118            }
119        }
120
121        // Check for range predicates
122        if self.has_range_predicate(query) {
123            return AccessPattern::RangeScan;
124        }
125
126        // Check for time-series patterns
127        if self.is_time_series_append(query) {
128            return AccessPattern::TimeSeriesAppend;
129        }
130
131        // Default to full scan if no WHERE clause
132        if !query.contains("WHERE") {
133            return AccessPattern::FullScan;
134        }
135
136        AccessPattern::Mixed
137    }
138
139    /// Check for equality on primary key
140    fn has_equality_on_pk(&self, query: &str, schema: &TableSchema) -> bool {
141        if schema.primary_key.is_empty() {
142            return false;
143        }
144
145        for pk_col in &schema.primary_key {
146            let pattern = format!("{} =", pk_col.to_uppercase());
147            if query.contains(&pattern) {
148                return true;
149            }
150        }
151
152        false
153    }
154
155    /// Check for range predicates
156    fn has_range_predicate(&self, query: &str) -> bool {
157        query.contains(" > ")
158            || query.contains(" < ")
159            || query.contains(" >= ")
160            || query.contains(" <= ")
161            || query.contains(" BETWEEN ")
162    }
163
164    /// Check for vector operators
165    fn has_vector_operator(&self, query: &str) -> bool {
166        query.contains("<->")
167            || query.contains("<#>")
168            || query.contains("<=>")
169            || query.contains("VECTOR")
170            || query.contains("EMBEDDING")
171            || query.contains("COSINE_DISTANCE")
172            || query.contains("L2_DISTANCE")
173    }
174
175    /// Check for time-series append pattern
176    fn is_time_series_append(&self, query: &str) -> bool {
177        query.starts_with("INSERT")
178            && (query.contains("TIMESTAMP")
179                || query.contains("CREATED_AT")
180                || query.contains("EVENT_TIME"))
181    }
182
183    /// Extract shard keys from query
184    fn extract_shard_keys(
185        &self,
186        query: &str,
187        tables: &[TableRef],
188    ) -> HashMap<String, ShardKeyValue> {
189        let mut shard_keys = HashMap::new();
190
191        for table in tables {
192            if let Some(schema) = &table.schema {
193                if let Some(shard_key) = &schema.shard_key {
194                    if let Some(value) = self.extract_shard_key_value(query, shard_key) {
195                        shard_keys.insert(shard_key.clone(), value);
196                    }
197                }
198            }
199        }
200
201        shard_keys
202    }
203
204    /// Extract shard key value from query
205    fn extract_shard_key_value(&self, query: &str, shard_key: &str) -> Option<ShardKeyValue> {
206        // Look for patterns like "shard_key = 'value'" or "shard_key = value"
207        let pattern = format!("{} =", shard_key.to_uppercase());
208        if let Some(idx) = query.find(&pattern) {
209            let rest = &query[idx + pattern.len()..];
210            let value = rest.split_whitespace().next()?;
211            let clean_value = value.trim_matches(|c| c == '\'' || c == '"' || c == ',');
212            return Some(ShardKeyValue::Single(clean_value.to_string()));
213        }
214
215        // Look for IN clause
216        let in_pattern = format!("{} IN", shard_key.to_uppercase());
217        if let Some(idx) = query.find(&in_pattern) {
218            let rest = &query[idx + in_pattern.len()..];
219            if let Some(start) = rest.find('(') {
220                if let Some(end) = rest.find(')') {
221                    let values_str = &rest[start + 1..end];
222                    let values: Vec<String> = values_str
223                        .split(',')
224                        .map(|v| v.trim().trim_matches(|c| c == '\'' || c == '"').to_string())
225                        .collect();
226                    return Some(ShardKeyValue::Multiple(values));
227                }
228            }
229        }
230
231        None
232    }
233
234    /// Classify workload type
235    fn classify_workload(&self, query: &str, tables: &[TableRef]) -> WorkloadType {
236        // Vector queries
237        if self.has_vector_operator(query) {
238            return WorkloadType::Vector;
239        }
240
241        // OLAP indicators
242        if self.has_aggregations(query)
243            || self.has_group_by(query)
244            || self.has_window_functions(query)
245        {
246            return WorkloadType::OLAP;
247        }
248
249        // Simple CRUD is OLTP
250        if self.is_simple_crud(query) {
251            return WorkloadType::OLTP;
252        }
253
254        // Check table hints
255        for table in tables {
256            if let Some(schema) = &table.schema {
257                if schema.workload != WorkloadType::Mixed {
258                    return schema.workload;
259                }
260            }
261        }
262
263        WorkloadType::Mixed
264    }
265
266    /// Check if query has aggregations
267    pub fn has_aggregations(&self, query: &str) -> bool {
268        query.contains("COUNT(")
269            || query.contains("SUM(")
270            || query.contains("AVG(")
271            || query.contains("MIN(")
272            || query.contains("MAX(")
273    }
274
275    /// Check if query has GROUP BY
276    fn has_group_by(&self, query: &str) -> bool {
277        query.contains("GROUP BY")
278    }
279
280    /// Check if query has window functions
281    fn has_window_functions(&self, query: &str) -> bool {
282        query.contains("OVER(")
283            || query.contains("OVER (")
284            || query.contains("ROW_NUMBER")
285            || query.contains("RANK()")
286            || query.contains("DENSE_RANK")
287            || query.contains("LAG(")
288            || query.contains("LEAD(")
289    }
290
291    /// Check if query is simple CRUD
292    fn is_simple_crud(&self, query: &str) -> bool {
293        let is_simple_select = query.starts_with("SELECT")
294            && !self.has_joins(query)
295            && !self.has_subqueries(query)
296            && !self.has_aggregations(query);
297
298        let is_simple_insert = query.starts_with("INSERT") && !query.contains("SELECT");
299
300        let is_simple_update = query.starts_with("UPDATE") && query.contains("WHERE");
301
302        let is_simple_delete = query.starts_with("DELETE") && query.contains("WHERE");
303
304        is_simple_select || is_simple_insert || is_simple_update || is_simple_delete
305    }
306
307    /// Check if query is read-only
308    pub fn is_read_only(&self, query: &str) -> bool {
309        query.starts_with("SELECT")
310            || query.starts_with("WITH")
311            || query.starts_with("EXPLAIN")
312            || query.starts_with("SHOW")
313    }
314
315    /// Check if query has joins
316    pub fn has_joins(&self, query: &str) -> bool {
317        query.contains(" JOIN ")
318    }
319
320    /// Check if query has subqueries
321    pub fn has_subqueries(&self, query: &str) -> bool {
322        // Count SELECT keywords (more than one suggests subqueries)
323        query.matches("SELECT").count() > 1
324    }
325
326    /// Estimate query complexity (0-100)
327    fn estimate_complexity(&self, query: &str) -> u32 {
328        let mut complexity: u32 = 10; // Base complexity
329
330        // Add for joins
331        complexity += (query.matches(" JOIN ").count() as u32) * 15;
332
333        // Add for subqueries
334        let select_count = query.matches("SELECT").count() as u32;
335        if select_count > 1 {
336            complexity += (select_count - 1) * 20;
337        }
338
339        // Add for aggregations
340        if self.has_aggregations(query) {
341            complexity += 10;
342        }
343
344        // Add for GROUP BY
345        if self.has_group_by(query) {
346            complexity += 10;
347        }
348
349        // Add for window functions
350        if self.has_window_functions(query) {
351            complexity += 15;
352        }
353
354        // Add for ORDER BY
355        if query.contains("ORDER BY") {
356            complexity += 5;
357        }
358
359        // Add for DISTINCT
360        if query.contains("DISTINCT") {
361            complexity += 5;
362        }
363
364        complexity.min(100)
365    }
366
367    /// Estimate selectivity (0.0 - 1.0)
368    fn estimate_selectivity(&self, query: &str) -> f64 {
369        if !query.contains("WHERE") {
370            return 1.0; // Full table scan
371        }
372
373        let mut selectivity = 0.5; // Default with WHERE
374
375        // Equality predicates are highly selective
376        let eq_count = query.matches(" = ").count();
377        selectivity *= 0.9_f64.powi(eq_count as i32);
378
379        // LIMIT reduces result set
380        if query.contains("LIMIT") {
381            selectivity *= 0.5;
382        }
383
384        selectivity.max(0.001) // Never assume 0 selectivity
385    }
386
387    /// Extract columns from query
388    pub fn extract_columns(&self, query: &str) -> Vec<String> {
389        let mut columns = HashSet::new();
390        let words: Vec<&str> = query.split_whitespace().collect();
391
392        // Find column names between SELECT and FROM
393        if let Some(select_idx) = words.iter().position(|w| *w == "SELECT") {
394            if let Some(from_idx) = words.iter().position(|w| *w == "FROM") {
395                for word in &words[select_idx + 1..from_idx] {
396                    let col = word.trim_matches(|c| c == ',' || c == '(' || c == ')');
397                    if !col.is_empty() && !is_keyword(col) && col != "*" {
398                        // Handle table.column format
399                        if let Some(dot_idx) = col.find('.') {
400                            columns.insert(col[dot_idx + 1..].to_lowercase());
401                        } else {
402                            columns.insert(col.to_lowercase());
403                        }
404                    }
405                }
406            }
407        }
408
409        columns.into_iter().collect()
410    }
411}
412
413/// Check if a word is a SQL keyword
414fn is_keyword(word: &str) -> bool {
415    let keywords = [
416        "SELECT",
417        "FROM",
418        "WHERE",
419        "JOIN",
420        "ON",
421        "AND",
422        "OR",
423        "NOT",
424        "IN",
425        "IS",
426        "NULL",
427        "AS",
428        "ORDER",
429        "BY",
430        "GROUP",
431        "HAVING",
432        "LIMIT",
433        "OFFSET",
434        "INSERT",
435        "INTO",
436        "VALUES",
437        "UPDATE",
438        "SET",
439        "DELETE",
440        "CREATE",
441        "DROP",
442        "ALTER",
443        "INDEX",
444        "TABLE",
445        "LEFT",
446        "RIGHT",
447        "INNER",
448        "OUTER",
449        "FULL",
450        "CROSS",
451        "NATURAL",
452        "USING",
453        "DISTINCT",
454        "ALL",
455        "UNION",
456        "INTERSECT",
457        "EXCEPT",
458        "CASE",
459        "WHEN",
460        "THEN",
461        "ELSE",
462        "END",
463        "BETWEEN",
464        "LIKE",
465        "ILIKE",
466        "EXISTS",
467        "WITH",
468        "RECURSIVE",
469        "ASC",
470        "DESC",
471        "NULLS",
472        "FIRST",
473        "LAST",
474    ];
475    keywords.contains(&word.to_uppercase().as_str())
476}
477
478/// Table reference from query
479#[derive(Debug, Clone)]
480pub struct TableRef {
481    /// Table name
482    pub name: String,
483    /// Table alias
484    pub alias: Option<String>,
485    /// Table schema (if found)
486    pub schema: Option<TableSchema>,
487}
488
489/// Shard key value
490#[derive(Debug, Clone)]
491pub enum ShardKeyValue {
492    /// Single value
493    Single(String),
494    /// Multiple values (IN clause)
495    Multiple(Vec<String>),
496}
497
498/// Query analysis result
499#[derive(Debug, Clone)]
500pub struct QueryAnalysis {
501    /// Original query
502    pub original_query: String,
503    /// Tables referenced
504    pub tables: Vec<TableRef>,
505    /// Access patterns per table
506    pub access_patterns: Vec<AccessPattern>,
507    /// Extracted shard keys
508    pub shard_keys: HashMap<String, ShardKeyValue>,
509    /// Classified workload type
510    pub workload_type: WorkloadType,
511    /// Estimated complexity (0-100)
512    pub complexity: u32,
513    /// Estimated selectivity (0.0 - 1.0)
514    pub selectivity: f64,
515    /// Is read-only query
516    pub is_read_only: bool,
517    /// Has aggregation functions
518    pub has_aggregations: bool,
519    /// Has JOIN clauses
520    pub has_joins: bool,
521    /// Has subqueries
522    pub has_subqueries: bool,
523}
524
525impl QueryAnalysis {
526    /// Check if query involves vector operations
527    pub fn is_vector_query(&self) -> bool {
528        self.access_patterns.contains(&AccessPattern::VectorSearch)
529    }
530
531    /// Check if query is analytics (OLAP)
532    pub fn is_analytics(&self) -> bool {
533        self.workload_type == WorkloadType::OLAP
534    }
535
536    /// Get primary table (first table in query)
537    pub fn primary_table(&self) -> Option<&TableRef> {
538        self.tables.first()
539    }
540
541    /// Check if query targets a specific shard
542    pub fn has_shard_key(&self) -> bool {
543        !self.shard_keys.is_empty()
544    }
545}
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    fn create_test_registry() -> Arc<SchemaRegistry> {
552        let registry = SchemaRegistry::new();
553
554        let users = TableSchema::new("users")
555            .with_workload(WorkloadType::OLTP)
556            .with_access_pattern(AccessPattern::PointLookup)
557            .with_primary_key(vec!["id".to_string()])
558            .with_shard_key("id");
559
560        let events = TableSchema::new("events")
561            .with_workload(WorkloadType::OLAP)
562            .with_access_pattern(AccessPattern::FullScan);
563
564        let embeddings = TableSchema::new("embeddings")
565            .with_workload(WorkloadType::Vector)
566            .with_access_pattern(AccessPattern::VectorSearch);
567
568        registry.register_table(users);
569        registry.register_table(events);
570        registry.register_table(embeddings);
571
572        Arc::new(registry)
573    }
574
575    #[test]
576    fn test_extract_tables() {
577        let registry = create_test_registry();
578        let analyzer = QueryAnalyzer::new(registry);
579
580        let query = "SELECT * FROM users WHERE id = 1";
581        let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
582
583        assert_eq!(tables.len(), 1);
584        assert_eq!(tables[0].name, "users");
585    }
586
587    #[test]
588    fn test_extract_tables_with_join() {
589        let registry = create_test_registry();
590        let analyzer = QueryAnalyzer::new(registry);
591
592        let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id";
593        let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
594
595        assert_eq!(tables.len(), 2);
596        assert_eq!(tables[0].name, "users");
597        assert_eq!(tables[0].alias, Some("u".to_string()));
598    }
599
600    #[test]
601    fn test_classify_oltp() {
602        let registry = create_test_registry();
603        let analyzer = QueryAnalyzer::new(registry);
604
605        let query = "SELECT * FROM users WHERE id = 1";
606        let analysis = analyzer.analyze(query);
607
608        assert_eq!(analysis.workload_type, WorkloadType::OLTP);
609        assert!(analysis.is_read_only);
610    }
611
612    #[test]
613    fn test_classify_olap() {
614        let registry = create_test_registry();
615        let analyzer = QueryAnalyzer::new(registry);
616
617        let query = "SELECT COUNT(*), SUM(amount) FROM events GROUP BY date";
618        let analysis = analyzer.analyze(query);
619
620        assert_eq!(analysis.workload_type, WorkloadType::OLAP);
621        assert!(analysis.has_aggregations);
622    }
623
624    #[test]
625    fn test_classify_vector() {
626        let registry = create_test_registry();
627        let analyzer = QueryAnalyzer::new(registry);
628
629        let query = "SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10";
630        let analysis = analyzer.analyze(query);
631
632        assert_eq!(analysis.workload_type, WorkloadType::Vector);
633        assert!(analysis.is_vector_query());
634    }
635
636    #[test]
637    fn test_extract_shard_key() {
638        let registry = create_test_registry();
639        let analyzer = QueryAnalyzer::new(registry);
640
641        let query = "SELECT * FROM users WHERE id = 'user_123'";
642        let analysis = analyzer.analyze(query);
643
644        assert!(analysis.has_shard_key());
645        assert!(analysis.shard_keys.contains_key("id"));
646    }
647
648    #[test]
649    fn test_complexity_estimation() {
650        let registry = create_test_registry();
651        let analyzer = QueryAnalyzer::new(registry);
652
653        let simple = "SELECT * FROM users WHERE id = 1";
654        let complex = "SELECT u.*, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.id ORDER BY COUNT(o.id) DESC";
655
656        let simple_analysis = analyzer.analyze(simple);
657        let complex_analysis = analyzer.analyze(complex);
658
659        assert!(simple_analysis.complexity < complex_analysis.complexity);
660    }
661
662    #[test]
663    fn test_detect_point_lookup() {
664        let registry = create_test_registry();
665        let analyzer = QueryAnalyzer::new(registry);
666
667        let query = "SELECT * FROM users WHERE id = 1";
668        let analysis = analyzer.analyze(query);
669
670        assert!(analysis
671            .access_patterns
672            .contains(&AccessPattern::PointLookup));
673    }
674
675    #[test]
676    fn test_detect_full_scan() {
677        let registry = create_test_registry();
678        let analyzer = QueryAnalyzer::new(registry);
679
680        let query = "SELECT * FROM events";
681        let analysis = analyzer.analyze(query);
682
683        assert!(analysis.access_patterns.contains(&AccessPattern::FullScan));
684    }
685
686    #[test]
687    fn test_has_joins() {
688        let registry = create_test_registry();
689        let analyzer = QueryAnalyzer::new(registry);
690
691        let with_join = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
692        let without_join = "SELECT * FROM users";
693
694        assert!(analyzer.analyze(with_join).has_joins);
695        assert!(!analyzer.analyze(without_join).has_joins);
696    }
697
698    #[test]
699    fn test_extract_columns() {
700        let registry = create_test_registry();
701        let analyzer = QueryAnalyzer::new(registry);
702
703        let query = "SELECT id, name, email FROM users WHERE id = 1";
704        let normalized = analyzer.normalize_query(query);
705        let columns = analyzer.extract_columns(&normalized);
706
707        assert!(columns.contains(&"id".to_string()));
708        assert!(columns.contains(&"name".to_string()));
709        assert!(columns.contains(&"email".to_string()));
710    }
711}