Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15    feature = "function-catalog-clickhouse",
16    feature = "function-catalog-duckdb",
17    feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20    FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21    HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34    feature = "function-catalog-clickhouse",
35    feature = "function-catalog-duckdb",
36    feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40/// Column definition used for schema-aware validation.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43    /// Column name.
44    pub name: String,
45    /// Optional column data type (currently informational).
46    #[serde(default, rename = "type")]
47    pub data_type: String,
48    /// Whether the column allows NULL values.
49    #[serde(default)]
50    pub nullable: Option<bool>,
51    /// Whether this column is part of a primary key.
52    #[serde(default, rename = "primaryKey")]
53    pub primary_key: bool,
54    /// Whether this column has a uniqueness constraint.
55    #[serde(default)]
56    pub unique: bool,
57    /// Optional column-level foreign key reference.
58    #[serde(default)]
59    pub references: Option<SchemaColumnReference>,
60}
61
62/// Column-level foreign key reference metadata.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65    /// Referenced table name.
66    pub table: String,
67    /// Referenced column name.
68    pub column: String,
69    /// Optional schema/namespace of referenced table.
70    #[serde(default)]
71    pub schema: Option<String>,
72}
73
74/// Table-level foreign key reference metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77    /// Optional FK name.
78    #[serde(default)]
79    pub name: Option<String>,
80    /// Source columns in the current table.
81    pub columns: Vec<String>,
82    /// Referenced target table + columns.
83    pub references: SchemaTableReference,
84}
85
86/// Target of a table-level foreign key.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89    /// Referenced table name.
90    pub table: String,
91    /// Referenced target columns.
92    pub columns: Vec<String>,
93    /// Optional schema/namespace of referenced table.
94    #[serde(default)]
95    pub schema: Option<String>,
96}
97
98/// Table definition used for schema-aware validation.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101    /// Table name.
102    pub name: String,
103    /// Optional schema/namespace name.
104    #[serde(default)]
105    pub schema: Option<String>,
106    /// Column definitions.
107    pub columns: Vec<SchemaColumn>,
108    /// Optional aliases that should resolve to this table.
109    #[serde(default)]
110    pub aliases: Vec<String>,
111    /// Optional primary key column list.
112    #[serde(default, rename = "primaryKey")]
113    pub primary_key: Vec<String>,
114    /// Optional unique key groups.
115    #[serde(default, rename = "uniqueKeys")]
116    pub unique_keys: Vec<Vec<String>>,
117    /// Optional table-level foreign keys.
118    #[serde(default, rename = "foreignKeys")]
119    pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122/// Schema payload used for schema-aware validation.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125    /// Known tables.
126    pub tables: Vec<SchemaTable>,
127    /// Default strict mode for unknown identifiers.
128    #[serde(default)]
129    pub strict: Option<bool>,
130}
131
132/// Options for schema-aware validation.
133#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
136    #[serde(default)]
137    pub check_types: bool,
138    /// Enables FK/reference integrity checks and query-level reference quality checks.
139    #[serde(default)]
140    pub check_references: bool,
141    /// If true/false, overrides schema.strict.
142    #[serde(default)]
143    pub strict: Option<bool>,
144    /// Enables semantic warnings (W001..W004).
145    #[serde(default)]
146    pub semantic: bool,
147    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
148    #[serde(default)]
149    pub strict_syntax: bool,
150    /// Optional external function catalog plugin for dialect-specific function validation.
151    #[serde(skip, default)]
152    pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156    feature = "function-catalog-clickhouse",
157    feature = "function-catalog-duckdb",
158    feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161    case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163    match case {
164        polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165            CoreFunctionNameCase::Insensitive
166        }
167        polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168            CoreFunctionNameCase::Sensitive
169        }
170    }
171}
172
173#[cfg(any(
174    feature = "function-catalog-clickhouse",
175    feature = "function-catalog-duckdb",
176    feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179    signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181    signatures
182        .into_iter()
183        .map(|signature| CoreFunctionSignature {
184            min_arity: signature.min_arity,
185            max_arity: signature.max_arity,
186        })
187        .collect()
188}
189
190#[cfg(any(
191    feature = "function-catalog-clickhouse",
192    feature = "function-catalog-duckdb",
193    feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196    catalog: &'a mut HashMapFunctionCatalog,
197    dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201    feature = "function-catalog-clickhouse",
202    feature = "function-catalog-duckdb",
203    feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206    fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207        if let Some(cached) = self.dialect_cache.get(dialect) {
208            return *cached;
209        }
210        let parsed = dialect.parse::<DialectType>().ok();
211        self.dialect_cache.insert(dialect, parsed);
212        parsed
213    }
214}
215
216#[cfg(any(
217    feature = "function-catalog-clickhouse",
218    feature = "function-catalog-duckdb",
219    feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222    fn set_dialect_name_case(
223        &mut self,
224        dialect: &'static str,
225        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226    ) {
227        if let Some(core_dialect) = self.resolve_dialect(dialect) {
228            self.catalog
229                .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230        }
231    }
232
233    fn set_function_name_case(
234        &mut self,
235        dialect: &'static str,
236        function_name: &str,
237        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238    ) {
239        if let Some(core_dialect) = self.resolve_dialect(dialect) {
240            self.catalog.set_function_name_case(
241                core_dialect,
242                function_name,
243                to_core_name_case(name_case),
244            );
245        }
246    }
247
248    fn register(
249        &mut self,
250        dialect: &'static str,
251        function_name: &str,
252        signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253    ) {
254        if let Some(core_dialect) = self.resolve_dialect(dialect) {
255            self.catalog
256                .register(core_dialect, function_name, to_core_signatures(signatures));
257        }
258    }
259}
260
261#[cfg(any(
262    feature = "function-catalog-clickhouse",
263    feature = "function-catalog-duckdb",
264    feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267    static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268        let mut catalog = HashMapFunctionCatalog::default();
269        let mut sink = EmbeddedCatalogSink {
270            catalog: &mut catalog,
271            dialect_cache: HashMap::new(),
272        };
273        polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274        Arc::new(catalog)
275    });
276
277    EMBEDDED.clone()
278}
279
280#[cfg(any(
281    feature = "function-catalog-clickhouse",
282    feature = "function-catalog-duckdb",
283    feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286    Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290    feature = "function-catalog-clickhouse",
291    feature = "function-catalog-duckdb",
292    feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295    None
296}
297
298/// Validation error/warning codes used by schema-aware validation.
299pub mod validation_codes {
300    // Existing schema and semantic checks.
301    pub const E_PARSE_OR_OPTIONS: &str = "E000";
302    pub const E_UNKNOWN_TABLE: &str = "E200";
303    pub const E_UNKNOWN_COLUMN: &str = "E201";
304    pub const E_UNKNOWN_FUNCTION: &str = "E202";
305    pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307    pub const W_SELECT_STAR: &str = "W001";
308    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309    pub const W_DISTINCT_ORDER_BY: &str = "W003";
310    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312    // Phase 2 (type checks): E210-E219, W210-W219.
313    pub const E_TYPE_MISMATCH: &str = "E210";
314    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321    pub const E_INVALID_CAST: &str = "E218";
322    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327    pub const W_LOSSY_CAST: &str = "W213";
328    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329    pub const W_PREDICATE_NULLABILITY: &str = "W215";
330    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335    // Phase 2 (reference checks): E220-E229, W220-W229.
336    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342    pub const W_CARTESIAN_JOIN: &str = "W220";
343    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347/// Canonical type family used by schema/type checks.
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351    Unknown,
352    Boolean,
353    Integer,
354    Numeric,
355    String,
356    Binary,
357    Date,
358    Time,
359    Timestamp,
360    Interval,
361    Json,
362    Uuid,
363    Array,
364    Map,
365    Struct,
366}
367
368impl TypeFamily {
369    pub fn is_numeric(self) -> bool {
370        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371    }
372
373    pub fn is_temporal(self) -> bool {
374        matches!(
375            self,
376            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377        )
378    }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383    columns: HashMap<String, TypeFamily>,
384    column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388    s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392    let open = data_type.find('(')?;
393    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394        return None;
395    }
396    let base = data_type[..open].trim();
397    let inner = data_type[open + 1..data_type.len() - 1].trim();
398    Some((base, inner))
399}
400
401/// Canonicalize a schema type string into a stable `TypeFamily`.
402pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403    let trimmed = data_type
404        .trim()
405        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406    if trimmed.is_empty() {
407        return TypeFamily::Unknown;
408    }
409
410    // Normalize whitespace and lowercase for matching.
411    let normalized = trimmed
412        .split_whitespace()
413        .collect::<Vec<_>>()
414        .join(" ")
415        .to_lowercase();
416
417    // Strip common wrappers first.
418    if let Some((base, inner)) = split_type_args(&normalized) {
419        match base {
420            "nullable" | "lowcardinality" => return canonical_type_family(inner),
421            "array" | "list" => return TypeFamily::Array,
422            "map" => return TypeFamily::Map,
423            "struct" | "row" | "record" => return TypeFamily::Struct,
424            _ => {}
425        }
426    }
427
428    if normalized.starts_with("array<") || normalized.starts_with("list<") {
429        return TypeFamily::Array;
430    }
431    if normalized.starts_with("map<") {
432        return TypeFamily::Map;
433    }
434    if normalized.starts_with("struct<")
435        || normalized.starts_with("row<")
436        || normalized.starts_with("record<")
437        || normalized.starts_with("object<")
438    {
439        return TypeFamily::Struct;
440    }
441
442    if normalized.ends_with("[]") {
443        return TypeFamily::Array;
444    }
445
446    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
447    let mut base = normalized
448        .split('(')
449        .next()
450        .unwrap_or("")
451        .trim()
452        .to_string();
453    if base.is_empty() {
454        return TypeFamily::Unknown;
455    }
456
457    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460    match base.as_str() {
461        "bool" | "boolean" => TypeFamily::Boolean,
462        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465            TypeFamily::Integer
466        }
467        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469            TypeFamily::Numeric
470        }
471        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472        | "string" | "clob" => TypeFamily::String,
473        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474        "date" => TypeFamily::Date,
475        "time" => TypeFamily::Time,
476        "timestamp"
477        | "timestamptz"
478        | "datetime"
479        | "datetime2"
480        | "smalldatetime"
481        | "timestamp with time zone"
482        | "timestamp without time zone" => TypeFamily::Timestamp,
483        "interval" => TypeFamily::Interval,
484        "json" | "jsonb" | "variant" => TypeFamily::Json,
485        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486        "array" | "list" => TypeFamily::Array,
487        "map" => TypeFamily::Map,
488        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489        _ => TypeFamily::Unknown,
490    }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494    let mut map = HashMap::new();
495
496    for table in &schema.tables {
497        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498        let columns: HashMap<String, TypeFamily> = table
499            .columns
500            .iter()
501            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502            .collect();
503        let entry = TableSchemaEntry {
504            columns,
505            column_order,
506        };
507
508        let simple_name = lower(&table.name);
509        map.insert(simple_name, entry.clone());
510
511        if let Some(table_schema) = &table.schema {
512            map.insert(
513                format!("{}.{}", lower(table_schema), lower(&table.name)),
514                entry.clone(),
515            );
516        }
517
518        for alias in &table.aliases {
519            map.insert(lower(alias), entry.clone());
520        }
521    }
522
523    map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527    match family {
528        TypeFamily::Unknown => DataType::Unknown,
529        TypeFamily::Boolean => DataType::Boolean,
530        TypeFamily::Integer => DataType::Int {
531            length: None,
532            integer_spelling: false,
533        },
534        TypeFamily::Numeric => DataType::Double {
535            precision: None,
536            scale: None,
537        },
538        TypeFamily::String => DataType::VarChar {
539            length: None,
540            parenthesized_length: false,
541        },
542        TypeFamily::Binary => DataType::VarBinary { length: None },
543        TypeFamily::Date => DataType::Date,
544        TypeFamily::Time => DataType::Time {
545            precision: None,
546            timezone: false,
547        },
548        TypeFamily::Timestamp => DataType::Timestamp {
549            precision: None,
550            timezone: false,
551        },
552        TypeFamily::Interval => DataType::Interval {
553            unit: None,
554            to: None,
555        },
556        TypeFamily::Json => DataType::Json,
557        TypeFamily::Uuid => DataType::Uuid,
558        TypeFamily::Array => DataType::Array {
559            element_type: Box::new(DataType::Unknown),
560            dimension: None,
561        },
562        TypeFamily::Map => DataType::Map {
563            key_type: Box::new(DataType::Unknown),
564            value_type: Box::new(DataType::Unknown),
565        },
566        TypeFamily::Struct => DataType::Struct {
567            fields: Vec::new(),
568            nested: false,
569        },
570    }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574    let mut mapping = MappingSchema::new();
575
576    for table in &schema.tables {
577        let columns: Vec<(String, DataType)> = table
578            .columns
579            .iter()
580            .map(|column| {
581                (
582                    lower(&column.name),
583                    type_family_to_data_type(canonical_type_family(&column.data_type)),
584                )
585            })
586            .collect();
587
588        let mut table_names = Vec::new();
589        table_names.push(lower(&table.name));
590        if let Some(table_schema) = &table.schema {
591            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592        }
593        for alias in &table.aliases {
594            table_names.push(lower(alias));
595        }
596
597        let mut dedup = HashSet::new();
598        for table_name in table_names {
599            if dedup.insert(table_name.clone()) {
600                let _ = mapping.add_table(&table_name, &columns, None);
601            }
602        }
603    }
604
605    mapping
606}
607
608fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
609    let mut aliases = HashSet::new();
610
611    for node in expr.dfs() {
612        match node {
613            Expression::Select(select) => {
614                if let Some(with) = &select.with {
615                    for cte in &with.ctes {
616                        aliases.insert(lower(&cte.alias.name));
617                    }
618                }
619            }
620            Expression::Insert(insert) => {
621                if let Some(with) = &insert.with {
622                    for cte in &with.ctes {
623                        aliases.insert(lower(&cte.alias.name));
624                    }
625                }
626            }
627            Expression::Update(update) => {
628                if let Some(with) = &update.with {
629                    for cte in &with.ctes {
630                        aliases.insert(lower(&cte.alias.name));
631                    }
632                }
633            }
634            Expression::Delete(delete) => {
635                if let Some(with) = &delete.with {
636                    for cte in &with.ctes {
637                        aliases.insert(lower(&cte.alias.name));
638                    }
639                }
640            }
641            Expression::Union(union) => {
642                if let Some(with) = &union.with {
643                    for cte in &with.ctes {
644                        aliases.insert(lower(&cte.alias.name));
645                    }
646                }
647            }
648            Expression::Intersect(intersect) => {
649                if let Some(with) = &intersect.with {
650                    for cte in &with.ctes {
651                        aliases.insert(lower(&cte.alias.name));
652                    }
653                }
654            }
655            Expression::Except(except) => {
656                if let Some(with) = &except.with {
657                    for cte in &with.ctes {
658                        aliases.insert(lower(&cte.alias.name));
659                    }
660                }
661            }
662            Expression::Merge(merge) => {
663                if let Some(with_) = &merge.with_ {
664                    if let Expression::With(with_clause) = with_.as_ref() {
665                        for cte in &with_clause.ctes {
666                            aliases.insert(lower(&cte.alias.name));
667                        }
668                    }
669                }
670            }
671            _ => {}
672        }
673    }
674
675    aliases
676}
677
678fn table_ref_candidates(table: &TableRef) -> Vec<String> {
679    let name = lower(&table.name.name);
680    let schema = table.schema.as_ref().map(|s| lower(&s.name));
681    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
682
683    let mut candidates = Vec::new();
684    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
685        candidates.push(format!("{}.{}.{}", catalog, schema, name));
686    }
687    if let Some(schema) = &schema {
688        candidates.push(format!("{}.{}", schema, name));
689    }
690    candidates.push(name);
691    candidates
692}
693
694fn table_ref_display_name(table: &TableRef) -> String {
695    let mut parts = Vec::new();
696    if let Some(catalog) = &table.catalog {
697        parts.push(catalog.name.clone());
698    }
699    if let Some(schema) = &table.schema {
700        parts.push(schema.name.clone());
701    }
702    parts.push(table.name.name.clone());
703    parts.join(".")
704}
705
706#[derive(Debug, Default, Clone)]
707struct TypeCheckContext {
708    referenced_tables: HashSet<String>,
709    table_aliases: HashMap<String, String>,
710}
711
712fn type_family_name(family: TypeFamily) -> &'static str {
713    match family {
714        TypeFamily::Unknown => "unknown",
715        TypeFamily::Boolean => "boolean",
716        TypeFamily::Integer => "integer",
717        TypeFamily::Numeric => "numeric",
718        TypeFamily::String => "string",
719        TypeFamily::Binary => "binary",
720        TypeFamily::Date => "date",
721        TypeFamily::Time => "time",
722        TypeFamily::Timestamp => "timestamp",
723        TypeFamily::Interval => "interval",
724        TypeFamily::Json => "json",
725        TypeFamily::Uuid => "uuid",
726        TypeFamily::Array => "array",
727        TypeFamily::Map => "map",
728        TypeFamily::Struct => "struct",
729    }
730}
731
732fn is_string_like(family: TypeFamily) -> bool {
733    matches!(family, TypeFamily::String)
734}
735
736fn is_string_or_binary(family: TypeFamily) -> bool {
737    matches!(family, TypeFamily::String | TypeFamily::Binary)
738}
739
740fn type_issue(
741    strict: bool,
742    error_code: &str,
743    warning_code: &str,
744    message: impl Into<String>,
745) -> ValidationError {
746    if strict {
747        ValidationError::error(message.into(), error_code)
748    } else {
749        ValidationError::warning(message.into(), warning_code)
750    }
751}
752
753fn data_type_family(data_type: &DataType) -> TypeFamily {
754    match data_type {
755        DataType::Boolean => TypeFamily::Boolean,
756        DataType::TinyInt { .. }
757        | DataType::SmallInt { .. }
758        | DataType::Int { .. }
759        | DataType::BigInt { .. } => TypeFamily::Integer,
760        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
761            TypeFamily::Numeric
762        }
763        DataType::Char { .. }
764        | DataType::VarChar { .. }
765        | DataType::String { .. }
766        | DataType::Text
767        | DataType::TextWithLength { .. }
768        | DataType::CharacterSet { .. } => TypeFamily::String,
769        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
770        DataType::Date => TypeFamily::Date,
771        DataType::Time { .. } => TypeFamily::Time,
772        DataType::Timestamp { .. } => TypeFamily::Timestamp,
773        DataType::Interval { .. } => TypeFamily::Interval,
774        DataType::Json | DataType::JsonB => TypeFamily::Json,
775        DataType::Uuid => TypeFamily::Uuid,
776        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
777        DataType::Map { .. } => TypeFamily::Map,
778        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
779            TypeFamily::Struct
780        }
781        DataType::Nullable { inner } => data_type_family(inner),
782        DataType::Custom { name } => canonical_type_family(name),
783        DataType::Unknown => TypeFamily::Unknown,
784        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
785        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
786        DataType::Vector { .. } => TypeFamily::Array,
787        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
788    }
789}
790
791fn collect_type_check_context(
792    stmt: &Expression,
793    schema_map: &HashMap<String, TableSchemaEntry>,
794) -> TypeCheckContext {
795    fn add_table_to_context(
796        table: &TableRef,
797        schema_map: &HashMap<String, TableSchemaEntry>,
798        context: &mut TypeCheckContext,
799    ) {
800        let resolved_key = table_ref_candidates(table)
801            .into_iter()
802            .find(|k| schema_map.contains_key(k));
803
804        let Some(table_key) = resolved_key else {
805            return;
806        };
807
808        context.referenced_tables.insert(table_key.clone());
809        context
810            .table_aliases
811            .insert(lower(&table.name.name), table_key.clone());
812        if let Some(alias) = &table.alias {
813            context
814                .table_aliases
815                .insert(lower(&alias.name), table_key.clone());
816        }
817    }
818
819    let mut context = TypeCheckContext::default();
820    let cte_aliases = collect_cte_aliases(stmt);
821
822    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
823        let Expression::Table(table) = node else {
824            continue;
825        };
826
827        if cte_aliases.contains(&lower(&table.name.name)) {
828            continue;
829        }
830
831        add_table_to_context(table, schema_map, &mut context);
832    }
833
834    // Seed DML target tables explicitly because they are struct fields and may
835    // not appear as standalone Expression::Table nodes in traversal output.
836    match stmt {
837        Expression::Insert(insert) => {
838            add_table_to_context(&insert.table, schema_map, &mut context);
839        }
840        Expression::Update(update) => {
841            add_table_to_context(&update.table, schema_map, &mut context);
842            for table in &update.extra_tables {
843                add_table_to_context(table, schema_map, &mut context);
844            }
845        }
846        Expression::Delete(delete) => {
847            add_table_to_context(&delete.table, schema_map, &mut context);
848            for table in &delete.using {
849                add_table_to_context(table, schema_map, &mut context);
850            }
851            for table in &delete.tables {
852                add_table_to_context(table, schema_map, &mut context);
853            }
854        }
855        _ => {}
856    }
857
858    context
859}
860
861fn resolve_table_schema_entry<'a>(
862    table: &TableRef,
863    schema_map: &'a HashMap<String, TableSchemaEntry>,
864) -> Option<(String, &'a TableSchemaEntry)> {
865    let key = table_ref_candidates(table)
866        .into_iter()
867        .find(|k| schema_map.contains_key(k))?;
868    let entry = schema_map.get(&key)?;
869    Some((key, entry))
870}
871
872fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
873    if strict {
874        ValidationError::error(
875            message.into(),
876            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
877        )
878    } else {
879        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
880    }
881}
882
883fn reference_table_candidates(
884    table_name: &str,
885    explicit_schema: Option<&str>,
886    source_schema: Option<&str>,
887) -> Vec<String> {
888    let mut candidates = Vec::new();
889    let raw = lower(table_name);
890
891    if let Some(schema) = explicit_schema {
892        candidates.push(format!("{}.{}", lower(schema), raw));
893    }
894
895    if raw.contains('.') {
896        candidates.push(raw.clone());
897        if let Some(last) = raw.rsplit('.').next() {
898            candidates.push(last.to_string());
899        }
900    } else {
901        if let Some(schema) = source_schema {
902            candidates.push(format!("{}.{}", lower(schema), raw));
903        }
904        candidates.push(raw);
905    }
906
907    let mut dedup = HashSet::new();
908    candidates
909        .into_iter()
910        .filter(|c| dedup.insert(c.clone()))
911        .collect()
912}
913
914fn resolve_reference_table_key(
915    table_name: &str,
916    explicit_schema: Option<&str>,
917    source_schema: Option<&str>,
918    schema_map: &HashMap<String, TableSchemaEntry>,
919) -> Option<String> {
920    reference_table_candidates(table_name, explicit_schema, source_schema)
921        .into_iter()
922        .find(|candidate| schema_map.contains_key(candidate))
923}
924
925fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
926    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
927        return true;
928    }
929    if source == target {
930        return true;
931    }
932    if source.is_numeric() && target.is_numeric() {
933        return true;
934    }
935    if source.is_temporal() && target.is_temporal() {
936        return true;
937    }
938    false
939}
940
941fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
942    let mut hints = HashSet::new();
943    for column in &table.columns {
944        if column.primary_key || column.unique {
945            hints.insert(lower(&column.name));
946        }
947    }
948    for key_col in &table.primary_key {
949        hints.insert(lower(key_col));
950    }
951    for group in &table.unique_keys {
952        if group.len() == 1 {
953            if let Some(col) = group.first() {
954                hints.insert(lower(col));
955            }
956        }
957    }
958    hints
959}
960
961fn check_reference_integrity(
962    schema: &ValidationSchema,
963    schema_map: &HashMap<String, TableSchemaEntry>,
964    strict: bool,
965) -> Vec<ValidationError> {
966    let mut errors = Vec::new();
967
968    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
969    for table in &schema.tables {
970        let simple = lower(&table.name);
971        key_hints_lookup.insert(simple, table_key_hints(table));
972        if let Some(schema_name) = &table.schema {
973            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
974            key_hints_lookup.insert(qualified, table_key_hints(table));
975        }
976    }
977
978    for table in &schema.tables {
979        let source_table_display = if let Some(schema_name) = &table.schema {
980            format!("{}.{}", schema_name, table.name)
981        } else {
982            table.name.clone()
983        };
984        let source_schema = table.schema.as_deref();
985        let source_columns: HashMap<String, TypeFamily> = table
986            .columns
987            .iter()
988            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
989            .collect();
990
991        for source_col in &table.columns {
992            let Some(reference) = &source_col.references else {
993                continue;
994            };
995            let source_type = canonical_type_family(&source_col.data_type);
996
997            let Some(target_key) = resolve_reference_table_key(
998                &reference.table,
999                reference.schema.as_deref(),
1000                source_schema,
1001                schema_map,
1002            ) else {
1003                errors.push(reference_issue(
1004                    strict,
1005                    format!(
1006                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1007                        source_table_display, source_col.name, reference.table
1008                    ),
1009                ));
1010                continue;
1011            };
1012
1013            let target_column = lower(&reference.column);
1014            let Some(target_entry) = schema_map.get(&target_key) else {
1015                errors.push(reference_issue(
1016                    strict,
1017                    format!(
1018                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1019                        source_table_display, source_col.name, reference.table
1020                    ),
1021                ));
1022                continue;
1023            };
1024
1025            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1026                errors.push(reference_issue(
1027                    strict,
1028                    format!(
1029                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1030                        source_table_display, source_col.name, target_key, reference.column
1031                    ),
1032                ));
1033                continue;
1034            };
1035
1036            if !key_types_compatible(source_type, target_type) {
1037                errors.push(reference_issue(
1038                    strict,
1039                    format!(
1040                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1041                        source_table_display,
1042                        source_col.name,
1043                        target_key,
1044                        reference.column,
1045                        type_family_name(source_type),
1046                        type_family_name(target_type)
1047                    ),
1048                ));
1049            }
1050
1051            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1052                if !target_key_hints.contains(&target_column) {
1053                    errors.push(ValidationError::warning(
1054                        format!(
1055                            "Referenced column '{}.{}' is not marked as primary/unique key",
1056                            target_key, reference.column
1057                        ),
1058                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1059                    ));
1060                }
1061            }
1062        }
1063
1064        for foreign_key in &table.foreign_keys {
1065            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1066                errors.push(reference_issue(
1067                    strict,
1068                    format!(
1069                        "Table-level foreign key on '{}' has empty source or target column list",
1070                        source_table_display
1071                    ),
1072                ));
1073                continue;
1074            }
1075            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1076                errors.push(reference_issue(
1077                    strict,
1078                    format!(
1079                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
1080                        source_table_display,
1081                        foreign_key.columns.len(),
1082                        foreign_key.references.columns.len()
1083                    ),
1084                ));
1085                continue;
1086            }
1087
1088            let Some(target_key) = resolve_reference_table_key(
1089                &foreign_key.references.table,
1090                foreign_key.references.schema.as_deref(),
1091                source_schema,
1092                schema_map,
1093            ) else {
1094                errors.push(reference_issue(
1095                    strict,
1096                    format!(
1097                        "Table-level foreign key on '{}' points to unknown table '{}'",
1098                        source_table_display, foreign_key.references.table
1099                    ),
1100                ));
1101                continue;
1102            };
1103
1104            let Some(target_entry) = schema_map.get(&target_key) else {
1105                errors.push(reference_issue(
1106                    strict,
1107                    format!(
1108                        "Table-level foreign key on '{}' points to unknown table '{}'",
1109                        source_table_display, foreign_key.references.table
1110                    ),
1111                ));
1112                continue;
1113            };
1114
1115            for (source_col, target_col) in foreign_key
1116                .columns
1117                .iter()
1118                .zip(foreign_key.references.columns.iter())
1119            {
1120                let source_col_name = lower(source_col);
1121                let target_col_name = lower(target_col);
1122
1123                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1124                    errors.push(reference_issue(
1125                        strict,
1126                        format!(
1127                            "Table-level foreign key on '{}' references unknown source column '{}'",
1128                            source_table_display, source_col
1129                        ),
1130                    ));
1131                    continue;
1132                };
1133
1134                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1135                    errors.push(reference_issue(
1136                        strict,
1137                        format!(
1138                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1139                            source_table_display, target_key, target_col
1140                        ),
1141                    ));
1142                    continue;
1143                };
1144
1145                if !key_types_compatible(source_type, target_type) {
1146                    errors.push(reference_issue(
1147                        strict,
1148                        format!(
1149                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1150                            source_table_display,
1151                            source_col,
1152                            target_key,
1153                            target_col,
1154                            type_family_name(source_type),
1155                            type_family_name(target_type)
1156                        ),
1157                    ));
1158                }
1159
1160                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1161                    if !target_key_hints.contains(&target_col_name) {
1162                        errors.push(ValidationError::warning(
1163                            format!(
1164                                "Referenced column '{}.{}' is not marked as primary/unique key",
1165                                target_key, target_col
1166                            ),
1167                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1168                        ));
1169                    }
1170                }
1171            }
1172        }
1173    }
1174
1175    errors
1176}
1177
1178fn resolve_unqualified_column_type(
1179    column_name: &str,
1180    schema_map: &HashMap<String, TableSchemaEntry>,
1181    context: &TypeCheckContext,
1182) -> TypeFamily {
1183    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1184        context.referenced_tables.iter().collect()
1185    } else {
1186        schema_map.keys().collect()
1187    };
1188
1189    let mut families = HashSet::new();
1190    for table_name in candidate_tables {
1191        if let Some(table_schema) = schema_map.get(table_name) {
1192            if let Some(family) = table_schema.columns.get(column_name) {
1193                families.insert(*family);
1194            }
1195        }
1196    }
1197
1198    if families.len() == 1 {
1199        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1200    } else {
1201        TypeFamily::Unknown
1202    }
1203}
1204
1205fn resolve_column_type(
1206    column: &Column,
1207    schema_map: &HashMap<String, TableSchemaEntry>,
1208    context: &TypeCheckContext,
1209) -> TypeFamily {
1210    let column_name = lower(&column.name.name);
1211    if column_name.is_empty() {
1212        return TypeFamily::Unknown;
1213    }
1214
1215    if let Some(table) = &column.table {
1216        let mut table_key = lower(&table.name);
1217        if let Some(mapped) = context.table_aliases.get(&table_key) {
1218            table_key = mapped.clone();
1219        }
1220
1221        return schema_map
1222            .get(&table_key)
1223            .and_then(|t| t.columns.get(&column_name))
1224            .copied()
1225            .unwrap_or(TypeFamily::Unknown);
1226    }
1227
1228    resolve_unqualified_column_type(&column_name, schema_map, context)
1229}
1230
1231struct TypeInferenceSchema<'a> {
1232    schema_map: &'a HashMap<String, TableSchemaEntry>,
1233    context: &'a TypeCheckContext,
1234}
1235
1236impl TypeInferenceSchema<'_> {
1237    fn resolve_table_key(&self, table: &str) -> Option<String> {
1238        let mut table_key = lower(table);
1239        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1240            table_key = mapped.clone();
1241        }
1242        if self.schema_map.contains_key(&table_key) {
1243            Some(table_key)
1244        } else {
1245            None
1246        }
1247    }
1248}
1249
1250impl SqlSchema for TypeInferenceSchema<'_> {
1251    fn dialect(&self) -> Option<DialectType> {
1252        None
1253    }
1254
1255    fn add_table(
1256        &mut self,
1257        _table: &str,
1258        _columns: &[(String, DataType)],
1259        _dialect: Option<DialectType>,
1260    ) -> SchemaResult<()> {
1261        Err(SchemaError::InvalidStructure(
1262            "Type inference schema is read-only".to_string(),
1263        ))
1264    }
1265
1266    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1267        let table_key = self
1268            .resolve_table_key(table)
1269            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1270        let entry = self
1271            .schema_map
1272            .get(&table_key)
1273            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1274        Ok(entry.column_order.clone())
1275    }
1276
1277    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1278        let col_name = lower(column);
1279        if table.is_empty() {
1280            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1281            return if family == TypeFamily::Unknown {
1282                Err(SchemaError::ColumnNotFound {
1283                    table: "<unqualified>".to_string(),
1284                    column: column.to_string(),
1285                })
1286            } else {
1287                Ok(type_family_to_data_type(family))
1288            };
1289        }
1290
1291        let table_key = self
1292            .resolve_table_key(table)
1293            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1294        let entry = self
1295            .schema_map
1296            .get(&table_key)
1297            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1298        let family =
1299            entry
1300                .columns
1301                .get(&col_name)
1302                .copied()
1303                .ok_or_else(|| SchemaError::ColumnNotFound {
1304                    table: table.to_string(),
1305                    column: column.to_string(),
1306                })?;
1307        Ok(type_family_to_data_type(family))
1308    }
1309
1310    fn has_column(&self, table: &str, column: &str) -> bool {
1311        self.get_column_type(table, column).is_ok()
1312    }
1313
1314    fn supported_table_args(&self) -> &[&str] {
1315        TABLE_PARTS
1316    }
1317
1318    fn is_empty(&self) -> bool {
1319        self.schema_map.is_empty()
1320    }
1321
1322    fn depth(&self) -> usize {
1323        1
1324    }
1325}
1326
1327fn infer_expression_type_family(
1328    expr: &Expression,
1329    schema_map: &HashMap<String, TableSchemaEntry>,
1330    context: &TypeCheckContext,
1331) -> TypeFamily {
1332    let inference_schema = TypeInferenceSchema {
1333        schema_map,
1334        context,
1335    };
1336    if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1337        let family = data_type_family(&data_type);
1338        if family != TypeFamily::Unknown {
1339            return family;
1340        }
1341    }
1342
1343    infer_expression_type_family_fallback(expr, schema_map, context)
1344}
1345
1346fn infer_expression_type_family_fallback(
1347    expr: &Expression,
1348    schema_map: &HashMap<String, TableSchemaEntry>,
1349    context: &TypeCheckContext,
1350) -> TypeFamily {
1351    match expr {
1352        Expression::Literal(literal) => match literal {
1353            crate::expressions::Literal::Number(value) => {
1354                if value.contains('.') || value.contains('e') || value.contains('E') {
1355                    TypeFamily::Numeric
1356                } else {
1357                    TypeFamily::Integer
1358                }
1359            }
1360            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1361            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1362            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1363            crate::expressions::Literal::Timestamp(_)
1364            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1365            crate::expressions::Literal::HexString(_)
1366            | crate::expressions::Literal::BitString(_)
1367            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1368            _ => TypeFamily::String,
1369        },
1370        Expression::Boolean(_) => TypeFamily::Boolean,
1371        Expression::Null(_) => TypeFamily::Unknown,
1372        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1373        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1374            data_type_family(&cast.to)
1375        }
1376        Expression::Alias(alias) => {
1377            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1378        }
1379        Expression::Neg(unary) => {
1380            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1381        }
1382        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1383            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1384            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1385            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1386                TypeFamily::Unknown
1387            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1388                TypeFamily::Integer
1389            } else if left.is_numeric() && right.is_numeric() {
1390                TypeFamily::Numeric
1391            } else if left.is_temporal() || right.is_temporal() {
1392                left
1393            } else {
1394                TypeFamily::Unknown
1395            }
1396        }
1397        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1398        Expression::Concat(_) => TypeFamily::String,
1399        Expression::Eq(_)
1400        | Expression::Neq(_)
1401        | Expression::Lt(_)
1402        | Expression::Lte(_)
1403        | Expression::Gt(_)
1404        | Expression::Gte(_)
1405        | Expression::Like(_)
1406        | Expression::ILike(_)
1407        | Expression::And(_)
1408        | Expression::Or(_)
1409        | Expression::Not(_)
1410        | Expression::Between(_)
1411        | Expression::In(_)
1412        | Expression::IsNull(_)
1413        | Expression::IsTrue(_)
1414        | Expression::IsFalse(_)
1415        | Expression::Is(_) => TypeFamily::Boolean,
1416        Expression::Length(_) => TypeFamily::Integer,
1417        Expression::Upper(_)
1418        | Expression::Lower(_)
1419        | Expression::Trim(_)
1420        | Expression::LTrim(_)
1421        | Expression::RTrim(_)
1422        | Expression::Replace(_)
1423        | Expression::Substring(_)
1424        | Expression::Left(_)
1425        | Expression::Right(_)
1426        | Expression::Repeat(_)
1427        | Expression::Lpad(_)
1428        | Expression::Rpad(_)
1429        | Expression::ConcatWs(_) => TypeFamily::String,
1430        Expression::Abs(_)
1431        | Expression::Round(_)
1432        | Expression::Floor(_)
1433        | Expression::Ceil(_)
1434        | Expression::Power(_)
1435        | Expression::Sqrt(_)
1436        | Expression::Cbrt(_)
1437        | Expression::Ln(_)
1438        | Expression::Log(_)
1439        | Expression::Exp(_) => TypeFamily::Numeric,
1440        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1441        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1442        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1443        Expression::CurrentDate(_) => TypeFamily::Date,
1444        Expression::CurrentTime(_) => TypeFamily::Time,
1445        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1446            TypeFamily::Timestamp
1447        }
1448        Expression::Interval(_) => TypeFamily::Interval,
1449        _ => TypeFamily::Unknown,
1450    }
1451}
1452
1453fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1454    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1455        return true;
1456    }
1457    if left == right {
1458        return true;
1459    }
1460    if left.is_numeric() && right.is_numeric() {
1461        return true;
1462    }
1463    if left.is_temporal() && right.is_temporal() {
1464        return true;
1465    }
1466    false
1467}
1468
1469fn check_function_argument(
1470    errors: &mut Vec<ValidationError>,
1471    strict: bool,
1472    function_name: &str,
1473    arg_index: usize,
1474    family: TypeFamily,
1475    expected: &str,
1476    valid: bool,
1477) {
1478    if family == TypeFamily::Unknown || valid {
1479        return;
1480    }
1481
1482    errors.push(type_issue(
1483        strict,
1484        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1485        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1486        format!(
1487            "Function '{}' argument {} expects {}, found {}",
1488            function_name,
1489            arg_index + 1,
1490            expected,
1491            type_family_name(family)
1492        ),
1493    ));
1494}
1495
1496fn function_dispatch_name(name: &str) -> String {
1497    let upper = name
1498        .rsplit('.')
1499        .next()
1500        .unwrap_or(name)
1501        .trim()
1502        .to_uppercase();
1503    lower(canonical_typed_function_name_upper(&upper))
1504}
1505
1506fn function_base_name(name: &str) -> &str {
1507    name.rsplit('.').next().unwrap_or(name).trim()
1508}
1509
1510fn check_generic_function(
1511    function: &Function,
1512    schema_map: &HashMap<String, TableSchemaEntry>,
1513    context: &TypeCheckContext,
1514    strict: bool,
1515    errors: &mut Vec<ValidationError>,
1516) {
1517    let name = function_dispatch_name(&function.name);
1518
1519    let arg_family = |index: usize| -> Option<TypeFamily> {
1520        function
1521            .args
1522            .get(index)
1523            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1524    };
1525
1526    match name.as_str() {
1527        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1528            if let Some(family) = arg_family(0) {
1529                check_function_argument(
1530                    errors,
1531                    strict,
1532                    &name,
1533                    0,
1534                    family,
1535                    "a numeric argument",
1536                    family.is_numeric(),
1537                );
1538            }
1539        }
1540        "round" | "floor" | "ceil" | "ceiling" => {
1541            if let Some(family) = arg_family(0) {
1542                check_function_argument(
1543                    errors,
1544                    strict,
1545                    &name,
1546                    0,
1547                    family,
1548                    "a numeric argument",
1549                    family.is_numeric(),
1550                );
1551            }
1552            if let Some(family) = arg_family(1) {
1553                check_function_argument(
1554                    errors,
1555                    strict,
1556                    &name,
1557                    1,
1558                    family,
1559                    "a numeric argument",
1560                    family.is_numeric(),
1561                );
1562            }
1563        }
1564        "power" | "pow" => {
1565            for i in [0_usize, 1_usize] {
1566                if let Some(family) = arg_family(i) {
1567                    check_function_argument(
1568                        errors,
1569                        strict,
1570                        &name,
1571                        i,
1572                        family,
1573                        "a numeric argument",
1574                        family.is_numeric(),
1575                    );
1576                }
1577            }
1578        }
1579        "length" | "char_length" | "character_length" => {
1580            if let Some(family) = arg_family(0) {
1581                check_function_argument(
1582                    errors,
1583                    strict,
1584                    &name,
1585                    0,
1586                    family,
1587                    "a string or binary argument",
1588                    is_string_or_binary(family),
1589                );
1590            }
1591        }
1592        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1593            if let Some(family) = arg_family(0) {
1594                check_function_argument(
1595                    errors,
1596                    strict,
1597                    &name,
1598                    0,
1599                    family,
1600                    "a string argument",
1601                    is_string_like(family),
1602                );
1603            }
1604        }
1605        "substring" | "substr" => {
1606            if let Some(family) = arg_family(0) {
1607                check_function_argument(
1608                    errors,
1609                    strict,
1610                    &name,
1611                    0,
1612                    family,
1613                    "a string argument",
1614                    is_string_like(family),
1615                );
1616            }
1617            if let Some(family) = arg_family(1) {
1618                check_function_argument(
1619                    errors,
1620                    strict,
1621                    &name,
1622                    1,
1623                    family,
1624                    "a numeric argument",
1625                    family.is_numeric(),
1626                );
1627            }
1628            if let Some(family) = arg_family(2) {
1629                check_function_argument(
1630                    errors,
1631                    strict,
1632                    &name,
1633                    2,
1634                    family,
1635                    "a numeric argument",
1636                    family.is_numeric(),
1637                );
1638            }
1639        }
1640        "replace" => {
1641            for i in [0_usize, 1_usize, 2_usize] {
1642                if let Some(family) = arg_family(i) {
1643                    check_function_argument(
1644                        errors,
1645                        strict,
1646                        &name,
1647                        i,
1648                        family,
1649                        "a string argument",
1650                        is_string_like(family),
1651                    );
1652                }
1653            }
1654        }
1655        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1656            if let Some(family) = arg_family(0) {
1657                check_function_argument(
1658                    errors,
1659                    strict,
1660                    &name,
1661                    0,
1662                    family,
1663                    "a string argument",
1664                    is_string_like(family),
1665                );
1666            }
1667            if let Some(family) = arg_family(1) {
1668                check_function_argument(
1669                    errors,
1670                    strict,
1671                    &name,
1672                    1,
1673                    family,
1674                    "a numeric argument",
1675                    family.is_numeric(),
1676                );
1677            }
1678            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1679                if let Some(family) = arg_family(2) {
1680                    check_function_argument(
1681                        errors,
1682                        strict,
1683                        &name,
1684                        2,
1685                        family,
1686                        "a string argument",
1687                        is_string_like(family),
1688                    );
1689                }
1690            }
1691        }
1692        _ => {}
1693    }
1694}
1695
1696fn check_function_catalog(
1697    function: &Function,
1698    dialect: DialectType,
1699    function_catalog: Option<&dyn FunctionCatalog>,
1700    strict: bool,
1701    errors: &mut Vec<ValidationError>,
1702) {
1703    let Some(catalog) = function_catalog else {
1704        return;
1705    };
1706
1707    let raw_name = function_base_name(&function.name);
1708    let normalized_name = function_dispatch_name(&function.name);
1709    let arity = function.args.len();
1710    let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1711        errors.push(if strict {
1712            ValidationError::error(
1713                format!(
1714                    "Unknown function '{}' for dialect {:?}",
1715                    function.name, dialect
1716                ),
1717                validation_codes::E_UNKNOWN_FUNCTION,
1718            )
1719        } else {
1720            ValidationError::warning(
1721                format!(
1722                    "Unknown function '{}' for dialect {:?}",
1723                    function.name, dialect
1724                ),
1725                validation_codes::E_UNKNOWN_FUNCTION,
1726            )
1727        });
1728        return;
1729    };
1730
1731    if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1732        return;
1733    }
1734
1735    let expected = signatures
1736        .iter()
1737        .map(|sig| sig.describe_arity())
1738        .collect::<Vec<_>>()
1739        .join(", ");
1740    errors.push(if strict {
1741        ValidationError::error(
1742            format!(
1743                "Invalid arity for function '{}': got {}, expected {}",
1744                function.name, arity, expected
1745            ),
1746            validation_codes::E_INVALID_FUNCTION_ARITY,
1747        )
1748    } else {
1749        ValidationError::warning(
1750            format!(
1751                "Invalid arity for function '{}': got {}, expected {}",
1752                function.name, arity, expected
1753            ),
1754            validation_codes::E_INVALID_FUNCTION_ARITY,
1755        )
1756    });
1757}
1758
1759#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1760struct DeclaredRelationship {
1761    source_table: String,
1762    source_column: String,
1763    target_table: String,
1764    target_column: String,
1765}
1766
1767fn build_declared_relationships(
1768    schema: &ValidationSchema,
1769    schema_map: &HashMap<String, TableSchemaEntry>,
1770) -> Vec<DeclaredRelationship> {
1771    let mut relationships = HashSet::new();
1772
1773    for table in &schema.tables {
1774        let Some(source_key) =
1775            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1776        else {
1777            continue;
1778        };
1779
1780        for column in &table.columns {
1781            let Some(reference) = &column.references else {
1782                continue;
1783            };
1784            let Some(target_key) = resolve_reference_table_key(
1785                &reference.table,
1786                reference.schema.as_deref(),
1787                table.schema.as_deref(),
1788                schema_map,
1789            ) else {
1790                continue;
1791            };
1792
1793            relationships.insert(DeclaredRelationship {
1794                source_table: source_key.clone(),
1795                source_column: lower(&column.name),
1796                target_table: target_key,
1797                target_column: lower(&reference.column),
1798            });
1799        }
1800
1801        for foreign_key in &table.foreign_keys {
1802            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1803                continue;
1804            }
1805            let Some(target_key) = resolve_reference_table_key(
1806                &foreign_key.references.table,
1807                foreign_key.references.schema.as_deref(),
1808                table.schema.as_deref(),
1809                schema_map,
1810            ) else {
1811                continue;
1812            };
1813
1814            for (source_col, target_col) in foreign_key
1815                .columns
1816                .iter()
1817                .zip(foreign_key.references.columns.iter())
1818            {
1819                relationships.insert(DeclaredRelationship {
1820                    source_table: source_key.clone(),
1821                    source_column: lower(source_col),
1822                    target_table: target_key.clone(),
1823                    target_column: lower(target_col),
1824                });
1825            }
1826        }
1827    }
1828
1829    relationships.into_iter().collect()
1830}
1831
1832fn resolve_column_binding(
1833    column: &Column,
1834    schema_map: &HashMap<String, TableSchemaEntry>,
1835    context: &TypeCheckContext,
1836    resolver: &mut Resolver<'_>,
1837) -> Option<(String, String)> {
1838    let column_name = lower(&column.name.name);
1839    if column_name.is_empty() {
1840        return None;
1841    }
1842
1843    if let Some(table) = &column.table {
1844        let mut table_key = lower(&table.name);
1845        if let Some(mapped) = context.table_aliases.get(&table_key) {
1846            table_key = mapped.clone();
1847        }
1848        if schema_map.contains_key(&table_key) {
1849            return Some((table_key, column_name));
1850        }
1851        return None;
1852    }
1853
1854    if let Some(resolved_source) = resolver.get_table(&column_name) {
1855        let mut table_key = lower(&resolved_source);
1856        if let Some(mapped) = context.table_aliases.get(&table_key) {
1857            table_key = mapped.clone();
1858        }
1859        if schema_map.contains_key(&table_key) {
1860            return Some((table_key, column_name));
1861        }
1862    }
1863
1864    let candidates: Vec<String> = context
1865        .referenced_tables
1866        .iter()
1867        .filter_map(|table_name| {
1868            schema_map
1869                .get(table_name)
1870                .filter(|entry| entry.columns.contains_key(&column_name))
1871                .map(|_| table_name.clone())
1872        })
1873        .collect();
1874    if candidates.len() == 1 {
1875        return Some((candidates[0].clone(), column_name));
1876    }
1877    None
1878}
1879
1880fn extract_join_equality_pairs(
1881    expr: &Expression,
1882    schema_map: &HashMap<String, TableSchemaEntry>,
1883    context: &TypeCheckContext,
1884    resolver: &mut Resolver<'_>,
1885    pairs: &mut Vec<((String, String), (String, String))>,
1886) {
1887    match expr {
1888        Expression::And(op) => {
1889            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1890            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1891        }
1892        Expression::Paren(paren) => {
1893            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1894        }
1895        Expression::Eq(op) => {
1896            let (Expression::Column(left_col), Expression::Column(right_col)) =
1897                (&op.left, &op.right)
1898            else {
1899                return;
1900            };
1901            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1902                return;
1903            };
1904            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1905            else {
1906                return;
1907            };
1908            pairs.push((left, right));
1909        }
1910        _ => {}
1911    }
1912}
1913
1914fn relationship_matches_pair(
1915    relationship: &DeclaredRelationship,
1916    left_table: &str,
1917    left_column: &str,
1918    right_table: &str,
1919    right_column: &str,
1920) -> bool {
1921    (relationship.source_table == left_table
1922        && relationship.source_column == left_column
1923        && relationship.target_table == right_table
1924        && relationship.target_column == right_column)
1925        || (relationship.source_table == right_table
1926            && relationship.source_column == right_column
1927            && relationship.target_table == left_table
1928            && relationship.target_column == left_column)
1929}
1930
1931fn resolved_table_key_from_expr(
1932    expr: &Expression,
1933    schema_map: &HashMap<String, TableSchemaEntry>,
1934) -> Option<String> {
1935    match expr {
1936        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1937        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1938        _ => None,
1939    }
1940}
1941
1942fn select_from_table_keys(
1943    select: &crate::expressions::Select,
1944    schema_map: &HashMap<String, TableSchemaEntry>,
1945) -> HashSet<String> {
1946    let mut keys = HashSet::new();
1947    if let Some(from_clause) = &select.from {
1948        for expr in &from_clause.expressions {
1949            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1950                keys.insert(key);
1951            }
1952        }
1953    }
1954    keys
1955}
1956
1957fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1958    matches!(
1959        kind,
1960        JoinKind::Natural
1961            | JoinKind::NaturalLeft
1962            | JoinKind::NaturalRight
1963            | JoinKind::NaturalFull
1964            | JoinKind::CrossApply
1965            | JoinKind::OuterApply
1966            | JoinKind::AsOf
1967            | JoinKind::AsOfLeft
1968            | JoinKind::AsOfRight
1969            | JoinKind::Lateral
1970            | JoinKind::LeftLateral
1971    )
1972}
1973
1974fn check_query_reference_quality(
1975    stmt: &Expression,
1976    schema_map: &HashMap<String, TableSchemaEntry>,
1977    resolver_schema: &MappingSchema,
1978    strict: bool,
1979    relationships: &[DeclaredRelationship],
1980) -> Vec<ValidationError> {
1981    let mut errors = Vec::new();
1982
1983    for node in stmt.dfs() {
1984        let Expression::Select(select) = node else {
1985            continue;
1986        };
1987
1988        let select_expr = Expression::Select(select.clone());
1989        let context = collect_type_check_context(&select_expr, schema_map);
1990        let scope = build_scope(&select_expr);
1991        let mut resolver = Resolver::new(&scope, resolver_schema, true);
1992
1993        if context.referenced_tables.len() > 1 {
1994            let using_columns: HashSet<String> = select
1995                .joins
1996                .iter()
1997                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1998                .collect();
1999
2000            let mut seen = HashSet::new();
2001            for column_expr in select_expr
2002                .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2003            {
2004                let Expression::Column(column) = column_expr else {
2005                    continue;
2006                };
2007
2008                let col_name = lower(&column.name.name);
2009                if col_name.is_empty()
2010                    || using_columns.contains(&col_name)
2011                    || !seen.insert(col_name.clone())
2012                {
2013                    continue;
2014                }
2015
2016                if resolver.is_ambiguous(&col_name) {
2017                    let source_count = resolver.sources_for_column(&col_name).len();
2018                    errors.push(if strict {
2019                        ValidationError::error(
2020                            format!(
2021                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2022                                col_name, source_count
2023                            ),
2024                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2025                        )
2026                    } else {
2027                        ValidationError::warning(
2028                            format!(
2029                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2030                                col_name, source_count
2031                            ),
2032                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2033                        )
2034                    });
2035                }
2036            }
2037        }
2038
2039        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2040
2041        for join in &select.joins {
2042            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2043            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2044            let cartesian_like_kind = matches!(
2045                join.kind,
2046                JoinKind::Cross
2047                    | JoinKind::Implicit
2048                    | JoinKind::Array
2049                    | JoinKind::LeftArray
2050                    | JoinKind::Paste
2051            );
2052
2053            if right_table_key.is_some()
2054                && (cartesian_like_kind
2055                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2056            {
2057                errors.push(ValidationError::warning(
2058                    "Potential cartesian join: JOIN without ON/USING condition",
2059                    validation_codes::W_CARTESIAN_JOIN,
2060                ));
2061            }
2062
2063            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2064                if join.using.is_empty() {
2065                    let mut eq_pairs = Vec::new();
2066                    extract_join_equality_pairs(
2067                        on_expr,
2068                        schema_map,
2069                        &context,
2070                        &mut resolver,
2071                        &mut eq_pairs,
2072                    );
2073
2074                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2075                        .iter()
2076                        .filter(|rel| {
2077                            cumulative_left_tables.contains(&rel.source_table)
2078                                && rel.target_table == right_key
2079                                || (cumulative_left_tables.contains(&rel.target_table)
2080                                    && rel.source_table == right_key)
2081                        })
2082                        .collect();
2083
2084                    if !relevant_relationships.is_empty() {
2085                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2086                            relevant_relationships
2087                                .iter()
2088                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2089                        });
2090                        if !uses_declared_fk {
2091                            errors.push(ValidationError::warning(
2092                                "JOIN predicate does not use declared foreign-key relationship columns",
2093                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2094                            ));
2095                        }
2096                    }
2097                }
2098            }
2099
2100            if let Some(right_key) = right_table_key {
2101                cumulative_left_tables.insert(right_key);
2102            }
2103        }
2104    }
2105
2106    errors
2107}
2108
2109fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2110    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2111        return true;
2112    }
2113    if left == right {
2114        return true;
2115    }
2116    if left.is_numeric() && right.is_numeric() {
2117        return true;
2118    }
2119    if left.is_temporal() && right.is_temporal() {
2120        return true;
2121    }
2122    false
2123}
2124
2125fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2126    if left == TypeFamily::Unknown {
2127        return right;
2128    }
2129    if right == TypeFamily::Unknown {
2130        return left;
2131    }
2132    if left == right {
2133        return left;
2134    }
2135    if left.is_numeric() && right.is_numeric() {
2136        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2137            return TypeFamily::Numeric;
2138        }
2139        return TypeFamily::Integer;
2140    }
2141    if left.is_temporal() && right.is_temporal() {
2142        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2143            return TypeFamily::Timestamp;
2144        }
2145        if left == TypeFamily::Date || right == TypeFamily::Date {
2146            return TypeFamily::Date;
2147        }
2148        return TypeFamily::Time;
2149    }
2150    TypeFamily::Unknown
2151}
2152
2153fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2154    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2155        return true;
2156    }
2157    if target == source {
2158        return true;
2159    }
2160
2161    match target {
2162        TypeFamily::Boolean => source == TypeFamily::Boolean,
2163        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2164        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2165            source.is_temporal()
2166        }
2167        TypeFamily::String => true,
2168        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2169        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2170        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2171        TypeFamily::Array => source == TypeFamily::Array,
2172        TypeFamily::Map => source == TypeFamily::Map,
2173        TypeFamily::Struct => source == TypeFamily::Struct,
2174        TypeFamily::Unknown => true,
2175    }
2176}
2177
2178fn projection_families(
2179    query_expr: &Expression,
2180    schema_map: &HashMap<String, TableSchemaEntry>,
2181) -> Option<Vec<TypeFamily>> {
2182    match query_expr {
2183        Expression::Select(select) => {
2184            if select
2185                .expressions
2186                .iter()
2187                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2188            {
2189                return None;
2190            }
2191            let select_expr = Expression::Select(select.clone());
2192            let context = collect_type_check_context(&select_expr, schema_map);
2193            Some(
2194                select
2195                    .expressions
2196                    .iter()
2197                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2198                    .collect(),
2199            )
2200        }
2201        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2202        Expression::Union(union) => {
2203            let left = projection_families(&union.left, schema_map)?;
2204            let right = projection_families(&union.right, schema_map)?;
2205            if left.len() != right.len() {
2206                return None;
2207            }
2208            Some(
2209                left.into_iter()
2210                    .zip(right)
2211                    .map(|(l, r)| merged_setop_family(l, r))
2212                    .collect(),
2213            )
2214        }
2215        Expression::Intersect(intersect) => {
2216            let left = projection_families(&intersect.left, schema_map)?;
2217            let right = projection_families(&intersect.right, schema_map)?;
2218            if left.len() != right.len() {
2219                return None;
2220            }
2221            Some(
2222                left.into_iter()
2223                    .zip(right)
2224                    .map(|(l, r)| merged_setop_family(l, r))
2225                    .collect(),
2226            )
2227        }
2228        Expression::Except(except) => {
2229            let left = projection_families(&except.left, schema_map)?;
2230            let right = projection_families(&except.right, schema_map)?;
2231            if left.len() != right.len() {
2232                return None;
2233            }
2234            Some(
2235                left.into_iter()
2236                    .zip(right)
2237                    .map(|(l, r)| merged_setop_family(l, r))
2238                    .collect(),
2239            )
2240        }
2241        Expression::Values(values) => {
2242            let first_row = values.expressions.first()?;
2243            let context = TypeCheckContext::default();
2244            Some(
2245                first_row
2246                    .expressions
2247                    .iter()
2248                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2249                    .collect(),
2250            )
2251        }
2252        _ => None,
2253    }
2254}
2255
2256fn check_set_operation_compatibility(
2257    op_name: &str,
2258    left_expr: &Expression,
2259    right_expr: &Expression,
2260    schema_map: &HashMap<String, TableSchemaEntry>,
2261    strict: bool,
2262    errors: &mut Vec<ValidationError>,
2263) {
2264    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2265        return;
2266    };
2267    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2268        return;
2269    };
2270
2271    if left_projection.len() != right_projection.len() {
2272        errors.push(type_issue(
2273            strict,
2274            validation_codes::E_SETOP_ARITY_MISMATCH,
2275            validation_codes::W_SETOP_IMPLICIT_COERCION,
2276            format!(
2277                "{} operands return different column counts: left {}, right {}",
2278                op_name,
2279                left_projection.len(),
2280                right_projection.len()
2281            ),
2282        ));
2283        return;
2284    }
2285
2286    for (idx, (left, right)) in left_projection
2287        .into_iter()
2288        .zip(right_projection)
2289        .enumerate()
2290    {
2291        if !are_setop_compatible(left, right) {
2292            errors.push(type_issue(
2293                strict,
2294                validation_codes::E_SETOP_TYPE_MISMATCH,
2295                validation_codes::W_SETOP_IMPLICIT_COERCION,
2296                format!(
2297                    "{} column {} has incompatible types: {} vs {}",
2298                    op_name,
2299                    idx + 1,
2300                    type_family_name(left),
2301                    type_family_name(right)
2302                ),
2303            ));
2304        }
2305    }
2306}
2307
2308fn check_insert_assignments(
2309    stmt: &Expression,
2310    insert: &Insert,
2311    schema_map: &HashMap<String, TableSchemaEntry>,
2312    strict: bool,
2313    errors: &mut Vec<ValidationError>,
2314) {
2315    let Some((target_table_key, table_schema)) =
2316        resolve_table_schema_entry(&insert.table, schema_map)
2317    else {
2318        return;
2319    };
2320
2321    let mut target_columns = Vec::new();
2322    if insert.columns.is_empty() {
2323        target_columns.extend(table_schema.column_order.iter().cloned());
2324    } else {
2325        for column in &insert.columns {
2326            let col_name = lower(&column.name);
2327            if table_schema.columns.contains_key(&col_name) {
2328                target_columns.push(col_name);
2329            } else {
2330                errors.push(if strict {
2331                    ValidationError::error(
2332                        format!(
2333                            "Unknown column '{}' in table '{}'",
2334                            column.name, target_table_key
2335                        ),
2336                        validation_codes::E_UNKNOWN_COLUMN,
2337                    )
2338                } else {
2339                    ValidationError::warning(
2340                        format!(
2341                            "Unknown column '{}' in table '{}'",
2342                            column.name, target_table_key
2343                        ),
2344                        validation_codes::E_UNKNOWN_COLUMN,
2345                    )
2346                });
2347            }
2348        }
2349    }
2350
2351    if target_columns.is_empty() {
2352        return;
2353    }
2354
2355    let context = collect_type_check_context(stmt, schema_map);
2356
2357    if !insert.default_values {
2358        for (row_idx, row) in insert.values.iter().enumerate() {
2359            if row.len() != target_columns.len() {
2360                errors.push(type_issue(
2361                    strict,
2362                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2363                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2364                    format!(
2365                        "INSERT row {} has {} values but target has {} columns",
2366                        row_idx + 1,
2367                        row.len(),
2368                        target_columns.len()
2369                    ),
2370                ));
2371                continue;
2372            }
2373
2374            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2375                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2376                    continue;
2377                };
2378                let source_family = infer_expression_type_family(value, schema_map, &context);
2379                if !are_assignment_compatible(target_family, source_family) {
2380                    errors.push(type_issue(
2381                        strict,
2382                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2383                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2384                        format!(
2385                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2386                            target_table_key,
2387                            target_column,
2388                            type_family_name(target_family),
2389                            type_family_name(source_family)
2390                        ),
2391                    ));
2392                }
2393            }
2394        }
2395    }
2396
2397    if let Some(query) = &insert.query {
2398        // DuckDB BY NAME maps source columns by name, not position.
2399        if insert.by_name {
2400            return;
2401        }
2402
2403        let Some(source_projection) = projection_families(query, schema_map) else {
2404            return;
2405        };
2406
2407        if source_projection.len() != target_columns.len() {
2408            errors.push(type_issue(
2409                strict,
2410                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2411                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2412                format!(
2413                    "INSERT source query has {} columns but target has {} columns",
2414                    source_projection.len(),
2415                    target_columns.len()
2416                ),
2417            ));
2418            return;
2419        }
2420
2421        for (source_family, target_column) in
2422            source_projection.into_iter().zip(target_columns.iter())
2423        {
2424            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2425                continue;
2426            };
2427            if !are_assignment_compatible(target_family, source_family) {
2428                errors.push(type_issue(
2429                    strict,
2430                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2431                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2432                    format!(
2433                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2434                        target_table_key,
2435                        target_column,
2436                        type_family_name(target_family),
2437                        type_family_name(source_family)
2438                    ),
2439                ));
2440            }
2441        }
2442    }
2443}
2444
2445fn check_update_assignments(
2446    stmt: &Expression,
2447    update: &Update,
2448    schema_map: &HashMap<String, TableSchemaEntry>,
2449    strict: bool,
2450    errors: &mut Vec<ValidationError>,
2451) {
2452    let Some((target_table_key, table_schema)) =
2453        resolve_table_schema_entry(&update.table, schema_map)
2454    else {
2455        return;
2456    };
2457
2458    let context = collect_type_check_context(stmt, schema_map);
2459
2460    for (column, value) in &update.set {
2461        let col_name = lower(&column.name);
2462        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2463            errors.push(if strict {
2464                ValidationError::error(
2465                    format!(
2466                        "Unknown column '{}' in table '{}'",
2467                        column.name, target_table_key
2468                    ),
2469                    validation_codes::E_UNKNOWN_COLUMN,
2470                )
2471            } else {
2472                ValidationError::warning(
2473                    format!(
2474                        "Unknown column '{}' in table '{}'",
2475                        column.name, target_table_key
2476                    ),
2477                    validation_codes::E_UNKNOWN_COLUMN,
2478                )
2479            });
2480            continue;
2481        };
2482
2483        let source_family = infer_expression_type_family(value, schema_map, &context);
2484        if !are_assignment_compatible(target_family, source_family) {
2485            errors.push(type_issue(
2486                strict,
2487                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2488                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2489                format!(
2490                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2491                    target_table_key,
2492                    col_name,
2493                    type_family_name(target_family),
2494                    type_family_name(source_family)
2495                ),
2496            ));
2497        }
2498    }
2499}
2500
2501fn check_types(
2502    stmt: &Expression,
2503    dialect: DialectType,
2504    schema_map: &HashMap<String, TableSchemaEntry>,
2505    function_catalog: Option<&dyn FunctionCatalog>,
2506    strict: bool,
2507) -> Vec<ValidationError> {
2508    let mut errors = Vec::new();
2509    let context = collect_type_check_context(stmt, schema_map);
2510
2511    for node in stmt.dfs() {
2512        match node {
2513            Expression::Insert(insert) => {
2514                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2515            }
2516            Expression::Update(update) => {
2517                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2518            }
2519            Expression::Union(union) => {
2520                check_set_operation_compatibility(
2521                    "UNION",
2522                    &union.left,
2523                    &union.right,
2524                    schema_map,
2525                    strict,
2526                    &mut errors,
2527                );
2528            }
2529            Expression::Intersect(intersect) => {
2530                check_set_operation_compatibility(
2531                    "INTERSECT",
2532                    &intersect.left,
2533                    &intersect.right,
2534                    schema_map,
2535                    strict,
2536                    &mut errors,
2537                );
2538            }
2539            Expression::Except(except) => {
2540                check_set_operation_compatibility(
2541                    "EXCEPT",
2542                    &except.left,
2543                    &except.right,
2544                    schema_map,
2545                    strict,
2546                    &mut errors,
2547                );
2548            }
2549            Expression::Select(select) => {
2550                if let Some(prewhere) = &select.prewhere {
2551                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2552                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2553                        errors.push(type_issue(
2554                            strict,
2555                            validation_codes::E_INVALID_PREDICATE_TYPE,
2556                            validation_codes::W_PREDICATE_NULLABILITY,
2557                            format!(
2558                                "PREWHERE clause expects a boolean predicate, found {}",
2559                                type_family_name(family)
2560                            ),
2561                        ));
2562                    }
2563                }
2564
2565                if let Some(where_clause) = &select.where_clause {
2566                    let family =
2567                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2568                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2569                        errors.push(type_issue(
2570                            strict,
2571                            validation_codes::E_INVALID_PREDICATE_TYPE,
2572                            validation_codes::W_PREDICATE_NULLABILITY,
2573                            format!(
2574                                "WHERE clause expects a boolean predicate, found {}",
2575                                type_family_name(family)
2576                            ),
2577                        ));
2578                    }
2579                }
2580
2581                if let Some(having_clause) = &select.having {
2582                    let family =
2583                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2584                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2585                        errors.push(type_issue(
2586                            strict,
2587                            validation_codes::E_INVALID_PREDICATE_TYPE,
2588                            validation_codes::W_PREDICATE_NULLABILITY,
2589                            format!(
2590                                "HAVING clause expects a boolean predicate, found {}",
2591                                type_family_name(family)
2592                            ),
2593                        ));
2594                    }
2595                }
2596
2597                for join in &select.joins {
2598                    if let Some(on) = &join.on {
2599                        let family = infer_expression_type_family(on, schema_map, &context);
2600                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2601                            errors.push(type_issue(
2602                                strict,
2603                                validation_codes::E_INVALID_PREDICATE_TYPE,
2604                                validation_codes::W_PREDICATE_NULLABILITY,
2605                                format!(
2606                                    "JOIN ON expects a boolean predicate, found {}",
2607                                    type_family_name(family)
2608                                ),
2609                            ));
2610                        }
2611                    }
2612                    if let Some(match_condition) = &join.match_condition {
2613                        let family =
2614                            infer_expression_type_family(match_condition, schema_map, &context);
2615                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2616                            errors.push(type_issue(
2617                                strict,
2618                                validation_codes::E_INVALID_PREDICATE_TYPE,
2619                                validation_codes::W_PREDICATE_NULLABILITY,
2620                                format!(
2621                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2622                                    type_family_name(family)
2623                                ),
2624                            ));
2625                        }
2626                    }
2627                }
2628            }
2629            Expression::Where(where_clause) => {
2630                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2631                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2632                    errors.push(type_issue(
2633                        strict,
2634                        validation_codes::E_INVALID_PREDICATE_TYPE,
2635                        validation_codes::W_PREDICATE_NULLABILITY,
2636                        format!(
2637                            "WHERE clause expects a boolean predicate, found {}",
2638                            type_family_name(family)
2639                        ),
2640                    ));
2641                }
2642            }
2643            Expression::Having(having_clause) => {
2644                let family =
2645                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2646                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2647                    errors.push(type_issue(
2648                        strict,
2649                        validation_codes::E_INVALID_PREDICATE_TYPE,
2650                        validation_codes::W_PREDICATE_NULLABILITY,
2651                        format!(
2652                            "HAVING clause expects a boolean predicate, found {}",
2653                            type_family_name(family)
2654                        ),
2655                    ));
2656                }
2657            }
2658            Expression::And(op) | Expression::Or(op) => {
2659                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2660                    let family = infer_expression_type_family(expr, schema_map, &context);
2661                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2662                        errors.push(type_issue(
2663                            strict,
2664                            validation_codes::E_INVALID_PREDICATE_TYPE,
2665                            validation_codes::W_PREDICATE_NULLABILITY,
2666                            format!(
2667                                "Logical {} operand expects boolean, found {}",
2668                                side,
2669                                type_family_name(family)
2670                            ),
2671                        ));
2672                    }
2673                }
2674            }
2675            Expression::Not(unary) => {
2676                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2677                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2678                    errors.push(type_issue(
2679                        strict,
2680                        validation_codes::E_INVALID_PREDICATE_TYPE,
2681                        validation_codes::W_PREDICATE_NULLABILITY,
2682                        format!("NOT expects boolean, found {}", type_family_name(family)),
2683                    ));
2684                }
2685            }
2686            Expression::Eq(op)
2687            | Expression::Neq(op)
2688            | Expression::Lt(op)
2689            | Expression::Lte(op)
2690            | Expression::Gt(op)
2691            | Expression::Gte(op) => {
2692                let left = infer_expression_type_family(&op.left, schema_map, &context);
2693                let right = infer_expression_type_family(&op.right, schema_map, &context);
2694                if !are_comparable(left, right) {
2695                    errors.push(type_issue(
2696                        strict,
2697                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2698                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2699                        format!(
2700                            "Incompatible comparison between {} and {}",
2701                            type_family_name(left),
2702                            type_family_name(right)
2703                        ),
2704                    ));
2705                }
2706            }
2707            Expression::Like(op) | Expression::ILike(op) => {
2708                let left = infer_expression_type_family(&op.left, schema_map, &context);
2709                let right = infer_expression_type_family(&op.right, schema_map, &context);
2710                if left != TypeFamily::Unknown
2711                    && right != TypeFamily::Unknown
2712                    && (!is_string_like(left) || !is_string_like(right))
2713                {
2714                    errors.push(type_issue(
2715                        strict,
2716                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2717                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2718                        format!(
2719                            "LIKE/ILIKE expects string operands, found {} and {}",
2720                            type_family_name(left),
2721                            type_family_name(right)
2722                        ),
2723                    ));
2724                }
2725            }
2726            Expression::Between(between) => {
2727                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2728                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2729                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2730
2731                if !are_comparable(this_family, low_family)
2732                    || !are_comparable(this_family, high_family)
2733                {
2734                    errors.push(type_issue(
2735                        strict,
2736                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2737                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2738                        format!(
2739                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2740                            type_family_name(this_family),
2741                            type_family_name(low_family),
2742                            type_family_name(high_family)
2743                        ),
2744                    ));
2745                }
2746            }
2747            Expression::In(in_expr) => {
2748                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2749                for value in &in_expr.expressions {
2750                    let item_family = infer_expression_type_family(value, schema_map, &context);
2751                    if !are_comparable(this_family, item_family) {
2752                        errors.push(type_issue(
2753                            strict,
2754                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2755                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2756                            format!(
2757                                "IN item type {} is incompatible with {}",
2758                                type_family_name(item_family),
2759                                type_family_name(this_family)
2760                            ),
2761                        ));
2762                        break;
2763                    }
2764                }
2765            }
2766            Expression::Add(op)
2767            | Expression::Sub(op)
2768            | Expression::Mul(op)
2769            | Expression::Div(op)
2770            | Expression::Mod(op) => {
2771                let left = infer_expression_type_family(&op.left, schema_map, &context);
2772                let right = infer_expression_type_family(&op.right, schema_map, &context);
2773
2774                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2775                    continue;
2776                }
2777
2778                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2779                    && ((left.is_temporal() && right.is_numeric())
2780                        || (right.is_temporal() && left.is_numeric())
2781                        || (matches!(node, Expression::Sub(_))
2782                            && left.is_temporal()
2783                            && right.is_temporal()));
2784
2785                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2786                    errors.push(type_issue(
2787                        strict,
2788                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2789                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2790                        format!(
2791                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2792                            type_family_name(left),
2793                            type_family_name(right)
2794                        ),
2795                    ));
2796                }
2797            }
2798            Expression::Function(function) => {
2799                check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2800                check_generic_function(function, schema_map, &context, strict, &mut errors);
2801            }
2802            Expression::Upper(func)
2803            | Expression::Lower(func)
2804            | Expression::LTrim(func)
2805            | Expression::RTrim(func)
2806            | Expression::Reverse(func) => {
2807                let family = infer_expression_type_family(&func.this, schema_map, &context);
2808                check_function_argument(
2809                    &mut errors,
2810                    strict,
2811                    "string_function",
2812                    0,
2813                    family,
2814                    "a string argument",
2815                    is_string_like(family),
2816                );
2817            }
2818            Expression::Length(func) => {
2819                let family = infer_expression_type_family(&func.this, schema_map, &context);
2820                check_function_argument(
2821                    &mut errors,
2822                    strict,
2823                    "length",
2824                    0,
2825                    family,
2826                    "a string or binary argument",
2827                    is_string_or_binary(family),
2828                );
2829            }
2830            Expression::Trim(func) => {
2831                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2832                check_function_argument(
2833                    &mut errors,
2834                    strict,
2835                    "trim",
2836                    0,
2837                    this_family,
2838                    "a string argument",
2839                    is_string_like(this_family),
2840                );
2841                if let Some(chars) = &func.characters {
2842                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2843                    check_function_argument(
2844                        &mut errors,
2845                        strict,
2846                        "trim",
2847                        1,
2848                        chars_family,
2849                        "a string argument",
2850                        is_string_like(chars_family),
2851                    );
2852                }
2853            }
2854            Expression::Substring(func) => {
2855                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2856                check_function_argument(
2857                    &mut errors,
2858                    strict,
2859                    "substring",
2860                    0,
2861                    this_family,
2862                    "a string argument",
2863                    is_string_like(this_family),
2864                );
2865
2866                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2867                check_function_argument(
2868                    &mut errors,
2869                    strict,
2870                    "substring",
2871                    1,
2872                    start_family,
2873                    "a numeric argument",
2874                    start_family.is_numeric(),
2875                );
2876                if let Some(length) = &func.length {
2877                    let length_family = infer_expression_type_family(length, schema_map, &context);
2878                    check_function_argument(
2879                        &mut errors,
2880                        strict,
2881                        "substring",
2882                        2,
2883                        length_family,
2884                        "a numeric argument",
2885                        length_family.is_numeric(),
2886                    );
2887                }
2888            }
2889            Expression::Replace(func) => {
2890                for (arg, idx) in [
2891                    (&func.this, 0_usize),
2892                    (&func.old, 1_usize),
2893                    (&func.new, 2_usize),
2894                ] {
2895                    let family = infer_expression_type_family(arg, schema_map, &context);
2896                    check_function_argument(
2897                        &mut errors,
2898                        strict,
2899                        "replace",
2900                        idx,
2901                        family,
2902                        "a string argument",
2903                        is_string_like(family),
2904                    );
2905                }
2906            }
2907            Expression::Left(func) | Expression::Right(func) => {
2908                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2909                check_function_argument(
2910                    &mut errors,
2911                    strict,
2912                    "left_right",
2913                    0,
2914                    this_family,
2915                    "a string argument",
2916                    is_string_like(this_family),
2917                );
2918                let length_family =
2919                    infer_expression_type_family(&func.length, schema_map, &context);
2920                check_function_argument(
2921                    &mut errors,
2922                    strict,
2923                    "left_right",
2924                    1,
2925                    length_family,
2926                    "a numeric argument",
2927                    length_family.is_numeric(),
2928                );
2929            }
2930            Expression::Repeat(func) => {
2931                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2932                check_function_argument(
2933                    &mut errors,
2934                    strict,
2935                    "repeat",
2936                    0,
2937                    this_family,
2938                    "a string argument",
2939                    is_string_like(this_family),
2940                );
2941                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2942                check_function_argument(
2943                    &mut errors,
2944                    strict,
2945                    "repeat",
2946                    1,
2947                    times_family,
2948                    "a numeric argument",
2949                    times_family.is_numeric(),
2950                );
2951            }
2952            Expression::Lpad(func) | Expression::Rpad(func) => {
2953                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2954                check_function_argument(
2955                    &mut errors,
2956                    strict,
2957                    "pad",
2958                    0,
2959                    this_family,
2960                    "a string argument",
2961                    is_string_like(this_family),
2962                );
2963                let length_family =
2964                    infer_expression_type_family(&func.length, schema_map, &context);
2965                check_function_argument(
2966                    &mut errors,
2967                    strict,
2968                    "pad",
2969                    1,
2970                    length_family,
2971                    "a numeric argument",
2972                    length_family.is_numeric(),
2973                );
2974                if let Some(fill) = &func.fill {
2975                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
2976                    check_function_argument(
2977                        &mut errors,
2978                        strict,
2979                        "pad",
2980                        2,
2981                        fill_family,
2982                        "a string argument",
2983                        is_string_like(fill_family),
2984                    );
2985                }
2986            }
2987            Expression::Abs(func)
2988            | Expression::Sqrt(func)
2989            | Expression::Cbrt(func)
2990            | Expression::Ln(func)
2991            | Expression::Exp(func) => {
2992                let family = infer_expression_type_family(&func.this, schema_map, &context);
2993                check_function_argument(
2994                    &mut errors,
2995                    strict,
2996                    "numeric_function",
2997                    0,
2998                    family,
2999                    "a numeric argument",
3000                    family.is_numeric(),
3001                );
3002            }
3003            Expression::Round(func) => {
3004                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3005                check_function_argument(
3006                    &mut errors,
3007                    strict,
3008                    "round",
3009                    0,
3010                    this_family,
3011                    "a numeric argument",
3012                    this_family.is_numeric(),
3013                );
3014                if let Some(decimals) = &func.decimals {
3015                    let decimals_family =
3016                        infer_expression_type_family(decimals, schema_map, &context);
3017                    check_function_argument(
3018                        &mut errors,
3019                        strict,
3020                        "round",
3021                        1,
3022                        decimals_family,
3023                        "a numeric argument",
3024                        decimals_family.is_numeric(),
3025                    );
3026                }
3027            }
3028            Expression::Floor(func) => {
3029                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3030                check_function_argument(
3031                    &mut errors,
3032                    strict,
3033                    "floor",
3034                    0,
3035                    this_family,
3036                    "a numeric argument",
3037                    this_family.is_numeric(),
3038                );
3039                if let Some(scale) = &func.scale {
3040                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
3041                    check_function_argument(
3042                        &mut errors,
3043                        strict,
3044                        "floor",
3045                        1,
3046                        scale_family,
3047                        "a numeric argument",
3048                        scale_family.is_numeric(),
3049                    );
3050                }
3051            }
3052            Expression::Ceil(func) => {
3053                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3054                check_function_argument(
3055                    &mut errors,
3056                    strict,
3057                    "ceil",
3058                    0,
3059                    this_family,
3060                    "a numeric argument",
3061                    this_family.is_numeric(),
3062                );
3063                if let Some(decimals) = &func.decimals {
3064                    let decimals_family =
3065                        infer_expression_type_family(decimals, schema_map, &context);
3066                    check_function_argument(
3067                        &mut errors,
3068                        strict,
3069                        "ceil",
3070                        1,
3071                        decimals_family,
3072                        "a numeric argument",
3073                        decimals_family.is_numeric(),
3074                    );
3075                }
3076            }
3077            Expression::Power(func) => {
3078                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3079                check_function_argument(
3080                    &mut errors,
3081                    strict,
3082                    "power",
3083                    0,
3084                    left_family,
3085                    "a numeric argument",
3086                    left_family.is_numeric(),
3087                );
3088                let right_family =
3089                    infer_expression_type_family(&func.expression, schema_map, &context);
3090                check_function_argument(
3091                    &mut errors,
3092                    strict,
3093                    "power",
3094                    1,
3095                    right_family,
3096                    "a numeric argument",
3097                    right_family.is_numeric(),
3098                );
3099            }
3100            Expression::Log(func) => {
3101                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3102                check_function_argument(
3103                    &mut errors,
3104                    strict,
3105                    "log",
3106                    0,
3107                    this_family,
3108                    "a numeric argument",
3109                    this_family.is_numeric(),
3110                );
3111                if let Some(base) = &func.base {
3112                    let base_family = infer_expression_type_family(base, schema_map, &context);
3113                    check_function_argument(
3114                        &mut errors,
3115                        strict,
3116                        "log",
3117                        1,
3118                        base_family,
3119                        "a numeric argument",
3120                        base_family.is_numeric(),
3121                    );
3122                }
3123            }
3124            _ => {}
3125        }
3126    }
3127
3128    errors
3129}
3130
3131fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3132    let mut errors = Vec::new();
3133
3134    let Expression::Select(select) = stmt else {
3135        return errors;
3136    };
3137    let select_expr = Expression::Select(select.clone());
3138
3139    // W001: SELECT * is discouraged
3140    if !select_expr
3141        .find_all(|e| matches!(e, Expression::Star(_)))
3142        .is_empty()
3143    {
3144        errors.push(ValidationError::warning(
3145            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3146            validation_codes::W_SELECT_STAR,
3147        ));
3148    }
3149
3150    // W002: aggregate + non-aggregate columns without GROUP BY
3151    let aggregate_count = get_aggregate_functions(&select_expr).len();
3152    if aggregate_count > 0 && select.group_by.is_none() {
3153        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3154            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3155                && get_aggregate_functions(expr).is_empty()
3156        });
3157
3158        if has_non_aggregate_column {
3159            errors.push(ValidationError::warning(
3160                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3161                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3162            ));
3163        }
3164    }
3165
3166    // W003: DISTINCT with ORDER BY
3167    if select.distinct && select.order_by.is_some() {
3168        errors.push(ValidationError::warning(
3169            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3170            validation_codes::W_DISTINCT_ORDER_BY,
3171        ));
3172    }
3173
3174    // W004: LIMIT without ORDER BY
3175    if select.limit.is_some() && select.order_by.is_none() {
3176        errors.push(ValidationError::warning(
3177            "LIMIT without ORDER BY produces non-deterministic results",
3178            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3179        ));
3180    }
3181
3182    errors
3183}
3184
3185fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3186    scope
3187        .sources
3188        .get_key_value(name)
3189        .map(|(k, _)| k.clone())
3190        .or_else(|| {
3191            scope
3192                .sources
3193                .keys()
3194                .find(|source| source.eq_ignore_ascii_case(name))
3195                .cloned()
3196        })
3197}
3198
3199fn source_has_column(columns: &[String], column_name: &str) -> bool {
3200    columns
3201        .iter()
3202        .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3203}
3204
3205fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3206    scope
3207        .sources
3208        .get(source_name)
3209        .map(|source| match &source.expression {
3210            Expression::Table(table) => lower(&table_ref_display_name(table)),
3211            _ => lower(source_name),
3212        })
3213        .unwrap_or_else(|| lower(source_name))
3214}
3215
3216fn validate_select_columns_with_schema(
3217    select: &crate::expressions::Select,
3218    schema_map: &HashMap<String, TableSchemaEntry>,
3219    resolver_schema: &MappingSchema,
3220    strict: bool,
3221) -> Vec<ValidationError> {
3222    let mut errors = Vec::new();
3223    let select_expr = Expression::Select(Box::new(select.clone()));
3224    let scope = build_scope(&select_expr);
3225    let mut resolver = Resolver::new(&scope, resolver_schema, true);
3226    let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3227
3228    for node in walk_in_scope(&select_expr, false) {
3229        let Expression::Column(column) = node else {
3230            continue;
3231        };
3232
3233        let col_name = lower(&column.name.name);
3234        if col_name.is_empty() {
3235            continue;
3236        }
3237
3238        if let Some(table) = &column.table {
3239            let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3240                // The table qualifier is not a declared alias or source in this scope
3241                errors.push(if strict {
3242                    ValidationError::error(
3243                        format!(
3244                            "Unknown table or alias '{}' referenced by column '{}'",
3245                            table.name, col_name
3246                        ),
3247                        validation_codes::E_UNRESOLVED_REFERENCE,
3248                    )
3249                } else {
3250                    ValidationError::warning(
3251                        format!(
3252                            "Unknown table or alias '{}' referenced by column '{}'",
3253                            table.name, col_name
3254                        ),
3255                        validation_codes::E_UNRESOLVED_REFERENCE,
3256                    )
3257                });
3258                continue;
3259            };
3260
3261            if let Ok(columns) = resolver.get_source_columns(&source_name) {
3262                if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3263                    let table_name = source_display_name(&scope, &source_name);
3264                    errors.push(if strict {
3265                        ValidationError::error(
3266                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3267                            validation_codes::E_UNKNOWN_COLUMN,
3268                        )
3269                    } else {
3270                        ValidationError::warning(
3271                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3272                            validation_codes::E_UNKNOWN_COLUMN,
3273                        )
3274                    });
3275                }
3276            }
3277            continue;
3278        }
3279
3280        let matching_sources: Vec<String> = source_names
3281            .iter()
3282            .filter_map(|source_name| {
3283                resolver
3284                    .get_source_columns(source_name)
3285                    .ok()
3286                    .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3287                    .map(|_| source_name.clone())
3288            })
3289            .collect();
3290
3291        if !matching_sources.is_empty() {
3292            continue;
3293        }
3294
3295        let known_sources: Vec<String> = source_names
3296            .iter()
3297            .filter_map(|source_name| {
3298                resolver
3299                    .get_source_columns(source_name)
3300                    .ok()
3301                    .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3302                    .map(|_| source_name.clone())
3303            })
3304            .collect();
3305
3306        if known_sources.len() == 1 {
3307            let table_name = source_display_name(&scope, &known_sources[0]);
3308            errors.push(if strict {
3309                ValidationError::error(
3310                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3311                    validation_codes::E_UNKNOWN_COLUMN,
3312                )
3313            } else {
3314                ValidationError::warning(
3315                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3316                    validation_codes::E_UNKNOWN_COLUMN,
3317                )
3318            });
3319        } else if known_sources.len() > 1 {
3320            errors.push(if strict {
3321                ValidationError::error(
3322                    format!(
3323                        "Unknown column '{}' (not found in any referenced table)",
3324                        col_name
3325                    ),
3326                    validation_codes::E_UNKNOWN_COLUMN,
3327                )
3328            } else {
3329                ValidationError::warning(
3330                    format!(
3331                        "Unknown column '{}' (not found in any referenced table)",
3332                        col_name
3333                    ),
3334                    validation_codes::E_UNKNOWN_COLUMN,
3335                )
3336            });
3337        } else if !schema_map.is_empty() {
3338            let found = schema_map
3339                .values()
3340                .any(|table_schema| table_schema.columns.contains_key(&col_name));
3341            if !found {
3342                errors.push(if strict {
3343                    ValidationError::error(
3344                        format!("Unknown column '{}'", col_name),
3345                        validation_codes::E_UNKNOWN_COLUMN,
3346                    )
3347                } else {
3348                    ValidationError::warning(
3349                        format!("Unknown column '{}'", col_name),
3350                        validation_codes::E_UNKNOWN_COLUMN,
3351                    )
3352                });
3353            }
3354        }
3355    }
3356
3357    errors
3358}
3359
3360fn validate_statement_with_schema(
3361    stmt: &Expression,
3362    schema_map: &HashMap<String, TableSchemaEntry>,
3363    resolver_schema: &MappingSchema,
3364    strict: bool,
3365) -> Vec<ValidationError> {
3366    let mut errors = Vec::new();
3367    let cte_aliases = collect_cte_aliases(stmt);
3368    let mut seen_tables: HashSet<String> = HashSet::new();
3369
3370    // Table validation (E200)
3371    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3372        let Expression::Table(table) = node else {
3373            continue;
3374        };
3375
3376        if cte_aliases.contains(&lower(&table.name.name)) {
3377            continue;
3378        }
3379
3380        let resolved_key = table_ref_candidates(table)
3381            .into_iter()
3382            .find(|k| schema_map.contains_key(k));
3383        let table_key = resolved_key
3384            .clone()
3385            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3386
3387        if !seen_tables.insert(table_key) {
3388            continue;
3389        }
3390
3391        if resolved_key.is_none() {
3392            errors.push(if strict {
3393                ValidationError::error(
3394                    format!("Unknown table '{}'", table_ref_display_name(table)),
3395                    validation_codes::E_UNKNOWN_TABLE,
3396                )
3397            } else {
3398                ValidationError::warning(
3399                    format!("Unknown table '{}'", table_ref_display_name(table)),
3400                    validation_codes::E_UNKNOWN_TABLE,
3401                )
3402            });
3403        }
3404    }
3405
3406    for node in stmt.dfs() {
3407        let Expression::Select(select) = node else {
3408            continue;
3409        };
3410        errors.extend(validate_select_columns_with_schema(
3411            select,
3412            schema_map,
3413            resolver_schema,
3414            strict,
3415        ));
3416    }
3417
3418    errors
3419}
3420
3421/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3422pub fn validate_with_schema(
3423    sql: &str,
3424    dialect: DialectType,
3425    schema: &ValidationSchema,
3426    options: &SchemaValidationOptions,
3427) -> ValidationResult {
3428    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3429
3430    // Syntax validation first.
3431    let syntax_result = crate::validate_with_options(
3432        sql,
3433        dialect,
3434        &crate::ValidationOptions {
3435            strict_syntax: options.strict_syntax,
3436        },
3437    );
3438    if !syntax_result.valid {
3439        return syntax_result;
3440    }
3441
3442    let d = Dialect::get(dialect);
3443    let statements = match d.parse(sql) {
3444        Ok(exprs) => exprs,
3445        Err(e) => {
3446            return ValidationResult::with_errors(vec![ValidationError::error(
3447                e.to_string(),
3448                validation_codes::E_PARSE_OR_OPTIONS,
3449            )]);
3450        }
3451    };
3452
3453    let schema_map = build_schema_map(schema);
3454    let resolver_schema = build_resolver_schema(schema);
3455    let mut all_errors = syntax_result.errors;
3456    let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3457        default_embedded_function_catalog()
3458    } else {
3459        None
3460    };
3461    let effective_function_catalog = options
3462        .function_catalog
3463        .as_deref()
3464        .or_else(|| embedded_function_catalog.as_deref());
3465    let declared_relationships = if options.check_references {
3466        build_declared_relationships(schema, &schema_map)
3467    } else {
3468        Vec::new()
3469    };
3470
3471    if options.check_references {
3472        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3473    }
3474
3475    for stmt in &statements {
3476        if options.semantic {
3477            all_errors.extend(check_semantics(stmt));
3478        }
3479        all_errors.extend(validate_statement_with_schema(
3480            stmt,
3481            &schema_map,
3482            &resolver_schema,
3483            strict,
3484        ));
3485        if options.check_types {
3486            all_errors.extend(check_types(
3487                stmt,
3488                dialect,
3489                &schema_map,
3490                effective_function_catalog,
3491                strict,
3492            ));
3493        }
3494        if options.check_references {
3495            all_errors.extend(check_query_reference_quality(
3496                stmt,
3497                &schema_map,
3498                &resolver_schema,
3499                strict,
3500                &declared_relationships,
3501            ));
3502        }
3503    }
3504
3505    ValidationResult::with_errors(all_errors)
3506}
3507
3508#[cfg(test)]
3509mod tests;