1use crate::ast::*;
16use aegis_common::{AegisError, DataType, Result};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, Default)]
25pub struct Catalog {
26 tables: HashMap<String, TableSchema>,
27}
28
29impl Catalog {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn add_table(&mut self, schema: TableSchema) {
35 self.tables.insert(schema.name.clone(), schema);
36 }
37
38 pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
39 self.tables.get(name)
40 }
41
42 pub fn table_exists(&self, name: &str) -> bool {
43 self.tables.contains_key(name)
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct TableSchema {
50 pub name: String,
51 pub columns: Vec<ColumnSchema>,
52}
53
54impl TableSchema {
55 pub fn new(name: &str) -> Self {
56 Self {
57 name: name.to_string(),
58 columns: Vec::new(),
59 }
60 }
61
62 pub fn add_column(&mut self, column: ColumnSchema) {
63 self.columns.push(column);
64 }
65
66 pub fn get_column(&self, name: &str) -> Option<&ColumnSchema> {
67 self.columns.iter().find(|c| c.name == name)
68 }
69
70 pub fn column_exists(&self, name: &str) -> bool {
71 self.columns.iter().any(|c| c.name == name)
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct ColumnSchema {
78 pub name: String,
79 pub data_type: DataType,
80 pub nullable: bool,
81}
82
83#[derive(Debug)]
89#[allow(dead_code)]
90struct AnalysisContext<'a> {
91 catalog: &'a Catalog,
92 scope: Scope,
93}
94
95#[derive(Debug, Default)]
97struct Scope {
98 tables: HashMap<String, String>,
99 columns: HashMap<String, ResolvedColumn>,
100}
101
102#[derive(Debug, Clone)]
104struct ResolvedColumn {
105 table: String,
106 column: String,
107 data_type: DataType,
108}
109
110pub struct Analyzer {
116 catalog: Catalog,
117}
118
119impl Analyzer {
120 pub fn new(catalog: Catalog) -> Self {
121 Self { catalog }
122 }
123
124 pub fn analyze(&self, stmt: &Statement) -> Result<AnalyzedStatement> {
126 match stmt {
127 Statement::Select(select) => {
128 let analyzed = self.analyze_select(select)?;
129 Ok(AnalyzedStatement::Select(analyzed))
130 }
131 Statement::Insert(insert) => {
132 self.analyze_insert(insert)?;
133 Ok(AnalyzedStatement::Insert(insert.clone()))
134 }
135 Statement::Update(update) => {
136 self.analyze_update(update)?;
137 Ok(AnalyzedStatement::Update(update.clone()))
138 }
139 Statement::Delete(delete) => {
140 self.analyze_delete(delete)?;
141 Ok(AnalyzedStatement::Delete(delete.clone()))
142 }
143 Statement::CreateTable(create) => {
144 self.analyze_create_table(create)?;
145 Ok(AnalyzedStatement::CreateTable(create.clone()))
146 }
147 Statement::DropTable(drop) => Ok(AnalyzedStatement::DropTable(drop.clone())),
148 Statement::AlterTable(alter) => Ok(AnalyzedStatement::AlterTable(alter.clone())),
149 Statement::CreateIndex(create) => {
150 self.analyze_create_index(create)?;
151 Ok(AnalyzedStatement::CreateIndex(create.clone()))
152 }
153 Statement::DropIndex(drop) => Ok(AnalyzedStatement::DropIndex(drop.clone())),
154 Statement::SetOperation(set_op) => {
155 let left = self.analyze(set_op.left.as_ref())?;
157 let right = self.analyze(set_op.right.as_ref())?;
158 Ok(AnalyzedStatement::SetOperation {
159 op: set_op.op,
160 left: Box::new(left),
161 right: Box::new(right),
162 })
163 }
164 Statement::Begin => Ok(AnalyzedStatement::Begin),
165 Statement::Commit => Ok(AnalyzedStatement::Commit),
166 Statement::Rollback => Ok(AnalyzedStatement::Rollback),
167 }
168 }
169
170 fn analyze_select(&self, select: &SelectStatement) -> Result<AnalyzedSelect> {
171 let mut scope = Scope::default();
172
173 if let Some(ref from) = select.from {
174 self.build_scope_from_clause(from, &mut scope)?;
175 }
176
177 let mut output_columns = Vec::new();
178 for col in &select.columns {
179 match col {
180 SelectColumn::AllColumns => {
181 for resolved in scope.columns.values() {
182 output_columns.push(OutputColumn {
183 name: resolved.column.clone(),
184 data_type: resolved.data_type.clone(),
185 });
186 }
187 }
188 SelectColumn::TableAllColumns(table) => {
189 for resolved in scope.columns.values() {
190 if resolved.table == *table {
191 output_columns.push(OutputColumn {
192 name: resolved.column.clone(),
193 data_type: resolved.data_type.clone(),
194 });
195 }
196 }
197 }
198 SelectColumn::Expression { expr, alias } => {
199 let data_type = self.infer_type(expr, &scope)?;
200 let name = alias.clone().unwrap_or_else(|| self.expr_name(expr));
201 output_columns.push(OutputColumn { name, data_type });
202 }
203 }
204 }
205
206 if let Some(ref where_clause) = select.where_clause {
207 self.validate_expression(where_clause, &scope)?;
208 }
209
210 for expr in &select.group_by {
211 self.validate_expression(expr, &scope)?;
212 }
213
214 if let Some(ref having) = select.having {
215 self.validate_expression(having, &scope)?;
216 }
217
218 for order_by in &select.order_by {
219 self.validate_expression(&order_by.expression, &scope)?;
220 }
221
222 Ok(AnalyzedSelect {
223 statement: select.clone(),
224 output_columns,
225 })
226 }
227
228 fn analyze_insert(&self, insert: &InsertStatement) -> Result<()> {
229 let table = self
230 .catalog
231 .get_table(&insert.table)
232 .ok_or_else(|| AegisError::TableNotFound(insert.table.clone()))?;
233
234 if let Some(ref columns) = insert.columns {
235 for col_name in columns {
236 if !table.column_exists(col_name) {
237 return Err(AegisError::ColumnNotFound(col_name.clone()));
238 }
239 }
240 }
241
242 Ok(())
243 }
244
245 fn analyze_update(&self, update: &UpdateStatement) -> Result<()> {
246 let table = self
247 .catalog
248 .get_table(&update.table)
249 .ok_or_else(|| AegisError::TableNotFound(update.table.clone()))?;
250
251 for assignment in &update.assignments {
252 if !table.column_exists(&assignment.column) {
253 return Err(AegisError::ColumnNotFound(assignment.column.clone()));
254 }
255 }
256
257 Ok(())
258 }
259
260 fn analyze_delete(&self, delete: &DeleteStatement) -> Result<()> {
261 if !self.catalog.table_exists(&delete.table) {
262 return Err(AegisError::TableNotFound(delete.table.clone()));
263 }
264 Ok(())
265 }
266
267 fn analyze_create_table(&self, create: &CreateTableStatement) -> Result<()> {
268 if self.catalog.table_exists(&create.name) && !create.if_not_exists {
269 return Err(AegisError::ConstraintViolation(format!(
270 "Table '{}' already exists",
271 create.name
272 )));
273 }
274 Ok(())
275 }
276
277 fn analyze_create_index(&self, create: &CreateIndexStatement) -> Result<()> {
278 let table = self
279 .catalog
280 .get_table(&create.table)
281 .ok_or_else(|| AegisError::TableNotFound(create.table.clone()))?;
282
283 for col_name in &create.columns {
284 if !table.column_exists(col_name) {
285 return Err(AegisError::ColumnNotFound(col_name.clone()));
286 }
287 }
288
289 Ok(())
290 }
291
292 fn build_scope_from_clause(&self, from: &FromClause, scope: &mut Scope) -> Result<()> {
293 self.add_table_to_scope(&from.source, scope)?;
294
295 for join in &from.joins {
296 self.add_table_to_scope(&join.table, scope)?;
297 }
298
299 Ok(())
300 }
301
302 fn add_table_to_scope(&self, table_ref: &TableReference, scope: &mut Scope) -> Result<()> {
303 match table_ref {
304 TableReference::Table { name, alias } => {
305 let table = self
306 .catalog
307 .get_table(name)
308 .ok_or_else(|| AegisError::TableNotFound(name.clone()))?;
309
310 let alias_name = alias.as_ref().unwrap_or(name);
311 scope.tables.insert(alias_name.clone(), name.clone());
312
313 for col in &table.columns {
314 let key = format!("{}.{}", alias_name, col.name);
315 scope.columns.insert(
316 key.clone(),
317 ResolvedColumn {
318 table: alias_name.clone(),
319 column: col.name.clone(),
320 data_type: col.data_type.clone(),
321 },
322 );
323
324 if !scope.columns.contains_key(&col.name) {
325 scope.columns.insert(
326 col.name.clone(),
327 ResolvedColumn {
328 table: alias_name.clone(),
329 column: col.name.clone(),
330 data_type: col.data_type.clone(),
331 },
332 );
333 }
334 }
335 }
336 TableReference::Subquery { query: _, alias } => {
337 scope.tables.insert(alias.clone(), alias.clone());
338 }
339 }
340
341 Ok(())
342 }
343
344 fn validate_expression(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
345 self.infer_type(expr, scope)
346 }
347
348 fn infer_type(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
349 match expr {
350 Expression::Literal(lit) => Ok(self.literal_type(lit)),
351 Expression::Column(col_ref) => {
352 let key = if let Some(ref table) = col_ref.table {
353 format!("{}.{}", table, col_ref.column)
354 } else {
355 col_ref.column.clone()
356 };
357
358 scope
359 .columns
360 .get(&key)
361 .map(|r| r.data_type.clone())
362 .ok_or(AegisError::ColumnNotFound(key))
363 }
364 Expression::BinaryOp { left, op, right } => {
365 let left_type = self.infer_type(left, scope)?;
366 let right_type = self.infer_type(right, scope)?;
367 self.binary_op_type(&left_type, op, &right_type)
368 }
369 Expression::UnaryOp { op, expr } => {
370 let expr_type = self.infer_type(expr, scope)?;
371 self.unary_op_type(op, &expr_type)
372 }
373 Expression::Function { name, args, .. } => {
374 for arg in args {
375 self.infer_type(arg, scope)?;
376 }
377 self.function_return_type(name, args, scope)
378 }
379 Expression::IsNull { .. } => Ok(DataType::Boolean),
380 Expression::InList { .. } => Ok(DataType::Boolean),
381 Expression::Between { .. } => Ok(DataType::Boolean),
382 Expression::Like { .. } => Ok(DataType::Boolean),
383 Expression::Cast { data_type, .. } => Ok(data_type.clone()),
384 Expression::Case {
385 conditions,
386 else_result,
387 ..
388 } => {
389 if let Some((_, then_expr)) = conditions.first() {
390 self.infer_type(then_expr, scope)
391 } else if let Some(else_expr) = else_result {
392 self.infer_type(else_expr, scope)
393 } else {
394 Ok(DataType::Text)
395 }
396 }
397 Expression::Subquery(_) => Ok(DataType::Text),
398 Expression::InSubquery { .. } => Ok(DataType::Boolean),
399 Expression::Exists { .. } => Ok(DataType::Boolean),
400 Expression::Placeholder(_) => Ok(DataType::Text),
401 }
402 }
403
404 fn literal_type(&self, lit: &Literal) -> DataType {
405 match lit {
406 Literal::Null => DataType::Text,
407 Literal::Boolean(_) => DataType::Boolean,
408 Literal::Integer(_) => DataType::BigInt,
409 Literal::Float(_) => DataType::Double,
410 Literal::String(_) => DataType::Text,
411 }
412 }
413
414 fn binary_op_type(
415 &self,
416 left: &DataType,
417 op: &BinaryOperator,
418 _right: &DataType,
419 ) -> Result<DataType> {
420 match op {
421 BinaryOperator::Equal
422 | BinaryOperator::NotEqual
423 | BinaryOperator::LessThan
424 | BinaryOperator::LessThanOrEqual
425 | BinaryOperator::GreaterThan
426 | BinaryOperator::GreaterThanOrEqual
427 | BinaryOperator::And
428 | BinaryOperator::Or => Ok(DataType::Boolean),
429 BinaryOperator::Add
430 | BinaryOperator::Subtract
431 | BinaryOperator::Multiply
432 | BinaryOperator::Divide
433 | BinaryOperator::Modulo => Ok(left.clone()),
434 BinaryOperator::Concat => Ok(DataType::Text),
435 }
436 }
437
438 fn unary_op_type(&self, op: &UnaryOperator, expr_type: &DataType) -> Result<DataType> {
439 match op {
440 UnaryOperator::Not => Ok(DataType::Boolean),
441 UnaryOperator::Negative | UnaryOperator::Positive => Ok(expr_type.clone()),
442 }
443 }
444
445 fn function_return_type(
446 &self,
447 name: &str,
448 _args: &[Expression],
449 _scope: &Scope,
450 ) -> Result<DataType> {
451 let name_upper = name.to_uppercase();
452 match name_upper.as_str() {
453 "COUNT" => Ok(DataType::BigInt),
454 "SUM" | "AVG" => Ok(DataType::Double),
455 "MIN" | "MAX" => Ok(DataType::Double),
456 "COALESCE" | "NULLIF" => Ok(DataType::Text),
457 "NOW" | "CURRENT_TIMESTAMP" => Ok(DataType::Timestamp),
458 "CURRENT_DATE" => Ok(DataType::Date),
459 "UPPER" | "LOWER" | "TRIM" | "CONCAT" | "SUBSTRING" => Ok(DataType::Text),
460 "LENGTH" | "CHAR_LENGTH" => Ok(DataType::Integer),
461 "ABS" | "CEIL" | "FLOOR" | "ROUND" => Ok(DataType::Double),
462 _ => Ok(DataType::Text),
463 }
464 }
465
466 fn expr_name(&self, expr: &Expression) -> String {
467 match expr {
468 Expression::Column(col) => col.column.clone(),
469 Expression::Function { name, .. } => name.clone(),
470 Expression::Literal(lit) => match lit {
471 Literal::Integer(i) => i.to_string(),
472 Literal::Float(f) => f.to_string(),
473 Literal::String(s) => s.clone(),
474 Literal::Boolean(b) => b.to_string(),
475 Literal::Null => "NULL".to_string(),
476 },
477 _ => "?column?".to_string(),
478 }
479 }
480}
481
482#[derive(Debug, Clone)]
488pub enum AnalyzedStatement {
489 Select(AnalyzedSelect),
490 Insert(InsertStatement),
491 Update(UpdateStatement),
492 Delete(DeleteStatement),
493 CreateTable(CreateTableStatement),
494 DropTable(DropTableStatement),
495 AlterTable(AlterTableStatement),
496 CreateIndex(CreateIndexStatement),
497 DropIndex(DropIndexStatement),
498 SetOperation {
499 op: SetOperationType,
500 left: Box<AnalyzedStatement>,
501 right: Box<AnalyzedStatement>,
502 },
503 Begin,
504 Commit,
505 Rollback,
506}
507
508#[derive(Debug, Clone)]
510pub struct AnalyzedSelect {
511 pub statement: SelectStatement,
512 pub output_columns: Vec<OutputColumn>,
513}
514
515#[derive(Debug, Clone)]
517pub struct OutputColumn {
518 pub name: String,
519 pub data_type: DataType,
520}
521
522#[cfg(test)]
527mod tests {
528 use super::*;
529 use crate::parser::Parser;
530
531 fn create_test_catalog() -> Catalog {
532 let mut catalog = Catalog::new();
533
534 let mut users = TableSchema::new("users");
535 users.add_column(ColumnSchema {
536 name: "id".to_string(),
537 data_type: DataType::Integer,
538 nullable: false,
539 });
540 users.add_column(ColumnSchema {
541 name: "name".to_string(),
542 data_type: DataType::Varchar(255),
543 nullable: true,
544 });
545 users.add_column(ColumnSchema {
546 name: "age".to_string(),
547 data_type: DataType::Integer,
548 nullable: true,
549 });
550 catalog.add_table(users);
551
552 catalog
553 }
554
555 #[test]
556 fn test_analyze_select() {
557 let catalog = create_test_catalog();
558 let analyzer = Analyzer::new(catalog);
559 let parser = Parser::new();
560
561 let stmt = parser.parse_single("SELECT id, name FROM users").unwrap();
562 let analyzed = analyzer.analyze(&stmt).unwrap();
563
564 match analyzed {
565 AnalyzedStatement::Select(select) => {
566 assert_eq!(select.output_columns.len(), 2);
567 assert_eq!(select.output_columns[0].name, "id");
568 assert_eq!(select.output_columns[0].data_type, DataType::Integer);
569 }
570 _ => panic!("Expected analyzed SELECT"),
571 }
572 }
573
574 #[test]
575 fn test_analyze_table_not_found() {
576 let catalog = create_test_catalog();
577 let analyzer = Analyzer::new(catalog);
578 let parser = Parser::new();
579
580 let stmt = parser.parse_single("SELECT * FROM nonexistent").unwrap();
581 let result = analyzer.analyze(&stmt);
582
583 assert!(matches!(result, Err(AegisError::TableNotFound(_))));
584 }
585
586 #[test]
587 fn test_analyze_column_not_found() {
588 let catalog = create_test_catalog();
589 let analyzer = Analyzer::new(catalog);
590 let parser = Parser::new();
591
592 let stmt = parser
593 .parse_single("SELECT nonexistent FROM users")
594 .unwrap();
595 let result = analyzer.analyze(&stmt);
596
597 assert!(matches!(result, Err(AegisError::ColumnNotFound(_))));
598 }
599}