Skip to main content

aegis_query/
analyzer.rs

1//! Aegis Analyzer - Semantic Analysis
2//!
3//! Performs semantic analysis on parsed SQL statements. Validates table and
4//! column references, resolves types, and checks for constraint violations.
5//!
6//! Key Features:
7//! - Table and column resolution
8//! - Type checking and inference
9//! - Constraint validation
10//! - Scope management for subqueries
11//!
12//! @version 0.1.0
13//! @author AutomataNexus Development Team
14
15use crate::ast::*;
16use aegis_common::{AegisError, DataType, Result};
17use std::collections::HashMap;
18
19// =============================================================================
20// Catalog
21// =============================================================================
22
23/// Schema information for analysis.
24#[derive(Debug, Clone, Default)]
25pub struct Catalog {
26    tables: HashMap<String, TableSchema>,
27}
28
29impl Catalog {
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    pub fn add_table(&mut self, schema: TableSchema) {
35        self.tables.insert(schema.name.clone(), schema);
36    }
37
38    pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
39        self.tables.get(name)
40    }
41
42    pub fn table_exists(&self, name: &str) -> bool {
43        self.tables.contains_key(name)
44    }
45}
46
47/// Schema for a table.
48#[derive(Debug, Clone)]
49pub struct TableSchema {
50    pub name: String,
51    pub columns: Vec<ColumnSchema>,
52}
53
54impl TableSchema {
55    pub fn new(name: &str) -> Self {
56        Self {
57            name: name.to_string(),
58            columns: Vec::new(),
59        }
60    }
61
62    pub fn add_column(&mut self, column: ColumnSchema) {
63        self.columns.push(column);
64    }
65
66    pub fn get_column(&self, name: &str) -> Option<&ColumnSchema> {
67        self.columns.iter().find(|c| c.name == name)
68    }
69
70    pub fn column_exists(&self, name: &str) -> bool {
71        self.columns.iter().any(|c| c.name == name)
72    }
73}
74
75/// Schema for a column.
76#[derive(Debug, Clone)]
77pub struct ColumnSchema {
78    pub name: String,
79    pub data_type: DataType,
80    pub nullable: bool,
81}
82
83// =============================================================================
84// Analysis Context
85// =============================================================================
86
87/// Context for semantic analysis.
88#[derive(Debug)]
89#[allow(dead_code)]
90struct AnalysisContext<'a> {
91    catalog: &'a Catalog,
92    scope: Scope,
93}
94
95/// Scope for name resolution.
96#[derive(Debug, Default)]
97struct Scope {
98    tables: HashMap<String, String>,
99    columns: HashMap<String, ResolvedColumn>,
100}
101
102/// A resolved column reference.
103#[derive(Debug, Clone)]
104struct ResolvedColumn {
105    table: String,
106    column: String,
107    data_type: DataType,
108}
109
110// =============================================================================
111// Analyzer
112// =============================================================================
113
114/// Semantic analyzer for SQL statements.
115pub struct Analyzer {
116    catalog: Catalog,
117}
118
119impl Analyzer {
120    pub fn new(catalog: Catalog) -> Self {
121        Self { catalog }
122    }
123
124    /// Analyze a statement for semantic correctness.
125    pub fn analyze(&self, stmt: &Statement) -> Result<AnalyzedStatement> {
126        match stmt {
127            Statement::Select(select) => {
128                let analyzed = self.analyze_select(select)?;
129                Ok(AnalyzedStatement::Select(analyzed))
130            }
131            Statement::Insert(insert) => {
132                self.analyze_insert(insert)?;
133                Ok(AnalyzedStatement::Insert(insert.clone()))
134            }
135            Statement::Update(update) => {
136                self.analyze_update(update)?;
137                Ok(AnalyzedStatement::Update(update.clone()))
138            }
139            Statement::Delete(delete) => {
140                self.analyze_delete(delete)?;
141                Ok(AnalyzedStatement::Delete(delete.clone()))
142            }
143            Statement::CreateTable(create) => {
144                self.analyze_create_table(create)?;
145                Ok(AnalyzedStatement::CreateTable(create.clone()))
146            }
147            Statement::DropTable(drop) => {
148                Ok(AnalyzedStatement::DropTable(drop.clone()))
149            }
150            Statement::AlterTable(alter) => {
151                Ok(AnalyzedStatement::AlterTable(alter.clone()))
152            }
153            Statement::CreateIndex(create) => {
154                self.analyze_create_index(create)?;
155                Ok(AnalyzedStatement::CreateIndex(create.clone()))
156            }
157            Statement::DropIndex(drop) => {
158                Ok(AnalyzedStatement::DropIndex(drop.clone()))
159            }
160            Statement::Begin => Ok(AnalyzedStatement::Begin),
161            Statement::Commit => Ok(AnalyzedStatement::Commit),
162            Statement::Rollback => Ok(AnalyzedStatement::Rollback),
163        }
164    }
165
166    fn analyze_select(&self, select: &SelectStatement) -> Result<AnalyzedSelect> {
167        let mut scope = Scope::default();
168
169        if let Some(ref from) = select.from {
170            self.build_scope_from_clause(from, &mut scope)?;
171        }
172
173        let mut output_columns = Vec::new();
174        for col in &select.columns {
175            match col {
176                SelectColumn::AllColumns => {
177                    for resolved in scope.columns.values() {
178                        output_columns.push(OutputColumn {
179                            name: resolved.column.clone(),
180                            data_type: resolved.data_type.clone(),
181                        });
182                    }
183                }
184                SelectColumn::TableAllColumns(table) => {
185                    for resolved in scope.columns.values() {
186                        if resolved.table == *table {
187                            output_columns.push(OutputColumn {
188                                name: resolved.column.clone(),
189                                data_type: resolved.data_type.clone(),
190                            });
191                        }
192                    }
193                }
194                SelectColumn::Expression { expr, alias } => {
195                    let data_type = self.infer_type(expr, &scope)?;
196                    let name = alias.clone().unwrap_or_else(|| self.expr_name(expr));
197                    output_columns.push(OutputColumn { name, data_type });
198                }
199            }
200        }
201
202        if let Some(ref where_clause) = select.where_clause {
203            self.validate_expression(where_clause, &scope)?;
204        }
205
206        for expr in &select.group_by {
207            self.validate_expression(expr, &scope)?;
208        }
209
210        if let Some(ref having) = select.having {
211            self.validate_expression(having, &scope)?;
212        }
213
214        for order_by in &select.order_by {
215            self.validate_expression(&order_by.expression, &scope)?;
216        }
217
218        Ok(AnalyzedSelect {
219            statement: select.clone(),
220            output_columns,
221        })
222    }
223
224    fn analyze_insert(&self, insert: &InsertStatement) -> Result<()> {
225        let table = self.catalog.get_table(&insert.table).ok_or_else(|| {
226            AegisError::TableNotFound(insert.table.clone())
227        })?;
228
229        if let Some(ref columns) = insert.columns {
230            for col_name in columns {
231                if !table.column_exists(col_name) {
232                    return Err(AegisError::ColumnNotFound(col_name.clone()));
233                }
234            }
235        }
236
237        Ok(())
238    }
239
240    fn analyze_update(&self, update: &UpdateStatement) -> Result<()> {
241        let table = self.catalog.get_table(&update.table).ok_or_else(|| {
242            AegisError::TableNotFound(update.table.clone())
243        })?;
244
245        for assignment in &update.assignments {
246            if !table.column_exists(&assignment.column) {
247                return Err(AegisError::ColumnNotFound(assignment.column.clone()));
248            }
249        }
250
251        Ok(())
252    }
253
254    fn analyze_delete(&self, delete: &DeleteStatement) -> Result<()> {
255        if !self.catalog.table_exists(&delete.table) {
256            return Err(AegisError::TableNotFound(delete.table.clone()));
257        }
258        Ok(())
259    }
260
261    fn analyze_create_table(&self, create: &CreateTableStatement) -> Result<()> {
262        if self.catalog.table_exists(&create.name) && !create.if_not_exists {
263            return Err(AegisError::ConstraintViolation(format!(
264                "Table '{}' already exists",
265                create.name
266            )));
267        }
268        Ok(())
269    }
270
271    fn analyze_create_index(&self, create: &CreateIndexStatement) -> Result<()> {
272        let table = self.catalog.get_table(&create.table).ok_or_else(|| {
273            AegisError::TableNotFound(create.table.clone())
274        })?;
275
276        for col_name in &create.columns {
277            if !table.column_exists(col_name) {
278                return Err(AegisError::ColumnNotFound(col_name.clone()));
279            }
280        }
281
282        Ok(())
283    }
284
285    fn build_scope_from_clause(&self, from: &FromClause, scope: &mut Scope) -> Result<()> {
286        self.add_table_to_scope(&from.source, scope)?;
287
288        for join in &from.joins {
289            self.add_table_to_scope(&join.table, scope)?;
290        }
291
292        Ok(())
293    }
294
295    fn add_table_to_scope(&self, table_ref: &TableReference, scope: &mut Scope) -> Result<()> {
296        match table_ref {
297            TableReference::Table { name, alias } => {
298                let table = self.catalog.get_table(name).ok_or_else(|| {
299                    AegisError::TableNotFound(name.clone())
300                })?;
301
302                let alias_name = alias.as_ref().unwrap_or(name);
303                scope.tables.insert(alias_name.clone(), name.clone());
304
305                for col in &table.columns {
306                    let key = format!("{}.{}", alias_name, col.name);
307                    scope.columns.insert(
308                        key.clone(),
309                        ResolvedColumn {
310                            table: alias_name.clone(),
311                            column: col.name.clone(),
312                            data_type: col.data_type.clone(),
313                        },
314                    );
315
316                    if !scope.columns.contains_key(&col.name) {
317                        scope.columns.insert(
318                            col.name.clone(),
319                            ResolvedColumn {
320                                table: alias_name.clone(),
321                                column: col.name.clone(),
322                                data_type: col.data_type.clone(),
323                            },
324                        );
325                    }
326                }
327            }
328            TableReference::Subquery { query: _, alias } => {
329                scope.tables.insert(alias.clone(), alias.clone());
330            }
331        }
332
333        Ok(())
334    }
335
336    fn validate_expression(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
337        self.infer_type(expr, scope)
338    }
339
340    fn infer_type(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
341        match expr {
342            Expression::Literal(lit) => Ok(self.literal_type(lit)),
343            Expression::Column(col_ref) => {
344                let key = if let Some(ref table) = col_ref.table {
345                    format!("{}.{}", table, col_ref.column)
346                } else {
347                    col_ref.column.clone()
348                };
349
350                scope
351                    .columns
352                    .get(&key)
353                    .map(|r| r.data_type.clone())
354                    .ok_or(AegisError::ColumnNotFound(key))
355            }
356            Expression::BinaryOp { left, op, right } => {
357                let left_type = self.infer_type(left, scope)?;
358                let right_type = self.infer_type(right, scope)?;
359                self.binary_op_type(&left_type, op, &right_type)
360            }
361            Expression::UnaryOp { op, expr } => {
362                let expr_type = self.infer_type(expr, scope)?;
363                self.unary_op_type(op, &expr_type)
364            }
365            Expression::Function { name, args, .. } => {
366                for arg in args {
367                    self.infer_type(arg, scope)?;
368                }
369                self.function_return_type(name, args, scope)
370            }
371            Expression::IsNull { .. } => Ok(DataType::Boolean),
372            Expression::InList { .. } => Ok(DataType::Boolean),
373            Expression::Between { .. } => Ok(DataType::Boolean),
374            Expression::Like { .. } => Ok(DataType::Boolean),
375            Expression::Cast { data_type, .. } => Ok(data_type.clone()),
376            Expression::Case { conditions, else_result, .. } => {
377                if let Some((_, then_expr)) = conditions.first() {
378                    self.infer_type(then_expr, scope)
379                } else if let Some(else_expr) = else_result {
380                    self.infer_type(else_expr, scope)
381                } else {
382                    Ok(DataType::Text)
383                }
384            }
385            Expression::Subquery(_) => Ok(DataType::Text),
386            Expression::InSubquery { .. } => Ok(DataType::Boolean),
387            Expression::Exists { .. } => Ok(DataType::Boolean),
388            Expression::Placeholder(_) => Ok(DataType::Text),
389        }
390    }
391
392    fn literal_type(&self, lit: &Literal) -> DataType {
393        match lit {
394            Literal::Null => DataType::Text,
395            Literal::Boolean(_) => DataType::Boolean,
396            Literal::Integer(_) => DataType::BigInt,
397            Literal::Float(_) => DataType::Double,
398            Literal::String(_) => DataType::Text,
399        }
400    }
401
402    fn binary_op_type(
403        &self,
404        left: &DataType,
405        op: &BinaryOperator,
406        _right: &DataType,
407    ) -> Result<DataType> {
408        match op {
409            BinaryOperator::Equal
410            | BinaryOperator::NotEqual
411            | BinaryOperator::LessThan
412            | BinaryOperator::LessThanOrEqual
413            | BinaryOperator::GreaterThan
414            | BinaryOperator::GreaterThanOrEqual
415            | BinaryOperator::And
416            | BinaryOperator::Or => Ok(DataType::Boolean),
417            BinaryOperator::Add
418            | BinaryOperator::Subtract
419            | BinaryOperator::Multiply
420            | BinaryOperator::Divide
421            | BinaryOperator::Modulo => Ok(left.clone()),
422            BinaryOperator::Concat => Ok(DataType::Text),
423        }
424    }
425
426    fn unary_op_type(&self, op: &UnaryOperator, expr_type: &DataType) -> Result<DataType> {
427        match op {
428            UnaryOperator::Not => Ok(DataType::Boolean),
429            UnaryOperator::Negative | UnaryOperator::Positive => Ok(expr_type.clone()),
430        }
431    }
432
433    fn function_return_type(
434        &self,
435        name: &str,
436        _args: &[Expression],
437        _scope: &Scope,
438    ) -> Result<DataType> {
439        let name_upper = name.to_uppercase();
440        match name_upper.as_str() {
441            "COUNT" => Ok(DataType::BigInt),
442            "SUM" | "AVG" => Ok(DataType::Double),
443            "MIN" | "MAX" => Ok(DataType::Double),
444            "COALESCE" | "NULLIF" => Ok(DataType::Text),
445            "NOW" | "CURRENT_TIMESTAMP" => Ok(DataType::Timestamp),
446            "CURRENT_DATE" => Ok(DataType::Date),
447            "UPPER" | "LOWER" | "TRIM" | "CONCAT" | "SUBSTRING" => Ok(DataType::Text),
448            "LENGTH" | "CHAR_LENGTH" => Ok(DataType::Integer),
449            "ABS" | "CEIL" | "FLOOR" | "ROUND" => Ok(DataType::Double),
450            _ => Ok(DataType::Text),
451        }
452    }
453
454    fn expr_name(&self, expr: &Expression) -> String {
455        match expr {
456            Expression::Column(col) => col.column.clone(),
457            Expression::Function { name, .. } => name.clone(),
458            Expression::Literal(lit) => match lit {
459                Literal::Integer(i) => i.to_string(),
460                Literal::Float(f) => f.to_string(),
461                Literal::String(s) => s.clone(),
462                Literal::Boolean(b) => b.to_string(),
463                Literal::Null => "NULL".to_string(),
464            },
465            _ => "?column?".to_string(),
466        }
467    }
468}
469
470// =============================================================================
471// Analyzed Output
472// =============================================================================
473
474/// Result of semantic analysis.
475#[derive(Debug, Clone)]
476pub enum AnalyzedStatement {
477    Select(AnalyzedSelect),
478    Insert(InsertStatement),
479    Update(UpdateStatement),
480    Delete(DeleteStatement),
481    CreateTable(CreateTableStatement),
482    DropTable(DropTableStatement),
483    AlterTable(AlterTableStatement),
484    CreateIndex(CreateIndexStatement),
485    DropIndex(DropIndexStatement),
486    Begin,
487    Commit,
488    Rollback,
489}
490
491/// Analyzed SELECT statement with resolved types.
492#[derive(Debug, Clone)]
493pub struct AnalyzedSelect {
494    pub statement: SelectStatement,
495    pub output_columns: Vec<OutputColumn>,
496}
497
498/// Output column with resolved type.
499#[derive(Debug, Clone)]
500pub struct OutputColumn {
501    pub name: String,
502    pub data_type: DataType,
503}
504
505// =============================================================================
506// Tests
507// =============================================================================
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use crate::parser::Parser;
513
514    fn create_test_catalog() -> Catalog {
515        let mut catalog = Catalog::new();
516
517        let mut users = TableSchema::new("users");
518        users.add_column(ColumnSchema {
519            name: "id".to_string(),
520            data_type: DataType::Integer,
521            nullable: false,
522        });
523        users.add_column(ColumnSchema {
524            name: "name".to_string(),
525            data_type: DataType::Varchar(255),
526            nullable: true,
527        });
528        users.add_column(ColumnSchema {
529            name: "age".to_string(),
530            data_type: DataType::Integer,
531            nullable: true,
532        });
533        catalog.add_table(users);
534
535        catalog
536    }
537
538    #[test]
539    fn test_analyze_select() {
540        let catalog = create_test_catalog();
541        let analyzer = Analyzer::new(catalog);
542        let parser = Parser::new();
543
544        let stmt = parser.parse_single("SELECT id, name FROM users").unwrap();
545        let analyzed = analyzer.analyze(&stmt).unwrap();
546
547        match analyzed {
548            AnalyzedStatement::Select(select) => {
549                assert_eq!(select.output_columns.len(), 2);
550                assert_eq!(select.output_columns[0].name, "id");
551                assert_eq!(select.output_columns[0].data_type, DataType::Integer);
552            }
553            _ => panic!("Expected analyzed SELECT"),
554        }
555    }
556
557    #[test]
558    fn test_analyze_table_not_found() {
559        let catalog = create_test_catalog();
560        let analyzer = Analyzer::new(catalog);
561        let parser = Parser::new();
562
563        let stmt = parser.parse_single("SELECT * FROM nonexistent").unwrap();
564        let result = analyzer.analyze(&stmt);
565
566        assert!(matches!(result, Err(AegisError::TableNotFound(_))));
567    }
568
569    #[test]
570    fn test_analyze_column_not_found() {
571        let catalog = create_test_catalog();
572        let analyzer = Analyzer::new(catalog);
573        let parser = Parser::new();
574
575        let stmt = parser.parse_single("SELECT nonexistent FROM users").unwrap();
576        let result = analyzer.analyze(&stmt);
577
578        assert!(matches!(result, Err(AegisError::ColumnNotFound(_))));
579    }
580}