Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15    feature = "function-catalog-clickhouse",
16    feature = "function-catalog-duckdb",
17    feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20    FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21    HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34    feature = "function-catalog-clickhouse",
35    feature = "function-catalog-duckdb",
36    feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40/// Column definition used for schema-aware validation.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43    /// Column name.
44    pub name: String,
45    /// Optional column data type (currently informational).
46    #[serde(default, rename = "type")]
47    pub data_type: String,
48    /// Whether the column allows NULL values.
49    #[serde(default)]
50    pub nullable: Option<bool>,
51    /// Whether this column is part of a primary key.
52    #[serde(default, rename = "primaryKey")]
53    pub primary_key: bool,
54    /// Whether this column has a uniqueness constraint.
55    #[serde(default)]
56    pub unique: bool,
57    /// Optional column-level foreign key reference.
58    #[serde(default)]
59    pub references: Option<SchemaColumnReference>,
60}
61
62/// Column-level foreign key reference metadata.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65    /// Referenced table name.
66    pub table: String,
67    /// Referenced column name.
68    pub column: String,
69    /// Optional schema/namespace of referenced table.
70    #[serde(default)]
71    pub schema: Option<String>,
72}
73
74/// Table-level foreign key reference metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77    /// Optional FK name.
78    #[serde(default)]
79    pub name: Option<String>,
80    /// Source columns in the current table.
81    pub columns: Vec<String>,
82    /// Referenced target table + columns.
83    pub references: SchemaTableReference,
84}
85
86/// Target of a table-level foreign key.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89    /// Referenced table name.
90    pub table: String,
91    /// Referenced target columns.
92    pub columns: Vec<String>,
93    /// Optional schema/namespace of referenced table.
94    #[serde(default)]
95    pub schema: Option<String>,
96}
97
98/// Table definition used for schema-aware validation.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101    /// Table name.
102    pub name: String,
103    /// Optional schema/namespace name.
104    #[serde(default)]
105    pub schema: Option<String>,
106    /// Column definitions.
107    pub columns: Vec<SchemaColumn>,
108    /// Optional aliases that should resolve to this table.
109    #[serde(default)]
110    pub aliases: Vec<String>,
111    /// Optional primary key column list.
112    #[serde(default, rename = "primaryKey")]
113    pub primary_key: Vec<String>,
114    /// Optional unique key groups.
115    #[serde(default, rename = "uniqueKeys")]
116    pub unique_keys: Vec<Vec<String>>,
117    /// Optional table-level foreign keys.
118    #[serde(default, rename = "foreignKeys")]
119    pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122/// Schema payload used for schema-aware validation.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125    /// Known tables.
126    pub tables: Vec<SchemaTable>,
127    /// Default strict mode for unknown identifiers.
128    #[serde(default)]
129    pub strict: Option<bool>,
130}
131
132/// Options for schema-aware validation.
133#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
136    #[serde(default)]
137    pub check_types: bool,
138    /// Enables FK/reference integrity checks and query-level reference quality checks.
139    #[serde(default)]
140    pub check_references: bool,
141    /// If true/false, overrides schema.strict.
142    #[serde(default)]
143    pub strict: Option<bool>,
144    /// Enables semantic warnings (W001..W004).
145    #[serde(default)]
146    pub semantic: bool,
147    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
148    #[serde(default)]
149    pub strict_syntax: bool,
150    /// Optional external function catalog plugin for dialect-specific function validation.
151    #[serde(skip, default)]
152    pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156    feature = "function-catalog-clickhouse",
157    feature = "function-catalog-duckdb",
158    feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(case: polyglot_sql_function_catalogs::FunctionNameCase) -> CoreFunctionNameCase {
161    match case {
162        polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
163            CoreFunctionNameCase::Insensitive
164        }
165        polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => CoreFunctionNameCase::Sensitive,
166    }
167}
168
169#[cfg(any(
170    feature = "function-catalog-clickhouse",
171    feature = "function-catalog-duckdb",
172    feature = "function-catalog-all-dialects"
173))]
174fn to_core_signatures(
175    signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
176) -> Vec<CoreFunctionSignature> {
177    signatures
178        .into_iter()
179        .map(|signature| CoreFunctionSignature {
180            min_arity: signature.min_arity,
181            max_arity: signature.max_arity,
182        })
183        .collect()
184}
185
186#[cfg(any(
187    feature = "function-catalog-clickhouse",
188    feature = "function-catalog-duckdb",
189    feature = "function-catalog-all-dialects"
190))]
191struct EmbeddedCatalogSink<'a> {
192    catalog: &'a mut HashMapFunctionCatalog,
193    dialect_cache: HashMap<&'static str, Option<DialectType>>,
194}
195
196#[cfg(any(
197    feature = "function-catalog-clickhouse",
198    feature = "function-catalog-duckdb",
199    feature = "function-catalog-all-dialects"
200))]
201impl<'a> EmbeddedCatalogSink<'a> {
202    fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
203        if let Some(cached) = self.dialect_cache.get(dialect) {
204            return *cached;
205        }
206        let parsed = dialect.parse::<DialectType>().ok();
207        self.dialect_cache.insert(dialect, parsed);
208        parsed
209    }
210}
211
212#[cfg(any(
213    feature = "function-catalog-clickhouse",
214    feature = "function-catalog-duckdb",
215    feature = "function-catalog-all-dialects"
216))]
217impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
218    fn set_dialect_name_case(
219        &mut self,
220        dialect: &'static str,
221        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
222    ) {
223        if let Some(core_dialect) = self.resolve_dialect(dialect) {
224            self.catalog
225                .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
226        }
227    }
228
229    fn set_function_name_case(
230        &mut self,
231        dialect: &'static str,
232        function_name: &str,
233        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
234    ) {
235        if let Some(core_dialect) = self.resolve_dialect(dialect) {
236            self.catalog
237                .set_function_name_case(core_dialect, function_name, to_core_name_case(name_case));
238        }
239    }
240
241    fn register(
242        &mut self,
243        dialect: &'static str,
244        function_name: &str,
245        signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
246    ) {
247        if let Some(core_dialect) = self.resolve_dialect(dialect) {
248            self.catalog
249                .register(core_dialect, function_name, to_core_signatures(signatures));
250        }
251    }
252}
253
254#[cfg(any(
255    feature = "function-catalog-clickhouse",
256    feature = "function-catalog-duckdb",
257    feature = "function-catalog-all-dialects"
258))]
259fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
260    static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
261        let mut catalog = HashMapFunctionCatalog::default();
262        let mut sink = EmbeddedCatalogSink {
263            catalog: &mut catalog,
264            dialect_cache: HashMap::new(),
265        };
266        polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
267        Arc::new(catalog)
268    });
269
270    EMBEDDED.clone()
271}
272
273#[cfg(any(
274    feature = "function-catalog-clickhouse",
275    feature = "function-catalog-duckdb",
276    feature = "function-catalog-all-dialects"
277))]
278fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
279    Some(embedded_function_catalog_arc())
280}
281
282#[cfg(not(any(
283    feature = "function-catalog-clickhouse",
284    feature = "function-catalog-duckdb",
285    feature = "function-catalog-all-dialects"
286)))]
287fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
288    None
289}
290
291/// Validation error/warning codes used by schema-aware validation.
292pub mod validation_codes {
293    // Existing schema and semantic checks.
294    pub const E_PARSE_OR_OPTIONS: &str = "E000";
295    pub const E_UNKNOWN_TABLE: &str = "E200";
296    pub const E_UNKNOWN_COLUMN: &str = "E201";
297    pub const E_UNKNOWN_FUNCTION: &str = "E202";
298    pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
299
300    pub const W_SELECT_STAR: &str = "W001";
301    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
302    pub const W_DISTINCT_ORDER_BY: &str = "W003";
303    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
304
305    // Phase 2 (type checks): E210-E219, W210-W219.
306    pub const E_TYPE_MISMATCH: &str = "E210";
307    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
308    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
309    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
310    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
311    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
312    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
313    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
314    pub const E_INVALID_CAST: &str = "E218";
315    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
316
317    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
318    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
319    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
320    pub const W_LOSSY_CAST: &str = "W213";
321    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
322    pub const W_PREDICATE_NULLABILITY: &str = "W215";
323    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
324    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
325    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
326    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
327
328    // Phase 2 (reference checks): E220-E229, W220-W229.
329    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
330    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
331    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
332    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
333    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
334
335    pub const W_CARTESIAN_JOIN: &str = "W220";
336    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
337    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
338}
339
340/// Canonical type family used by schema/type checks.
341#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
342#[serde(rename_all = "snake_case")]
343pub enum TypeFamily {
344    Unknown,
345    Boolean,
346    Integer,
347    Numeric,
348    String,
349    Binary,
350    Date,
351    Time,
352    Timestamp,
353    Interval,
354    Json,
355    Uuid,
356    Array,
357    Map,
358    Struct,
359}
360
361impl TypeFamily {
362    pub fn is_numeric(self) -> bool {
363        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
364    }
365
366    pub fn is_temporal(self) -> bool {
367        matches!(
368            self,
369            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
370        )
371    }
372}
373
374#[derive(Debug, Clone)]
375struct TableSchemaEntry {
376    columns: HashMap<String, TypeFamily>,
377    column_order: Vec<String>,
378}
379
380fn lower(s: &str) -> String {
381    s.to_lowercase()
382}
383
384fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
385    let open = data_type.find('(')?;
386    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
387        return None;
388    }
389    let base = data_type[..open].trim();
390    let inner = data_type[open + 1..data_type.len() - 1].trim();
391    Some((base, inner))
392}
393
394/// Canonicalize a schema type string into a stable `TypeFamily`.
395pub fn canonical_type_family(data_type: &str) -> TypeFamily {
396    let trimmed = data_type
397        .trim()
398        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
399    if trimmed.is_empty() {
400        return TypeFamily::Unknown;
401    }
402
403    // Normalize whitespace and lowercase for matching.
404    let normalized = trimmed
405        .split_whitespace()
406        .collect::<Vec<_>>()
407        .join(" ")
408        .to_lowercase();
409
410    // Strip common wrappers first.
411    if let Some((base, inner)) = split_type_args(&normalized) {
412        match base {
413            "nullable" | "lowcardinality" => return canonical_type_family(inner),
414            "array" | "list" => return TypeFamily::Array,
415            "map" => return TypeFamily::Map,
416            "struct" | "row" | "record" => return TypeFamily::Struct,
417            _ => {}
418        }
419    }
420
421    if normalized.starts_with("array<") || normalized.starts_with("list<") {
422        return TypeFamily::Array;
423    }
424    if normalized.starts_with("map<") {
425        return TypeFamily::Map;
426    }
427    if normalized.starts_with("struct<")
428        || normalized.starts_with("row<")
429        || normalized.starts_with("record<")
430        || normalized.starts_with("object<")
431    {
432        return TypeFamily::Struct;
433    }
434
435    if normalized.ends_with("[]") {
436        return TypeFamily::Array;
437    }
438
439    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
440    let mut base = normalized
441        .split('(')
442        .next()
443        .unwrap_or("")
444        .trim()
445        .to_string();
446    if base.is_empty() {
447        return TypeFamily::Unknown;
448    }
449
450    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
451    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
452
453    match base.as_str() {
454        "bool" | "boolean" => TypeFamily::Boolean,
455        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
456        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
457        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
458            TypeFamily::Integer
459        }
460        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
461        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
462            TypeFamily::Numeric
463        }
464        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
465        | "string" | "clob" => TypeFamily::String,
466        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
467        "date" => TypeFamily::Date,
468        "time" => TypeFamily::Time,
469        "timestamp"
470        | "timestamptz"
471        | "datetime"
472        | "datetime2"
473        | "smalldatetime"
474        | "timestamp with time zone"
475        | "timestamp without time zone" => TypeFamily::Timestamp,
476        "interval" => TypeFamily::Interval,
477        "json" | "jsonb" | "variant" => TypeFamily::Json,
478        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
479        "array" | "list" => TypeFamily::Array,
480        "map" => TypeFamily::Map,
481        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
482        _ => TypeFamily::Unknown,
483    }
484}
485
486fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
487    let mut map = HashMap::new();
488
489    for table in &schema.tables {
490        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
491        let columns: HashMap<String, TypeFamily> = table
492            .columns
493            .iter()
494            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
495            .collect();
496        let entry = TableSchemaEntry {
497            columns,
498            column_order,
499        };
500
501        let simple_name = lower(&table.name);
502        map.insert(simple_name, entry.clone());
503
504        if let Some(table_schema) = &table.schema {
505            map.insert(
506                format!("{}.{}", lower(table_schema), lower(&table.name)),
507                entry.clone(),
508            );
509        }
510
511        for alias in &table.aliases {
512            map.insert(lower(alias), entry.clone());
513        }
514    }
515
516    map
517}
518
519fn type_family_to_data_type(family: TypeFamily) -> DataType {
520    match family {
521        TypeFamily::Unknown => DataType::Unknown,
522        TypeFamily::Boolean => DataType::Boolean,
523        TypeFamily::Integer => DataType::Int {
524            length: None,
525            integer_spelling: false,
526        },
527        TypeFamily::Numeric => DataType::Double {
528            precision: None,
529            scale: None,
530        },
531        TypeFamily::String => DataType::VarChar {
532            length: None,
533            parenthesized_length: false,
534        },
535        TypeFamily::Binary => DataType::VarBinary { length: None },
536        TypeFamily::Date => DataType::Date,
537        TypeFamily::Time => DataType::Time {
538            precision: None,
539            timezone: false,
540        },
541        TypeFamily::Timestamp => DataType::Timestamp {
542            precision: None,
543            timezone: false,
544        },
545        TypeFamily::Interval => DataType::Interval {
546            unit: None,
547            to: None,
548        },
549        TypeFamily::Json => DataType::Json,
550        TypeFamily::Uuid => DataType::Uuid,
551        TypeFamily::Array => DataType::Array {
552            element_type: Box::new(DataType::Unknown),
553            dimension: None,
554        },
555        TypeFamily::Map => DataType::Map {
556            key_type: Box::new(DataType::Unknown),
557            value_type: Box::new(DataType::Unknown),
558        },
559        TypeFamily::Struct => DataType::Struct {
560            fields: Vec::new(),
561            nested: false,
562        },
563    }
564}
565
566fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
567    let mut mapping = MappingSchema::new();
568
569    for table in &schema.tables {
570        let columns: Vec<(String, DataType)> = table
571            .columns
572            .iter()
573            .map(|column| {
574                (
575                    lower(&column.name),
576                    type_family_to_data_type(canonical_type_family(&column.data_type)),
577                )
578            })
579            .collect();
580
581        let mut table_names = Vec::new();
582        table_names.push(lower(&table.name));
583        if let Some(table_schema) = &table.schema {
584            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
585        }
586        for alias in &table.aliases {
587            table_names.push(lower(alias));
588        }
589
590        let mut dedup = HashSet::new();
591        for table_name in table_names {
592            if dedup.insert(table_name.clone()) {
593                let _ = mapping.add_table(&table_name, &columns, None);
594            }
595        }
596    }
597
598    mapping
599}
600
601fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
602    let mut aliases = HashSet::new();
603
604    for node in expr.dfs() {
605        match node {
606            Expression::Select(select) => {
607                if let Some(with) = &select.with {
608                    for cte in &with.ctes {
609                        aliases.insert(lower(&cte.alias.name));
610                    }
611                }
612            }
613            Expression::Insert(insert) => {
614                if let Some(with) = &insert.with {
615                    for cte in &with.ctes {
616                        aliases.insert(lower(&cte.alias.name));
617                    }
618                }
619            }
620            Expression::Update(update) => {
621                if let Some(with) = &update.with {
622                    for cte in &with.ctes {
623                        aliases.insert(lower(&cte.alias.name));
624                    }
625                }
626            }
627            Expression::Delete(delete) => {
628                if let Some(with) = &delete.with {
629                    for cte in &with.ctes {
630                        aliases.insert(lower(&cte.alias.name));
631                    }
632                }
633            }
634            Expression::Union(union) => {
635                if let Some(with) = &union.with {
636                    for cte in &with.ctes {
637                        aliases.insert(lower(&cte.alias.name));
638                    }
639                }
640            }
641            Expression::Intersect(intersect) => {
642                if let Some(with) = &intersect.with {
643                    for cte in &with.ctes {
644                        aliases.insert(lower(&cte.alias.name));
645                    }
646                }
647            }
648            Expression::Except(except) => {
649                if let Some(with) = &except.with {
650                    for cte in &with.ctes {
651                        aliases.insert(lower(&cte.alias.name));
652                    }
653                }
654            }
655            Expression::Merge(merge) => {
656                if let Some(with_) = &merge.with_ {
657                    if let Expression::With(with_clause) = with_.as_ref() {
658                        for cte in &with_clause.ctes {
659                            aliases.insert(lower(&cte.alias.name));
660                        }
661                    }
662                }
663            }
664            _ => {}
665        }
666    }
667
668    aliases
669}
670
671fn table_ref_candidates(table: &TableRef) -> Vec<String> {
672    let name = lower(&table.name.name);
673    let schema = table.schema.as_ref().map(|s| lower(&s.name));
674    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
675
676    let mut candidates = Vec::new();
677    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
678        candidates.push(format!("{}.{}.{}", catalog, schema, name));
679    }
680    if let Some(schema) = &schema {
681        candidates.push(format!("{}.{}", schema, name));
682    }
683    candidates.push(name);
684    candidates
685}
686
687fn table_ref_display_name(table: &TableRef) -> String {
688    let mut parts = Vec::new();
689    if let Some(catalog) = &table.catalog {
690        parts.push(catalog.name.clone());
691    }
692    if let Some(schema) = &table.schema {
693        parts.push(schema.name.clone());
694    }
695    parts.push(table.name.name.clone());
696    parts.join(".")
697}
698
699#[derive(Debug, Default, Clone)]
700struct TypeCheckContext {
701    referenced_tables: HashSet<String>,
702    table_aliases: HashMap<String, String>,
703}
704
705fn type_family_name(family: TypeFamily) -> &'static str {
706    match family {
707        TypeFamily::Unknown => "unknown",
708        TypeFamily::Boolean => "boolean",
709        TypeFamily::Integer => "integer",
710        TypeFamily::Numeric => "numeric",
711        TypeFamily::String => "string",
712        TypeFamily::Binary => "binary",
713        TypeFamily::Date => "date",
714        TypeFamily::Time => "time",
715        TypeFamily::Timestamp => "timestamp",
716        TypeFamily::Interval => "interval",
717        TypeFamily::Json => "json",
718        TypeFamily::Uuid => "uuid",
719        TypeFamily::Array => "array",
720        TypeFamily::Map => "map",
721        TypeFamily::Struct => "struct",
722    }
723}
724
725fn is_string_like(family: TypeFamily) -> bool {
726    matches!(family, TypeFamily::String)
727}
728
729fn is_string_or_binary(family: TypeFamily) -> bool {
730    matches!(family, TypeFamily::String | TypeFamily::Binary)
731}
732
733fn type_issue(
734    strict: bool,
735    error_code: &str,
736    warning_code: &str,
737    message: impl Into<String>,
738) -> ValidationError {
739    if strict {
740        ValidationError::error(message.into(), error_code)
741    } else {
742        ValidationError::warning(message.into(), warning_code)
743    }
744}
745
746fn data_type_family(data_type: &DataType) -> TypeFamily {
747    match data_type {
748        DataType::Boolean => TypeFamily::Boolean,
749        DataType::TinyInt { .. }
750        | DataType::SmallInt { .. }
751        | DataType::Int { .. }
752        | DataType::BigInt { .. } => TypeFamily::Integer,
753        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
754            TypeFamily::Numeric
755        }
756        DataType::Char { .. }
757        | DataType::VarChar { .. }
758        | DataType::String { .. }
759        | DataType::Text
760        | DataType::TextWithLength { .. }
761        | DataType::CharacterSet { .. } => TypeFamily::String,
762        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
763        DataType::Date => TypeFamily::Date,
764        DataType::Time { .. } => TypeFamily::Time,
765        DataType::Timestamp { .. } => TypeFamily::Timestamp,
766        DataType::Interval { .. } => TypeFamily::Interval,
767        DataType::Json | DataType::JsonB => TypeFamily::Json,
768        DataType::Uuid => TypeFamily::Uuid,
769        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
770        DataType::Map { .. } => TypeFamily::Map,
771        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
772            TypeFamily::Struct
773        }
774        DataType::Nullable { inner } => data_type_family(inner),
775        DataType::Custom { name } => canonical_type_family(name),
776        DataType::Unknown => TypeFamily::Unknown,
777        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
778        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
779        DataType::Vector { .. } => TypeFamily::Array,
780        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
781    }
782}
783
784fn collect_type_check_context(
785    stmt: &Expression,
786    schema_map: &HashMap<String, TableSchemaEntry>,
787) -> TypeCheckContext {
788    fn add_table_to_context(
789        table: &TableRef,
790        schema_map: &HashMap<String, TableSchemaEntry>,
791        context: &mut TypeCheckContext,
792    ) {
793        let resolved_key = table_ref_candidates(table)
794            .into_iter()
795            .find(|k| schema_map.contains_key(k));
796
797        let Some(table_key) = resolved_key else {
798            return;
799        };
800
801        context.referenced_tables.insert(table_key.clone());
802        context
803            .table_aliases
804            .insert(lower(&table.name.name), table_key.clone());
805        if let Some(alias) = &table.alias {
806            context
807                .table_aliases
808                .insert(lower(&alias.name), table_key.clone());
809        }
810    }
811
812    let mut context = TypeCheckContext::default();
813    let cte_aliases = collect_cte_aliases(stmt);
814
815    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
816        let Expression::Table(table) = node else {
817            continue;
818        };
819
820        if cte_aliases.contains(&lower(&table.name.name)) {
821            continue;
822        }
823
824        add_table_to_context(table, schema_map, &mut context);
825    }
826
827    // Seed DML target tables explicitly because they are struct fields and may
828    // not appear as standalone Expression::Table nodes in traversal output.
829    match stmt {
830        Expression::Insert(insert) => {
831            add_table_to_context(&insert.table, schema_map, &mut context);
832        }
833        Expression::Update(update) => {
834            add_table_to_context(&update.table, schema_map, &mut context);
835            for table in &update.extra_tables {
836                add_table_to_context(table, schema_map, &mut context);
837            }
838        }
839        Expression::Delete(delete) => {
840            add_table_to_context(&delete.table, schema_map, &mut context);
841            for table in &delete.using {
842                add_table_to_context(table, schema_map, &mut context);
843            }
844            for table in &delete.tables {
845                add_table_to_context(table, schema_map, &mut context);
846            }
847        }
848        _ => {}
849    }
850
851    context
852}
853
854fn resolve_table_schema_entry<'a>(
855    table: &TableRef,
856    schema_map: &'a HashMap<String, TableSchemaEntry>,
857) -> Option<(String, &'a TableSchemaEntry)> {
858    let key = table_ref_candidates(table)
859        .into_iter()
860        .find(|k| schema_map.contains_key(k))?;
861    let entry = schema_map.get(&key)?;
862    Some((key, entry))
863}
864
865fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
866    if strict {
867        ValidationError::error(
868            message.into(),
869            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
870        )
871    } else {
872        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
873    }
874}
875
876fn reference_table_candidates(
877    table_name: &str,
878    explicit_schema: Option<&str>,
879    source_schema: Option<&str>,
880) -> Vec<String> {
881    let mut candidates = Vec::new();
882    let raw = lower(table_name);
883
884    if let Some(schema) = explicit_schema {
885        candidates.push(format!("{}.{}", lower(schema), raw));
886    }
887
888    if raw.contains('.') {
889        candidates.push(raw.clone());
890        if let Some(last) = raw.rsplit('.').next() {
891            candidates.push(last.to_string());
892        }
893    } else {
894        if let Some(schema) = source_schema {
895            candidates.push(format!("{}.{}", lower(schema), raw));
896        }
897        candidates.push(raw);
898    }
899
900    let mut dedup = HashSet::new();
901    candidates
902        .into_iter()
903        .filter(|c| dedup.insert(c.clone()))
904        .collect()
905}
906
907fn resolve_reference_table_key(
908    table_name: &str,
909    explicit_schema: Option<&str>,
910    source_schema: Option<&str>,
911    schema_map: &HashMap<String, TableSchemaEntry>,
912) -> Option<String> {
913    reference_table_candidates(table_name, explicit_schema, source_schema)
914        .into_iter()
915        .find(|candidate| schema_map.contains_key(candidate))
916}
917
918fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
919    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
920        return true;
921    }
922    if source == target {
923        return true;
924    }
925    if source.is_numeric() && target.is_numeric() {
926        return true;
927    }
928    if source.is_temporal() && target.is_temporal() {
929        return true;
930    }
931    false
932}
933
934fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
935    let mut hints = HashSet::new();
936    for column in &table.columns {
937        if column.primary_key || column.unique {
938            hints.insert(lower(&column.name));
939        }
940    }
941    for key_col in &table.primary_key {
942        hints.insert(lower(key_col));
943    }
944    for group in &table.unique_keys {
945        if group.len() == 1 {
946            if let Some(col) = group.first() {
947                hints.insert(lower(col));
948            }
949        }
950    }
951    hints
952}
953
954fn check_reference_integrity(
955    schema: &ValidationSchema,
956    schema_map: &HashMap<String, TableSchemaEntry>,
957    strict: bool,
958) -> Vec<ValidationError> {
959    let mut errors = Vec::new();
960
961    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
962    for table in &schema.tables {
963        let simple = lower(&table.name);
964        key_hints_lookup.insert(simple, table_key_hints(table));
965        if let Some(schema_name) = &table.schema {
966            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
967            key_hints_lookup.insert(qualified, table_key_hints(table));
968        }
969    }
970
971    for table in &schema.tables {
972        let source_table_display = if let Some(schema_name) = &table.schema {
973            format!("{}.{}", schema_name, table.name)
974        } else {
975            table.name.clone()
976        };
977        let source_schema = table.schema.as_deref();
978        let source_columns: HashMap<String, TypeFamily> = table
979            .columns
980            .iter()
981            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
982            .collect();
983
984        for source_col in &table.columns {
985            let Some(reference) = &source_col.references else {
986                continue;
987            };
988            let source_type = canonical_type_family(&source_col.data_type);
989
990            let Some(target_key) = resolve_reference_table_key(
991                &reference.table,
992                reference.schema.as_deref(),
993                source_schema,
994                schema_map,
995            ) else {
996                errors.push(reference_issue(
997                    strict,
998                    format!(
999                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1000                        source_table_display, source_col.name, reference.table
1001                    ),
1002                ));
1003                continue;
1004            };
1005
1006            let target_column = lower(&reference.column);
1007            let Some(target_entry) = schema_map.get(&target_key) else {
1008                errors.push(reference_issue(
1009                    strict,
1010                    format!(
1011                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1012                        source_table_display, source_col.name, reference.table
1013                    ),
1014                ));
1015                continue;
1016            };
1017
1018            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1019                errors.push(reference_issue(
1020                    strict,
1021                    format!(
1022                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1023                        source_table_display, source_col.name, target_key, reference.column
1024                    ),
1025                ));
1026                continue;
1027            };
1028
1029            if !key_types_compatible(source_type, target_type) {
1030                errors.push(reference_issue(
1031                    strict,
1032                    format!(
1033                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1034                        source_table_display,
1035                        source_col.name,
1036                        target_key,
1037                        reference.column,
1038                        type_family_name(source_type),
1039                        type_family_name(target_type)
1040                    ),
1041                ));
1042            }
1043
1044            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1045                if !target_key_hints.contains(&target_column) {
1046                    errors.push(ValidationError::warning(
1047                        format!(
1048                            "Referenced column '{}.{}' is not marked as primary/unique key",
1049                            target_key, reference.column
1050                        ),
1051                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1052                    ));
1053                }
1054            }
1055        }
1056
1057        for foreign_key in &table.foreign_keys {
1058            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1059                errors.push(reference_issue(
1060                    strict,
1061                    format!(
1062                        "Table-level foreign key on '{}' has empty source or target column list",
1063                        source_table_display
1064                    ),
1065                ));
1066                continue;
1067            }
1068            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1069                errors.push(reference_issue(
1070                    strict,
1071                    format!(
1072                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
1073                        source_table_display,
1074                        foreign_key.columns.len(),
1075                        foreign_key.references.columns.len()
1076                    ),
1077                ));
1078                continue;
1079            }
1080
1081            let Some(target_key) = resolve_reference_table_key(
1082                &foreign_key.references.table,
1083                foreign_key.references.schema.as_deref(),
1084                source_schema,
1085                schema_map,
1086            ) else {
1087                errors.push(reference_issue(
1088                    strict,
1089                    format!(
1090                        "Table-level foreign key on '{}' points to unknown table '{}'",
1091                        source_table_display, foreign_key.references.table
1092                    ),
1093                ));
1094                continue;
1095            };
1096
1097            let Some(target_entry) = schema_map.get(&target_key) else {
1098                errors.push(reference_issue(
1099                    strict,
1100                    format!(
1101                        "Table-level foreign key on '{}' points to unknown table '{}'",
1102                        source_table_display, foreign_key.references.table
1103                    ),
1104                ));
1105                continue;
1106            };
1107
1108            for (source_col, target_col) in foreign_key
1109                .columns
1110                .iter()
1111                .zip(foreign_key.references.columns.iter())
1112            {
1113                let source_col_name = lower(source_col);
1114                let target_col_name = lower(target_col);
1115
1116                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1117                    errors.push(reference_issue(
1118                        strict,
1119                        format!(
1120                            "Table-level foreign key on '{}' references unknown source column '{}'",
1121                            source_table_display, source_col
1122                        ),
1123                    ));
1124                    continue;
1125                };
1126
1127                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1128                    errors.push(reference_issue(
1129                        strict,
1130                        format!(
1131                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1132                            source_table_display, target_key, target_col
1133                        ),
1134                    ));
1135                    continue;
1136                };
1137
1138                if !key_types_compatible(source_type, target_type) {
1139                    errors.push(reference_issue(
1140                        strict,
1141                        format!(
1142                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1143                            source_table_display,
1144                            source_col,
1145                            target_key,
1146                            target_col,
1147                            type_family_name(source_type),
1148                            type_family_name(target_type)
1149                        ),
1150                    ));
1151                }
1152
1153                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1154                    if !target_key_hints.contains(&target_col_name) {
1155                        errors.push(ValidationError::warning(
1156                            format!(
1157                                "Referenced column '{}.{}' is not marked as primary/unique key",
1158                                target_key, target_col
1159                            ),
1160                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1161                        ));
1162                    }
1163                }
1164            }
1165        }
1166    }
1167
1168    errors
1169}
1170
1171fn resolve_unqualified_column_type(
1172    column_name: &str,
1173    schema_map: &HashMap<String, TableSchemaEntry>,
1174    context: &TypeCheckContext,
1175) -> TypeFamily {
1176    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1177        context.referenced_tables.iter().collect()
1178    } else {
1179        schema_map.keys().collect()
1180    };
1181
1182    let mut families = HashSet::new();
1183    for table_name in candidate_tables {
1184        if let Some(table_schema) = schema_map.get(table_name) {
1185            if let Some(family) = table_schema.columns.get(column_name) {
1186                families.insert(*family);
1187            }
1188        }
1189    }
1190
1191    if families.len() == 1 {
1192        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1193    } else {
1194        TypeFamily::Unknown
1195    }
1196}
1197
1198fn resolve_column_type(
1199    column: &Column,
1200    schema_map: &HashMap<String, TableSchemaEntry>,
1201    context: &TypeCheckContext,
1202) -> TypeFamily {
1203    let column_name = lower(&column.name.name);
1204    if column_name.is_empty() {
1205        return TypeFamily::Unknown;
1206    }
1207
1208    if let Some(table) = &column.table {
1209        let mut table_key = lower(&table.name);
1210        if let Some(mapped) = context.table_aliases.get(&table_key) {
1211            table_key = mapped.clone();
1212        }
1213
1214        return schema_map
1215            .get(&table_key)
1216            .and_then(|t| t.columns.get(&column_name))
1217            .copied()
1218            .unwrap_or(TypeFamily::Unknown);
1219    }
1220
1221    resolve_unqualified_column_type(&column_name, schema_map, context)
1222}
1223
1224struct TypeInferenceSchema<'a> {
1225    schema_map: &'a HashMap<String, TableSchemaEntry>,
1226    context: &'a TypeCheckContext,
1227}
1228
1229impl TypeInferenceSchema<'_> {
1230    fn resolve_table_key(&self, table: &str) -> Option<String> {
1231        let mut table_key = lower(table);
1232        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1233            table_key = mapped.clone();
1234        }
1235        if self.schema_map.contains_key(&table_key) {
1236            Some(table_key)
1237        } else {
1238            None
1239        }
1240    }
1241}
1242
1243impl SqlSchema for TypeInferenceSchema<'_> {
1244    fn dialect(&self) -> Option<DialectType> {
1245        None
1246    }
1247
1248    fn add_table(
1249        &mut self,
1250        _table: &str,
1251        _columns: &[(String, DataType)],
1252        _dialect: Option<DialectType>,
1253    ) -> SchemaResult<()> {
1254        Err(SchemaError::InvalidStructure(
1255            "Type inference schema is read-only".to_string(),
1256        ))
1257    }
1258
1259    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1260        let table_key = self
1261            .resolve_table_key(table)
1262            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1263        let entry = self
1264            .schema_map
1265            .get(&table_key)
1266            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1267        Ok(entry.column_order.clone())
1268    }
1269
1270    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1271        let col_name = lower(column);
1272        if table.is_empty() {
1273            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1274            return if family == TypeFamily::Unknown {
1275                Err(SchemaError::ColumnNotFound {
1276                    table: "<unqualified>".to_string(),
1277                    column: column.to_string(),
1278                })
1279            } else {
1280                Ok(type_family_to_data_type(family))
1281            };
1282        }
1283
1284        let table_key = self
1285            .resolve_table_key(table)
1286            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1287        let entry = self
1288            .schema_map
1289            .get(&table_key)
1290            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1291        let family =
1292            entry
1293                .columns
1294                .get(&col_name)
1295                .copied()
1296                .ok_or_else(|| SchemaError::ColumnNotFound {
1297                    table: table.to_string(),
1298                    column: column.to_string(),
1299                })?;
1300        Ok(type_family_to_data_type(family))
1301    }
1302
1303    fn has_column(&self, table: &str, column: &str) -> bool {
1304        self.get_column_type(table, column).is_ok()
1305    }
1306
1307    fn supported_table_args(&self) -> &[&str] {
1308        TABLE_PARTS
1309    }
1310
1311    fn is_empty(&self) -> bool {
1312        self.schema_map.is_empty()
1313    }
1314
1315    fn depth(&self) -> usize {
1316        1
1317    }
1318}
1319
1320fn infer_expression_type_family(
1321    expr: &Expression,
1322    schema_map: &HashMap<String, TableSchemaEntry>,
1323    context: &TypeCheckContext,
1324) -> TypeFamily {
1325    let inference_schema = TypeInferenceSchema {
1326        schema_map,
1327        context,
1328    };
1329    if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1330        let family = data_type_family(&data_type);
1331        if family != TypeFamily::Unknown {
1332            return family;
1333        }
1334    }
1335
1336    infer_expression_type_family_fallback(expr, schema_map, context)
1337}
1338
1339fn infer_expression_type_family_fallback(
1340    expr: &Expression,
1341    schema_map: &HashMap<String, TableSchemaEntry>,
1342    context: &TypeCheckContext,
1343) -> TypeFamily {
1344    match expr {
1345        Expression::Literal(literal) => match literal {
1346            crate::expressions::Literal::Number(value) => {
1347                if value.contains('.') || value.contains('e') || value.contains('E') {
1348                    TypeFamily::Numeric
1349                } else {
1350                    TypeFamily::Integer
1351                }
1352            }
1353            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1354            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1355            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1356            crate::expressions::Literal::Timestamp(_)
1357            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1358            crate::expressions::Literal::HexString(_)
1359            | crate::expressions::Literal::BitString(_)
1360            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1361            _ => TypeFamily::String,
1362        },
1363        Expression::Boolean(_) => TypeFamily::Boolean,
1364        Expression::Null(_) => TypeFamily::Unknown,
1365        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1366        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1367            data_type_family(&cast.to)
1368        }
1369        Expression::Alias(alias) => {
1370            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1371        }
1372        Expression::Neg(unary) => {
1373            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1374        }
1375        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1376            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1377            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1378            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1379                TypeFamily::Unknown
1380            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1381                TypeFamily::Integer
1382            } else if left.is_numeric() && right.is_numeric() {
1383                TypeFamily::Numeric
1384            } else if left.is_temporal() || right.is_temporal() {
1385                left
1386            } else {
1387                TypeFamily::Unknown
1388            }
1389        }
1390        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1391        Expression::Concat(_) => TypeFamily::String,
1392        Expression::Eq(_)
1393        | Expression::Neq(_)
1394        | Expression::Lt(_)
1395        | Expression::Lte(_)
1396        | Expression::Gt(_)
1397        | Expression::Gte(_)
1398        | Expression::Like(_)
1399        | Expression::ILike(_)
1400        | Expression::And(_)
1401        | Expression::Or(_)
1402        | Expression::Not(_)
1403        | Expression::Between(_)
1404        | Expression::In(_)
1405        | Expression::IsNull(_)
1406        | Expression::IsTrue(_)
1407        | Expression::IsFalse(_)
1408        | Expression::Is(_) => TypeFamily::Boolean,
1409        Expression::Length(_) => TypeFamily::Integer,
1410        Expression::Upper(_)
1411        | Expression::Lower(_)
1412        | Expression::Trim(_)
1413        | Expression::LTrim(_)
1414        | Expression::RTrim(_)
1415        | Expression::Replace(_)
1416        | Expression::Substring(_)
1417        | Expression::Left(_)
1418        | Expression::Right(_)
1419        | Expression::Repeat(_)
1420        | Expression::Lpad(_)
1421        | Expression::Rpad(_)
1422        | Expression::ConcatWs(_) => TypeFamily::String,
1423        Expression::Abs(_)
1424        | Expression::Round(_)
1425        | Expression::Floor(_)
1426        | Expression::Ceil(_)
1427        | Expression::Power(_)
1428        | Expression::Sqrt(_)
1429        | Expression::Cbrt(_)
1430        | Expression::Ln(_)
1431        | Expression::Log(_)
1432        | Expression::Exp(_) => TypeFamily::Numeric,
1433        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1434        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1435        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1436        Expression::CurrentDate(_) => TypeFamily::Date,
1437        Expression::CurrentTime(_) => TypeFamily::Time,
1438        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1439            TypeFamily::Timestamp
1440        }
1441        Expression::Interval(_) => TypeFamily::Interval,
1442        _ => TypeFamily::Unknown,
1443    }
1444}
1445
1446fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1447    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1448        return true;
1449    }
1450    if left == right {
1451        return true;
1452    }
1453    if left.is_numeric() && right.is_numeric() {
1454        return true;
1455    }
1456    if left.is_temporal() && right.is_temporal() {
1457        return true;
1458    }
1459    false
1460}
1461
1462fn check_function_argument(
1463    errors: &mut Vec<ValidationError>,
1464    strict: bool,
1465    function_name: &str,
1466    arg_index: usize,
1467    family: TypeFamily,
1468    expected: &str,
1469    valid: bool,
1470) {
1471    if family == TypeFamily::Unknown || valid {
1472        return;
1473    }
1474
1475    errors.push(type_issue(
1476        strict,
1477        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1478        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1479        format!(
1480            "Function '{}' argument {} expects {}, found {}",
1481            function_name,
1482            arg_index + 1,
1483            expected,
1484            type_family_name(family)
1485        ),
1486    ));
1487}
1488
1489fn function_dispatch_name(name: &str) -> String {
1490    let upper = name
1491        .rsplit('.')
1492        .next()
1493        .unwrap_or(name)
1494        .trim()
1495        .to_uppercase();
1496    lower(canonical_typed_function_name_upper(&upper))
1497}
1498
1499fn function_base_name(name: &str) -> &str {
1500    name.rsplit('.').next().unwrap_or(name).trim()
1501}
1502
1503fn check_generic_function(
1504    function: &Function,
1505    schema_map: &HashMap<String, TableSchemaEntry>,
1506    context: &TypeCheckContext,
1507    strict: bool,
1508    errors: &mut Vec<ValidationError>,
1509) {
1510    let name = function_dispatch_name(&function.name);
1511
1512    let arg_family = |index: usize| -> Option<TypeFamily> {
1513        function
1514            .args
1515            .get(index)
1516            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1517    };
1518
1519    match name.as_str() {
1520        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1521            if let Some(family) = arg_family(0) {
1522                check_function_argument(
1523                    errors,
1524                    strict,
1525                    &name,
1526                    0,
1527                    family,
1528                    "a numeric argument",
1529                    family.is_numeric(),
1530                );
1531            }
1532        }
1533        "round" | "floor" | "ceil" | "ceiling" => {
1534            if let Some(family) = arg_family(0) {
1535                check_function_argument(
1536                    errors,
1537                    strict,
1538                    &name,
1539                    0,
1540                    family,
1541                    "a numeric argument",
1542                    family.is_numeric(),
1543                );
1544            }
1545            if let Some(family) = arg_family(1) {
1546                check_function_argument(
1547                    errors,
1548                    strict,
1549                    &name,
1550                    1,
1551                    family,
1552                    "a numeric argument",
1553                    family.is_numeric(),
1554                );
1555            }
1556        }
1557        "power" | "pow" => {
1558            for i in [0_usize, 1_usize] {
1559                if let Some(family) = arg_family(i) {
1560                    check_function_argument(
1561                        errors,
1562                        strict,
1563                        &name,
1564                        i,
1565                        family,
1566                        "a numeric argument",
1567                        family.is_numeric(),
1568                    );
1569                }
1570            }
1571        }
1572        "length" | "char_length" | "character_length" => {
1573            if let Some(family) = arg_family(0) {
1574                check_function_argument(
1575                    errors,
1576                    strict,
1577                    &name,
1578                    0,
1579                    family,
1580                    "a string or binary argument",
1581                    is_string_or_binary(family),
1582                );
1583            }
1584        }
1585        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1586            if let Some(family) = arg_family(0) {
1587                check_function_argument(
1588                    errors,
1589                    strict,
1590                    &name,
1591                    0,
1592                    family,
1593                    "a string argument",
1594                    is_string_like(family),
1595                );
1596            }
1597        }
1598        "substring" | "substr" => {
1599            if let Some(family) = arg_family(0) {
1600                check_function_argument(
1601                    errors,
1602                    strict,
1603                    &name,
1604                    0,
1605                    family,
1606                    "a string argument",
1607                    is_string_like(family),
1608                );
1609            }
1610            if let Some(family) = arg_family(1) {
1611                check_function_argument(
1612                    errors,
1613                    strict,
1614                    &name,
1615                    1,
1616                    family,
1617                    "a numeric argument",
1618                    family.is_numeric(),
1619                );
1620            }
1621            if let Some(family) = arg_family(2) {
1622                check_function_argument(
1623                    errors,
1624                    strict,
1625                    &name,
1626                    2,
1627                    family,
1628                    "a numeric argument",
1629                    family.is_numeric(),
1630                );
1631            }
1632        }
1633        "replace" => {
1634            for i in [0_usize, 1_usize, 2_usize] {
1635                if let Some(family) = arg_family(i) {
1636                    check_function_argument(
1637                        errors,
1638                        strict,
1639                        &name,
1640                        i,
1641                        family,
1642                        "a string argument",
1643                        is_string_like(family),
1644                    );
1645                }
1646            }
1647        }
1648        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1649            if let Some(family) = arg_family(0) {
1650                check_function_argument(
1651                    errors,
1652                    strict,
1653                    &name,
1654                    0,
1655                    family,
1656                    "a string argument",
1657                    is_string_like(family),
1658                );
1659            }
1660            if let Some(family) = arg_family(1) {
1661                check_function_argument(
1662                    errors,
1663                    strict,
1664                    &name,
1665                    1,
1666                    family,
1667                    "a numeric argument",
1668                    family.is_numeric(),
1669                );
1670            }
1671            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1672                if let Some(family) = arg_family(2) {
1673                    check_function_argument(
1674                        errors,
1675                        strict,
1676                        &name,
1677                        2,
1678                        family,
1679                        "a string argument",
1680                        is_string_like(family),
1681                    );
1682                }
1683            }
1684        }
1685        _ => {}
1686    }
1687}
1688
1689fn check_function_catalog(
1690    function: &Function,
1691    dialect: DialectType,
1692    function_catalog: Option<&dyn FunctionCatalog>,
1693    strict: bool,
1694    errors: &mut Vec<ValidationError>,
1695) {
1696    let Some(catalog) = function_catalog else {
1697        return;
1698    };
1699
1700    let raw_name = function_base_name(&function.name);
1701    let normalized_name = function_dispatch_name(&function.name);
1702    let arity = function.args.len();
1703    let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1704        errors.push(if strict {
1705            ValidationError::error(
1706                format!(
1707                    "Unknown function '{}' for dialect {:?}",
1708                    function.name, dialect
1709                ),
1710                validation_codes::E_UNKNOWN_FUNCTION,
1711            )
1712        } else {
1713            ValidationError::warning(
1714                format!(
1715                    "Unknown function '{}' for dialect {:?}",
1716                    function.name, dialect
1717                ),
1718                validation_codes::E_UNKNOWN_FUNCTION,
1719            )
1720        });
1721        return;
1722    };
1723
1724    if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1725        return;
1726    }
1727
1728    let expected = signatures
1729        .iter()
1730        .map(|sig| sig.describe_arity())
1731        .collect::<Vec<_>>()
1732        .join(", ");
1733    errors.push(if strict {
1734        ValidationError::error(
1735            format!(
1736                "Invalid arity for function '{}': got {}, expected {}",
1737                function.name, arity, expected
1738            ),
1739            validation_codes::E_INVALID_FUNCTION_ARITY,
1740        )
1741    } else {
1742        ValidationError::warning(
1743            format!(
1744                "Invalid arity for function '{}': got {}, expected {}",
1745                function.name, arity, expected
1746            ),
1747            validation_codes::E_INVALID_FUNCTION_ARITY,
1748        )
1749    });
1750}
1751
1752#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1753struct DeclaredRelationship {
1754    source_table: String,
1755    source_column: String,
1756    target_table: String,
1757    target_column: String,
1758}
1759
1760fn build_declared_relationships(
1761    schema: &ValidationSchema,
1762    schema_map: &HashMap<String, TableSchemaEntry>,
1763) -> Vec<DeclaredRelationship> {
1764    let mut relationships = HashSet::new();
1765
1766    for table in &schema.tables {
1767        let Some(source_key) =
1768            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1769        else {
1770            continue;
1771        };
1772
1773        for column in &table.columns {
1774            let Some(reference) = &column.references else {
1775                continue;
1776            };
1777            let Some(target_key) = resolve_reference_table_key(
1778                &reference.table,
1779                reference.schema.as_deref(),
1780                table.schema.as_deref(),
1781                schema_map,
1782            ) else {
1783                continue;
1784            };
1785
1786            relationships.insert(DeclaredRelationship {
1787                source_table: source_key.clone(),
1788                source_column: lower(&column.name),
1789                target_table: target_key,
1790                target_column: lower(&reference.column),
1791            });
1792        }
1793
1794        for foreign_key in &table.foreign_keys {
1795            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1796                continue;
1797            }
1798            let Some(target_key) = resolve_reference_table_key(
1799                &foreign_key.references.table,
1800                foreign_key.references.schema.as_deref(),
1801                table.schema.as_deref(),
1802                schema_map,
1803            ) else {
1804                continue;
1805            };
1806
1807            for (source_col, target_col) in foreign_key
1808                .columns
1809                .iter()
1810                .zip(foreign_key.references.columns.iter())
1811            {
1812                relationships.insert(DeclaredRelationship {
1813                    source_table: source_key.clone(),
1814                    source_column: lower(source_col),
1815                    target_table: target_key.clone(),
1816                    target_column: lower(target_col),
1817                });
1818            }
1819        }
1820    }
1821
1822    relationships.into_iter().collect()
1823}
1824
1825fn resolve_column_binding(
1826    column: &Column,
1827    schema_map: &HashMap<String, TableSchemaEntry>,
1828    context: &TypeCheckContext,
1829    resolver: &mut Resolver<'_>,
1830) -> Option<(String, String)> {
1831    let column_name = lower(&column.name.name);
1832    if column_name.is_empty() {
1833        return None;
1834    }
1835
1836    if let Some(table) = &column.table {
1837        let mut table_key = lower(&table.name);
1838        if let Some(mapped) = context.table_aliases.get(&table_key) {
1839            table_key = mapped.clone();
1840        }
1841        if schema_map.contains_key(&table_key) {
1842            return Some((table_key, column_name));
1843        }
1844        return None;
1845    }
1846
1847    if let Some(resolved_source) = resolver.get_table(&column_name) {
1848        let mut table_key = lower(&resolved_source);
1849        if let Some(mapped) = context.table_aliases.get(&table_key) {
1850            table_key = mapped.clone();
1851        }
1852        if schema_map.contains_key(&table_key) {
1853            return Some((table_key, column_name));
1854        }
1855    }
1856
1857    let candidates: Vec<String> = context
1858        .referenced_tables
1859        .iter()
1860        .filter_map(|table_name| {
1861            schema_map
1862                .get(table_name)
1863                .filter(|entry| entry.columns.contains_key(&column_name))
1864                .map(|_| table_name.clone())
1865        })
1866        .collect();
1867    if candidates.len() == 1 {
1868        return Some((candidates[0].clone(), column_name));
1869    }
1870    None
1871}
1872
1873fn extract_join_equality_pairs(
1874    expr: &Expression,
1875    schema_map: &HashMap<String, TableSchemaEntry>,
1876    context: &TypeCheckContext,
1877    resolver: &mut Resolver<'_>,
1878    pairs: &mut Vec<((String, String), (String, String))>,
1879) {
1880    match expr {
1881        Expression::And(op) => {
1882            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1883            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1884        }
1885        Expression::Paren(paren) => {
1886            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1887        }
1888        Expression::Eq(op) => {
1889            let (Expression::Column(left_col), Expression::Column(right_col)) =
1890                (&op.left, &op.right)
1891            else {
1892                return;
1893            };
1894            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1895                return;
1896            };
1897            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1898            else {
1899                return;
1900            };
1901            pairs.push((left, right));
1902        }
1903        _ => {}
1904    }
1905}
1906
1907fn relationship_matches_pair(
1908    relationship: &DeclaredRelationship,
1909    left_table: &str,
1910    left_column: &str,
1911    right_table: &str,
1912    right_column: &str,
1913) -> bool {
1914    (relationship.source_table == left_table
1915        && relationship.source_column == left_column
1916        && relationship.target_table == right_table
1917        && relationship.target_column == right_column)
1918        || (relationship.source_table == right_table
1919            && relationship.source_column == right_column
1920            && relationship.target_table == left_table
1921            && relationship.target_column == left_column)
1922}
1923
1924fn resolved_table_key_from_expr(
1925    expr: &Expression,
1926    schema_map: &HashMap<String, TableSchemaEntry>,
1927) -> Option<String> {
1928    match expr {
1929        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1930        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1931        _ => None,
1932    }
1933}
1934
1935fn select_from_table_keys(
1936    select: &crate::expressions::Select,
1937    schema_map: &HashMap<String, TableSchemaEntry>,
1938) -> HashSet<String> {
1939    let mut keys = HashSet::new();
1940    if let Some(from_clause) = &select.from {
1941        for expr in &from_clause.expressions {
1942            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1943                keys.insert(key);
1944            }
1945        }
1946    }
1947    keys
1948}
1949
1950fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1951    matches!(
1952        kind,
1953        JoinKind::Natural
1954            | JoinKind::NaturalLeft
1955            | JoinKind::NaturalRight
1956            | JoinKind::NaturalFull
1957            | JoinKind::CrossApply
1958            | JoinKind::OuterApply
1959            | JoinKind::AsOf
1960            | JoinKind::AsOfLeft
1961            | JoinKind::AsOfRight
1962            | JoinKind::Lateral
1963            | JoinKind::LeftLateral
1964    )
1965}
1966
1967fn check_query_reference_quality(
1968    stmt: &Expression,
1969    schema_map: &HashMap<String, TableSchemaEntry>,
1970    resolver_schema: &MappingSchema,
1971    strict: bool,
1972    relationships: &[DeclaredRelationship],
1973) -> Vec<ValidationError> {
1974    let mut errors = Vec::new();
1975
1976    for node in stmt.dfs() {
1977        let Expression::Select(select) = node else {
1978            continue;
1979        };
1980
1981        let select_expr = Expression::Select(select.clone());
1982        let context = collect_type_check_context(&select_expr, schema_map);
1983        let scope = build_scope(&select_expr);
1984        let mut resolver = Resolver::new(&scope, resolver_schema, true);
1985
1986        if context.referenced_tables.len() > 1 {
1987            let using_columns: HashSet<String> = select
1988                .joins
1989                .iter()
1990                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1991                .collect();
1992
1993            let mut seen = HashSet::new();
1994            for column_expr in select_expr
1995                .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
1996            {
1997                let Expression::Column(column) = column_expr else {
1998                    continue;
1999                };
2000
2001                let col_name = lower(&column.name.name);
2002                if col_name.is_empty()
2003                    || using_columns.contains(&col_name)
2004                    || !seen.insert(col_name.clone())
2005                {
2006                    continue;
2007                }
2008
2009                if resolver.is_ambiguous(&col_name) {
2010                    let source_count = resolver.sources_for_column(&col_name).len();
2011                    errors.push(if strict {
2012                        ValidationError::error(
2013                            format!(
2014                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2015                                col_name, source_count
2016                            ),
2017                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2018                        )
2019                    } else {
2020                        ValidationError::warning(
2021                            format!(
2022                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2023                                col_name, source_count
2024                            ),
2025                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2026                        )
2027                    });
2028                }
2029            }
2030        }
2031
2032        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2033
2034        for join in &select.joins {
2035            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2036            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2037            let cartesian_like_kind = matches!(
2038                join.kind,
2039                JoinKind::Cross
2040                    | JoinKind::Implicit
2041                    | JoinKind::Array
2042                    | JoinKind::LeftArray
2043                    | JoinKind::Paste
2044            );
2045
2046            if right_table_key.is_some()
2047                && (cartesian_like_kind
2048                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2049            {
2050                errors.push(ValidationError::warning(
2051                    "Potential cartesian join: JOIN without ON/USING condition",
2052                    validation_codes::W_CARTESIAN_JOIN,
2053                ));
2054            }
2055
2056            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2057                if join.using.is_empty() {
2058                    let mut eq_pairs = Vec::new();
2059                    extract_join_equality_pairs(
2060                        on_expr,
2061                        schema_map,
2062                        &context,
2063                        &mut resolver,
2064                        &mut eq_pairs,
2065                    );
2066
2067                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2068                        .iter()
2069                        .filter(|rel| {
2070                            cumulative_left_tables.contains(&rel.source_table)
2071                                && rel.target_table == right_key
2072                                || (cumulative_left_tables.contains(&rel.target_table)
2073                                    && rel.source_table == right_key)
2074                        })
2075                        .collect();
2076
2077                    if !relevant_relationships.is_empty() {
2078                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2079                            relevant_relationships
2080                                .iter()
2081                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2082                        });
2083                        if !uses_declared_fk {
2084                            errors.push(ValidationError::warning(
2085                                "JOIN predicate does not use declared foreign-key relationship columns",
2086                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2087                            ));
2088                        }
2089                    }
2090                }
2091            }
2092
2093            if let Some(right_key) = right_table_key {
2094                cumulative_left_tables.insert(right_key);
2095            }
2096        }
2097    }
2098
2099    errors
2100}
2101
2102fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2103    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2104        return true;
2105    }
2106    if left == right {
2107        return true;
2108    }
2109    if left.is_numeric() && right.is_numeric() {
2110        return true;
2111    }
2112    if left.is_temporal() && right.is_temporal() {
2113        return true;
2114    }
2115    false
2116}
2117
2118fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2119    if left == TypeFamily::Unknown {
2120        return right;
2121    }
2122    if right == TypeFamily::Unknown {
2123        return left;
2124    }
2125    if left == right {
2126        return left;
2127    }
2128    if left.is_numeric() && right.is_numeric() {
2129        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2130            return TypeFamily::Numeric;
2131        }
2132        return TypeFamily::Integer;
2133    }
2134    if left.is_temporal() && right.is_temporal() {
2135        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2136            return TypeFamily::Timestamp;
2137        }
2138        if left == TypeFamily::Date || right == TypeFamily::Date {
2139            return TypeFamily::Date;
2140        }
2141        return TypeFamily::Time;
2142    }
2143    TypeFamily::Unknown
2144}
2145
2146fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2147    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2148        return true;
2149    }
2150    if target == source {
2151        return true;
2152    }
2153
2154    match target {
2155        TypeFamily::Boolean => source == TypeFamily::Boolean,
2156        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2157        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2158            source.is_temporal()
2159        }
2160        TypeFamily::String => true,
2161        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2162        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2163        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2164        TypeFamily::Array => source == TypeFamily::Array,
2165        TypeFamily::Map => source == TypeFamily::Map,
2166        TypeFamily::Struct => source == TypeFamily::Struct,
2167        TypeFamily::Unknown => true,
2168    }
2169}
2170
2171fn projection_families(
2172    query_expr: &Expression,
2173    schema_map: &HashMap<String, TableSchemaEntry>,
2174) -> Option<Vec<TypeFamily>> {
2175    match query_expr {
2176        Expression::Select(select) => {
2177            if select
2178                .expressions
2179                .iter()
2180                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2181            {
2182                return None;
2183            }
2184            let select_expr = Expression::Select(select.clone());
2185            let context = collect_type_check_context(&select_expr, schema_map);
2186            Some(
2187                select
2188                    .expressions
2189                    .iter()
2190                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2191                    .collect(),
2192            )
2193        }
2194        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2195        Expression::Union(union) => {
2196            let left = projection_families(&union.left, schema_map)?;
2197            let right = projection_families(&union.right, schema_map)?;
2198            if left.len() != right.len() {
2199                return None;
2200            }
2201            Some(
2202                left.into_iter()
2203                    .zip(right)
2204                    .map(|(l, r)| merged_setop_family(l, r))
2205                    .collect(),
2206            )
2207        }
2208        Expression::Intersect(intersect) => {
2209            let left = projection_families(&intersect.left, schema_map)?;
2210            let right = projection_families(&intersect.right, schema_map)?;
2211            if left.len() != right.len() {
2212                return None;
2213            }
2214            Some(
2215                left.into_iter()
2216                    .zip(right)
2217                    .map(|(l, r)| merged_setop_family(l, r))
2218                    .collect(),
2219            )
2220        }
2221        Expression::Except(except) => {
2222            let left = projection_families(&except.left, schema_map)?;
2223            let right = projection_families(&except.right, schema_map)?;
2224            if left.len() != right.len() {
2225                return None;
2226            }
2227            Some(
2228                left.into_iter()
2229                    .zip(right)
2230                    .map(|(l, r)| merged_setop_family(l, r))
2231                    .collect(),
2232            )
2233        }
2234        Expression::Values(values) => {
2235            let first_row = values.expressions.first()?;
2236            let context = TypeCheckContext::default();
2237            Some(
2238                first_row
2239                    .expressions
2240                    .iter()
2241                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2242                    .collect(),
2243            )
2244        }
2245        _ => None,
2246    }
2247}
2248
2249fn check_set_operation_compatibility(
2250    op_name: &str,
2251    left_expr: &Expression,
2252    right_expr: &Expression,
2253    schema_map: &HashMap<String, TableSchemaEntry>,
2254    strict: bool,
2255    errors: &mut Vec<ValidationError>,
2256) {
2257    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2258        return;
2259    };
2260    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2261        return;
2262    };
2263
2264    if left_projection.len() != right_projection.len() {
2265        errors.push(type_issue(
2266            strict,
2267            validation_codes::E_SETOP_ARITY_MISMATCH,
2268            validation_codes::W_SETOP_IMPLICIT_COERCION,
2269            format!(
2270                "{} operands return different column counts: left {}, right {}",
2271                op_name,
2272                left_projection.len(),
2273                right_projection.len()
2274            ),
2275        ));
2276        return;
2277    }
2278
2279    for (idx, (left, right)) in left_projection
2280        .into_iter()
2281        .zip(right_projection)
2282        .enumerate()
2283    {
2284        if !are_setop_compatible(left, right) {
2285            errors.push(type_issue(
2286                strict,
2287                validation_codes::E_SETOP_TYPE_MISMATCH,
2288                validation_codes::W_SETOP_IMPLICIT_COERCION,
2289                format!(
2290                    "{} column {} has incompatible types: {} vs {}",
2291                    op_name,
2292                    idx + 1,
2293                    type_family_name(left),
2294                    type_family_name(right)
2295                ),
2296            ));
2297        }
2298    }
2299}
2300
2301fn check_insert_assignments(
2302    stmt: &Expression,
2303    insert: &Insert,
2304    schema_map: &HashMap<String, TableSchemaEntry>,
2305    strict: bool,
2306    errors: &mut Vec<ValidationError>,
2307) {
2308    let Some((target_table_key, table_schema)) =
2309        resolve_table_schema_entry(&insert.table, schema_map)
2310    else {
2311        return;
2312    };
2313
2314    let mut target_columns = Vec::new();
2315    if insert.columns.is_empty() {
2316        target_columns.extend(table_schema.column_order.iter().cloned());
2317    } else {
2318        for column in &insert.columns {
2319            let col_name = lower(&column.name);
2320            if table_schema.columns.contains_key(&col_name) {
2321                target_columns.push(col_name);
2322            } else {
2323                errors.push(if strict {
2324                    ValidationError::error(
2325                        format!(
2326                            "Unknown column '{}' in table '{}'",
2327                            column.name, target_table_key
2328                        ),
2329                        validation_codes::E_UNKNOWN_COLUMN,
2330                    )
2331                } else {
2332                    ValidationError::warning(
2333                        format!(
2334                            "Unknown column '{}' in table '{}'",
2335                            column.name, target_table_key
2336                        ),
2337                        validation_codes::E_UNKNOWN_COLUMN,
2338                    )
2339                });
2340            }
2341        }
2342    }
2343
2344    if target_columns.is_empty() {
2345        return;
2346    }
2347
2348    let context = collect_type_check_context(stmt, schema_map);
2349
2350    if !insert.default_values {
2351        for (row_idx, row) in insert.values.iter().enumerate() {
2352            if row.len() != target_columns.len() {
2353                errors.push(type_issue(
2354                    strict,
2355                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2356                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2357                    format!(
2358                        "INSERT row {} has {} values but target has {} columns",
2359                        row_idx + 1,
2360                        row.len(),
2361                        target_columns.len()
2362                    ),
2363                ));
2364                continue;
2365            }
2366
2367            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2368                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2369                    continue;
2370                };
2371                let source_family = infer_expression_type_family(value, schema_map, &context);
2372                if !are_assignment_compatible(target_family, source_family) {
2373                    errors.push(type_issue(
2374                        strict,
2375                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2376                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2377                        format!(
2378                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2379                            target_table_key,
2380                            target_column,
2381                            type_family_name(target_family),
2382                            type_family_name(source_family)
2383                        ),
2384                    ));
2385                }
2386            }
2387        }
2388    }
2389
2390    if let Some(query) = &insert.query {
2391        // DuckDB BY NAME maps source columns by name, not position.
2392        if insert.by_name {
2393            return;
2394        }
2395
2396        let Some(source_projection) = projection_families(query, schema_map) else {
2397            return;
2398        };
2399
2400        if source_projection.len() != target_columns.len() {
2401            errors.push(type_issue(
2402                strict,
2403                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2404                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2405                format!(
2406                    "INSERT source query has {} columns but target has {} columns",
2407                    source_projection.len(),
2408                    target_columns.len()
2409                ),
2410            ));
2411            return;
2412        }
2413
2414        for (source_family, target_column) in
2415            source_projection.into_iter().zip(target_columns.iter())
2416        {
2417            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2418                continue;
2419            };
2420            if !are_assignment_compatible(target_family, source_family) {
2421                errors.push(type_issue(
2422                    strict,
2423                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2424                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2425                    format!(
2426                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2427                        target_table_key,
2428                        target_column,
2429                        type_family_name(target_family),
2430                        type_family_name(source_family)
2431                    ),
2432                ));
2433            }
2434        }
2435    }
2436}
2437
2438fn check_update_assignments(
2439    stmt: &Expression,
2440    update: &Update,
2441    schema_map: &HashMap<String, TableSchemaEntry>,
2442    strict: bool,
2443    errors: &mut Vec<ValidationError>,
2444) {
2445    let Some((target_table_key, table_schema)) =
2446        resolve_table_schema_entry(&update.table, schema_map)
2447    else {
2448        return;
2449    };
2450
2451    let context = collect_type_check_context(stmt, schema_map);
2452
2453    for (column, value) in &update.set {
2454        let col_name = lower(&column.name);
2455        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2456            errors.push(if strict {
2457                ValidationError::error(
2458                    format!(
2459                        "Unknown column '{}' in table '{}'",
2460                        column.name, target_table_key
2461                    ),
2462                    validation_codes::E_UNKNOWN_COLUMN,
2463                )
2464            } else {
2465                ValidationError::warning(
2466                    format!(
2467                        "Unknown column '{}' in table '{}'",
2468                        column.name, target_table_key
2469                    ),
2470                    validation_codes::E_UNKNOWN_COLUMN,
2471                )
2472            });
2473            continue;
2474        };
2475
2476        let source_family = infer_expression_type_family(value, schema_map, &context);
2477        if !are_assignment_compatible(target_family, source_family) {
2478            errors.push(type_issue(
2479                strict,
2480                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2481                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2482                format!(
2483                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2484                    target_table_key,
2485                    col_name,
2486                    type_family_name(target_family),
2487                    type_family_name(source_family)
2488                ),
2489            ));
2490        }
2491    }
2492}
2493
2494fn check_types(
2495    stmt: &Expression,
2496    dialect: DialectType,
2497    schema_map: &HashMap<String, TableSchemaEntry>,
2498    function_catalog: Option<&dyn FunctionCatalog>,
2499    strict: bool,
2500) -> Vec<ValidationError> {
2501    let mut errors = Vec::new();
2502    let context = collect_type_check_context(stmt, schema_map);
2503
2504    for node in stmt.dfs() {
2505        match node {
2506            Expression::Insert(insert) => {
2507                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2508            }
2509            Expression::Update(update) => {
2510                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2511            }
2512            Expression::Union(union) => {
2513                check_set_operation_compatibility(
2514                    "UNION",
2515                    &union.left,
2516                    &union.right,
2517                    schema_map,
2518                    strict,
2519                    &mut errors,
2520                );
2521            }
2522            Expression::Intersect(intersect) => {
2523                check_set_operation_compatibility(
2524                    "INTERSECT",
2525                    &intersect.left,
2526                    &intersect.right,
2527                    schema_map,
2528                    strict,
2529                    &mut errors,
2530                );
2531            }
2532            Expression::Except(except) => {
2533                check_set_operation_compatibility(
2534                    "EXCEPT",
2535                    &except.left,
2536                    &except.right,
2537                    schema_map,
2538                    strict,
2539                    &mut errors,
2540                );
2541            }
2542            Expression::Select(select) => {
2543                if let Some(prewhere) = &select.prewhere {
2544                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2545                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2546                        errors.push(type_issue(
2547                            strict,
2548                            validation_codes::E_INVALID_PREDICATE_TYPE,
2549                            validation_codes::W_PREDICATE_NULLABILITY,
2550                            format!(
2551                                "PREWHERE clause expects a boolean predicate, found {}",
2552                                type_family_name(family)
2553                            ),
2554                        ));
2555                    }
2556                }
2557
2558                if let Some(where_clause) = &select.where_clause {
2559                    let family =
2560                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2561                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2562                        errors.push(type_issue(
2563                            strict,
2564                            validation_codes::E_INVALID_PREDICATE_TYPE,
2565                            validation_codes::W_PREDICATE_NULLABILITY,
2566                            format!(
2567                                "WHERE clause expects a boolean predicate, found {}",
2568                                type_family_name(family)
2569                            ),
2570                        ));
2571                    }
2572                }
2573
2574                if let Some(having_clause) = &select.having {
2575                    let family =
2576                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2577                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578                        errors.push(type_issue(
2579                            strict,
2580                            validation_codes::E_INVALID_PREDICATE_TYPE,
2581                            validation_codes::W_PREDICATE_NULLABILITY,
2582                            format!(
2583                                "HAVING clause expects a boolean predicate, found {}",
2584                                type_family_name(family)
2585                            ),
2586                        ));
2587                    }
2588                }
2589
2590                for join in &select.joins {
2591                    if let Some(on) = &join.on {
2592                        let family = infer_expression_type_family(on, schema_map, &context);
2593                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594                            errors.push(type_issue(
2595                                strict,
2596                                validation_codes::E_INVALID_PREDICATE_TYPE,
2597                                validation_codes::W_PREDICATE_NULLABILITY,
2598                                format!(
2599                                    "JOIN ON expects a boolean predicate, found {}",
2600                                    type_family_name(family)
2601                                ),
2602                            ));
2603                        }
2604                    }
2605                    if let Some(match_condition) = &join.match_condition {
2606                        let family =
2607                            infer_expression_type_family(match_condition, schema_map, &context);
2608                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2609                            errors.push(type_issue(
2610                                strict,
2611                                validation_codes::E_INVALID_PREDICATE_TYPE,
2612                                validation_codes::W_PREDICATE_NULLABILITY,
2613                                format!(
2614                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2615                                    type_family_name(family)
2616                                ),
2617                            ));
2618                        }
2619                    }
2620                }
2621            }
2622            Expression::Where(where_clause) => {
2623                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2624                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2625                    errors.push(type_issue(
2626                        strict,
2627                        validation_codes::E_INVALID_PREDICATE_TYPE,
2628                        validation_codes::W_PREDICATE_NULLABILITY,
2629                        format!(
2630                            "WHERE clause expects a boolean predicate, found {}",
2631                            type_family_name(family)
2632                        ),
2633                    ));
2634                }
2635            }
2636            Expression::Having(having_clause) => {
2637                let family =
2638                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2639                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2640                    errors.push(type_issue(
2641                        strict,
2642                        validation_codes::E_INVALID_PREDICATE_TYPE,
2643                        validation_codes::W_PREDICATE_NULLABILITY,
2644                        format!(
2645                            "HAVING clause expects a boolean predicate, found {}",
2646                            type_family_name(family)
2647                        ),
2648                    ));
2649                }
2650            }
2651            Expression::And(op) | Expression::Or(op) => {
2652                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2653                    let family = infer_expression_type_family(expr, schema_map, &context);
2654                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2655                        errors.push(type_issue(
2656                            strict,
2657                            validation_codes::E_INVALID_PREDICATE_TYPE,
2658                            validation_codes::W_PREDICATE_NULLABILITY,
2659                            format!(
2660                                "Logical {} operand expects boolean, found {}",
2661                                side,
2662                                type_family_name(family)
2663                            ),
2664                        ));
2665                    }
2666                }
2667            }
2668            Expression::Not(unary) => {
2669                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2670                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2671                    errors.push(type_issue(
2672                        strict,
2673                        validation_codes::E_INVALID_PREDICATE_TYPE,
2674                        validation_codes::W_PREDICATE_NULLABILITY,
2675                        format!("NOT expects boolean, found {}", type_family_name(family)),
2676                    ));
2677                }
2678            }
2679            Expression::Eq(op)
2680            | Expression::Neq(op)
2681            | Expression::Lt(op)
2682            | Expression::Lte(op)
2683            | Expression::Gt(op)
2684            | Expression::Gte(op) => {
2685                let left = infer_expression_type_family(&op.left, schema_map, &context);
2686                let right = infer_expression_type_family(&op.right, schema_map, &context);
2687                if !are_comparable(left, right) {
2688                    errors.push(type_issue(
2689                        strict,
2690                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2691                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2692                        format!(
2693                            "Incompatible comparison between {} and {}",
2694                            type_family_name(left),
2695                            type_family_name(right)
2696                        ),
2697                    ));
2698                }
2699            }
2700            Expression::Like(op) | Expression::ILike(op) => {
2701                let left = infer_expression_type_family(&op.left, schema_map, &context);
2702                let right = infer_expression_type_family(&op.right, schema_map, &context);
2703                if left != TypeFamily::Unknown
2704                    && right != TypeFamily::Unknown
2705                    && (!is_string_like(left) || !is_string_like(right))
2706                {
2707                    errors.push(type_issue(
2708                        strict,
2709                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2710                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2711                        format!(
2712                            "LIKE/ILIKE expects string operands, found {} and {}",
2713                            type_family_name(left),
2714                            type_family_name(right)
2715                        ),
2716                    ));
2717                }
2718            }
2719            Expression::Between(between) => {
2720                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2721                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2722                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2723
2724                if !are_comparable(this_family, low_family)
2725                    || !are_comparable(this_family, high_family)
2726                {
2727                    errors.push(type_issue(
2728                        strict,
2729                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2730                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2731                        format!(
2732                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2733                            type_family_name(this_family),
2734                            type_family_name(low_family),
2735                            type_family_name(high_family)
2736                        ),
2737                    ));
2738                }
2739            }
2740            Expression::In(in_expr) => {
2741                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2742                for value in &in_expr.expressions {
2743                    let item_family = infer_expression_type_family(value, schema_map, &context);
2744                    if !are_comparable(this_family, item_family) {
2745                        errors.push(type_issue(
2746                            strict,
2747                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2748                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2749                            format!(
2750                                "IN item type {} is incompatible with {}",
2751                                type_family_name(item_family),
2752                                type_family_name(this_family)
2753                            ),
2754                        ));
2755                        break;
2756                    }
2757                }
2758            }
2759            Expression::Add(op)
2760            | Expression::Sub(op)
2761            | Expression::Mul(op)
2762            | Expression::Div(op)
2763            | Expression::Mod(op) => {
2764                let left = infer_expression_type_family(&op.left, schema_map, &context);
2765                let right = infer_expression_type_family(&op.right, schema_map, &context);
2766
2767                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2768                    continue;
2769                }
2770
2771                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2772                    && ((left.is_temporal() && right.is_numeric())
2773                        || (right.is_temporal() && left.is_numeric())
2774                        || (matches!(node, Expression::Sub(_))
2775                            && left.is_temporal()
2776                            && right.is_temporal()));
2777
2778                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2779                    errors.push(type_issue(
2780                        strict,
2781                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2782                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2783                        format!(
2784                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2785                            type_family_name(left),
2786                            type_family_name(right)
2787                        ),
2788                    ));
2789                }
2790            }
2791            Expression::Function(function) => {
2792                check_function_catalog(
2793                    function,
2794                    dialect,
2795                    function_catalog,
2796                    strict,
2797                    &mut errors,
2798                );
2799                check_generic_function(function, schema_map, &context, strict, &mut errors);
2800            }
2801            Expression::Upper(func)
2802            | Expression::Lower(func)
2803            | Expression::LTrim(func)
2804            | Expression::RTrim(func)
2805            | Expression::Reverse(func) => {
2806                let family = infer_expression_type_family(&func.this, schema_map, &context);
2807                check_function_argument(
2808                    &mut errors,
2809                    strict,
2810                    "string_function",
2811                    0,
2812                    family,
2813                    "a string argument",
2814                    is_string_like(family),
2815                );
2816            }
2817            Expression::Length(func) => {
2818                let family = infer_expression_type_family(&func.this, schema_map, &context);
2819                check_function_argument(
2820                    &mut errors,
2821                    strict,
2822                    "length",
2823                    0,
2824                    family,
2825                    "a string or binary argument",
2826                    is_string_or_binary(family),
2827                );
2828            }
2829            Expression::Trim(func) => {
2830                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2831                check_function_argument(
2832                    &mut errors,
2833                    strict,
2834                    "trim",
2835                    0,
2836                    this_family,
2837                    "a string argument",
2838                    is_string_like(this_family),
2839                );
2840                if let Some(chars) = &func.characters {
2841                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2842                    check_function_argument(
2843                        &mut errors,
2844                        strict,
2845                        "trim",
2846                        1,
2847                        chars_family,
2848                        "a string argument",
2849                        is_string_like(chars_family),
2850                    );
2851                }
2852            }
2853            Expression::Substring(func) => {
2854                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2855                check_function_argument(
2856                    &mut errors,
2857                    strict,
2858                    "substring",
2859                    0,
2860                    this_family,
2861                    "a string argument",
2862                    is_string_like(this_family),
2863                );
2864
2865                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2866                check_function_argument(
2867                    &mut errors,
2868                    strict,
2869                    "substring",
2870                    1,
2871                    start_family,
2872                    "a numeric argument",
2873                    start_family.is_numeric(),
2874                );
2875                if let Some(length) = &func.length {
2876                    let length_family = infer_expression_type_family(length, schema_map, &context);
2877                    check_function_argument(
2878                        &mut errors,
2879                        strict,
2880                        "substring",
2881                        2,
2882                        length_family,
2883                        "a numeric argument",
2884                        length_family.is_numeric(),
2885                    );
2886                }
2887            }
2888            Expression::Replace(func) => {
2889                for (arg, idx) in [
2890                    (&func.this, 0_usize),
2891                    (&func.old, 1_usize),
2892                    (&func.new, 2_usize),
2893                ] {
2894                    let family = infer_expression_type_family(arg, schema_map, &context);
2895                    check_function_argument(
2896                        &mut errors,
2897                        strict,
2898                        "replace",
2899                        idx,
2900                        family,
2901                        "a string argument",
2902                        is_string_like(family),
2903                    );
2904                }
2905            }
2906            Expression::Left(func) | Expression::Right(func) => {
2907                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2908                check_function_argument(
2909                    &mut errors,
2910                    strict,
2911                    "left_right",
2912                    0,
2913                    this_family,
2914                    "a string argument",
2915                    is_string_like(this_family),
2916                );
2917                let length_family =
2918                    infer_expression_type_family(&func.length, schema_map, &context);
2919                check_function_argument(
2920                    &mut errors,
2921                    strict,
2922                    "left_right",
2923                    1,
2924                    length_family,
2925                    "a numeric argument",
2926                    length_family.is_numeric(),
2927                );
2928            }
2929            Expression::Repeat(func) => {
2930                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2931                check_function_argument(
2932                    &mut errors,
2933                    strict,
2934                    "repeat",
2935                    0,
2936                    this_family,
2937                    "a string argument",
2938                    is_string_like(this_family),
2939                );
2940                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2941                check_function_argument(
2942                    &mut errors,
2943                    strict,
2944                    "repeat",
2945                    1,
2946                    times_family,
2947                    "a numeric argument",
2948                    times_family.is_numeric(),
2949                );
2950            }
2951            Expression::Lpad(func) | Expression::Rpad(func) => {
2952                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2953                check_function_argument(
2954                    &mut errors,
2955                    strict,
2956                    "pad",
2957                    0,
2958                    this_family,
2959                    "a string argument",
2960                    is_string_like(this_family),
2961                );
2962                let length_family =
2963                    infer_expression_type_family(&func.length, schema_map, &context);
2964                check_function_argument(
2965                    &mut errors,
2966                    strict,
2967                    "pad",
2968                    1,
2969                    length_family,
2970                    "a numeric argument",
2971                    length_family.is_numeric(),
2972                );
2973                if let Some(fill) = &func.fill {
2974                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
2975                    check_function_argument(
2976                        &mut errors,
2977                        strict,
2978                        "pad",
2979                        2,
2980                        fill_family,
2981                        "a string argument",
2982                        is_string_like(fill_family),
2983                    );
2984                }
2985            }
2986            Expression::Abs(func)
2987            | Expression::Sqrt(func)
2988            | Expression::Cbrt(func)
2989            | Expression::Ln(func)
2990            | Expression::Exp(func) => {
2991                let family = infer_expression_type_family(&func.this, schema_map, &context);
2992                check_function_argument(
2993                    &mut errors,
2994                    strict,
2995                    "numeric_function",
2996                    0,
2997                    family,
2998                    "a numeric argument",
2999                    family.is_numeric(),
3000                );
3001            }
3002            Expression::Round(func) => {
3003                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3004                check_function_argument(
3005                    &mut errors,
3006                    strict,
3007                    "round",
3008                    0,
3009                    this_family,
3010                    "a numeric argument",
3011                    this_family.is_numeric(),
3012                );
3013                if let Some(decimals) = &func.decimals {
3014                    let decimals_family =
3015                        infer_expression_type_family(decimals, schema_map, &context);
3016                    check_function_argument(
3017                        &mut errors,
3018                        strict,
3019                        "round",
3020                        1,
3021                        decimals_family,
3022                        "a numeric argument",
3023                        decimals_family.is_numeric(),
3024                    );
3025                }
3026            }
3027            Expression::Floor(func) => {
3028                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3029                check_function_argument(
3030                    &mut errors,
3031                    strict,
3032                    "floor",
3033                    0,
3034                    this_family,
3035                    "a numeric argument",
3036                    this_family.is_numeric(),
3037                );
3038                if let Some(scale) = &func.scale {
3039                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
3040                    check_function_argument(
3041                        &mut errors,
3042                        strict,
3043                        "floor",
3044                        1,
3045                        scale_family,
3046                        "a numeric argument",
3047                        scale_family.is_numeric(),
3048                    );
3049                }
3050            }
3051            Expression::Ceil(func) => {
3052                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3053                check_function_argument(
3054                    &mut errors,
3055                    strict,
3056                    "ceil",
3057                    0,
3058                    this_family,
3059                    "a numeric argument",
3060                    this_family.is_numeric(),
3061                );
3062                if let Some(decimals) = &func.decimals {
3063                    let decimals_family =
3064                        infer_expression_type_family(decimals, schema_map, &context);
3065                    check_function_argument(
3066                        &mut errors,
3067                        strict,
3068                        "ceil",
3069                        1,
3070                        decimals_family,
3071                        "a numeric argument",
3072                        decimals_family.is_numeric(),
3073                    );
3074                }
3075            }
3076            Expression::Power(func) => {
3077                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3078                check_function_argument(
3079                    &mut errors,
3080                    strict,
3081                    "power",
3082                    0,
3083                    left_family,
3084                    "a numeric argument",
3085                    left_family.is_numeric(),
3086                );
3087                let right_family =
3088                    infer_expression_type_family(&func.expression, schema_map, &context);
3089                check_function_argument(
3090                    &mut errors,
3091                    strict,
3092                    "power",
3093                    1,
3094                    right_family,
3095                    "a numeric argument",
3096                    right_family.is_numeric(),
3097                );
3098            }
3099            Expression::Log(func) => {
3100                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3101                check_function_argument(
3102                    &mut errors,
3103                    strict,
3104                    "log",
3105                    0,
3106                    this_family,
3107                    "a numeric argument",
3108                    this_family.is_numeric(),
3109                );
3110                if let Some(base) = &func.base {
3111                    let base_family = infer_expression_type_family(base, schema_map, &context);
3112                    check_function_argument(
3113                        &mut errors,
3114                        strict,
3115                        "log",
3116                        1,
3117                        base_family,
3118                        "a numeric argument",
3119                        base_family.is_numeric(),
3120                    );
3121                }
3122            }
3123            _ => {}
3124        }
3125    }
3126
3127    errors
3128}
3129
3130fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3131    let mut errors = Vec::new();
3132
3133    let Expression::Select(select) = stmt else {
3134        return errors;
3135    };
3136    let select_expr = Expression::Select(select.clone());
3137
3138    // W001: SELECT * is discouraged
3139    if !select_expr
3140        .find_all(|e| matches!(e, Expression::Star(_)))
3141        .is_empty()
3142    {
3143        errors.push(ValidationError::warning(
3144            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3145            validation_codes::W_SELECT_STAR,
3146        ));
3147    }
3148
3149    // W002: aggregate + non-aggregate columns without GROUP BY
3150    let aggregate_count = get_aggregate_functions(&select_expr).len();
3151    if aggregate_count > 0 && select.group_by.is_none() {
3152        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3153            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3154                && get_aggregate_functions(expr).is_empty()
3155        });
3156
3157        if has_non_aggregate_column {
3158            errors.push(ValidationError::warning(
3159                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3160                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3161            ));
3162        }
3163    }
3164
3165    // W003: DISTINCT with ORDER BY
3166    if select.distinct && select.order_by.is_some() {
3167        errors.push(ValidationError::warning(
3168            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3169            validation_codes::W_DISTINCT_ORDER_BY,
3170        ));
3171    }
3172
3173    // W004: LIMIT without ORDER BY
3174    if select.limit.is_some() && select.order_by.is_none() {
3175        errors.push(ValidationError::warning(
3176            "LIMIT without ORDER BY produces non-deterministic results",
3177            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3178        ));
3179    }
3180
3181    errors
3182}
3183
3184fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3185    scope
3186        .sources
3187        .get_key_value(name)
3188        .map(|(k, _)| k.clone())
3189        .or_else(|| {
3190            scope
3191                .sources
3192                .keys()
3193                .find(|source| source.eq_ignore_ascii_case(name))
3194                .cloned()
3195        })
3196}
3197
3198fn source_has_column(columns: &[String], column_name: &str) -> bool {
3199    columns
3200        .iter()
3201        .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3202}
3203
3204fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3205    scope
3206        .sources
3207        .get(source_name)
3208        .map(|source| match &source.expression {
3209            Expression::Table(table) => lower(&table_ref_display_name(table)),
3210            _ => lower(source_name),
3211        })
3212        .unwrap_or_else(|| lower(source_name))
3213}
3214
3215fn validate_select_columns_with_schema(
3216    select: &crate::expressions::Select,
3217    schema_map: &HashMap<String, TableSchemaEntry>,
3218    resolver_schema: &MappingSchema,
3219    strict: bool,
3220) -> Vec<ValidationError> {
3221    let mut errors = Vec::new();
3222    let select_expr = Expression::Select(Box::new(select.clone()));
3223    let scope = build_scope(&select_expr);
3224    let mut resolver = Resolver::new(&scope, resolver_schema, true);
3225    let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3226
3227    for node in walk_in_scope(&select_expr, false) {
3228        let Expression::Column(column) = node else {
3229            continue;
3230        };
3231
3232        let col_name = lower(&column.name.name);
3233        if col_name.is_empty() {
3234            continue;
3235        }
3236
3237        if let Some(table) = &column.table {
3238            let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3239                // The table qualifier is not a declared alias or source in this scope
3240                errors.push(if strict {
3241                    ValidationError::error(
3242                        format!(
3243                            "Unknown table or alias '{}' referenced by column '{}'",
3244                            table.name, col_name
3245                        ),
3246                        validation_codes::E_UNRESOLVED_REFERENCE,
3247                    )
3248                } else {
3249                    ValidationError::warning(
3250                        format!(
3251                            "Unknown table or alias '{}' referenced by column '{}'",
3252                            table.name, col_name
3253                        ),
3254                        validation_codes::E_UNRESOLVED_REFERENCE,
3255                    )
3256                });
3257                continue;
3258            };
3259
3260            if let Ok(columns) = resolver.get_source_columns(&source_name) {
3261                if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3262                    let table_name = source_display_name(&scope, &source_name);
3263                    errors.push(if strict {
3264                        ValidationError::error(
3265                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3266                            validation_codes::E_UNKNOWN_COLUMN,
3267                        )
3268                    } else {
3269                        ValidationError::warning(
3270                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3271                            validation_codes::E_UNKNOWN_COLUMN,
3272                        )
3273                    });
3274                }
3275            }
3276            continue;
3277        }
3278
3279        let matching_sources: Vec<String> = source_names
3280            .iter()
3281            .filter_map(|source_name| {
3282                resolver
3283                    .get_source_columns(source_name)
3284                    .ok()
3285                    .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3286                    .map(|_| source_name.clone())
3287            })
3288            .collect();
3289
3290        if !matching_sources.is_empty() {
3291            continue;
3292        }
3293
3294        let known_sources: Vec<String> = source_names
3295            .iter()
3296            .filter_map(|source_name| {
3297                resolver
3298                    .get_source_columns(source_name)
3299                    .ok()
3300                    .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3301                    .map(|_| source_name.clone())
3302            })
3303            .collect();
3304
3305        if known_sources.len() == 1 {
3306            let table_name = source_display_name(&scope, &known_sources[0]);
3307            errors.push(if strict {
3308                ValidationError::error(
3309                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3310                    validation_codes::E_UNKNOWN_COLUMN,
3311                )
3312            } else {
3313                ValidationError::warning(
3314                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3315                    validation_codes::E_UNKNOWN_COLUMN,
3316                )
3317            });
3318        } else if known_sources.len() > 1 {
3319            errors.push(if strict {
3320                ValidationError::error(
3321                    format!(
3322                        "Unknown column '{}' (not found in any referenced table)",
3323                        col_name
3324                    ),
3325                    validation_codes::E_UNKNOWN_COLUMN,
3326                )
3327            } else {
3328                ValidationError::warning(
3329                    format!(
3330                        "Unknown column '{}' (not found in any referenced table)",
3331                        col_name
3332                    ),
3333                    validation_codes::E_UNKNOWN_COLUMN,
3334                )
3335            });
3336        } else if !schema_map.is_empty() {
3337            let found = schema_map
3338                .values()
3339                .any(|table_schema| table_schema.columns.contains_key(&col_name));
3340            if !found {
3341                errors.push(if strict {
3342                    ValidationError::error(
3343                        format!("Unknown column '{}'", col_name),
3344                        validation_codes::E_UNKNOWN_COLUMN,
3345                    )
3346                } else {
3347                    ValidationError::warning(
3348                        format!("Unknown column '{}'", col_name),
3349                        validation_codes::E_UNKNOWN_COLUMN,
3350                    )
3351                });
3352            }
3353        }
3354    }
3355
3356    errors
3357}
3358
3359fn validate_statement_with_schema(
3360    stmt: &Expression,
3361    schema_map: &HashMap<String, TableSchemaEntry>,
3362    resolver_schema: &MappingSchema,
3363    strict: bool,
3364) -> Vec<ValidationError> {
3365    let mut errors = Vec::new();
3366    let cte_aliases = collect_cte_aliases(stmt);
3367    let mut seen_tables: HashSet<String> = HashSet::new();
3368
3369    // Table validation (E200)
3370    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3371        let Expression::Table(table) = node else {
3372            continue;
3373        };
3374
3375        if cte_aliases.contains(&lower(&table.name.name)) {
3376            continue;
3377        }
3378
3379        let resolved_key = table_ref_candidates(table)
3380            .into_iter()
3381            .find(|k| schema_map.contains_key(k));
3382        let table_key = resolved_key
3383            .clone()
3384            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3385
3386        if !seen_tables.insert(table_key) {
3387            continue;
3388        }
3389
3390        if resolved_key.is_none() {
3391            errors.push(if strict {
3392                ValidationError::error(
3393                    format!("Unknown table '{}'", table_ref_display_name(table)),
3394                    validation_codes::E_UNKNOWN_TABLE,
3395                )
3396            } else {
3397                ValidationError::warning(
3398                    format!("Unknown table '{}'", table_ref_display_name(table)),
3399                    validation_codes::E_UNKNOWN_TABLE,
3400                )
3401            });
3402        }
3403    }
3404
3405    for node in stmt.dfs() {
3406        let Expression::Select(select) = node else {
3407            continue;
3408        };
3409        errors.extend(validate_select_columns_with_schema(
3410            select,
3411            schema_map,
3412            resolver_schema,
3413            strict,
3414        ));
3415    }
3416
3417    errors
3418}
3419
3420/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3421pub fn validate_with_schema(
3422    sql: &str,
3423    dialect: DialectType,
3424    schema: &ValidationSchema,
3425    options: &SchemaValidationOptions,
3426) -> ValidationResult {
3427    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3428
3429    // Syntax validation first.
3430    let syntax_result = crate::validate_with_options(
3431        sql,
3432        dialect,
3433        &crate::ValidationOptions {
3434            strict_syntax: options.strict_syntax,
3435        },
3436    );
3437    if !syntax_result.valid {
3438        return syntax_result;
3439    }
3440
3441    let d = Dialect::get(dialect);
3442    let statements = match d.parse(sql) {
3443        Ok(exprs) => exprs,
3444        Err(e) => {
3445            return ValidationResult::with_errors(vec![ValidationError::error(
3446                e.to_string(),
3447                validation_codes::E_PARSE_OR_OPTIONS,
3448            )]);
3449        }
3450    };
3451
3452    let schema_map = build_schema_map(schema);
3453    let resolver_schema = build_resolver_schema(schema);
3454    let mut all_errors = syntax_result.errors;
3455    let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3456        default_embedded_function_catalog()
3457    } else {
3458        None
3459    };
3460    let effective_function_catalog = options
3461        .function_catalog
3462        .as_deref()
3463        .or_else(|| embedded_function_catalog.as_deref());
3464    let declared_relationships = if options.check_references {
3465        build_declared_relationships(schema, &schema_map)
3466    } else {
3467        Vec::new()
3468    };
3469
3470    if options.check_references {
3471        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3472    }
3473
3474    for stmt in &statements {
3475        if options.semantic {
3476            all_errors.extend(check_semantics(stmt));
3477        }
3478        all_errors.extend(validate_statement_with_schema(
3479            stmt,
3480            &schema_map,
3481            &resolver_schema,
3482            strict,
3483        ));
3484        if options.check_types {
3485            all_errors.extend(check_types(
3486                stmt,
3487                dialect,
3488                &schema_map,
3489                effective_function_catalog,
3490                strict,
3491            ));
3492        }
3493        if options.check_references {
3494            all_errors.extend(check_query_reference_quality(
3495                stmt,
3496                &schema_map,
3497                &resolver_schema,
3498                strict,
3499                &declared_relationships,
3500            ));
3501        }
3502    }
3503
3504    ValidationResult::with_errors(all_errors)
3505}
3506
3507#[cfg(test)]
3508mod tests;