Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_registry::canonical_typed_function_name_upper;
14use crate::optimizer::annotate_types;
15use crate::resolver::Resolver;
16use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
17use crate::scope::build_scope;
18use crate::traversal::ExpressionWalk;
19use serde::{Deserialize, Serialize};
20use std::collections::{HashMap, HashSet};
21
22/// Column definition used for schema-aware validation.
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SchemaColumn {
25    /// Column name.
26    pub name: String,
27    /// Optional column data type (currently informational).
28    #[serde(default, rename = "type")]
29    pub data_type: String,
30    /// Whether the column allows NULL values.
31    #[serde(default)]
32    pub nullable: Option<bool>,
33    /// Whether this column is part of a primary key.
34    #[serde(default, rename = "primaryKey")]
35    pub primary_key: bool,
36    /// Whether this column has a uniqueness constraint.
37    #[serde(default)]
38    pub unique: bool,
39    /// Optional column-level foreign key reference.
40    #[serde(default)]
41    pub references: Option<SchemaColumnReference>,
42}
43
44/// Column-level foreign key reference metadata.
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct SchemaColumnReference {
47    /// Referenced table name.
48    pub table: String,
49    /// Referenced column name.
50    pub column: String,
51    /// Optional schema/namespace of referenced table.
52    #[serde(default)]
53    pub schema: Option<String>,
54}
55
56/// Table-level foreign key reference metadata.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct SchemaForeignKey {
59    /// Optional FK name.
60    #[serde(default)]
61    pub name: Option<String>,
62    /// Source columns in the current table.
63    pub columns: Vec<String>,
64    /// Referenced target table + columns.
65    pub references: SchemaTableReference,
66}
67
68/// Target of a table-level foreign key.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70pub struct SchemaTableReference {
71    /// Referenced table name.
72    pub table: String,
73    /// Referenced target columns.
74    pub columns: Vec<String>,
75    /// Optional schema/namespace of referenced table.
76    #[serde(default)]
77    pub schema: Option<String>,
78}
79
80/// Table definition used for schema-aware validation.
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SchemaTable {
83    /// Table name.
84    pub name: String,
85    /// Optional schema/namespace name.
86    #[serde(default)]
87    pub schema: Option<String>,
88    /// Column definitions.
89    pub columns: Vec<SchemaColumn>,
90    /// Optional aliases that should resolve to this table.
91    #[serde(default)]
92    pub aliases: Vec<String>,
93    /// Optional primary key column list.
94    #[serde(default, rename = "primaryKey")]
95    pub primary_key: Vec<String>,
96    /// Optional unique key groups.
97    #[serde(default, rename = "uniqueKeys")]
98    pub unique_keys: Vec<Vec<String>>,
99    /// Optional table-level foreign keys.
100    #[serde(default, rename = "foreignKeys")]
101    pub foreign_keys: Vec<SchemaForeignKey>,
102}
103
104/// Schema payload used for schema-aware validation.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct ValidationSchema {
107    /// Known tables.
108    pub tables: Vec<SchemaTable>,
109    /// Default strict mode for unknown identifiers.
110    #[serde(default)]
111    pub strict: Option<bool>,
112}
113
114/// Options for schema-aware validation.
115#[derive(Debug, Clone, Serialize, Deserialize, Default)]
116pub struct SchemaValidationOptions {
117    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
118    #[serde(default)]
119    pub check_types: bool,
120    /// Enables FK/reference integrity checks and query-level reference quality checks.
121    #[serde(default)]
122    pub check_references: bool,
123    /// If true/false, overrides schema.strict.
124    #[serde(default)]
125    pub strict: Option<bool>,
126    /// Enables semantic warnings (W001..W004).
127    #[serde(default)]
128    pub semantic: bool,
129    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
130    #[serde(default)]
131    pub strict_syntax: bool,
132}
133
134/// Validation error/warning codes used by schema-aware validation.
135pub mod validation_codes {
136    // Existing schema and semantic checks.
137    pub const E_PARSE_OR_OPTIONS: &str = "E000";
138    pub const E_UNKNOWN_TABLE: &str = "E200";
139    pub const E_UNKNOWN_COLUMN: &str = "E201";
140
141    pub const W_SELECT_STAR: &str = "W001";
142    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
143    pub const W_DISTINCT_ORDER_BY: &str = "W003";
144    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
145
146    // Phase 2 (type checks): E210-E219, W210-W219.
147    pub const E_TYPE_MISMATCH: &str = "E210";
148    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
149    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
150    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
151    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
152    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
153    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
154    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
155    pub const E_INVALID_CAST: &str = "E218";
156    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
157
158    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
159    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
160    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
161    pub const W_LOSSY_CAST: &str = "W213";
162    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
163    pub const W_PREDICATE_NULLABILITY: &str = "W215";
164    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
165    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
166    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
167    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
168
169    // Phase 2 (reference checks): E220-E229, W220-W229.
170    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
171    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
172    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
173    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
174    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
175
176    pub const W_CARTESIAN_JOIN: &str = "W220";
177    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
178    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
179}
180
181/// Canonical type family used by schema/type checks.
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
183#[serde(rename_all = "snake_case")]
184pub enum TypeFamily {
185    Unknown,
186    Boolean,
187    Integer,
188    Numeric,
189    String,
190    Binary,
191    Date,
192    Time,
193    Timestamp,
194    Interval,
195    Json,
196    Uuid,
197    Array,
198    Map,
199    Struct,
200}
201
202impl TypeFamily {
203    pub fn is_numeric(self) -> bool {
204        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
205    }
206
207    pub fn is_temporal(self) -> bool {
208        matches!(
209            self,
210            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
211        )
212    }
213}
214
215#[derive(Debug, Clone)]
216struct TableSchemaEntry {
217    columns: HashMap<String, TypeFamily>,
218    column_order: Vec<String>,
219}
220
221fn lower(s: &str) -> String {
222    s.to_lowercase()
223}
224
225fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
226    let open = data_type.find('(')?;
227    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
228        return None;
229    }
230    let base = data_type[..open].trim();
231    let inner = data_type[open + 1..data_type.len() - 1].trim();
232    Some((base, inner))
233}
234
235/// Canonicalize a schema type string into a stable `TypeFamily`.
236pub fn canonical_type_family(data_type: &str) -> TypeFamily {
237    let trimmed = data_type
238        .trim()
239        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
240    if trimmed.is_empty() {
241        return TypeFamily::Unknown;
242    }
243
244    // Normalize whitespace and lowercase for matching.
245    let normalized = trimmed
246        .split_whitespace()
247        .collect::<Vec<_>>()
248        .join(" ")
249        .to_lowercase();
250
251    // Strip common wrappers first.
252    if let Some((base, inner)) = split_type_args(&normalized) {
253        match base {
254            "nullable" | "lowcardinality" => return canonical_type_family(inner),
255            "array" | "list" => return TypeFamily::Array,
256            "map" => return TypeFamily::Map,
257            "struct" | "row" | "record" => return TypeFamily::Struct,
258            _ => {}
259        }
260    }
261
262    if normalized.starts_with("array<") || normalized.starts_with("list<") {
263        return TypeFamily::Array;
264    }
265    if normalized.starts_with("map<") {
266        return TypeFamily::Map;
267    }
268    if normalized.starts_with("struct<")
269        || normalized.starts_with("row<")
270        || normalized.starts_with("record<")
271        || normalized.starts_with("object<")
272    {
273        return TypeFamily::Struct;
274    }
275
276    if normalized.ends_with("[]") {
277        return TypeFamily::Array;
278    }
279
280    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
281    let mut base = normalized
282        .split('(')
283        .next()
284        .unwrap_or("")
285        .trim()
286        .to_string();
287    if base.is_empty() {
288        return TypeFamily::Unknown;
289    }
290
291    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
292    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
293
294    match base.as_str() {
295        "bool" | "boolean" => TypeFamily::Boolean,
296        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
297        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
298        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
299            TypeFamily::Integer
300        }
301        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
302        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
303            TypeFamily::Numeric
304        }
305        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
306        | "string" | "clob" => TypeFamily::String,
307        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
308        "date" => TypeFamily::Date,
309        "time" => TypeFamily::Time,
310        "timestamp"
311        | "timestamptz"
312        | "datetime"
313        | "datetime2"
314        | "smalldatetime"
315        | "timestamp with time zone"
316        | "timestamp without time zone" => TypeFamily::Timestamp,
317        "interval" => TypeFamily::Interval,
318        "json" | "jsonb" | "variant" => TypeFamily::Json,
319        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
320        "array" | "list" => TypeFamily::Array,
321        "map" => TypeFamily::Map,
322        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
323        _ => TypeFamily::Unknown,
324    }
325}
326
327fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
328    let mut map = HashMap::new();
329
330    for table in &schema.tables {
331        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
332        let columns: HashMap<String, TypeFamily> = table
333            .columns
334            .iter()
335            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
336            .collect();
337        let entry = TableSchemaEntry {
338            columns,
339            column_order,
340        };
341
342        let simple_name = lower(&table.name);
343        map.insert(simple_name, entry.clone());
344
345        if let Some(table_schema) = &table.schema {
346            map.insert(
347                format!("{}.{}", lower(table_schema), lower(&table.name)),
348                entry.clone(),
349            );
350        }
351
352        for alias in &table.aliases {
353            map.insert(lower(alias), entry.clone());
354        }
355    }
356
357    map
358}
359
360fn type_family_to_data_type(family: TypeFamily) -> DataType {
361    match family {
362        TypeFamily::Unknown => DataType::Unknown,
363        TypeFamily::Boolean => DataType::Boolean,
364        TypeFamily::Integer => DataType::Int {
365            length: None,
366            integer_spelling: false,
367        },
368        TypeFamily::Numeric => DataType::Double {
369            precision: None,
370            scale: None,
371        },
372        TypeFamily::String => DataType::VarChar {
373            length: None,
374            parenthesized_length: false,
375        },
376        TypeFamily::Binary => DataType::VarBinary { length: None },
377        TypeFamily::Date => DataType::Date,
378        TypeFamily::Time => DataType::Time {
379            precision: None,
380            timezone: false,
381        },
382        TypeFamily::Timestamp => DataType::Timestamp {
383            precision: None,
384            timezone: false,
385        },
386        TypeFamily::Interval => DataType::Interval {
387            unit: None,
388            to: None,
389        },
390        TypeFamily::Json => DataType::Json,
391        TypeFamily::Uuid => DataType::Uuid,
392        TypeFamily::Array => DataType::Array {
393            element_type: Box::new(DataType::Unknown),
394            dimension: None,
395        },
396        TypeFamily::Map => DataType::Map {
397            key_type: Box::new(DataType::Unknown),
398            value_type: Box::new(DataType::Unknown),
399        },
400        TypeFamily::Struct => DataType::Struct {
401            fields: Vec::new(),
402            nested: false,
403        },
404    }
405}
406
407fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
408    let mut mapping = MappingSchema::new();
409
410    for table in &schema.tables {
411        let columns: Vec<(String, DataType)> = table
412            .columns
413            .iter()
414            .map(|column| {
415                (
416                    lower(&column.name),
417                    type_family_to_data_type(canonical_type_family(&column.data_type)),
418                )
419            })
420            .collect();
421
422        let mut table_names = Vec::new();
423        table_names.push(lower(&table.name));
424        if let Some(table_schema) = &table.schema {
425            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
426        }
427        for alias in &table.aliases {
428            table_names.push(lower(alias));
429        }
430
431        let mut dedup = HashSet::new();
432        for table_name in table_names {
433            if dedup.insert(table_name.clone()) {
434                let _ = mapping.add_table(&table_name, &columns, None);
435            }
436        }
437    }
438
439    mapping
440}
441
442fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
443    let mut aliases = HashSet::new();
444
445    for node in expr.dfs() {
446        match node {
447            Expression::Select(select) => {
448                if let Some(with) = &select.with {
449                    for cte in &with.ctes {
450                        aliases.insert(lower(&cte.alias.name));
451                    }
452                }
453            }
454            Expression::Insert(insert) => {
455                if let Some(with) = &insert.with {
456                    for cte in &with.ctes {
457                        aliases.insert(lower(&cte.alias.name));
458                    }
459                }
460            }
461            Expression::Update(update) => {
462                if let Some(with) = &update.with {
463                    for cte in &with.ctes {
464                        aliases.insert(lower(&cte.alias.name));
465                    }
466                }
467            }
468            Expression::Delete(delete) => {
469                if let Some(with) = &delete.with {
470                    for cte in &with.ctes {
471                        aliases.insert(lower(&cte.alias.name));
472                    }
473                }
474            }
475            Expression::Union(union) => {
476                if let Some(with) = &union.with {
477                    for cte in &with.ctes {
478                        aliases.insert(lower(&cte.alias.name));
479                    }
480                }
481            }
482            Expression::Intersect(intersect) => {
483                if let Some(with) = &intersect.with {
484                    for cte in &with.ctes {
485                        aliases.insert(lower(&cte.alias.name));
486                    }
487                }
488            }
489            Expression::Except(except) => {
490                if let Some(with) = &except.with {
491                    for cte in &with.ctes {
492                        aliases.insert(lower(&cte.alias.name));
493                    }
494                }
495            }
496            Expression::Merge(merge) => {
497                if let Some(with_) = &merge.with_ {
498                    if let Expression::With(with_clause) = with_.as_ref() {
499                        for cte in &with_clause.ctes {
500                            aliases.insert(lower(&cte.alias.name));
501                        }
502                    }
503                }
504            }
505            _ => {}
506        }
507    }
508
509    aliases
510}
511
512fn table_ref_candidates(table: &TableRef) -> Vec<String> {
513    let name = lower(&table.name.name);
514    let schema = table.schema.as_ref().map(|s| lower(&s.name));
515    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
516
517    let mut candidates = Vec::new();
518    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
519        candidates.push(format!("{}.{}.{}", catalog, schema, name));
520    }
521    if let Some(schema) = &schema {
522        candidates.push(format!("{}.{}", schema, name));
523    }
524    candidates.push(name);
525    candidates
526}
527
528fn table_ref_display_name(table: &TableRef) -> String {
529    let mut parts = Vec::new();
530    if let Some(catalog) = &table.catalog {
531        parts.push(catalog.name.clone());
532    }
533    if let Some(schema) = &table.schema {
534        parts.push(schema.name.clone());
535    }
536    parts.push(table.name.name.clone());
537    parts.join(".")
538}
539
540#[derive(Debug, Default, Clone)]
541struct TypeCheckContext {
542    referenced_tables: HashSet<String>,
543    table_aliases: HashMap<String, String>,
544}
545
546fn type_family_name(family: TypeFamily) -> &'static str {
547    match family {
548        TypeFamily::Unknown => "unknown",
549        TypeFamily::Boolean => "boolean",
550        TypeFamily::Integer => "integer",
551        TypeFamily::Numeric => "numeric",
552        TypeFamily::String => "string",
553        TypeFamily::Binary => "binary",
554        TypeFamily::Date => "date",
555        TypeFamily::Time => "time",
556        TypeFamily::Timestamp => "timestamp",
557        TypeFamily::Interval => "interval",
558        TypeFamily::Json => "json",
559        TypeFamily::Uuid => "uuid",
560        TypeFamily::Array => "array",
561        TypeFamily::Map => "map",
562        TypeFamily::Struct => "struct",
563    }
564}
565
566fn is_string_like(family: TypeFamily) -> bool {
567    matches!(family, TypeFamily::String)
568}
569
570fn is_string_or_binary(family: TypeFamily) -> bool {
571    matches!(family, TypeFamily::String | TypeFamily::Binary)
572}
573
574fn type_issue(
575    strict: bool,
576    error_code: &str,
577    warning_code: &str,
578    message: impl Into<String>,
579) -> ValidationError {
580    if strict {
581        ValidationError::error(message.into(), error_code)
582    } else {
583        ValidationError::warning(message.into(), warning_code)
584    }
585}
586
587fn data_type_family(data_type: &DataType) -> TypeFamily {
588    match data_type {
589        DataType::Boolean => TypeFamily::Boolean,
590        DataType::TinyInt { .. }
591        | DataType::SmallInt { .. }
592        | DataType::Int { .. }
593        | DataType::BigInt { .. } => TypeFamily::Integer,
594        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
595            TypeFamily::Numeric
596        }
597        DataType::Char { .. }
598        | DataType::VarChar { .. }
599        | DataType::String { .. }
600        | DataType::Text
601        | DataType::TextWithLength { .. }
602        | DataType::CharacterSet { .. } => TypeFamily::String,
603        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
604        DataType::Date => TypeFamily::Date,
605        DataType::Time { .. } => TypeFamily::Time,
606        DataType::Timestamp { .. } => TypeFamily::Timestamp,
607        DataType::Interval { .. } => TypeFamily::Interval,
608        DataType::Json | DataType::JsonB => TypeFamily::Json,
609        DataType::Uuid => TypeFamily::Uuid,
610        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
611        DataType::Map { .. } => TypeFamily::Map,
612        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
613            TypeFamily::Struct
614        }
615        DataType::Nullable { inner } => data_type_family(inner),
616        DataType::Custom { name } => canonical_type_family(name),
617        DataType::Unknown => TypeFamily::Unknown,
618        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
619        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
620        DataType::Vector { .. } => TypeFamily::Array,
621        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
622    }
623}
624
625fn collect_type_check_context(
626    stmt: &Expression,
627    schema_map: &HashMap<String, TableSchemaEntry>,
628) -> TypeCheckContext {
629    fn add_table_to_context(
630        table: &TableRef,
631        schema_map: &HashMap<String, TableSchemaEntry>,
632        context: &mut TypeCheckContext,
633    ) {
634        let resolved_key = table_ref_candidates(table)
635            .into_iter()
636            .find(|k| schema_map.contains_key(k));
637
638        let Some(table_key) = resolved_key else {
639            return;
640        };
641
642        context.referenced_tables.insert(table_key.clone());
643        context
644            .table_aliases
645            .insert(lower(&table.name.name), table_key.clone());
646        if let Some(alias) = &table.alias {
647            context
648                .table_aliases
649                .insert(lower(&alias.name), table_key.clone());
650        }
651    }
652
653    let mut context = TypeCheckContext::default();
654    let cte_aliases = collect_cte_aliases(stmt);
655
656    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
657        let Expression::Table(table) = node else {
658            continue;
659        };
660
661        if cte_aliases.contains(&lower(&table.name.name)) {
662            continue;
663        }
664
665        add_table_to_context(table, schema_map, &mut context);
666    }
667
668    // Seed DML target tables explicitly because they are struct fields and may
669    // not appear as standalone Expression::Table nodes in traversal output.
670    match stmt {
671        Expression::Insert(insert) => {
672            add_table_to_context(&insert.table, schema_map, &mut context);
673        }
674        Expression::Update(update) => {
675            add_table_to_context(&update.table, schema_map, &mut context);
676            for table in &update.extra_tables {
677                add_table_to_context(table, schema_map, &mut context);
678            }
679        }
680        Expression::Delete(delete) => {
681            add_table_to_context(&delete.table, schema_map, &mut context);
682            for table in &delete.using {
683                add_table_to_context(table, schema_map, &mut context);
684            }
685            for table in &delete.tables {
686                add_table_to_context(table, schema_map, &mut context);
687            }
688        }
689        _ => {}
690    }
691
692    context
693}
694
695fn resolve_table_schema_entry<'a>(
696    table: &TableRef,
697    schema_map: &'a HashMap<String, TableSchemaEntry>,
698) -> Option<(String, &'a TableSchemaEntry)> {
699    let key = table_ref_candidates(table)
700        .into_iter()
701        .find(|k| schema_map.contains_key(k))?;
702    let entry = schema_map.get(&key)?;
703    Some((key, entry))
704}
705
706fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
707    if strict {
708        ValidationError::error(
709            message.into(),
710            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
711        )
712    } else {
713        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
714    }
715}
716
717fn reference_table_candidates(
718    table_name: &str,
719    explicit_schema: Option<&str>,
720    source_schema: Option<&str>,
721) -> Vec<String> {
722    let mut candidates = Vec::new();
723    let raw = lower(table_name);
724
725    if let Some(schema) = explicit_schema {
726        candidates.push(format!("{}.{}", lower(schema), raw));
727    }
728
729    if raw.contains('.') {
730        candidates.push(raw.clone());
731        if let Some(last) = raw.rsplit('.').next() {
732            candidates.push(last.to_string());
733        }
734    } else {
735        if let Some(schema) = source_schema {
736            candidates.push(format!("{}.{}", lower(schema), raw));
737        }
738        candidates.push(raw);
739    }
740
741    let mut dedup = HashSet::new();
742    candidates
743        .into_iter()
744        .filter(|c| dedup.insert(c.clone()))
745        .collect()
746}
747
748fn resolve_reference_table_key(
749    table_name: &str,
750    explicit_schema: Option<&str>,
751    source_schema: Option<&str>,
752    schema_map: &HashMap<String, TableSchemaEntry>,
753) -> Option<String> {
754    reference_table_candidates(table_name, explicit_schema, source_schema)
755        .into_iter()
756        .find(|candidate| schema_map.contains_key(candidate))
757}
758
759fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
760    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
761        return true;
762    }
763    if source == target {
764        return true;
765    }
766    if source.is_numeric() && target.is_numeric() {
767        return true;
768    }
769    if source.is_temporal() && target.is_temporal() {
770        return true;
771    }
772    false
773}
774
775fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
776    let mut hints = HashSet::new();
777    for column in &table.columns {
778        if column.primary_key || column.unique {
779            hints.insert(lower(&column.name));
780        }
781    }
782    for key_col in &table.primary_key {
783        hints.insert(lower(key_col));
784    }
785    for group in &table.unique_keys {
786        if group.len() == 1 {
787            if let Some(col) = group.first() {
788                hints.insert(lower(col));
789            }
790        }
791    }
792    hints
793}
794
795fn check_reference_integrity(
796    schema: &ValidationSchema,
797    schema_map: &HashMap<String, TableSchemaEntry>,
798    strict: bool,
799) -> Vec<ValidationError> {
800    let mut errors = Vec::new();
801
802    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
803    for table in &schema.tables {
804        let simple = lower(&table.name);
805        key_hints_lookup.insert(simple, table_key_hints(table));
806        if let Some(schema_name) = &table.schema {
807            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
808            key_hints_lookup.insert(qualified, table_key_hints(table));
809        }
810    }
811
812    for table in &schema.tables {
813        let source_table_display = if let Some(schema_name) = &table.schema {
814            format!("{}.{}", schema_name, table.name)
815        } else {
816            table.name.clone()
817        };
818        let source_schema = table.schema.as_deref();
819        let source_columns: HashMap<String, TypeFamily> = table
820            .columns
821            .iter()
822            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
823            .collect();
824
825        for source_col in &table.columns {
826            let Some(reference) = &source_col.references else {
827                continue;
828            };
829            let source_type = canonical_type_family(&source_col.data_type);
830
831            let Some(target_key) = resolve_reference_table_key(
832                &reference.table,
833                reference.schema.as_deref(),
834                source_schema,
835                schema_map,
836            ) else {
837                errors.push(reference_issue(
838                    strict,
839                    format!(
840                        "Foreign key reference '{}.{}' points to unknown table '{}'",
841                        source_table_display, source_col.name, reference.table
842                    ),
843                ));
844                continue;
845            };
846
847            let target_column = lower(&reference.column);
848            let Some(target_entry) = schema_map.get(&target_key) else {
849                errors.push(reference_issue(
850                    strict,
851                    format!(
852                        "Foreign key reference '{}.{}' points to unknown table '{}'",
853                        source_table_display, source_col.name, reference.table
854                    ),
855                ));
856                continue;
857            };
858
859            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
860                errors.push(reference_issue(
861                    strict,
862                    format!(
863                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
864                        source_table_display, source_col.name, target_key, reference.column
865                    ),
866                ));
867                continue;
868            };
869
870            if !key_types_compatible(source_type, target_type) {
871                errors.push(reference_issue(
872                    strict,
873                    format!(
874                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
875                        source_table_display,
876                        source_col.name,
877                        target_key,
878                        reference.column,
879                        type_family_name(source_type),
880                        type_family_name(target_type)
881                    ),
882                ));
883            }
884
885            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
886                if !target_key_hints.contains(&target_column) {
887                    errors.push(ValidationError::warning(
888                        format!(
889                            "Referenced column '{}.{}' is not marked as primary/unique key",
890                            target_key, reference.column
891                        ),
892                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
893                    ));
894                }
895            }
896        }
897
898        for foreign_key in &table.foreign_keys {
899            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
900                errors.push(reference_issue(
901                    strict,
902                    format!(
903                        "Table-level foreign key on '{}' has empty source or target column list",
904                        source_table_display
905                    ),
906                ));
907                continue;
908            }
909            if foreign_key.columns.len() != foreign_key.references.columns.len() {
910                errors.push(reference_issue(
911                    strict,
912                    format!(
913                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
914                        source_table_display,
915                        foreign_key.columns.len(),
916                        foreign_key.references.columns.len()
917                    ),
918                ));
919                continue;
920            }
921
922            let Some(target_key) = resolve_reference_table_key(
923                &foreign_key.references.table,
924                foreign_key.references.schema.as_deref(),
925                source_schema,
926                schema_map,
927            ) else {
928                errors.push(reference_issue(
929                    strict,
930                    format!(
931                        "Table-level foreign key on '{}' points to unknown table '{}'",
932                        source_table_display, foreign_key.references.table
933                    ),
934                ));
935                continue;
936            };
937
938            let Some(target_entry) = schema_map.get(&target_key) else {
939                errors.push(reference_issue(
940                    strict,
941                    format!(
942                        "Table-level foreign key on '{}' points to unknown table '{}'",
943                        source_table_display, foreign_key.references.table
944                    ),
945                ));
946                continue;
947            };
948
949            for (source_col, target_col) in foreign_key
950                .columns
951                .iter()
952                .zip(foreign_key.references.columns.iter())
953            {
954                let source_col_name = lower(source_col);
955                let target_col_name = lower(target_col);
956
957                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
958                    errors.push(reference_issue(
959                        strict,
960                        format!(
961                            "Table-level foreign key on '{}' references unknown source column '{}'",
962                            source_table_display, source_col
963                        ),
964                    ));
965                    continue;
966                };
967
968                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
969                    errors.push(reference_issue(
970                        strict,
971                        format!(
972                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
973                            source_table_display, target_key, target_col
974                        ),
975                    ));
976                    continue;
977                };
978
979                if !key_types_compatible(source_type, target_type) {
980                    errors.push(reference_issue(
981                        strict,
982                        format!(
983                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
984                            source_table_display,
985                            source_col,
986                            target_key,
987                            target_col,
988                            type_family_name(source_type),
989                            type_family_name(target_type)
990                        ),
991                    ));
992                }
993
994                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
995                    if !target_key_hints.contains(&target_col_name) {
996                        errors.push(ValidationError::warning(
997                            format!(
998                                "Referenced column '{}.{}' is not marked as primary/unique key",
999                                target_key, target_col
1000                            ),
1001                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1002                        ));
1003                    }
1004                }
1005            }
1006        }
1007    }
1008
1009    errors
1010}
1011
1012fn resolve_unqualified_column_type(
1013    column_name: &str,
1014    schema_map: &HashMap<String, TableSchemaEntry>,
1015    context: &TypeCheckContext,
1016) -> TypeFamily {
1017    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1018        context.referenced_tables.iter().collect()
1019    } else {
1020        schema_map.keys().collect()
1021    };
1022
1023    let mut families = HashSet::new();
1024    for table_name in candidate_tables {
1025        if let Some(table_schema) = schema_map.get(table_name) {
1026            if let Some(family) = table_schema.columns.get(column_name) {
1027                families.insert(*family);
1028            }
1029        }
1030    }
1031
1032    if families.len() == 1 {
1033        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1034    } else {
1035        TypeFamily::Unknown
1036    }
1037}
1038
1039fn resolve_column_type(
1040    column: &Column,
1041    schema_map: &HashMap<String, TableSchemaEntry>,
1042    context: &TypeCheckContext,
1043) -> TypeFamily {
1044    let column_name = lower(&column.name.name);
1045    if column_name.is_empty() {
1046        return TypeFamily::Unknown;
1047    }
1048
1049    if let Some(table) = &column.table {
1050        let mut table_key = lower(&table.name);
1051        if let Some(mapped) = context.table_aliases.get(&table_key) {
1052            table_key = mapped.clone();
1053        }
1054
1055        return schema_map
1056            .get(&table_key)
1057            .and_then(|t| t.columns.get(&column_name))
1058            .copied()
1059            .unwrap_or(TypeFamily::Unknown);
1060    }
1061
1062    resolve_unqualified_column_type(&column_name, schema_map, context)
1063}
1064
1065struct TypeInferenceSchema<'a> {
1066    schema_map: &'a HashMap<String, TableSchemaEntry>,
1067    context: &'a TypeCheckContext,
1068}
1069
1070impl TypeInferenceSchema<'_> {
1071    fn resolve_table_key(&self, table: &str) -> Option<String> {
1072        let mut table_key = lower(table);
1073        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1074            table_key = mapped.clone();
1075        }
1076        if self.schema_map.contains_key(&table_key) {
1077            Some(table_key)
1078        } else {
1079            None
1080        }
1081    }
1082}
1083
1084impl SqlSchema for TypeInferenceSchema<'_> {
1085    fn dialect(&self) -> Option<DialectType> {
1086        None
1087    }
1088
1089    fn add_table(
1090        &mut self,
1091        _table: &str,
1092        _columns: &[(String, DataType)],
1093        _dialect: Option<DialectType>,
1094    ) -> SchemaResult<()> {
1095        Err(SchemaError::InvalidStructure(
1096            "Type inference schema is read-only".to_string(),
1097        ))
1098    }
1099
1100    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1101        let table_key = self
1102            .resolve_table_key(table)
1103            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1104        let entry = self
1105            .schema_map
1106            .get(&table_key)
1107            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1108        Ok(entry.column_order.clone())
1109    }
1110
1111    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1112        let col_name = lower(column);
1113        if table.is_empty() {
1114            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1115            return if family == TypeFamily::Unknown {
1116                Err(SchemaError::ColumnNotFound {
1117                    table: "<unqualified>".to_string(),
1118                    column: column.to_string(),
1119                })
1120            } else {
1121                Ok(type_family_to_data_type(family))
1122            };
1123        }
1124
1125        let table_key = self
1126            .resolve_table_key(table)
1127            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1128        let entry = self
1129            .schema_map
1130            .get(&table_key)
1131            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1132        let family =
1133            entry
1134                .columns
1135                .get(&col_name)
1136                .copied()
1137                .ok_or_else(|| SchemaError::ColumnNotFound {
1138                    table: table.to_string(),
1139                    column: column.to_string(),
1140                })?;
1141        Ok(type_family_to_data_type(family))
1142    }
1143
1144    fn has_column(&self, table: &str, column: &str) -> bool {
1145        self.get_column_type(table, column).is_ok()
1146    }
1147
1148    fn supported_table_args(&self) -> &[&str] {
1149        TABLE_PARTS
1150    }
1151
1152    fn is_empty(&self) -> bool {
1153        self.schema_map.is_empty()
1154    }
1155
1156    fn depth(&self) -> usize {
1157        1
1158    }
1159}
1160
1161fn infer_expression_type_family(
1162    expr: &Expression,
1163    schema_map: &HashMap<String, TableSchemaEntry>,
1164    context: &TypeCheckContext,
1165) -> TypeFamily {
1166    let inference_schema = TypeInferenceSchema {
1167        schema_map,
1168        context,
1169    };
1170    if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1171        let family = data_type_family(&data_type);
1172        if family != TypeFamily::Unknown {
1173            return family;
1174        }
1175    }
1176
1177    infer_expression_type_family_fallback(expr, schema_map, context)
1178}
1179
1180fn infer_expression_type_family_fallback(
1181    expr: &Expression,
1182    schema_map: &HashMap<String, TableSchemaEntry>,
1183    context: &TypeCheckContext,
1184) -> TypeFamily {
1185    match expr {
1186        Expression::Literal(literal) => match literal {
1187            crate::expressions::Literal::Number(value) => {
1188                if value.contains('.') || value.contains('e') || value.contains('E') {
1189                    TypeFamily::Numeric
1190                } else {
1191                    TypeFamily::Integer
1192                }
1193            }
1194            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1195            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1196            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1197            crate::expressions::Literal::Timestamp(_)
1198            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1199            crate::expressions::Literal::HexString(_)
1200            | crate::expressions::Literal::BitString(_)
1201            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1202            _ => TypeFamily::String,
1203        },
1204        Expression::Boolean(_) => TypeFamily::Boolean,
1205        Expression::Null(_) => TypeFamily::Unknown,
1206        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1207        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1208            data_type_family(&cast.to)
1209        }
1210        Expression::Alias(alias) => {
1211            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1212        }
1213        Expression::Neg(unary) => {
1214            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1215        }
1216        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1217            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1218            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1219            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1220                TypeFamily::Unknown
1221            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1222                TypeFamily::Integer
1223            } else if left.is_numeric() && right.is_numeric() {
1224                TypeFamily::Numeric
1225            } else if left.is_temporal() || right.is_temporal() {
1226                left
1227            } else {
1228                TypeFamily::Unknown
1229            }
1230        }
1231        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1232        Expression::Concat(_) => TypeFamily::String,
1233        Expression::Eq(_)
1234        | Expression::Neq(_)
1235        | Expression::Lt(_)
1236        | Expression::Lte(_)
1237        | Expression::Gt(_)
1238        | Expression::Gte(_)
1239        | Expression::Like(_)
1240        | Expression::ILike(_)
1241        | Expression::And(_)
1242        | Expression::Or(_)
1243        | Expression::Not(_)
1244        | Expression::Between(_)
1245        | Expression::In(_)
1246        | Expression::IsNull(_)
1247        | Expression::IsTrue(_)
1248        | Expression::IsFalse(_)
1249        | Expression::Is(_) => TypeFamily::Boolean,
1250        Expression::Length(_) => TypeFamily::Integer,
1251        Expression::Upper(_)
1252        | Expression::Lower(_)
1253        | Expression::Trim(_)
1254        | Expression::LTrim(_)
1255        | Expression::RTrim(_)
1256        | Expression::Replace(_)
1257        | Expression::Substring(_)
1258        | Expression::Left(_)
1259        | Expression::Right(_)
1260        | Expression::Repeat(_)
1261        | Expression::Lpad(_)
1262        | Expression::Rpad(_)
1263        | Expression::ConcatWs(_) => TypeFamily::String,
1264        Expression::Abs(_)
1265        | Expression::Round(_)
1266        | Expression::Floor(_)
1267        | Expression::Ceil(_)
1268        | Expression::Power(_)
1269        | Expression::Sqrt(_)
1270        | Expression::Cbrt(_)
1271        | Expression::Ln(_)
1272        | Expression::Log(_)
1273        | Expression::Exp(_) => TypeFamily::Numeric,
1274        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1275        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1276        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1277        Expression::CurrentDate(_) => TypeFamily::Date,
1278        Expression::CurrentTime(_) => TypeFamily::Time,
1279        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1280            TypeFamily::Timestamp
1281        }
1282        Expression::Interval(_) => TypeFamily::Interval,
1283        _ => TypeFamily::Unknown,
1284    }
1285}
1286
1287fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1288    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1289        return true;
1290    }
1291    if left == right {
1292        return true;
1293    }
1294    if left.is_numeric() && right.is_numeric() {
1295        return true;
1296    }
1297    if left.is_temporal() && right.is_temporal() {
1298        return true;
1299    }
1300    false
1301}
1302
1303fn check_function_argument(
1304    errors: &mut Vec<ValidationError>,
1305    strict: bool,
1306    function_name: &str,
1307    arg_index: usize,
1308    family: TypeFamily,
1309    expected: &str,
1310    valid: bool,
1311) {
1312    if family == TypeFamily::Unknown || valid {
1313        return;
1314    }
1315
1316    errors.push(type_issue(
1317        strict,
1318        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1319        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1320        format!(
1321            "Function '{}' argument {} expects {}, found {}",
1322            function_name,
1323            arg_index + 1,
1324            expected,
1325            type_family_name(family)
1326        ),
1327    ));
1328}
1329
1330fn function_dispatch_name(name: &str) -> String {
1331    let upper = name
1332        .rsplit('.')
1333        .next()
1334        .unwrap_or(name)
1335        .trim()
1336        .to_uppercase();
1337    lower(canonical_typed_function_name_upper(&upper))
1338}
1339
1340fn check_generic_function(
1341    function: &Function,
1342    schema_map: &HashMap<String, TableSchemaEntry>,
1343    context: &TypeCheckContext,
1344    strict: bool,
1345    errors: &mut Vec<ValidationError>,
1346) {
1347    let name = function_dispatch_name(&function.name);
1348
1349    let arg_family = |index: usize| -> Option<TypeFamily> {
1350        function
1351            .args
1352            .get(index)
1353            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1354    };
1355
1356    match name.as_str() {
1357        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1358            if let Some(family) = arg_family(0) {
1359                check_function_argument(
1360                    errors,
1361                    strict,
1362                    &name,
1363                    0,
1364                    family,
1365                    "a numeric argument",
1366                    family.is_numeric(),
1367                );
1368            }
1369        }
1370        "round" | "floor" | "ceil" | "ceiling" => {
1371            if let Some(family) = arg_family(0) {
1372                check_function_argument(
1373                    errors,
1374                    strict,
1375                    &name,
1376                    0,
1377                    family,
1378                    "a numeric argument",
1379                    family.is_numeric(),
1380                );
1381            }
1382            if let Some(family) = arg_family(1) {
1383                check_function_argument(
1384                    errors,
1385                    strict,
1386                    &name,
1387                    1,
1388                    family,
1389                    "a numeric argument",
1390                    family.is_numeric(),
1391                );
1392            }
1393        }
1394        "power" | "pow" => {
1395            for i in [0_usize, 1_usize] {
1396                if let Some(family) = arg_family(i) {
1397                    check_function_argument(
1398                        errors,
1399                        strict,
1400                        &name,
1401                        i,
1402                        family,
1403                        "a numeric argument",
1404                        family.is_numeric(),
1405                    );
1406                }
1407            }
1408        }
1409        "length" | "char_length" | "character_length" => {
1410            if let Some(family) = arg_family(0) {
1411                check_function_argument(
1412                    errors,
1413                    strict,
1414                    &name,
1415                    0,
1416                    family,
1417                    "a string or binary argument",
1418                    is_string_or_binary(family),
1419                );
1420            }
1421        }
1422        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1423            if let Some(family) = arg_family(0) {
1424                check_function_argument(
1425                    errors,
1426                    strict,
1427                    &name,
1428                    0,
1429                    family,
1430                    "a string argument",
1431                    is_string_like(family),
1432                );
1433            }
1434        }
1435        "substring" | "substr" => {
1436            if let Some(family) = arg_family(0) {
1437                check_function_argument(
1438                    errors,
1439                    strict,
1440                    &name,
1441                    0,
1442                    family,
1443                    "a string argument",
1444                    is_string_like(family),
1445                );
1446            }
1447            if let Some(family) = arg_family(1) {
1448                check_function_argument(
1449                    errors,
1450                    strict,
1451                    &name,
1452                    1,
1453                    family,
1454                    "a numeric argument",
1455                    family.is_numeric(),
1456                );
1457            }
1458            if let Some(family) = arg_family(2) {
1459                check_function_argument(
1460                    errors,
1461                    strict,
1462                    &name,
1463                    2,
1464                    family,
1465                    "a numeric argument",
1466                    family.is_numeric(),
1467                );
1468            }
1469        }
1470        "replace" => {
1471            for i in [0_usize, 1_usize, 2_usize] {
1472                if let Some(family) = arg_family(i) {
1473                    check_function_argument(
1474                        errors,
1475                        strict,
1476                        &name,
1477                        i,
1478                        family,
1479                        "a string argument",
1480                        is_string_like(family),
1481                    );
1482                }
1483            }
1484        }
1485        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1486            if let Some(family) = arg_family(0) {
1487                check_function_argument(
1488                    errors,
1489                    strict,
1490                    &name,
1491                    0,
1492                    family,
1493                    "a string argument",
1494                    is_string_like(family),
1495                );
1496            }
1497            if let Some(family) = arg_family(1) {
1498                check_function_argument(
1499                    errors,
1500                    strict,
1501                    &name,
1502                    1,
1503                    family,
1504                    "a numeric argument",
1505                    family.is_numeric(),
1506                );
1507            }
1508            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1509                if let Some(family) = arg_family(2) {
1510                    check_function_argument(
1511                        errors,
1512                        strict,
1513                        &name,
1514                        2,
1515                        family,
1516                        "a string argument",
1517                        is_string_like(family),
1518                    );
1519                }
1520            }
1521        }
1522        _ => {}
1523    }
1524}
1525
1526#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1527struct DeclaredRelationship {
1528    source_table: String,
1529    source_column: String,
1530    target_table: String,
1531    target_column: String,
1532}
1533
1534fn build_declared_relationships(
1535    schema: &ValidationSchema,
1536    schema_map: &HashMap<String, TableSchemaEntry>,
1537) -> Vec<DeclaredRelationship> {
1538    let mut relationships = HashSet::new();
1539
1540    for table in &schema.tables {
1541        let Some(source_key) =
1542            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1543        else {
1544            continue;
1545        };
1546
1547        for column in &table.columns {
1548            let Some(reference) = &column.references else {
1549                continue;
1550            };
1551            let Some(target_key) = resolve_reference_table_key(
1552                &reference.table,
1553                reference.schema.as_deref(),
1554                table.schema.as_deref(),
1555                schema_map,
1556            ) else {
1557                continue;
1558            };
1559
1560            relationships.insert(DeclaredRelationship {
1561                source_table: source_key.clone(),
1562                source_column: lower(&column.name),
1563                target_table: target_key,
1564                target_column: lower(&reference.column),
1565            });
1566        }
1567
1568        for foreign_key in &table.foreign_keys {
1569            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1570                continue;
1571            }
1572            let Some(target_key) = resolve_reference_table_key(
1573                &foreign_key.references.table,
1574                foreign_key.references.schema.as_deref(),
1575                table.schema.as_deref(),
1576                schema_map,
1577            ) else {
1578                continue;
1579            };
1580
1581            for (source_col, target_col) in foreign_key
1582                .columns
1583                .iter()
1584                .zip(foreign_key.references.columns.iter())
1585            {
1586                relationships.insert(DeclaredRelationship {
1587                    source_table: source_key.clone(),
1588                    source_column: lower(source_col),
1589                    target_table: target_key.clone(),
1590                    target_column: lower(target_col),
1591                });
1592            }
1593        }
1594    }
1595
1596    relationships.into_iter().collect()
1597}
1598
1599fn resolve_column_binding(
1600    column: &Column,
1601    schema_map: &HashMap<String, TableSchemaEntry>,
1602    context: &TypeCheckContext,
1603    resolver: &mut Resolver<'_>,
1604) -> Option<(String, String)> {
1605    let column_name = lower(&column.name.name);
1606    if column_name.is_empty() {
1607        return None;
1608    }
1609
1610    if let Some(table) = &column.table {
1611        let mut table_key = lower(&table.name);
1612        if let Some(mapped) = context.table_aliases.get(&table_key) {
1613            table_key = mapped.clone();
1614        }
1615        if schema_map.contains_key(&table_key) {
1616            return Some((table_key, column_name));
1617        }
1618        return None;
1619    }
1620
1621    if let Some(resolved_source) = resolver.get_table(&column_name) {
1622        let mut table_key = lower(&resolved_source);
1623        if let Some(mapped) = context.table_aliases.get(&table_key) {
1624            table_key = mapped.clone();
1625        }
1626        if schema_map.contains_key(&table_key) {
1627            return Some((table_key, column_name));
1628        }
1629    }
1630
1631    let candidates: Vec<String> = context
1632        .referenced_tables
1633        .iter()
1634        .filter_map(|table_name| {
1635            schema_map
1636                .get(table_name)
1637                .filter(|entry| entry.columns.contains_key(&column_name))
1638                .map(|_| table_name.clone())
1639        })
1640        .collect();
1641    if candidates.len() == 1 {
1642        return Some((candidates[0].clone(), column_name));
1643    }
1644    None
1645}
1646
1647fn extract_join_equality_pairs(
1648    expr: &Expression,
1649    schema_map: &HashMap<String, TableSchemaEntry>,
1650    context: &TypeCheckContext,
1651    resolver: &mut Resolver<'_>,
1652    pairs: &mut Vec<((String, String), (String, String))>,
1653) {
1654    match expr {
1655        Expression::And(op) => {
1656            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1657            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1658        }
1659        Expression::Paren(paren) => {
1660            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1661        }
1662        Expression::Eq(op) => {
1663            let (Expression::Column(left_col), Expression::Column(right_col)) =
1664                (&op.left, &op.right)
1665            else {
1666                return;
1667            };
1668            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1669                return;
1670            };
1671            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1672            else {
1673                return;
1674            };
1675            pairs.push((left, right));
1676        }
1677        _ => {}
1678    }
1679}
1680
1681fn relationship_matches_pair(
1682    relationship: &DeclaredRelationship,
1683    left_table: &str,
1684    left_column: &str,
1685    right_table: &str,
1686    right_column: &str,
1687) -> bool {
1688    (relationship.source_table == left_table
1689        && relationship.source_column == left_column
1690        && relationship.target_table == right_table
1691        && relationship.target_column == right_column)
1692        || (relationship.source_table == right_table
1693            && relationship.source_column == right_column
1694            && relationship.target_table == left_table
1695            && relationship.target_column == left_column)
1696}
1697
1698fn resolved_table_key_from_expr(
1699    expr: &Expression,
1700    schema_map: &HashMap<String, TableSchemaEntry>,
1701) -> Option<String> {
1702    match expr {
1703        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1704        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1705        _ => None,
1706    }
1707}
1708
1709fn select_from_table_keys(
1710    select: &crate::expressions::Select,
1711    schema_map: &HashMap<String, TableSchemaEntry>,
1712) -> HashSet<String> {
1713    let mut keys = HashSet::new();
1714    if let Some(from_clause) = &select.from {
1715        for expr in &from_clause.expressions {
1716            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1717                keys.insert(key);
1718            }
1719        }
1720    }
1721    keys
1722}
1723
1724fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1725    matches!(
1726        kind,
1727        JoinKind::Natural
1728            | JoinKind::NaturalLeft
1729            | JoinKind::NaturalRight
1730            | JoinKind::NaturalFull
1731            | JoinKind::CrossApply
1732            | JoinKind::OuterApply
1733            | JoinKind::AsOf
1734            | JoinKind::AsOfLeft
1735            | JoinKind::AsOfRight
1736            | JoinKind::Lateral
1737            | JoinKind::LeftLateral
1738    )
1739}
1740
1741fn check_query_reference_quality(
1742    stmt: &Expression,
1743    schema_map: &HashMap<String, TableSchemaEntry>,
1744    resolver_schema: &MappingSchema,
1745    strict: bool,
1746    relationships: &[DeclaredRelationship],
1747) -> Vec<ValidationError> {
1748    let mut errors = Vec::new();
1749
1750    for node in stmt.dfs() {
1751        let Expression::Select(select) = node else {
1752            continue;
1753        };
1754
1755        let select_expr = Expression::Select(select.clone());
1756        let context = collect_type_check_context(&select_expr, schema_map);
1757        let scope = build_scope(&select_expr);
1758        let mut resolver = Resolver::new(&scope, resolver_schema, true);
1759
1760        if context.referenced_tables.len() > 1 {
1761            let using_columns: HashSet<String> = select
1762                .joins
1763                .iter()
1764                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1765                .collect();
1766
1767            let mut seen = HashSet::new();
1768            for column_expr in select_expr
1769                .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
1770            {
1771                let Expression::Column(column) = column_expr else {
1772                    continue;
1773                };
1774
1775                let col_name = lower(&column.name.name);
1776                if col_name.is_empty()
1777                    || using_columns.contains(&col_name)
1778                    || !seen.insert(col_name.clone())
1779                {
1780                    continue;
1781                }
1782
1783                if resolver.is_ambiguous(&col_name) {
1784                    let source_count = resolver.sources_for_column(&col_name).len();
1785                    errors.push(if strict {
1786                        ValidationError::error(
1787                            format!(
1788                                "Ambiguous unqualified column '{}' found in {} referenced tables",
1789                                col_name, source_count
1790                            ),
1791                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
1792                        )
1793                    } else {
1794                        ValidationError::warning(
1795                            format!(
1796                                "Ambiguous unqualified column '{}' found in {} referenced tables",
1797                                col_name, source_count
1798                            ),
1799                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1800                        )
1801                    });
1802                }
1803            }
1804        }
1805
1806        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
1807
1808        for join in &select.joins {
1809            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
1810            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
1811            let cartesian_like_kind = matches!(
1812                join.kind,
1813                JoinKind::Cross
1814                    | JoinKind::Implicit
1815                    | JoinKind::Array
1816                    | JoinKind::LeftArray
1817                    | JoinKind::Paste
1818            );
1819
1820            if right_table_key.is_some()
1821                && (cartesian_like_kind
1822                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
1823            {
1824                errors.push(ValidationError::warning(
1825                    "Potential cartesian join: JOIN without ON/USING condition",
1826                    validation_codes::W_CARTESIAN_JOIN,
1827                ));
1828            }
1829
1830            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
1831                if join.using.is_empty() {
1832                    let mut eq_pairs = Vec::new();
1833                    extract_join_equality_pairs(
1834                        on_expr,
1835                        schema_map,
1836                        &context,
1837                        &mut resolver,
1838                        &mut eq_pairs,
1839                    );
1840
1841                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
1842                        .iter()
1843                        .filter(|rel| {
1844                            cumulative_left_tables.contains(&rel.source_table)
1845                                && rel.target_table == right_key
1846                                || (cumulative_left_tables.contains(&rel.target_table)
1847                                    && rel.source_table == right_key)
1848                        })
1849                        .collect();
1850
1851                    if !relevant_relationships.is_empty() {
1852                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
1853                            relevant_relationships
1854                                .iter()
1855                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
1856                        });
1857                        if !uses_declared_fk {
1858                            errors.push(ValidationError::warning(
1859                                "JOIN predicate does not use declared foreign-key relationship columns",
1860                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
1861                            ));
1862                        }
1863                    }
1864                }
1865            }
1866
1867            if let Some(right_key) = right_table_key {
1868                cumulative_left_tables.insert(right_key);
1869            }
1870        }
1871    }
1872
1873    errors
1874}
1875
1876fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
1877    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1878        return true;
1879    }
1880    if left == right {
1881        return true;
1882    }
1883    if left.is_numeric() && right.is_numeric() {
1884        return true;
1885    }
1886    if left.is_temporal() && right.is_temporal() {
1887        return true;
1888    }
1889    false
1890}
1891
1892fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
1893    if left == TypeFamily::Unknown {
1894        return right;
1895    }
1896    if right == TypeFamily::Unknown {
1897        return left;
1898    }
1899    if left == right {
1900        return left;
1901    }
1902    if left.is_numeric() && right.is_numeric() {
1903        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
1904            return TypeFamily::Numeric;
1905        }
1906        return TypeFamily::Integer;
1907    }
1908    if left.is_temporal() && right.is_temporal() {
1909        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
1910            return TypeFamily::Timestamp;
1911        }
1912        if left == TypeFamily::Date || right == TypeFamily::Date {
1913            return TypeFamily::Date;
1914        }
1915        return TypeFamily::Time;
1916    }
1917    TypeFamily::Unknown
1918}
1919
1920fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
1921    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
1922        return true;
1923    }
1924    if target == source {
1925        return true;
1926    }
1927
1928    match target {
1929        TypeFamily::Boolean => source == TypeFamily::Boolean,
1930        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
1931        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
1932            source.is_temporal()
1933        }
1934        TypeFamily::String => true,
1935        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
1936        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
1937        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
1938        TypeFamily::Array => source == TypeFamily::Array,
1939        TypeFamily::Map => source == TypeFamily::Map,
1940        TypeFamily::Struct => source == TypeFamily::Struct,
1941        TypeFamily::Unknown => true,
1942    }
1943}
1944
1945fn projection_families(
1946    query_expr: &Expression,
1947    schema_map: &HashMap<String, TableSchemaEntry>,
1948) -> Option<Vec<TypeFamily>> {
1949    match query_expr {
1950        Expression::Select(select) => {
1951            if select
1952                .expressions
1953                .iter()
1954                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
1955            {
1956                return None;
1957            }
1958            let select_expr = Expression::Select(select.clone());
1959            let context = collect_type_check_context(&select_expr, schema_map);
1960            Some(
1961                select
1962                    .expressions
1963                    .iter()
1964                    .map(|e| infer_expression_type_family(e, schema_map, &context))
1965                    .collect(),
1966            )
1967        }
1968        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
1969        Expression::Union(union) => {
1970            let left = projection_families(&union.left, schema_map)?;
1971            let right = projection_families(&union.right, schema_map)?;
1972            if left.len() != right.len() {
1973                return None;
1974            }
1975            Some(
1976                left.into_iter()
1977                    .zip(right)
1978                    .map(|(l, r)| merged_setop_family(l, r))
1979                    .collect(),
1980            )
1981        }
1982        Expression::Intersect(intersect) => {
1983            let left = projection_families(&intersect.left, schema_map)?;
1984            let right = projection_families(&intersect.right, schema_map)?;
1985            if left.len() != right.len() {
1986                return None;
1987            }
1988            Some(
1989                left.into_iter()
1990                    .zip(right)
1991                    .map(|(l, r)| merged_setop_family(l, r))
1992                    .collect(),
1993            )
1994        }
1995        Expression::Except(except) => {
1996            let left = projection_families(&except.left, schema_map)?;
1997            let right = projection_families(&except.right, schema_map)?;
1998            if left.len() != right.len() {
1999                return None;
2000            }
2001            Some(
2002                left.into_iter()
2003                    .zip(right)
2004                    .map(|(l, r)| merged_setop_family(l, r))
2005                    .collect(),
2006            )
2007        }
2008        Expression::Values(values) => {
2009            let first_row = values.expressions.first()?;
2010            let context = TypeCheckContext::default();
2011            Some(
2012                first_row
2013                    .expressions
2014                    .iter()
2015                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2016                    .collect(),
2017            )
2018        }
2019        _ => None,
2020    }
2021}
2022
2023fn check_set_operation_compatibility(
2024    op_name: &str,
2025    left_expr: &Expression,
2026    right_expr: &Expression,
2027    schema_map: &HashMap<String, TableSchemaEntry>,
2028    strict: bool,
2029    errors: &mut Vec<ValidationError>,
2030) {
2031    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2032        return;
2033    };
2034    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2035        return;
2036    };
2037
2038    if left_projection.len() != right_projection.len() {
2039        errors.push(type_issue(
2040            strict,
2041            validation_codes::E_SETOP_ARITY_MISMATCH,
2042            validation_codes::W_SETOP_IMPLICIT_COERCION,
2043            format!(
2044                "{} operands return different column counts: left {}, right {}",
2045                op_name,
2046                left_projection.len(),
2047                right_projection.len()
2048            ),
2049        ));
2050        return;
2051    }
2052
2053    for (idx, (left, right)) in left_projection
2054        .into_iter()
2055        .zip(right_projection)
2056        .enumerate()
2057    {
2058        if !are_setop_compatible(left, right) {
2059            errors.push(type_issue(
2060                strict,
2061                validation_codes::E_SETOP_TYPE_MISMATCH,
2062                validation_codes::W_SETOP_IMPLICIT_COERCION,
2063                format!(
2064                    "{} column {} has incompatible types: {} vs {}",
2065                    op_name,
2066                    idx + 1,
2067                    type_family_name(left),
2068                    type_family_name(right)
2069                ),
2070            ));
2071        }
2072    }
2073}
2074
2075fn check_insert_assignments(
2076    stmt: &Expression,
2077    insert: &Insert,
2078    schema_map: &HashMap<String, TableSchemaEntry>,
2079    strict: bool,
2080    errors: &mut Vec<ValidationError>,
2081) {
2082    let Some((target_table_key, table_schema)) =
2083        resolve_table_schema_entry(&insert.table, schema_map)
2084    else {
2085        return;
2086    };
2087
2088    let mut target_columns = Vec::new();
2089    if insert.columns.is_empty() {
2090        target_columns.extend(table_schema.column_order.iter().cloned());
2091    } else {
2092        for column in &insert.columns {
2093            let col_name = lower(&column.name);
2094            if table_schema.columns.contains_key(&col_name) {
2095                target_columns.push(col_name);
2096            } else {
2097                errors.push(if strict {
2098                    ValidationError::error(
2099                        format!(
2100                            "Unknown column '{}' in table '{}'",
2101                            column.name, target_table_key
2102                        ),
2103                        validation_codes::E_UNKNOWN_COLUMN,
2104                    )
2105                } else {
2106                    ValidationError::warning(
2107                        format!(
2108                            "Unknown column '{}' in table '{}'",
2109                            column.name, target_table_key
2110                        ),
2111                        validation_codes::E_UNKNOWN_COLUMN,
2112                    )
2113                });
2114            }
2115        }
2116    }
2117
2118    if target_columns.is_empty() {
2119        return;
2120    }
2121
2122    let context = collect_type_check_context(stmt, schema_map);
2123
2124    if !insert.default_values {
2125        for (row_idx, row) in insert.values.iter().enumerate() {
2126            if row.len() != target_columns.len() {
2127                errors.push(type_issue(
2128                    strict,
2129                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2130                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2131                    format!(
2132                        "INSERT row {} has {} values but target has {} columns",
2133                        row_idx + 1,
2134                        row.len(),
2135                        target_columns.len()
2136                    ),
2137                ));
2138                continue;
2139            }
2140
2141            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2142                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2143                    continue;
2144                };
2145                let source_family = infer_expression_type_family(value, schema_map, &context);
2146                if !are_assignment_compatible(target_family, source_family) {
2147                    errors.push(type_issue(
2148                        strict,
2149                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2150                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2151                        format!(
2152                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2153                            target_table_key,
2154                            target_column,
2155                            type_family_name(target_family),
2156                            type_family_name(source_family)
2157                        ),
2158                    ));
2159                }
2160            }
2161        }
2162    }
2163
2164    if let Some(query) = &insert.query {
2165        // DuckDB BY NAME maps source columns by name, not position.
2166        if insert.by_name {
2167            return;
2168        }
2169
2170        let Some(source_projection) = projection_families(query, schema_map) else {
2171            return;
2172        };
2173
2174        if source_projection.len() != target_columns.len() {
2175            errors.push(type_issue(
2176                strict,
2177                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2178                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2179                format!(
2180                    "INSERT source query has {} columns but target has {} columns",
2181                    source_projection.len(),
2182                    target_columns.len()
2183                ),
2184            ));
2185            return;
2186        }
2187
2188        for (source_family, target_column) in
2189            source_projection.into_iter().zip(target_columns.iter())
2190        {
2191            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2192                continue;
2193            };
2194            if !are_assignment_compatible(target_family, source_family) {
2195                errors.push(type_issue(
2196                    strict,
2197                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2198                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2199                    format!(
2200                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2201                        target_table_key,
2202                        target_column,
2203                        type_family_name(target_family),
2204                        type_family_name(source_family)
2205                    ),
2206                ));
2207            }
2208        }
2209    }
2210}
2211
2212fn check_update_assignments(
2213    stmt: &Expression,
2214    update: &Update,
2215    schema_map: &HashMap<String, TableSchemaEntry>,
2216    strict: bool,
2217    errors: &mut Vec<ValidationError>,
2218) {
2219    let Some((target_table_key, table_schema)) =
2220        resolve_table_schema_entry(&update.table, schema_map)
2221    else {
2222        return;
2223    };
2224
2225    let context = collect_type_check_context(stmt, schema_map);
2226
2227    for (column, value) in &update.set {
2228        let col_name = lower(&column.name);
2229        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2230            errors.push(if strict {
2231                ValidationError::error(
2232                    format!(
2233                        "Unknown column '{}' in table '{}'",
2234                        column.name, target_table_key
2235                    ),
2236                    validation_codes::E_UNKNOWN_COLUMN,
2237                )
2238            } else {
2239                ValidationError::warning(
2240                    format!(
2241                        "Unknown column '{}' in table '{}'",
2242                        column.name, target_table_key
2243                    ),
2244                    validation_codes::E_UNKNOWN_COLUMN,
2245                )
2246            });
2247            continue;
2248        };
2249
2250        let source_family = infer_expression_type_family(value, schema_map, &context);
2251        if !are_assignment_compatible(target_family, source_family) {
2252            errors.push(type_issue(
2253                strict,
2254                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2255                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2256                format!(
2257                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2258                    target_table_key,
2259                    col_name,
2260                    type_family_name(target_family),
2261                    type_family_name(source_family)
2262                ),
2263            ));
2264        }
2265    }
2266}
2267
2268fn check_types(
2269    stmt: &Expression,
2270    schema_map: &HashMap<String, TableSchemaEntry>,
2271    strict: bool,
2272) -> Vec<ValidationError> {
2273    let mut errors = Vec::new();
2274    let context = collect_type_check_context(stmt, schema_map);
2275
2276    for node in stmt.dfs() {
2277        match node {
2278            Expression::Insert(insert) => {
2279                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2280            }
2281            Expression::Update(update) => {
2282                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2283            }
2284            Expression::Union(union) => {
2285                check_set_operation_compatibility(
2286                    "UNION",
2287                    &union.left,
2288                    &union.right,
2289                    schema_map,
2290                    strict,
2291                    &mut errors,
2292                );
2293            }
2294            Expression::Intersect(intersect) => {
2295                check_set_operation_compatibility(
2296                    "INTERSECT",
2297                    &intersect.left,
2298                    &intersect.right,
2299                    schema_map,
2300                    strict,
2301                    &mut errors,
2302                );
2303            }
2304            Expression::Except(except) => {
2305                check_set_operation_compatibility(
2306                    "EXCEPT",
2307                    &except.left,
2308                    &except.right,
2309                    schema_map,
2310                    strict,
2311                    &mut errors,
2312                );
2313            }
2314            Expression::Select(select) => {
2315                if let Some(prewhere) = &select.prewhere {
2316                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2317                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2318                        errors.push(type_issue(
2319                            strict,
2320                            validation_codes::E_INVALID_PREDICATE_TYPE,
2321                            validation_codes::W_PREDICATE_NULLABILITY,
2322                            format!(
2323                                "PREWHERE clause expects a boolean predicate, found {}",
2324                                type_family_name(family)
2325                            ),
2326                        ));
2327                    }
2328                }
2329
2330                if let Some(where_clause) = &select.where_clause {
2331                    let family =
2332                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2333                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2334                        errors.push(type_issue(
2335                            strict,
2336                            validation_codes::E_INVALID_PREDICATE_TYPE,
2337                            validation_codes::W_PREDICATE_NULLABILITY,
2338                            format!(
2339                                "WHERE clause expects a boolean predicate, found {}",
2340                                type_family_name(family)
2341                            ),
2342                        ));
2343                    }
2344                }
2345
2346                if let Some(having_clause) = &select.having {
2347                    let family =
2348                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2349                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2350                        errors.push(type_issue(
2351                            strict,
2352                            validation_codes::E_INVALID_PREDICATE_TYPE,
2353                            validation_codes::W_PREDICATE_NULLABILITY,
2354                            format!(
2355                                "HAVING clause expects a boolean predicate, found {}",
2356                                type_family_name(family)
2357                            ),
2358                        ));
2359                    }
2360                }
2361
2362                for join in &select.joins {
2363                    if let Some(on) = &join.on {
2364                        let family = infer_expression_type_family(on, schema_map, &context);
2365                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2366                            errors.push(type_issue(
2367                                strict,
2368                                validation_codes::E_INVALID_PREDICATE_TYPE,
2369                                validation_codes::W_PREDICATE_NULLABILITY,
2370                                format!(
2371                                    "JOIN ON expects a boolean predicate, found {}",
2372                                    type_family_name(family)
2373                                ),
2374                            ));
2375                        }
2376                    }
2377                    if let Some(match_condition) = &join.match_condition {
2378                        let family =
2379                            infer_expression_type_family(match_condition, schema_map, &context);
2380                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2381                            errors.push(type_issue(
2382                                strict,
2383                                validation_codes::E_INVALID_PREDICATE_TYPE,
2384                                validation_codes::W_PREDICATE_NULLABILITY,
2385                                format!(
2386                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2387                                    type_family_name(family)
2388                                ),
2389                            ));
2390                        }
2391                    }
2392                }
2393            }
2394            Expression::Where(where_clause) => {
2395                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2396                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2397                    errors.push(type_issue(
2398                        strict,
2399                        validation_codes::E_INVALID_PREDICATE_TYPE,
2400                        validation_codes::W_PREDICATE_NULLABILITY,
2401                        format!(
2402                            "WHERE clause expects a boolean predicate, found {}",
2403                            type_family_name(family)
2404                        ),
2405                    ));
2406                }
2407            }
2408            Expression::Having(having_clause) => {
2409                let family =
2410                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2411                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2412                    errors.push(type_issue(
2413                        strict,
2414                        validation_codes::E_INVALID_PREDICATE_TYPE,
2415                        validation_codes::W_PREDICATE_NULLABILITY,
2416                        format!(
2417                            "HAVING clause expects a boolean predicate, found {}",
2418                            type_family_name(family)
2419                        ),
2420                    ));
2421                }
2422            }
2423            Expression::And(op) | Expression::Or(op) => {
2424                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2425                    let family = infer_expression_type_family(expr, schema_map, &context);
2426                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2427                        errors.push(type_issue(
2428                            strict,
2429                            validation_codes::E_INVALID_PREDICATE_TYPE,
2430                            validation_codes::W_PREDICATE_NULLABILITY,
2431                            format!(
2432                                "Logical {} operand expects boolean, found {}",
2433                                side,
2434                                type_family_name(family)
2435                            ),
2436                        ));
2437                    }
2438                }
2439            }
2440            Expression::Not(unary) => {
2441                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2442                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2443                    errors.push(type_issue(
2444                        strict,
2445                        validation_codes::E_INVALID_PREDICATE_TYPE,
2446                        validation_codes::W_PREDICATE_NULLABILITY,
2447                        format!("NOT expects boolean, found {}", type_family_name(family)),
2448                    ));
2449                }
2450            }
2451            Expression::Eq(op)
2452            | Expression::Neq(op)
2453            | Expression::Lt(op)
2454            | Expression::Lte(op)
2455            | Expression::Gt(op)
2456            | Expression::Gte(op) => {
2457                let left = infer_expression_type_family(&op.left, schema_map, &context);
2458                let right = infer_expression_type_family(&op.right, schema_map, &context);
2459                if !are_comparable(left, right) {
2460                    errors.push(type_issue(
2461                        strict,
2462                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2463                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2464                        format!(
2465                            "Incompatible comparison between {} and {}",
2466                            type_family_name(left),
2467                            type_family_name(right)
2468                        ),
2469                    ));
2470                }
2471            }
2472            Expression::Like(op) | Expression::ILike(op) => {
2473                let left = infer_expression_type_family(&op.left, schema_map, &context);
2474                let right = infer_expression_type_family(&op.right, schema_map, &context);
2475                if left != TypeFamily::Unknown
2476                    && right != TypeFamily::Unknown
2477                    && (!is_string_like(left) || !is_string_like(right))
2478                {
2479                    errors.push(type_issue(
2480                        strict,
2481                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2482                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2483                        format!(
2484                            "LIKE/ILIKE expects string operands, found {} and {}",
2485                            type_family_name(left),
2486                            type_family_name(right)
2487                        ),
2488                    ));
2489                }
2490            }
2491            Expression::Between(between) => {
2492                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2493                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2494                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2495
2496                if !are_comparable(this_family, low_family)
2497                    || !are_comparable(this_family, high_family)
2498                {
2499                    errors.push(type_issue(
2500                        strict,
2501                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2502                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2503                        format!(
2504                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2505                            type_family_name(this_family),
2506                            type_family_name(low_family),
2507                            type_family_name(high_family)
2508                        ),
2509                    ));
2510                }
2511            }
2512            Expression::In(in_expr) => {
2513                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2514                for value in &in_expr.expressions {
2515                    let item_family = infer_expression_type_family(value, schema_map, &context);
2516                    if !are_comparable(this_family, item_family) {
2517                        errors.push(type_issue(
2518                            strict,
2519                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2520                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2521                            format!(
2522                                "IN item type {} is incompatible with {}",
2523                                type_family_name(item_family),
2524                                type_family_name(this_family)
2525                            ),
2526                        ));
2527                        break;
2528                    }
2529                }
2530            }
2531            Expression::Add(op)
2532            | Expression::Sub(op)
2533            | Expression::Mul(op)
2534            | Expression::Div(op)
2535            | Expression::Mod(op) => {
2536                let left = infer_expression_type_family(&op.left, schema_map, &context);
2537                let right = infer_expression_type_family(&op.right, schema_map, &context);
2538
2539                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2540                    continue;
2541                }
2542
2543                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2544                    && ((left.is_temporal() && right.is_numeric())
2545                        || (right.is_temporal() && left.is_numeric())
2546                        || (matches!(node, Expression::Sub(_))
2547                            && left.is_temporal()
2548                            && right.is_temporal()));
2549
2550                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2551                    errors.push(type_issue(
2552                        strict,
2553                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2554                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2555                        format!(
2556                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2557                            type_family_name(left),
2558                            type_family_name(right)
2559                        ),
2560                    ));
2561                }
2562            }
2563            Expression::Function(function) => {
2564                check_generic_function(function, schema_map, &context, strict, &mut errors);
2565            }
2566            Expression::Upper(func)
2567            | Expression::Lower(func)
2568            | Expression::LTrim(func)
2569            | Expression::RTrim(func)
2570            | Expression::Reverse(func) => {
2571                let family = infer_expression_type_family(&func.this, schema_map, &context);
2572                check_function_argument(
2573                    &mut errors,
2574                    strict,
2575                    "string_function",
2576                    0,
2577                    family,
2578                    "a string argument",
2579                    is_string_like(family),
2580                );
2581            }
2582            Expression::Length(func) => {
2583                let family = infer_expression_type_family(&func.this, schema_map, &context);
2584                check_function_argument(
2585                    &mut errors,
2586                    strict,
2587                    "length",
2588                    0,
2589                    family,
2590                    "a string or binary argument",
2591                    is_string_or_binary(family),
2592                );
2593            }
2594            Expression::Trim(func) => {
2595                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2596                check_function_argument(
2597                    &mut errors,
2598                    strict,
2599                    "trim",
2600                    0,
2601                    this_family,
2602                    "a string argument",
2603                    is_string_like(this_family),
2604                );
2605                if let Some(chars) = &func.characters {
2606                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2607                    check_function_argument(
2608                        &mut errors,
2609                        strict,
2610                        "trim",
2611                        1,
2612                        chars_family,
2613                        "a string argument",
2614                        is_string_like(chars_family),
2615                    );
2616                }
2617            }
2618            Expression::Substring(func) => {
2619                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2620                check_function_argument(
2621                    &mut errors,
2622                    strict,
2623                    "substring",
2624                    0,
2625                    this_family,
2626                    "a string argument",
2627                    is_string_like(this_family),
2628                );
2629
2630                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2631                check_function_argument(
2632                    &mut errors,
2633                    strict,
2634                    "substring",
2635                    1,
2636                    start_family,
2637                    "a numeric argument",
2638                    start_family.is_numeric(),
2639                );
2640                if let Some(length) = &func.length {
2641                    let length_family = infer_expression_type_family(length, schema_map, &context);
2642                    check_function_argument(
2643                        &mut errors,
2644                        strict,
2645                        "substring",
2646                        2,
2647                        length_family,
2648                        "a numeric argument",
2649                        length_family.is_numeric(),
2650                    );
2651                }
2652            }
2653            Expression::Replace(func) => {
2654                for (arg, idx) in [
2655                    (&func.this, 0_usize),
2656                    (&func.old, 1_usize),
2657                    (&func.new, 2_usize),
2658                ] {
2659                    let family = infer_expression_type_family(arg, schema_map, &context);
2660                    check_function_argument(
2661                        &mut errors,
2662                        strict,
2663                        "replace",
2664                        idx,
2665                        family,
2666                        "a string argument",
2667                        is_string_like(family),
2668                    );
2669                }
2670            }
2671            Expression::Left(func) | Expression::Right(func) => {
2672                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2673                check_function_argument(
2674                    &mut errors,
2675                    strict,
2676                    "left_right",
2677                    0,
2678                    this_family,
2679                    "a string argument",
2680                    is_string_like(this_family),
2681                );
2682                let length_family =
2683                    infer_expression_type_family(&func.length, schema_map, &context);
2684                check_function_argument(
2685                    &mut errors,
2686                    strict,
2687                    "left_right",
2688                    1,
2689                    length_family,
2690                    "a numeric argument",
2691                    length_family.is_numeric(),
2692                );
2693            }
2694            Expression::Repeat(func) => {
2695                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2696                check_function_argument(
2697                    &mut errors,
2698                    strict,
2699                    "repeat",
2700                    0,
2701                    this_family,
2702                    "a string argument",
2703                    is_string_like(this_family),
2704                );
2705                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2706                check_function_argument(
2707                    &mut errors,
2708                    strict,
2709                    "repeat",
2710                    1,
2711                    times_family,
2712                    "a numeric argument",
2713                    times_family.is_numeric(),
2714                );
2715            }
2716            Expression::Lpad(func) | Expression::Rpad(func) => {
2717                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2718                check_function_argument(
2719                    &mut errors,
2720                    strict,
2721                    "pad",
2722                    0,
2723                    this_family,
2724                    "a string argument",
2725                    is_string_like(this_family),
2726                );
2727                let length_family =
2728                    infer_expression_type_family(&func.length, schema_map, &context);
2729                check_function_argument(
2730                    &mut errors,
2731                    strict,
2732                    "pad",
2733                    1,
2734                    length_family,
2735                    "a numeric argument",
2736                    length_family.is_numeric(),
2737                );
2738                if let Some(fill) = &func.fill {
2739                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
2740                    check_function_argument(
2741                        &mut errors,
2742                        strict,
2743                        "pad",
2744                        2,
2745                        fill_family,
2746                        "a string argument",
2747                        is_string_like(fill_family),
2748                    );
2749                }
2750            }
2751            Expression::Abs(func)
2752            | Expression::Sqrt(func)
2753            | Expression::Cbrt(func)
2754            | Expression::Ln(func)
2755            | Expression::Exp(func) => {
2756                let family = infer_expression_type_family(&func.this, schema_map, &context);
2757                check_function_argument(
2758                    &mut errors,
2759                    strict,
2760                    "numeric_function",
2761                    0,
2762                    family,
2763                    "a numeric argument",
2764                    family.is_numeric(),
2765                );
2766            }
2767            Expression::Round(func) => {
2768                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2769                check_function_argument(
2770                    &mut errors,
2771                    strict,
2772                    "round",
2773                    0,
2774                    this_family,
2775                    "a numeric argument",
2776                    this_family.is_numeric(),
2777                );
2778                if let Some(decimals) = &func.decimals {
2779                    let decimals_family =
2780                        infer_expression_type_family(decimals, schema_map, &context);
2781                    check_function_argument(
2782                        &mut errors,
2783                        strict,
2784                        "round",
2785                        1,
2786                        decimals_family,
2787                        "a numeric argument",
2788                        decimals_family.is_numeric(),
2789                    );
2790                }
2791            }
2792            Expression::Floor(func) => {
2793                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2794                check_function_argument(
2795                    &mut errors,
2796                    strict,
2797                    "floor",
2798                    0,
2799                    this_family,
2800                    "a numeric argument",
2801                    this_family.is_numeric(),
2802                );
2803                if let Some(scale) = &func.scale {
2804                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
2805                    check_function_argument(
2806                        &mut errors,
2807                        strict,
2808                        "floor",
2809                        1,
2810                        scale_family,
2811                        "a numeric argument",
2812                        scale_family.is_numeric(),
2813                    );
2814                }
2815            }
2816            Expression::Ceil(func) => {
2817                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2818                check_function_argument(
2819                    &mut errors,
2820                    strict,
2821                    "ceil",
2822                    0,
2823                    this_family,
2824                    "a numeric argument",
2825                    this_family.is_numeric(),
2826                );
2827                if let Some(decimals) = &func.decimals {
2828                    let decimals_family =
2829                        infer_expression_type_family(decimals, schema_map, &context);
2830                    check_function_argument(
2831                        &mut errors,
2832                        strict,
2833                        "ceil",
2834                        1,
2835                        decimals_family,
2836                        "a numeric argument",
2837                        decimals_family.is_numeric(),
2838                    );
2839                }
2840            }
2841            Expression::Power(func) => {
2842                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
2843                check_function_argument(
2844                    &mut errors,
2845                    strict,
2846                    "power",
2847                    0,
2848                    left_family,
2849                    "a numeric argument",
2850                    left_family.is_numeric(),
2851                );
2852                let right_family =
2853                    infer_expression_type_family(&func.expression, schema_map, &context);
2854                check_function_argument(
2855                    &mut errors,
2856                    strict,
2857                    "power",
2858                    1,
2859                    right_family,
2860                    "a numeric argument",
2861                    right_family.is_numeric(),
2862                );
2863            }
2864            Expression::Log(func) => {
2865                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2866                check_function_argument(
2867                    &mut errors,
2868                    strict,
2869                    "log",
2870                    0,
2871                    this_family,
2872                    "a numeric argument",
2873                    this_family.is_numeric(),
2874                );
2875                if let Some(base) = &func.base {
2876                    let base_family = infer_expression_type_family(base, schema_map, &context);
2877                    check_function_argument(
2878                        &mut errors,
2879                        strict,
2880                        "log",
2881                        1,
2882                        base_family,
2883                        "a numeric argument",
2884                        base_family.is_numeric(),
2885                    );
2886                }
2887            }
2888            _ => {}
2889        }
2890    }
2891
2892    errors
2893}
2894
2895fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
2896    let mut errors = Vec::new();
2897
2898    let Expression::Select(select) = stmt else {
2899        return errors;
2900    };
2901    let select_expr = Expression::Select(select.clone());
2902
2903    // W001: SELECT * is discouraged
2904    if !select_expr
2905        .find_all(|e| matches!(e, Expression::Star(_)))
2906        .is_empty()
2907    {
2908        errors.push(ValidationError::warning(
2909            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
2910            validation_codes::W_SELECT_STAR,
2911        ));
2912    }
2913
2914    // W002: aggregate + non-aggregate columns without GROUP BY
2915    let aggregate_count = get_aggregate_functions(&select_expr).len();
2916    if aggregate_count > 0 && select.group_by.is_none() {
2917        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
2918            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
2919                && get_aggregate_functions(expr).is_empty()
2920        });
2921
2922        if has_non_aggregate_column {
2923            errors.push(ValidationError::warning(
2924                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
2925                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
2926            ));
2927        }
2928    }
2929
2930    // W003: DISTINCT with ORDER BY
2931    if select.distinct && select.order_by.is_some() {
2932        errors.push(ValidationError::warning(
2933            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
2934            validation_codes::W_DISTINCT_ORDER_BY,
2935        ));
2936    }
2937
2938    // W004: LIMIT without ORDER BY
2939    if select.limit.is_some() && select.order_by.is_none() {
2940        errors.push(ValidationError::warning(
2941            "LIMIT without ORDER BY produces non-deterministic results",
2942            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
2943        ));
2944    }
2945
2946    errors
2947}
2948
2949fn validate_statement_with_schema(
2950    stmt: &Expression,
2951    schema_map: &HashMap<String, TableSchemaEntry>,
2952    strict: bool,
2953) -> Vec<ValidationError> {
2954    let mut errors = Vec::new();
2955    let cte_aliases = collect_cte_aliases(stmt);
2956
2957    let mut referenced_tables: HashSet<String> = HashSet::new();
2958    let mut table_aliases: HashMap<String, String> = HashMap::new();
2959    let mut seen_tables: HashSet<String> = HashSet::new();
2960
2961    // Table validation (E200)
2962    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
2963        let Expression::Table(table) = node else {
2964            continue;
2965        };
2966
2967        if cte_aliases.contains(&lower(&table.name.name)) {
2968            continue;
2969        }
2970
2971        let resolved_key = table_ref_candidates(table)
2972            .into_iter()
2973            .find(|k| schema_map.contains_key(k));
2974        let table_key = resolved_key
2975            .clone()
2976            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
2977
2978        // Keep alias map for all seen table refs, including unknown ones.
2979        if let Some(alias) = &table.alias {
2980            table_aliases.insert(lower(&alias.name), table_key.clone());
2981        }
2982
2983        if !seen_tables.insert(table_key.clone()) {
2984            continue;
2985        }
2986        referenced_tables.insert(table_key.clone());
2987
2988        if resolved_key.is_none() {
2989            let err = if strict {
2990                ValidationError::error(
2991                    format!("Unknown table '{}'", table_ref_display_name(table)),
2992                    validation_codes::E_UNKNOWN_TABLE,
2993                )
2994            } else {
2995                ValidationError::warning(
2996                    format!("Unknown table '{}'", table_ref_display_name(table)),
2997                    validation_codes::E_UNKNOWN_TABLE,
2998                )
2999            };
3000            errors.push(err);
3001        }
3002    }
3003
3004    // Column validation (E201)
3005    for node in stmt.find_all(|e| matches!(e, Expression::Column(_))) {
3006        let Expression::Column(column) = node else {
3007            continue;
3008        };
3009
3010        let col_name = lower(&column.name.name);
3011        if col_name.is_empty() {
3012            continue;
3013        }
3014
3015        let mut table_name = column.table.as_ref().map(|t| lower(&t.name));
3016        if let Some(name) = &table_name {
3017            if let Some(mapped) = table_aliases.get(name) {
3018                table_name = Some(mapped.clone());
3019            }
3020        }
3021
3022        if let Some(table_name) = table_name {
3023            if let Some(table_schema) = schema_map.get(&table_name) {
3024                if !table_schema.columns.contains_key(&col_name) {
3025                    let err = if strict {
3026                        ValidationError::error(
3027                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3028                            validation_codes::E_UNKNOWN_COLUMN,
3029                        )
3030                    } else {
3031                        ValidationError::warning(
3032                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3033                            validation_codes::E_UNKNOWN_COLUMN,
3034                        )
3035                    };
3036                    errors.push(err);
3037                }
3038            }
3039            continue;
3040        }
3041
3042        if referenced_tables.len() == 1 {
3043            if let Some(single_table) = referenced_tables.iter().next() {
3044                if let Some(table_schema) = schema_map.get(single_table) {
3045                    if !table_schema.columns.contains_key(&col_name) {
3046                        let err = if strict {
3047                            ValidationError::error(
3048                                format!(
3049                                    "Unknown column '{}' in table '{}'",
3050                                    col_name, single_table
3051                                ),
3052                                validation_codes::E_UNKNOWN_COLUMN,
3053                            )
3054                        } else {
3055                            ValidationError::warning(
3056                                format!(
3057                                    "Unknown column '{}' in table '{}'",
3058                                    col_name, single_table
3059                                ),
3060                                validation_codes::E_UNKNOWN_COLUMN,
3061                            )
3062                        };
3063                        errors.push(err);
3064                    }
3065                }
3066            }
3067        } else if referenced_tables.len() > 1 {
3068            let found = referenced_tables.iter().any(|table_name| {
3069                schema_map
3070                    .get(table_name)
3071                    .map(|s| s.columns.contains_key(&col_name))
3072                    .unwrap_or(false)
3073            });
3074
3075            if !found {
3076                let err = if strict {
3077                    ValidationError::error(
3078                        format!(
3079                            "Unknown column '{}' (not found in any referenced table)",
3080                            col_name
3081                        ),
3082                        validation_codes::E_UNKNOWN_COLUMN,
3083                    )
3084                } else {
3085                    ValidationError::warning(
3086                        format!(
3087                            "Unknown column '{}' (not found in any referenced table)",
3088                            col_name
3089                        ),
3090                        validation_codes::E_UNKNOWN_COLUMN,
3091                    )
3092                };
3093                errors.push(err);
3094            }
3095        } else if !schema_map.is_empty() {
3096            let found = schema_map
3097                .values()
3098                .any(|table| table.columns.contains_key(&col_name));
3099            if !found {
3100                let err = if strict {
3101                    ValidationError::error(
3102                        format!("Unknown column '{}'", col_name),
3103                        validation_codes::E_UNKNOWN_COLUMN,
3104                    )
3105                } else {
3106                    ValidationError::warning(
3107                        format!("Unknown column '{}'", col_name),
3108                        validation_codes::E_UNKNOWN_COLUMN,
3109                    )
3110                };
3111                errors.push(err);
3112            }
3113        }
3114    }
3115
3116    errors
3117}
3118
3119/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3120pub fn validate_with_schema(
3121    sql: &str,
3122    dialect: DialectType,
3123    schema: &ValidationSchema,
3124    options: &SchemaValidationOptions,
3125) -> ValidationResult {
3126    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3127
3128    // Syntax validation first.
3129    let syntax_result = crate::validate_with_options(
3130        sql,
3131        dialect,
3132        &crate::ValidationOptions {
3133            strict_syntax: options.strict_syntax,
3134        },
3135    );
3136    if !syntax_result.valid {
3137        return syntax_result;
3138    }
3139
3140    let d = Dialect::get(dialect);
3141    let statements = match d.parse(sql) {
3142        Ok(exprs) => exprs,
3143        Err(e) => {
3144            return ValidationResult::with_errors(vec![ValidationError::error(
3145                e.to_string(),
3146                validation_codes::E_PARSE_OR_OPTIONS,
3147            )]);
3148        }
3149    };
3150
3151    let schema_map = build_schema_map(schema);
3152    let resolver_schema = build_resolver_schema(schema);
3153    let mut all_errors = syntax_result.errors;
3154    let declared_relationships = if options.check_references {
3155        build_declared_relationships(schema, &schema_map)
3156    } else {
3157        Vec::new()
3158    };
3159
3160    if options.check_references {
3161        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3162    }
3163
3164    for stmt in &statements {
3165        if options.semantic {
3166            all_errors.extend(check_semantics(stmt));
3167        }
3168        all_errors.extend(validate_statement_with_schema(stmt, &schema_map, strict));
3169        if options.check_types {
3170            all_errors.extend(check_types(stmt, &schema_map, strict));
3171        }
3172        if options.check_references {
3173            all_errors.extend(check_query_reference_quality(
3174                stmt,
3175                &schema_map,
3176                &resolver_schema,
3177                strict,
3178                &declared_relationships,
3179            ));
3180        }
3181    }
3182
3183    ValidationResult::with_errors(all_errors)
3184}
3185
3186#[cfg(test)]
3187mod tests;