Skip to main content

heliosdb_proxy/schema_routing/
analyzer.rs

1//! Query Analyzer
2//!
3//! Analyzes SQL queries to determine routing requirements.
4
5use std::collections::{HashMap, HashSet};
6use std::sync::Arc;
7
8use super::registry::{
9    SchemaRegistry, TableSchema, AccessPattern, WorkloadType,
10};
11
12/// Query analyzer for schema-aware routing
13#[derive(Debug)]
14pub struct QueryAnalyzer {
15    /// Schema registry reference
16    schema: Arc<SchemaRegistry>,
17}
18
19impl QueryAnalyzer {
20    /// Create a new query analyzer
21    pub fn new(schema: Arc<SchemaRegistry>) -> Self {
22        Self { schema }
23    }
24
25    /// Analyze a query and determine routing requirements
26    pub fn analyze(&self, query: &str) -> QueryAnalysis {
27        let normalized = self.normalize_query(query);
28        let tables = self.extract_tables(&normalized);
29        let access_patterns = self.detect_access_patterns(&normalized, &tables);
30        let shard_keys = self.extract_shard_keys(&normalized, &tables);
31        let workload_type = self.classify_workload(&normalized, &tables);
32
33        QueryAnalysis {
34            original_query: query.to_string(),
35            tables,
36            access_patterns,
37            shard_keys,
38            workload_type,
39            complexity: self.estimate_complexity(&normalized),
40            selectivity: self.estimate_selectivity(&normalized),
41            is_read_only: self.is_read_only(&normalized),
42            has_aggregations: self.has_aggregations(&normalized),
43            has_joins: self.has_joins(&normalized),
44            has_subqueries: self.has_subqueries(&normalized),
45        }
46    }
47
48    /// Normalize query for analysis
49    fn normalize_query(&self, query: &str) -> String {
50        query.to_uppercase()
51            .replace('\n', " ")
52            .replace('\t', " ")
53            .split_whitespace()
54            .collect::<Vec<_>>()
55            .join(" ")
56    }
57
58    /// Extract tables from query
59    pub fn extract_tables(&self, query: &str) -> Vec<TableRef> {
60        let mut tables = Vec::new();
61        let words: Vec<&str> = query.split_whitespace().collect();
62
63        // Find tables after FROM, JOIN, INTO, UPDATE
64        let table_keywords = ["FROM", "JOIN", "INTO", "UPDATE"];
65
66        for (i, word) in words.iter().enumerate() {
67            if table_keywords.contains(word) {
68                if let Some(table_name) = words.get(i + 1) {
69                    let name = table_name.trim_matches(|c| c == ',' || c == '(' || c == ')');
70                    if !name.is_empty() && !is_keyword(name) {
71                        let alias = self.find_alias(&words, i + 1);
72                        tables.push(TableRef {
73                            name: name.to_lowercase(),
74                            alias,
75                            schema: self.schema.get_table(&name.to_lowercase()),
76                        });
77                    }
78                }
79            }
80        }
81
82        tables
83    }
84
85    /// Find alias for a table
86    fn find_alias(&self, words: &[&str], table_idx: usize) -> Option<String> {
87        if let Some(next) = words.get(table_idx + 1) {
88            if next.eq_ignore_ascii_case("AS") {
89                return words.get(table_idx + 2).map(|s| s.to_lowercase());
90            } else if !is_keyword(next) && !next.starts_with('(') {
91                return Some(next.to_lowercase());
92            }
93        }
94        None
95    }
96
97    /// Detect access patterns for each table
98    fn detect_access_patterns(&self, query: &str, tables: &[TableRef]) -> Vec<AccessPattern> {
99        let mut patterns = Vec::new();
100
101        for table in tables {
102            let pattern = self.detect_table_access_pattern(query, table);
103            patterns.push(pattern);
104        }
105
106        patterns
107    }
108
109    /// Detect access pattern for a specific table
110    fn detect_table_access_pattern(&self, query: &str, table: &TableRef) -> AccessPattern {
111        // Check for vector operations
112        if self.has_vector_operator(query) {
113            return AccessPattern::VectorSearch;
114        }
115
116        // Check for point lookup (equality on PK)
117        if let Some(schema) = &table.schema {
118            if self.has_equality_on_pk(query, schema) {
119                return AccessPattern::PointLookup;
120            }
121        }
122
123        // Check for range predicates
124        if self.has_range_predicate(query) {
125            return AccessPattern::RangeScan;
126        }
127
128        // Check for time-series patterns
129        if self.is_time_series_append(query) {
130            return AccessPattern::TimeSeriesAppend;
131        }
132
133        // Default to full scan if no WHERE clause
134        if !query.contains("WHERE") {
135            return AccessPattern::FullScan;
136        }
137
138        AccessPattern::Mixed
139    }
140
141    /// Check for equality on primary key
142    fn has_equality_on_pk(&self, query: &str, schema: &TableSchema) -> bool {
143        if schema.primary_key.is_empty() {
144            return false;
145        }
146
147        for pk_col in &schema.primary_key {
148            let pattern = format!("{} =", pk_col.to_uppercase());
149            if query.contains(&pattern) {
150                return true;
151            }
152        }
153
154        false
155    }
156
157    /// Check for range predicates
158    fn has_range_predicate(&self, query: &str) -> bool {
159        query.contains(" > ") || query.contains(" < ")
160            || query.contains(" >= ") || query.contains(" <= ")
161            || query.contains(" BETWEEN ")
162    }
163
164    /// Check for vector operators
165    fn has_vector_operator(&self, query: &str) -> bool {
166        query.contains("<->") || query.contains("<#>") || query.contains("<=>")
167            || query.contains("VECTOR") || query.contains("EMBEDDING")
168            || query.contains("COSINE_DISTANCE") || query.contains("L2_DISTANCE")
169    }
170
171    /// Check for time-series append pattern
172    fn is_time_series_append(&self, query: &str) -> bool {
173        query.starts_with("INSERT") && (
174            query.contains("TIMESTAMP") || query.contains("CREATED_AT")
175                || query.contains("EVENT_TIME")
176        )
177    }
178
179    /// Extract shard keys from query
180    fn extract_shard_keys(&self, query: &str, tables: &[TableRef]) -> HashMap<String, ShardKeyValue> {
181        let mut shard_keys = HashMap::new();
182
183        for table in tables {
184            if let Some(schema) = &table.schema {
185                if let Some(shard_key) = &schema.shard_key {
186                    if let Some(value) = self.extract_shard_key_value(query, shard_key) {
187                        shard_keys.insert(shard_key.clone(), value);
188                    }
189                }
190            }
191        }
192
193        shard_keys
194    }
195
196    /// Extract shard key value from query
197    fn extract_shard_key_value(&self, query: &str, shard_key: &str) -> Option<ShardKeyValue> {
198        // Look for patterns like "shard_key = 'value'" or "shard_key = value"
199        let pattern = format!("{} =", shard_key.to_uppercase());
200        if let Some(idx) = query.find(&pattern) {
201            let rest = &query[idx + pattern.len()..];
202            let value = rest.split_whitespace().next()?;
203            let clean_value = value.trim_matches(|c| c == '\'' || c == '"' || c == ',');
204            return Some(ShardKeyValue::Single(clean_value.to_string()));
205        }
206
207        // Look for IN clause
208        let in_pattern = format!("{} IN", shard_key.to_uppercase());
209        if let Some(idx) = query.find(&in_pattern) {
210            let rest = &query[idx + in_pattern.len()..];
211            if let Some(start) = rest.find('(') {
212                if let Some(end) = rest.find(')') {
213                    let values_str = &rest[start + 1..end];
214                    let values: Vec<String> = values_str
215                        .split(',')
216                        .map(|v| v.trim().trim_matches(|c| c == '\'' || c == '"').to_string())
217                        .collect();
218                    return Some(ShardKeyValue::Multiple(values));
219                }
220            }
221        }
222
223        None
224    }
225
226    /// Classify workload type
227    fn classify_workload(&self, query: &str, tables: &[TableRef]) -> WorkloadType {
228        // Vector queries
229        if self.has_vector_operator(query) {
230            return WorkloadType::Vector;
231        }
232
233        // OLAP indicators
234        if self.has_aggregations(query) || self.has_group_by(query) || self.has_window_functions(query) {
235            return WorkloadType::OLAP;
236        }
237
238        // Simple CRUD is OLTP
239        if self.is_simple_crud(query) {
240            return WorkloadType::OLTP;
241        }
242
243        // Check table hints
244        for table in tables {
245            if let Some(schema) = &table.schema {
246                if schema.workload != WorkloadType::Mixed {
247                    return schema.workload;
248                }
249            }
250        }
251
252        WorkloadType::Mixed
253    }
254
255    /// Check if query has aggregations
256    pub fn has_aggregations(&self, query: &str) -> bool {
257        query.contains("COUNT(") || query.contains("SUM(")
258            || query.contains("AVG(") || query.contains("MIN(")
259            || query.contains("MAX(")
260    }
261
262    /// Check if query has GROUP BY
263    fn has_group_by(&self, query: &str) -> bool {
264        query.contains("GROUP BY")
265    }
266
267    /// Check if query has window functions
268    fn has_window_functions(&self, query: &str) -> bool {
269        query.contains("OVER(") || query.contains("OVER (")
270            || query.contains("ROW_NUMBER") || query.contains("RANK()")
271            || query.contains("DENSE_RANK") || query.contains("LAG(")
272            || query.contains("LEAD(")
273    }
274
275    /// Check if query is simple CRUD
276    fn is_simple_crud(&self, query: &str) -> bool {
277        let is_simple_select = query.starts_with("SELECT")
278            && !self.has_joins(query)
279            && !self.has_subqueries(query)
280            && !self.has_aggregations(query);
281
282        let is_simple_insert = query.starts_with("INSERT")
283            && !query.contains("SELECT");
284
285        let is_simple_update = query.starts_with("UPDATE")
286            && query.contains("WHERE");
287
288        let is_simple_delete = query.starts_with("DELETE")
289            && query.contains("WHERE");
290
291        is_simple_select || is_simple_insert || is_simple_update || is_simple_delete
292    }
293
294    /// Check if query is read-only
295    pub fn is_read_only(&self, query: &str) -> bool {
296        query.starts_with("SELECT") || query.starts_with("WITH")
297            || query.starts_with("EXPLAIN") || query.starts_with("SHOW")
298    }
299
300    /// Check if query has joins
301    pub fn has_joins(&self, query: &str) -> bool {
302        query.contains(" JOIN ")
303    }
304
305    /// Check if query has subqueries
306    pub fn has_subqueries(&self, query: &str) -> bool {
307        // Count SELECT keywords (more than one suggests subqueries)
308        query.matches("SELECT").count() > 1
309    }
310
311    /// Estimate query complexity (0-100)
312    fn estimate_complexity(&self, query: &str) -> u32 {
313        let mut complexity: u32 = 10; // Base complexity
314
315        // Add for joins
316        complexity += (query.matches(" JOIN ").count() as u32) * 15;
317
318        // Add for subqueries
319        let select_count = query.matches("SELECT").count() as u32;
320        if select_count > 1 {
321            complexity += (select_count - 1) * 20;
322        }
323
324        // Add for aggregations
325        if self.has_aggregations(query) {
326            complexity += 10;
327        }
328
329        // Add for GROUP BY
330        if self.has_group_by(query) {
331            complexity += 10;
332        }
333
334        // Add for window functions
335        if self.has_window_functions(query) {
336            complexity += 15;
337        }
338
339        // Add for ORDER BY
340        if query.contains("ORDER BY") {
341            complexity += 5;
342        }
343
344        // Add for DISTINCT
345        if query.contains("DISTINCT") {
346            complexity += 5;
347        }
348
349        complexity.min(100)
350    }
351
352    /// Estimate selectivity (0.0 - 1.0)
353    fn estimate_selectivity(&self, query: &str) -> f64 {
354        if !query.contains("WHERE") {
355            return 1.0; // Full table scan
356        }
357
358        let mut selectivity = 0.5; // Default with WHERE
359
360        // Equality predicates are highly selective
361        let eq_count = query.matches(" = ").count();
362        selectivity *= 0.9_f64.powi(eq_count as i32);
363
364        // LIMIT reduces result set
365        if query.contains("LIMIT") {
366            selectivity *= 0.5;
367        }
368
369        selectivity.max(0.001) // Never assume 0 selectivity
370    }
371
372    /// Extract columns from query
373    pub fn extract_columns(&self, query: &str) -> Vec<String> {
374        let mut columns = HashSet::new();
375        let words: Vec<&str> = query.split_whitespace().collect();
376
377        // Find column names between SELECT and FROM
378        if let Some(select_idx) = words.iter().position(|w| *w == "SELECT") {
379            if let Some(from_idx) = words.iter().position(|w| *w == "FROM") {
380                for word in &words[select_idx + 1..from_idx] {
381                    let col = word.trim_matches(|c| c == ',' || c == '(' || c == ')');
382                    if !col.is_empty() && !is_keyword(col) && col != "*" {
383                        // Handle table.column format
384                        if let Some(dot_idx) = col.find('.') {
385                            columns.insert(col[dot_idx + 1..].to_lowercase());
386                        } else {
387                            columns.insert(col.to_lowercase());
388                        }
389                    }
390                }
391            }
392        }
393
394        columns.into_iter().collect()
395    }
396}
397
398/// Check if a word is a SQL keyword
399fn is_keyword(word: &str) -> bool {
400    let keywords = [
401        "SELECT", "FROM", "WHERE", "JOIN", "ON", "AND", "OR", "NOT",
402        "IN", "IS", "NULL", "AS", "ORDER", "BY", "GROUP", "HAVING",
403        "LIMIT", "OFFSET", "INSERT", "INTO", "VALUES", "UPDATE", "SET",
404        "DELETE", "CREATE", "DROP", "ALTER", "INDEX", "TABLE", "LEFT",
405        "RIGHT", "INNER", "OUTER", "FULL", "CROSS", "NATURAL", "USING",
406        "DISTINCT", "ALL", "UNION", "INTERSECT", "EXCEPT", "CASE",
407        "WHEN", "THEN", "ELSE", "END", "BETWEEN", "LIKE", "ILIKE",
408        "EXISTS", "WITH", "RECURSIVE", "ASC", "DESC", "NULLS", "FIRST", "LAST",
409    ];
410    keywords.contains(&word.to_uppercase().as_str())
411}
412
413/// Table reference from query
414#[derive(Debug, Clone)]
415pub struct TableRef {
416    /// Table name
417    pub name: String,
418    /// Table alias
419    pub alias: Option<String>,
420    /// Table schema (if found)
421    pub schema: Option<TableSchema>,
422}
423
424/// Shard key value
425#[derive(Debug, Clone)]
426pub enum ShardKeyValue {
427    /// Single value
428    Single(String),
429    /// Multiple values (IN clause)
430    Multiple(Vec<String>),
431}
432
433/// Query analysis result
434#[derive(Debug, Clone)]
435pub struct QueryAnalysis {
436    /// Original query
437    pub original_query: String,
438    /// Tables referenced
439    pub tables: Vec<TableRef>,
440    /// Access patterns per table
441    pub access_patterns: Vec<AccessPattern>,
442    /// Extracted shard keys
443    pub shard_keys: HashMap<String, ShardKeyValue>,
444    /// Classified workload type
445    pub workload_type: WorkloadType,
446    /// Estimated complexity (0-100)
447    pub complexity: u32,
448    /// Estimated selectivity (0.0 - 1.0)
449    pub selectivity: f64,
450    /// Is read-only query
451    pub is_read_only: bool,
452    /// Has aggregation functions
453    pub has_aggregations: bool,
454    /// Has JOIN clauses
455    pub has_joins: bool,
456    /// Has subqueries
457    pub has_subqueries: bool,
458}
459
460impl QueryAnalysis {
461    /// Check if query involves vector operations
462    pub fn is_vector_query(&self) -> bool {
463        self.access_patterns.contains(&AccessPattern::VectorSearch)
464    }
465
466    /// Check if query is analytics (OLAP)
467    pub fn is_analytics(&self) -> bool {
468        self.workload_type == WorkloadType::OLAP
469    }
470
471    /// Get primary table (first table in query)
472    pub fn primary_table(&self) -> Option<&TableRef> {
473        self.tables.first()
474    }
475
476    /// Check if query targets a specific shard
477    pub fn has_shard_key(&self) -> bool {
478        !self.shard_keys.is_empty()
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    fn create_test_registry() -> Arc<SchemaRegistry> {
487        let registry = SchemaRegistry::new();
488
489        let users = TableSchema::new("users")
490            .with_workload(WorkloadType::OLTP)
491            .with_access_pattern(AccessPattern::PointLookup)
492            .with_primary_key(vec!["id".to_string()])
493            .with_shard_key("id");
494
495        let events = TableSchema::new("events")
496            .with_workload(WorkloadType::OLAP)
497            .with_access_pattern(AccessPattern::FullScan);
498
499        let embeddings = TableSchema::new("embeddings")
500            .with_workload(WorkloadType::Vector)
501            .with_access_pattern(AccessPattern::VectorSearch);
502
503        registry.register_table(users);
504        registry.register_table(events);
505        registry.register_table(embeddings);
506
507        Arc::new(registry)
508    }
509
510    #[test]
511    fn test_extract_tables() {
512        let registry = create_test_registry();
513        let analyzer = QueryAnalyzer::new(registry);
514
515        let query = "SELECT * FROM users WHERE id = 1";
516        let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
517
518        assert_eq!(tables.len(), 1);
519        assert_eq!(tables[0].name, "users");
520    }
521
522    #[test]
523    fn test_extract_tables_with_join() {
524        let registry = create_test_registry();
525        let analyzer = QueryAnalyzer::new(registry);
526
527        let query = "SELECT u.*, o.* FROM users u JOIN orders o ON u.id = o.user_id";
528        let tables = analyzer.extract_tables(&analyzer.normalize_query(query));
529
530        assert_eq!(tables.len(), 2);
531        assert_eq!(tables[0].name, "users");
532        assert_eq!(tables[0].alias, Some("u".to_string()));
533    }
534
535    #[test]
536    fn test_classify_oltp() {
537        let registry = create_test_registry();
538        let analyzer = QueryAnalyzer::new(registry);
539
540        let query = "SELECT * FROM users WHERE id = 1";
541        let analysis = analyzer.analyze(query);
542
543        assert_eq!(analysis.workload_type, WorkloadType::OLTP);
544        assert!(analysis.is_read_only);
545    }
546
547    #[test]
548    fn test_classify_olap() {
549        let registry = create_test_registry();
550        let analyzer = QueryAnalyzer::new(registry);
551
552        let query = "SELECT COUNT(*), SUM(amount) FROM events GROUP BY date";
553        let analysis = analyzer.analyze(query);
554
555        assert_eq!(analysis.workload_type, WorkloadType::OLAP);
556        assert!(analysis.has_aggregations);
557    }
558
559    #[test]
560    fn test_classify_vector() {
561        let registry = create_test_registry();
562        let analyzer = QueryAnalyzer::new(registry);
563
564        let query = "SELECT * FROM embeddings ORDER BY embedding <-> '[1,2,3]' LIMIT 10";
565        let analysis = analyzer.analyze(query);
566
567        assert_eq!(analysis.workload_type, WorkloadType::Vector);
568        assert!(analysis.is_vector_query());
569    }
570
571    #[test]
572    fn test_extract_shard_key() {
573        let registry = create_test_registry();
574        let analyzer = QueryAnalyzer::new(registry);
575
576        let query = "SELECT * FROM users WHERE id = 'user_123'";
577        let analysis = analyzer.analyze(query);
578
579        assert!(analysis.has_shard_key());
580        assert!(analysis.shard_keys.contains_key("id"));
581    }
582
583    #[test]
584    fn test_complexity_estimation() {
585        let registry = create_test_registry();
586        let analyzer = QueryAnalyzer::new(registry);
587
588        let simple = "SELECT * FROM users WHERE id = 1";
589        let complex = "SELECT u.*, COUNT(o.id) FROM users u JOIN orders o ON u.id = o.user_id GROUP BY u.id ORDER BY COUNT(o.id) DESC";
590
591        let simple_analysis = analyzer.analyze(simple);
592        let complex_analysis = analyzer.analyze(complex);
593
594        assert!(simple_analysis.complexity < complex_analysis.complexity);
595    }
596
597    #[test]
598    fn test_detect_point_lookup() {
599        let registry = create_test_registry();
600        let analyzer = QueryAnalyzer::new(registry);
601
602        let query = "SELECT * FROM users WHERE id = 1";
603        let analysis = analyzer.analyze(query);
604
605        assert!(analysis.access_patterns.contains(&AccessPattern::PointLookup));
606    }
607
608    #[test]
609    fn test_detect_full_scan() {
610        let registry = create_test_registry();
611        let analyzer = QueryAnalyzer::new(registry);
612
613        let query = "SELECT * FROM events";
614        let analysis = analyzer.analyze(query);
615
616        assert!(analysis.access_patterns.contains(&AccessPattern::FullScan));
617    }
618
619    #[test]
620    fn test_has_joins() {
621        let registry = create_test_registry();
622        let analyzer = QueryAnalyzer::new(registry);
623
624        let with_join = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
625        let without_join = "SELECT * FROM users";
626
627        assert!(analyzer.analyze(with_join).has_joins);
628        assert!(!analyzer.analyze(without_join).has_joins);
629    }
630
631    #[test]
632    fn test_extract_columns() {
633        let registry = create_test_registry();
634        let analyzer = QueryAnalyzer::new(registry);
635
636        let query = "SELECT id, name, email FROM users WHERE id = 1";
637        let normalized = analyzer.normalize_query(query);
638        let columns = analyzer.extract_columns(&normalized);
639
640        assert!(columns.contains(&"id".to_string()));
641        assert!(columns.contains(&"name".to_string()));
642        assert!(columns.contains(&"email".to_string()));
643    }
644}