1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_registry::canonical_typed_function_name_upper;
14use crate::optimizer::annotate_types;
15use crate::resolver::Resolver;
16use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
17use crate::scope::build_scope;
18use crate::traversal::ExpressionWalk;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SchemaColumn {
25 pub name: String,
27 #[serde(default, rename = "type")]
29 pub data_type: String,
30 #[serde(default)]
32 pub nullable: Option<bool>,
33 #[serde(default, rename = "primaryKey")]
35 pub primary_key: bool,
36 #[serde(default)]
38 pub unique: bool,
39 #[serde(default)]
41 pub references: Option<SchemaColumnReference>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SchemaColumnReference {
47 pub table: String,
49 pub column: String,
51 #[serde(default)]
53 pub schema: Option<String>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SchemaForeignKey {
59 #[serde(default)]
61 pub name: Option<String>,
62 pub columns: Vec<String>,
64 pub references: SchemaTableReference,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct SchemaTableReference {
71 pub table: String,
73 pub columns: Vec<String>,
75 #[serde(default)]
77 pub schema: Option<String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SchemaTable {
83 pub name: String,
85 #[serde(default)]
87 pub schema: Option<String>,
88 pub columns: Vec<SchemaColumn>,
90 #[serde(default)]
92 pub aliases: Vec<String>,
93 #[serde(default, rename = "primaryKey")]
95 pub primary_key: Vec<String>,
96 #[serde(default, rename = "uniqueKeys")]
98 pub unique_keys: Vec<Vec<String>>,
99 #[serde(default, rename = "foreignKeys")]
101 pub foreign_keys: Vec<SchemaForeignKey>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ValidationSchema {
107 pub tables: Vec<SchemaTable>,
109 #[serde(default)]
111 pub strict: Option<bool>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct SchemaValidationOptions {
117 #[serde(default)]
119 pub check_types: bool,
120 #[serde(default)]
122 pub check_references: bool,
123 #[serde(default)]
125 pub strict: Option<bool>,
126 #[serde(default)]
128 pub semantic: bool,
129 #[serde(default)]
131 pub strict_syntax: bool,
132}
133
134pub mod validation_codes {
136 pub const E_PARSE_OR_OPTIONS: &str = "E000";
138 pub const E_UNKNOWN_TABLE: &str = "E200";
139 pub const E_UNKNOWN_COLUMN: &str = "E201";
140
141 pub const W_SELECT_STAR: &str = "W001";
142 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
143 pub const W_DISTINCT_ORDER_BY: &str = "W003";
144 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
145
146 pub const E_TYPE_MISMATCH: &str = "E210";
148 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
149 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
150 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
151 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
152 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
153 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
154 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
155 pub const E_INVALID_CAST: &str = "E218";
156 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
157
158 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
159 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
160 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
161 pub const W_LOSSY_CAST: &str = "W213";
162 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
163 pub const W_PREDICATE_NULLABILITY: &str = "W215";
164 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
165 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
166 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
167 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
168
169 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
171 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
172 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
173 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
174 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
175
176 pub const W_CARTESIAN_JOIN: &str = "W220";
177 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
178 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
179}
180
181#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum TypeFamily {
185 Unknown,
186 Boolean,
187 Integer,
188 Numeric,
189 String,
190 Binary,
191 Date,
192 Time,
193 Timestamp,
194 Interval,
195 Json,
196 Uuid,
197 Array,
198 Map,
199 Struct,
200}
201
202impl TypeFamily {
203 pub fn is_numeric(self) -> bool {
204 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
205 }
206
207 pub fn is_temporal(self) -> bool {
208 matches!(
209 self,
210 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
211 )
212 }
213}
214
215#[derive(Debug, Clone)]
216struct TableSchemaEntry {
217 columns: HashMap<String, TypeFamily>,
218 column_order: Vec<String>,
219}
220
221fn lower(s: &str) -> String {
222 s.to_lowercase()
223}
224
225fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
226 let open = data_type.find('(')?;
227 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
228 return None;
229 }
230 let base = data_type[..open].trim();
231 let inner = data_type[open + 1..data_type.len() - 1].trim();
232 Some((base, inner))
233}
234
235pub fn canonical_type_family(data_type: &str) -> TypeFamily {
237 let trimmed = data_type
238 .trim()
239 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
240 if trimmed.is_empty() {
241 return TypeFamily::Unknown;
242 }
243
244 let normalized = trimmed
246 .split_whitespace()
247 .collect::<Vec<_>>()
248 .join(" ")
249 .to_lowercase();
250
251 if let Some((base, inner)) = split_type_args(&normalized) {
253 match base {
254 "nullable" | "lowcardinality" => return canonical_type_family(inner),
255 "array" | "list" => return TypeFamily::Array,
256 "map" => return TypeFamily::Map,
257 "struct" | "row" | "record" => return TypeFamily::Struct,
258 _ => {}
259 }
260 }
261
262 if normalized.starts_with("array<") || normalized.starts_with("list<") {
263 return TypeFamily::Array;
264 }
265 if normalized.starts_with("map<") {
266 return TypeFamily::Map;
267 }
268 if normalized.starts_with("struct<")
269 || normalized.starts_with("row<")
270 || normalized.starts_with("record<")
271 || normalized.starts_with("object<")
272 {
273 return TypeFamily::Struct;
274 }
275
276 if normalized.ends_with("[]") {
277 return TypeFamily::Array;
278 }
279
280 let mut base = normalized
282 .split('(')
283 .next()
284 .unwrap_or("")
285 .trim()
286 .to_string();
287 if base.is_empty() {
288 return TypeFamily::Unknown;
289 }
290
291 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
292 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
293
294 match base.as_str() {
295 "bool" | "boolean" => TypeFamily::Boolean,
296 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
297 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
298 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
299 TypeFamily::Integer
300 }
301 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
302 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
303 TypeFamily::Numeric
304 }
305 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
306 | "string" | "clob" => TypeFamily::String,
307 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
308 "date" => TypeFamily::Date,
309 "time" => TypeFamily::Time,
310 "timestamp"
311 | "timestamptz"
312 | "datetime"
313 | "datetime2"
314 | "smalldatetime"
315 | "timestamp with time zone"
316 | "timestamp without time zone" => TypeFamily::Timestamp,
317 "interval" => TypeFamily::Interval,
318 "json" | "jsonb" | "variant" => TypeFamily::Json,
319 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
320 "array" | "list" => TypeFamily::Array,
321 "map" => TypeFamily::Map,
322 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
323 _ => TypeFamily::Unknown,
324 }
325}
326
327fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
328 let mut map = HashMap::new();
329
330 for table in &schema.tables {
331 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
332 let columns: HashMap<String, TypeFamily> = table
333 .columns
334 .iter()
335 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
336 .collect();
337 let entry = TableSchemaEntry {
338 columns,
339 column_order,
340 };
341
342 let simple_name = lower(&table.name);
343 map.insert(simple_name, entry.clone());
344
345 if let Some(table_schema) = &table.schema {
346 map.insert(
347 format!("{}.{}", lower(table_schema), lower(&table.name)),
348 entry.clone(),
349 );
350 }
351
352 for alias in &table.aliases {
353 map.insert(lower(alias), entry.clone());
354 }
355 }
356
357 map
358}
359
360fn type_family_to_data_type(family: TypeFamily) -> DataType {
361 match family {
362 TypeFamily::Unknown => DataType::Unknown,
363 TypeFamily::Boolean => DataType::Boolean,
364 TypeFamily::Integer => DataType::Int {
365 length: None,
366 integer_spelling: false,
367 },
368 TypeFamily::Numeric => DataType::Double {
369 precision: None,
370 scale: None,
371 },
372 TypeFamily::String => DataType::VarChar {
373 length: None,
374 parenthesized_length: false,
375 },
376 TypeFamily::Binary => DataType::VarBinary { length: None },
377 TypeFamily::Date => DataType::Date,
378 TypeFamily::Time => DataType::Time {
379 precision: None,
380 timezone: false,
381 },
382 TypeFamily::Timestamp => DataType::Timestamp {
383 precision: None,
384 timezone: false,
385 },
386 TypeFamily::Interval => DataType::Interval {
387 unit: None,
388 to: None,
389 },
390 TypeFamily::Json => DataType::Json,
391 TypeFamily::Uuid => DataType::Uuid,
392 TypeFamily::Array => DataType::Array {
393 element_type: Box::new(DataType::Unknown),
394 dimension: None,
395 },
396 TypeFamily::Map => DataType::Map {
397 key_type: Box::new(DataType::Unknown),
398 value_type: Box::new(DataType::Unknown),
399 },
400 TypeFamily::Struct => DataType::Struct {
401 fields: Vec::new(),
402 nested: false,
403 },
404 }
405}
406
407fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
408 let mut mapping = MappingSchema::new();
409
410 for table in &schema.tables {
411 let columns: Vec<(String, DataType)> = table
412 .columns
413 .iter()
414 .map(|column| {
415 (
416 lower(&column.name),
417 type_family_to_data_type(canonical_type_family(&column.data_type)),
418 )
419 })
420 .collect();
421
422 let mut table_names = Vec::new();
423 table_names.push(lower(&table.name));
424 if let Some(table_schema) = &table.schema {
425 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
426 }
427 for alias in &table.aliases {
428 table_names.push(lower(alias));
429 }
430
431 let mut dedup = HashSet::new();
432 for table_name in table_names {
433 if dedup.insert(table_name.clone()) {
434 let _ = mapping.add_table(&table_name, &columns, None);
435 }
436 }
437 }
438
439 mapping
440}
441
442fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
443 let mut aliases = HashSet::new();
444
445 for node in expr.dfs() {
446 match node {
447 Expression::Select(select) => {
448 if let Some(with) = &select.with {
449 for cte in &with.ctes {
450 aliases.insert(lower(&cte.alias.name));
451 }
452 }
453 }
454 Expression::Insert(insert) => {
455 if let Some(with) = &insert.with {
456 for cte in &with.ctes {
457 aliases.insert(lower(&cte.alias.name));
458 }
459 }
460 }
461 Expression::Update(update) => {
462 if let Some(with) = &update.with {
463 for cte in &with.ctes {
464 aliases.insert(lower(&cte.alias.name));
465 }
466 }
467 }
468 Expression::Delete(delete) => {
469 if let Some(with) = &delete.with {
470 for cte in &with.ctes {
471 aliases.insert(lower(&cte.alias.name));
472 }
473 }
474 }
475 Expression::Union(union) => {
476 if let Some(with) = &union.with {
477 for cte in &with.ctes {
478 aliases.insert(lower(&cte.alias.name));
479 }
480 }
481 }
482 Expression::Intersect(intersect) => {
483 if let Some(with) = &intersect.with {
484 for cte in &with.ctes {
485 aliases.insert(lower(&cte.alias.name));
486 }
487 }
488 }
489 Expression::Except(except) => {
490 if let Some(with) = &except.with {
491 for cte in &with.ctes {
492 aliases.insert(lower(&cte.alias.name));
493 }
494 }
495 }
496 Expression::Merge(merge) => {
497 if let Some(with_) = &merge.with_ {
498 if let Expression::With(with_clause) = with_.as_ref() {
499 for cte in &with_clause.ctes {
500 aliases.insert(lower(&cte.alias.name));
501 }
502 }
503 }
504 }
505 _ => {}
506 }
507 }
508
509 aliases
510}
511
512fn table_ref_candidates(table: &TableRef) -> Vec<String> {
513 let name = lower(&table.name.name);
514 let schema = table.schema.as_ref().map(|s| lower(&s.name));
515 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
516
517 let mut candidates = Vec::new();
518 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
519 candidates.push(format!("{}.{}.{}", catalog, schema, name));
520 }
521 if let Some(schema) = &schema {
522 candidates.push(format!("{}.{}", schema, name));
523 }
524 candidates.push(name);
525 candidates
526}
527
528fn table_ref_display_name(table: &TableRef) -> String {
529 let mut parts = Vec::new();
530 if let Some(catalog) = &table.catalog {
531 parts.push(catalog.name.clone());
532 }
533 if let Some(schema) = &table.schema {
534 parts.push(schema.name.clone());
535 }
536 parts.push(table.name.name.clone());
537 parts.join(".")
538}
539
540#[derive(Debug, Default, Clone)]
541struct TypeCheckContext {
542 referenced_tables: HashSet<String>,
543 table_aliases: HashMap<String, String>,
544}
545
546fn type_family_name(family: TypeFamily) -> &'static str {
547 match family {
548 TypeFamily::Unknown => "unknown",
549 TypeFamily::Boolean => "boolean",
550 TypeFamily::Integer => "integer",
551 TypeFamily::Numeric => "numeric",
552 TypeFamily::String => "string",
553 TypeFamily::Binary => "binary",
554 TypeFamily::Date => "date",
555 TypeFamily::Time => "time",
556 TypeFamily::Timestamp => "timestamp",
557 TypeFamily::Interval => "interval",
558 TypeFamily::Json => "json",
559 TypeFamily::Uuid => "uuid",
560 TypeFamily::Array => "array",
561 TypeFamily::Map => "map",
562 TypeFamily::Struct => "struct",
563 }
564}
565
566fn is_string_like(family: TypeFamily) -> bool {
567 matches!(family, TypeFamily::String)
568}
569
570fn is_string_or_binary(family: TypeFamily) -> bool {
571 matches!(family, TypeFamily::String | TypeFamily::Binary)
572}
573
574fn type_issue(
575 strict: bool,
576 error_code: &str,
577 warning_code: &str,
578 message: impl Into<String>,
579) -> ValidationError {
580 if strict {
581 ValidationError::error(message.into(), error_code)
582 } else {
583 ValidationError::warning(message.into(), warning_code)
584 }
585}
586
587fn data_type_family(data_type: &DataType) -> TypeFamily {
588 match data_type {
589 DataType::Boolean => TypeFamily::Boolean,
590 DataType::TinyInt { .. }
591 | DataType::SmallInt { .. }
592 | DataType::Int { .. }
593 | DataType::BigInt { .. } => TypeFamily::Integer,
594 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
595 TypeFamily::Numeric
596 }
597 DataType::Char { .. }
598 | DataType::VarChar { .. }
599 | DataType::String { .. }
600 | DataType::Text
601 | DataType::TextWithLength { .. }
602 | DataType::CharacterSet { .. } => TypeFamily::String,
603 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
604 DataType::Date => TypeFamily::Date,
605 DataType::Time { .. } => TypeFamily::Time,
606 DataType::Timestamp { .. } => TypeFamily::Timestamp,
607 DataType::Interval { .. } => TypeFamily::Interval,
608 DataType::Json | DataType::JsonB => TypeFamily::Json,
609 DataType::Uuid => TypeFamily::Uuid,
610 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
611 DataType::Map { .. } => TypeFamily::Map,
612 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
613 TypeFamily::Struct
614 }
615 DataType::Nullable { inner } => data_type_family(inner),
616 DataType::Custom { name } => canonical_type_family(name),
617 DataType::Unknown => TypeFamily::Unknown,
618 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
619 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
620 DataType::Vector { .. } => TypeFamily::Array,
621 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
622 }
623}
624
625fn collect_type_check_context(
626 stmt: &Expression,
627 schema_map: &HashMap<String, TableSchemaEntry>,
628) -> TypeCheckContext {
629 fn add_table_to_context(
630 table: &TableRef,
631 schema_map: &HashMap<String, TableSchemaEntry>,
632 context: &mut TypeCheckContext,
633 ) {
634 let resolved_key = table_ref_candidates(table)
635 .into_iter()
636 .find(|k| schema_map.contains_key(k));
637
638 let Some(table_key) = resolved_key else {
639 return;
640 };
641
642 context.referenced_tables.insert(table_key.clone());
643 context
644 .table_aliases
645 .insert(lower(&table.name.name), table_key.clone());
646 if let Some(alias) = &table.alias {
647 context
648 .table_aliases
649 .insert(lower(&alias.name), table_key.clone());
650 }
651 }
652
653 let mut context = TypeCheckContext::default();
654 let cte_aliases = collect_cte_aliases(stmt);
655
656 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
657 let Expression::Table(table) = node else {
658 continue;
659 };
660
661 if cte_aliases.contains(&lower(&table.name.name)) {
662 continue;
663 }
664
665 add_table_to_context(table, schema_map, &mut context);
666 }
667
668 match stmt {
671 Expression::Insert(insert) => {
672 add_table_to_context(&insert.table, schema_map, &mut context);
673 }
674 Expression::Update(update) => {
675 add_table_to_context(&update.table, schema_map, &mut context);
676 for table in &update.extra_tables {
677 add_table_to_context(table, schema_map, &mut context);
678 }
679 }
680 Expression::Delete(delete) => {
681 add_table_to_context(&delete.table, schema_map, &mut context);
682 for table in &delete.using {
683 add_table_to_context(table, schema_map, &mut context);
684 }
685 for table in &delete.tables {
686 add_table_to_context(table, schema_map, &mut context);
687 }
688 }
689 _ => {}
690 }
691
692 context
693}
694
695fn resolve_table_schema_entry<'a>(
696 table: &TableRef,
697 schema_map: &'a HashMap<String, TableSchemaEntry>,
698) -> Option<(String, &'a TableSchemaEntry)> {
699 let key = table_ref_candidates(table)
700 .into_iter()
701 .find(|k| schema_map.contains_key(k))?;
702 let entry = schema_map.get(&key)?;
703 Some((key, entry))
704}
705
706fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
707 if strict {
708 ValidationError::error(
709 message.into(),
710 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
711 )
712 } else {
713 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
714 }
715}
716
717fn reference_table_candidates(
718 table_name: &str,
719 explicit_schema: Option<&str>,
720 source_schema: Option<&str>,
721) -> Vec<String> {
722 let mut candidates = Vec::new();
723 let raw = lower(table_name);
724
725 if let Some(schema) = explicit_schema {
726 candidates.push(format!("{}.{}", lower(schema), raw));
727 }
728
729 if raw.contains('.') {
730 candidates.push(raw.clone());
731 if let Some(last) = raw.rsplit('.').next() {
732 candidates.push(last.to_string());
733 }
734 } else {
735 if let Some(schema) = source_schema {
736 candidates.push(format!("{}.{}", lower(schema), raw));
737 }
738 candidates.push(raw);
739 }
740
741 let mut dedup = HashSet::new();
742 candidates
743 .into_iter()
744 .filter(|c| dedup.insert(c.clone()))
745 .collect()
746}
747
748fn resolve_reference_table_key(
749 table_name: &str,
750 explicit_schema: Option<&str>,
751 source_schema: Option<&str>,
752 schema_map: &HashMap<String, TableSchemaEntry>,
753) -> Option<String> {
754 reference_table_candidates(table_name, explicit_schema, source_schema)
755 .into_iter()
756 .find(|candidate| schema_map.contains_key(candidate))
757}
758
759fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
760 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
761 return true;
762 }
763 if source == target {
764 return true;
765 }
766 if source.is_numeric() && target.is_numeric() {
767 return true;
768 }
769 if source.is_temporal() && target.is_temporal() {
770 return true;
771 }
772 false
773}
774
775fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
776 let mut hints = HashSet::new();
777 for column in &table.columns {
778 if column.primary_key || column.unique {
779 hints.insert(lower(&column.name));
780 }
781 }
782 for key_col in &table.primary_key {
783 hints.insert(lower(key_col));
784 }
785 for group in &table.unique_keys {
786 if group.len() == 1 {
787 if let Some(col) = group.first() {
788 hints.insert(lower(col));
789 }
790 }
791 }
792 hints
793}
794
795fn check_reference_integrity(
796 schema: &ValidationSchema,
797 schema_map: &HashMap<String, TableSchemaEntry>,
798 strict: bool,
799) -> Vec<ValidationError> {
800 let mut errors = Vec::new();
801
802 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
803 for table in &schema.tables {
804 let simple = lower(&table.name);
805 key_hints_lookup.insert(simple, table_key_hints(table));
806 if let Some(schema_name) = &table.schema {
807 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
808 key_hints_lookup.insert(qualified, table_key_hints(table));
809 }
810 }
811
812 for table in &schema.tables {
813 let source_table_display = if let Some(schema_name) = &table.schema {
814 format!("{}.{}", schema_name, table.name)
815 } else {
816 table.name.clone()
817 };
818 let source_schema = table.schema.as_deref();
819 let source_columns: HashMap<String, TypeFamily> = table
820 .columns
821 .iter()
822 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
823 .collect();
824
825 for source_col in &table.columns {
826 let Some(reference) = &source_col.references else {
827 continue;
828 };
829 let source_type = canonical_type_family(&source_col.data_type);
830
831 let Some(target_key) = resolve_reference_table_key(
832 &reference.table,
833 reference.schema.as_deref(),
834 source_schema,
835 schema_map,
836 ) else {
837 errors.push(reference_issue(
838 strict,
839 format!(
840 "Foreign key reference '{}.{}' points to unknown table '{}'",
841 source_table_display, source_col.name, reference.table
842 ),
843 ));
844 continue;
845 };
846
847 let target_column = lower(&reference.column);
848 let Some(target_entry) = schema_map.get(&target_key) else {
849 errors.push(reference_issue(
850 strict,
851 format!(
852 "Foreign key reference '{}.{}' points to unknown table '{}'",
853 source_table_display, source_col.name, reference.table
854 ),
855 ));
856 continue;
857 };
858
859 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
860 errors.push(reference_issue(
861 strict,
862 format!(
863 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
864 source_table_display, source_col.name, target_key, reference.column
865 ),
866 ));
867 continue;
868 };
869
870 if !key_types_compatible(source_type, target_type) {
871 errors.push(reference_issue(
872 strict,
873 format!(
874 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
875 source_table_display,
876 source_col.name,
877 target_key,
878 reference.column,
879 type_family_name(source_type),
880 type_family_name(target_type)
881 ),
882 ));
883 }
884
885 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
886 if !target_key_hints.contains(&target_column) {
887 errors.push(ValidationError::warning(
888 format!(
889 "Referenced column '{}.{}' is not marked as primary/unique key",
890 target_key, reference.column
891 ),
892 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
893 ));
894 }
895 }
896 }
897
898 for foreign_key in &table.foreign_keys {
899 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
900 errors.push(reference_issue(
901 strict,
902 format!(
903 "Table-level foreign key on '{}' has empty source or target column list",
904 source_table_display
905 ),
906 ));
907 continue;
908 }
909 if foreign_key.columns.len() != foreign_key.references.columns.len() {
910 errors.push(reference_issue(
911 strict,
912 format!(
913 "Table-level foreign key on '{}' has {} source columns but {} target columns",
914 source_table_display,
915 foreign_key.columns.len(),
916 foreign_key.references.columns.len()
917 ),
918 ));
919 continue;
920 }
921
922 let Some(target_key) = resolve_reference_table_key(
923 &foreign_key.references.table,
924 foreign_key.references.schema.as_deref(),
925 source_schema,
926 schema_map,
927 ) else {
928 errors.push(reference_issue(
929 strict,
930 format!(
931 "Table-level foreign key on '{}' points to unknown table '{}'",
932 source_table_display, foreign_key.references.table
933 ),
934 ));
935 continue;
936 };
937
938 let Some(target_entry) = schema_map.get(&target_key) else {
939 errors.push(reference_issue(
940 strict,
941 format!(
942 "Table-level foreign key on '{}' points to unknown table '{}'",
943 source_table_display, foreign_key.references.table
944 ),
945 ));
946 continue;
947 };
948
949 for (source_col, target_col) in foreign_key
950 .columns
951 .iter()
952 .zip(foreign_key.references.columns.iter())
953 {
954 let source_col_name = lower(source_col);
955 let target_col_name = lower(target_col);
956
957 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
958 errors.push(reference_issue(
959 strict,
960 format!(
961 "Table-level foreign key on '{}' references unknown source column '{}'",
962 source_table_display, source_col
963 ),
964 ));
965 continue;
966 };
967
968 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
969 errors.push(reference_issue(
970 strict,
971 format!(
972 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
973 source_table_display, target_key, target_col
974 ),
975 ));
976 continue;
977 };
978
979 if !key_types_compatible(source_type, target_type) {
980 errors.push(reference_issue(
981 strict,
982 format!(
983 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
984 source_table_display,
985 source_col,
986 target_key,
987 target_col,
988 type_family_name(source_type),
989 type_family_name(target_type)
990 ),
991 ));
992 }
993
994 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
995 if !target_key_hints.contains(&target_col_name) {
996 errors.push(ValidationError::warning(
997 format!(
998 "Referenced column '{}.{}' is not marked as primary/unique key",
999 target_key, target_col
1000 ),
1001 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1002 ));
1003 }
1004 }
1005 }
1006 }
1007 }
1008
1009 errors
1010}
1011
1012fn resolve_unqualified_column_type(
1013 column_name: &str,
1014 schema_map: &HashMap<String, TableSchemaEntry>,
1015 context: &TypeCheckContext,
1016) -> TypeFamily {
1017 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1018 context.referenced_tables.iter().collect()
1019 } else {
1020 schema_map.keys().collect()
1021 };
1022
1023 let mut families = HashSet::new();
1024 for table_name in candidate_tables {
1025 if let Some(table_schema) = schema_map.get(table_name) {
1026 if let Some(family) = table_schema.columns.get(column_name) {
1027 families.insert(*family);
1028 }
1029 }
1030 }
1031
1032 if families.len() == 1 {
1033 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1034 } else {
1035 TypeFamily::Unknown
1036 }
1037}
1038
1039fn resolve_column_type(
1040 column: &Column,
1041 schema_map: &HashMap<String, TableSchemaEntry>,
1042 context: &TypeCheckContext,
1043) -> TypeFamily {
1044 let column_name = lower(&column.name.name);
1045 if column_name.is_empty() {
1046 return TypeFamily::Unknown;
1047 }
1048
1049 if let Some(table) = &column.table {
1050 let mut table_key = lower(&table.name);
1051 if let Some(mapped) = context.table_aliases.get(&table_key) {
1052 table_key = mapped.clone();
1053 }
1054
1055 return schema_map
1056 .get(&table_key)
1057 .and_then(|t| t.columns.get(&column_name))
1058 .copied()
1059 .unwrap_or(TypeFamily::Unknown);
1060 }
1061
1062 resolve_unqualified_column_type(&column_name, schema_map, context)
1063}
1064
1065struct TypeInferenceSchema<'a> {
1066 schema_map: &'a HashMap<String, TableSchemaEntry>,
1067 context: &'a TypeCheckContext,
1068}
1069
1070impl TypeInferenceSchema<'_> {
1071 fn resolve_table_key(&self, table: &str) -> Option<String> {
1072 let mut table_key = lower(table);
1073 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1074 table_key = mapped.clone();
1075 }
1076 if self.schema_map.contains_key(&table_key) {
1077 Some(table_key)
1078 } else {
1079 None
1080 }
1081 }
1082}
1083
1084impl SqlSchema for TypeInferenceSchema<'_> {
1085 fn dialect(&self) -> Option<DialectType> {
1086 None
1087 }
1088
1089 fn add_table(
1090 &mut self,
1091 _table: &str,
1092 _columns: &[(String, DataType)],
1093 _dialect: Option<DialectType>,
1094 ) -> SchemaResult<()> {
1095 Err(SchemaError::InvalidStructure(
1096 "Type inference schema is read-only".to_string(),
1097 ))
1098 }
1099
1100 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1101 let table_key = self
1102 .resolve_table_key(table)
1103 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1104 let entry = self
1105 .schema_map
1106 .get(&table_key)
1107 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1108 Ok(entry.column_order.clone())
1109 }
1110
1111 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1112 let col_name = lower(column);
1113 if table.is_empty() {
1114 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1115 return if family == TypeFamily::Unknown {
1116 Err(SchemaError::ColumnNotFound {
1117 table: "<unqualified>".to_string(),
1118 column: column.to_string(),
1119 })
1120 } else {
1121 Ok(type_family_to_data_type(family))
1122 };
1123 }
1124
1125 let table_key = self
1126 .resolve_table_key(table)
1127 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1128 let entry = self
1129 .schema_map
1130 .get(&table_key)
1131 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1132 let family =
1133 entry
1134 .columns
1135 .get(&col_name)
1136 .copied()
1137 .ok_or_else(|| SchemaError::ColumnNotFound {
1138 table: table.to_string(),
1139 column: column.to_string(),
1140 })?;
1141 Ok(type_family_to_data_type(family))
1142 }
1143
1144 fn has_column(&self, table: &str, column: &str) -> bool {
1145 self.get_column_type(table, column).is_ok()
1146 }
1147
1148 fn supported_table_args(&self) -> &[&str] {
1149 TABLE_PARTS
1150 }
1151
1152 fn is_empty(&self) -> bool {
1153 self.schema_map.is_empty()
1154 }
1155
1156 fn depth(&self) -> usize {
1157 1
1158 }
1159}
1160
1161fn infer_expression_type_family(
1162 expr: &Expression,
1163 schema_map: &HashMap<String, TableSchemaEntry>,
1164 context: &TypeCheckContext,
1165) -> TypeFamily {
1166 let inference_schema = TypeInferenceSchema {
1167 schema_map,
1168 context,
1169 };
1170 if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1171 let family = data_type_family(&data_type);
1172 if family != TypeFamily::Unknown {
1173 return family;
1174 }
1175 }
1176
1177 infer_expression_type_family_fallback(expr, schema_map, context)
1178}
1179
1180fn infer_expression_type_family_fallback(
1181 expr: &Expression,
1182 schema_map: &HashMap<String, TableSchemaEntry>,
1183 context: &TypeCheckContext,
1184) -> TypeFamily {
1185 match expr {
1186 Expression::Literal(literal) => match literal {
1187 crate::expressions::Literal::Number(value) => {
1188 if value.contains('.') || value.contains('e') || value.contains('E') {
1189 TypeFamily::Numeric
1190 } else {
1191 TypeFamily::Integer
1192 }
1193 }
1194 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1195 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1196 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1197 crate::expressions::Literal::Timestamp(_)
1198 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1199 crate::expressions::Literal::HexString(_)
1200 | crate::expressions::Literal::BitString(_)
1201 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1202 _ => TypeFamily::String,
1203 },
1204 Expression::Boolean(_) => TypeFamily::Boolean,
1205 Expression::Null(_) => TypeFamily::Unknown,
1206 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1207 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1208 data_type_family(&cast.to)
1209 }
1210 Expression::Alias(alias) => {
1211 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1212 }
1213 Expression::Neg(unary) => {
1214 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1215 }
1216 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1217 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1218 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1219 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1220 TypeFamily::Unknown
1221 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1222 TypeFamily::Integer
1223 } else if left.is_numeric() && right.is_numeric() {
1224 TypeFamily::Numeric
1225 } else if left.is_temporal() || right.is_temporal() {
1226 left
1227 } else {
1228 TypeFamily::Unknown
1229 }
1230 }
1231 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1232 Expression::Concat(_) => TypeFamily::String,
1233 Expression::Eq(_)
1234 | Expression::Neq(_)
1235 | Expression::Lt(_)
1236 | Expression::Lte(_)
1237 | Expression::Gt(_)
1238 | Expression::Gte(_)
1239 | Expression::Like(_)
1240 | Expression::ILike(_)
1241 | Expression::And(_)
1242 | Expression::Or(_)
1243 | Expression::Not(_)
1244 | Expression::Between(_)
1245 | Expression::In(_)
1246 | Expression::IsNull(_)
1247 | Expression::IsTrue(_)
1248 | Expression::IsFalse(_)
1249 | Expression::Is(_) => TypeFamily::Boolean,
1250 Expression::Length(_) => TypeFamily::Integer,
1251 Expression::Upper(_)
1252 | Expression::Lower(_)
1253 | Expression::Trim(_)
1254 | Expression::LTrim(_)
1255 | Expression::RTrim(_)
1256 | Expression::Replace(_)
1257 | Expression::Substring(_)
1258 | Expression::Left(_)
1259 | Expression::Right(_)
1260 | Expression::Repeat(_)
1261 | Expression::Lpad(_)
1262 | Expression::Rpad(_)
1263 | Expression::ConcatWs(_) => TypeFamily::String,
1264 Expression::Abs(_)
1265 | Expression::Round(_)
1266 | Expression::Floor(_)
1267 | Expression::Ceil(_)
1268 | Expression::Power(_)
1269 | Expression::Sqrt(_)
1270 | Expression::Cbrt(_)
1271 | Expression::Ln(_)
1272 | Expression::Log(_)
1273 | Expression::Exp(_) => TypeFamily::Numeric,
1274 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1275 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1276 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1277 Expression::CurrentDate(_) => TypeFamily::Date,
1278 Expression::CurrentTime(_) => TypeFamily::Time,
1279 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1280 TypeFamily::Timestamp
1281 }
1282 Expression::Interval(_) => TypeFamily::Interval,
1283 _ => TypeFamily::Unknown,
1284 }
1285}
1286
1287fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1288 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1289 return true;
1290 }
1291 if left == right {
1292 return true;
1293 }
1294 if left.is_numeric() && right.is_numeric() {
1295 return true;
1296 }
1297 if left.is_temporal() && right.is_temporal() {
1298 return true;
1299 }
1300 false
1301}
1302
1303fn check_function_argument(
1304 errors: &mut Vec<ValidationError>,
1305 strict: bool,
1306 function_name: &str,
1307 arg_index: usize,
1308 family: TypeFamily,
1309 expected: &str,
1310 valid: bool,
1311) {
1312 if family == TypeFamily::Unknown || valid {
1313 return;
1314 }
1315
1316 errors.push(type_issue(
1317 strict,
1318 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1319 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1320 format!(
1321 "Function '{}' argument {} expects {}, found {}",
1322 function_name,
1323 arg_index + 1,
1324 expected,
1325 type_family_name(family)
1326 ),
1327 ));
1328}
1329
1330fn function_dispatch_name(name: &str) -> String {
1331 let upper = name
1332 .rsplit('.')
1333 .next()
1334 .unwrap_or(name)
1335 .trim()
1336 .to_uppercase();
1337 lower(canonical_typed_function_name_upper(&upper))
1338}
1339
1340fn check_generic_function(
1341 function: &Function,
1342 schema_map: &HashMap<String, TableSchemaEntry>,
1343 context: &TypeCheckContext,
1344 strict: bool,
1345 errors: &mut Vec<ValidationError>,
1346) {
1347 let name = function_dispatch_name(&function.name);
1348
1349 let arg_family = |index: usize| -> Option<TypeFamily> {
1350 function
1351 .args
1352 .get(index)
1353 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1354 };
1355
1356 match name.as_str() {
1357 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1358 if let Some(family) = arg_family(0) {
1359 check_function_argument(
1360 errors,
1361 strict,
1362 &name,
1363 0,
1364 family,
1365 "a numeric argument",
1366 family.is_numeric(),
1367 );
1368 }
1369 }
1370 "round" | "floor" | "ceil" | "ceiling" => {
1371 if let Some(family) = arg_family(0) {
1372 check_function_argument(
1373 errors,
1374 strict,
1375 &name,
1376 0,
1377 family,
1378 "a numeric argument",
1379 family.is_numeric(),
1380 );
1381 }
1382 if let Some(family) = arg_family(1) {
1383 check_function_argument(
1384 errors,
1385 strict,
1386 &name,
1387 1,
1388 family,
1389 "a numeric argument",
1390 family.is_numeric(),
1391 );
1392 }
1393 }
1394 "power" | "pow" => {
1395 for i in [0_usize, 1_usize] {
1396 if let Some(family) = arg_family(i) {
1397 check_function_argument(
1398 errors,
1399 strict,
1400 &name,
1401 i,
1402 family,
1403 "a numeric argument",
1404 family.is_numeric(),
1405 );
1406 }
1407 }
1408 }
1409 "length" | "char_length" | "character_length" => {
1410 if let Some(family) = arg_family(0) {
1411 check_function_argument(
1412 errors,
1413 strict,
1414 &name,
1415 0,
1416 family,
1417 "a string or binary argument",
1418 is_string_or_binary(family),
1419 );
1420 }
1421 }
1422 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1423 if let Some(family) = arg_family(0) {
1424 check_function_argument(
1425 errors,
1426 strict,
1427 &name,
1428 0,
1429 family,
1430 "a string argument",
1431 is_string_like(family),
1432 );
1433 }
1434 }
1435 "substring" | "substr" => {
1436 if let Some(family) = arg_family(0) {
1437 check_function_argument(
1438 errors,
1439 strict,
1440 &name,
1441 0,
1442 family,
1443 "a string argument",
1444 is_string_like(family),
1445 );
1446 }
1447 if let Some(family) = arg_family(1) {
1448 check_function_argument(
1449 errors,
1450 strict,
1451 &name,
1452 1,
1453 family,
1454 "a numeric argument",
1455 family.is_numeric(),
1456 );
1457 }
1458 if let Some(family) = arg_family(2) {
1459 check_function_argument(
1460 errors,
1461 strict,
1462 &name,
1463 2,
1464 family,
1465 "a numeric argument",
1466 family.is_numeric(),
1467 );
1468 }
1469 }
1470 "replace" => {
1471 for i in [0_usize, 1_usize, 2_usize] {
1472 if let Some(family) = arg_family(i) {
1473 check_function_argument(
1474 errors,
1475 strict,
1476 &name,
1477 i,
1478 family,
1479 "a string argument",
1480 is_string_like(family),
1481 );
1482 }
1483 }
1484 }
1485 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1486 if let Some(family) = arg_family(0) {
1487 check_function_argument(
1488 errors,
1489 strict,
1490 &name,
1491 0,
1492 family,
1493 "a string argument",
1494 is_string_like(family),
1495 );
1496 }
1497 if let Some(family) = arg_family(1) {
1498 check_function_argument(
1499 errors,
1500 strict,
1501 &name,
1502 1,
1503 family,
1504 "a numeric argument",
1505 family.is_numeric(),
1506 );
1507 }
1508 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1509 if let Some(family) = arg_family(2) {
1510 check_function_argument(
1511 errors,
1512 strict,
1513 &name,
1514 2,
1515 family,
1516 "a string argument",
1517 is_string_like(family),
1518 );
1519 }
1520 }
1521 }
1522 _ => {}
1523 }
1524}
1525
1526#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1527struct DeclaredRelationship {
1528 source_table: String,
1529 source_column: String,
1530 target_table: String,
1531 target_column: String,
1532}
1533
1534fn build_declared_relationships(
1535 schema: &ValidationSchema,
1536 schema_map: &HashMap<String, TableSchemaEntry>,
1537) -> Vec<DeclaredRelationship> {
1538 let mut relationships = HashSet::new();
1539
1540 for table in &schema.tables {
1541 let Some(source_key) =
1542 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1543 else {
1544 continue;
1545 };
1546
1547 for column in &table.columns {
1548 let Some(reference) = &column.references else {
1549 continue;
1550 };
1551 let Some(target_key) = resolve_reference_table_key(
1552 &reference.table,
1553 reference.schema.as_deref(),
1554 table.schema.as_deref(),
1555 schema_map,
1556 ) else {
1557 continue;
1558 };
1559
1560 relationships.insert(DeclaredRelationship {
1561 source_table: source_key.clone(),
1562 source_column: lower(&column.name),
1563 target_table: target_key,
1564 target_column: lower(&reference.column),
1565 });
1566 }
1567
1568 for foreign_key in &table.foreign_keys {
1569 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1570 continue;
1571 }
1572 let Some(target_key) = resolve_reference_table_key(
1573 &foreign_key.references.table,
1574 foreign_key.references.schema.as_deref(),
1575 table.schema.as_deref(),
1576 schema_map,
1577 ) else {
1578 continue;
1579 };
1580
1581 for (source_col, target_col) in foreign_key
1582 .columns
1583 .iter()
1584 .zip(foreign_key.references.columns.iter())
1585 {
1586 relationships.insert(DeclaredRelationship {
1587 source_table: source_key.clone(),
1588 source_column: lower(source_col),
1589 target_table: target_key.clone(),
1590 target_column: lower(target_col),
1591 });
1592 }
1593 }
1594 }
1595
1596 relationships.into_iter().collect()
1597}
1598
1599fn resolve_column_binding(
1600 column: &Column,
1601 schema_map: &HashMap<String, TableSchemaEntry>,
1602 context: &TypeCheckContext,
1603 resolver: &mut Resolver<'_>,
1604) -> Option<(String, String)> {
1605 let column_name = lower(&column.name.name);
1606 if column_name.is_empty() {
1607 return None;
1608 }
1609
1610 if let Some(table) = &column.table {
1611 let mut table_key = lower(&table.name);
1612 if let Some(mapped) = context.table_aliases.get(&table_key) {
1613 table_key = mapped.clone();
1614 }
1615 if schema_map.contains_key(&table_key) {
1616 return Some((table_key, column_name));
1617 }
1618 return None;
1619 }
1620
1621 if let Some(resolved_source) = resolver.get_table(&column_name) {
1622 let mut table_key = lower(&resolved_source);
1623 if let Some(mapped) = context.table_aliases.get(&table_key) {
1624 table_key = mapped.clone();
1625 }
1626 if schema_map.contains_key(&table_key) {
1627 return Some((table_key, column_name));
1628 }
1629 }
1630
1631 let candidates: Vec<String> = context
1632 .referenced_tables
1633 .iter()
1634 .filter_map(|table_name| {
1635 schema_map
1636 .get(table_name)
1637 .filter(|entry| entry.columns.contains_key(&column_name))
1638 .map(|_| table_name.clone())
1639 })
1640 .collect();
1641 if candidates.len() == 1 {
1642 return Some((candidates[0].clone(), column_name));
1643 }
1644 None
1645}
1646
1647fn extract_join_equality_pairs(
1648 expr: &Expression,
1649 schema_map: &HashMap<String, TableSchemaEntry>,
1650 context: &TypeCheckContext,
1651 resolver: &mut Resolver<'_>,
1652 pairs: &mut Vec<((String, String), (String, String))>,
1653) {
1654 match expr {
1655 Expression::And(op) => {
1656 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1657 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1658 }
1659 Expression::Paren(paren) => {
1660 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1661 }
1662 Expression::Eq(op) => {
1663 let (Expression::Column(left_col), Expression::Column(right_col)) =
1664 (&op.left, &op.right)
1665 else {
1666 return;
1667 };
1668 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1669 return;
1670 };
1671 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1672 else {
1673 return;
1674 };
1675 pairs.push((left, right));
1676 }
1677 _ => {}
1678 }
1679}
1680
1681fn relationship_matches_pair(
1682 relationship: &DeclaredRelationship,
1683 left_table: &str,
1684 left_column: &str,
1685 right_table: &str,
1686 right_column: &str,
1687) -> bool {
1688 (relationship.source_table == left_table
1689 && relationship.source_column == left_column
1690 && relationship.target_table == right_table
1691 && relationship.target_column == right_column)
1692 || (relationship.source_table == right_table
1693 && relationship.source_column == right_column
1694 && relationship.target_table == left_table
1695 && relationship.target_column == left_column)
1696}
1697
1698fn resolved_table_key_from_expr(
1699 expr: &Expression,
1700 schema_map: &HashMap<String, TableSchemaEntry>,
1701) -> Option<String> {
1702 match expr {
1703 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1704 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1705 _ => None,
1706 }
1707}
1708
1709fn select_from_table_keys(
1710 select: &crate::expressions::Select,
1711 schema_map: &HashMap<String, TableSchemaEntry>,
1712) -> HashSet<String> {
1713 let mut keys = HashSet::new();
1714 if let Some(from_clause) = &select.from {
1715 for expr in &from_clause.expressions {
1716 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1717 keys.insert(key);
1718 }
1719 }
1720 }
1721 keys
1722}
1723
1724fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1725 matches!(
1726 kind,
1727 JoinKind::Natural
1728 | JoinKind::NaturalLeft
1729 | JoinKind::NaturalRight
1730 | JoinKind::NaturalFull
1731 | JoinKind::CrossApply
1732 | JoinKind::OuterApply
1733 | JoinKind::AsOf
1734 | JoinKind::AsOfLeft
1735 | JoinKind::AsOfRight
1736 | JoinKind::Lateral
1737 | JoinKind::LeftLateral
1738 )
1739}
1740
1741fn check_query_reference_quality(
1742 stmt: &Expression,
1743 schema_map: &HashMap<String, TableSchemaEntry>,
1744 resolver_schema: &MappingSchema,
1745 strict: bool,
1746 relationships: &[DeclaredRelationship],
1747) -> Vec<ValidationError> {
1748 let mut errors = Vec::new();
1749
1750 for node in stmt.dfs() {
1751 let Expression::Select(select) = node else {
1752 continue;
1753 };
1754
1755 let select_expr = Expression::Select(select.clone());
1756 let context = collect_type_check_context(&select_expr, schema_map);
1757 let scope = build_scope(&select_expr);
1758 let mut resolver = Resolver::new(&scope, resolver_schema, true);
1759
1760 if context.referenced_tables.len() > 1 {
1761 let using_columns: HashSet<String> = select
1762 .joins
1763 .iter()
1764 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1765 .collect();
1766
1767 let mut seen = HashSet::new();
1768 for column_expr in select_expr
1769 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
1770 {
1771 let Expression::Column(column) = column_expr else {
1772 continue;
1773 };
1774
1775 let col_name = lower(&column.name.name);
1776 if col_name.is_empty()
1777 || using_columns.contains(&col_name)
1778 || !seen.insert(col_name.clone())
1779 {
1780 continue;
1781 }
1782
1783 if resolver.is_ambiguous(&col_name) {
1784 let source_count = resolver.sources_for_column(&col_name).len();
1785 errors.push(if strict {
1786 ValidationError::error(
1787 format!(
1788 "Ambiguous unqualified column '{}' found in {} referenced tables",
1789 col_name, source_count
1790 ),
1791 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
1792 )
1793 } else {
1794 ValidationError::warning(
1795 format!(
1796 "Ambiguous unqualified column '{}' found in {} referenced tables",
1797 col_name, source_count
1798 ),
1799 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1800 )
1801 });
1802 }
1803 }
1804 }
1805
1806 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
1807
1808 for join in &select.joins {
1809 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
1810 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
1811 let cartesian_like_kind = matches!(
1812 join.kind,
1813 JoinKind::Cross
1814 | JoinKind::Implicit
1815 | JoinKind::Array
1816 | JoinKind::LeftArray
1817 | JoinKind::Paste
1818 );
1819
1820 if right_table_key.is_some()
1821 && (cartesian_like_kind
1822 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
1823 {
1824 errors.push(ValidationError::warning(
1825 "Potential cartesian join: JOIN without ON/USING condition",
1826 validation_codes::W_CARTESIAN_JOIN,
1827 ));
1828 }
1829
1830 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
1831 if join.using.is_empty() {
1832 let mut eq_pairs = Vec::new();
1833 extract_join_equality_pairs(
1834 on_expr,
1835 schema_map,
1836 &context,
1837 &mut resolver,
1838 &mut eq_pairs,
1839 );
1840
1841 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
1842 .iter()
1843 .filter(|rel| {
1844 cumulative_left_tables.contains(&rel.source_table)
1845 && rel.target_table == right_key
1846 || (cumulative_left_tables.contains(&rel.target_table)
1847 && rel.source_table == right_key)
1848 })
1849 .collect();
1850
1851 if !relevant_relationships.is_empty() {
1852 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
1853 relevant_relationships
1854 .iter()
1855 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
1856 });
1857 if !uses_declared_fk {
1858 errors.push(ValidationError::warning(
1859 "JOIN predicate does not use declared foreign-key relationship columns",
1860 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
1861 ));
1862 }
1863 }
1864 }
1865 }
1866
1867 if let Some(right_key) = right_table_key {
1868 cumulative_left_tables.insert(right_key);
1869 }
1870 }
1871 }
1872
1873 errors
1874}
1875
1876fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
1877 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1878 return true;
1879 }
1880 if left == right {
1881 return true;
1882 }
1883 if left.is_numeric() && right.is_numeric() {
1884 return true;
1885 }
1886 if left.is_temporal() && right.is_temporal() {
1887 return true;
1888 }
1889 false
1890}
1891
1892fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
1893 if left == TypeFamily::Unknown {
1894 return right;
1895 }
1896 if right == TypeFamily::Unknown {
1897 return left;
1898 }
1899 if left == right {
1900 return left;
1901 }
1902 if left.is_numeric() && right.is_numeric() {
1903 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
1904 return TypeFamily::Numeric;
1905 }
1906 return TypeFamily::Integer;
1907 }
1908 if left.is_temporal() && right.is_temporal() {
1909 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
1910 return TypeFamily::Timestamp;
1911 }
1912 if left == TypeFamily::Date || right == TypeFamily::Date {
1913 return TypeFamily::Date;
1914 }
1915 return TypeFamily::Time;
1916 }
1917 TypeFamily::Unknown
1918}
1919
1920fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
1921 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
1922 return true;
1923 }
1924 if target == source {
1925 return true;
1926 }
1927
1928 match target {
1929 TypeFamily::Boolean => source == TypeFamily::Boolean,
1930 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
1931 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
1932 source.is_temporal()
1933 }
1934 TypeFamily::String => true,
1935 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
1936 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
1937 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
1938 TypeFamily::Array => source == TypeFamily::Array,
1939 TypeFamily::Map => source == TypeFamily::Map,
1940 TypeFamily::Struct => source == TypeFamily::Struct,
1941 TypeFamily::Unknown => true,
1942 }
1943}
1944
1945fn projection_families(
1946 query_expr: &Expression,
1947 schema_map: &HashMap<String, TableSchemaEntry>,
1948) -> Option<Vec<TypeFamily>> {
1949 match query_expr {
1950 Expression::Select(select) => {
1951 if select
1952 .expressions
1953 .iter()
1954 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
1955 {
1956 return None;
1957 }
1958 let select_expr = Expression::Select(select.clone());
1959 let context = collect_type_check_context(&select_expr, schema_map);
1960 Some(
1961 select
1962 .expressions
1963 .iter()
1964 .map(|e| infer_expression_type_family(e, schema_map, &context))
1965 .collect(),
1966 )
1967 }
1968 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
1969 Expression::Union(union) => {
1970 let left = projection_families(&union.left, schema_map)?;
1971 let right = projection_families(&union.right, schema_map)?;
1972 if left.len() != right.len() {
1973 return None;
1974 }
1975 Some(
1976 left.into_iter()
1977 .zip(right)
1978 .map(|(l, r)| merged_setop_family(l, r))
1979 .collect(),
1980 )
1981 }
1982 Expression::Intersect(intersect) => {
1983 let left = projection_families(&intersect.left, schema_map)?;
1984 let right = projection_families(&intersect.right, schema_map)?;
1985 if left.len() != right.len() {
1986 return None;
1987 }
1988 Some(
1989 left.into_iter()
1990 .zip(right)
1991 .map(|(l, r)| merged_setop_family(l, r))
1992 .collect(),
1993 )
1994 }
1995 Expression::Except(except) => {
1996 let left = projection_families(&except.left, schema_map)?;
1997 let right = projection_families(&except.right, schema_map)?;
1998 if left.len() != right.len() {
1999 return None;
2000 }
2001 Some(
2002 left.into_iter()
2003 .zip(right)
2004 .map(|(l, r)| merged_setop_family(l, r))
2005 .collect(),
2006 )
2007 }
2008 Expression::Values(values) => {
2009 let first_row = values.expressions.first()?;
2010 let context = TypeCheckContext::default();
2011 Some(
2012 first_row
2013 .expressions
2014 .iter()
2015 .map(|e| infer_expression_type_family(e, schema_map, &context))
2016 .collect(),
2017 )
2018 }
2019 _ => None,
2020 }
2021}
2022
2023fn check_set_operation_compatibility(
2024 op_name: &str,
2025 left_expr: &Expression,
2026 right_expr: &Expression,
2027 schema_map: &HashMap<String, TableSchemaEntry>,
2028 strict: bool,
2029 errors: &mut Vec<ValidationError>,
2030) {
2031 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2032 return;
2033 };
2034 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2035 return;
2036 };
2037
2038 if left_projection.len() != right_projection.len() {
2039 errors.push(type_issue(
2040 strict,
2041 validation_codes::E_SETOP_ARITY_MISMATCH,
2042 validation_codes::W_SETOP_IMPLICIT_COERCION,
2043 format!(
2044 "{} operands return different column counts: left {}, right {}",
2045 op_name,
2046 left_projection.len(),
2047 right_projection.len()
2048 ),
2049 ));
2050 return;
2051 }
2052
2053 for (idx, (left, right)) in left_projection
2054 .into_iter()
2055 .zip(right_projection)
2056 .enumerate()
2057 {
2058 if !are_setop_compatible(left, right) {
2059 errors.push(type_issue(
2060 strict,
2061 validation_codes::E_SETOP_TYPE_MISMATCH,
2062 validation_codes::W_SETOP_IMPLICIT_COERCION,
2063 format!(
2064 "{} column {} has incompatible types: {} vs {}",
2065 op_name,
2066 idx + 1,
2067 type_family_name(left),
2068 type_family_name(right)
2069 ),
2070 ));
2071 }
2072 }
2073}
2074
2075fn check_insert_assignments(
2076 stmt: &Expression,
2077 insert: &Insert,
2078 schema_map: &HashMap<String, TableSchemaEntry>,
2079 strict: bool,
2080 errors: &mut Vec<ValidationError>,
2081) {
2082 let Some((target_table_key, table_schema)) =
2083 resolve_table_schema_entry(&insert.table, schema_map)
2084 else {
2085 return;
2086 };
2087
2088 let mut target_columns = Vec::new();
2089 if insert.columns.is_empty() {
2090 target_columns.extend(table_schema.column_order.iter().cloned());
2091 } else {
2092 for column in &insert.columns {
2093 let col_name = lower(&column.name);
2094 if table_schema.columns.contains_key(&col_name) {
2095 target_columns.push(col_name);
2096 } else {
2097 errors.push(if strict {
2098 ValidationError::error(
2099 format!(
2100 "Unknown column '{}' in table '{}'",
2101 column.name, target_table_key
2102 ),
2103 validation_codes::E_UNKNOWN_COLUMN,
2104 )
2105 } else {
2106 ValidationError::warning(
2107 format!(
2108 "Unknown column '{}' in table '{}'",
2109 column.name, target_table_key
2110 ),
2111 validation_codes::E_UNKNOWN_COLUMN,
2112 )
2113 });
2114 }
2115 }
2116 }
2117
2118 if target_columns.is_empty() {
2119 return;
2120 }
2121
2122 let context = collect_type_check_context(stmt, schema_map);
2123
2124 if !insert.default_values {
2125 for (row_idx, row) in insert.values.iter().enumerate() {
2126 if row.len() != target_columns.len() {
2127 errors.push(type_issue(
2128 strict,
2129 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2130 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2131 format!(
2132 "INSERT row {} has {} values but target has {} columns",
2133 row_idx + 1,
2134 row.len(),
2135 target_columns.len()
2136 ),
2137 ));
2138 continue;
2139 }
2140
2141 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2142 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2143 continue;
2144 };
2145 let source_family = infer_expression_type_family(value, schema_map, &context);
2146 if !are_assignment_compatible(target_family, source_family) {
2147 errors.push(type_issue(
2148 strict,
2149 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2150 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2151 format!(
2152 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2153 target_table_key,
2154 target_column,
2155 type_family_name(target_family),
2156 type_family_name(source_family)
2157 ),
2158 ));
2159 }
2160 }
2161 }
2162 }
2163
2164 if let Some(query) = &insert.query {
2165 if insert.by_name {
2167 return;
2168 }
2169
2170 let Some(source_projection) = projection_families(query, schema_map) else {
2171 return;
2172 };
2173
2174 if source_projection.len() != target_columns.len() {
2175 errors.push(type_issue(
2176 strict,
2177 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2178 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2179 format!(
2180 "INSERT source query has {} columns but target has {} columns",
2181 source_projection.len(),
2182 target_columns.len()
2183 ),
2184 ));
2185 return;
2186 }
2187
2188 for (source_family, target_column) in
2189 source_projection.into_iter().zip(target_columns.iter())
2190 {
2191 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2192 continue;
2193 };
2194 if !are_assignment_compatible(target_family, source_family) {
2195 errors.push(type_issue(
2196 strict,
2197 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2198 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2199 format!(
2200 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2201 target_table_key,
2202 target_column,
2203 type_family_name(target_family),
2204 type_family_name(source_family)
2205 ),
2206 ));
2207 }
2208 }
2209 }
2210}
2211
2212fn check_update_assignments(
2213 stmt: &Expression,
2214 update: &Update,
2215 schema_map: &HashMap<String, TableSchemaEntry>,
2216 strict: bool,
2217 errors: &mut Vec<ValidationError>,
2218) {
2219 let Some((target_table_key, table_schema)) =
2220 resolve_table_schema_entry(&update.table, schema_map)
2221 else {
2222 return;
2223 };
2224
2225 let context = collect_type_check_context(stmt, schema_map);
2226
2227 for (column, value) in &update.set {
2228 let col_name = lower(&column.name);
2229 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2230 errors.push(if strict {
2231 ValidationError::error(
2232 format!(
2233 "Unknown column '{}' in table '{}'",
2234 column.name, target_table_key
2235 ),
2236 validation_codes::E_UNKNOWN_COLUMN,
2237 )
2238 } else {
2239 ValidationError::warning(
2240 format!(
2241 "Unknown column '{}' in table '{}'",
2242 column.name, target_table_key
2243 ),
2244 validation_codes::E_UNKNOWN_COLUMN,
2245 )
2246 });
2247 continue;
2248 };
2249
2250 let source_family = infer_expression_type_family(value, schema_map, &context);
2251 if !are_assignment_compatible(target_family, source_family) {
2252 errors.push(type_issue(
2253 strict,
2254 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2255 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2256 format!(
2257 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2258 target_table_key,
2259 col_name,
2260 type_family_name(target_family),
2261 type_family_name(source_family)
2262 ),
2263 ));
2264 }
2265 }
2266}
2267
2268fn check_types(
2269 stmt: &Expression,
2270 schema_map: &HashMap<String, TableSchemaEntry>,
2271 strict: bool,
2272) -> Vec<ValidationError> {
2273 let mut errors = Vec::new();
2274 let context = collect_type_check_context(stmt, schema_map);
2275
2276 for node in stmt.dfs() {
2277 match node {
2278 Expression::Insert(insert) => {
2279 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2280 }
2281 Expression::Update(update) => {
2282 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2283 }
2284 Expression::Union(union) => {
2285 check_set_operation_compatibility(
2286 "UNION",
2287 &union.left,
2288 &union.right,
2289 schema_map,
2290 strict,
2291 &mut errors,
2292 );
2293 }
2294 Expression::Intersect(intersect) => {
2295 check_set_operation_compatibility(
2296 "INTERSECT",
2297 &intersect.left,
2298 &intersect.right,
2299 schema_map,
2300 strict,
2301 &mut errors,
2302 );
2303 }
2304 Expression::Except(except) => {
2305 check_set_operation_compatibility(
2306 "EXCEPT",
2307 &except.left,
2308 &except.right,
2309 schema_map,
2310 strict,
2311 &mut errors,
2312 );
2313 }
2314 Expression::Select(select) => {
2315 if let Some(prewhere) = &select.prewhere {
2316 let family = infer_expression_type_family(prewhere, schema_map, &context);
2317 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2318 errors.push(type_issue(
2319 strict,
2320 validation_codes::E_INVALID_PREDICATE_TYPE,
2321 validation_codes::W_PREDICATE_NULLABILITY,
2322 format!(
2323 "PREWHERE clause expects a boolean predicate, found {}",
2324 type_family_name(family)
2325 ),
2326 ));
2327 }
2328 }
2329
2330 if let Some(where_clause) = &select.where_clause {
2331 let family =
2332 infer_expression_type_family(&where_clause.this, schema_map, &context);
2333 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2334 errors.push(type_issue(
2335 strict,
2336 validation_codes::E_INVALID_PREDICATE_TYPE,
2337 validation_codes::W_PREDICATE_NULLABILITY,
2338 format!(
2339 "WHERE clause expects a boolean predicate, found {}",
2340 type_family_name(family)
2341 ),
2342 ));
2343 }
2344 }
2345
2346 if let Some(having_clause) = &select.having {
2347 let family =
2348 infer_expression_type_family(&having_clause.this, schema_map, &context);
2349 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2350 errors.push(type_issue(
2351 strict,
2352 validation_codes::E_INVALID_PREDICATE_TYPE,
2353 validation_codes::W_PREDICATE_NULLABILITY,
2354 format!(
2355 "HAVING clause expects a boolean predicate, found {}",
2356 type_family_name(family)
2357 ),
2358 ));
2359 }
2360 }
2361
2362 for join in &select.joins {
2363 if let Some(on) = &join.on {
2364 let family = infer_expression_type_family(on, schema_map, &context);
2365 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2366 errors.push(type_issue(
2367 strict,
2368 validation_codes::E_INVALID_PREDICATE_TYPE,
2369 validation_codes::W_PREDICATE_NULLABILITY,
2370 format!(
2371 "JOIN ON expects a boolean predicate, found {}",
2372 type_family_name(family)
2373 ),
2374 ));
2375 }
2376 }
2377 if let Some(match_condition) = &join.match_condition {
2378 let family =
2379 infer_expression_type_family(match_condition, schema_map, &context);
2380 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2381 errors.push(type_issue(
2382 strict,
2383 validation_codes::E_INVALID_PREDICATE_TYPE,
2384 validation_codes::W_PREDICATE_NULLABILITY,
2385 format!(
2386 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2387 type_family_name(family)
2388 ),
2389 ));
2390 }
2391 }
2392 }
2393 }
2394 Expression::Where(where_clause) => {
2395 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2396 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2397 errors.push(type_issue(
2398 strict,
2399 validation_codes::E_INVALID_PREDICATE_TYPE,
2400 validation_codes::W_PREDICATE_NULLABILITY,
2401 format!(
2402 "WHERE clause expects a boolean predicate, found {}",
2403 type_family_name(family)
2404 ),
2405 ));
2406 }
2407 }
2408 Expression::Having(having_clause) => {
2409 let family =
2410 infer_expression_type_family(&having_clause.this, schema_map, &context);
2411 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2412 errors.push(type_issue(
2413 strict,
2414 validation_codes::E_INVALID_PREDICATE_TYPE,
2415 validation_codes::W_PREDICATE_NULLABILITY,
2416 format!(
2417 "HAVING clause expects a boolean predicate, found {}",
2418 type_family_name(family)
2419 ),
2420 ));
2421 }
2422 }
2423 Expression::And(op) | Expression::Or(op) => {
2424 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2425 let family = infer_expression_type_family(expr, schema_map, &context);
2426 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2427 errors.push(type_issue(
2428 strict,
2429 validation_codes::E_INVALID_PREDICATE_TYPE,
2430 validation_codes::W_PREDICATE_NULLABILITY,
2431 format!(
2432 "Logical {} operand expects boolean, found {}",
2433 side,
2434 type_family_name(family)
2435 ),
2436 ));
2437 }
2438 }
2439 }
2440 Expression::Not(unary) => {
2441 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2442 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2443 errors.push(type_issue(
2444 strict,
2445 validation_codes::E_INVALID_PREDICATE_TYPE,
2446 validation_codes::W_PREDICATE_NULLABILITY,
2447 format!("NOT expects boolean, found {}", type_family_name(family)),
2448 ));
2449 }
2450 }
2451 Expression::Eq(op)
2452 | Expression::Neq(op)
2453 | Expression::Lt(op)
2454 | Expression::Lte(op)
2455 | Expression::Gt(op)
2456 | Expression::Gte(op) => {
2457 let left = infer_expression_type_family(&op.left, schema_map, &context);
2458 let right = infer_expression_type_family(&op.right, schema_map, &context);
2459 if !are_comparable(left, right) {
2460 errors.push(type_issue(
2461 strict,
2462 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2463 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2464 format!(
2465 "Incompatible comparison between {} and {}",
2466 type_family_name(left),
2467 type_family_name(right)
2468 ),
2469 ));
2470 }
2471 }
2472 Expression::Like(op) | Expression::ILike(op) => {
2473 let left = infer_expression_type_family(&op.left, schema_map, &context);
2474 let right = infer_expression_type_family(&op.right, schema_map, &context);
2475 if left != TypeFamily::Unknown
2476 && right != TypeFamily::Unknown
2477 && (!is_string_like(left) || !is_string_like(right))
2478 {
2479 errors.push(type_issue(
2480 strict,
2481 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2482 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2483 format!(
2484 "LIKE/ILIKE expects string operands, found {} and {}",
2485 type_family_name(left),
2486 type_family_name(right)
2487 ),
2488 ));
2489 }
2490 }
2491 Expression::Between(between) => {
2492 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2493 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2494 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2495
2496 if !are_comparable(this_family, low_family)
2497 || !are_comparable(this_family, high_family)
2498 {
2499 errors.push(type_issue(
2500 strict,
2501 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2502 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2503 format!(
2504 "BETWEEN bounds are incompatible with {} (found {} and {})",
2505 type_family_name(this_family),
2506 type_family_name(low_family),
2507 type_family_name(high_family)
2508 ),
2509 ));
2510 }
2511 }
2512 Expression::In(in_expr) => {
2513 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2514 for value in &in_expr.expressions {
2515 let item_family = infer_expression_type_family(value, schema_map, &context);
2516 if !are_comparable(this_family, item_family) {
2517 errors.push(type_issue(
2518 strict,
2519 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2520 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2521 format!(
2522 "IN item type {} is incompatible with {}",
2523 type_family_name(item_family),
2524 type_family_name(this_family)
2525 ),
2526 ));
2527 break;
2528 }
2529 }
2530 }
2531 Expression::Add(op)
2532 | Expression::Sub(op)
2533 | Expression::Mul(op)
2534 | Expression::Div(op)
2535 | Expression::Mod(op) => {
2536 let left = infer_expression_type_family(&op.left, schema_map, &context);
2537 let right = infer_expression_type_family(&op.right, schema_map, &context);
2538
2539 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2540 continue;
2541 }
2542
2543 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2544 && ((left.is_temporal() && right.is_numeric())
2545 || (right.is_temporal() && left.is_numeric())
2546 || (matches!(node, Expression::Sub(_))
2547 && left.is_temporal()
2548 && right.is_temporal()));
2549
2550 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2551 errors.push(type_issue(
2552 strict,
2553 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2554 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2555 format!(
2556 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2557 type_family_name(left),
2558 type_family_name(right)
2559 ),
2560 ));
2561 }
2562 }
2563 Expression::Function(function) => {
2564 check_generic_function(function, schema_map, &context, strict, &mut errors);
2565 }
2566 Expression::Upper(func)
2567 | Expression::Lower(func)
2568 | Expression::LTrim(func)
2569 | Expression::RTrim(func)
2570 | Expression::Reverse(func) => {
2571 let family = infer_expression_type_family(&func.this, schema_map, &context);
2572 check_function_argument(
2573 &mut errors,
2574 strict,
2575 "string_function",
2576 0,
2577 family,
2578 "a string argument",
2579 is_string_like(family),
2580 );
2581 }
2582 Expression::Length(func) => {
2583 let family = infer_expression_type_family(&func.this, schema_map, &context);
2584 check_function_argument(
2585 &mut errors,
2586 strict,
2587 "length",
2588 0,
2589 family,
2590 "a string or binary argument",
2591 is_string_or_binary(family),
2592 );
2593 }
2594 Expression::Trim(func) => {
2595 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2596 check_function_argument(
2597 &mut errors,
2598 strict,
2599 "trim",
2600 0,
2601 this_family,
2602 "a string argument",
2603 is_string_like(this_family),
2604 );
2605 if let Some(chars) = &func.characters {
2606 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2607 check_function_argument(
2608 &mut errors,
2609 strict,
2610 "trim",
2611 1,
2612 chars_family,
2613 "a string argument",
2614 is_string_like(chars_family),
2615 );
2616 }
2617 }
2618 Expression::Substring(func) => {
2619 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2620 check_function_argument(
2621 &mut errors,
2622 strict,
2623 "substring",
2624 0,
2625 this_family,
2626 "a string argument",
2627 is_string_like(this_family),
2628 );
2629
2630 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2631 check_function_argument(
2632 &mut errors,
2633 strict,
2634 "substring",
2635 1,
2636 start_family,
2637 "a numeric argument",
2638 start_family.is_numeric(),
2639 );
2640 if let Some(length) = &func.length {
2641 let length_family = infer_expression_type_family(length, schema_map, &context);
2642 check_function_argument(
2643 &mut errors,
2644 strict,
2645 "substring",
2646 2,
2647 length_family,
2648 "a numeric argument",
2649 length_family.is_numeric(),
2650 );
2651 }
2652 }
2653 Expression::Replace(func) => {
2654 for (arg, idx) in [
2655 (&func.this, 0_usize),
2656 (&func.old, 1_usize),
2657 (&func.new, 2_usize),
2658 ] {
2659 let family = infer_expression_type_family(arg, schema_map, &context);
2660 check_function_argument(
2661 &mut errors,
2662 strict,
2663 "replace",
2664 idx,
2665 family,
2666 "a string argument",
2667 is_string_like(family),
2668 );
2669 }
2670 }
2671 Expression::Left(func) | Expression::Right(func) => {
2672 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2673 check_function_argument(
2674 &mut errors,
2675 strict,
2676 "left_right",
2677 0,
2678 this_family,
2679 "a string argument",
2680 is_string_like(this_family),
2681 );
2682 let length_family =
2683 infer_expression_type_family(&func.length, schema_map, &context);
2684 check_function_argument(
2685 &mut errors,
2686 strict,
2687 "left_right",
2688 1,
2689 length_family,
2690 "a numeric argument",
2691 length_family.is_numeric(),
2692 );
2693 }
2694 Expression::Repeat(func) => {
2695 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2696 check_function_argument(
2697 &mut errors,
2698 strict,
2699 "repeat",
2700 0,
2701 this_family,
2702 "a string argument",
2703 is_string_like(this_family),
2704 );
2705 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2706 check_function_argument(
2707 &mut errors,
2708 strict,
2709 "repeat",
2710 1,
2711 times_family,
2712 "a numeric argument",
2713 times_family.is_numeric(),
2714 );
2715 }
2716 Expression::Lpad(func) | Expression::Rpad(func) => {
2717 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2718 check_function_argument(
2719 &mut errors,
2720 strict,
2721 "pad",
2722 0,
2723 this_family,
2724 "a string argument",
2725 is_string_like(this_family),
2726 );
2727 let length_family =
2728 infer_expression_type_family(&func.length, schema_map, &context);
2729 check_function_argument(
2730 &mut errors,
2731 strict,
2732 "pad",
2733 1,
2734 length_family,
2735 "a numeric argument",
2736 length_family.is_numeric(),
2737 );
2738 if let Some(fill) = &func.fill {
2739 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2740 check_function_argument(
2741 &mut errors,
2742 strict,
2743 "pad",
2744 2,
2745 fill_family,
2746 "a string argument",
2747 is_string_like(fill_family),
2748 );
2749 }
2750 }
2751 Expression::Abs(func)
2752 | Expression::Sqrt(func)
2753 | Expression::Cbrt(func)
2754 | Expression::Ln(func)
2755 | Expression::Exp(func) => {
2756 let family = infer_expression_type_family(&func.this, schema_map, &context);
2757 check_function_argument(
2758 &mut errors,
2759 strict,
2760 "numeric_function",
2761 0,
2762 family,
2763 "a numeric argument",
2764 family.is_numeric(),
2765 );
2766 }
2767 Expression::Round(func) => {
2768 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2769 check_function_argument(
2770 &mut errors,
2771 strict,
2772 "round",
2773 0,
2774 this_family,
2775 "a numeric argument",
2776 this_family.is_numeric(),
2777 );
2778 if let Some(decimals) = &func.decimals {
2779 let decimals_family =
2780 infer_expression_type_family(decimals, schema_map, &context);
2781 check_function_argument(
2782 &mut errors,
2783 strict,
2784 "round",
2785 1,
2786 decimals_family,
2787 "a numeric argument",
2788 decimals_family.is_numeric(),
2789 );
2790 }
2791 }
2792 Expression::Floor(func) => {
2793 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2794 check_function_argument(
2795 &mut errors,
2796 strict,
2797 "floor",
2798 0,
2799 this_family,
2800 "a numeric argument",
2801 this_family.is_numeric(),
2802 );
2803 if let Some(scale) = &func.scale {
2804 let scale_family = infer_expression_type_family(scale, schema_map, &context);
2805 check_function_argument(
2806 &mut errors,
2807 strict,
2808 "floor",
2809 1,
2810 scale_family,
2811 "a numeric argument",
2812 scale_family.is_numeric(),
2813 );
2814 }
2815 }
2816 Expression::Ceil(func) => {
2817 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2818 check_function_argument(
2819 &mut errors,
2820 strict,
2821 "ceil",
2822 0,
2823 this_family,
2824 "a numeric argument",
2825 this_family.is_numeric(),
2826 );
2827 if let Some(decimals) = &func.decimals {
2828 let decimals_family =
2829 infer_expression_type_family(decimals, schema_map, &context);
2830 check_function_argument(
2831 &mut errors,
2832 strict,
2833 "ceil",
2834 1,
2835 decimals_family,
2836 "a numeric argument",
2837 decimals_family.is_numeric(),
2838 );
2839 }
2840 }
2841 Expression::Power(func) => {
2842 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
2843 check_function_argument(
2844 &mut errors,
2845 strict,
2846 "power",
2847 0,
2848 left_family,
2849 "a numeric argument",
2850 left_family.is_numeric(),
2851 );
2852 let right_family =
2853 infer_expression_type_family(&func.expression, schema_map, &context);
2854 check_function_argument(
2855 &mut errors,
2856 strict,
2857 "power",
2858 1,
2859 right_family,
2860 "a numeric argument",
2861 right_family.is_numeric(),
2862 );
2863 }
2864 Expression::Log(func) => {
2865 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2866 check_function_argument(
2867 &mut errors,
2868 strict,
2869 "log",
2870 0,
2871 this_family,
2872 "a numeric argument",
2873 this_family.is_numeric(),
2874 );
2875 if let Some(base) = &func.base {
2876 let base_family = infer_expression_type_family(base, schema_map, &context);
2877 check_function_argument(
2878 &mut errors,
2879 strict,
2880 "log",
2881 1,
2882 base_family,
2883 "a numeric argument",
2884 base_family.is_numeric(),
2885 );
2886 }
2887 }
2888 _ => {}
2889 }
2890 }
2891
2892 errors
2893}
2894
2895fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
2896 let mut errors = Vec::new();
2897
2898 let Expression::Select(select) = stmt else {
2899 return errors;
2900 };
2901 let select_expr = Expression::Select(select.clone());
2902
2903 if !select_expr
2905 .find_all(|e| matches!(e, Expression::Star(_)))
2906 .is_empty()
2907 {
2908 errors.push(ValidationError::warning(
2909 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
2910 validation_codes::W_SELECT_STAR,
2911 ));
2912 }
2913
2914 let aggregate_count = get_aggregate_functions(&select_expr).len();
2916 if aggregate_count > 0 && select.group_by.is_none() {
2917 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
2918 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
2919 && get_aggregate_functions(expr).is_empty()
2920 });
2921
2922 if has_non_aggregate_column {
2923 errors.push(ValidationError::warning(
2924 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
2925 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
2926 ));
2927 }
2928 }
2929
2930 if select.distinct && select.order_by.is_some() {
2932 errors.push(ValidationError::warning(
2933 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
2934 validation_codes::W_DISTINCT_ORDER_BY,
2935 ));
2936 }
2937
2938 if select.limit.is_some() && select.order_by.is_none() {
2940 errors.push(ValidationError::warning(
2941 "LIMIT without ORDER BY produces non-deterministic results",
2942 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
2943 ));
2944 }
2945
2946 errors
2947}
2948
2949fn validate_statement_with_schema(
2950 stmt: &Expression,
2951 schema_map: &HashMap<String, TableSchemaEntry>,
2952 strict: bool,
2953) -> Vec<ValidationError> {
2954 let mut errors = Vec::new();
2955 let cte_aliases = collect_cte_aliases(stmt);
2956
2957 let mut referenced_tables: HashSet<String> = HashSet::new();
2958 let mut table_aliases: HashMap<String, String> = HashMap::new();
2959 let mut seen_tables: HashSet<String> = HashSet::new();
2960
2961 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
2963 let Expression::Table(table) = node else {
2964 continue;
2965 };
2966
2967 if cte_aliases.contains(&lower(&table.name.name)) {
2968 continue;
2969 }
2970
2971 let resolved_key = table_ref_candidates(table)
2972 .into_iter()
2973 .find(|k| schema_map.contains_key(k));
2974 let table_key = resolved_key
2975 .clone()
2976 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
2977
2978 if let Some(alias) = &table.alias {
2980 table_aliases.insert(lower(&alias.name), table_key.clone());
2981 }
2982
2983 if !seen_tables.insert(table_key.clone()) {
2984 continue;
2985 }
2986 referenced_tables.insert(table_key.clone());
2987
2988 if resolved_key.is_none() {
2989 let err = if strict {
2990 ValidationError::error(
2991 format!("Unknown table '{}'", table_ref_display_name(table)),
2992 validation_codes::E_UNKNOWN_TABLE,
2993 )
2994 } else {
2995 ValidationError::warning(
2996 format!("Unknown table '{}'", table_ref_display_name(table)),
2997 validation_codes::E_UNKNOWN_TABLE,
2998 )
2999 };
3000 errors.push(err);
3001 }
3002 }
3003
3004 for node in stmt.find_all(|e| matches!(e, Expression::Column(_))) {
3006 let Expression::Column(column) = node else {
3007 continue;
3008 };
3009
3010 let col_name = lower(&column.name.name);
3011 if col_name.is_empty() {
3012 continue;
3013 }
3014
3015 let mut table_name = column.table.as_ref().map(|t| lower(&t.name));
3016 if let Some(name) = &table_name {
3017 if let Some(mapped) = table_aliases.get(name) {
3018 table_name = Some(mapped.clone());
3019 }
3020 }
3021
3022 if let Some(table_name) = table_name {
3023 if let Some(table_schema) = schema_map.get(&table_name) {
3024 if !table_schema.columns.contains_key(&col_name) {
3025 let err = if strict {
3026 ValidationError::error(
3027 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3028 validation_codes::E_UNKNOWN_COLUMN,
3029 )
3030 } else {
3031 ValidationError::warning(
3032 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3033 validation_codes::E_UNKNOWN_COLUMN,
3034 )
3035 };
3036 errors.push(err);
3037 }
3038 }
3039 continue;
3040 }
3041
3042 if referenced_tables.len() == 1 {
3043 if let Some(single_table) = referenced_tables.iter().next() {
3044 if let Some(table_schema) = schema_map.get(single_table) {
3045 if !table_schema.columns.contains_key(&col_name) {
3046 let err = if strict {
3047 ValidationError::error(
3048 format!(
3049 "Unknown column '{}' in table '{}'",
3050 col_name, single_table
3051 ),
3052 validation_codes::E_UNKNOWN_COLUMN,
3053 )
3054 } else {
3055 ValidationError::warning(
3056 format!(
3057 "Unknown column '{}' in table '{}'",
3058 col_name, single_table
3059 ),
3060 validation_codes::E_UNKNOWN_COLUMN,
3061 )
3062 };
3063 errors.push(err);
3064 }
3065 }
3066 }
3067 } else if referenced_tables.len() > 1 {
3068 let found = referenced_tables.iter().any(|table_name| {
3069 schema_map
3070 .get(table_name)
3071 .map(|s| s.columns.contains_key(&col_name))
3072 .unwrap_or(false)
3073 });
3074
3075 if !found {
3076 let err = if strict {
3077 ValidationError::error(
3078 format!(
3079 "Unknown column '{}' (not found in any referenced table)",
3080 col_name
3081 ),
3082 validation_codes::E_UNKNOWN_COLUMN,
3083 )
3084 } else {
3085 ValidationError::warning(
3086 format!(
3087 "Unknown column '{}' (not found in any referenced table)",
3088 col_name
3089 ),
3090 validation_codes::E_UNKNOWN_COLUMN,
3091 )
3092 };
3093 errors.push(err);
3094 }
3095 } else if !schema_map.is_empty() {
3096 let found = schema_map
3097 .values()
3098 .any(|table| table.columns.contains_key(&col_name));
3099 if !found {
3100 let err = if strict {
3101 ValidationError::error(
3102 format!("Unknown column '{}'", col_name),
3103 validation_codes::E_UNKNOWN_COLUMN,
3104 )
3105 } else {
3106 ValidationError::warning(
3107 format!("Unknown column '{}'", col_name),
3108 validation_codes::E_UNKNOWN_COLUMN,
3109 )
3110 };
3111 errors.push(err);
3112 }
3113 }
3114 }
3115
3116 errors
3117}
3118
3119pub fn validate_with_schema(
3121 sql: &str,
3122 dialect: DialectType,
3123 schema: &ValidationSchema,
3124 options: &SchemaValidationOptions,
3125) -> ValidationResult {
3126 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3127
3128 let syntax_result = crate::validate_with_options(
3130 sql,
3131 dialect,
3132 &crate::ValidationOptions {
3133 strict_syntax: options.strict_syntax,
3134 },
3135 );
3136 if !syntax_result.valid {
3137 return syntax_result;
3138 }
3139
3140 let d = Dialect::get(dialect);
3141 let statements = match d.parse(sql) {
3142 Ok(exprs) => exprs,
3143 Err(e) => {
3144 return ValidationResult::with_errors(vec![ValidationError::error(
3145 e.to_string(),
3146 validation_codes::E_PARSE_OR_OPTIONS,
3147 )]);
3148 }
3149 };
3150
3151 let schema_map = build_schema_map(schema);
3152 let resolver_schema = build_resolver_schema(schema);
3153 let mut all_errors = syntax_result.errors;
3154 let declared_relationships = if options.check_references {
3155 build_declared_relationships(schema, &schema_map)
3156 } else {
3157 Vec::new()
3158 };
3159
3160 if options.check_references {
3161 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3162 }
3163
3164 for stmt in &statements {
3165 if options.semantic {
3166 all_errors.extend(check_semantics(stmt));
3167 }
3168 all_errors.extend(validate_statement_with_schema(stmt, &schema_map, strict));
3169 if options.check_types {
3170 all_errors.extend(check_types(stmt, &schema_map, strict));
3171 }
3172 if options.check_references {
3173 all_errors.extend(check_query_reference_quality(
3174 stmt,
3175 &schema_map,
3176 &resolver_schema,
3177 strict,
3178 &declared_relationships,
3179 ));
3180 }
3181 }
3182
3183 ValidationResult::with_errors(all_errors)
3184}
3185
3186#[cfg(test)]
3187mod tests;