1use crate::ast::*;
16use aegis_common::{AegisError, DataType, Result};
17use std::collections::HashMap;
18
19#[derive(Debug, Clone, Default)]
25pub struct Catalog {
26 tables: HashMap<String, TableSchema>,
27}
28
29impl Catalog {
30 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub fn add_table(&mut self, schema: TableSchema) {
35 self.tables.insert(schema.name.clone(), schema);
36 }
37
38 pub fn get_table(&self, name: &str) -> Option<&TableSchema> {
39 self.tables.get(name)
40 }
41
42 pub fn table_exists(&self, name: &str) -> bool {
43 self.tables.contains_key(name)
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct TableSchema {
50 pub name: String,
51 pub columns: Vec<ColumnSchema>,
52}
53
54impl TableSchema {
55 pub fn new(name: &str) -> Self {
56 Self {
57 name: name.to_string(),
58 columns: Vec::new(),
59 }
60 }
61
62 pub fn add_column(&mut self, column: ColumnSchema) {
63 self.columns.push(column);
64 }
65
66 pub fn get_column(&self, name: &str) -> Option<&ColumnSchema> {
67 self.columns.iter().find(|c| c.name == name)
68 }
69
70 pub fn column_exists(&self, name: &str) -> bool {
71 self.columns.iter().any(|c| c.name == name)
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct ColumnSchema {
78 pub name: String,
79 pub data_type: DataType,
80 pub nullable: bool,
81}
82
83#[derive(Debug)]
89#[allow(dead_code)]
90struct AnalysisContext<'a> {
91 catalog: &'a Catalog,
92 scope: Scope,
93}
94
95#[derive(Debug, Default)]
97struct Scope {
98 tables: HashMap<String, String>,
99 columns: HashMap<String, ResolvedColumn>,
100}
101
102#[derive(Debug, Clone)]
104struct ResolvedColumn {
105 table: String,
106 column: String,
107 data_type: DataType,
108}
109
110pub struct Analyzer {
116 catalog: Catalog,
117}
118
119impl Analyzer {
120 pub fn new(catalog: Catalog) -> Self {
121 Self { catalog }
122 }
123
124 pub fn analyze(&self, stmt: &Statement) -> Result<AnalyzedStatement> {
126 match stmt {
127 Statement::Select(select) => {
128 let analyzed = self.analyze_select(select)?;
129 Ok(AnalyzedStatement::Select(analyzed))
130 }
131 Statement::Insert(insert) => {
132 self.analyze_insert(insert)?;
133 Ok(AnalyzedStatement::Insert(insert.clone()))
134 }
135 Statement::Update(update) => {
136 self.analyze_update(update)?;
137 Ok(AnalyzedStatement::Update(update.clone()))
138 }
139 Statement::Delete(delete) => {
140 self.analyze_delete(delete)?;
141 Ok(AnalyzedStatement::Delete(delete.clone()))
142 }
143 Statement::CreateTable(create) => {
144 self.analyze_create_table(create)?;
145 Ok(AnalyzedStatement::CreateTable(create.clone()))
146 }
147 Statement::DropTable(drop) => {
148 Ok(AnalyzedStatement::DropTable(drop.clone()))
149 }
150 Statement::AlterTable(alter) => {
151 Ok(AnalyzedStatement::AlterTable(alter.clone()))
152 }
153 Statement::CreateIndex(create) => {
154 self.analyze_create_index(create)?;
155 Ok(AnalyzedStatement::CreateIndex(create.clone()))
156 }
157 Statement::DropIndex(drop) => {
158 Ok(AnalyzedStatement::DropIndex(drop.clone()))
159 }
160 Statement::Begin => Ok(AnalyzedStatement::Begin),
161 Statement::Commit => Ok(AnalyzedStatement::Commit),
162 Statement::Rollback => Ok(AnalyzedStatement::Rollback),
163 }
164 }
165
166 fn analyze_select(&self, select: &SelectStatement) -> Result<AnalyzedSelect> {
167 let mut scope = Scope::default();
168
169 if let Some(ref from) = select.from {
170 self.build_scope_from_clause(from, &mut scope)?;
171 }
172
173 let mut output_columns = Vec::new();
174 for col in &select.columns {
175 match col {
176 SelectColumn::AllColumns => {
177 for resolved in scope.columns.values() {
178 output_columns.push(OutputColumn {
179 name: resolved.column.clone(),
180 data_type: resolved.data_type.clone(),
181 });
182 }
183 }
184 SelectColumn::TableAllColumns(table) => {
185 for resolved in scope.columns.values() {
186 if resolved.table == *table {
187 output_columns.push(OutputColumn {
188 name: resolved.column.clone(),
189 data_type: resolved.data_type.clone(),
190 });
191 }
192 }
193 }
194 SelectColumn::Expression { expr, alias } => {
195 let data_type = self.infer_type(expr, &scope)?;
196 let name = alias.clone().unwrap_or_else(|| self.expr_name(expr));
197 output_columns.push(OutputColumn { name, data_type });
198 }
199 }
200 }
201
202 if let Some(ref where_clause) = select.where_clause {
203 self.validate_expression(where_clause, &scope)?;
204 }
205
206 for expr in &select.group_by {
207 self.validate_expression(expr, &scope)?;
208 }
209
210 if let Some(ref having) = select.having {
211 self.validate_expression(having, &scope)?;
212 }
213
214 for order_by in &select.order_by {
215 self.validate_expression(&order_by.expression, &scope)?;
216 }
217
218 Ok(AnalyzedSelect {
219 statement: select.clone(),
220 output_columns,
221 })
222 }
223
224 fn analyze_insert(&self, insert: &InsertStatement) -> Result<()> {
225 let table = self.catalog.get_table(&insert.table).ok_or_else(|| {
226 AegisError::TableNotFound(insert.table.clone())
227 })?;
228
229 if let Some(ref columns) = insert.columns {
230 for col_name in columns {
231 if !table.column_exists(col_name) {
232 return Err(AegisError::ColumnNotFound(col_name.clone()));
233 }
234 }
235 }
236
237 Ok(())
238 }
239
240 fn analyze_update(&self, update: &UpdateStatement) -> Result<()> {
241 let table = self.catalog.get_table(&update.table).ok_or_else(|| {
242 AegisError::TableNotFound(update.table.clone())
243 })?;
244
245 for assignment in &update.assignments {
246 if !table.column_exists(&assignment.column) {
247 return Err(AegisError::ColumnNotFound(assignment.column.clone()));
248 }
249 }
250
251 Ok(())
252 }
253
254 fn analyze_delete(&self, delete: &DeleteStatement) -> Result<()> {
255 if !self.catalog.table_exists(&delete.table) {
256 return Err(AegisError::TableNotFound(delete.table.clone()));
257 }
258 Ok(())
259 }
260
261 fn analyze_create_table(&self, create: &CreateTableStatement) -> Result<()> {
262 if self.catalog.table_exists(&create.name) && !create.if_not_exists {
263 return Err(AegisError::ConstraintViolation(format!(
264 "Table '{}' already exists",
265 create.name
266 )));
267 }
268 Ok(())
269 }
270
271 fn analyze_create_index(&self, create: &CreateIndexStatement) -> Result<()> {
272 let table = self.catalog.get_table(&create.table).ok_or_else(|| {
273 AegisError::TableNotFound(create.table.clone())
274 })?;
275
276 for col_name in &create.columns {
277 if !table.column_exists(col_name) {
278 return Err(AegisError::ColumnNotFound(col_name.clone()));
279 }
280 }
281
282 Ok(())
283 }
284
285 fn build_scope_from_clause(&self, from: &FromClause, scope: &mut Scope) -> Result<()> {
286 self.add_table_to_scope(&from.source, scope)?;
287
288 for join in &from.joins {
289 self.add_table_to_scope(&join.table, scope)?;
290 }
291
292 Ok(())
293 }
294
295 fn add_table_to_scope(&self, table_ref: &TableReference, scope: &mut Scope) -> Result<()> {
296 match table_ref {
297 TableReference::Table { name, alias } => {
298 let table = self.catalog.get_table(name).ok_or_else(|| {
299 AegisError::TableNotFound(name.clone())
300 })?;
301
302 let alias_name = alias.as_ref().unwrap_or(name);
303 scope.tables.insert(alias_name.clone(), name.clone());
304
305 for col in &table.columns {
306 let key = format!("{}.{}", alias_name, col.name);
307 scope.columns.insert(
308 key.clone(),
309 ResolvedColumn {
310 table: alias_name.clone(),
311 column: col.name.clone(),
312 data_type: col.data_type.clone(),
313 },
314 );
315
316 if !scope.columns.contains_key(&col.name) {
317 scope.columns.insert(
318 col.name.clone(),
319 ResolvedColumn {
320 table: alias_name.clone(),
321 column: col.name.clone(),
322 data_type: col.data_type.clone(),
323 },
324 );
325 }
326 }
327 }
328 TableReference::Subquery { query: _, alias } => {
329 scope.tables.insert(alias.clone(), alias.clone());
330 }
331 }
332
333 Ok(())
334 }
335
336 fn validate_expression(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
337 self.infer_type(expr, scope)
338 }
339
340 fn infer_type(&self, expr: &Expression, scope: &Scope) -> Result<DataType> {
341 match expr {
342 Expression::Literal(lit) => Ok(self.literal_type(lit)),
343 Expression::Column(col_ref) => {
344 let key = if let Some(ref table) = col_ref.table {
345 format!("{}.{}", table, col_ref.column)
346 } else {
347 col_ref.column.clone()
348 };
349
350 scope
351 .columns
352 .get(&key)
353 .map(|r| r.data_type.clone())
354 .ok_or(AegisError::ColumnNotFound(key))
355 }
356 Expression::BinaryOp { left, op, right } => {
357 let left_type = self.infer_type(left, scope)?;
358 let right_type = self.infer_type(right, scope)?;
359 self.binary_op_type(&left_type, op, &right_type)
360 }
361 Expression::UnaryOp { op, expr } => {
362 let expr_type = self.infer_type(expr, scope)?;
363 self.unary_op_type(op, &expr_type)
364 }
365 Expression::Function { name, args, .. } => {
366 for arg in args {
367 self.infer_type(arg, scope)?;
368 }
369 self.function_return_type(name, args, scope)
370 }
371 Expression::IsNull { .. } => Ok(DataType::Boolean),
372 Expression::InList { .. } => Ok(DataType::Boolean),
373 Expression::Between { .. } => Ok(DataType::Boolean),
374 Expression::Like { .. } => Ok(DataType::Boolean),
375 Expression::Cast { data_type, .. } => Ok(data_type.clone()),
376 Expression::Case { conditions, else_result, .. } => {
377 if let Some((_, then_expr)) = conditions.first() {
378 self.infer_type(then_expr, scope)
379 } else if let Some(else_expr) = else_result {
380 self.infer_type(else_expr, scope)
381 } else {
382 Ok(DataType::Text)
383 }
384 }
385 Expression::Subquery(_) => Ok(DataType::Text),
386 Expression::InSubquery { .. } => Ok(DataType::Boolean),
387 Expression::Exists { .. } => Ok(DataType::Boolean),
388 Expression::Placeholder(_) => Ok(DataType::Text),
389 }
390 }
391
392 fn literal_type(&self, lit: &Literal) -> DataType {
393 match lit {
394 Literal::Null => DataType::Text,
395 Literal::Boolean(_) => DataType::Boolean,
396 Literal::Integer(_) => DataType::BigInt,
397 Literal::Float(_) => DataType::Double,
398 Literal::String(_) => DataType::Text,
399 }
400 }
401
402 fn binary_op_type(
403 &self,
404 left: &DataType,
405 op: &BinaryOperator,
406 _right: &DataType,
407 ) -> Result<DataType> {
408 match op {
409 BinaryOperator::Equal
410 | BinaryOperator::NotEqual
411 | BinaryOperator::LessThan
412 | BinaryOperator::LessThanOrEqual
413 | BinaryOperator::GreaterThan
414 | BinaryOperator::GreaterThanOrEqual
415 | BinaryOperator::And
416 | BinaryOperator::Or => Ok(DataType::Boolean),
417 BinaryOperator::Add
418 | BinaryOperator::Subtract
419 | BinaryOperator::Multiply
420 | BinaryOperator::Divide
421 | BinaryOperator::Modulo => Ok(left.clone()),
422 BinaryOperator::Concat => Ok(DataType::Text),
423 }
424 }
425
426 fn unary_op_type(&self, op: &UnaryOperator, expr_type: &DataType) -> Result<DataType> {
427 match op {
428 UnaryOperator::Not => Ok(DataType::Boolean),
429 UnaryOperator::Negative | UnaryOperator::Positive => Ok(expr_type.clone()),
430 }
431 }
432
433 fn function_return_type(
434 &self,
435 name: &str,
436 _args: &[Expression],
437 _scope: &Scope,
438 ) -> Result<DataType> {
439 let name_upper = name.to_uppercase();
440 match name_upper.as_str() {
441 "COUNT" => Ok(DataType::BigInt),
442 "SUM" | "AVG" => Ok(DataType::Double),
443 "MIN" | "MAX" => Ok(DataType::Double),
444 "COALESCE" | "NULLIF" => Ok(DataType::Text),
445 "NOW" | "CURRENT_TIMESTAMP" => Ok(DataType::Timestamp),
446 "CURRENT_DATE" => Ok(DataType::Date),
447 "UPPER" | "LOWER" | "TRIM" | "CONCAT" | "SUBSTRING" => Ok(DataType::Text),
448 "LENGTH" | "CHAR_LENGTH" => Ok(DataType::Integer),
449 "ABS" | "CEIL" | "FLOOR" | "ROUND" => Ok(DataType::Double),
450 _ => Ok(DataType::Text),
451 }
452 }
453
454 fn expr_name(&self, expr: &Expression) -> String {
455 match expr {
456 Expression::Column(col) => col.column.clone(),
457 Expression::Function { name, .. } => name.clone(),
458 Expression::Literal(lit) => match lit {
459 Literal::Integer(i) => i.to_string(),
460 Literal::Float(f) => f.to_string(),
461 Literal::String(s) => s.clone(),
462 Literal::Boolean(b) => b.to_string(),
463 Literal::Null => "NULL".to_string(),
464 },
465 _ => "?column?".to_string(),
466 }
467 }
468}
469
470#[derive(Debug, Clone)]
476pub enum AnalyzedStatement {
477 Select(AnalyzedSelect),
478 Insert(InsertStatement),
479 Update(UpdateStatement),
480 Delete(DeleteStatement),
481 CreateTable(CreateTableStatement),
482 DropTable(DropTableStatement),
483 AlterTable(AlterTableStatement),
484 CreateIndex(CreateIndexStatement),
485 DropIndex(DropIndexStatement),
486 Begin,
487 Commit,
488 Rollback,
489}
490
491#[derive(Debug, Clone)]
493pub struct AnalyzedSelect {
494 pub statement: SelectStatement,
495 pub output_columns: Vec<OutputColumn>,
496}
497
498#[derive(Debug, Clone)]
500pub struct OutputColumn {
501 pub name: String,
502 pub data_type: DataType,
503}
504
505#[cfg(test)]
510mod tests {
511 use super::*;
512 use crate::parser::Parser;
513
514 fn create_test_catalog() -> Catalog {
515 let mut catalog = Catalog::new();
516
517 let mut users = TableSchema::new("users");
518 users.add_column(ColumnSchema {
519 name: "id".to_string(),
520 data_type: DataType::Integer,
521 nullable: false,
522 });
523 users.add_column(ColumnSchema {
524 name: "name".to_string(),
525 data_type: DataType::Varchar(255),
526 nullable: true,
527 });
528 users.add_column(ColumnSchema {
529 name: "age".to_string(),
530 data_type: DataType::Integer,
531 nullable: true,
532 });
533 catalog.add_table(users);
534
535 catalog
536 }
537
538 #[test]
539 fn test_analyze_select() {
540 let catalog = create_test_catalog();
541 let analyzer = Analyzer::new(catalog);
542 let parser = Parser::new();
543
544 let stmt = parser.parse_single("SELECT id, name FROM users").unwrap();
545 let analyzed = analyzer.analyze(&stmt).unwrap();
546
547 match analyzed {
548 AnalyzedStatement::Select(select) => {
549 assert_eq!(select.output_columns.len(), 2);
550 assert_eq!(select.output_columns[0].name, "id");
551 assert_eq!(select.output_columns[0].data_type, DataType::Integer);
552 }
553 _ => panic!("Expected analyzed SELECT"),
554 }
555 }
556
557 #[test]
558 fn test_analyze_table_not_found() {
559 let catalog = create_test_catalog();
560 let analyzer = Analyzer::new(catalog);
561 let parser = Parser::new();
562
563 let stmt = parser.parse_single("SELECT * FROM nonexistent").unwrap();
564 let result = analyzer.analyze(&stmt);
565
566 assert!(matches!(result, Err(AegisError::TableNotFound(_))));
567 }
568
569 #[test]
570 fn test_analyze_column_not_found() {
571 let catalog = create_test_catalog();
572 let analyzer = Analyzer::new(catalog);
573 let parser = Parser::new();
574
575 let stmt = parser.parse_single("SELECT nonexistent FROM users").unwrap();
576 let result = analyzer.analyze(&stmt);
577
578 assert!(matches!(result, Err(AegisError::ColumnNotFound(_))));
579 }
580}