Skip to main content

cqlite_core/cql/
visitor.rs

1//! Visitor pattern implementations for CQL AST traversal
2//!
3//! This module provides default implementations and utilities for the visitor pattern,
4//! allowing easy traversal and transformation of CQL AST nodes.
5
6use super::ast::*;
7use super::traits::{CqlVisitor, ValidationContext, ValidationStrictness};
8use crate::error::{Error, Result};
9use crate::schema::{ClusteringColumn, Column, KeyColumn, TableSchema};
10use std::collections::HashMap;
11
12/// Render a `CqlDataType` as a CQL source string (e.g. `list<text>`, `map<uuid, int>`).
13pub(crate) fn cql_data_type_to_string(data_type: &CqlDataType) -> String {
14    match data_type {
15        CqlDataType::Boolean => "boolean".to_string(),
16        CqlDataType::TinyInt => "tinyint".to_string(),
17        CqlDataType::SmallInt => "smallint".to_string(),
18        CqlDataType::Int => "int".to_string(),
19        CqlDataType::BigInt => "bigint".to_string(),
20        CqlDataType::Varint => "varint".to_string(),
21        CqlDataType::Decimal => "decimal".to_string(),
22        CqlDataType::Float => "float".to_string(),
23        CqlDataType::Double => "double".to_string(),
24        CqlDataType::Text => "text".to_string(),
25        CqlDataType::Ascii => "ascii".to_string(),
26        CqlDataType::Varchar => "varchar".to_string(),
27        CqlDataType::Blob => "blob".to_string(),
28        CqlDataType::Timestamp => "timestamp".to_string(),
29        CqlDataType::Date => "date".to_string(),
30        CqlDataType::Time => "time".to_string(),
31        CqlDataType::Uuid => "uuid".to_string(),
32        CqlDataType::TimeUuid => "timeuuid".to_string(),
33        CqlDataType::Inet => "inet".to_string(),
34        CqlDataType::Duration => "duration".to_string(),
35        CqlDataType::Counter => "counter".to_string(),
36        CqlDataType::List(inner) => format!("list<{}>", cql_data_type_to_string(inner)),
37        CqlDataType::Set(inner) => format!("set<{}>", cql_data_type_to_string(inner)),
38        CqlDataType::Map(key, value) => format!(
39            "map<{}, {}>",
40            cql_data_type_to_string(key),
41            cql_data_type_to_string(value)
42        ),
43        CqlDataType::Tuple(types) => {
44            let type_strs: Vec<String> = types.iter().map(cql_data_type_to_string).collect();
45            format!("tuple<{}>", type_strs.join(", "))
46        }
47        CqlDataType::Udt(name) => name.as_str().to_string(),
48        CqlDataType::Frozen(inner) => format!("frozen<{}>", cql_data_type_to_string(inner)),
49        CqlDataType::Custom(name) => name.clone(),
50    }
51}
52
53/// Extract the inner `CqlIdentifier` from any `CqlIndexColumn` variant.
54fn index_column_identifier(column: &CqlIndexColumn) -> &CqlIdentifier {
55    match column {
56        CqlIndexColumn::Column(id)
57        | CqlIndexColumn::Keys(id)
58        | CqlIndexColumn::Values(id)
59        | CqlIndexColumn::Entries(id)
60        | CqlIndexColumn::Full(id) => id,
61    }
62}
63
64/// Default visitor implementation that traverses the entire AST
65///
66/// This visitor provides default implementations for all visit methods
67/// that recursively traverse child nodes. Implementations can override
68/// specific methods to handle particular node types.
69#[derive(Debug, Default)]
70pub struct DefaultVisitor;
71
72impl<T: Default> CqlVisitor<T> for DefaultVisitor {
73    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<T> {
74        match statement {
75            CqlStatement::Select(select) => self.visit_select(select),
76            CqlStatement::Insert(insert) => self.visit_insert(insert),
77            CqlStatement::Update(update) => self.visit_update(update),
78            CqlStatement::Delete(delete) => self.visit_delete(delete),
79            CqlStatement::CreateTable(create) => self.visit_create_table(create),
80            CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
81            CqlStatement::CreateIndex(create) => self.visit_create_index(create),
82            CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
83            CqlStatement::CreateType(_) => Ok(T::default()),
84            CqlStatement::DropType(_) => Ok(T::default()),
85            CqlStatement::Use(_) => Ok(T::default()),
86            CqlStatement::Truncate(_) => Ok(T::default()),
87            CqlStatement::Batch(_) => Ok(T::default()),
88        }
89    }
90
91    fn visit_select(&mut self, select: &CqlSelect) -> Result<T> {
92        for item in &select.select_list {
93            match item {
94                CqlSelectItem::Expression { expression, .. } => {
95                    let _: T = self.visit_expression(expression)?;
96                }
97                CqlSelectItem::Function { args, .. } => {
98                    for arg in args {
99                        let _: T = self.visit_expression(arg)?;
100                    }
101                }
102                CqlSelectItem::Wildcard => {}
103            }
104        }
105
106        if let Some(where_clause) = &select.where_clause {
107            let _: T = self.visit_expression(where_clause)?;
108        }
109
110        Ok(T::default())
111    }
112
113    fn visit_insert(&mut self, insert: &CqlInsert) -> Result<T> {
114        for column in &insert.columns {
115            let _: T = self.visit_identifier(column)?;
116        }
117
118        if let CqlInsertValues::Values(expressions) = &insert.values {
119            for expr in expressions {
120                let _: T = self.visit_expression(expr)?;
121            }
122        }
123
124        if let Some(using) = &insert.using {
125            if let Some(ttl) = &using.ttl {
126                let _: T = self.visit_expression(ttl)?;
127            }
128            if let Some(timestamp) = &using.timestamp {
129                let _: T = self.visit_expression(timestamp)?;
130            }
131        }
132
133        Ok(T::default())
134    }
135
136    fn visit_update(&mut self, update: &CqlUpdate) -> Result<T> {
137        for assignment in &update.assignments {
138            let _: T = self.visit_identifier(&assignment.column)?;
139            let _: T = self.visit_expression(&assignment.value)?;
140
141            if let CqlAssignmentOperator::MapUpdate(key_expr) = &assignment.operator {
142                let _: T = self.visit_expression(key_expr)?;
143            }
144        }
145
146        let _: T = self.visit_expression(&update.where_clause)?;
147
148        if let Some(if_condition) = &update.if_condition {
149            let _: T = self.visit_expression(if_condition)?;
150        }
151
152        if let Some(using) = &update.using {
153            if let Some(ttl) = &using.ttl {
154                let _: T = self.visit_expression(ttl)?;
155            }
156            if let Some(timestamp) = &using.timestamp {
157                let _: T = self.visit_expression(timestamp)?;
158            }
159        }
160
161        Ok(T::default())
162    }
163
164    fn visit_delete(&mut self, delete: &CqlDelete) -> Result<T> {
165        for column in &delete.columns {
166            let _: T = self.visit_identifier(column)?;
167        }
168
169        let _: T = self.visit_expression(&delete.where_clause)?;
170
171        if let Some(if_condition) = &delete.if_condition {
172            let _: T = self.visit_expression(if_condition)?;
173        }
174
175        if let Some(using) = &delete.using {
176            if let Some(timestamp) = &using.timestamp {
177                let _: T = self.visit_expression(timestamp)?;
178            }
179        }
180
181        Ok(T::default())
182    }
183
184    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<T> {
185        let _: T = self.visit_identifier(&create.table.name)?;
186        if let Some(keyspace) = &create.table.keyspace {
187            let _: T = self.visit_identifier(keyspace)?;
188        }
189
190        for column in &create.columns {
191            let _: T = self.visit_identifier(&column.name)?;
192            let _: T = self.visit_data_type(&column.data_type)?;
193        }
194
195        for pk_column in &create.primary_key.partition_key {
196            let _: T = self.visit_identifier(pk_column)?;
197        }
198        for ck_column in &create.primary_key.clustering_key {
199            let _: T = self.visit_identifier(ck_column)?;
200        }
201
202        Ok(T::default())
203    }
204
205    fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<T> {
206        let _: T = self.visit_identifier(&drop.table.name)?;
207        if let Some(keyspace) = &drop.table.keyspace {
208            let _: T = self.visit_identifier(keyspace)?;
209        }
210
211        Ok(T::default())
212    }
213
214    fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<T> {
215        if let Some(name) = &create.name {
216            let _: T = self.visit_identifier(name)?;
217        }
218
219        let _: T = self.visit_identifier(&create.table.name)?;
220        if let Some(keyspace) = &create.table.keyspace {
221            let _: T = self.visit_identifier(keyspace)?;
222        }
223
224        for column in &create.columns {
225            let _: T = self.visit_identifier(index_column_identifier(column))?;
226        }
227
228        Ok(T::default())
229    }
230
231    fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<T> {
232        let _: T = self.visit_identifier(&alter.table.name)?;
233        if let Some(keyspace) = &alter.table.keyspace {
234            let _: T = self.visit_identifier(keyspace)?;
235        }
236
237        match &alter.operation {
238            CqlAlterTableOp::AddColumn(column_def) => {
239                let _: T = self.visit_identifier(&column_def.name)?;
240                let _: T = self.visit_data_type(&column_def.data_type)?;
241            }
242            CqlAlterTableOp::DropColumn(column) => {
243                let _: T = self.visit_identifier(column)?;
244            }
245            CqlAlterTableOp::AlterColumn { column, new_type } => {
246                let _: T = self.visit_identifier(column)?;
247                let _: T = self.visit_data_type(new_type)?;
248            }
249            CqlAlterTableOp::RenameColumn { old_name, new_name } => {
250                let _: T = self.visit_identifier(old_name)?;
251                let _: T = self.visit_identifier(new_name)?;
252            }
253            CqlAlterTableOp::WithOptions(_) => {}
254        }
255
256        Ok(T::default())
257    }
258
259    fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<T> {
260        match data_type {
261            CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
262                let _: T = self.visit_data_type(inner)?;
263            }
264            CqlDataType::Map(key_type, value_type) => {
265                let _: T = self.visit_data_type(key_type)?;
266                let _: T = self.visit_data_type(value_type)?;
267            }
268            CqlDataType::Tuple(types) => {
269                for typ in types {
270                    let _: T = self.visit_data_type(typ)?;
271                }
272            }
273            CqlDataType::Udt(name) => {
274                let _: T = self.visit_identifier(name)?;
275            }
276            _ => {}
277        }
278
279        Ok(T::default())
280    }
281
282    fn visit_expression(&mut self, expression: &CqlExpression) -> Result<T> {
283        match expression {
284            CqlExpression::Literal(literal) => self.visit_literal(literal),
285            CqlExpression::Column(column) => self.visit_identifier(column),
286            CqlExpression::Parameter(_) | CqlExpression::NamedParameter(_) => Ok(T::default()),
287            CqlExpression::Binary { left, right, .. } => {
288                let _: T = self.visit_expression(left)?;
289                let _: T = self.visit_expression(right)?;
290                Ok(T::default())
291            }
292            CqlExpression::Unary { operand, .. } => {
293                let _: T = self.visit_expression(operand)?;
294                Ok(T::default())
295            }
296            CqlExpression::Function { name, args } => {
297                let _: T = self.visit_identifier(name)?;
298                for arg in args {
299                    let _: T = self.visit_expression(arg)?;
300                }
301                Ok(T::default())
302            }
303            CqlExpression::In { expression, values } => {
304                let _: T = self.visit_expression(expression)?;
305                for value in values {
306                    let _: T = self.visit_expression(value)?;
307                }
308                Ok(T::default())
309            }
310            CqlExpression::Contains { column, value } => {
311                let _: T = self.visit_identifier(column)?;
312                let _: T = self.visit_expression(value)?;
313                Ok(T::default())
314            }
315            CqlExpression::ContainsKey { column, key } => {
316                let _: T = self.visit_identifier(column)?;
317                let _: T = self.visit_expression(key)?;
318                Ok(T::default())
319            }
320            CqlExpression::CollectionAccess { collection, index } => {
321                let _: T = self.visit_expression(collection)?;
322                let _: T = self.visit_expression(index)?;
323                Ok(T::default())
324            }
325            CqlExpression::FieldAccess { object, field } => {
326                let _: T = self.visit_expression(object)?;
327                let _: T = self.visit_identifier(field)?;
328                Ok(T::default())
329            }
330            CqlExpression::Case {
331                when_clauses,
332                else_clause,
333            } => {
334                for when_clause in when_clauses {
335                    let _: T = self.visit_expression(&when_clause.condition)?;
336                    let _: T = self.visit_expression(&when_clause.result)?;
337                }
338                if let Some(else_expr) = else_clause {
339                    let _: T = self.visit_expression(else_expr)?;
340                }
341                Ok(T::default())
342            }
343            CqlExpression::Cast {
344                expression,
345                target_type,
346            } => {
347                let _: T = self.visit_expression(expression)?;
348                let _: T = self.visit_data_type(target_type)?;
349                Ok(T::default())
350            }
351        }
352    }
353
354    fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<T> {
355        Ok(T::default())
356    }
357
358    fn visit_literal(&mut self, literal: &CqlLiteral) -> Result<T> {
359        match literal {
360            CqlLiteral::Collection(collection) => match collection {
361                CqlCollectionLiteral::List(items) | CqlCollectionLiteral::Set(items) => {
362                    for item in items {
363                        let _: T = self.visit_literal(item)?;
364                    }
365                }
366                CqlCollectionLiteral::Map(pairs) => {
367                    for (key, value) in pairs {
368                        let _: T = self.visit_literal(key)?;
369                        let _: T = self.visit_literal(value)?;
370                    }
371                }
372            },
373            CqlLiteral::Udt(udt) => {
374                for (field_name, field_value) in &udt.fields {
375                    let _: T = self.visit_identifier(field_name)?;
376                    let _: T = self.visit_literal(field_value)?;
377                }
378            }
379            CqlLiteral::Tuple(items) => {
380                for item in items {
381                    let _: T = self.visit_literal(item)?;
382                }
383            }
384            _ => {}
385        }
386
387        Ok(T::default())
388    }
389}
390
391/// Visitor that collects all identifiers in an AST node
392#[derive(Debug, Default)]
393pub struct IdentifierCollector {
394    pub identifiers: Vec<CqlIdentifier>,
395}
396
397impl IdentifierCollector {
398    pub fn new() -> Self {
399        Self {
400            identifiers: Vec::new(),
401        }
402    }
403
404    pub fn into_identifiers(self) -> Vec<CqlIdentifier> {
405        self.identifiers
406    }
407}
408
409impl CqlVisitor<()> for IdentifierCollector {
410    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
411        match statement {
412            CqlStatement::Select(select) => self.visit_select(select),
413            CqlStatement::Insert(insert) => self.visit_insert(insert),
414            CqlStatement::Update(update) => self.visit_update(update),
415            CqlStatement::Delete(delete) => self.visit_delete(delete),
416            CqlStatement::CreateTable(create) => self.visit_create_table(create),
417            CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
418            CqlStatement::CreateIndex(create) => self.visit_create_index(create),
419            CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
420            _ => Ok(()),
421        }
422    }
423
424    fn visit_select(&mut self, select: &CqlSelect) -> Result<()> {
425        for item in &select.select_list {
426            match item {
427                CqlSelectItem::Expression { expression, .. } => {
428                    self.visit_expression(expression)?;
429                }
430                CqlSelectItem::Function { args, .. } => {
431                    for arg in args {
432                        self.visit_expression(arg)?;
433                    }
434                }
435                CqlSelectItem::Wildcard => {}
436            }
437        }
438
439        self.visit_identifier(&select.from.name)?;
440        if let Some(keyspace) = &select.from.keyspace {
441            self.visit_identifier(keyspace)?;
442        }
443
444        if let Some(where_clause) = &select.where_clause {
445            self.visit_expression(where_clause)?;
446        }
447
448        Ok(())
449    }
450
451    fn visit_insert(&mut self, insert: &CqlInsert) -> Result<()> {
452        self.visit_identifier(&insert.table.name)?;
453        if let Some(keyspace) = &insert.table.keyspace {
454            self.visit_identifier(keyspace)?;
455        }
456
457        for column in &insert.columns {
458            self.visit_identifier(column)?;
459        }
460
461        if let CqlInsertValues::Values(values) = &insert.values {
462            for value in values {
463                self.visit_expression(value)?;
464            }
465        }
466
467        Ok(())
468    }
469
470    fn visit_update(&mut self, update: &CqlUpdate) -> Result<()> {
471        self.visit_identifier(&update.table.name)?;
472        if let Some(keyspace) = &update.table.keyspace {
473            self.visit_identifier(keyspace)?;
474        }
475
476        for assignment in &update.assignments {
477            self.visit_identifier(&assignment.column)?;
478            self.visit_expression(&assignment.value)?;
479        }
480
481        self.visit_expression(&update.where_clause)?;
482
483        Ok(())
484    }
485
486    fn visit_delete(&mut self, delete: &CqlDelete) -> Result<()> {
487        self.visit_identifier(&delete.table.name)?;
488        if let Some(keyspace) = &delete.table.keyspace {
489            self.visit_identifier(keyspace)?;
490        }
491
492        self.visit_expression(&delete.where_clause)?;
493
494        Ok(())
495    }
496
497    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
498        self.visit_identifier(&create.table.name)?;
499        if let Some(keyspace) = &create.table.keyspace {
500            self.visit_identifier(keyspace)?;
501        }
502
503        for column in &create.columns {
504            self.visit_identifier(&column.name)?;
505            self.visit_data_type(&column.data_type)?;
506        }
507
508        for pk_col in &create.primary_key.partition_key {
509            self.visit_identifier(pk_col)?;
510        }
511        for ck_col in &create.primary_key.clustering_key {
512            self.visit_identifier(ck_col)?;
513        }
514
515        Ok(())
516    }
517
518    fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
519        self.visit_identifier(&drop.table.name)?;
520        if let Some(keyspace) = &drop.table.keyspace {
521            self.visit_identifier(keyspace)?;
522        }
523        Ok(())
524    }
525
526    fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<()> {
527        if let Some(index_name) = &create.name {
528            self.visit_identifier(index_name)?;
529        }
530        self.visit_identifier(&create.table.name)?;
531        if let Some(keyspace) = &create.table.keyspace {
532            self.visit_identifier(keyspace)?;
533        }
534        for column in &create.columns {
535            self.visit_identifier(index_column_identifier(column))?;
536        }
537        Ok(())
538    }
539
540    fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
541        self.visit_identifier(&alter.table.name)?;
542        if let Some(keyspace) = &alter.table.keyspace {
543            self.visit_identifier(keyspace)?;
544        }
545
546        match &alter.operation {
547            CqlAlterTableOp::AddColumn(column_def) => {
548                self.visit_identifier(&column_def.name)?;
549                self.visit_data_type(&column_def.data_type)?;
550            }
551            CqlAlterTableOp::DropColumn(column_name) => {
552                self.visit_identifier(column_name)?;
553            }
554            CqlAlterTableOp::AlterColumn { column, new_type } => {
555                self.visit_identifier(column)?;
556                self.visit_data_type(new_type)?;
557            }
558            CqlAlterTableOp::RenameColumn { old_name, new_name } => {
559                self.visit_identifier(old_name)?;
560                self.visit_identifier(new_name)?;
561            }
562            _ => {}
563        }
564
565        Ok(())
566    }
567
568    fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
569        match data_type {
570            CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
571                self.visit_data_type(inner)?;
572            }
573            CqlDataType::Map(key, value) => {
574                self.visit_data_type(key)?;
575                self.visit_data_type(value)?;
576            }
577            CqlDataType::Udt(name) => {
578                self.visit_identifier(name)?;
579            }
580            _ => {}
581        }
582        Ok(())
583    }
584
585    fn visit_expression(&mut self, expression: &CqlExpression) -> Result<()> {
586        match expression {
587            CqlExpression::Column(identifier) => {
588                self.visit_identifier(identifier)?;
589            }
590            CqlExpression::Literal(literal) => {
591                self.visit_literal(literal)?;
592            }
593            CqlExpression::Function { name, args } => {
594                self.visit_identifier(name)?;
595                for arg in args {
596                    self.visit_expression(arg)?;
597                }
598            }
599            CqlExpression::Binary { left, right, .. } => {
600                self.visit_expression(left)?;
601                self.visit_expression(right)?;
602            }
603            CqlExpression::Unary { operand, .. } => {
604                self.visit_expression(operand)?;
605            }
606            CqlExpression::In { expression, values } => {
607                self.visit_expression(expression)?;
608                for value in values {
609                    self.visit_expression(value)?;
610                }
611            }
612            CqlExpression::Contains { column, value } => {
613                self.visit_identifier(column)?;
614                self.visit_expression(value)?;
615            }
616            CqlExpression::ContainsKey { column, key } => {
617                self.visit_identifier(column)?;
618                self.visit_expression(key)?;
619            }
620            CqlExpression::CollectionAccess { collection, index } => {
621                self.visit_expression(collection)?;
622                self.visit_expression(index)?;
623            }
624            CqlExpression::FieldAccess { object, field } => {
625                self.visit_expression(object)?;
626                self.visit_identifier(field)?;
627            }
628            CqlExpression::Case {
629                when_clauses,
630                else_clause,
631            } => {
632                for when_clause in when_clauses {
633                    self.visit_expression(&when_clause.condition)?;
634                    self.visit_expression(&when_clause.result)?;
635                }
636                if let Some(else_expr) = else_clause {
637                    self.visit_expression(else_expr)?;
638                }
639            }
640            CqlExpression::Cast {
641                expression,
642                target_type,
643            } => {
644                self.visit_expression(expression)?;
645                self.visit_data_type(target_type)?;
646            }
647            CqlExpression::Parameter(_) | CqlExpression::NamedParameter(_) => {}
648        }
649        Ok(())
650    }
651
652    fn visit_identifier(&mut self, identifier: &CqlIdentifier) -> Result<()> {
653        self.identifiers.push(identifier.clone());
654        Ok(())
655    }
656
657    fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
658        Ok(())
659    }
660}
661
662/// Visitor that validates semantic correctness of CQL statements
663#[derive(Debug)]
664pub struct SemanticValidator {
665    pub context: ValidationContext,
666    pub errors: Vec<String>,
667}
668
669impl SemanticValidator {
670    /// Create a new semantic validator with the given context
671    pub fn new(context: ValidationContext) -> Self {
672        Self {
673            context,
674            errors: Vec::new(),
675        }
676    }
677
678    fn add_error(&mut self, message: String) {
679        self.errors.push(message);
680    }
681
682    /// Check if validation passed (no errors)
683    pub fn is_valid(&self) -> bool {
684        self.errors.is_empty()
685    }
686
687    /// Get all validation errors
688    pub fn get_errors(&self) -> &[String] {
689        &self.errors
690    }
691
692    fn is_strict(&self) -> bool {
693        matches!(self.context.strictness, ValidationStrictness::Strict)
694    }
695
696    /// Record an error if `table` is unknown and strict validation is enabled.
697    fn check_table_exists(&mut self, table: &CqlTable) {
698        let name = table.full_name();
699        if !self.context.schemas.contains_key(&name) && self.is_strict() {
700            self.add_error(format!("Table '{}' does not exist", name));
701        }
702    }
703}
704
705impl CqlVisitor<()> for SemanticValidator {
706    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
707        match statement {
708            CqlStatement::Select(select) => self.visit_select(select),
709            CqlStatement::Insert(insert) => self.visit_insert(insert),
710            CqlStatement::Update(update) => self.visit_update(update),
711            CqlStatement::Delete(delete) => self.visit_delete(delete),
712            CqlStatement::CreateTable(create) => self.visit_create_table(create),
713            CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
714            CqlStatement::CreateIndex(create) => self.visit_create_index(create),
715            CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
716            CqlStatement::CreateType(_) => Ok(()),
717            CqlStatement::DropType(_) => Ok(()),
718            CqlStatement::Use(_) => Ok(()),
719            CqlStatement::Truncate(_) => Ok(()),
720            CqlStatement::Batch(_) => Ok(()),
721        }
722    }
723
724    fn visit_select(&mut self, select: &CqlSelect) -> Result<()> {
725        self.check_table_exists(&select.from);
726        DefaultVisitor.visit_select(select)
727    }
728
729    fn visit_insert(&mut self, insert: &CqlInsert) -> Result<()> {
730        let table_name = insert.table.full_name();
731        if self.context.schemas.contains_key(&table_name) {
732            if let CqlInsertValues::Values(values) = &insert.values {
733                if insert.columns.len() != values.len() {
734                    self.add_error(format!(
735                        "Column count ({}) does not match value count ({})",
736                        insert.columns.len(),
737                        values.len()
738                    ));
739                }
740            }
741        } else if self.is_strict() {
742            self.add_error(format!("Table '{}' does not exist", table_name));
743        }
744
745        DefaultVisitor.visit_insert(insert)
746    }
747
748    fn visit_update(&mut self, update: &CqlUpdate) -> Result<()> {
749        self.check_table_exists(&update.table);
750        DefaultVisitor.visit_update(update)
751    }
752
753    fn visit_delete(&mut self, delete: &CqlDelete) -> Result<()> {
754        self.check_table_exists(&delete.table);
755        DefaultVisitor.visit_delete(delete)
756    }
757
758    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
759        let mut column_names = std::collections::HashSet::new();
760        for column in &create.columns {
761            let name = column.name.as_str();
762            if !column_names.insert(name) {
763                self.add_error(format!("Duplicate column name: '{}'", name));
764            }
765        }
766
767        for pk_column in &create.primary_key.partition_key {
768            let name = pk_column.as_str();
769            if !create.columns.iter().any(|c| c.name.as_str() == name) {
770                self.add_error(format!(
771                    "Partition key column '{}' not found in column definitions",
772                    name
773                ));
774            }
775        }
776
777        for ck_column in &create.primary_key.clustering_key {
778            let name = ck_column.as_str();
779            if !create.columns.iter().any(|c| c.name.as_str() == name) {
780                self.add_error(format!(
781                    "Clustering key column '{}' not found in column definitions",
782                    name
783                ));
784            }
785        }
786
787        DefaultVisitor.visit_create_table(create)
788    }
789
790    fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
791        if !drop.if_exists {
792            self.check_table_exists(&drop.table);
793        }
794        DefaultVisitor.visit_drop_table(drop)
795    }
796
797    fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<()> {
798        self.check_table_exists(&create.table);
799        DefaultVisitor.visit_create_index(create)
800    }
801
802    fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
803        self.check_table_exists(&alter.table);
804        DefaultVisitor.visit_alter_table(alter)
805    }
806
807    fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
808        if let CqlDataType::Udt(udt_name) = data_type {
809            let udt_key = udt_name.as_str();
810            if !self.context.udts.contains_key(udt_key) && self.is_strict() {
811                self.add_error(format!("UDT '{}' does not exist", udt_key));
812            }
813        }
814
815        DefaultVisitor.visit_data_type(data_type)
816    }
817
818    fn visit_expression(&mut self, expression: &CqlExpression) -> Result<()> {
819        DefaultVisitor.visit_expression(expression)
820    }
821
822    fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
823        Ok(())
824    }
825
826    fn visit_literal(&mut self, literal: &CqlLiteral) -> Result<()> {
827        DefaultVisitor.visit_literal(literal)
828    }
829}
830
831/// Type alias for AST transformation function
832pub type TransformationFn = Box<dyn Fn(&CqlStatement) -> Option<CqlStatement>>;
833
834/// Visitor that transforms AST nodes
835pub struct AstTransformer {
836    /// Transformations to apply
837    pub transformations: Vec<TransformationFn>,
838}
839
840impl std::fmt::Debug for AstTransformer {
841    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
842        f.debug_struct("AstTransformer")
843            .field(
844                "transformations",
845                &format!("[{} transformations]", self.transformations.len()),
846            )
847            .finish()
848    }
849}
850
851impl AstTransformer {
852    /// Create a new AST transformer
853    pub fn new() -> Self {
854        Self {
855            transformations: Vec::new(),
856        }
857    }
858
859    /// Add a transformation function
860    pub fn add_transformation<F>(&mut self, transform: F)
861    where
862        F: Fn(&CqlStatement) -> Option<CqlStatement> + 'static,
863    {
864        self.transformations.push(Box::new(transform));
865    }
866
867    /// Apply all transformations to a statement
868    pub fn transform(&self, statement: &CqlStatement) -> CqlStatement {
869        let mut result = statement.clone();
870
871        for transformation in &self.transformations {
872            if let Some(transformed) = transformation(&result) {
873                result = transformed;
874            }
875        }
876
877        result
878    }
879}
880
881impl Default for AstTransformer {
882    fn default() -> Self {
883        Self::new()
884    }
885}
886
887/// Utility functions for working with visitors
888pub mod utils {
889    use super::*;
890
891    /// Collect all table references in a statement
892    pub fn collect_table_references(statement: &CqlStatement) -> Vec<String> {
893        let table = match statement {
894            CqlStatement::Select(select) => &select.from,
895            CqlStatement::Insert(insert) => &insert.table,
896            CqlStatement::Update(update) => &update.table,
897            CqlStatement::Delete(delete) => &delete.table,
898            CqlStatement::CreateTable(create) => &create.table,
899            CqlStatement::DropTable(drop) => &drop.table,
900            CqlStatement::CreateIndex(create) => &create.table,
901            CqlStatement::AlterTable(alter) => &alter.table,
902            CqlStatement::Truncate(truncate) => &truncate.table,
903            CqlStatement::CreateType(_)
904            | CqlStatement::DropType(_)
905            | CqlStatement::Use(_)
906            | CqlStatement::Batch(_) => return Vec::new(),
907        };
908        vec![table.full_name()]
909    }
910
911    /// Check if a statement modifies data
912    pub fn is_modifying_statement(statement: &CqlStatement) -> bool {
913        matches!(
914            statement,
915            CqlStatement::Insert(_)
916                | CqlStatement::Update(_)
917                | CqlStatement::Delete(_)
918                | CqlStatement::CreateTable(_)
919                | CqlStatement::DropTable(_)
920                | CqlStatement::CreateIndex(_)
921                | CqlStatement::AlterTable(_)
922        )
923    }
924
925    /// Check if a statement is a data query
926    pub fn is_query_statement(statement: &CqlStatement) -> bool {
927        matches!(statement, CqlStatement::Select(_))
928    }
929
930    /// Check if a statement is a schema operation
931    pub fn is_schema_statement(statement: &CqlStatement) -> bool {
932        matches!(
933            statement,
934            CqlStatement::CreateTable(_)
935                | CqlStatement::DropTable(_)
936                | CqlStatement::CreateIndex(_)
937                | CqlStatement::AlterTable(_)
938        )
939    }
940}
941
942/// Visitor that converts CQL CREATE TABLE AST to TableSchema
943///
944/// This visitor extracts the business logic from the existing nom parser
945/// and converts AST structures to TableSchema objects.
946#[derive(Debug, Default)]
947pub struct SchemaBuilderVisitor;
948
949/// Error returned for every non-CREATE-TABLE visitor method on `SchemaBuilderVisitor`.
950fn schema_builder_unsupported(kind: &str) -> Error {
951    Error::invalid_input(format!("SchemaBuilderVisitor {}", kind))
952}
953
954/// Find the column definition matching `key` or return an invalid-input error.
955fn column_def_for<'a>(
956    create: &'a CqlCreateTable,
957    key: &CqlIdentifier,
958    role: &str,
959) -> Result<&'a CqlColumnDef> {
960    create
961        .columns
962        .iter()
963        .find(|col| col.name.as_str() == key.as_str())
964        .ok_or_else(|| {
965            Error::invalid_input(format!(
966                "{} key column '{}' not found in column definitions",
967                role,
968                key.as_str()
969            ))
970        })
971}
972
973impl CqlVisitor<TableSchema> for SchemaBuilderVisitor {
974    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<TableSchema> {
975        match statement {
976            CqlStatement::CreateTable(create) => self.visit_create_table(create),
977            _ => Err(schema_builder_unsupported(
978                "only supports CREATE TABLE statements",
979            )),
980        }
981    }
982
983    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<TableSchema> {
984        let table_name = create.table.name.as_str().to_string();
985        let keyspace = create
986            .table
987            .keyspace
988            .as_ref()
989            .map(|ks| ks.as_str().to_string())
990            .unwrap_or_else(|| "default".to_string());
991
992        let partition_keys = create
993            .primary_key
994            .partition_key
995            .iter()
996            .enumerate()
997            .map(|(pos, pk_col)| {
998                let column_def = column_def_for(create, pk_col, "Partition")?;
999                Ok(KeyColumn {
1000                    name: pk_col.as_str().to_string(),
1001                    data_type: cql_data_type_to_string(&column_def.data_type),
1002                    position: pos,
1003                })
1004            })
1005            .collect::<Result<Vec<_>>>()?;
1006
1007        let clustering_keys = create
1008            .primary_key
1009            .clustering_key
1010            .iter()
1011            .enumerate()
1012            .map(|(pos, ck_col)| {
1013                let column_def = column_def_for(create, ck_col, "Clustering")?;
1014                Ok(ClusteringColumn {
1015                    name: ck_col.as_str().to_string(),
1016                    data_type: cql_data_type_to_string(&column_def.data_type),
1017                    position: pos,
1018                    order: crate::schema::ClusteringOrder::Asc,
1019                })
1020            })
1021            .collect::<Result<Vec<_>>>()?;
1022
1023        let columns: Vec<Column> = create
1024            .columns
1025            .iter()
1026            .map(|col_def| Column {
1027                name: col_def.name.as_str().to_string(),
1028                data_type: cql_data_type_to_string(&col_def.data_type),
1029                nullable: true,
1030                default: None,
1031                is_static: col_def.is_static,
1032            })
1033            .collect();
1034
1035        Ok(TableSchema {
1036            keyspace,
1037            table: table_name,
1038            partition_keys,
1039            clustering_keys,
1040            columns,
1041            comments: HashMap::new(),
1042        })
1043    }
1044
1045    fn visit_select(&mut self, _select: &CqlSelect) -> Result<TableSchema> {
1046        Err(schema_builder_unsupported(
1047            "does not support SELECT statements",
1048        ))
1049    }
1050
1051    fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<TableSchema> {
1052        Err(schema_builder_unsupported(
1053            "does not support INSERT statements",
1054        ))
1055    }
1056
1057    fn visit_update(&mut self, _update: &CqlUpdate) -> Result<TableSchema> {
1058        Err(schema_builder_unsupported(
1059            "does not support UPDATE statements",
1060        ))
1061    }
1062
1063    fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<TableSchema> {
1064        Err(schema_builder_unsupported(
1065            "does not support DELETE statements",
1066        ))
1067    }
1068
1069    fn visit_drop_table(&mut self, _drop: &CqlDropTable) -> Result<TableSchema> {
1070        Err(schema_builder_unsupported(
1071            "does not support DROP TABLE statements",
1072        ))
1073    }
1074
1075    fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<TableSchema> {
1076        Err(schema_builder_unsupported(
1077            "does not support CREATE INDEX statements",
1078        ))
1079    }
1080
1081    fn visit_alter_table(&mut self, _alter: &CqlAlterTable) -> Result<TableSchema> {
1082        Err(schema_builder_unsupported(
1083            "does not support ALTER TABLE statements",
1084        ))
1085    }
1086
1087    fn visit_data_type(&mut self, _data_type: &CqlDataType) -> Result<TableSchema> {
1088        Err(schema_builder_unsupported(
1089            "does not support standalone data types",
1090        ))
1091    }
1092
1093    fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<TableSchema> {
1094        Err(schema_builder_unsupported("does not support expressions"))
1095    }
1096
1097    fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<TableSchema> {
1098        Err(schema_builder_unsupported("does not support identifiers"))
1099    }
1100
1101    fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<TableSchema> {
1102        Err(schema_builder_unsupported("does not support literals"))
1103    }
1104}
1105
1106impl SchemaBuilderVisitor {
1107    /// Create a new SchemaBuilderVisitor
1108    pub fn new() -> Self {
1109        Self
1110    }
1111}
1112
1113/// ValidationVisitor for AST validation
1114///
1115/// This visitor performs semantic validation of AST nodes beyond syntactic correctness.
1116#[derive(Debug, Default)]
1117pub struct ValidationVisitor {
1118    pub errors: Vec<String>,
1119}
1120
1121impl ValidationVisitor {
1122    pub fn new() -> Self {
1123        Self { errors: Vec::new() }
1124    }
1125
1126    pub fn has_errors(&self) -> bool {
1127        !self.errors.is_empty()
1128    }
1129
1130    pub fn get_errors(&self) -> &[String] {
1131        &self.errors
1132    }
1133
1134    fn add_error(&mut self, error: String) {
1135        self.errors.push(error);
1136    }
1137}
1138
1139impl CqlVisitor<()> for ValidationVisitor {
1140    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
1141        match statement {
1142            CqlStatement::CreateTable(create) => self.visit_create_table(create),
1143            CqlStatement::Select(select) => self.visit_select(select),
1144            CqlStatement::Insert(insert) => self.visit_insert(insert),
1145            CqlStatement::Update(update) => self.visit_update(update),
1146            CqlStatement::Delete(delete) => self.visit_delete(delete),
1147            CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
1148            CqlStatement::CreateIndex(create) => self.visit_create_index(create),
1149            CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
1150            _ => Ok(()), // Other statements not validated
1151        }
1152    }
1153
1154    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
1155        if create.table.name.as_str().is_empty() {
1156            self.add_error("Table name cannot be empty".to_string());
1157        }
1158
1159        for pk_col in &create.primary_key.partition_key {
1160            if !create
1161                .columns
1162                .iter()
1163                .any(|col| col.name.as_str() == pk_col.as_str())
1164            {
1165                self.add_error(format!(
1166                    "Partition key column '{}' not found in column definitions",
1167                    pk_col.as_str()
1168                ));
1169            }
1170        }
1171
1172        for ck_col in &create.primary_key.clustering_key {
1173            if !create
1174                .columns
1175                .iter()
1176                .any(|col| col.name.as_str() == ck_col.as_str())
1177            {
1178                self.add_error(format!(
1179                    "Clustering key column '{}' not found in column definitions",
1180                    ck_col.as_str()
1181                ));
1182            }
1183        }
1184
1185        let mut column_names = std::collections::HashSet::new();
1186        for column in &create.columns {
1187            let name = column.name.as_str();
1188            if !column_names.insert(name) {
1189                self.add_error(format!("Duplicate column name: '{}'", name));
1190            }
1191        }
1192
1193        if create.primary_key.partition_key.is_empty() {
1194            self.add_error("Table must have at least one partition key column".to_string());
1195        }
1196
1197        Ok(())
1198    }
1199
1200    fn visit_select(&mut self, _select: &CqlSelect) -> Result<()> {
1201        Ok(())
1202    }
1203
1204    fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<()> {
1205        Ok(())
1206    }
1207
1208    fn visit_update(&mut self, _update: &CqlUpdate) -> Result<()> {
1209        Ok(())
1210    }
1211
1212    fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<()> {
1213        Ok(())
1214    }
1215
1216    fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
1217        if drop.table.name.as_str().is_empty() {
1218            self.add_error("Table name cannot be empty".to_string());
1219        }
1220        Ok(())
1221    }
1222
1223    fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<()> {
1224        Ok(())
1225    }
1226
1227    fn visit_alter_table(&mut self, _alter: &CqlAlterTable) -> Result<()> {
1228        Ok(())
1229    }
1230
1231    fn visit_data_type(&mut self, _data_type: &CqlDataType) -> Result<()> {
1232        Ok(())
1233    }
1234
1235    fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<()> {
1236        Ok(())
1237    }
1238
1239    fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
1240        Ok(())
1241    }
1242
1243    fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
1244        Ok(())
1245    }
1246}
1247
1248/// TypeCollectorVisitor for collecting type information from AST
1249///
1250/// This visitor extracts all data types used in a statement for analysis.
1251#[derive(Debug, Default)]
1252pub struct TypeCollectorVisitor {
1253    pub types: Vec<CqlDataType>,
1254}
1255
1256impl TypeCollectorVisitor {
1257    pub fn new() -> Self {
1258        Self { types: Vec::new() }
1259    }
1260
1261    pub fn into_types(self) -> Vec<CqlDataType> {
1262        self.types
1263    }
1264
1265    fn collect_type(&mut self, data_type: &CqlDataType) {
1266        self.types.push(data_type.clone());
1267
1268        match data_type {
1269            CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
1270                self.collect_type(inner);
1271            }
1272            CqlDataType::Map(key, value) => {
1273                self.collect_type(key);
1274                self.collect_type(value);
1275            }
1276            CqlDataType::Tuple(types) => {
1277                for t in types {
1278                    self.collect_type(t);
1279                }
1280            }
1281            _ => {}
1282        }
1283    }
1284}
1285
1286impl CqlVisitor<()> for TypeCollectorVisitor {
1287    fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
1288        match statement {
1289            CqlStatement::CreateTable(create) => self.visit_create_table(create),
1290            _ => Ok(()),
1291        }
1292    }
1293
1294    fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
1295        for column in &create.columns {
1296            self.collect_type(&column.data_type);
1297        }
1298        Ok(())
1299    }
1300
1301    fn visit_select(&mut self, _select: &CqlSelect) -> Result<()> {
1302        Ok(())
1303    }
1304
1305    fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<()> {
1306        Ok(())
1307    }
1308
1309    fn visit_update(&mut self, _update: &CqlUpdate) -> Result<()> {
1310        Ok(())
1311    }
1312
1313    fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<()> {
1314        Ok(())
1315    }
1316
1317    fn visit_drop_table(&mut self, _drop: &CqlDropTable) -> Result<()> {
1318        Ok(())
1319    }
1320
1321    fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<()> {
1322        Ok(())
1323    }
1324
1325    fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
1326        match &alter.operation {
1327            CqlAlterTableOp::AddColumn(column_def) => {
1328                self.collect_type(&column_def.data_type);
1329            }
1330            CqlAlterTableOp::AlterColumn { new_type, .. } => {
1331                self.collect_type(new_type);
1332            }
1333            _ => {}
1334        }
1335        Ok(())
1336    }
1337
1338    fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
1339        self.collect_type(data_type);
1340        Ok(())
1341    }
1342
1343    fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<()> {
1344        Ok(())
1345    }
1346
1347    fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
1348        Ok(())
1349    }
1350
1351    fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
1352        Ok(())
1353    }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358    use super::*;
1359
1360    #[test]
1361    fn test_identifier_collector() {
1362        let statement = CqlStatement::Select(CqlSelect {
1363            distinct: false,
1364            select_list: vec![
1365                CqlSelectItem::Expression {
1366                    expression: CqlExpression::Column(CqlIdentifier::new("id")),
1367                    alias: None,
1368                },
1369                CqlSelectItem::Expression {
1370                    expression: CqlExpression::Column(CqlIdentifier::new("name")),
1371                    alias: None,
1372                },
1373            ],
1374            from: CqlTable::new("users"),
1375            where_clause: Some(CqlExpression::Binary {
1376                left: Box::new(CqlExpression::Column(CqlIdentifier::new("id"))),
1377                operator: CqlBinaryOperator::Eq,
1378                right: Box::new(CqlExpression::Parameter(1)),
1379            }),
1380            order_by: None,
1381            limit: None,
1382            allow_filtering: false,
1383        });
1384
1385        let mut collector = IdentifierCollector::default();
1386        collector.visit_statement(&statement).unwrap();
1387
1388        // Should collect: id, name, users, id (from WHERE clause)
1389        assert_eq!(collector.identifiers.len(), 4);
1390        assert_eq!(collector.identifiers[0].as_str(), "id");
1391        assert_eq!(collector.identifiers[1].as_str(), "name");
1392        assert_eq!(collector.identifiers[2].as_str(), "users");
1393        assert_eq!(collector.identifiers[3].as_str(), "id");
1394    }
1395
1396    #[test]
1397    fn test_semantic_validator() {
1398        let statement = CqlStatement::Insert(CqlInsert {
1399            table: CqlTable::new("users"),
1400            columns: vec![CqlIdentifier::new("id"), CqlIdentifier::new("name")],
1401            values: CqlInsertValues::Values(vec![
1402                CqlExpression::Parameter(1),
1403                // Missing second value - should cause validation error
1404            ]),
1405            if_not_exists: false,
1406            using: None,
1407        });
1408
1409        let context = ValidationContext::new();
1410        let mut validator = SemanticValidator::new(context);
1411        validator.visit_statement(&statement).unwrap();
1412
1413        // Should have validation errors
1414        assert!(!validator.is_valid());
1415        assert!(!validator.get_errors().is_empty());
1416    }
1417
1418    #[test]
1419    fn test_utils() {
1420        let statement = CqlStatement::Select(CqlSelect {
1421            distinct: false,
1422            select_list: vec![CqlSelectItem::Wildcard],
1423            from: CqlTable::with_keyspace("test", "users"),
1424            where_clause: None,
1425            order_by: None,
1426            limit: None,
1427            allow_filtering: false,
1428        });
1429
1430        let tables = utils::collect_table_references(&statement);
1431        assert_eq!(tables, vec!["test.users"]);
1432
1433        assert!(utils::is_query_statement(&statement));
1434        assert!(!utils::is_modifying_statement(&statement));
1435        assert!(!utils::is_schema_statement(&statement));
1436    }
1437
1438    #[test]
1439    fn test_schema_builder_visitor() {
1440        // Create a sample CREATE TABLE AST
1441        let create_table = CqlCreateTable {
1442            if_not_exists: false,
1443            table: CqlTable::with_keyspace("test_keyspace", "users"),
1444            columns: vec![
1445                CqlColumnDef {
1446                    name: CqlIdentifier::new("id"),
1447                    data_type: CqlDataType::Uuid,
1448                    is_static: false,
1449                },
1450                CqlColumnDef {
1451                    name: CqlIdentifier::new("name"),
1452                    data_type: CqlDataType::Text,
1453                    is_static: false,
1454                },
1455                CqlColumnDef {
1456                    name: CqlIdentifier::new("age"),
1457                    data_type: CqlDataType::Int,
1458                    is_static: false,
1459                },
1460                CqlColumnDef {
1461                    name: CqlIdentifier::new("tags"),
1462                    data_type: CqlDataType::List(Box::new(CqlDataType::Text)),
1463                    is_static: false,
1464                },
1465            ],
1466            primary_key: CqlPrimaryKey {
1467                partition_key: vec![CqlIdentifier::new("id")],
1468                clustering_key: vec![CqlIdentifier::new("name")],
1469            },
1470            options: CqlTableOptions {
1471                options: HashMap::new(),
1472            },
1473        };
1474
1475        let statement = CqlStatement::CreateTable(create_table);
1476        let mut visitor = SchemaBuilderVisitor;
1477        let schema = visitor.visit_statement(&statement).unwrap();
1478
1479        // Verify the schema was correctly built
1480        assert_eq!(schema.keyspace, "test_keyspace");
1481        assert_eq!(schema.table, "users");
1482        assert_eq!(schema.partition_keys.len(), 1);
1483        assert_eq!(schema.partition_keys[0].name, "id");
1484        assert_eq!(schema.partition_keys[0].data_type, "uuid");
1485        assert_eq!(schema.clustering_keys.len(), 1);
1486        assert_eq!(schema.clustering_keys[0].name, "name");
1487        assert_eq!(schema.clustering_keys[0].data_type, "text");
1488        assert_eq!(schema.columns.len(), 4);
1489
1490        // Check that list type was correctly converted
1491        let tags_column = schema
1492            .columns
1493            .iter()
1494            .find(|col| col.name == "tags")
1495            .expect("tags column should exist");
1496        assert_eq!(tags_column.data_type, "list<text>");
1497    }
1498
1499    #[test]
1500    fn test_validation_visitor() {
1501        // Create a CREATE TABLE AST with validation errors
1502        let create_table = CqlCreateTable {
1503            if_not_exists: false,
1504            table: CqlTable::new("test_table"),
1505            columns: vec![
1506                CqlColumnDef {
1507                    name: CqlIdentifier::new("id"),
1508                    data_type: CqlDataType::Uuid,
1509                    is_static: false,
1510                },
1511                CqlColumnDef {
1512                    name: CqlIdentifier::new("name"),
1513                    data_type: CqlDataType::Text,
1514                    is_static: false,
1515                },
1516                CqlColumnDef {
1517                    name: CqlIdentifier::new("name"), // Duplicate column name
1518                    data_type: CqlDataType::Int,
1519                    is_static: false,
1520                },
1521            ],
1522            primary_key: CqlPrimaryKey {
1523                partition_key: vec![CqlIdentifier::new("missing_column")], // Column doesn't exist
1524                clustering_key: vec![],
1525            },
1526            options: CqlTableOptions {
1527                options: HashMap::new(),
1528            },
1529        };
1530
1531        let statement = CqlStatement::CreateTable(create_table);
1532        let mut visitor = ValidationVisitor::new();
1533        let _ = visitor.visit_statement(&statement);
1534
1535        // Should have validation errors
1536        assert!(visitor.has_errors());
1537        let errors = visitor.get_errors();
1538        assert!(errors.iter().any(|e| e.contains("Duplicate column name")));
1539        assert!(errors
1540            .iter()
1541            .any(|e| e.contains("not found in column definitions")));
1542    }
1543
1544    #[test]
1545    fn test_type_collector_visitor() {
1546        // Create a CREATE TABLE AST with various types
1547        let create_table = CqlCreateTable {
1548            if_not_exists: false,
1549            table: CqlTable::new("test_table"),
1550            columns: vec![
1551                CqlColumnDef {
1552                    name: CqlIdentifier::new("simple"),
1553                    data_type: CqlDataType::Text,
1554                    is_static: false,
1555                },
1556                CqlColumnDef {
1557                    name: CqlIdentifier::new("list_col"),
1558                    data_type: CqlDataType::List(Box::new(CqlDataType::Int)),
1559                    is_static: false,
1560                },
1561                CqlColumnDef {
1562                    name: CqlIdentifier::new("map_col"),
1563                    data_type: CqlDataType::Map(
1564                        Box::new(CqlDataType::Text),
1565                        Box::new(CqlDataType::Uuid),
1566                    ),
1567                    is_static: false,
1568                },
1569                CqlColumnDef {
1570                    name: CqlIdentifier::new("frozen_col"),
1571                    data_type: CqlDataType::Frozen(Box::new(CqlDataType::Set(Box::new(
1572                        CqlDataType::BigInt,
1573                    )))),
1574                    is_static: false,
1575                },
1576            ],
1577            primary_key: CqlPrimaryKey {
1578                partition_key: vec![CqlIdentifier::new("simple")],
1579                clustering_key: vec![],
1580            },
1581            options: CqlTableOptions {
1582                options: HashMap::new(),
1583            },
1584        };
1585
1586        let statement = CqlStatement::CreateTable(create_table);
1587        let mut visitor = TypeCollectorVisitor::new();
1588        let _ = visitor.visit_statement(&statement);
1589
1590        let types = visitor.into_types();
1591
1592        // Should collect all types including nested ones
1593        assert!(types.iter().any(|t| matches!(t, CqlDataType::Text)));
1594        assert!(types.iter().any(|t| matches!(t, CqlDataType::List(_))));
1595        assert!(types.iter().any(|t| matches!(t, CqlDataType::Int)));
1596        assert!(types.iter().any(|t| matches!(t, CqlDataType::Map(_, _))));
1597        assert!(types.iter().any(|t| matches!(t, CqlDataType::Uuid)));
1598        assert!(types.iter().any(|t| matches!(t, CqlDataType::Frozen(_))));
1599        assert!(types.iter().any(|t| matches!(t, CqlDataType::Set(_))));
1600        assert!(types.iter().any(|t| matches!(t, CqlDataType::BigInt)));
1601    }
1602
1603    #[test]
1604    fn test_default_visitor_handles_batch_and_truncate() {
1605        let insert = CqlInsert {
1606            table: CqlTable::new("users"),
1607            columns: vec![CqlIdentifier::new("id"), CqlIdentifier::new("name")],
1608            values: CqlInsertValues::Values(vec![
1609                CqlExpression::Literal(CqlLiteral::Integer(1)),
1610                CqlExpression::Literal(CqlLiteral::String("alice".to_string())),
1611            ]),
1612            if_not_exists: false,
1613            using: None,
1614        };
1615
1616        let batch = CqlStatement::Batch(CqlBatch {
1617            batch_type: CqlBatchType::Logged,
1618            using: None,
1619            statements: vec![CqlBatchStatement::Insert(insert.clone())],
1620        });
1621
1622        let truncate = CqlStatement::Truncate(CqlTruncate {
1623            table: CqlTable::new("users"),
1624        });
1625
1626        let mut visitor = DefaultVisitor;
1627        let _: () = visitor.visit_statement(&batch).unwrap();
1628
1629        let _: () = visitor.visit_statement(&truncate).unwrap();
1630    }
1631}