1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_registry::canonical_typed_function_name_upper;
14use crate::optimizer::annotate_types;
15use crate::resolver::Resolver;
16use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
17use crate::scope::build_scope;
18use crate::traversal::ExpressionWalk;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SchemaColumn {
25 pub name: String,
27 #[serde(default, rename = "type")]
29 pub data_type: String,
30 #[serde(default)]
32 pub nullable: Option<bool>,
33 #[serde(default, rename = "primaryKey")]
35 pub primary_key: bool,
36 #[serde(default)]
38 pub unique: bool,
39 #[serde(default)]
41 pub references: Option<SchemaColumnReference>,
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SchemaColumnReference {
47 pub table: String,
49 pub column: String,
51 #[serde(default)]
53 pub schema: Option<String>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SchemaForeignKey {
59 #[serde(default)]
61 pub name: Option<String>,
62 pub columns: Vec<String>,
64 pub references: SchemaTableReference,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct SchemaTableReference {
71 pub table: String,
73 pub columns: Vec<String>,
75 #[serde(default)]
77 pub schema: Option<String>,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SchemaTable {
83 pub name: String,
85 #[serde(default)]
87 pub schema: Option<String>,
88 pub columns: Vec<SchemaColumn>,
90 #[serde(default)]
92 pub aliases: Vec<String>,
93 #[serde(default, rename = "primaryKey")]
95 pub primary_key: Vec<String>,
96 #[serde(default, rename = "uniqueKeys")]
98 pub unique_keys: Vec<Vec<String>>,
99 #[serde(default, rename = "foreignKeys")]
101 pub foreign_keys: Vec<SchemaForeignKey>,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ValidationSchema {
107 pub tables: Vec<SchemaTable>,
109 #[serde(default)]
111 pub strict: Option<bool>,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct SchemaValidationOptions {
117 #[serde(default)]
119 pub check_types: bool,
120 #[serde(default)]
122 pub check_references: bool,
123 #[serde(default)]
125 pub strict: Option<bool>,
126 #[serde(default)]
128 pub semantic: bool,
129}
130
131pub mod validation_codes {
133 pub const E_PARSE_OR_OPTIONS: &str = "E000";
135 pub const E_UNKNOWN_TABLE: &str = "E200";
136 pub const E_UNKNOWN_COLUMN: &str = "E201";
137
138 pub const W_SELECT_STAR: &str = "W001";
139 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
140 pub const W_DISTINCT_ORDER_BY: &str = "W003";
141 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
142
143 pub const E_TYPE_MISMATCH: &str = "E210";
145 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
146 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
147 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
148 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
149 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
150 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
151 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
152 pub const E_INVALID_CAST: &str = "E218";
153 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
154
155 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
156 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
157 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
158 pub const W_LOSSY_CAST: &str = "W213";
159 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
160 pub const W_PREDICATE_NULLABILITY: &str = "W215";
161 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
162 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
163 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
164 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
165
166 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
168 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
169 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
170 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
171 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
172
173 pub const W_CARTESIAN_JOIN: &str = "W220";
174 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
175 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
176}
177
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
180#[serde(rename_all = "snake_case")]
181pub enum TypeFamily {
182 Unknown,
183 Boolean,
184 Integer,
185 Numeric,
186 String,
187 Binary,
188 Date,
189 Time,
190 Timestamp,
191 Interval,
192 Json,
193 Uuid,
194 Array,
195 Map,
196 Struct,
197}
198
199impl TypeFamily {
200 pub fn is_numeric(self) -> bool {
201 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
202 }
203
204 pub fn is_temporal(self) -> bool {
205 matches!(
206 self,
207 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
208 )
209 }
210}
211
212#[derive(Debug, Clone)]
213struct TableSchemaEntry {
214 columns: HashMap<String, TypeFamily>,
215 column_order: Vec<String>,
216}
217
218fn lower(s: &str) -> String {
219 s.to_lowercase()
220}
221
222fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
223 let open = data_type.find('(')?;
224 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
225 return None;
226 }
227 let base = data_type[..open].trim();
228 let inner = data_type[open + 1..data_type.len() - 1].trim();
229 Some((base, inner))
230}
231
232pub fn canonical_type_family(data_type: &str) -> TypeFamily {
234 let trimmed = data_type
235 .trim()
236 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
237 if trimmed.is_empty() {
238 return TypeFamily::Unknown;
239 }
240
241 let normalized = trimmed
243 .split_whitespace()
244 .collect::<Vec<_>>()
245 .join(" ")
246 .to_lowercase();
247
248 if let Some((base, inner)) = split_type_args(&normalized) {
250 match base {
251 "nullable" | "lowcardinality" => return canonical_type_family(inner),
252 "array" | "list" => return TypeFamily::Array,
253 "map" => return TypeFamily::Map,
254 "struct" | "row" | "record" => return TypeFamily::Struct,
255 _ => {}
256 }
257 }
258
259 if normalized.starts_with("array<") || normalized.starts_with("list<") {
260 return TypeFamily::Array;
261 }
262 if normalized.starts_with("map<") {
263 return TypeFamily::Map;
264 }
265 if normalized.starts_with("struct<")
266 || normalized.starts_with("row<")
267 || normalized.starts_with("record<")
268 || normalized.starts_with("object<")
269 {
270 return TypeFamily::Struct;
271 }
272
273 if normalized.ends_with("[]") {
274 return TypeFamily::Array;
275 }
276
277 let mut base = normalized
279 .split('(')
280 .next()
281 .unwrap_or("")
282 .trim()
283 .to_string();
284 if base.is_empty() {
285 return TypeFamily::Unknown;
286 }
287
288 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
289 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
290
291 match base.as_str() {
292 "bool" | "boolean" => TypeFamily::Boolean,
293 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
294 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
295 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
296 TypeFamily::Integer
297 }
298 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
299 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
300 TypeFamily::Numeric
301 }
302 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
303 | "string" | "clob" => TypeFamily::String,
304 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
305 "date" => TypeFamily::Date,
306 "time" => TypeFamily::Time,
307 "timestamp"
308 | "timestamptz"
309 | "datetime"
310 | "datetime2"
311 | "smalldatetime"
312 | "timestamp with time zone"
313 | "timestamp without time zone" => TypeFamily::Timestamp,
314 "interval" => TypeFamily::Interval,
315 "json" | "jsonb" | "variant" => TypeFamily::Json,
316 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
317 "array" | "list" => TypeFamily::Array,
318 "map" => TypeFamily::Map,
319 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
320 _ => TypeFamily::Unknown,
321 }
322}
323
324fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
325 let mut map = HashMap::new();
326
327 for table in &schema.tables {
328 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
329 let columns: HashMap<String, TypeFamily> = table
330 .columns
331 .iter()
332 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
333 .collect();
334 let entry = TableSchemaEntry {
335 columns,
336 column_order,
337 };
338
339 let simple_name = lower(&table.name);
340 map.insert(simple_name, entry.clone());
341
342 if let Some(table_schema) = &table.schema {
343 map.insert(
344 format!("{}.{}", lower(table_schema), lower(&table.name)),
345 entry.clone(),
346 );
347 }
348
349 for alias in &table.aliases {
350 map.insert(lower(alias), entry.clone());
351 }
352 }
353
354 map
355}
356
357fn type_family_to_data_type(family: TypeFamily) -> DataType {
358 match family {
359 TypeFamily::Unknown => DataType::Unknown,
360 TypeFamily::Boolean => DataType::Boolean,
361 TypeFamily::Integer => DataType::Int {
362 length: None,
363 integer_spelling: false,
364 },
365 TypeFamily::Numeric => DataType::Double {
366 precision: None,
367 scale: None,
368 },
369 TypeFamily::String => DataType::VarChar {
370 length: None,
371 parenthesized_length: false,
372 },
373 TypeFamily::Binary => DataType::VarBinary { length: None },
374 TypeFamily::Date => DataType::Date,
375 TypeFamily::Time => DataType::Time {
376 precision: None,
377 timezone: false,
378 },
379 TypeFamily::Timestamp => DataType::Timestamp {
380 precision: None,
381 timezone: false,
382 },
383 TypeFamily::Interval => DataType::Interval {
384 unit: None,
385 to: None,
386 },
387 TypeFamily::Json => DataType::Json,
388 TypeFamily::Uuid => DataType::Uuid,
389 TypeFamily::Array => DataType::Array {
390 element_type: Box::new(DataType::Unknown),
391 dimension: None,
392 },
393 TypeFamily::Map => DataType::Map {
394 key_type: Box::new(DataType::Unknown),
395 value_type: Box::new(DataType::Unknown),
396 },
397 TypeFamily::Struct => DataType::Struct {
398 fields: Vec::new(),
399 nested: false,
400 },
401 }
402}
403
404fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
405 let mut mapping = MappingSchema::new();
406
407 for table in &schema.tables {
408 let columns: Vec<(String, DataType)> = table
409 .columns
410 .iter()
411 .map(|column| {
412 (
413 lower(&column.name),
414 type_family_to_data_type(canonical_type_family(&column.data_type)),
415 )
416 })
417 .collect();
418
419 let mut table_names = Vec::new();
420 table_names.push(lower(&table.name));
421 if let Some(table_schema) = &table.schema {
422 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
423 }
424 for alias in &table.aliases {
425 table_names.push(lower(alias));
426 }
427
428 let mut dedup = HashSet::new();
429 for table_name in table_names {
430 if dedup.insert(table_name.clone()) {
431 let _ = mapping.add_table(&table_name, &columns, None);
432 }
433 }
434 }
435
436 mapping
437}
438
439fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
440 let mut aliases = HashSet::new();
441
442 for node in expr.dfs() {
443 match node {
444 Expression::Select(select) => {
445 if let Some(with) = &select.with {
446 for cte in &with.ctes {
447 aliases.insert(lower(&cte.alias.name));
448 }
449 }
450 }
451 Expression::Insert(insert) => {
452 if let Some(with) = &insert.with {
453 for cte in &with.ctes {
454 aliases.insert(lower(&cte.alias.name));
455 }
456 }
457 }
458 Expression::Update(update) => {
459 if let Some(with) = &update.with {
460 for cte in &with.ctes {
461 aliases.insert(lower(&cte.alias.name));
462 }
463 }
464 }
465 Expression::Delete(delete) => {
466 if let Some(with) = &delete.with {
467 for cte in &with.ctes {
468 aliases.insert(lower(&cte.alias.name));
469 }
470 }
471 }
472 Expression::Union(union) => {
473 if let Some(with) = &union.with {
474 for cte in &with.ctes {
475 aliases.insert(lower(&cte.alias.name));
476 }
477 }
478 }
479 Expression::Intersect(intersect) => {
480 if let Some(with) = &intersect.with {
481 for cte in &with.ctes {
482 aliases.insert(lower(&cte.alias.name));
483 }
484 }
485 }
486 Expression::Except(except) => {
487 if let Some(with) = &except.with {
488 for cte in &with.ctes {
489 aliases.insert(lower(&cte.alias.name));
490 }
491 }
492 }
493 Expression::Merge(merge) => {
494 if let Some(with_) = &merge.with_ {
495 if let Expression::With(with_clause) = with_.as_ref() {
496 for cte in &with_clause.ctes {
497 aliases.insert(lower(&cte.alias.name));
498 }
499 }
500 }
501 }
502 _ => {}
503 }
504 }
505
506 aliases
507}
508
509fn table_ref_candidates(table: &TableRef) -> Vec<String> {
510 let name = lower(&table.name.name);
511 let schema = table.schema.as_ref().map(|s| lower(&s.name));
512 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
513
514 let mut candidates = Vec::new();
515 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
516 candidates.push(format!("{}.{}.{}", catalog, schema, name));
517 }
518 if let Some(schema) = &schema {
519 candidates.push(format!("{}.{}", schema, name));
520 }
521 candidates.push(name);
522 candidates
523}
524
525fn table_ref_display_name(table: &TableRef) -> String {
526 let mut parts = Vec::new();
527 if let Some(catalog) = &table.catalog {
528 parts.push(catalog.name.clone());
529 }
530 if let Some(schema) = &table.schema {
531 parts.push(schema.name.clone());
532 }
533 parts.push(table.name.name.clone());
534 parts.join(".")
535}
536
537#[derive(Debug, Default, Clone)]
538struct TypeCheckContext {
539 referenced_tables: HashSet<String>,
540 table_aliases: HashMap<String, String>,
541}
542
543fn type_family_name(family: TypeFamily) -> &'static str {
544 match family {
545 TypeFamily::Unknown => "unknown",
546 TypeFamily::Boolean => "boolean",
547 TypeFamily::Integer => "integer",
548 TypeFamily::Numeric => "numeric",
549 TypeFamily::String => "string",
550 TypeFamily::Binary => "binary",
551 TypeFamily::Date => "date",
552 TypeFamily::Time => "time",
553 TypeFamily::Timestamp => "timestamp",
554 TypeFamily::Interval => "interval",
555 TypeFamily::Json => "json",
556 TypeFamily::Uuid => "uuid",
557 TypeFamily::Array => "array",
558 TypeFamily::Map => "map",
559 TypeFamily::Struct => "struct",
560 }
561}
562
563fn is_string_like(family: TypeFamily) -> bool {
564 matches!(family, TypeFamily::String)
565}
566
567fn is_string_or_binary(family: TypeFamily) -> bool {
568 matches!(family, TypeFamily::String | TypeFamily::Binary)
569}
570
571fn type_issue(
572 strict: bool,
573 error_code: &str,
574 warning_code: &str,
575 message: impl Into<String>,
576) -> ValidationError {
577 if strict {
578 ValidationError::error(message.into(), error_code)
579 } else {
580 ValidationError::warning(message.into(), warning_code)
581 }
582}
583
584fn data_type_family(data_type: &DataType) -> TypeFamily {
585 match data_type {
586 DataType::Boolean => TypeFamily::Boolean,
587 DataType::TinyInt { .. }
588 | DataType::SmallInt { .. }
589 | DataType::Int { .. }
590 | DataType::BigInt { .. } => TypeFamily::Integer,
591 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
592 TypeFamily::Numeric
593 }
594 DataType::Char { .. }
595 | DataType::VarChar { .. }
596 | DataType::String { .. }
597 | DataType::Text
598 | DataType::TextWithLength { .. }
599 | DataType::CharacterSet { .. } => TypeFamily::String,
600 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
601 DataType::Date => TypeFamily::Date,
602 DataType::Time { .. } => TypeFamily::Time,
603 DataType::Timestamp { .. } => TypeFamily::Timestamp,
604 DataType::Interval { .. } => TypeFamily::Interval,
605 DataType::Json | DataType::JsonB => TypeFamily::Json,
606 DataType::Uuid => TypeFamily::Uuid,
607 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
608 DataType::Map { .. } => TypeFamily::Map,
609 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
610 TypeFamily::Struct
611 }
612 DataType::Nullable { inner } => data_type_family(inner),
613 DataType::Custom { name } => canonical_type_family(name),
614 DataType::Unknown => TypeFamily::Unknown,
615 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
616 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
617 DataType::Vector { .. } => TypeFamily::Array,
618 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
619 }
620}
621
622fn collect_type_check_context(
623 stmt: &Expression,
624 schema_map: &HashMap<String, TableSchemaEntry>,
625) -> TypeCheckContext {
626 fn add_table_to_context(
627 table: &TableRef,
628 schema_map: &HashMap<String, TableSchemaEntry>,
629 context: &mut TypeCheckContext,
630 ) {
631 let resolved_key = table_ref_candidates(table)
632 .into_iter()
633 .find(|k| schema_map.contains_key(k));
634
635 let Some(table_key) = resolved_key else {
636 return;
637 };
638
639 context.referenced_tables.insert(table_key.clone());
640 context
641 .table_aliases
642 .insert(lower(&table.name.name), table_key.clone());
643 if let Some(alias) = &table.alias {
644 context
645 .table_aliases
646 .insert(lower(&alias.name), table_key.clone());
647 }
648 }
649
650 let mut context = TypeCheckContext::default();
651 let cte_aliases = collect_cte_aliases(stmt);
652
653 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
654 let Expression::Table(table) = node else {
655 continue;
656 };
657
658 if cte_aliases.contains(&lower(&table.name.name)) {
659 continue;
660 }
661
662 add_table_to_context(table, schema_map, &mut context);
663 }
664
665 match stmt {
668 Expression::Insert(insert) => {
669 add_table_to_context(&insert.table, schema_map, &mut context);
670 }
671 Expression::Update(update) => {
672 add_table_to_context(&update.table, schema_map, &mut context);
673 for table in &update.extra_tables {
674 add_table_to_context(table, schema_map, &mut context);
675 }
676 }
677 Expression::Delete(delete) => {
678 add_table_to_context(&delete.table, schema_map, &mut context);
679 for table in &delete.using {
680 add_table_to_context(table, schema_map, &mut context);
681 }
682 for table in &delete.tables {
683 add_table_to_context(table, schema_map, &mut context);
684 }
685 }
686 _ => {}
687 }
688
689 context
690}
691
692fn resolve_table_schema_entry<'a>(
693 table: &TableRef,
694 schema_map: &'a HashMap<String, TableSchemaEntry>,
695) -> Option<(String, &'a TableSchemaEntry)> {
696 let key = table_ref_candidates(table)
697 .into_iter()
698 .find(|k| schema_map.contains_key(k))?;
699 let entry = schema_map.get(&key)?;
700 Some((key, entry))
701}
702
703fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
704 if strict {
705 ValidationError::error(
706 message.into(),
707 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
708 )
709 } else {
710 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
711 }
712}
713
714fn reference_table_candidates(
715 table_name: &str,
716 explicit_schema: Option<&str>,
717 source_schema: Option<&str>,
718) -> Vec<String> {
719 let mut candidates = Vec::new();
720 let raw = lower(table_name);
721
722 if let Some(schema) = explicit_schema {
723 candidates.push(format!("{}.{}", lower(schema), raw));
724 }
725
726 if raw.contains('.') {
727 candidates.push(raw.clone());
728 if let Some(last) = raw.rsplit('.').next() {
729 candidates.push(last.to_string());
730 }
731 } else {
732 if let Some(schema) = source_schema {
733 candidates.push(format!("{}.{}", lower(schema), raw));
734 }
735 candidates.push(raw);
736 }
737
738 let mut dedup = HashSet::new();
739 candidates
740 .into_iter()
741 .filter(|c| dedup.insert(c.clone()))
742 .collect()
743}
744
745fn resolve_reference_table_key(
746 table_name: &str,
747 explicit_schema: Option<&str>,
748 source_schema: Option<&str>,
749 schema_map: &HashMap<String, TableSchemaEntry>,
750) -> Option<String> {
751 reference_table_candidates(table_name, explicit_schema, source_schema)
752 .into_iter()
753 .find(|candidate| schema_map.contains_key(candidate))
754}
755
756fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
757 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
758 return true;
759 }
760 if source == target {
761 return true;
762 }
763 if source.is_numeric() && target.is_numeric() {
764 return true;
765 }
766 if source.is_temporal() && target.is_temporal() {
767 return true;
768 }
769 false
770}
771
772fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
773 let mut hints = HashSet::new();
774 for column in &table.columns {
775 if column.primary_key || column.unique {
776 hints.insert(lower(&column.name));
777 }
778 }
779 for key_col in &table.primary_key {
780 hints.insert(lower(key_col));
781 }
782 for group in &table.unique_keys {
783 if group.len() == 1 {
784 if let Some(col) = group.first() {
785 hints.insert(lower(col));
786 }
787 }
788 }
789 hints
790}
791
792fn check_reference_integrity(
793 schema: &ValidationSchema,
794 schema_map: &HashMap<String, TableSchemaEntry>,
795 strict: bool,
796) -> Vec<ValidationError> {
797 let mut errors = Vec::new();
798
799 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
800 for table in &schema.tables {
801 let simple = lower(&table.name);
802 key_hints_lookup.insert(simple, table_key_hints(table));
803 if let Some(schema_name) = &table.schema {
804 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
805 key_hints_lookup.insert(qualified, table_key_hints(table));
806 }
807 }
808
809 for table in &schema.tables {
810 let source_table_display = if let Some(schema_name) = &table.schema {
811 format!("{}.{}", schema_name, table.name)
812 } else {
813 table.name.clone()
814 };
815 let source_schema = table.schema.as_deref();
816 let source_columns: HashMap<String, TypeFamily> = table
817 .columns
818 .iter()
819 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
820 .collect();
821
822 for source_col in &table.columns {
823 let Some(reference) = &source_col.references else {
824 continue;
825 };
826 let source_type = canonical_type_family(&source_col.data_type);
827
828 let Some(target_key) = resolve_reference_table_key(
829 &reference.table,
830 reference.schema.as_deref(),
831 source_schema,
832 schema_map,
833 ) else {
834 errors.push(reference_issue(
835 strict,
836 format!(
837 "Foreign key reference '{}.{}' points to unknown table '{}'",
838 source_table_display, source_col.name, reference.table
839 ),
840 ));
841 continue;
842 };
843
844 let target_column = lower(&reference.column);
845 let Some(target_entry) = schema_map.get(&target_key) else {
846 errors.push(reference_issue(
847 strict,
848 format!(
849 "Foreign key reference '{}.{}' points to unknown table '{}'",
850 source_table_display, source_col.name, reference.table
851 ),
852 ));
853 continue;
854 };
855
856 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
857 errors.push(reference_issue(
858 strict,
859 format!(
860 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
861 source_table_display, source_col.name, target_key, reference.column
862 ),
863 ));
864 continue;
865 };
866
867 if !key_types_compatible(source_type, target_type) {
868 errors.push(reference_issue(
869 strict,
870 format!(
871 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
872 source_table_display,
873 source_col.name,
874 target_key,
875 reference.column,
876 type_family_name(source_type),
877 type_family_name(target_type)
878 ),
879 ));
880 }
881
882 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
883 if !target_key_hints.contains(&target_column) {
884 errors.push(ValidationError::warning(
885 format!(
886 "Referenced column '{}.{}' is not marked as primary/unique key",
887 target_key, reference.column
888 ),
889 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
890 ));
891 }
892 }
893 }
894
895 for foreign_key in &table.foreign_keys {
896 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
897 errors.push(reference_issue(
898 strict,
899 format!(
900 "Table-level foreign key on '{}' has empty source or target column list",
901 source_table_display
902 ),
903 ));
904 continue;
905 }
906 if foreign_key.columns.len() != foreign_key.references.columns.len() {
907 errors.push(reference_issue(
908 strict,
909 format!(
910 "Table-level foreign key on '{}' has {} source columns but {} target columns",
911 source_table_display,
912 foreign_key.columns.len(),
913 foreign_key.references.columns.len()
914 ),
915 ));
916 continue;
917 }
918
919 let Some(target_key) = resolve_reference_table_key(
920 &foreign_key.references.table,
921 foreign_key.references.schema.as_deref(),
922 source_schema,
923 schema_map,
924 ) else {
925 errors.push(reference_issue(
926 strict,
927 format!(
928 "Table-level foreign key on '{}' points to unknown table '{}'",
929 source_table_display, foreign_key.references.table
930 ),
931 ));
932 continue;
933 };
934
935 let Some(target_entry) = schema_map.get(&target_key) else {
936 errors.push(reference_issue(
937 strict,
938 format!(
939 "Table-level foreign key on '{}' points to unknown table '{}'",
940 source_table_display, foreign_key.references.table
941 ),
942 ));
943 continue;
944 };
945
946 for (source_col, target_col) in foreign_key
947 .columns
948 .iter()
949 .zip(foreign_key.references.columns.iter())
950 {
951 let source_col_name = lower(source_col);
952 let target_col_name = lower(target_col);
953
954 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
955 errors.push(reference_issue(
956 strict,
957 format!(
958 "Table-level foreign key on '{}' references unknown source column '{}'",
959 source_table_display, source_col
960 ),
961 ));
962 continue;
963 };
964
965 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
966 errors.push(reference_issue(
967 strict,
968 format!(
969 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
970 source_table_display, target_key, target_col
971 ),
972 ));
973 continue;
974 };
975
976 if !key_types_compatible(source_type, target_type) {
977 errors.push(reference_issue(
978 strict,
979 format!(
980 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
981 source_table_display,
982 source_col,
983 target_key,
984 target_col,
985 type_family_name(source_type),
986 type_family_name(target_type)
987 ),
988 ));
989 }
990
991 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
992 if !target_key_hints.contains(&target_col_name) {
993 errors.push(ValidationError::warning(
994 format!(
995 "Referenced column '{}.{}' is not marked as primary/unique key",
996 target_key, target_col
997 ),
998 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
999 ));
1000 }
1001 }
1002 }
1003 }
1004 }
1005
1006 errors
1007}
1008
1009fn resolve_unqualified_column_type(
1010 column_name: &str,
1011 schema_map: &HashMap<String, TableSchemaEntry>,
1012 context: &TypeCheckContext,
1013) -> TypeFamily {
1014 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1015 context.referenced_tables.iter().collect()
1016 } else {
1017 schema_map.keys().collect()
1018 };
1019
1020 let mut families = HashSet::new();
1021 for table_name in candidate_tables {
1022 if let Some(table_schema) = schema_map.get(table_name) {
1023 if let Some(family) = table_schema.columns.get(column_name) {
1024 families.insert(*family);
1025 }
1026 }
1027 }
1028
1029 if families.len() == 1 {
1030 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1031 } else {
1032 TypeFamily::Unknown
1033 }
1034}
1035
1036fn resolve_column_type(
1037 column: &Column,
1038 schema_map: &HashMap<String, TableSchemaEntry>,
1039 context: &TypeCheckContext,
1040) -> TypeFamily {
1041 let column_name = lower(&column.name.name);
1042 if column_name.is_empty() {
1043 return TypeFamily::Unknown;
1044 }
1045
1046 if let Some(table) = &column.table {
1047 let mut table_key = lower(&table.name);
1048 if let Some(mapped) = context.table_aliases.get(&table_key) {
1049 table_key = mapped.clone();
1050 }
1051
1052 return schema_map
1053 .get(&table_key)
1054 .and_then(|t| t.columns.get(&column_name))
1055 .copied()
1056 .unwrap_or(TypeFamily::Unknown);
1057 }
1058
1059 resolve_unqualified_column_type(&column_name, schema_map, context)
1060}
1061
1062struct TypeInferenceSchema<'a> {
1063 schema_map: &'a HashMap<String, TableSchemaEntry>,
1064 context: &'a TypeCheckContext,
1065}
1066
1067impl TypeInferenceSchema<'_> {
1068 fn resolve_table_key(&self, table: &str) -> Option<String> {
1069 let mut table_key = lower(table);
1070 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1071 table_key = mapped.clone();
1072 }
1073 if self.schema_map.contains_key(&table_key) {
1074 Some(table_key)
1075 } else {
1076 None
1077 }
1078 }
1079}
1080
1081impl SqlSchema for TypeInferenceSchema<'_> {
1082 fn dialect(&self) -> Option<DialectType> {
1083 None
1084 }
1085
1086 fn add_table(
1087 &mut self,
1088 _table: &str,
1089 _columns: &[(String, DataType)],
1090 _dialect: Option<DialectType>,
1091 ) -> SchemaResult<()> {
1092 Err(SchemaError::InvalidStructure(
1093 "Type inference schema is read-only".to_string(),
1094 ))
1095 }
1096
1097 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1098 let table_key = self
1099 .resolve_table_key(table)
1100 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1101 let entry = self
1102 .schema_map
1103 .get(&table_key)
1104 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1105 Ok(entry.column_order.clone())
1106 }
1107
1108 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1109 let col_name = lower(column);
1110 if table.is_empty() {
1111 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1112 return if family == TypeFamily::Unknown {
1113 Err(SchemaError::ColumnNotFound {
1114 table: "<unqualified>".to_string(),
1115 column: column.to_string(),
1116 })
1117 } else {
1118 Ok(type_family_to_data_type(family))
1119 };
1120 }
1121
1122 let table_key = self
1123 .resolve_table_key(table)
1124 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1125 let entry = self
1126 .schema_map
1127 .get(&table_key)
1128 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1129 let family =
1130 entry
1131 .columns
1132 .get(&col_name)
1133 .copied()
1134 .ok_or_else(|| SchemaError::ColumnNotFound {
1135 table: table.to_string(),
1136 column: column.to_string(),
1137 })?;
1138 Ok(type_family_to_data_type(family))
1139 }
1140
1141 fn has_column(&self, table: &str, column: &str) -> bool {
1142 self.get_column_type(table, column).is_ok()
1143 }
1144
1145 fn supported_table_args(&self) -> &[&str] {
1146 TABLE_PARTS
1147 }
1148
1149 fn is_empty(&self) -> bool {
1150 self.schema_map.is_empty()
1151 }
1152
1153 fn depth(&self) -> usize {
1154 1
1155 }
1156}
1157
1158fn infer_expression_type_family(
1159 expr: &Expression,
1160 schema_map: &HashMap<String, TableSchemaEntry>,
1161 context: &TypeCheckContext,
1162) -> TypeFamily {
1163 let inference_schema = TypeInferenceSchema {
1164 schema_map,
1165 context,
1166 };
1167 if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1168 let family = data_type_family(&data_type);
1169 if family != TypeFamily::Unknown {
1170 return family;
1171 }
1172 }
1173
1174 infer_expression_type_family_fallback(expr, schema_map, context)
1175}
1176
1177fn infer_expression_type_family_fallback(
1178 expr: &Expression,
1179 schema_map: &HashMap<String, TableSchemaEntry>,
1180 context: &TypeCheckContext,
1181) -> TypeFamily {
1182 match expr {
1183 Expression::Literal(literal) => match literal {
1184 crate::expressions::Literal::Number(value) => {
1185 if value.contains('.') || value.contains('e') || value.contains('E') {
1186 TypeFamily::Numeric
1187 } else {
1188 TypeFamily::Integer
1189 }
1190 }
1191 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1192 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1193 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1194 crate::expressions::Literal::Timestamp(_)
1195 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1196 crate::expressions::Literal::HexString(_)
1197 | crate::expressions::Literal::BitString(_)
1198 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1199 _ => TypeFamily::String,
1200 },
1201 Expression::Boolean(_) => TypeFamily::Boolean,
1202 Expression::Null(_) => TypeFamily::Unknown,
1203 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1204 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1205 data_type_family(&cast.to)
1206 }
1207 Expression::Alias(alias) => {
1208 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1209 }
1210 Expression::Neg(unary) => {
1211 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1212 }
1213 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1214 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1215 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1216 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1217 TypeFamily::Unknown
1218 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1219 TypeFamily::Integer
1220 } else if left.is_numeric() && right.is_numeric() {
1221 TypeFamily::Numeric
1222 } else if left.is_temporal() || right.is_temporal() {
1223 left
1224 } else {
1225 TypeFamily::Unknown
1226 }
1227 }
1228 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1229 Expression::Concat(_) => TypeFamily::String,
1230 Expression::Eq(_)
1231 | Expression::Neq(_)
1232 | Expression::Lt(_)
1233 | Expression::Lte(_)
1234 | Expression::Gt(_)
1235 | Expression::Gte(_)
1236 | Expression::Like(_)
1237 | Expression::ILike(_)
1238 | Expression::And(_)
1239 | Expression::Or(_)
1240 | Expression::Not(_)
1241 | Expression::Between(_)
1242 | Expression::In(_)
1243 | Expression::IsNull(_)
1244 | Expression::IsTrue(_)
1245 | Expression::IsFalse(_)
1246 | Expression::Is(_) => TypeFamily::Boolean,
1247 Expression::Length(_) => TypeFamily::Integer,
1248 Expression::Upper(_)
1249 | Expression::Lower(_)
1250 | Expression::Trim(_)
1251 | Expression::LTrim(_)
1252 | Expression::RTrim(_)
1253 | Expression::Replace(_)
1254 | Expression::Substring(_)
1255 | Expression::Left(_)
1256 | Expression::Right(_)
1257 | Expression::Repeat(_)
1258 | Expression::Lpad(_)
1259 | Expression::Rpad(_)
1260 | Expression::ConcatWs(_) => TypeFamily::String,
1261 Expression::Abs(_)
1262 | Expression::Round(_)
1263 | Expression::Floor(_)
1264 | Expression::Ceil(_)
1265 | Expression::Power(_)
1266 | Expression::Sqrt(_)
1267 | Expression::Cbrt(_)
1268 | Expression::Ln(_)
1269 | Expression::Log(_)
1270 | Expression::Exp(_) => TypeFamily::Numeric,
1271 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1272 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1273 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1274 Expression::CurrentDate(_) => TypeFamily::Date,
1275 Expression::CurrentTime(_) => TypeFamily::Time,
1276 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1277 TypeFamily::Timestamp
1278 }
1279 Expression::Interval(_) => TypeFamily::Interval,
1280 _ => TypeFamily::Unknown,
1281 }
1282}
1283
1284fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1285 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1286 return true;
1287 }
1288 if left == right {
1289 return true;
1290 }
1291 if left.is_numeric() && right.is_numeric() {
1292 return true;
1293 }
1294 if left.is_temporal() && right.is_temporal() {
1295 return true;
1296 }
1297 false
1298}
1299
1300fn check_function_argument(
1301 errors: &mut Vec<ValidationError>,
1302 strict: bool,
1303 function_name: &str,
1304 arg_index: usize,
1305 family: TypeFamily,
1306 expected: &str,
1307 valid: bool,
1308) {
1309 if family == TypeFamily::Unknown || valid {
1310 return;
1311 }
1312
1313 errors.push(type_issue(
1314 strict,
1315 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1316 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1317 format!(
1318 "Function '{}' argument {} expects {}, found {}",
1319 function_name,
1320 arg_index + 1,
1321 expected,
1322 type_family_name(family)
1323 ),
1324 ));
1325}
1326
1327fn function_dispatch_name(name: &str) -> String {
1328 let upper = name
1329 .rsplit('.')
1330 .next()
1331 .unwrap_or(name)
1332 .trim()
1333 .to_uppercase();
1334 lower(canonical_typed_function_name_upper(&upper))
1335}
1336
1337fn check_generic_function(
1338 function: &Function,
1339 schema_map: &HashMap<String, TableSchemaEntry>,
1340 context: &TypeCheckContext,
1341 strict: bool,
1342 errors: &mut Vec<ValidationError>,
1343) {
1344 let name = function_dispatch_name(&function.name);
1345
1346 let arg_family = |index: usize| -> Option<TypeFamily> {
1347 function
1348 .args
1349 .get(index)
1350 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1351 };
1352
1353 match name.as_str() {
1354 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1355 if let Some(family) = arg_family(0) {
1356 check_function_argument(
1357 errors,
1358 strict,
1359 &name,
1360 0,
1361 family,
1362 "a numeric argument",
1363 family.is_numeric(),
1364 );
1365 }
1366 }
1367 "round" | "floor" | "ceil" | "ceiling" => {
1368 if let Some(family) = arg_family(0) {
1369 check_function_argument(
1370 errors,
1371 strict,
1372 &name,
1373 0,
1374 family,
1375 "a numeric argument",
1376 family.is_numeric(),
1377 );
1378 }
1379 if let Some(family) = arg_family(1) {
1380 check_function_argument(
1381 errors,
1382 strict,
1383 &name,
1384 1,
1385 family,
1386 "a numeric argument",
1387 family.is_numeric(),
1388 );
1389 }
1390 }
1391 "power" | "pow" => {
1392 for i in [0_usize, 1_usize] {
1393 if let Some(family) = arg_family(i) {
1394 check_function_argument(
1395 errors,
1396 strict,
1397 &name,
1398 i,
1399 family,
1400 "a numeric argument",
1401 family.is_numeric(),
1402 );
1403 }
1404 }
1405 }
1406 "length" | "char_length" | "character_length" => {
1407 if let Some(family) = arg_family(0) {
1408 check_function_argument(
1409 errors,
1410 strict,
1411 &name,
1412 0,
1413 family,
1414 "a string or binary argument",
1415 is_string_or_binary(family),
1416 );
1417 }
1418 }
1419 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1420 if let Some(family) = arg_family(0) {
1421 check_function_argument(
1422 errors,
1423 strict,
1424 &name,
1425 0,
1426 family,
1427 "a string argument",
1428 is_string_like(family),
1429 );
1430 }
1431 }
1432 "substring" | "substr" => {
1433 if let Some(family) = arg_family(0) {
1434 check_function_argument(
1435 errors,
1436 strict,
1437 &name,
1438 0,
1439 family,
1440 "a string argument",
1441 is_string_like(family),
1442 );
1443 }
1444 if let Some(family) = arg_family(1) {
1445 check_function_argument(
1446 errors,
1447 strict,
1448 &name,
1449 1,
1450 family,
1451 "a numeric argument",
1452 family.is_numeric(),
1453 );
1454 }
1455 if let Some(family) = arg_family(2) {
1456 check_function_argument(
1457 errors,
1458 strict,
1459 &name,
1460 2,
1461 family,
1462 "a numeric argument",
1463 family.is_numeric(),
1464 );
1465 }
1466 }
1467 "replace" => {
1468 for i in [0_usize, 1_usize, 2_usize] {
1469 if let Some(family) = arg_family(i) {
1470 check_function_argument(
1471 errors,
1472 strict,
1473 &name,
1474 i,
1475 family,
1476 "a string argument",
1477 is_string_like(family),
1478 );
1479 }
1480 }
1481 }
1482 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1483 if let Some(family) = arg_family(0) {
1484 check_function_argument(
1485 errors,
1486 strict,
1487 &name,
1488 0,
1489 family,
1490 "a string argument",
1491 is_string_like(family),
1492 );
1493 }
1494 if let Some(family) = arg_family(1) {
1495 check_function_argument(
1496 errors,
1497 strict,
1498 &name,
1499 1,
1500 family,
1501 "a numeric argument",
1502 family.is_numeric(),
1503 );
1504 }
1505 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1506 if let Some(family) = arg_family(2) {
1507 check_function_argument(
1508 errors,
1509 strict,
1510 &name,
1511 2,
1512 family,
1513 "a string argument",
1514 is_string_like(family),
1515 );
1516 }
1517 }
1518 }
1519 _ => {}
1520 }
1521}
1522
1523#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1524struct DeclaredRelationship {
1525 source_table: String,
1526 source_column: String,
1527 target_table: String,
1528 target_column: String,
1529}
1530
1531fn build_declared_relationships(
1532 schema: &ValidationSchema,
1533 schema_map: &HashMap<String, TableSchemaEntry>,
1534) -> Vec<DeclaredRelationship> {
1535 let mut relationships = HashSet::new();
1536
1537 for table in &schema.tables {
1538 let Some(source_key) =
1539 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1540 else {
1541 continue;
1542 };
1543
1544 for column in &table.columns {
1545 let Some(reference) = &column.references else {
1546 continue;
1547 };
1548 let Some(target_key) = resolve_reference_table_key(
1549 &reference.table,
1550 reference.schema.as_deref(),
1551 table.schema.as_deref(),
1552 schema_map,
1553 ) else {
1554 continue;
1555 };
1556
1557 relationships.insert(DeclaredRelationship {
1558 source_table: source_key.clone(),
1559 source_column: lower(&column.name),
1560 target_table: target_key,
1561 target_column: lower(&reference.column),
1562 });
1563 }
1564
1565 for foreign_key in &table.foreign_keys {
1566 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1567 continue;
1568 }
1569 let Some(target_key) = resolve_reference_table_key(
1570 &foreign_key.references.table,
1571 foreign_key.references.schema.as_deref(),
1572 table.schema.as_deref(),
1573 schema_map,
1574 ) else {
1575 continue;
1576 };
1577
1578 for (source_col, target_col) in foreign_key
1579 .columns
1580 .iter()
1581 .zip(foreign_key.references.columns.iter())
1582 {
1583 relationships.insert(DeclaredRelationship {
1584 source_table: source_key.clone(),
1585 source_column: lower(source_col),
1586 target_table: target_key.clone(),
1587 target_column: lower(target_col),
1588 });
1589 }
1590 }
1591 }
1592
1593 relationships.into_iter().collect()
1594}
1595
1596fn resolve_column_binding(
1597 column: &Column,
1598 schema_map: &HashMap<String, TableSchemaEntry>,
1599 context: &TypeCheckContext,
1600 resolver: &mut Resolver<'_>,
1601) -> Option<(String, String)> {
1602 let column_name = lower(&column.name.name);
1603 if column_name.is_empty() {
1604 return None;
1605 }
1606
1607 if let Some(table) = &column.table {
1608 let mut table_key = lower(&table.name);
1609 if let Some(mapped) = context.table_aliases.get(&table_key) {
1610 table_key = mapped.clone();
1611 }
1612 if schema_map.contains_key(&table_key) {
1613 return Some((table_key, column_name));
1614 }
1615 return None;
1616 }
1617
1618 if let Some(resolved_source) = resolver.get_table(&column_name) {
1619 let mut table_key = lower(&resolved_source);
1620 if let Some(mapped) = context.table_aliases.get(&table_key) {
1621 table_key = mapped.clone();
1622 }
1623 if schema_map.contains_key(&table_key) {
1624 return Some((table_key, column_name));
1625 }
1626 }
1627
1628 let candidates: Vec<String> = context
1629 .referenced_tables
1630 .iter()
1631 .filter_map(|table_name| {
1632 schema_map
1633 .get(table_name)
1634 .filter(|entry| entry.columns.contains_key(&column_name))
1635 .map(|_| table_name.clone())
1636 })
1637 .collect();
1638 if candidates.len() == 1 {
1639 return Some((candidates[0].clone(), column_name));
1640 }
1641 None
1642}
1643
1644fn extract_join_equality_pairs(
1645 expr: &Expression,
1646 schema_map: &HashMap<String, TableSchemaEntry>,
1647 context: &TypeCheckContext,
1648 resolver: &mut Resolver<'_>,
1649 pairs: &mut Vec<((String, String), (String, String))>,
1650) {
1651 match expr {
1652 Expression::And(op) => {
1653 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1654 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1655 }
1656 Expression::Paren(paren) => {
1657 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1658 }
1659 Expression::Eq(op) => {
1660 let (Expression::Column(left_col), Expression::Column(right_col)) =
1661 (&op.left, &op.right)
1662 else {
1663 return;
1664 };
1665 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1666 return;
1667 };
1668 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1669 else {
1670 return;
1671 };
1672 pairs.push((left, right));
1673 }
1674 _ => {}
1675 }
1676}
1677
1678fn relationship_matches_pair(
1679 relationship: &DeclaredRelationship,
1680 left_table: &str,
1681 left_column: &str,
1682 right_table: &str,
1683 right_column: &str,
1684) -> bool {
1685 (relationship.source_table == left_table
1686 && relationship.source_column == left_column
1687 && relationship.target_table == right_table
1688 && relationship.target_column == right_column)
1689 || (relationship.source_table == right_table
1690 && relationship.source_column == right_column
1691 && relationship.target_table == left_table
1692 && relationship.target_column == left_column)
1693}
1694
1695fn resolved_table_key_from_expr(
1696 expr: &Expression,
1697 schema_map: &HashMap<String, TableSchemaEntry>,
1698) -> Option<String> {
1699 match expr {
1700 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1701 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1702 _ => None,
1703 }
1704}
1705
1706fn select_from_table_keys(
1707 select: &crate::expressions::Select,
1708 schema_map: &HashMap<String, TableSchemaEntry>,
1709) -> HashSet<String> {
1710 let mut keys = HashSet::new();
1711 if let Some(from_clause) = &select.from {
1712 for expr in &from_clause.expressions {
1713 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1714 keys.insert(key);
1715 }
1716 }
1717 }
1718 keys
1719}
1720
1721fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1722 matches!(
1723 kind,
1724 JoinKind::Natural
1725 | JoinKind::NaturalLeft
1726 | JoinKind::NaturalRight
1727 | JoinKind::NaturalFull
1728 | JoinKind::CrossApply
1729 | JoinKind::OuterApply
1730 | JoinKind::AsOf
1731 | JoinKind::AsOfLeft
1732 | JoinKind::AsOfRight
1733 | JoinKind::Lateral
1734 | JoinKind::LeftLateral
1735 )
1736}
1737
1738fn check_query_reference_quality(
1739 stmt: &Expression,
1740 schema_map: &HashMap<String, TableSchemaEntry>,
1741 resolver_schema: &MappingSchema,
1742 strict: bool,
1743 relationships: &[DeclaredRelationship],
1744) -> Vec<ValidationError> {
1745 let mut errors = Vec::new();
1746
1747 for node in stmt.dfs() {
1748 let Expression::Select(select) = node else {
1749 continue;
1750 };
1751
1752 let select_expr = Expression::Select(select.clone());
1753 let context = collect_type_check_context(&select_expr, schema_map);
1754 let scope = build_scope(&select_expr);
1755 let mut resolver = Resolver::new(&scope, resolver_schema, true);
1756
1757 if context.referenced_tables.len() > 1 {
1758 let using_columns: HashSet<String> = select
1759 .joins
1760 .iter()
1761 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1762 .collect();
1763
1764 let mut seen = HashSet::new();
1765 for column_expr in select_expr
1766 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
1767 {
1768 let Expression::Column(column) = column_expr else {
1769 continue;
1770 };
1771
1772 let col_name = lower(&column.name.name);
1773 if col_name.is_empty()
1774 || using_columns.contains(&col_name)
1775 || !seen.insert(col_name.clone())
1776 {
1777 continue;
1778 }
1779
1780 if resolver.is_ambiguous(&col_name) {
1781 let source_count = resolver.sources_for_column(&col_name).len();
1782 errors.push(if strict {
1783 ValidationError::error(
1784 format!(
1785 "Ambiguous unqualified column '{}' found in {} referenced tables",
1786 col_name, source_count
1787 ),
1788 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
1789 )
1790 } else {
1791 ValidationError::warning(
1792 format!(
1793 "Ambiguous unqualified column '{}' found in {} referenced tables",
1794 col_name, source_count
1795 ),
1796 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1797 )
1798 });
1799 }
1800 }
1801 }
1802
1803 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
1804
1805 for join in &select.joins {
1806 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
1807 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
1808 let cartesian_like_kind = matches!(
1809 join.kind,
1810 JoinKind::Cross
1811 | JoinKind::Implicit
1812 | JoinKind::Array
1813 | JoinKind::LeftArray
1814 | JoinKind::Paste
1815 );
1816
1817 if right_table_key.is_some()
1818 && (cartesian_like_kind
1819 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
1820 {
1821 errors.push(ValidationError::warning(
1822 "Potential cartesian join: JOIN without ON/USING condition",
1823 validation_codes::W_CARTESIAN_JOIN,
1824 ));
1825 }
1826
1827 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
1828 if join.using.is_empty() {
1829 let mut eq_pairs = Vec::new();
1830 extract_join_equality_pairs(
1831 on_expr,
1832 schema_map,
1833 &context,
1834 &mut resolver,
1835 &mut eq_pairs,
1836 );
1837
1838 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
1839 .iter()
1840 .filter(|rel| {
1841 cumulative_left_tables.contains(&rel.source_table)
1842 && rel.target_table == right_key
1843 || (cumulative_left_tables.contains(&rel.target_table)
1844 && rel.source_table == right_key)
1845 })
1846 .collect();
1847
1848 if !relevant_relationships.is_empty() {
1849 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
1850 relevant_relationships
1851 .iter()
1852 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
1853 });
1854 if !uses_declared_fk {
1855 errors.push(ValidationError::warning(
1856 "JOIN predicate does not use declared foreign-key relationship columns",
1857 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
1858 ));
1859 }
1860 }
1861 }
1862 }
1863
1864 if let Some(right_key) = right_table_key {
1865 cumulative_left_tables.insert(right_key);
1866 }
1867 }
1868 }
1869
1870 errors
1871}
1872
1873fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
1874 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1875 return true;
1876 }
1877 if left == right {
1878 return true;
1879 }
1880 if left.is_numeric() && right.is_numeric() {
1881 return true;
1882 }
1883 if left.is_temporal() && right.is_temporal() {
1884 return true;
1885 }
1886 false
1887}
1888
1889fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
1890 if left == TypeFamily::Unknown {
1891 return right;
1892 }
1893 if right == TypeFamily::Unknown {
1894 return left;
1895 }
1896 if left == right {
1897 return left;
1898 }
1899 if left.is_numeric() && right.is_numeric() {
1900 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
1901 return TypeFamily::Numeric;
1902 }
1903 return TypeFamily::Integer;
1904 }
1905 if left.is_temporal() && right.is_temporal() {
1906 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
1907 return TypeFamily::Timestamp;
1908 }
1909 if left == TypeFamily::Date || right == TypeFamily::Date {
1910 return TypeFamily::Date;
1911 }
1912 return TypeFamily::Time;
1913 }
1914 TypeFamily::Unknown
1915}
1916
1917fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
1918 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
1919 return true;
1920 }
1921 if target == source {
1922 return true;
1923 }
1924
1925 match target {
1926 TypeFamily::Boolean => source == TypeFamily::Boolean,
1927 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
1928 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
1929 source.is_temporal()
1930 }
1931 TypeFamily::String => true,
1932 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
1933 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
1934 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
1935 TypeFamily::Array => source == TypeFamily::Array,
1936 TypeFamily::Map => source == TypeFamily::Map,
1937 TypeFamily::Struct => source == TypeFamily::Struct,
1938 TypeFamily::Unknown => true,
1939 }
1940}
1941
1942fn projection_families(
1943 query_expr: &Expression,
1944 schema_map: &HashMap<String, TableSchemaEntry>,
1945) -> Option<Vec<TypeFamily>> {
1946 match query_expr {
1947 Expression::Select(select) => {
1948 if select
1949 .expressions
1950 .iter()
1951 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
1952 {
1953 return None;
1954 }
1955 let select_expr = Expression::Select(select.clone());
1956 let context = collect_type_check_context(&select_expr, schema_map);
1957 Some(
1958 select
1959 .expressions
1960 .iter()
1961 .map(|e| infer_expression_type_family(e, schema_map, &context))
1962 .collect(),
1963 )
1964 }
1965 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
1966 Expression::Union(union) => {
1967 let left = projection_families(&union.left, schema_map)?;
1968 let right = projection_families(&union.right, schema_map)?;
1969 if left.len() != right.len() {
1970 return None;
1971 }
1972 Some(
1973 left.into_iter()
1974 .zip(right)
1975 .map(|(l, r)| merged_setop_family(l, r))
1976 .collect(),
1977 )
1978 }
1979 Expression::Intersect(intersect) => {
1980 let left = projection_families(&intersect.left, schema_map)?;
1981 let right = projection_families(&intersect.right, schema_map)?;
1982 if left.len() != right.len() {
1983 return None;
1984 }
1985 Some(
1986 left.into_iter()
1987 .zip(right)
1988 .map(|(l, r)| merged_setop_family(l, r))
1989 .collect(),
1990 )
1991 }
1992 Expression::Except(except) => {
1993 let left = projection_families(&except.left, schema_map)?;
1994 let right = projection_families(&except.right, schema_map)?;
1995 if left.len() != right.len() {
1996 return None;
1997 }
1998 Some(
1999 left.into_iter()
2000 .zip(right)
2001 .map(|(l, r)| merged_setop_family(l, r))
2002 .collect(),
2003 )
2004 }
2005 Expression::Values(values) => {
2006 let first_row = values.expressions.first()?;
2007 let context = TypeCheckContext::default();
2008 Some(
2009 first_row
2010 .expressions
2011 .iter()
2012 .map(|e| infer_expression_type_family(e, schema_map, &context))
2013 .collect(),
2014 )
2015 }
2016 _ => None,
2017 }
2018}
2019
2020fn check_set_operation_compatibility(
2021 op_name: &str,
2022 left_expr: &Expression,
2023 right_expr: &Expression,
2024 schema_map: &HashMap<String, TableSchemaEntry>,
2025 strict: bool,
2026 errors: &mut Vec<ValidationError>,
2027) {
2028 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2029 return;
2030 };
2031 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2032 return;
2033 };
2034
2035 if left_projection.len() != right_projection.len() {
2036 errors.push(type_issue(
2037 strict,
2038 validation_codes::E_SETOP_ARITY_MISMATCH,
2039 validation_codes::W_SETOP_IMPLICIT_COERCION,
2040 format!(
2041 "{} operands return different column counts: left {}, right {}",
2042 op_name,
2043 left_projection.len(),
2044 right_projection.len()
2045 ),
2046 ));
2047 return;
2048 }
2049
2050 for (idx, (left, right)) in left_projection
2051 .into_iter()
2052 .zip(right_projection)
2053 .enumerate()
2054 {
2055 if !are_setop_compatible(left, right) {
2056 errors.push(type_issue(
2057 strict,
2058 validation_codes::E_SETOP_TYPE_MISMATCH,
2059 validation_codes::W_SETOP_IMPLICIT_COERCION,
2060 format!(
2061 "{} column {} has incompatible types: {} vs {}",
2062 op_name,
2063 idx + 1,
2064 type_family_name(left),
2065 type_family_name(right)
2066 ),
2067 ));
2068 }
2069 }
2070}
2071
2072fn check_insert_assignments(
2073 stmt: &Expression,
2074 insert: &Insert,
2075 schema_map: &HashMap<String, TableSchemaEntry>,
2076 strict: bool,
2077 errors: &mut Vec<ValidationError>,
2078) {
2079 let Some((target_table_key, table_schema)) =
2080 resolve_table_schema_entry(&insert.table, schema_map)
2081 else {
2082 return;
2083 };
2084
2085 let mut target_columns = Vec::new();
2086 if insert.columns.is_empty() {
2087 target_columns.extend(table_schema.column_order.iter().cloned());
2088 } else {
2089 for column in &insert.columns {
2090 let col_name = lower(&column.name);
2091 if table_schema.columns.contains_key(&col_name) {
2092 target_columns.push(col_name);
2093 } else {
2094 errors.push(if strict {
2095 ValidationError::error(
2096 format!(
2097 "Unknown column '{}' in table '{}'",
2098 column.name, target_table_key
2099 ),
2100 validation_codes::E_UNKNOWN_COLUMN,
2101 )
2102 } else {
2103 ValidationError::warning(
2104 format!(
2105 "Unknown column '{}' in table '{}'",
2106 column.name, target_table_key
2107 ),
2108 validation_codes::E_UNKNOWN_COLUMN,
2109 )
2110 });
2111 }
2112 }
2113 }
2114
2115 if target_columns.is_empty() {
2116 return;
2117 }
2118
2119 let context = collect_type_check_context(stmt, schema_map);
2120
2121 if !insert.default_values {
2122 for (row_idx, row) in insert.values.iter().enumerate() {
2123 if row.len() != target_columns.len() {
2124 errors.push(type_issue(
2125 strict,
2126 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2127 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2128 format!(
2129 "INSERT row {} has {} values but target has {} columns",
2130 row_idx + 1,
2131 row.len(),
2132 target_columns.len()
2133 ),
2134 ));
2135 continue;
2136 }
2137
2138 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2139 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2140 continue;
2141 };
2142 let source_family = infer_expression_type_family(value, schema_map, &context);
2143 if !are_assignment_compatible(target_family, source_family) {
2144 errors.push(type_issue(
2145 strict,
2146 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2147 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2148 format!(
2149 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2150 target_table_key,
2151 target_column,
2152 type_family_name(target_family),
2153 type_family_name(source_family)
2154 ),
2155 ));
2156 }
2157 }
2158 }
2159 }
2160
2161 if let Some(query) = &insert.query {
2162 if insert.by_name {
2164 return;
2165 }
2166
2167 let Some(source_projection) = projection_families(query, schema_map) else {
2168 return;
2169 };
2170
2171 if source_projection.len() != target_columns.len() {
2172 errors.push(type_issue(
2173 strict,
2174 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2175 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2176 format!(
2177 "INSERT source query has {} columns but target has {} columns",
2178 source_projection.len(),
2179 target_columns.len()
2180 ),
2181 ));
2182 return;
2183 }
2184
2185 for (source_family, target_column) in
2186 source_projection.into_iter().zip(target_columns.iter())
2187 {
2188 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2189 continue;
2190 };
2191 if !are_assignment_compatible(target_family, source_family) {
2192 errors.push(type_issue(
2193 strict,
2194 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2195 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2196 format!(
2197 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2198 target_table_key,
2199 target_column,
2200 type_family_name(target_family),
2201 type_family_name(source_family)
2202 ),
2203 ));
2204 }
2205 }
2206 }
2207}
2208
2209fn check_update_assignments(
2210 stmt: &Expression,
2211 update: &Update,
2212 schema_map: &HashMap<String, TableSchemaEntry>,
2213 strict: bool,
2214 errors: &mut Vec<ValidationError>,
2215) {
2216 let Some((target_table_key, table_schema)) =
2217 resolve_table_schema_entry(&update.table, schema_map)
2218 else {
2219 return;
2220 };
2221
2222 let context = collect_type_check_context(stmt, schema_map);
2223
2224 for (column, value) in &update.set {
2225 let col_name = lower(&column.name);
2226 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2227 errors.push(if strict {
2228 ValidationError::error(
2229 format!(
2230 "Unknown column '{}' in table '{}'",
2231 column.name, target_table_key
2232 ),
2233 validation_codes::E_UNKNOWN_COLUMN,
2234 )
2235 } else {
2236 ValidationError::warning(
2237 format!(
2238 "Unknown column '{}' in table '{}'",
2239 column.name, target_table_key
2240 ),
2241 validation_codes::E_UNKNOWN_COLUMN,
2242 )
2243 });
2244 continue;
2245 };
2246
2247 let source_family = infer_expression_type_family(value, schema_map, &context);
2248 if !are_assignment_compatible(target_family, source_family) {
2249 errors.push(type_issue(
2250 strict,
2251 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2252 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2253 format!(
2254 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2255 target_table_key,
2256 col_name,
2257 type_family_name(target_family),
2258 type_family_name(source_family)
2259 ),
2260 ));
2261 }
2262 }
2263}
2264
2265fn check_types(
2266 stmt: &Expression,
2267 schema_map: &HashMap<String, TableSchemaEntry>,
2268 strict: bool,
2269) -> Vec<ValidationError> {
2270 let mut errors = Vec::new();
2271 let context = collect_type_check_context(stmt, schema_map);
2272
2273 for node in stmt.dfs() {
2274 match node {
2275 Expression::Insert(insert) => {
2276 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2277 }
2278 Expression::Update(update) => {
2279 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2280 }
2281 Expression::Union(union) => {
2282 check_set_operation_compatibility(
2283 "UNION",
2284 &union.left,
2285 &union.right,
2286 schema_map,
2287 strict,
2288 &mut errors,
2289 );
2290 }
2291 Expression::Intersect(intersect) => {
2292 check_set_operation_compatibility(
2293 "INTERSECT",
2294 &intersect.left,
2295 &intersect.right,
2296 schema_map,
2297 strict,
2298 &mut errors,
2299 );
2300 }
2301 Expression::Except(except) => {
2302 check_set_operation_compatibility(
2303 "EXCEPT",
2304 &except.left,
2305 &except.right,
2306 schema_map,
2307 strict,
2308 &mut errors,
2309 );
2310 }
2311 Expression::Select(select) => {
2312 if let Some(prewhere) = &select.prewhere {
2313 let family = infer_expression_type_family(prewhere, schema_map, &context);
2314 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2315 errors.push(type_issue(
2316 strict,
2317 validation_codes::E_INVALID_PREDICATE_TYPE,
2318 validation_codes::W_PREDICATE_NULLABILITY,
2319 format!(
2320 "PREWHERE clause expects a boolean predicate, found {}",
2321 type_family_name(family)
2322 ),
2323 ));
2324 }
2325 }
2326
2327 if let Some(where_clause) = &select.where_clause {
2328 let family =
2329 infer_expression_type_family(&where_clause.this, schema_map, &context);
2330 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2331 errors.push(type_issue(
2332 strict,
2333 validation_codes::E_INVALID_PREDICATE_TYPE,
2334 validation_codes::W_PREDICATE_NULLABILITY,
2335 format!(
2336 "WHERE clause expects a boolean predicate, found {}",
2337 type_family_name(family)
2338 ),
2339 ));
2340 }
2341 }
2342
2343 if let Some(having_clause) = &select.having {
2344 let family =
2345 infer_expression_type_family(&having_clause.this, schema_map, &context);
2346 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2347 errors.push(type_issue(
2348 strict,
2349 validation_codes::E_INVALID_PREDICATE_TYPE,
2350 validation_codes::W_PREDICATE_NULLABILITY,
2351 format!(
2352 "HAVING clause expects a boolean predicate, found {}",
2353 type_family_name(family)
2354 ),
2355 ));
2356 }
2357 }
2358
2359 for join in &select.joins {
2360 if let Some(on) = &join.on {
2361 let family = infer_expression_type_family(on, schema_map, &context);
2362 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2363 errors.push(type_issue(
2364 strict,
2365 validation_codes::E_INVALID_PREDICATE_TYPE,
2366 validation_codes::W_PREDICATE_NULLABILITY,
2367 format!(
2368 "JOIN ON expects a boolean predicate, found {}",
2369 type_family_name(family)
2370 ),
2371 ));
2372 }
2373 }
2374 if let Some(match_condition) = &join.match_condition {
2375 let family =
2376 infer_expression_type_family(match_condition, schema_map, &context);
2377 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2378 errors.push(type_issue(
2379 strict,
2380 validation_codes::E_INVALID_PREDICATE_TYPE,
2381 validation_codes::W_PREDICATE_NULLABILITY,
2382 format!(
2383 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2384 type_family_name(family)
2385 ),
2386 ));
2387 }
2388 }
2389 }
2390 }
2391 Expression::Where(where_clause) => {
2392 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2393 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2394 errors.push(type_issue(
2395 strict,
2396 validation_codes::E_INVALID_PREDICATE_TYPE,
2397 validation_codes::W_PREDICATE_NULLABILITY,
2398 format!(
2399 "WHERE clause expects a boolean predicate, found {}",
2400 type_family_name(family)
2401 ),
2402 ));
2403 }
2404 }
2405 Expression::Having(having_clause) => {
2406 let family =
2407 infer_expression_type_family(&having_clause.this, schema_map, &context);
2408 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2409 errors.push(type_issue(
2410 strict,
2411 validation_codes::E_INVALID_PREDICATE_TYPE,
2412 validation_codes::W_PREDICATE_NULLABILITY,
2413 format!(
2414 "HAVING clause expects a boolean predicate, found {}",
2415 type_family_name(family)
2416 ),
2417 ));
2418 }
2419 }
2420 Expression::And(op) | Expression::Or(op) => {
2421 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2422 let family = infer_expression_type_family(expr, schema_map, &context);
2423 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2424 errors.push(type_issue(
2425 strict,
2426 validation_codes::E_INVALID_PREDICATE_TYPE,
2427 validation_codes::W_PREDICATE_NULLABILITY,
2428 format!(
2429 "Logical {} operand expects boolean, found {}",
2430 side,
2431 type_family_name(family)
2432 ),
2433 ));
2434 }
2435 }
2436 }
2437 Expression::Not(unary) => {
2438 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2439 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2440 errors.push(type_issue(
2441 strict,
2442 validation_codes::E_INVALID_PREDICATE_TYPE,
2443 validation_codes::W_PREDICATE_NULLABILITY,
2444 format!("NOT expects boolean, found {}", type_family_name(family)),
2445 ));
2446 }
2447 }
2448 Expression::Eq(op)
2449 | Expression::Neq(op)
2450 | Expression::Lt(op)
2451 | Expression::Lte(op)
2452 | Expression::Gt(op)
2453 | Expression::Gte(op) => {
2454 let left = infer_expression_type_family(&op.left, schema_map, &context);
2455 let right = infer_expression_type_family(&op.right, schema_map, &context);
2456 if !are_comparable(left, right) {
2457 errors.push(type_issue(
2458 strict,
2459 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2460 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2461 format!(
2462 "Incompatible comparison between {} and {}",
2463 type_family_name(left),
2464 type_family_name(right)
2465 ),
2466 ));
2467 }
2468 }
2469 Expression::Like(op) | Expression::ILike(op) => {
2470 let left = infer_expression_type_family(&op.left, schema_map, &context);
2471 let right = infer_expression_type_family(&op.right, schema_map, &context);
2472 if left != TypeFamily::Unknown
2473 && right != TypeFamily::Unknown
2474 && (!is_string_like(left) || !is_string_like(right))
2475 {
2476 errors.push(type_issue(
2477 strict,
2478 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2479 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2480 format!(
2481 "LIKE/ILIKE expects string operands, found {} and {}",
2482 type_family_name(left),
2483 type_family_name(right)
2484 ),
2485 ));
2486 }
2487 }
2488 Expression::Between(between) => {
2489 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2490 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2491 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2492
2493 if !are_comparable(this_family, low_family)
2494 || !are_comparable(this_family, high_family)
2495 {
2496 errors.push(type_issue(
2497 strict,
2498 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2499 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2500 format!(
2501 "BETWEEN bounds are incompatible with {} (found {} and {})",
2502 type_family_name(this_family),
2503 type_family_name(low_family),
2504 type_family_name(high_family)
2505 ),
2506 ));
2507 }
2508 }
2509 Expression::In(in_expr) => {
2510 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2511 for value in &in_expr.expressions {
2512 let item_family = infer_expression_type_family(value, schema_map, &context);
2513 if !are_comparable(this_family, item_family) {
2514 errors.push(type_issue(
2515 strict,
2516 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2517 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2518 format!(
2519 "IN item type {} is incompatible with {}",
2520 type_family_name(item_family),
2521 type_family_name(this_family)
2522 ),
2523 ));
2524 break;
2525 }
2526 }
2527 }
2528 Expression::Add(op)
2529 | Expression::Sub(op)
2530 | Expression::Mul(op)
2531 | Expression::Div(op)
2532 | Expression::Mod(op) => {
2533 let left = infer_expression_type_family(&op.left, schema_map, &context);
2534 let right = infer_expression_type_family(&op.right, schema_map, &context);
2535
2536 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2537 continue;
2538 }
2539
2540 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2541 && ((left.is_temporal() && right.is_numeric())
2542 || (right.is_temporal() && left.is_numeric())
2543 || (matches!(node, Expression::Sub(_))
2544 && left.is_temporal()
2545 && right.is_temporal()));
2546
2547 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2548 errors.push(type_issue(
2549 strict,
2550 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2551 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2552 format!(
2553 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2554 type_family_name(left),
2555 type_family_name(right)
2556 ),
2557 ));
2558 }
2559 }
2560 Expression::Function(function) => {
2561 check_generic_function(function, schema_map, &context, strict, &mut errors);
2562 }
2563 Expression::Upper(func)
2564 | Expression::Lower(func)
2565 | Expression::LTrim(func)
2566 | Expression::RTrim(func)
2567 | Expression::Reverse(func) => {
2568 let family = infer_expression_type_family(&func.this, schema_map, &context);
2569 check_function_argument(
2570 &mut errors,
2571 strict,
2572 "string_function",
2573 0,
2574 family,
2575 "a string argument",
2576 is_string_like(family),
2577 );
2578 }
2579 Expression::Length(func) => {
2580 let family = infer_expression_type_family(&func.this, schema_map, &context);
2581 check_function_argument(
2582 &mut errors,
2583 strict,
2584 "length",
2585 0,
2586 family,
2587 "a string or binary argument",
2588 is_string_or_binary(family),
2589 );
2590 }
2591 Expression::Trim(func) => {
2592 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2593 check_function_argument(
2594 &mut errors,
2595 strict,
2596 "trim",
2597 0,
2598 this_family,
2599 "a string argument",
2600 is_string_like(this_family),
2601 );
2602 if let Some(chars) = &func.characters {
2603 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2604 check_function_argument(
2605 &mut errors,
2606 strict,
2607 "trim",
2608 1,
2609 chars_family,
2610 "a string argument",
2611 is_string_like(chars_family),
2612 );
2613 }
2614 }
2615 Expression::Substring(func) => {
2616 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2617 check_function_argument(
2618 &mut errors,
2619 strict,
2620 "substring",
2621 0,
2622 this_family,
2623 "a string argument",
2624 is_string_like(this_family),
2625 );
2626
2627 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2628 check_function_argument(
2629 &mut errors,
2630 strict,
2631 "substring",
2632 1,
2633 start_family,
2634 "a numeric argument",
2635 start_family.is_numeric(),
2636 );
2637 if let Some(length) = &func.length {
2638 let length_family = infer_expression_type_family(length, schema_map, &context);
2639 check_function_argument(
2640 &mut errors,
2641 strict,
2642 "substring",
2643 2,
2644 length_family,
2645 "a numeric argument",
2646 length_family.is_numeric(),
2647 );
2648 }
2649 }
2650 Expression::Replace(func) => {
2651 for (arg, idx) in [
2652 (&func.this, 0_usize),
2653 (&func.old, 1_usize),
2654 (&func.new, 2_usize),
2655 ] {
2656 let family = infer_expression_type_family(arg, schema_map, &context);
2657 check_function_argument(
2658 &mut errors,
2659 strict,
2660 "replace",
2661 idx,
2662 family,
2663 "a string argument",
2664 is_string_like(family),
2665 );
2666 }
2667 }
2668 Expression::Left(func) | Expression::Right(func) => {
2669 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2670 check_function_argument(
2671 &mut errors,
2672 strict,
2673 "left_right",
2674 0,
2675 this_family,
2676 "a string argument",
2677 is_string_like(this_family),
2678 );
2679 let length_family =
2680 infer_expression_type_family(&func.length, schema_map, &context);
2681 check_function_argument(
2682 &mut errors,
2683 strict,
2684 "left_right",
2685 1,
2686 length_family,
2687 "a numeric argument",
2688 length_family.is_numeric(),
2689 );
2690 }
2691 Expression::Repeat(func) => {
2692 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2693 check_function_argument(
2694 &mut errors,
2695 strict,
2696 "repeat",
2697 0,
2698 this_family,
2699 "a string argument",
2700 is_string_like(this_family),
2701 );
2702 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2703 check_function_argument(
2704 &mut errors,
2705 strict,
2706 "repeat",
2707 1,
2708 times_family,
2709 "a numeric argument",
2710 times_family.is_numeric(),
2711 );
2712 }
2713 Expression::Lpad(func) | Expression::Rpad(func) => {
2714 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2715 check_function_argument(
2716 &mut errors,
2717 strict,
2718 "pad",
2719 0,
2720 this_family,
2721 "a string argument",
2722 is_string_like(this_family),
2723 );
2724 let length_family =
2725 infer_expression_type_family(&func.length, schema_map, &context);
2726 check_function_argument(
2727 &mut errors,
2728 strict,
2729 "pad",
2730 1,
2731 length_family,
2732 "a numeric argument",
2733 length_family.is_numeric(),
2734 );
2735 if let Some(fill) = &func.fill {
2736 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2737 check_function_argument(
2738 &mut errors,
2739 strict,
2740 "pad",
2741 2,
2742 fill_family,
2743 "a string argument",
2744 is_string_like(fill_family),
2745 );
2746 }
2747 }
2748 Expression::Abs(func)
2749 | Expression::Sqrt(func)
2750 | Expression::Cbrt(func)
2751 | Expression::Ln(func)
2752 | Expression::Exp(func) => {
2753 let family = infer_expression_type_family(&func.this, schema_map, &context);
2754 check_function_argument(
2755 &mut errors,
2756 strict,
2757 "numeric_function",
2758 0,
2759 family,
2760 "a numeric argument",
2761 family.is_numeric(),
2762 );
2763 }
2764 Expression::Round(func) => {
2765 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2766 check_function_argument(
2767 &mut errors,
2768 strict,
2769 "round",
2770 0,
2771 this_family,
2772 "a numeric argument",
2773 this_family.is_numeric(),
2774 );
2775 if let Some(decimals) = &func.decimals {
2776 let decimals_family =
2777 infer_expression_type_family(decimals, schema_map, &context);
2778 check_function_argument(
2779 &mut errors,
2780 strict,
2781 "round",
2782 1,
2783 decimals_family,
2784 "a numeric argument",
2785 decimals_family.is_numeric(),
2786 );
2787 }
2788 }
2789 Expression::Floor(func) => {
2790 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2791 check_function_argument(
2792 &mut errors,
2793 strict,
2794 "floor",
2795 0,
2796 this_family,
2797 "a numeric argument",
2798 this_family.is_numeric(),
2799 );
2800 if let Some(scale) = &func.scale {
2801 let scale_family = infer_expression_type_family(scale, schema_map, &context);
2802 check_function_argument(
2803 &mut errors,
2804 strict,
2805 "floor",
2806 1,
2807 scale_family,
2808 "a numeric argument",
2809 scale_family.is_numeric(),
2810 );
2811 }
2812 }
2813 Expression::Ceil(func) => {
2814 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2815 check_function_argument(
2816 &mut errors,
2817 strict,
2818 "ceil",
2819 0,
2820 this_family,
2821 "a numeric argument",
2822 this_family.is_numeric(),
2823 );
2824 if let Some(decimals) = &func.decimals {
2825 let decimals_family =
2826 infer_expression_type_family(decimals, schema_map, &context);
2827 check_function_argument(
2828 &mut errors,
2829 strict,
2830 "ceil",
2831 1,
2832 decimals_family,
2833 "a numeric argument",
2834 decimals_family.is_numeric(),
2835 );
2836 }
2837 }
2838 Expression::Power(func) => {
2839 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
2840 check_function_argument(
2841 &mut errors,
2842 strict,
2843 "power",
2844 0,
2845 left_family,
2846 "a numeric argument",
2847 left_family.is_numeric(),
2848 );
2849 let right_family =
2850 infer_expression_type_family(&func.expression, schema_map, &context);
2851 check_function_argument(
2852 &mut errors,
2853 strict,
2854 "power",
2855 1,
2856 right_family,
2857 "a numeric argument",
2858 right_family.is_numeric(),
2859 );
2860 }
2861 Expression::Log(func) => {
2862 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2863 check_function_argument(
2864 &mut errors,
2865 strict,
2866 "log",
2867 0,
2868 this_family,
2869 "a numeric argument",
2870 this_family.is_numeric(),
2871 );
2872 if let Some(base) = &func.base {
2873 let base_family = infer_expression_type_family(base, schema_map, &context);
2874 check_function_argument(
2875 &mut errors,
2876 strict,
2877 "log",
2878 1,
2879 base_family,
2880 "a numeric argument",
2881 base_family.is_numeric(),
2882 );
2883 }
2884 }
2885 _ => {}
2886 }
2887 }
2888
2889 errors
2890}
2891
2892fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
2893 let mut errors = Vec::new();
2894
2895 let Expression::Select(select) = stmt else {
2896 return errors;
2897 };
2898 let select_expr = Expression::Select(select.clone());
2899
2900 if !select_expr
2902 .find_all(|e| matches!(e, Expression::Star(_)))
2903 .is_empty()
2904 {
2905 errors.push(ValidationError::warning(
2906 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
2907 validation_codes::W_SELECT_STAR,
2908 ));
2909 }
2910
2911 let aggregate_count = get_aggregate_functions(&select_expr).len();
2913 if aggregate_count > 0 && select.group_by.is_none() {
2914 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
2915 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
2916 && get_aggregate_functions(expr).is_empty()
2917 });
2918
2919 if has_non_aggregate_column {
2920 errors.push(ValidationError::warning(
2921 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
2922 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
2923 ));
2924 }
2925 }
2926
2927 if select.distinct && select.order_by.is_some() {
2929 errors.push(ValidationError::warning(
2930 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
2931 validation_codes::W_DISTINCT_ORDER_BY,
2932 ));
2933 }
2934
2935 if select.limit.is_some() && select.order_by.is_none() {
2937 errors.push(ValidationError::warning(
2938 "LIMIT without ORDER BY produces non-deterministic results",
2939 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
2940 ));
2941 }
2942
2943 errors
2944}
2945
2946fn validate_statement_with_schema(
2947 stmt: &Expression,
2948 schema_map: &HashMap<String, TableSchemaEntry>,
2949 strict: bool,
2950) -> Vec<ValidationError> {
2951 let mut errors = Vec::new();
2952 let cte_aliases = collect_cte_aliases(stmt);
2953
2954 let mut referenced_tables: HashSet<String> = HashSet::new();
2955 let mut table_aliases: HashMap<String, String> = HashMap::new();
2956 let mut seen_tables: HashSet<String> = HashSet::new();
2957
2958 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
2960 let Expression::Table(table) = node else {
2961 continue;
2962 };
2963
2964 if cte_aliases.contains(&lower(&table.name.name)) {
2965 continue;
2966 }
2967
2968 let resolved_key = table_ref_candidates(table)
2969 .into_iter()
2970 .find(|k| schema_map.contains_key(k));
2971 let table_key = resolved_key
2972 .clone()
2973 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
2974
2975 if let Some(alias) = &table.alias {
2977 table_aliases.insert(lower(&alias.name), table_key.clone());
2978 }
2979
2980 if !seen_tables.insert(table_key.clone()) {
2981 continue;
2982 }
2983 referenced_tables.insert(table_key.clone());
2984
2985 if resolved_key.is_none() {
2986 let err = if strict {
2987 ValidationError::error(
2988 format!("Unknown table '{}'", table_ref_display_name(table)),
2989 validation_codes::E_UNKNOWN_TABLE,
2990 )
2991 } else {
2992 ValidationError::warning(
2993 format!("Unknown table '{}'", table_ref_display_name(table)),
2994 validation_codes::E_UNKNOWN_TABLE,
2995 )
2996 };
2997 errors.push(err);
2998 }
2999 }
3000
3001 for node in stmt.find_all(|e| matches!(e, Expression::Column(_))) {
3003 let Expression::Column(column) = node else {
3004 continue;
3005 };
3006
3007 let col_name = lower(&column.name.name);
3008 if col_name.is_empty() {
3009 continue;
3010 }
3011
3012 let mut table_name = column.table.as_ref().map(|t| lower(&t.name));
3013 if let Some(name) = &table_name {
3014 if let Some(mapped) = table_aliases.get(name) {
3015 table_name = Some(mapped.clone());
3016 }
3017 }
3018
3019 if let Some(table_name) = table_name {
3020 if let Some(table_schema) = schema_map.get(&table_name) {
3021 if !table_schema.columns.contains_key(&col_name) {
3022 let err = if strict {
3023 ValidationError::error(
3024 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3025 validation_codes::E_UNKNOWN_COLUMN,
3026 )
3027 } else {
3028 ValidationError::warning(
3029 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3030 validation_codes::E_UNKNOWN_COLUMN,
3031 )
3032 };
3033 errors.push(err);
3034 }
3035 }
3036 continue;
3037 }
3038
3039 if referenced_tables.len() == 1 {
3040 if let Some(single_table) = referenced_tables.iter().next() {
3041 if let Some(table_schema) = schema_map.get(single_table) {
3042 if !table_schema.columns.contains_key(&col_name) {
3043 let err = if strict {
3044 ValidationError::error(
3045 format!(
3046 "Unknown column '{}' in table '{}'",
3047 col_name, single_table
3048 ),
3049 validation_codes::E_UNKNOWN_COLUMN,
3050 )
3051 } else {
3052 ValidationError::warning(
3053 format!(
3054 "Unknown column '{}' in table '{}'",
3055 col_name, single_table
3056 ),
3057 validation_codes::E_UNKNOWN_COLUMN,
3058 )
3059 };
3060 errors.push(err);
3061 }
3062 }
3063 }
3064 } else if referenced_tables.len() > 1 {
3065 let found = referenced_tables.iter().any(|table_name| {
3066 schema_map
3067 .get(table_name)
3068 .map(|s| s.columns.contains_key(&col_name))
3069 .unwrap_or(false)
3070 });
3071
3072 if !found {
3073 let err = if strict {
3074 ValidationError::error(
3075 format!(
3076 "Unknown column '{}' (not found in any referenced table)",
3077 col_name
3078 ),
3079 validation_codes::E_UNKNOWN_COLUMN,
3080 )
3081 } else {
3082 ValidationError::warning(
3083 format!(
3084 "Unknown column '{}' (not found in any referenced table)",
3085 col_name
3086 ),
3087 validation_codes::E_UNKNOWN_COLUMN,
3088 )
3089 };
3090 errors.push(err);
3091 }
3092 } else if !schema_map.is_empty() {
3093 let found = schema_map
3094 .values()
3095 .any(|table| table.columns.contains_key(&col_name));
3096 if !found {
3097 let err = if strict {
3098 ValidationError::error(
3099 format!("Unknown column '{}'", col_name),
3100 validation_codes::E_UNKNOWN_COLUMN,
3101 )
3102 } else {
3103 ValidationError::warning(
3104 format!("Unknown column '{}'", col_name),
3105 validation_codes::E_UNKNOWN_COLUMN,
3106 )
3107 };
3108 errors.push(err);
3109 }
3110 }
3111 }
3112
3113 errors
3114}
3115
3116pub fn validate_with_schema(
3118 sql: &str,
3119 dialect: DialectType,
3120 schema: &ValidationSchema,
3121 options: &SchemaValidationOptions,
3122) -> ValidationResult {
3123 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3124
3125 let syntax_result = crate::validate(sql, dialect);
3127 if !syntax_result.valid {
3128 return syntax_result;
3129 }
3130
3131 let d = Dialect::get(dialect);
3132 let statements = match d.parse(sql) {
3133 Ok(exprs) => exprs,
3134 Err(e) => {
3135 return ValidationResult::with_errors(vec![ValidationError::error(
3136 e.to_string(),
3137 validation_codes::E_PARSE_OR_OPTIONS,
3138 )]);
3139 }
3140 };
3141
3142 let schema_map = build_schema_map(schema);
3143 let resolver_schema = build_resolver_schema(schema);
3144 let mut all_errors = syntax_result.errors;
3145 let declared_relationships = if options.check_references {
3146 build_declared_relationships(schema, &schema_map)
3147 } else {
3148 Vec::new()
3149 };
3150
3151 if options.check_references {
3152 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3153 }
3154
3155 for stmt in &statements {
3156 if options.semantic {
3157 all_errors.extend(check_semantics(stmt));
3158 }
3159 all_errors.extend(validate_statement_with_schema(stmt, &schema_map, strict));
3160 if options.check_types {
3161 all_errors.extend(check_types(stmt, &schema_map, strict));
3162 }
3163 if options.check_references {
3164 all_errors.extend(check_query_reference_quality(
3165 stmt,
3166 &schema_map,
3167 &resolver_schema,
3168 strict,
3169 &declared_relationships,
3170 ));
3171 }
3172 }
3173
3174 ValidationResult::with_errors(all_errors)
3175}
3176
3177#[cfg(test)]
3178mod tests;