1use super::ast::*;
7use super::traits::{CqlVisitor, ValidationContext, ValidationStrictness};
8use crate::error::{Error, Result};
9use crate::schema::{ClusteringColumn, Column, KeyColumn, TableSchema};
10use std::collections::HashMap;
11
12pub(crate) fn cql_data_type_to_string(data_type: &CqlDataType) -> String {
14 match data_type {
15 CqlDataType::Boolean => "boolean".to_string(),
16 CqlDataType::TinyInt => "tinyint".to_string(),
17 CqlDataType::SmallInt => "smallint".to_string(),
18 CqlDataType::Int => "int".to_string(),
19 CqlDataType::BigInt => "bigint".to_string(),
20 CqlDataType::Varint => "varint".to_string(),
21 CqlDataType::Decimal => "decimal".to_string(),
22 CqlDataType::Float => "float".to_string(),
23 CqlDataType::Double => "double".to_string(),
24 CqlDataType::Text => "text".to_string(),
25 CqlDataType::Ascii => "ascii".to_string(),
26 CqlDataType::Varchar => "varchar".to_string(),
27 CqlDataType::Blob => "blob".to_string(),
28 CqlDataType::Timestamp => "timestamp".to_string(),
29 CqlDataType::Date => "date".to_string(),
30 CqlDataType::Time => "time".to_string(),
31 CqlDataType::Uuid => "uuid".to_string(),
32 CqlDataType::TimeUuid => "timeuuid".to_string(),
33 CqlDataType::Inet => "inet".to_string(),
34 CqlDataType::Duration => "duration".to_string(),
35 CqlDataType::Counter => "counter".to_string(),
36 CqlDataType::List(inner) => format!("list<{}>", cql_data_type_to_string(inner)),
37 CqlDataType::Set(inner) => format!("set<{}>", cql_data_type_to_string(inner)),
38 CqlDataType::Map(key, value) => format!(
39 "map<{}, {}>",
40 cql_data_type_to_string(key),
41 cql_data_type_to_string(value)
42 ),
43 CqlDataType::Tuple(types) => {
44 let type_strs: Vec<String> = types.iter().map(cql_data_type_to_string).collect();
45 format!("tuple<{}>", type_strs.join(", "))
46 }
47 CqlDataType::Udt(name) => name.as_str().to_string(),
48 CqlDataType::Frozen(inner) => format!("frozen<{}>", cql_data_type_to_string(inner)),
49 CqlDataType::Custom(name) => name.clone(),
50 }
51}
52
53fn index_column_identifier(column: &CqlIndexColumn) -> &CqlIdentifier {
55 match column {
56 CqlIndexColumn::Column(id)
57 | CqlIndexColumn::Keys(id)
58 | CqlIndexColumn::Values(id)
59 | CqlIndexColumn::Entries(id)
60 | CqlIndexColumn::Full(id) => id,
61 }
62}
63
64#[derive(Debug, Default)]
70pub struct DefaultVisitor;
71
72impl<T: Default> CqlVisitor<T> for DefaultVisitor {
73 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<T> {
74 match statement {
75 CqlStatement::Select(select) => self.visit_select(select),
76 CqlStatement::Insert(insert) => self.visit_insert(insert),
77 CqlStatement::Update(update) => self.visit_update(update),
78 CqlStatement::Delete(delete) => self.visit_delete(delete),
79 CqlStatement::CreateTable(create) => self.visit_create_table(create),
80 CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
81 CqlStatement::CreateIndex(create) => self.visit_create_index(create),
82 CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
83 CqlStatement::CreateType(_) => Ok(T::default()),
84 CqlStatement::DropType(_) => Ok(T::default()),
85 CqlStatement::Use(_) => Ok(T::default()),
86 CqlStatement::Truncate(_) => Ok(T::default()),
87 CqlStatement::Batch(_) => Ok(T::default()),
88 }
89 }
90
91 fn visit_select(&mut self, select: &CqlSelect) -> Result<T> {
92 for item in &select.select_list {
93 match item {
94 CqlSelectItem::Expression { expression, .. } => {
95 let _: T = self.visit_expression(expression)?;
96 }
97 CqlSelectItem::Function { args, .. } => {
98 for arg in args {
99 let _: T = self.visit_expression(arg)?;
100 }
101 }
102 CqlSelectItem::Wildcard => {}
103 }
104 }
105
106 if let Some(where_clause) = &select.where_clause {
107 let _: T = self.visit_expression(where_clause)?;
108 }
109
110 Ok(T::default())
111 }
112
113 fn visit_insert(&mut self, insert: &CqlInsert) -> Result<T> {
114 for column in &insert.columns {
115 let _: T = self.visit_identifier(column)?;
116 }
117
118 if let CqlInsertValues::Values(expressions) = &insert.values {
119 for expr in expressions {
120 let _: T = self.visit_expression(expr)?;
121 }
122 }
123
124 if let Some(using) = &insert.using {
125 if let Some(ttl) = &using.ttl {
126 let _: T = self.visit_expression(ttl)?;
127 }
128 if let Some(timestamp) = &using.timestamp {
129 let _: T = self.visit_expression(timestamp)?;
130 }
131 }
132
133 Ok(T::default())
134 }
135
136 fn visit_update(&mut self, update: &CqlUpdate) -> Result<T> {
137 for assignment in &update.assignments {
138 let _: T = self.visit_identifier(&assignment.column)?;
139 let _: T = self.visit_expression(&assignment.value)?;
140
141 if let CqlAssignmentOperator::MapUpdate(key_expr) = &assignment.operator {
142 let _: T = self.visit_expression(key_expr)?;
143 }
144 }
145
146 let _: T = self.visit_expression(&update.where_clause)?;
147
148 if let Some(if_condition) = &update.if_condition {
149 let _: T = self.visit_expression(if_condition)?;
150 }
151
152 if let Some(using) = &update.using {
153 if let Some(ttl) = &using.ttl {
154 let _: T = self.visit_expression(ttl)?;
155 }
156 if let Some(timestamp) = &using.timestamp {
157 let _: T = self.visit_expression(timestamp)?;
158 }
159 }
160
161 Ok(T::default())
162 }
163
164 fn visit_delete(&mut self, delete: &CqlDelete) -> Result<T> {
165 for column in &delete.columns {
166 let _: T = self.visit_identifier(column)?;
167 }
168
169 let _: T = self.visit_expression(&delete.where_clause)?;
170
171 if let Some(if_condition) = &delete.if_condition {
172 let _: T = self.visit_expression(if_condition)?;
173 }
174
175 if let Some(using) = &delete.using {
176 if let Some(timestamp) = &using.timestamp {
177 let _: T = self.visit_expression(timestamp)?;
178 }
179 }
180
181 Ok(T::default())
182 }
183
184 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<T> {
185 let _: T = self.visit_identifier(&create.table.name)?;
186 if let Some(keyspace) = &create.table.keyspace {
187 let _: T = self.visit_identifier(keyspace)?;
188 }
189
190 for column in &create.columns {
191 let _: T = self.visit_identifier(&column.name)?;
192 let _: T = self.visit_data_type(&column.data_type)?;
193 }
194
195 for pk_column in &create.primary_key.partition_key {
196 let _: T = self.visit_identifier(pk_column)?;
197 }
198 for ck_column in &create.primary_key.clustering_key {
199 let _: T = self.visit_identifier(ck_column)?;
200 }
201
202 Ok(T::default())
203 }
204
205 fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<T> {
206 let _: T = self.visit_identifier(&drop.table.name)?;
207 if let Some(keyspace) = &drop.table.keyspace {
208 let _: T = self.visit_identifier(keyspace)?;
209 }
210
211 Ok(T::default())
212 }
213
214 fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<T> {
215 if let Some(name) = &create.name {
216 let _: T = self.visit_identifier(name)?;
217 }
218
219 let _: T = self.visit_identifier(&create.table.name)?;
220 if let Some(keyspace) = &create.table.keyspace {
221 let _: T = self.visit_identifier(keyspace)?;
222 }
223
224 for column in &create.columns {
225 let _: T = self.visit_identifier(index_column_identifier(column))?;
226 }
227
228 Ok(T::default())
229 }
230
231 fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<T> {
232 let _: T = self.visit_identifier(&alter.table.name)?;
233 if let Some(keyspace) = &alter.table.keyspace {
234 let _: T = self.visit_identifier(keyspace)?;
235 }
236
237 match &alter.operation {
238 CqlAlterTableOp::AddColumn(column_def) => {
239 let _: T = self.visit_identifier(&column_def.name)?;
240 let _: T = self.visit_data_type(&column_def.data_type)?;
241 }
242 CqlAlterTableOp::DropColumn(column) => {
243 let _: T = self.visit_identifier(column)?;
244 }
245 CqlAlterTableOp::AlterColumn { column, new_type } => {
246 let _: T = self.visit_identifier(column)?;
247 let _: T = self.visit_data_type(new_type)?;
248 }
249 CqlAlterTableOp::RenameColumn { old_name, new_name } => {
250 let _: T = self.visit_identifier(old_name)?;
251 let _: T = self.visit_identifier(new_name)?;
252 }
253 CqlAlterTableOp::WithOptions(_) => {}
254 }
255
256 Ok(T::default())
257 }
258
259 fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<T> {
260 match data_type {
261 CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
262 let _: T = self.visit_data_type(inner)?;
263 }
264 CqlDataType::Map(key_type, value_type) => {
265 let _: T = self.visit_data_type(key_type)?;
266 let _: T = self.visit_data_type(value_type)?;
267 }
268 CqlDataType::Tuple(types) => {
269 for typ in types {
270 let _: T = self.visit_data_type(typ)?;
271 }
272 }
273 CqlDataType::Udt(name) => {
274 let _: T = self.visit_identifier(name)?;
275 }
276 _ => {}
277 }
278
279 Ok(T::default())
280 }
281
282 fn visit_expression(&mut self, expression: &CqlExpression) -> Result<T> {
283 match expression {
284 CqlExpression::Literal(literal) => self.visit_literal(literal),
285 CqlExpression::Column(column) => self.visit_identifier(column),
286 CqlExpression::Parameter(_) | CqlExpression::NamedParameter(_) => Ok(T::default()),
287 CqlExpression::Binary { left, right, .. } => {
288 let _: T = self.visit_expression(left)?;
289 let _: T = self.visit_expression(right)?;
290 Ok(T::default())
291 }
292 CqlExpression::Unary { operand, .. } => {
293 let _: T = self.visit_expression(operand)?;
294 Ok(T::default())
295 }
296 CqlExpression::Function { name, args } => {
297 let _: T = self.visit_identifier(name)?;
298 for arg in args {
299 let _: T = self.visit_expression(arg)?;
300 }
301 Ok(T::default())
302 }
303 CqlExpression::In { expression, values } => {
304 let _: T = self.visit_expression(expression)?;
305 for value in values {
306 let _: T = self.visit_expression(value)?;
307 }
308 Ok(T::default())
309 }
310 CqlExpression::Contains { column, value } => {
311 let _: T = self.visit_identifier(column)?;
312 let _: T = self.visit_expression(value)?;
313 Ok(T::default())
314 }
315 CqlExpression::ContainsKey { column, key } => {
316 let _: T = self.visit_identifier(column)?;
317 let _: T = self.visit_expression(key)?;
318 Ok(T::default())
319 }
320 CqlExpression::CollectionAccess { collection, index } => {
321 let _: T = self.visit_expression(collection)?;
322 let _: T = self.visit_expression(index)?;
323 Ok(T::default())
324 }
325 CqlExpression::FieldAccess { object, field } => {
326 let _: T = self.visit_expression(object)?;
327 let _: T = self.visit_identifier(field)?;
328 Ok(T::default())
329 }
330 CqlExpression::Case {
331 when_clauses,
332 else_clause,
333 } => {
334 for when_clause in when_clauses {
335 let _: T = self.visit_expression(&when_clause.condition)?;
336 let _: T = self.visit_expression(&when_clause.result)?;
337 }
338 if let Some(else_expr) = else_clause {
339 let _: T = self.visit_expression(else_expr)?;
340 }
341 Ok(T::default())
342 }
343 CqlExpression::Cast {
344 expression,
345 target_type,
346 } => {
347 let _: T = self.visit_expression(expression)?;
348 let _: T = self.visit_data_type(target_type)?;
349 Ok(T::default())
350 }
351 }
352 }
353
354 fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<T> {
355 Ok(T::default())
356 }
357
358 fn visit_literal(&mut self, literal: &CqlLiteral) -> Result<T> {
359 match literal {
360 CqlLiteral::Collection(collection) => match collection {
361 CqlCollectionLiteral::List(items) | CqlCollectionLiteral::Set(items) => {
362 for item in items {
363 let _: T = self.visit_literal(item)?;
364 }
365 }
366 CqlCollectionLiteral::Map(pairs) => {
367 for (key, value) in pairs {
368 let _: T = self.visit_literal(key)?;
369 let _: T = self.visit_literal(value)?;
370 }
371 }
372 },
373 CqlLiteral::Udt(udt) => {
374 for (field_name, field_value) in &udt.fields {
375 let _: T = self.visit_identifier(field_name)?;
376 let _: T = self.visit_literal(field_value)?;
377 }
378 }
379 CqlLiteral::Tuple(items) => {
380 for item in items {
381 let _: T = self.visit_literal(item)?;
382 }
383 }
384 _ => {}
385 }
386
387 Ok(T::default())
388 }
389}
390
391#[derive(Debug, Default)]
393pub struct IdentifierCollector {
394 pub identifiers: Vec<CqlIdentifier>,
395}
396
397impl IdentifierCollector {
398 pub fn new() -> Self {
399 Self {
400 identifiers: Vec::new(),
401 }
402 }
403
404 pub fn into_identifiers(self) -> Vec<CqlIdentifier> {
405 self.identifiers
406 }
407}
408
409impl CqlVisitor<()> for IdentifierCollector {
410 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
411 match statement {
412 CqlStatement::Select(select) => self.visit_select(select),
413 CqlStatement::Insert(insert) => self.visit_insert(insert),
414 CqlStatement::Update(update) => self.visit_update(update),
415 CqlStatement::Delete(delete) => self.visit_delete(delete),
416 CqlStatement::CreateTable(create) => self.visit_create_table(create),
417 CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
418 CqlStatement::CreateIndex(create) => self.visit_create_index(create),
419 CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
420 _ => Ok(()),
421 }
422 }
423
424 fn visit_select(&mut self, select: &CqlSelect) -> Result<()> {
425 for item in &select.select_list {
426 match item {
427 CqlSelectItem::Expression { expression, .. } => {
428 self.visit_expression(expression)?;
429 }
430 CqlSelectItem::Function { args, .. } => {
431 for arg in args {
432 self.visit_expression(arg)?;
433 }
434 }
435 CqlSelectItem::Wildcard => {}
436 }
437 }
438
439 self.visit_identifier(&select.from.name)?;
440 if let Some(keyspace) = &select.from.keyspace {
441 self.visit_identifier(keyspace)?;
442 }
443
444 if let Some(where_clause) = &select.where_clause {
445 self.visit_expression(where_clause)?;
446 }
447
448 Ok(())
449 }
450
451 fn visit_insert(&mut self, insert: &CqlInsert) -> Result<()> {
452 self.visit_identifier(&insert.table.name)?;
453 if let Some(keyspace) = &insert.table.keyspace {
454 self.visit_identifier(keyspace)?;
455 }
456
457 for column in &insert.columns {
458 self.visit_identifier(column)?;
459 }
460
461 if let CqlInsertValues::Values(values) = &insert.values {
462 for value in values {
463 self.visit_expression(value)?;
464 }
465 }
466
467 Ok(())
468 }
469
470 fn visit_update(&mut self, update: &CqlUpdate) -> Result<()> {
471 self.visit_identifier(&update.table.name)?;
472 if let Some(keyspace) = &update.table.keyspace {
473 self.visit_identifier(keyspace)?;
474 }
475
476 for assignment in &update.assignments {
477 self.visit_identifier(&assignment.column)?;
478 self.visit_expression(&assignment.value)?;
479 }
480
481 self.visit_expression(&update.where_clause)?;
482
483 Ok(())
484 }
485
486 fn visit_delete(&mut self, delete: &CqlDelete) -> Result<()> {
487 self.visit_identifier(&delete.table.name)?;
488 if let Some(keyspace) = &delete.table.keyspace {
489 self.visit_identifier(keyspace)?;
490 }
491
492 self.visit_expression(&delete.where_clause)?;
493
494 Ok(())
495 }
496
497 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
498 self.visit_identifier(&create.table.name)?;
499 if let Some(keyspace) = &create.table.keyspace {
500 self.visit_identifier(keyspace)?;
501 }
502
503 for column in &create.columns {
504 self.visit_identifier(&column.name)?;
505 self.visit_data_type(&column.data_type)?;
506 }
507
508 for pk_col in &create.primary_key.partition_key {
509 self.visit_identifier(pk_col)?;
510 }
511 for ck_col in &create.primary_key.clustering_key {
512 self.visit_identifier(ck_col)?;
513 }
514
515 Ok(())
516 }
517
518 fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
519 self.visit_identifier(&drop.table.name)?;
520 if let Some(keyspace) = &drop.table.keyspace {
521 self.visit_identifier(keyspace)?;
522 }
523 Ok(())
524 }
525
526 fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<()> {
527 if let Some(index_name) = &create.name {
528 self.visit_identifier(index_name)?;
529 }
530 self.visit_identifier(&create.table.name)?;
531 if let Some(keyspace) = &create.table.keyspace {
532 self.visit_identifier(keyspace)?;
533 }
534 for column in &create.columns {
535 self.visit_identifier(index_column_identifier(column))?;
536 }
537 Ok(())
538 }
539
540 fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
541 self.visit_identifier(&alter.table.name)?;
542 if let Some(keyspace) = &alter.table.keyspace {
543 self.visit_identifier(keyspace)?;
544 }
545
546 match &alter.operation {
547 CqlAlterTableOp::AddColumn(column_def) => {
548 self.visit_identifier(&column_def.name)?;
549 self.visit_data_type(&column_def.data_type)?;
550 }
551 CqlAlterTableOp::DropColumn(column_name) => {
552 self.visit_identifier(column_name)?;
553 }
554 CqlAlterTableOp::AlterColumn { column, new_type } => {
555 self.visit_identifier(column)?;
556 self.visit_data_type(new_type)?;
557 }
558 CqlAlterTableOp::RenameColumn { old_name, new_name } => {
559 self.visit_identifier(old_name)?;
560 self.visit_identifier(new_name)?;
561 }
562 _ => {}
563 }
564
565 Ok(())
566 }
567
568 fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
569 match data_type {
570 CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
571 self.visit_data_type(inner)?;
572 }
573 CqlDataType::Map(key, value) => {
574 self.visit_data_type(key)?;
575 self.visit_data_type(value)?;
576 }
577 CqlDataType::Udt(name) => {
578 self.visit_identifier(name)?;
579 }
580 _ => {}
581 }
582 Ok(())
583 }
584
585 fn visit_expression(&mut self, expression: &CqlExpression) -> Result<()> {
586 match expression {
587 CqlExpression::Column(identifier) => {
588 self.visit_identifier(identifier)?;
589 }
590 CqlExpression::Literal(literal) => {
591 self.visit_literal(literal)?;
592 }
593 CqlExpression::Function { name, args } => {
594 self.visit_identifier(name)?;
595 for arg in args {
596 self.visit_expression(arg)?;
597 }
598 }
599 CqlExpression::Binary { left, right, .. } => {
600 self.visit_expression(left)?;
601 self.visit_expression(right)?;
602 }
603 CqlExpression::Unary { operand, .. } => {
604 self.visit_expression(operand)?;
605 }
606 CqlExpression::In { expression, values } => {
607 self.visit_expression(expression)?;
608 for value in values {
609 self.visit_expression(value)?;
610 }
611 }
612 CqlExpression::Contains { column, value } => {
613 self.visit_identifier(column)?;
614 self.visit_expression(value)?;
615 }
616 CqlExpression::ContainsKey { column, key } => {
617 self.visit_identifier(column)?;
618 self.visit_expression(key)?;
619 }
620 CqlExpression::CollectionAccess { collection, index } => {
621 self.visit_expression(collection)?;
622 self.visit_expression(index)?;
623 }
624 CqlExpression::FieldAccess { object, field } => {
625 self.visit_expression(object)?;
626 self.visit_identifier(field)?;
627 }
628 CqlExpression::Case {
629 when_clauses,
630 else_clause,
631 } => {
632 for when_clause in when_clauses {
633 self.visit_expression(&when_clause.condition)?;
634 self.visit_expression(&when_clause.result)?;
635 }
636 if let Some(else_expr) = else_clause {
637 self.visit_expression(else_expr)?;
638 }
639 }
640 CqlExpression::Cast {
641 expression,
642 target_type,
643 } => {
644 self.visit_expression(expression)?;
645 self.visit_data_type(target_type)?;
646 }
647 CqlExpression::Parameter(_) | CqlExpression::NamedParameter(_) => {}
648 }
649 Ok(())
650 }
651
652 fn visit_identifier(&mut self, identifier: &CqlIdentifier) -> Result<()> {
653 self.identifiers.push(identifier.clone());
654 Ok(())
655 }
656
657 fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
658 Ok(())
659 }
660}
661
662#[derive(Debug)]
664pub struct SemanticValidator {
665 pub context: ValidationContext,
666 pub errors: Vec<String>,
667}
668
669impl SemanticValidator {
670 pub fn new(context: ValidationContext) -> Self {
672 Self {
673 context,
674 errors: Vec::new(),
675 }
676 }
677
678 fn add_error(&mut self, message: String) {
679 self.errors.push(message);
680 }
681
682 pub fn is_valid(&self) -> bool {
684 self.errors.is_empty()
685 }
686
687 pub fn get_errors(&self) -> &[String] {
689 &self.errors
690 }
691
692 fn is_strict(&self) -> bool {
693 matches!(self.context.strictness, ValidationStrictness::Strict)
694 }
695
696 fn check_table_exists(&mut self, table: &CqlTable) {
698 let name = table.full_name();
699 if !self.context.schemas.contains_key(&name) && self.is_strict() {
700 self.add_error(format!("Table '{}' does not exist", name));
701 }
702 }
703}
704
705impl CqlVisitor<()> for SemanticValidator {
706 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
707 match statement {
708 CqlStatement::Select(select) => self.visit_select(select),
709 CqlStatement::Insert(insert) => self.visit_insert(insert),
710 CqlStatement::Update(update) => self.visit_update(update),
711 CqlStatement::Delete(delete) => self.visit_delete(delete),
712 CqlStatement::CreateTable(create) => self.visit_create_table(create),
713 CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
714 CqlStatement::CreateIndex(create) => self.visit_create_index(create),
715 CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
716 CqlStatement::CreateType(_) => Ok(()),
717 CqlStatement::DropType(_) => Ok(()),
718 CqlStatement::Use(_) => Ok(()),
719 CqlStatement::Truncate(_) => Ok(()),
720 CqlStatement::Batch(_) => Ok(()),
721 }
722 }
723
724 fn visit_select(&mut self, select: &CqlSelect) -> Result<()> {
725 self.check_table_exists(&select.from);
726 DefaultVisitor.visit_select(select)
727 }
728
729 fn visit_insert(&mut self, insert: &CqlInsert) -> Result<()> {
730 let table_name = insert.table.full_name();
731 if self.context.schemas.contains_key(&table_name) {
732 if let CqlInsertValues::Values(values) = &insert.values {
733 if insert.columns.len() != values.len() {
734 self.add_error(format!(
735 "Column count ({}) does not match value count ({})",
736 insert.columns.len(),
737 values.len()
738 ));
739 }
740 }
741 } else if self.is_strict() {
742 self.add_error(format!("Table '{}' does not exist", table_name));
743 }
744
745 DefaultVisitor.visit_insert(insert)
746 }
747
748 fn visit_update(&mut self, update: &CqlUpdate) -> Result<()> {
749 self.check_table_exists(&update.table);
750 DefaultVisitor.visit_update(update)
751 }
752
753 fn visit_delete(&mut self, delete: &CqlDelete) -> Result<()> {
754 self.check_table_exists(&delete.table);
755 DefaultVisitor.visit_delete(delete)
756 }
757
758 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
759 let mut column_names = std::collections::HashSet::new();
760 for column in &create.columns {
761 let name = column.name.as_str();
762 if !column_names.insert(name) {
763 self.add_error(format!("Duplicate column name: '{}'", name));
764 }
765 }
766
767 for pk_column in &create.primary_key.partition_key {
768 let name = pk_column.as_str();
769 if !create.columns.iter().any(|c| c.name.as_str() == name) {
770 self.add_error(format!(
771 "Partition key column '{}' not found in column definitions",
772 name
773 ));
774 }
775 }
776
777 for ck_column in &create.primary_key.clustering_key {
778 let name = ck_column.as_str();
779 if !create.columns.iter().any(|c| c.name.as_str() == name) {
780 self.add_error(format!(
781 "Clustering key column '{}' not found in column definitions",
782 name
783 ));
784 }
785 }
786
787 DefaultVisitor.visit_create_table(create)
788 }
789
790 fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
791 if !drop.if_exists {
792 self.check_table_exists(&drop.table);
793 }
794 DefaultVisitor.visit_drop_table(drop)
795 }
796
797 fn visit_create_index(&mut self, create: &CqlCreateIndex) -> Result<()> {
798 self.check_table_exists(&create.table);
799 DefaultVisitor.visit_create_index(create)
800 }
801
802 fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
803 self.check_table_exists(&alter.table);
804 DefaultVisitor.visit_alter_table(alter)
805 }
806
807 fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
808 if let CqlDataType::Udt(udt_name) = data_type {
809 let udt_key = udt_name.as_str();
810 if !self.context.udts.contains_key(udt_key) && self.is_strict() {
811 self.add_error(format!("UDT '{}' does not exist", udt_key));
812 }
813 }
814
815 DefaultVisitor.visit_data_type(data_type)
816 }
817
818 fn visit_expression(&mut self, expression: &CqlExpression) -> Result<()> {
819 DefaultVisitor.visit_expression(expression)
820 }
821
822 fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
823 Ok(())
824 }
825
826 fn visit_literal(&mut self, literal: &CqlLiteral) -> Result<()> {
827 DefaultVisitor.visit_literal(literal)
828 }
829}
830
831pub type TransformationFn = Box<dyn Fn(&CqlStatement) -> Option<CqlStatement>>;
833
834pub struct AstTransformer {
836 pub transformations: Vec<TransformationFn>,
838}
839
840impl std::fmt::Debug for AstTransformer {
841 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
842 f.debug_struct("AstTransformer")
843 .field(
844 "transformations",
845 &format!("[{} transformations]", self.transformations.len()),
846 )
847 .finish()
848 }
849}
850
851impl AstTransformer {
852 pub fn new() -> Self {
854 Self {
855 transformations: Vec::new(),
856 }
857 }
858
859 pub fn add_transformation<F>(&mut self, transform: F)
861 where
862 F: Fn(&CqlStatement) -> Option<CqlStatement> + 'static,
863 {
864 self.transformations.push(Box::new(transform));
865 }
866
867 pub fn transform(&self, statement: &CqlStatement) -> CqlStatement {
869 let mut result = statement.clone();
870
871 for transformation in &self.transformations {
872 if let Some(transformed) = transformation(&result) {
873 result = transformed;
874 }
875 }
876
877 result
878 }
879}
880
881impl Default for AstTransformer {
882 fn default() -> Self {
883 Self::new()
884 }
885}
886
887pub mod utils {
889 use super::*;
890
891 pub fn collect_table_references(statement: &CqlStatement) -> Vec<String> {
893 let table = match statement {
894 CqlStatement::Select(select) => &select.from,
895 CqlStatement::Insert(insert) => &insert.table,
896 CqlStatement::Update(update) => &update.table,
897 CqlStatement::Delete(delete) => &delete.table,
898 CqlStatement::CreateTable(create) => &create.table,
899 CqlStatement::DropTable(drop) => &drop.table,
900 CqlStatement::CreateIndex(create) => &create.table,
901 CqlStatement::AlterTable(alter) => &alter.table,
902 CqlStatement::Truncate(truncate) => &truncate.table,
903 CqlStatement::CreateType(_)
904 | CqlStatement::DropType(_)
905 | CqlStatement::Use(_)
906 | CqlStatement::Batch(_) => return Vec::new(),
907 };
908 vec![table.full_name()]
909 }
910
911 pub fn is_modifying_statement(statement: &CqlStatement) -> bool {
913 matches!(
914 statement,
915 CqlStatement::Insert(_)
916 | CqlStatement::Update(_)
917 | CqlStatement::Delete(_)
918 | CqlStatement::CreateTable(_)
919 | CqlStatement::DropTable(_)
920 | CqlStatement::CreateIndex(_)
921 | CqlStatement::AlterTable(_)
922 )
923 }
924
925 pub fn is_query_statement(statement: &CqlStatement) -> bool {
927 matches!(statement, CqlStatement::Select(_))
928 }
929
930 pub fn is_schema_statement(statement: &CqlStatement) -> bool {
932 matches!(
933 statement,
934 CqlStatement::CreateTable(_)
935 | CqlStatement::DropTable(_)
936 | CqlStatement::CreateIndex(_)
937 | CqlStatement::AlterTable(_)
938 )
939 }
940}
941
942#[derive(Debug, Default)]
947pub struct SchemaBuilderVisitor;
948
949fn schema_builder_unsupported(kind: &str) -> Error {
951 Error::invalid_input(format!("SchemaBuilderVisitor {}", kind))
952}
953
954fn column_def_for<'a>(
956 create: &'a CqlCreateTable,
957 key: &CqlIdentifier,
958 role: &str,
959) -> Result<&'a CqlColumnDef> {
960 create
961 .columns
962 .iter()
963 .find(|col| col.name.as_str() == key.as_str())
964 .ok_or_else(|| {
965 Error::invalid_input(format!(
966 "{} key column '{}' not found in column definitions",
967 role,
968 key.as_str()
969 ))
970 })
971}
972
973impl CqlVisitor<TableSchema> for SchemaBuilderVisitor {
974 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<TableSchema> {
975 match statement {
976 CqlStatement::CreateTable(create) => self.visit_create_table(create),
977 _ => Err(schema_builder_unsupported(
978 "only supports CREATE TABLE statements",
979 )),
980 }
981 }
982
983 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<TableSchema> {
984 let table_name = create.table.name.as_str().to_string();
985 let keyspace = create
986 .table
987 .keyspace
988 .as_ref()
989 .map(|ks| ks.as_str().to_string())
990 .unwrap_or_else(|| "default".to_string());
991
992 let partition_keys = create
993 .primary_key
994 .partition_key
995 .iter()
996 .enumerate()
997 .map(|(pos, pk_col)| {
998 let column_def = column_def_for(create, pk_col, "Partition")?;
999 Ok(KeyColumn {
1000 name: pk_col.as_str().to_string(),
1001 data_type: cql_data_type_to_string(&column_def.data_type),
1002 position: pos,
1003 })
1004 })
1005 .collect::<Result<Vec<_>>>()?;
1006
1007 let clustering_keys = create
1008 .primary_key
1009 .clustering_key
1010 .iter()
1011 .enumerate()
1012 .map(|(pos, ck_col)| {
1013 let column_def = column_def_for(create, ck_col, "Clustering")?;
1014 Ok(ClusteringColumn {
1015 name: ck_col.as_str().to_string(),
1016 data_type: cql_data_type_to_string(&column_def.data_type),
1017 position: pos,
1018 order: crate::schema::ClusteringOrder::Asc,
1019 })
1020 })
1021 .collect::<Result<Vec<_>>>()?;
1022
1023 let columns: Vec<Column> = create
1024 .columns
1025 .iter()
1026 .map(|col_def| Column {
1027 name: col_def.name.as_str().to_string(),
1028 data_type: cql_data_type_to_string(&col_def.data_type),
1029 nullable: true,
1030 default: None,
1031 is_static: col_def.is_static,
1032 })
1033 .collect();
1034
1035 Ok(TableSchema {
1036 keyspace,
1037 table: table_name,
1038 partition_keys,
1039 clustering_keys,
1040 columns,
1041 comments: HashMap::new(),
1042 })
1043 }
1044
1045 fn visit_select(&mut self, _select: &CqlSelect) -> Result<TableSchema> {
1046 Err(schema_builder_unsupported(
1047 "does not support SELECT statements",
1048 ))
1049 }
1050
1051 fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<TableSchema> {
1052 Err(schema_builder_unsupported(
1053 "does not support INSERT statements",
1054 ))
1055 }
1056
1057 fn visit_update(&mut self, _update: &CqlUpdate) -> Result<TableSchema> {
1058 Err(schema_builder_unsupported(
1059 "does not support UPDATE statements",
1060 ))
1061 }
1062
1063 fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<TableSchema> {
1064 Err(schema_builder_unsupported(
1065 "does not support DELETE statements",
1066 ))
1067 }
1068
1069 fn visit_drop_table(&mut self, _drop: &CqlDropTable) -> Result<TableSchema> {
1070 Err(schema_builder_unsupported(
1071 "does not support DROP TABLE statements",
1072 ))
1073 }
1074
1075 fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<TableSchema> {
1076 Err(schema_builder_unsupported(
1077 "does not support CREATE INDEX statements",
1078 ))
1079 }
1080
1081 fn visit_alter_table(&mut self, _alter: &CqlAlterTable) -> Result<TableSchema> {
1082 Err(schema_builder_unsupported(
1083 "does not support ALTER TABLE statements",
1084 ))
1085 }
1086
1087 fn visit_data_type(&mut self, _data_type: &CqlDataType) -> Result<TableSchema> {
1088 Err(schema_builder_unsupported(
1089 "does not support standalone data types",
1090 ))
1091 }
1092
1093 fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<TableSchema> {
1094 Err(schema_builder_unsupported("does not support expressions"))
1095 }
1096
1097 fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<TableSchema> {
1098 Err(schema_builder_unsupported("does not support identifiers"))
1099 }
1100
1101 fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<TableSchema> {
1102 Err(schema_builder_unsupported("does not support literals"))
1103 }
1104}
1105
1106impl SchemaBuilderVisitor {
1107 pub fn new() -> Self {
1109 Self
1110 }
1111}
1112
1113#[derive(Debug, Default)]
1117pub struct ValidationVisitor {
1118 pub errors: Vec<String>,
1119}
1120
1121impl ValidationVisitor {
1122 pub fn new() -> Self {
1123 Self { errors: Vec::new() }
1124 }
1125
1126 pub fn has_errors(&self) -> bool {
1127 !self.errors.is_empty()
1128 }
1129
1130 pub fn get_errors(&self) -> &[String] {
1131 &self.errors
1132 }
1133
1134 fn add_error(&mut self, error: String) {
1135 self.errors.push(error);
1136 }
1137}
1138
1139impl CqlVisitor<()> for ValidationVisitor {
1140 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
1141 match statement {
1142 CqlStatement::CreateTable(create) => self.visit_create_table(create),
1143 CqlStatement::Select(select) => self.visit_select(select),
1144 CqlStatement::Insert(insert) => self.visit_insert(insert),
1145 CqlStatement::Update(update) => self.visit_update(update),
1146 CqlStatement::Delete(delete) => self.visit_delete(delete),
1147 CqlStatement::DropTable(drop) => self.visit_drop_table(drop),
1148 CqlStatement::CreateIndex(create) => self.visit_create_index(create),
1149 CqlStatement::AlterTable(alter) => self.visit_alter_table(alter),
1150 _ => Ok(()), }
1152 }
1153
1154 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
1155 if create.table.name.as_str().is_empty() {
1156 self.add_error("Table name cannot be empty".to_string());
1157 }
1158
1159 for pk_col in &create.primary_key.partition_key {
1160 if !create
1161 .columns
1162 .iter()
1163 .any(|col| col.name.as_str() == pk_col.as_str())
1164 {
1165 self.add_error(format!(
1166 "Partition key column '{}' not found in column definitions",
1167 pk_col.as_str()
1168 ));
1169 }
1170 }
1171
1172 for ck_col in &create.primary_key.clustering_key {
1173 if !create
1174 .columns
1175 .iter()
1176 .any(|col| col.name.as_str() == ck_col.as_str())
1177 {
1178 self.add_error(format!(
1179 "Clustering key column '{}' not found in column definitions",
1180 ck_col.as_str()
1181 ));
1182 }
1183 }
1184
1185 let mut column_names = std::collections::HashSet::new();
1186 for column in &create.columns {
1187 let name = column.name.as_str();
1188 if !column_names.insert(name) {
1189 self.add_error(format!("Duplicate column name: '{}'", name));
1190 }
1191 }
1192
1193 if create.primary_key.partition_key.is_empty() {
1194 self.add_error("Table must have at least one partition key column".to_string());
1195 }
1196
1197 Ok(())
1198 }
1199
1200 fn visit_select(&mut self, _select: &CqlSelect) -> Result<()> {
1201 Ok(())
1202 }
1203
1204 fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<()> {
1205 Ok(())
1206 }
1207
1208 fn visit_update(&mut self, _update: &CqlUpdate) -> Result<()> {
1209 Ok(())
1210 }
1211
1212 fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<()> {
1213 Ok(())
1214 }
1215
1216 fn visit_drop_table(&mut self, drop: &CqlDropTable) -> Result<()> {
1217 if drop.table.name.as_str().is_empty() {
1218 self.add_error("Table name cannot be empty".to_string());
1219 }
1220 Ok(())
1221 }
1222
1223 fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<()> {
1224 Ok(())
1225 }
1226
1227 fn visit_alter_table(&mut self, _alter: &CqlAlterTable) -> Result<()> {
1228 Ok(())
1229 }
1230
1231 fn visit_data_type(&mut self, _data_type: &CqlDataType) -> Result<()> {
1232 Ok(())
1233 }
1234
1235 fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<()> {
1236 Ok(())
1237 }
1238
1239 fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
1240 Ok(())
1241 }
1242
1243 fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
1244 Ok(())
1245 }
1246}
1247
1248#[derive(Debug, Default)]
1252pub struct TypeCollectorVisitor {
1253 pub types: Vec<CqlDataType>,
1254}
1255
1256impl TypeCollectorVisitor {
1257 pub fn new() -> Self {
1258 Self { types: Vec::new() }
1259 }
1260
1261 pub fn into_types(self) -> Vec<CqlDataType> {
1262 self.types
1263 }
1264
1265 fn collect_type(&mut self, data_type: &CqlDataType) {
1266 self.types.push(data_type.clone());
1267
1268 match data_type {
1269 CqlDataType::List(inner) | CqlDataType::Set(inner) | CqlDataType::Frozen(inner) => {
1270 self.collect_type(inner);
1271 }
1272 CqlDataType::Map(key, value) => {
1273 self.collect_type(key);
1274 self.collect_type(value);
1275 }
1276 CqlDataType::Tuple(types) => {
1277 for t in types {
1278 self.collect_type(t);
1279 }
1280 }
1281 _ => {}
1282 }
1283 }
1284}
1285
1286impl CqlVisitor<()> for TypeCollectorVisitor {
1287 fn visit_statement(&mut self, statement: &CqlStatement) -> Result<()> {
1288 match statement {
1289 CqlStatement::CreateTable(create) => self.visit_create_table(create),
1290 _ => Ok(()),
1291 }
1292 }
1293
1294 fn visit_create_table(&mut self, create: &CqlCreateTable) -> Result<()> {
1295 for column in &create.columns {
1296 self.collect_type(&column.data_type);
1297 }
1298 Ok(())
1299 }
1300
1301 fn visit_select(&mut self, _select: &CqlSelect) -> Result<()> {
1302 Ok(())
1303 }
1304
1305 fn visit_insert(&mut self, _insert: &CqlInsert) -> Result<()> {
1306 Ok(())
1307 }
1308
1309 fn visit_update(&mut self, _update: &CqlUpdate) -> Result<()> {
1310 Ok(())
1311 }
1312
1313 fn visit_delete(&mut self, _delete: &CqlDelete) -> Result<()> {
1314 Ok(())
1315 }
1316
1317 fn visit_drop_table(&mut self, _drop: &CqlDropTable) -> Result<()> {
1318 Ok(())
1319 }
1320
1321 fn visit_create_index(&mut self, _create: &CqlCreateIndex) -> Result<()> {
1322 Ok(())
1323 }
1324
1325 fn visit_alter_table(&mut self, alter: &CqlAlterTable) -> Result<()> {
1326 match &alter.operation {
1327 CqlAlterTableOp::AddColumn(column_def) => {
1328 self.collect_type(&column_def.data_type);
1329 }
1330 CqlAlterTableOp::AlterColumn { new_type, .. } => {
1331 self.collect_type(new_type);
1332 }
1333 _ => {}
1334 }
1335 Ok(())
1336 }
1337
1338 fn visit_data_type(&mut self, data_type: &CqlDataType) -> Result<()> {
1339 self.collect_type(data_type);
1340 Ok(())
1341 }
1342
1343 fn visit_expression(&mut self, _expression: &CqlExpression) -> Result<()> {
1344 Ok(())
1345 }
1346
1347 fn visit_identifier(&mut self, _identifier: &CqlIdentifier) -> Result<()> {
1348 Ok(())
1349 }
1350
1351 fn visit_literal(&mut self, _literal: &CqlLiteral) -> Result<()> {
1352 Ok(())
1353 }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358 use super::*;
1359
1360 #[test]
1361 fn test_identifier_collector() {
1362 let statement = CqlStatement::Select(CqlSelect {
1363 distinct: false,
1364 select_list: vec![
1365 CqlSelectItem::Expression {
1366 expression: CqlExpression::Column(CqlIdentifier::new("id")),
1367 alias: None,
1368 },
1369 CqlSelectItem::Expression {
1370 expression: CqlExpression::Column(CqlIdentifier::new("name")),
1371 alias: None,
1372 },
1373 ],
1374 from: CqlTable::new("users"),
1375 where_clause: Some(CqlExpression::Binary {
1376 left: Box::new(CqlExpression::Column(CqlIdentifier::new("id"))),
1377 operator: CqlBinaryOperator::Eq,
1378 right: Box::new(CqlExpression::Parameter(1)),
1379 }),
1380 order_by: None,
1381 limit: None,
1382 allow_filtering: false,
1383 });
1384
1385 let mut collector = IdentifierCollector::default();
1386 collector.visit_statement(&statement).unwrap();
1387
1388 assert_eq!(collector.identifiers.len(), 4);
1390 assert_eq!(collector.identifiers[0].as_str(), "id");
1391 assert_eq!(collector.identifiers[1].as_str(), "name");
1392 assert_eq!(collector.identifiers[2].as_str(), "users");
1393 assert_eq!(collector.identifiers[3].as_str(), "id");
1394 }
1395
1396 #[test]
1397 fn test_semantic_validator() {
1398 let statement = CqlStatement::Insert(CqlInsert {
1399 table: CqlTable::new("users"),
1400 columns: vec![CqlIdentifier::new("id"), CqlIdentifier::new("name")],
1401 values: CqlInsertValues::Values(vec![
1402 CqlExpression::Parameter(1),
1403 ]),
1405 if_not_exists: false,
1406 using: None,
1407 });
1408
1409 let context = ValidationContext::new();
1410 let mut validator = SemanticValidator::new(context);
1411 validator.visit_statement(&statement).unwrap();
1412
1413 assert!(!validator.is_valid());
1415 assert!(!validator.get_errors().is_empty());
1416 }
1417
1418 #[test]
1419 fn test_utils() {
1420 let statement = CqlStatement::Select(CqlSelect {
1421 distinct: false,
1422 select_list: vec![CqlSelectItem::Wildcard],
1423 from: CqlTable::with_keyspace("test", "users"),
1424 where_clause: None,
1425 order_by: None,
1426 limit: None,
1427 allow_filtering: false,
1428 });
1429
1430 let tables = utils::collect_table_references(&statement);
1431 assert_eq!(tables, vec!["test.users"]);
1432
1433 assert!(utils::is_query_statement(&statement));
1434 assert!(!utils::is_modifying_statement(&statement));
1435 assert!(!utils::is_schema_statement(&statement));
1436 }
1437
1438 #[test]
1439 fn test_schema_builder_visitor() {
1440 let create_table = CqlCreateTable {
1442 if_not_exists: false,
1443 table: CqlTable::with_keyspace("test_keyspace", "users"),
1444 columns: vec![
1445 CqlColumnDef {
1446 name: CqlIdentifier::new("id"),
1447 data_type: CqlDataType::Uuid,
1448 is_static: false,
1449 },
1450 CqlColumnDef {
1451 name: CqlIdentifier::new("name"),
1452 data_type: CqlDataType::Text,
1453 is_static: false,
1454 },
1455 CqlColumnDef {
1456 name: CqlIdentifier::new("age"),
1457 data_type: CqlDataType::Int,
1458 is_static: false,
1459 },
1460 CqlColumnDef {
1461 name: CqlIdentifier::new("tags"),
1462 data_type: CqlDataType::List(Box::new(CqlDataType::Text)),
1463 is_static: false,
1464 },
1465 ],
1466 primary_key: CqlPrimaryKey {
1467 partition_key: vec![CqlIdentifier::new("id")],
1468 clustering_key: vec![CqlIdentifier::new("name")],
1469 },
1470 options: CqlTableOptions {
1471 options: HashMap::new(),
1472 },
1473 };
1474
1475 let statement = CqlStatement::CreateTable(create_table);
1476 let mut visitor = SchemaBuilderVisitor;
1477 let schema = visitor.visit_statement(&statement).unwrap();
1478
1479 assert_eq!(schema.keyspace, "test_keyspace");
1481 assert_eq!(schema.table, "users");
1482 assert_eq!(schema.partition_keys.len(), 1);
1483 assert_eq!(schema.partition_keys[0].name, "id");
1484 assert_eq!(schema.partition_keys[0].data_type, "uuid");
1485 assert_eq!(schema.clustering_keys.len(), 1);
1486 assert_eq!(schema.clustering_keys[0].name, "name");
1487 assert_eq!(schema.clustering_keys[0].data_type, "text");
1488 assert_eq!(schema.columns.len(), 4);
1489
1490 let tags_column = schema
1492 .columns
1493 .iter()
1494 .find(|col| col.name == "tags")
1495 .expect("tags column should exist");
1496 assert_eq!(tags_column.data_type, "list<text>");
1497 }
1498
1499 #[test]
1500 fn test_validation_visitor() {
1501 let create_table = CqlCreateTable {
1503 if_not_exists: false,
1504 table: CqlTable::new("test_table"),
1505 columns: vec![
1506 CqlColumnDef {
1507 name: CqlIdentifier::new("id"),
1508 data_type: CqlDataType::Uuid,
1509 is_static: false,
1510 },
1511 CqlColumnDef {
1512 name: CqlIdentifier::new("name"),
1513 data_type: CqlDataType::Text,
1514 is_static: false,
1515 },
1516 CqlColumnDef {
1517 name: CqlIdentifier::new("name"), data_type: CqlDataType::Int,
1519 is_static: false,
1520 },
1521 ],
1522 primary_key: CqlPrimaryKey {
1523 partition_key: vec![CqlIdentifier::new("missing_column")], clustering_key: vec![],
1525 },
1526 options: CqlTableOptions {
1527 options: HashMap::new(),
1528 },
1529 };
1530
1531 let statement = CqlStatement::CreateTable(create_table);
1532 let mut visitor = ValidationVisitor::new();
1533 let _ = visitor.visit_statement(&statement);
1534
1535 assert!(visitor.has_errors());
1537 let errors = visitor.get_errors();
1538 assert!(errors.iter().any(|e| e.contains("Duplicate column name")));
1539 assert!(errors
1540 .iter()
1541 .any(|e| e.contains("not found in column definitions")));
1542 }
1543
1544 #[test]
1545 fn test_type_collector_visitor() {
1546 let create_table = CqlCreateTable {
1548 if_not_exists: false,
1549 table: CqlTable::new("test_table"),
1550 columns: vec![
1551 CqlColumnDef {
1552 name: CqlIdentifier::new("simple"),
1553 data_type: CqlDataType::Text,
1554 is_static: false,
1555 },
1556 CqlColumnDef {
1557 name: CqlIdentifier::new("list_col"),
1558 data_type: CqlDataType::List(Box::new(CqlDataType::Int)),
1559 is_static: false,
1560 },
1561 CqlColumnDef {
1562 name: CqlIdentifier::new("map_col"),
1563 data_type: CqlDataType::Map(
1564 Box::new(CqlDataType::Text),
1565 Box::new(CqlDataType::Uuid),
1566 ),
1567 is_static: false,
1568 },
1569 CqlColumnDef {
1570 name: CqlIdentifier::new("frozen_col"),
1571 data_type: CqlDataType::Frozen(Box::new(CqlDataType::Set(Box::new(
1572 CqlDataType::BigInt,
1573 )))),
1574 is_static: false,
1575 },
1576 ],
1577 primary_key: CqlPrimaryKey {
1578 partition_key: vec![CqlIdentifier::new("simple")],
1579 clustering_key: vec![],
1580 },
1581 options: CqlTableOptions {
1582 options: HashMap::new(),
1583 },
1584 };
1585
1586 let statement = CqlStatement::CreateTable(create_table);
1587 let mut visitor = TypeCollectorVisitor::new();
1588 let _ = visitor.visit_statement(&statement);
1589
1590 let types = visitor.into_types();
1591
1592 assert!(types.iter().any(|t| matches!(t, CqlDataType::Text)));
1594 assert!(types.iter().any(|t| matches!(t, CqlDataType::List(_))));
1595 assert!(types.iter().any(|t| matches!(t, CqlDataType::Int)));
1596 assert!(types.iter().any(|t| matches!(t, CqlDataType::Map(_, _))));
1597 assert!(types.iter().any(|t| matches!(t, CqlDataType::Uuid)));
1598 assert!(types.iter().any(|t| matches!(t, CqlDataType::Frozen(_))));
1599 assert!(types.iter().any(|t| matches!(t, CqlDataType::Set(_))));
1600 assert!(types.iter().any(|t| matches!(t, CqlDataType::BigInt)));
1601 }
1602
1603 #[test]
1604 fn test_default_visitor_handles_batch_and_truncate() {
1605 let insert = CqlInsert {
1606 table: CqlTable::new("users"),
1607 columns: vec![CqlIdentifier::new("id"), CqlIdentifier::new("name")],
1608 values: CqlInsertValues::Values(vec![
1609 CqlExpression::Literal(CqlLiteral::Integer(1)),
1610 CqlExpression::Literal(CqlLiteral::String("alice".to_string())),
1611 ]),
1612 if_not_exists: false,
1613 using: None,
1614 };
1615
1616 let batch = CqlStatement::Batch(CqlBatch {
1617 batch_type: CqlBatchType::Logged,
1618 using: None,
1619 statements: vec![CqlBatchStatement::Insert(insert.clone())],
1620 });
1621
1622 let truncate = CqlStatement::Truncate(CqlTruncate {
1623 table: CqlTable::new("users"),
1624 });
1625
1626 let mut visitor = DefaultVisitor;
1627 let _: () = visitor.visit_statement(&batch).unwrap();
1628
1629 let _: () = visitor.visit_statement(&truncate).unwrap();
1630 }
1631}