Skip to main content

aegis_query/
analyzer.rs

1//! Aegis Analyzer - Semantic Analysis
2//!
3//! Performs semantic analysis on parsed SQL statements. Validates table and
4//! column references, resolves types, and checks for constraint violations.
5//!
6//! Key Features:
7//! - Table and column resolution
8//! - Type checking and inference
9//! - Constraint validation
10//! - Scope management for subqueries
11//!
12//! @version 0.1.0
13//! @author AutomataNexus Development Team
14
15use crate::ast::*;
16use aegis_common::{AegisError, DataType, Result};
17use std::collections::HashMap;
18
19// =============================================================================
20// Catalog
21// =============================================================================
22
23/// Schema information for analysis.
24#[derive(Debug, Clone, Default)]
25pub struct Catalog {
26    tables: HashMap<String, TableSchema>,
27}
28
29impl Catalog {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn add_table(&mut self, schema: TableSchema) {
35        self.tables.insert(schema.name.clone(), schema);
36    }
37
38    pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
39        self.tables.get(name)
40    }
41
42    pub fn table_exists(&self, name: &str) -> bool {
43        self.tables.contains_key(name)
44    }
45}
46
47/// Schema for a table.
48#[derive(Debug, Clone)]
49pub struct TableSchema {
50    pub name: String,
51    pub columns: Vec<ColumnSchema>,
52}
53
54impl TableSchema {
55    pub fn new(name: &str) -> Self {
56        Self {
57            name: name.to_string(),
58            columns: Vec::new(),
59        }
60    }
61
62    pub fn add_column(&mut self, column: ColumnSchema) {
63        self.columns.push(column);
64    }
65
66    pub fn get_column(&self, name: &str) -> Option<&ColumnSchema> {
67        self.columns.iter().find(|c| c.name == name)
68    }
69
70    pub fn column_exists(&self, name: &str) -> bool {
71        self.columns.iter().any(|c| c.name == name)
72    }
73}
74
75/// Schema for a column.
76#[derive(Debug, Clone)]
77pub struct ColumnSchema {
78    pub name: String,
79    pub data_type: DataType,
80    pub nullable: bool,
81}
82
83// =============================================================================
84// Analysis Context
85// =============================================================================
86
87/// Context for semantic analysis.
88#[derive(Debug)]
89#[allow(dead_code)]
90struct AnalysisContext<'a> {
91    catalog: &'a Catalog,
92    scope: Scope,
93}
94
95/// Scope for name resolution.
96#[derive(Debug, Default)]
97struct Scope {
98    tables: HashMap<String, String>,
99    columns: HashMap<String, ResolvedColumn>,
100}
101
102/// A resolved column reference.
103#[derive(Debug, Clone)]
104struct ResolvedColumn {
105    table: String,
106    column: String,
107    data_type: DataType,
108}
109
110// =============================================================================
111// Analyzer
112// =============================================================================
113
114/// Semantic analyzer for SQL statements.
115pub struct Analyzer {
116    catalog: Catalog,
117}
118
119impl Analyzer {
120    pub fn new(catalog: Catalog) -> Self {
121        Self { catalog }
122    }
123
124    /// Analyze a statement for semantic correctness.
125    pub fn analyze(&self, stmt: &Statement) -> Result<AnalyzedStatement> {
126        match stmt {
127            Statement::Select(select) => {
128                let analyzed = self.analyze_select(select)?;
129                Ok(AnalyzedStatement::Select(analyzed))
130            }
131            Statement::Insert(insert) => {
132                self.analyze_insert(insert)?;
133                Ok(AnalyzedStatement::Insert(insert.clone()))
134            }
135            Statement::Update(update) => {
136                self.analyze_update(update)?;
137                Ok(AnalyzedStatement::Update(update.clone()))
138            }
139            Statement::Delete(delete) => {
140                self.analyze_delete(delete)?;
141                Ok(AnalyzedStatement::Delete(delete.clone()))
142            }
143            Statement::CreateTable(create) => {
144                self.analyze_create_table(create)?;
145                Ok(AnalyzedStatement::CreateTable(create.clone()))
146            }
147            Statement::DropTable(drop) => Ok(AnalyzedStatement::DropTable(drop.clone())),
148            Statement::AlterTable(alter) => Ok(AnalyzedStatement::AlterTable(alter.clone())),
149            Statement::CreateIndex(create) => {
150                self.analyze_create_index(create)?;
151                Ok(AnalyzedStatement::CreateIndex(create.clone()))
152            }
153            Statement::DropIndex(drop) => Ok(AnalyzedStatement::DropIndex(drop.clone())),
154            Statement::SetOperation(set_op) => {
155                // Analyze both sides of the set operation
156                let left = self.analyze(set_op.left.as_ref())?;
157                let right = self.analyze(set_op.right.as_ref())?;
158                Ok(AnalyzedStatement::SetOperation {
159                    op: set_op.op,
160                    left: Box::new(left),
161                    right: Box::new(right),
162                })
163            }
164            Statement::Begin => Ok(AnalyzedStatement::Begin),
165            Statement::Commit => Ok(AnalyzedStatement::Commit),
166            Statement::Rollback => Ok(AnalyzedStatement::Rollback),
167        }
168    }
169
170    fn analyze_select(&self, select: &SelectStatement) -> Result<AnalyzedSelect> {
171        let mut scope = Scope::default();
172
173        if let Some(ref from) = select.from {
174            self.build_scope_from_clause(from, &mut scope)?;
175        }
176
177        let mut output_columns = Vec::new();
178        for col in &select.columns {
179            match col {
180                SelectColumn::AllColumns => {
181                    for resolved in scope.columns.values() {
182                        output_columns.push(OutputColumn {
183                            name: resolved.column.clone(),
184                            data_type: resolved.data_type.clone(),
185                        });
186                    }
187                }
188                SelectColumn::TableAllColumns(table) => {
189                    for resolved in scope.columns.values() {
190                        if resolved.table == *table {
191                            output_columns.push(OutputColumn {
192                                name: resolved.column.clone(),
193                                data_type: resolved.data_type.clone(),
194                            });
195                        }
196                    }
197                }
198                SelectColumn::Expression { expr, alias } => {
199                    let data_type = self.infer_type(expr, &scope)?;
200                    let name = alias.clone().unwrap_or_else(|| self.expr_name(expr));
201                    output_columns.push(OutputColumn { name, data_type });
202                }
203            }
204        }
205
206        if let Some(ref where_clause) = select.where_clause {
207            self.validate_expression(where_clause, &scope)?;
208        }
209
210        for expr in &select.group_by {
211            self.validate_expression(expr, &scope)?;
212        }
213
214        if let Some(ref having) = select.having {
215            self.validate_expression(having, &scope)?;
216        }
217
218        for order_by in &select.order_by {
219            self.validate_expression(&order_by.expression, &scope)?;
220        }
221
222        Ok(AnalyzedSelect {
223            statement: select.clone(),
224            output_columns,
225        })
226    }
227
228    fn analyze_insert(&self, insert: &InsertStatement) -> Result<()> {
229        let table = self
230            .catalog
231            .get_table(&insert.table)
232            .ok_or_else(|| AegisError::TableNotFound(insert.table.clone()))?;
233
234        if let Some(ref columns) = insert.columns {
235            for col_name in columns {
236                if !table.column_exists(col_name) {
237                    return Err(AegisError::ColumnNotFound(col_name.clone()));
238                }
239            }
240        }
241
242        Ok(())
243    }
244
245    fn analyze_update(&self, update: &UpdateStatement) -> Result<()> {
246        let table = self
247            .catalog
248            .get_table(&update.table)
249            .ok_or_else(|| AegisError::TableNotFound(update.table.clone()))?;
250
251        for assignment in &update.assignments {
252            if !table.column_exists(&assignment.column) {
253                return Err(AegisError::ColumnNotFound(assignment.column.clone()));
254            }
255        }
256
257        Ok(())
258    }
259
260    fn analyze_delete(&self, delete: &DeleteStatement) -> Result<()> {
261        if !self.catalog.table_exists(&delete.table) {
262            return Err(AegisError::TableNotFound(delete.table.clone()));
263        }
264        Ok(())
265    }
266
267    fn analyze_create_table(&self, create: &CreateTableStatement) -> Result<()> {
268        if self.catalog.table_exists(&create.name) && !create.if_not_exists {
269            return Err(AegisError::ConstraintViolation(format!(
270                "Table '{}' already exists",
271                create.name
272            )));
273        }
274        Ok(())
275    }
276
277    fn analyze_create_index(&self, create: &CreateIndexStatement) -> Result<()> {
278        let table = self
279            .catalog
280            .get_table(&create.table)
281            .ok_or_else(|| AegisError::TableNotFound(create.table.clone()))?;
282
283        for col_name in &create.columns {
284            if !table.column_exists(col_name) {
285                return Err(AegisError::ColumnNotFound(col_name.clone()));
286            }
287        }
288
289        Ok(())
290    }
291
292    fn build_scope_from_clause(&self, from: &FromClause, scope: &mut Scope) -> Result<()> {
293        self.add_table_to_scope(&from.source, scope)?;
294
295        for join in &from.joins {
296            self.add_table_to_scope(&join.table, scope)?;
297        }
298
299        Ok(())
300    }
301
302    fn add_table_to_scope(&self, table_ref: &TableReference, scope: &mut Scope) -> Result<()> {
303        match table_ref {
304            TableReference::Table { name, alias } => {
305                let table = self
306                    .catalog
307                    .get_table(name)
308                    .ok_or_else(|| AegisError::TableNotFound(name.clone()))?;
309
310                let alias_name = alias.as_ref().unwrap_or(name);
311                scope.tables.insert(alias_name.clone(), name.clone());
312
313                for col in &table.columns {
314                    let key = format!("{}.{}", alias_name, col.name);
315                    scope.columns.insert(
316                        key.clone(),
317                        ResolvedColumn {
318                            table: alias_name.clone(),
319                            column: col.name.clone(),
320                            data_type: col.data_type.clone(),
321                        },
322                    );
323
324                    if !scope.columns.contains_key(&col.name) {
325                        scope.columns.insert(
326                            col.name.clone(),
327                            ResolvedColumn {
328                                table: alias_name.clone(),
329                                column: col.name.clone(),
330                                data_type: col.data_type.clone(),
331                            },
332                        );
333                    }
334                }
335            }
336            TableReference::Subquery { query: _, alias } => {
337                scope.tables.insert(alias.clone(), alias.clone());
338            }
339        }
340
341        Ok(())
342    }
343
344    fn validate_expression(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
345        self.infer_type(expr, scope)
346    }
347
348    fn infer_type(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
349        match expr {
350            Expression::Literal(lit) => Ok(self.literal_type(lit)),
351            Expression::Column(col_ref) => {
352                let key = if let Some(ref table) = col_ref.table {
353                    format!("{}.{}", table, col_ref.column)
354                } else {
355                    col_ref.column.clone()
356                };
357
358                scope
359                    .columns
360                    .get(&key)
361                    .map(|r| r.data_type.clone())
362                    .ok_or(AegisError::ColumnNotFound(key))
363            }
364            Expression::BinaryOp { left, op, right } => {
365                let left_type = self.infer_type(left, scope)?;
366                let right_type = self.infer_type(right, scope)?;
367                self.binary_op_type(&left_type, op, &right_type)
368            }
369            Expression::UnaryOp { op, expr } => {
370                let expr_type = self.infer_type(expr, scope)?;
371                self.unary_op_type(op, &expr_type)
372            }
373            Expression::Function { name, args, .. } => {
374                for arg in args {
375                    self.infer_type(arg, scope)?;
376                }
377                self.function_return_type(name, args, scope)
378            }
379            Expression::IsNull { .. } => Ok(DataType::Boolean),
380            Expression::InList { .. } => Ok(DataType::Boolean),
381            Expression::Between { .. } => Ok(DataType::Boolean),
382            Expression::Like { .. } => Ok(DataType::Boolean),
383            Expression::Cast { data_type, .. } => Ok(data_type.clone()),
384            Expression::Case {
385                conditions,
386                else_result,
387                ..
388            } => {
389                if let Some((_, then_expr)) = conditions.first() {
390                    self.infer_type(then_expr, scope)
391                } else if let Some(else_expr) = else_result {
392                    self.infer_type(else_expr, scope)
393                } else {
394                    Ok(DataType::Text)
395                }
396            }
397            Expression::Subquery(_) => Ok(DataType::Text),
398            Expression::InSubquery { .. } => Ok(DataType::Boolean),
399            Expression::Exists { .. } => Ok(DataType::Boolean),
400            Expression::Placeholder(_) => Ok(DataType::Text),
401        }
402    }
403
404    fn literal_type(&self, lit: &Literal) -> DataType {
405        match lit {
406            Literal::Null => DataType::Text,
407            Literal::Boolean(_) => DataType::Boolean,
408            Literal::Integer(_) => DataType::BigInt,
409            Literal::Float(_) => DataType::Double,
410            Literal::String(_) => DataType::Text,
411        }
412    }
413
414    fn binary_op_type(
415        &self,
416        left: &DataType,
417        op: &BinaryOperator,
418        _right: &DataType,
419    ) -> Result<DataType> {
420        match op {
421            BinaryOperator::Equal
422            | BinaryOperator::NotEqual
423            | BinaryOperator::LessThan
424            | BinaryOperator::LessThanOrEqual
425            | BinaryOperator::GreaterThan
426            | BinaryOperator::GreaterThanOrEqual
427            | BinaryOperator::And
428            | BinaryOperator::Or => Ok(DataType::Boolean),
429            BinaryOperator::Add
430            | BinaryOperator::Subtract
431            | BinaryOperator::Multiply
432            | BinaryOperator::Divide
433            | BinaryOperator::Modulo => Ok(left.clone()),
434            BinaryOperator::Concat => Ok(DataType::Text),
435        }
436    }
437
438    fn unary_op_type(&self, op: &UnaryOperator, expr_type: &DataType) -> Result<DataType> {
439        match op {
440            UnaryOperator::Not => Ok(DataType::Boolean),
441            UnaryOperator::Negative | UnaryOperator::Positive => Ok(expr_type.clone()),
442        }
443    }
444
445    fn function_return_type(
446        &self,
447        name: &str,
448        _args: &[Expression],
449        _scope: &Scope,
450    ) -> Result<DataType> {
451        let name_upper = name.to_uppercase();
452        match name_upper.as_str() {
453            "COUNT" => Ok(DataType::BigInt),
454            "SUM" | "AVG" => Ok(DataType::Double),
455            "MIN" | "MAX" => Ok(DataType::Double),
456            "COALESCE" | "NULLIF" => Ok(DataType::Text),
457            "NOW" | "CURRENT_TIMESTAMP" => Ok(DataType::Timestamp),
458            "CURRENT_DATE" => Ok(DataType::Date),
459            "UPPER" | "LOWER" | "TRIM" | "CONCAT" | "SUBSTRING" => Ok(DataType::Text),
460            "LENGTH" | "CHAR_LENGTH" => Ok(DataType::Integer),
461            "ABS" | "CEIL" | "FLOOR" | "ROUND" => Ok(DataType::Double),
462            _ => Ok(DataType::Text),
463        }
464    }
465
466    fn expr_name(&self, expr: &Expression) -> String {
467        match expr {
468            Expression::Column(col) => col.column.clone(),
469            Expression::Function { name, .. } => name.clone(),
470            Expression::Literal(lit) => match lit {
471                Literal::Integer(i) => i.to_string(),
472                Literal::Float(f) => f.to_string(),
473                Literal::String(s) => s.clone(),
474                Literal::Boolean(b) => b.to_string(),
475                Literal::Null => "NULL".to_string(),
476            },
477            _ => "?column?".to_string(),
478        }
479    }
480}
481
482// =============================================================================
483// Analyzed Output
484// =============================================================================
485
486/// Result of semantic analysis.
487#[derive(Debug, Clone)]
488pub enum AnalyzedStatement {
489    Select(AnalyzedSelect),
490    Insert(InsertStatement),
491    Update(UpdateStatement),
492    Delete(DeleteStatement),
493    CreateTable(CreateTableStatement),
494    DropTable(DropTableStatement),
495    AlterTable(AlterTableStatement),
496    CreateIndex(CreateIndexStatement),
497    DropIndex(DropIndexStatement),
498    SetOperation {
499        op: SetOperationType,
500        left: Box<AnalyzedStatement>,
501        right: Box<AnalyzedStatement>,
502    },
503    Begin,
504    Commit,
505    Rollback,
506}
507
508/// Analyzed SELECT statement with resolved types.
509#[derive(Debug, Clone)]
510pub struct AnalyzedSelect {
511    pub statement: SelectStatement,
512    pub output_columns: Vec<OutputColumn>,
513}
514
515/// Output column with resolved type.
516#[derive(Debug, Clone)]
517pub struct OutputColumn {
518    pub name: String,
519    pub data_type: DataType,
520}
521
522// =============================================================================
523// Tests
524// =============================================================================
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use crate::parser::Parser;
530
531    fn create_test_catalog() -> Catalog {
532        let mut catalog = Catalog::new();
533
534        let mut users = TableSchema::new("users");
535        users.add_column(ColumnSchema {
536            name: "id".to_string(),
537            data_type: DataType::Integer,
538            nullable: false,
539        });
540        users.add_column(ColumnSchema {
541            name: "name".to_string(),
542            data_type: DataType::Varchar(255),
543            nullable: true,
544        });
545        users.add_column(ColumnSchema {
546            name: "age".to_string(),
547            data_type: DataType::Integer,
548            nullable: true,
549        });
550        catalog.add_table(users);
551
552        catalog
553    }
554
555    #[test]
556    fn test_analyze_select() {
557        let catalog = create_test_catalog();
558        let analyzer = Analyzer::new(catalog);
559        let parser = Parser::new();
560
561        let stmt = parser.parse_single("SELECT id, name FROM users").unwrap();
562        let analyzed = analyzer.analyze(&stmt).unwrap();
563
564        match analyzed {
565            AnalyzedStatement::Select(select) => {
566                assert_eq!(select.output_columns.len(), 2);
567                assert_eq!(select.output_columns[0].name, "id");
568                assert_eq!(select.output_columns[0].data_type, DataType::Integer);
569            }
570            _ => panic!("Expected analyzed SELECT"),
571        }
572    }
573
574    #[test]
575    fn test_analyze_table_not_found() {
576        let catalog = create_test_catalog();
577        let analyzer = Analyzer::new(catalog);
578        let parser = Parser::new();
579
580        let stmt = parser.parse_single("SELECT * FROM nonexistent").unwrap();
581        let result = analyzer.analyze(&stmt);
582
583        assert!(matches!(result, Err(AegisError::TableNotFound(_))));
584    }
585
586    #[test]
587    fn test_analyze_column_not_found() {
588        let catalog = create_test_catalog();
589        let analyzer = Analyzer::new(catalog);
590        let parser = Parser::new();
591
592        let stmt = parser
593            .parse_single("SELECT nonexistent FROM users")
594            .unwrap();
595        let result = analyzer.analyze(&stmt);
596
597        assert!(matches!(result, Err(AegisError::ColumnNotFound(_))));
598    }
599}