Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15    feature = "function-catalog-clickhouse",
16    feature = "function-catalog-duckdb",
17    feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20    FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21    HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34    feature = "function-catalog-clickhouse",
35    feature = "function-catalog-duckdb",
36    feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40/// Column definition used for schema-aware validation.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43    /// Column name.
44    pub name: String,
45    /// Optional column data type (currently informational).
46    #[serde(default, rename = "type")]
47    pub data_type: String,
48    /// Whether the column allows NULL values.
49    #[serde(default)]
50    pub nullable: Option<bool>,
51    /// Whether this column is part of a primary key.
52    #[serde(default, rename = "primaryKey")]
53    pub primary_key: bool,
54    /// Whether this column has a uniqueness constraint.
55    #[serde(default)]
56    pub unique: bool,
57    /// Optional column-level foreign key reference.
58    #[serde(default)]
59    pub references: Option<SchemaColumnReference>,
60}
61
62/// Column-level foreign key reference metadata.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65    /// Referenced table name.
66    pub table: String,
67    /// Referenced column name.
68    pub column: String,
69    /// Optional schema/namespace of referenced table.
70    #[serde(default)]
71    pub schema: Option<String>,
72}
73
74/// Table-level foreign key reference metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77    /// Optional FK name.
78    #[serde(default)]
79    pub name: Option<String>,
80    /// Source columns in the current table.
81    pub columns: Vec<String>,
82    /// Referenced target table + columns.
83    pub references: SchemaTableReference,
84}
85
86/// Target of a table-level foreign key.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89    /// Referenced table name.
90    pub table: String,
91    /// Referenced target columns.
92    pub columns: Vec<String>,
93    /// Optional schema/namespace of referenced table.
94    #[serde(default)]
95    pub schema: Option<String>,
96}
97
98/// Table definition used for schema-aware validation.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101    /// Table name.
102    pub name: String,
103    /// Optional schema/namespace name.
104    #[serde(default)]
105    pub schema: Option<String>,
106    /// Column definitions.
107    pub columns: Vec<SchemaColumn>,
108    /// Optional aliases that should resolve to this table.
109    #[serde(default)]
110    pub aliases: Vec<String>,
111    /// Optional primary key column list.
112    #[serde(default, rename = "primaryKey")]
113    pub primary_key: Vec<String>,
114    /// Optional unique key groups.
115    #[serde(default, rename = "uniqueKeys")]
116    pub unique_keys: Vec<Vec<String>>,
117    /// Optional table-level foreign keys.
118    #[serde(default, rename = "foreignKeys")]
119    pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122/// Schema payload used for schema-aware validation.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125    /// Known tables.
126    pub tables: Vec<SchemaTable>,
127    /// Default strict mode for unknown identifiers.
128    #[serde(default)]
129    pub strict: Option<bool>,
130}
131
132/// Options for schema-aware validation.
133#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
136    #[serde(default)]
137    pub check_types: bool,
138    /// Enables FK/reference integrity checks and query-level reference quality checks.
139    #[serde(default)]
140    pub check_references: bool,
141    /// If true/false, overrides schema.strict.
142    #[serde(default)]
143    pub strict: Option<bool>,
144    /// Enables semantic warnings (W001..W004).
145    #[serde(default)]
146    pub semantic: bool,
147    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
148    #[serde(default)]
149    pub strict_syntax: bool,
150    /// Optional external function catalog plugin for dialect-specific function validation.
151    #[serde(skip, default)]
152    pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156    feature = "function-catalog-clickhouse",
157    feature = "function-catalog-duckdb",
158    feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161    case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163    match case {
164        polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165            CoreFunctionNameCase::Insensitive
166        }
167        polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168            CoreFunctionNameCase::Sensitive
169        }
170    }
171}
172
173#[cfg(any(
174    feature = "function-catalog-clickhouse",
175    feature = "function-catalog-duckdb",
176    feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179    signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181    signatures
182        .into_iter()
183        .map(|signature| CoreFunctionSignature {
184            min_arity: signature.min_arity,
185            max_arity: signature.max_arity,
186        })
187        .collect()
188}
189
190#[cfg(any(
191    feature = "function-catalog-clickhouse",
192    feature = "function-catalog-duckdb",
193    feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196    catalog: &'a mut HashMapFunctionCatalog,
197    dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201    feature = "function-catalog-clickhouse",
202    feature = "function-catalog-duckdb",
203    feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206    fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207        if let Some(cached) = self.dialect_cache.get(dialect) {
208            return *cached;
209        }
210        let parsed = dialect.parse::<DialectType>().ok();
211        self.dialect_cache.insert(dialect, parsed);
212        parsed
213    }
214}
215
216#[cfg(any(
217    feature = "function-catalog-clickhouse",
218    feature = "function-catalog-duckdb",
219    feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222    fn set_dialect_name_case(
223        &mut self,
224        dialect: &'static str,
225        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226    ) {
227        if let Some(core_dialect) = self.resolve_dialect(dialect) {
228            self.catalog
229                .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230        }
231    }
232
233    fn set_function_name_case(
234        &mut self,
235        dialect: &'static str,
236        function_name: &str,
237        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238    ) {
239        if let Some(core_dialect) = self.resolve_dialect(dialect) {
240            self.catalog.set_function_name_case(
241                core_dialect,
242                function_name,
243                to_core_name_case(name_case),
244            );
245        }
246    }
247
248    fn register(
249        &mut self,
250        dialect: &'static str,
251        function_name: &str,
252        signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253    ) {
254        if let Some(core_dialect) = self.resolve_dialect(dialect) {
255            self.catalog
256                .register(core_dialect, function_name, to_core_signatures(signatures));
257        }
258    }
259}
260
261#[cfg(any(
262    feature = "function-catalog-clickhouse",
263    feature = "function-catalog-duckdb",
264    feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267    static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268        let mut catalog = HashMapFunctionCatalog::default();
269        let mut sink = EmbeddedCatalogSink {
270            catalog: &mut catalog,
271            dialect_cache: HashMap::new(),
272        };
273        polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274        Arc::new(catalog)
275    });
276
277    EMBEDDED.clone()
278}
279
280#[cfg(any(
281    feature = "function-catalog-clickhouse",
282    feature = "function-catalog-duckdb",
283    feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286    Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290    feature = "function-catalog-clickhouse",
291    feature = "function-catalog-duckdb",
292    feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295    None
296}
297
298/// Validation error/warning codes used by schema-aware validation.
299pub mod validation_codes {
300    // Existing schema and semantic checks.
301    pub const E_PARSE_OR_OPTIONS: &str = "E000";
302    pub const E_UNKNOWN_TABLE: &str = "E200";
303    pub const E_UNKNOWN_COLUMN: &str = "E201";
304    pub const E_UNKNOWN_FUNCTION: &str = "E202";
305    pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307    pub const W_SELECT_STAR: &str = "W001";
308    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309    pub const W_DISTINCT_ORDER_BY: &str = "W003";
310    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312    // Phase 2 (type checks): E210-E219, W210-W219.
313    pub const E_TYPE_MISMATCH: &str = "E210";
314    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321    pub const E_INVALID_CAST: &str = "E218";
322    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327    pub const W_LOSSY_CAST: &str = "W213";
328    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329    pub const W_PREDICATE_NULLABILITY: &str = "W215";
330    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335    // Phase 2 (reference checks): E220-E229, W220-W229.
336    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342    pub const W_CARTESIAN_JOIN: &str = "W220";
343    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347/// Canonical type family used by schema/type checks.
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351    Unknown,
352    Boolean,
353    Integer,
354    Numeric,
355    String,
356    Binary,
357    Date,
358    Time,
359    Timestamp,
360    Interval,
361    Json,
362    Uuid,
363    Array,
364    Map,
365    Struct,
366}
367
368impl TypeFamily {
369    pub fn is_numeric(self) -> bool {
370        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371    }
372
373    pub fn is_temporal(self) -> bool {
374        matches!(
375            self,
376            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377        )
378    }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383    columns: HashMap<String, TypeFamily>,
384    column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388    s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392    let open = data_type.find('(')?;
393    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394        return None;
395    }
396    let base = data_type[..open].trim();
397    let inner = data_type[open + 1..data_type.len() - 1].trim();
398    Some((base, inner))
399}
400
401/// Canonicalize a schema type string into a stable `TypeFamily`.
402pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403    let trimmed = data_type
404        .trim()
405        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406    if trimmed.is_empty() {
407        return TypeFamily::Unknown;
408    }
409
410    // Normalize whitespace and lowercase for matching.
411    let normalized = trimmed
412        .split_whitespace()
413        .collect::<Vec<_>>()
414        .join(" ")
415        .to_lowercase();
416
417    // Strip common wrappers first.
418    if let Some((base, inner)) = split_type_args(&normalized) {
419        match base {
420            "nullable" | "lowcardinality" => return canonical_type_family(inner),
421            "array" | "list" => return TypeFamily::Array,
422            "map" => return TypeFamily::Map,
423            "struct" | "row" | "record" => return TypeFamily::Struct,
424            _ => {}
425        }
426    }
427
428    if normalized.starts_with("array<") || normalized.starts_with("list<") {
429        return TypeFamily::Array;
430    }
431    if normalized.starts_with("map<") {
432        return TypeFamily::Map;
433    }
434    if normalized.starts_with("struct<")
435        || normalized.starts_with("row<")
436        || normalized.starts_with("record<")
437        || normalized.starts_with("object<")
438    {
439        return TypeFamily::Struct;
440    }
441
442    if normalized.ends_with("[]") {
443        return TypeFamily::Array;
444    }
445
446    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
447    let mut base = normalized
448        .split('(')
449        .next()
450        .unwrap_or("")
451        .trim()
452        .to_string();
453    if base.is_empty() {
454        return TypeFamily::Unknown;
455    }
456
457    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460    match base.as_str() {
461        "bool" | "boolean" => TypeFamily::Boolean,
462        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465            TypeFamily::Integer
466        }
467        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469            TypeFamily::Numeric
470        }
471        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472        | "string" | "clob" => TypeFamily::String,
473        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474        "date" => TypeFamily::Date,
475        "time" => TypeFamily::Time,
476        "timestamp"
477        | "timestamptz"
478        | "datetime"
479        | "datetime2"
480        | "smalldatetime"
481        | "timestamp with time zone"
482        | "timestamp without time zone" => TypeFamily::Timestamp,
483        "interval" => TypeFamily::Interval,
484        "json" | "jsonb" | "variant" => TypeFamily::Json,
485        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486        "array" | "list" => TypeFamily::Array,
487        "map" => TypeFamily::Map,
488        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489        _ => TypeFamily::Unknown,
490    }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494    let mut map = HashMap::new();
495
496    for table in &schema.tables {
497        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498        let columns: HashMap<String, TypeFamily> = table
499            .columns
500            .iter()
501            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502            .collect();
503        let entry = TableSchemaEntry {
504            columns,
505            column_order,
506        };
507
508        let simple_name = lower(&table.name);
509        map.insert(simple_name, entry.clone());
510
511        if let Some(table_schema) = &table.schema {
512            map.insert(
513                format!("{}.{}", lower(table_schema), lower(&table.name)),
514                entry.clone(),
515            );
516        }
517
518        for alias in &table.aliases {
519            map.insert(lower(alias), entry.clone());
520        }
521    }
522
523    map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527    match family {
528        TypeFamily::Unknown => DataType::Unknown,
529        TypeFamily::Boolean => DataType::Boolean,
530        TypeFamily::Integer => DataType::Int {
531            length: None,
532            integer_spelling: false,
533        },
534        TypeFamily::Numeric => DataType::Double {
535            precision: None,
536            scale: None,
537        },
538        TypeFamily::String => DataType::VarChar {
539            length: None,
540            parenthesized_length: false,
541        },
542        TypeFamily::Binary => DataType::VarBinary { length: None },
543        TypeFamily::Date => DataType::Date,
544        TypeFamily::Time => DataType::Time {
545            precision: None,
546            timezone: false,
547        },
548        TypeFamily::Timestamp => DataType::Timestamp {
549            precision: None,
550            timezone: false,
551        },
552        TypeFamily::Interval => DataType::Interval {
553            unit: None,
554            to: None,
555        },
556        TypeFamily::Json => DataType::Json,
557        TypeFamily::Uuid => DataType::Uuid,
558        TypeFamily::Array => DataType::Array {
559            element_type: Box::new(DataType::Unknown),
560            dimension: None,
561        },
562        TypeFamily::Map => DataType::Map {
563            key_type: Box::new(DataType::Unknown),
564            value_type: Box::new(DataType::Unknown),
565        },
566        TypeFamily::Struct => DataType::Struct {
567            fields: Vec::new(),
568            nested: false,
569        },
570    }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574    let mut mapping = MappingSchema::new();
575
576    for table in &schema.tables {
577        let columns: Vec<(String, DataType)> = table
578            .columns
579            .iter()
580            .map(|column| {
581                (
582                    lower(&column.name),
583                    type_family_to_data_type(canonical_type_family(&column.data_type)),
584                )
585            })
586            .collect();
587
588        let mut table_names = Vec::new();
589        table_names.push(lower(&table.name));
590        if let Some(table_schema) = &table.schema {
591            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592        }
593        for alias in &table.aliases {
594            table_names.push(lower(alias));
595        }
596
597        let mut dedup = HashSet::new();
598        for table_name in table_names {
599            if dedup.insert(table_name.clone()) {
600                let _ = mapping.add_table(&table_name, &columns, None);
601            }
602        }
603    }
604
605    mapping
606}
607
608/// Build a `MappingSchema` from a `ValidationSchema` payload.
609///
610/// This is useful for APIs that already accept `ValidationSchema`-shaped input
611/// (for example JSON wrappers) and need to run schema-aware lineage or other
612/// resolver-based analysis.
613pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614    build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618    let mut aliases = HashSet::new();
619
620    for node in expr.dfs() {
621        match node {
622            Expression::Select(select) => {
623                if let Some(with) = &select.with {
624                    for cte in &with.ctes {
625                        aliases.insert(lower(&cte.alias.name));
626                    }
627                }
628            }
629            Expression::Insert(insert) => {
630                if let Some(with) = &insert.with {
631                    for cte in &with.ctes {
632                        aliases.insert(lower(&cte.alias.name));
633                    }
634                }
635            }
636            Expression::Update(update) => {
637                if let Some(with) = &update.with {
638                    for cte in &with.ctes {
639                        aliases.insert(lower(&cte.alias.name));
640                    }
641                }
642            }
643            Expression::Delete(delete) => {
644                if let Some(with) = &delete.with {
645                    for cte in &with.ctes {
646                        aliases.insert(lower(&cte.alias.name));
647                    }
648                }
649            }
650            Expression::Union(union) => {
651                if let Some(with) = &union.with {
652                    for cte in &with.ctes {
653                        aliases.insert(lower(&cte.alias.name));
654                    }
655                }
656            }
657            Expression::Intersect(intersect) => {
658                if let Some(with) = &intersect.with {
659                    for cte in &with.ctes {
660                        aliases.insert(lower(&cte.alias.name));
661                    }
662                }
663            }
664            Expression::Except(except) => {
665                if let Some(with) = &except.with {
666                    for cte in &with.ctes {
667                        aliases.insert(lower(&cte.alias.name));
668                    }
669                }
670            }
671            Expression::Merge(merge) => {
672                if let Some(with_) = &merge.with_ {
673                    if let Expression::With(with_clause) = with_.as_ref() {
674                        for cte in &with_clause.ctes {
675                            aliases.insert(lower(&cte.alias.name));
676                        }
677                    }
678                }
679            }
680            _ => {}
681        }
682    }
683
684    aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688    let name = lower(&table.name.name);
689    let schema = table.schema.as_ref().map(|s| lower(&s.name));
690    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692    let mut candidates = Vec::new();
693    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694        candidates.push(format!("{}.{}.{}", catalog, schema, name));
695    }
696    if let Some(schema) = &schema {
697        candidates.push(format!("{}.{}", schema, name));
698    }
699    candidates.push(name);
700    candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704    let mut parts = Vec::new();
705    if let Some(catalog) = &table.catalog {
706        parts.push(catalog.name.clone());
707    }
708    if let Some(schema) = &table.schema {
709        parts.push(schema.name.clone());
710    }
711    parts.push(table.name.name.clone());
712    parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717    referenced_tables: HashSet<String>,
718    table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722    match family {
723        TypeFamily::Unknown => "unknown",
724        TypeFamily::Boolean => "boolean",
725        TypeFamily::Integer => "integer",
726        TypeFamily::Numeric => "numeric",
727        TypeFamily::String => "string",
728        TypeFamily::Binary => "binary",
729        TypeFamily::Date => "date",
730        TypeFamily::Time => "time",
731        TypeFamily::Timestamp => "timestamp",
732        TypeFamily::Interval => "interval",
733        TypeFamily::Json => "json",
734        TypeFamily::Uuid => "uuid",
735        TypeFamily::Array => "array",
736        TypeFamily::Map => "map",
737        TypeFamily::Struct => "struct",
738    }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742    matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746    matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750    strict: bool,
751    error_code: &str,
752    warning_code: &str,
753    message: impl Into<String>,
754) -> ValidationError {
755    if strict {
756        ValidationError::error(message.into(), error_code)
757    } else {
758        ValidationError::warning(message.into(), warning_code)
759    }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763    match data_type {
764        DataType::Boolean => TypeFamily::Boolean,
765        DataType::TinyInt { .. }
766        | DataType::SmallInt { .. }
767        | DataType::Int { .. }
768        | DataType::BigInt { .. } => TypeFamily::Integer,
769        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770            TypeFamily::Numeric
771        }
772        DataType::Char { .. }
773        | DataType::VarChar { .. }
774        | DataType::String { .. }
775        | DataType::Text
776        | DataType::TextWithLength { .. }
777        | DataType::CharacterSet { .. } => TypeFamily::String,
778        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779        DataType::Date => TypeFamily::Date,
780        DataType::Time { .. } => TypeFamily::Time,
781        DataType::Timestamp { .. } => TypeFamily::Timestamp,
782        DataType::Interval { .. } => TypeFamily::Interval,
783        DataType::Json | DataType::JsonB => TypeFamily::Json,
784        DataType::Uuid => TypeFamily::Uuid,
785        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786        DataType::Map { .. } => TypeFamily::Map,
787        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788            TypeFamily::Struct
789        }
790        DataType::Nullable { inner } => data_type_family(inner),
791        DataType::Custom { name } => canonical_type_family(name),
792        DataType::Unknown => TypeFamily::Unknown,
793        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795        DataType::Vector { .. } => TypeFamily::Array,
796        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797    }
798}
799
800fn collect_type_check_context(
801    stmt: &Expression,
802    schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804    fn add_table_to_context(
805        table: &TableRef,
806        schema_map: &HashMap<String, TableSchemaEntry>,
807        context: &mut TypeCheckContext,
808    ) {
809        let resolved_key = table_ref_candidates(table)
810            .into_iter()
811            .find(|k| schema_map.contains_key(k));
812
813        let Some(table_key) = resolved_key else {
814            return;
815        };
816
817        context.referenced_tables.insert(table_key.clone());
818        context
819            .table_aliases
820            .insert(lower(&table.name.name), table_key.clone());
821        if let Some(alias) = &table.alias {
822            context
823                .table_aliases
824                .insert(lower(&alias.name), table_key.clone());
825        }
826    }
827
828    let mut context = TypeCheckContext::default();
829    let cte_aliases = collect_cte_aliases(stmt);
830
831    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832        let Expression::Table(table) = node else {
833            continue;
834        };
835
836        if cte_aliases.contains(&lower(&table.name.name)) {
837            continue;
838        }
839
840        add_table_to_context(table, schema_map, &mut context);
841    }
842
843    // Seed DML target tables explicitly because they are struct fields and may
844    // not appear as standalone Expression::Table nodes in traversal output.
845    match stmt {
846        Expression::Insert(insert) => {
847            add_table_to_context(&insert.table, schema_map, &mut context);
848        }
849        Expression::Update(update) => {
850            add_table_to_context(&update.table, schema_map, &mut context);
851            for table in &update.extra_tables {
852                add_table_to_context(table, schema_map, &mut context);
853            }
854        }
855        Expression::Delete(delete) => {
856            add_table_to_context(&delete.table, schema_map, &mut context);
857            for table in &delete.using {
858                add_table_to_context(table, schema_map, &mut context);
859            }
860            for table in &delete.tables {
861                add_table_to_context(table, schema_map, &mut context);
862            }
863        }
864        _ => {}
865    }
866
867    context
868}
869
870fn resolve_table_schema_entry<'a>(
871    table: &TableRef,
872    schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874    let key = table_ref_candidates(table)
875        .into_iter()
876        .find(|k| schema_map.contains_key(k))?;
877    let entry = schema_map.get(&key)?;
878    Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882    if strict {
883        ValidationError::error(
884            message.into(),
885            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886        )
887    } else {
888        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889    }
890}
891
892fn reference_table_candidates(
893    table_name: &str,
894    explicit_schema: Option<&str>,
895    source_schema: Option<&str>,
896) -> Vec<String> {
897    let mut candidates = Vec::new();
898    let raw = lower(table_name);
899
900    if let Some(schema) = explicit_schema {
901        candidates.push(format!("{}.{}", lower(schema), raw));
902    }
903
904    if raw.contains('.') {
905        candidates.push(raw.clone());
906        if let Some(last) = raw.rsplit('.').next() {
907            candidates.push(last.to_string());
908        }
909    } else {
910        if let Some(schema) = source_schema {
911            candidates.push(format!("{}.{}", lower(schema), raw));
912        }
913        candidates.push(raw);
914    }
915
916    let mut dedup = HashSet::new();
917    candidates
918        .into_iter()
919        .filter(|c| dedup.insert(c.clone()))
920        .collect()
921}
922
923fn resolve_reference_table_key(
924    table_name: &str,
925    explicit_schema: Option<&str>,
926    source_schema: Option<&str>,
927    schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929    reference_table_candidates(table_name, explicit_schema, source_schema)
930        .into_iter()
931        .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936        return true;
937    }
938    if source == target {
939        return true;
940    }
941    if source.is_numeric() && target.is_numeric() {
942        return true;
943    }
944    if source.is_temporal() && target.is_temporal() {
945        return true;
946    }
947    false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951    let mut hints = HashSet::new();
952    for column in &table.columns {
953        if column.primary_key || column.unique {
954            hints.insert(lower(&column.name));
955        }
956    }
957    for key_col in &table.primary_key {
958        hints.insert(lower(key_col));
959    }
960    for group in &table.unique_keys {
961        if group.len() == 1 {
962            if let Some(col) = group.first() {
963                hints.insert(lower(col));
964            }
965        }
966    }
967    hints
968}
969
970fn check_reference_integrity(
971    schema: &ValidationSchema,
972    schema_map: &HashMap<String, TableSchemaEntry>,
973    strict: bool,
974) -> Vec<ValidationError> {
975    let mut errors = Vec::new();
976
977    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978    for table in &schema.tables {
979        let simple = lower(&table.name);
980        key_hints_lookup.insert(simple, table_key_hints(table));
981        if let Some(schema_name) = &table.schema {
982            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983            key_hints_lookup.insert(qualified, table_key_hints(table));
984        }
985    }
986
987    for table in &schema.tables {
988        let source_table_display = if let Some(schema_name) = &table.schema {
989            format!("{}.{}", schema_name, table.name)
990        } else {
991            table.name.clone()
992        };
993        let source_schema = table.schema.as_deref();
994        let source_columns: HashMap<String, TypeFamily> = table
995            .columns
996            .iter()
997            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998            .collect();
999
1000        for source_col in &table.columns {
1001            let Some(reference) = &source_col.references else {
1002                continue;
1003            };
1004            let source_type = canonical_type_family(&source_col.data_type);
1005
1006            let Some(target_key) = resolve_reference_table_key(
1007                &reference.table,
1008                reference.schema.as_deref(),
1009                source_schema,
1010                schema_map,
1011            ) else {
1012                errors.push(reference_issue(
1013                    strict,
1014                    format!(
1015                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1016                        source_table_display, source_col.name, reference.table
1017                    ),
1018                ));
1019                continue;
1020            };
1021
1022            let target_column = lower(&reference.column);
1023            let Some(target_entry) = schema_map.get(&target_key) else {
1024                errors.push(reference_issue(
1025                    strict,
1026                    format!(
1027                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1028                        source_table_display, source_col.name, reference.table
1029                    ),
1030                ));
1031                continue;
1032            };
1033
1034            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035                errors.push(reference_issue(
1036                    strict,
1037                    format!(
1038                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039                        source_table_display, source_col.name, target_key, reference.column
1040                    ),
1041                ));
1042                continue;
1043            };
1044
1045            if !key_types_compatible(source_type, target_type) {
1046                errors.push(reference_issue(
1047                    strict,
1048                    format!(
1049                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050                        source_table_display,
1051                        source_col.name,
1052                        target_key,
1053                        reference.column,
1054                        type_family_name(source_type),
1055                        type_family_name(target_type)
1056                    ),
1057                ));
1058            }
1059
1060            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061                if !target_key_hints.contains(&target_column) {
1062                    errors.push(ValidationError::warning(
1063                        format!(
1064                            "Referenced column '{}.{}' is not marked as primary/unique key",
1065                            target_key, reference.column
1066                        ),
1067                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068                    ));
1069                }
1070            }
1071        }
1072
1073        for foreign_key in &table.foreign_keys {
1074            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075                errors.push(reference_issue(
1076                    strict,
1077                    format!(
1078                        "Table-level foreign key on '{}' has empty source or target column list",
1079                        source_table_display
1080                    ),
1081                ));
1082                continue;
1083            }
1084            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085                errors.push(reference_issue(
1086                    strict,
1087                    format!(
1088                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089                        source_table_display,
1090                        foreign_key.columns.len(),
1091                        foreign_key.references.columns.len()
1092                    ),
1093                ));
1094                continue;
1095            }
1096
1097            let Some(target_key) = resolve_reference_table_key(
1098                &foreign_key.references.table,
1099                foreign_key.references.schema.as_deref(),
1100                source_schema,
1101                schema_map,
1102            ) else {
1103                errors.push(reference_issue(
1104                    strict,
1105                    format!(
1106                        "Table-level foreign key on '{}' points to unknown table '{}'",
1107                        source_table_display, foreign_key.references.table
1108                    ),
1109                ));
1110                continue;
1111            };
1112
1113            let Some(target_entry) = schema_map.get(&target_key) else {
1114                errors.push(reference_issue(
1115                    strict,
1116                    format!(
1117                        "Table-level foreign key on '{}' points to unknown table '{}'",
1118                        source_table_display, foreign_key.references.table
1119                    ),
1120                ));
1121                continue;
1122            };
1123
1124            for (source_col, target_col) in foreign_key
1125                .columns
1126                .iter()
1127                .zip(foreign_key.references.columns.iter())
1128            {
1129                let source_col_name = lower(source_col);
1130                let target_col_name = lower(target_col);
1131
1132                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133                    errors.push(reference_issue(
1134                        strict,
1135                        format!(
1136                            "Table-level foreign key on '{}' references unknown source column '{}'",
1137                            source_table_display, source_col
1138                        ),
1139                    ));
1140                    continue;
1141                };
1142
1143                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144                    errors.push(reference_issue(
1145                        strict,
1146                        format!(
1147                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148                            source_table_display, target_key, target_col
1149                        ),
1150                    ));
1151                    continue;
1152                };
1153
1154                if !key_types_compatible(source_type, target_type) {
1155                    errors.push(reference_issue(
1156                        strict,
1157                        format!(
1158                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159                            source_table_display,
1160                            source_col,
1161                            target_key,
1162                            target_col,
1163                            type_family_name(source_type),
1164                            type_family_name(target_type)
1165                        ),
1166                    ));
1167                }
1168
1169                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170                    if !target_key_hints.contains(&target_col_name) {
1171                        errors.push(ValidationError::warning(
1172                            format!(
1173                                "Referenced column '{}.{}' is not marked as primary/unique key",
1174                                target_key, target_col
1175                            ),
1176                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177                        ));
1178                    }
1179                }
1180            }
1181        }
1182    }
1183
1184    errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188    column_name: &str,
1189    schema_map: &HashMap<String, TableSchemaEntry>,
1190    context: &TypeCheckContext,
1191) -> TypeFamily {
1192    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193        context.referenced_tables.iter().collect()
1194    } else {
1195        schema_map.keys().collect()
1196    };
1197
1198    let mut families = HashSet::new();
1199    for table_name in candidate_tables {
1200        if let Some(table_schema) = schema_map.get(table_name) {
1201            if let Some(family) = table_schema.columns.get(column_name) {
1202                families.insert(*family);
1203            }
1204        }
1205    }
1206
1207    if families.len() == 1 {
1208        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209    } else {
1210        TypeFamily::Unknown
1211    }
1212}
1213
1214fn resolve_column_type(
1215    column: &Column,
1216    schema_map: &HashMap<String, TableSchemaEntry>,
1217    context: &TypeCheckContext,
1218) -> TypeFamily {
1219    let column_name = lower(&column.name.name);
1220    if column_name.is_empty() {
1221        return TypeFamily::Unknown;
1222    }
1223
1224    if let Some(table) = &column.table {
1225        let mut table_key = lower(&table.name);
1226        if let Some(mapped) = context.table_aliases.get(&table_key) {
1227            table_key = mapped.clone();
1228        }
1229
1230        return schema_map
1231            .get(&table_key)
1232            .and_then(|t| t.columns.get(&column_name))
1233            .copied()
1234            .unwrap_or(TypeFamily::Unknown);
1235    }
1236
1237    resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241    schema_map: &'a HashMap<String, TableSchemaEntry>,
1242    context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246    fn resolve_table_key(&self, table: &str) -> Option<String> {
1247        let mut table_key = lower(table);
1248        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249            table_key = mapped.clone();
1250        }
1251        if self.schema_map.contains_key(&table_key) {
1252            Some(table_key)
1253        } else {
1254            None
1255        }
1256    }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260    fn dialect(&self) -> Option<DialectType> {
1261        None
1262    }
1263
1264    fn add_table(
1265        &mut self,
1266        _table: &str,
1267        _columns: &[(String, DataType)],
1268        _dialect: Option<DialectType>,
1269    ) -> SchemaResult<()> {
1270        Err(SchemaError::InvalidStructure(
1271            "Type inference schema is read-only".to_string(),
1272        ))
1273    }
1274
1275    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276        let table_key = self
1277            .resolve_table_key(table)
1278            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279        let entry = self
1280            .schema_map
1281            .get(&table_key)
1282            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283        Ok(entry.column_order.clone())
1284    }
1285
1286    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287        let col_name = lower(column);
1288        if table.is_empty() {
1289            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290            return if family == TypeFamily::Unknown {
1291                Err(SchemaError::ColumnNotFound {
1292                    table: "<unqualified>".to_string(),
1293                    column: column.to_string(),
1294                })
1295            } else {
1296                Ok(type_family_to_data_type(family))
1297            };
1298        }
1299
1300        let table_key = self
1301            .resolve_table_key(table)
1302            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303        let entry = self
1304            .schema_map
1305            .get(&table_key)
1306            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307        let family =
1308            entry
1309                .columns
1310                .get(&col_name)
1311                .copied()
1312                .ok_or_else(|| SchemaError::ColumnNotFound {
1313                    table: table.to_string(),
1314                    column: column.to_string(),
1315                })?;
1316        Ok(type_family_to_data_type(family))
1317    }
1318
1319    fn has_column(&self, table: &str, column: &str) -> bool {
1320        self.get_column_type(table, column).is_ok()
1321    }
1322
1323    fn supported_table_args(&self) -> &[&str] {
1324        TABLE_PARTS
1325    }
1326
1327    fn is_empty(&self) -> bool {
1328        self.schema_map.is_empty()
1329    }
1330
1331    fn depth(&self) -> usize {
1332        1
1333    }
1334}
1335
1336fn infer_expression_type_family(
1337    expr: &Expression,
1338    schema_map: &HashMap<String, TableSchemaEntry>,
1339    context: &TypeCheckContext,
1340) -> TypeFamily {
1341    let inference_schema = TypeInferenceSchema {
1342        schema_map,
1343        context,
1344    };
1345    if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1346        let family = data_type_family(&data_type);
1347        if family != TypeFamily::Unknown {
1348            return family;
1349        }
1350    }
1351
1352    infer_expression_type_family_fallback(expr, schema_map, context)
1353}
1354
1355fn infer_expression_type_family_fallback(
1356    expr: &Expression,
1357    schema_map: &HashMap<String, TableSchemaEntry>,
1358    context: &TypeCheckContext,
1359) -> TypeFamily {
1360    match expr {
1361        Expression::Literal(literal) => match literal {
1362            crate::expressions::Literal::Number(value) => {
1363                if value.contains('.') || value.contains('e') || value.contains('E') {
1364                    TypeFamily::Numeric
1365                } else {
1366                    TypeFamily::Integer
1367                }
1368            }
1369            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1370            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1371            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1372            crate::expressions::Literal::Timestamp(_)
1373            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1374            crate::expressions::Literal::HexString(_)
1375            | crate::expressions::Literal::BitString(_)
1376            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1377            _ => TypeFamily::String,
1378        },
1379        Expression::Boolean(_) => TypeFamily::Boolean,
1380        Expression::Null(_) => TypeFamily::Unknown,
1381        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1382        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1383            data_type_family(&cast.to)
1384        }
1385        Expression::Alias(alias) => {
1386            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1387        }
1388        Expression::Neg(unary) => {
1389            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1390        }
1391        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1392            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1393            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1394            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1395                TypeFamily::Unknown
1396            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1397                TypeFamily::Integer
1398            } else if left.is_numeric() && right.is_numeric() {
1399                TypeFamily::Numeric
1400            } else if left.is_temporal() || right.is_temporal() {
1401                left
1402            } else {
1403                TypeFamily::Unknown
1404            }
1405        }
1406        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1407        Expression::Concat(_) => TypeFamily::String,
1408        Expression::Eq(_)
1409        | Expression::Neq(_)
1410        | Expression::Lt(_)
1411        | Expression::Lte(_)
1412        | Expression::Gt(_)
1413        | Expression::Gte(_)
1414        | Expression::Like(_)
1415        | Expression::ILike(_)
1416        | Expression::And(_)
1417        | Expression::Or(_)
1418        | Expression::Not(_)
1419        | Expression::Between(_)
1420        | Expression::In(_)
1421        | Expression::IsNull(_)
1422        | Expression::IsTrue(_)
1423        | Expression::IsFalse(_)
1424        | Expression::Is(_) => TypeFamily::Boolean,
1425        Expression::Length(_) => TypeFamily::Integer,
1426        Expression::Upper(_)
1427        | Expression::Lower(_)
1428        | Expression::Trim(_)
1429        | Expression::LTrim(_)
1430        | Expression::RTrim(_)
1431        | Expression::Replace(_)
1432        | Expression::Substring(_)
1433        | Expression::Left(_)
1434        | Expression::Right(_)
1435        | Expression::Repeat(_)
1436        | Expression::Lpad(_)
1437        | Expression::Rpad(_)
1438        | Expression::ConcatWs(_) => TypeFamily::String,
1439        Expression::Abs(_)
1440        | Expression::Round(_)
1441        | Expression::Floor(_)
1442        | Expression::Ceil(_)
1443        | Expression::Power(_)
1444        | Expression::Sqrt(_)
1445        | Expression::Cbrt(_)
1446        | Expression::Ln(_)
1447        | Expression::Log(_)
1448        | Expression::Exp(_) => TypeFamily::Numeric,
1449        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1450        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1451        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1452        Expression::CurrentDate(_) => TypeFamily::Date,
1453        Expression::CurrentTime(_) => TypeFamily::Time,
1454        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1455            TypeFamily::Timestamp
1456        }
1457        Expression::Interval(_) => TypeFamily::Interval,
1458        _ => TypeFamily::Unknown,
1459    }
1460}
1461
1462fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1463    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1464        return true;
1465    }
1466    if left == right {
1467        return true;
1468    }
1469    if left.is_numeric() && right.is_numeric() {
1470        return true;
1471    }
1472    if left.is_temporal() && right.is_temporal() {
1473        return true;
1474    }
1475    false
1476}
1477
1478fn check_function_argument(
1479    errors: &mut Vec<ValidationError>,
1480    strict: bool,
1481    function_name: &str,
1482    arg_index: usize,
1483    family: TypeFamily,
1484    expected: &str,
1485    valid: bool,
1486) {
1487    if family == TypeFamily::Unknown || valid {
1488        return;
1489    }
1490
1491    errors.push(type_issue(
1492        strict,
1493        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1494        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1495        format!(
1496            "Function '{}' argument {} expects {}, found {}",
1497            function_name,
1498            arg_index + 1,
1499            expected,
1500            type_family_name(family)
1501        ),
1502    ));
1503}
1504
1505fn function_dispatch_name(name: &str) -> String {
1506    let upper = name
1507        .rsplit('.')
1508        .next()
1509        .unwrap_or(name)
1510        .trim()
1511        .to_uppercase();
1512    lower(canonical_typed_function_name_upper(&upper))
1513}
1514
1515fn function_base_name(name: &str) -> &str {
1516    name.rsplit('.').next().unwrap_or(name).trim()
1517}
1518
1519fn check_generic_function(
1520    function: &Function,
1521    schema_map: &HashMap<String, TableSchemaEntry>,
1522    context: &TypeCheckContext,
1523    strict: bool,
1524    errors: &mut Vec<ValidationError>,
1525) {
1526    let name = function_dispatch_name(&function.name);
1527
1528    let arg_family = |index: usize| -> Option<TypeFamily> {
1529        function
1530            .args
1531            .get(index)
1532            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1533    };
1534
1535    match name.as_str() {
1536        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1537            if let Some(family) = arg_family(0) {
1538                check_function_argument(
1539                    errors,
1540                    strict,
1541                    &name,
1542                    0,
1543                    family,
1544                    "a numeric argument",
1545                    family.is_numeric(),
1546                );
1547            }
1548        }
1549        "round" | "floor" | "ceil" | "ceiling" => {
1550            if let Some(family) = arg_family(0) {
1551                check_function_argument(
1552                    errors,
1553                    strict,
1554                    &name,
1555                    0,
1556                    family,
1557                    "a numeric argument",
1558                    family.is_numeric(),
1559                );
1560            }
1561            if let Some(family) = arg_family(1) {
1562                check_function_argument(
1563                    errors,
1564                    strict,
1565                    &name,
1566                    1,
1567                    family,
1568                    "a numeric argument",
1569                    family.is_numeric(),
1570                );
1571            }
1572        }
1573        "power" | "pow" => {
1574            for i in [0_usize, 1_usize] {
1575                if let Some(family) = arg_family(i) {
1576                    check_function_argument(
1577                        errors,
1578                        strict,
1579                        &name,
1580                        i,
1581                        family,
1582                        "a numeric argument",
1583                        family.is_numeric(),
1584                    );
1585                }
1586            }
1587        }
1588        "length" | "char_length" | "character_length" => {
1589            if let Some(family) = arg_family(0) {
1590                check_function_argument(
1591                    errors,
1592                    strict,
1593                    &name,
1594                    0,
1595                    family,
1596                    "a string or binary argument",
1597                    is_string_or_binary(family),
1598                );
1599            }
1600        }
1601        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1602            if let Some(family) = arg_family(0) {
1603                check_function_argument(
1604                    errors,
1605                    strict,
1606                    &name,
1607                    0,
1608                    family,
1609                    "a string argument",
1610                    is_string_like(family),
1611                );
1612            }
1613        }
1614        "substring" | "substr" => {
1615            if let Some(family) = arg_family(0) {
1616                check_function_argument(
1617                    errors,
1618                    strict,
1619                    &name,
1620                    0,
1621                    family,
1622                    "a string argument",
1623                    is_string_like(family),
1624                );
1625            }
1626            if let Some(family) = arg_family(1) {
1627                check_function_argument(
1628                    errors,
1629                    strict,
1630                    &name,
1631                    1,
1632                    family,
1633                    "a numeric argument",
1634                    family.is_numeric(),
1635                );
1636            }
1637            if let Some(family) = arg_family(2) {
1638                check_function_argument(
1639                    errors,
1640                    strict,
1641                    &name,
1642                    2,
1643                    family,
1644                    "a numeric argument",
1645                    family.is_numeric(),
1646                );
1647            }
1648        }
1649        "replace" => {
1650            for i in [0_usize, 1_usize, 2_usize] {
1651                if let Some(family) = arg_family(i) {
1652                    check_function_argument(
1653                        errors,
1654                        strict,
1655                        &name,
1656                        i,
1657                        family,
1658                        "a string argument",
1659                        is_string_like(family),
1660                    );
1661                }
1662            }
1663        }
1664        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1665            if let Some(family) = arg_family(0) {
1666                check_function_argument(
1667                    errors,
1668                    strict,
1669                    &name,
1670                    0,
1671                    family,
1672                    "a string argument",
1673                    is_string_like(family),
1674                );
1675            }
1676            if let Some(family) = arg_family(1) {
1677                check_function_argument(
1678                    errors,
1679                    strict,
1680                    &name,
1681                    1,
1682                    family,
1683                    "a numeric argument",
1684                    family.is_numeric(),
1685                );
1686            }
1687            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1688                if let Some(family) = arg_family(2) {
1689                    check_function_argument(
1690                        errors,
1691                        strict,
1692                        &name,
1693                        2,
1694                        family,
1695                        "a string argument",
1696                        is_string_like(family),
1697                    );
1698                }
1699            }
1700        }
1701        _ => {}
1702    }
1703}
1704
1705fn check_function_catalog(
1706    function: &Function,
1707    dialect: DialectType,
1708    function_catalog: Option<&dyn FunctionCatalog>,
1709    strict: bool,
1710    errors: &mut Vec<ValidationError>,
1711) {
1712    let Some(catalog) = function_catalog else {
1713        return;
1714    };
1715
1716    let raw_name = function_base_name(&function.name);
1717    let normalized_name = function_dispatch_name(&function.name);
1718    let arity = function.args.len();
1719    let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1720        errors.push(if strict {
1721            ValidationError::error(
1722                format!(
1723                    "Unknown function '{}' for dialect {:?}",
1724                    function.name, dialect
1725                ),
1726                validation_codes::E_UNKNOWN_FUNCTION,
1727            )
1728        } else {
1729            ValidationError::warning(
1730                format!(
1731                    "Unknown function '{}' for dialect {:?}",
1732                    function.name, dialect
1733                ),
1734                validation_codes::E_UNKNOWN_FUNCTION,
1735            )
1736        });
1737        return;
1738    };
1739
1740    if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1741        return;
1742    }
1743
1744    let expected = signatures
1745        .iter()
1746        .map(|sig| sig.describe_arity())
1747        .collect::<Vec<_>>()
1748        .join(", ");
1749    errors.push(if strict {
1750        ValidationError::error(
1751            format!(
1752                "Invalid arity for function '{}': got {}, expected {}",
1753                function.name, arity, expected
1754            ),
1755            validation_codes::E_INVALID_FUNCTION_ARITY,
1756        )
1757    } else {
1758        ValidationError::warning(
1759            format!(
1760                "Invalid arity for function '{}': got {}, expected {}",
1761                function.name, arity, expected
1762            ),
1763            validation_codes::E_INVALID_FUNCTION_ARITY,
1764        )
1765    });
1766}
1767
1768#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1769struct DeclaredRelationship {
1770    source_table: String,
1771    source_column: String,
1772    target_table: String,
1773    target_column: String,
1774}
1775
1776fn build_declared_relationships(
1777    schema: &ValidationSchema,
1778    schema_map: &HashMap<String, TableSchemaEntry>,
1779) -> Vec<DeclaredRelationship> {
1780    let mut relationships = HashSet::new();
1781
1782    for table in &schema.tables {
1783        let Some(source_key) =
1784            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1785        else {
1786            continue;
1787        };
1788
1789        for column in &table.columns {
1790            let Some(reference) = &column.references else {
1791                continue;
1792            };
1793            let Some(target_key) = resolve_reference_table_key(
1794                &reference.table,
1795                reference.schema.as_deref(),
1796                table.schema.as_deref(),
1797                schema_map,
1798            ) else {
1799                continue;
1800            };
1801
1802            relationships.insert(DeclaredRelationship {
1803                source_table: source_key.clone(),
1804                source_column: lower(&column.name),
1805                target_table: target_key,
1806                target_column: lower(&reference.column),
1807            });
1808        }
1809
1810        for foreign_key in &table.foreign_keys {
1811            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1812                continue;
1813            }
1814            let Some(target_key) = resolve_reference_table_key(
1815                &foreign_key.references.table,
1816                foreign_key.references.schema.as_deref(),
1817                table.schema.as_deref(),
1818                schema_map,
1819            ) else {
1820                continue;
1821            };
1822
1823            for (source_col, target_col) in foreign_key
1824                .columns
1825                .iter()
1826                .zip(foreign_key.references.columns.iter())
1827            {
1828                relationships.insert(DeclaredRelationship {
1829                    source_table: source_key.clone(),
1830                    source_column: lower(source_col),
1831                    target_table: target_key.clone(),
1832                    target_column: lower(target_col),
1833                });
1834            }
1835        }
1836    }
1837
1838    relationships.into_iter().collect()
1839}
1840
1841fn resolve_column_binding(
1842    column: &Column,
1843    schema_map: &HashMap<String, TableSchemaEntry>,
1844    context: &TypeCheckContext,
1845    resolver: &mut Resolver<'_>,
1846) -> Option<(String, String)> {
1847    let column_name = lower(&column.name.name);
1848    if column_name.is_empty() {
1849        return None;
1850    }
1851
1852    if let Some(table) = &column.table {
1853        let mut table_key = lower(&table.name);
1854        if let Some(mapped) = context.table_aliases.get(&table_key) {
1855            table_key = mapped.clone();
1856        }
1857        if schema_map.contains_key(&table_key) {
1858            return Some((table_key, column_name));
1859        }
1860        return None;
1861    }
1862
1863    if let Some(resolved_source) = resolver.get_table(&column_name) {
1864        let mut table_key = lower(&resolved_source);
1865        if let Some(mapped) = context.table_aliases.get(&table_key) {
1866            table_key = mapped.clone();
1867        }
1868        if schema_map.contains_key(&table_key) {
1869            return Some((table_key, column_name));
1870        }
1871    }
1872
1873    let candidates: Vec<String> = context
1874        .referenced_tables
1875        .iter()
1876        .filter_map(|table_name| {
1877            schema_map
1878                .get(table_name)
1879                .filter(|entry| entry.columns.contains_key(&column_name))
1880                .map(|_| table_name.clone())
1881        })
1882        .collect();
1883    if candidates.len() == 1 {
1884        return Some((candidates[0].clone(), column_name));
1885    }
1886    None
1887}
1888
1889fn extract_join_equality_pairs(
1890    expr: &Expression,
1891    schema_map: &HashMap<String, TableSchemaEntry>,
1892    context: &TypeCheckContext,
1893    resolver: &mut Resolver<'_>,
1894    pairs: &mut Vec<((String, String), (String, String))>,
1895) {
1896    match expr {
1897        Expression::And(op) => {
1898            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1899            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1900        }
1901        Expression::Paren(paren) => {
1902            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1903        }
1904        Expression::Eq(op) => {
1905            let (Expression::Column(left_col), Expression::Column(right_col)) =
1906                (&op.left, &op.right)
1907            else {
1908                return;
1909            };
1910            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1911                return;
1912            };
1913            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1914            else {
1915                return;
1916            };
1917            pairs.push((left, right));
1918        }
1919        _ => {}
1920    }
1921}
1922
1923fn relationship_matches_pair(
1924    relationship: &DeclaredRelationship,
1925    left_table: &str,
1926    left_column: &str,
1927    right_table: &str,
1928    right_column: &str,
1929) -> bool {
1930    (relationship.source_table == left_table
1931        && relationship.source_column == left_column
1932        && relationship.target_table == right_table
1933        && relationship.target_column == right_column)
1934        || (relationship.source_table == right_table
1935            && relationship.source_column == right_column
1936            && relationship.target_table == left_table
1937            && relationship.target_column == left_column)
1938}
1939
1940fn resolved_table_key_from_expr(
1941    expr: &Expression,
1942    schema_map: &HashMap<String, TableSchemaEntry>,
1943) -> Option<String> {
1944    match expr {
1945        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1946        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1947        _ => None,
1948    }
1949}
1950
1951fn select_from_table_keys(
1952    select: &crate::expressions::Select,
1953    schema_map: &HashMap<String, TableSchemaEntry>,
1954) -> HashSet<String> {
1955    let mut keys = HashSet::new();
1956    if let Some(from_clause) = &select.from {
1957        for expr in &from_clause.expressions {
1958            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1959                keys.insert(key);
1960            }
1961        }
1962    }
1963    keys
1964}
1965
1966fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1967    matches!(
1968        kind,
1969        JoinKind::Natural
1970            | JoinKind::NaturalLeft
1971            | JoinKind::NaturalRight
1972            | JoinKind::NaturalFull
1973            | JoinKind::CrossApply
1974            | JoinKind::OuterApply
1975            | JoinKind::AsOf
1976            | JoinKind::AsOfLeft
1977            | JoinKind::AsOfRight
1978            | JoinKind::Lateral
1979            | JoinKind::LeftLateral
1980    )
1981}
1982
1983fn check_query_reference_quality(
1984    stmt: &Expression,
1985    schema_map: &HashMap<String, TableSchemaEntry>,
1986    resolver_schema: &MappingSchema,
1987    strict: bool,
1988    relationships: &[DeclaredRelationship],
1989) -> Vec<ValidationError> {
1990    let mut errors = Vec::new();
1991
1992    for node in stmt.dfs() {
1993        let Expression::Select(select) = node else {
1994            continue;
1995        };
1996
1997        let select_expr = Expression::Select(select.clone());
1998        let context = collect_type_check_context(&select_expr, schema_map);
1999        let scope = build_scope(&select_expr);
2000        let mut resolver = Resolver::new(&scope, resolver_schema, true);
2001
2002        if context.referenced_tables.len() > 1 {
2003            let using_columns: HashSet<String> = select
2004                .joins
2005                .iter()
2006                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2007                .collect();
2008
2009            let mut seen = HashSet::new();
2010            for column_expr in select_expr
2011                .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2012            {
2013                let Expression::Column(column) = column_expr else {
2014                    continue;
2015                };
2016
2017                let col_name = lower(&column.name.name);
2018                if col_name.is_empty()
2019                    || using_columns.contains(&col_name)
2020                    || !seen.insert(col_name.clone())
2021                {
2022                    continue;
2023                }
2024
2025                if resolver.is_ambiguous(&col_name) {
2026                    let source_count = resolver.sources_for_column(&col_name).len();
2027                    errors.push(if strict {
2028                        ValidationError::error(
2029                            format!(
2030                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2031                                col_name, source_count
2032                            ),
2033                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2034                        )
2035                    } else {
2036                        ValidationError::warning(
2037                            format!(
2038                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2039                                col_name, source_count
2040                            ),
2041                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2042                        )
2043                    });
2044                }
2045            }
2046        }
2047
2048        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2049
2050        for join in &select.joins {
2051            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2052            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2053            let cartesian_like_kind = matches!(
2054                join.kind,
2055                JoinKind::Cross
2056                    | JoinKind::Implicit
2057                    | JoinKind::Array
2058                    | JoinKind::LeftArray
2059                    | JoinKind::Paste
2060            );
2061
2062            if right_table_key.is_some()
2063                && (cartesian_like_kind
2064                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2065            {
2066                errors.push(ValidationError::warning(
2067                    "Potential cartesian join: JOIN without ON/USING condition",
2068                    validation_codes::W_CARTESIAN_JOIN,
2069                ));
2070            }
2071
2072            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2073                if join.using.is_empty() {
2074                    let mut eq_pairs = Vec::new();
2075                    extract_join_equality_pairs(
2076                        on_expr,
2077                        schema_map,
2078                        &context,
2079                        &mut resolver,
2080                        &mut eq_pairs,
2081                    );
2082
2083                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2084                        .iter()
2085                        .filter(|rel| {
2086                            cumulative_left_tables.contains(&rel.source_table)
2087                                && rel.target_table == right_key
2088                                || (cumulative_left_tables.contains(&rel.target_table)
2089                                    && rel.source_table == right_key)
2090                        })
2091                        .collect();
2092
2093                    if !relevant_relationships.is_empty() {
2094                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2095                            relevant_relationships
2096                                .iter()
2097                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2098                        });
2099                        if !uses_declared_fk {
2100                            errors.push(ValidationError::warning(
2101                                "JOIN predicate does not use declared foreign-key relationship columns",
2102                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2103                            ));
2104                        }
2105                    }
2106                }
2107            }
2108
2109            if let Some(right_key) = right_table_key {
2110                cumulative_left_tables.insert(right_key);
2111            }
2112        }
2113    }
2114
2115    errors
2116}
2117
2118fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2119    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2120        return true;
2121    }
2122    if left == right {
2123        return true;
2124    }
2125    if left.is_numeric() && right.is_numeric() {
2126        return true;
2127    }
2128    if left.is_temporal() && right.is_temporal() {
2129        return true;
2130    }
2131    false
2132}
2133
2134fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2135    if left == TypeFamily::Unknown {
2136        return right;
2137    }
2138    if right == TypeFamily::Unknown {
2139        return left;
2140    }
2141    if left == right {
2142        return left;
2143    }
2144    if left.is_numeric() && right.is_numeric() {
2145        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2146            return TypeFamily::Numeric;
2147        }
2148        return TypeFamily::Integer;
2149    }
2150    if left.is_temporal() && right.is_temporal() {
2151        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2152            return TypeFamily::Timestamp;
2153        }
2154        if left == TypeFamily::Date || right == TypeFamily::Date {
2155            return TypeFamily::Date;
2156        }
2157        return TypeFamily::Time;
2158    }
2159    TypeFamily::Unknown
2160}
2161
2162fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2163    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2164        return true;
2165    }
2166    if target == source {
2167        return true;
2168    }
2169
2170    match target {
2171        TypeFamily::Boolean => source == TypeFamily::Boolean,
2172        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2173        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2174            source.is_temporal()
2175        }
2176        TypeFamily::String => true,
2177        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2178        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2179        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2180        TypeFamily::Array => source == TypeFamily::Array,
2181        TypeFamily::Map => source == TypeFamily::Map,
2182        TypeFamily::Struct => source == TypeFamily::Struct,
2183        TypeFamily::Unknown => true,
2184    }
2185}
2186
2187fn projection_families(
2188    query_expr: &Expression,
2189    schema_map: &HashMap<String, TableSchemaEntry>,
2190) -> Option<Vec<TypeFamily>> {
2191    match query_expr {
2192        Expression::Select(select) => {
2193            if select
2194                .expressions
2195                .iter()
2196                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2197            {
2198                return None;
2199            }
2200            let select_expr = Expression::Select(select.clone());
2201            let context = collect_type_check_context(&select_expr, schema_map);
2202            Some(
2203                select
2204                    .expressions
2205                    .iter()
2206                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2207                    .collect(),
2208            )
2209        }
2210        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2211        Expression::Union(union) => {
2212            let left = projection_families(&union.left, schema_map)?;
2213            let right = projection_families(&union.right, schema_map)?;
2214            if left.len() != right.len() {
2215                return None;
2216            }
2217            Some(
2218                left.into_iter()
2219                    .zip(right)
2220                    .map(|(l, r)| merged_setop_family(l, r))
2221                    .collect(),
2222            )
2223        }
2224        Expression::Intersect(intersect) => {
2225            let left = projection_families(&intersect.left, schema_map)?;
2226            let right = projection_families(&intersect.right, schema_map)?;
2227            if left.len() != right.len() {
2228                return None;
2229            }
2230            Some(
2231                left.into_iter()
2232                    .zip(right)
2233                    .map(|(l, r)| merged_setop_family(l, r))
2234                    .collect(),
2235            )
2236        }
2237        Expression::Except(except) => {
2238            let left = projection_families(&except.left, schema_map)?;
2239            let right = projection_families(&except.right, schema_map)?;
2240            if left.len() != right.len() {
2241                return None;
2242            }
2243            Some(
2244                left.into_iter()
2245                    .zip(right)
2246                    .map(|(l, r)| merged_setop_family(l, r))
2247                    .collect(),
2248            )
2249        }
2250        Expression::Values(values) => {
2251            let first_row = values.expressions.first()?;
2252            let context = TypeCheckContext::default();
2253            Some(
2254                first_row
2255                    .expressions
2256                    .iter()
2257                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2258                    .collect(),
2259            )
2260        }
2261        _ => None,
2262    }
2263}
2264
2265fn check_set_operation_compatibility(
2266    op_name: &str,
2267    left_expr: &Expression,
2268    right_expr: &Expression,
2269    schema_map: &HashMap<String, TableSchemaEntry>,
2270    strict: bool,
2271    errors: &mut Vec<ValidationError>,
2272) {
2273    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2274        return;
2275    };
2276    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2277        return;
2278    };
2279
2280    if left_projection.len() != right_projection.len() {
2281        errors.push(type_issue(
2282            strict,
2283            validation_codes::E_SETOP_ARITY_MISMATCH,
2284            validation_codes::W_SETOP_IMPLICIT_COERCION,
2285            format!(
2286                "{} operands return different column counts: left {}, right {}",
2287                op_name,
2288                left_projection.len(),
2289                right_projection.len()
2290            ),
2291        ));
2292        return;
2293    }
2294
2295    for (idx, (left, right)) in left_projection
2296        .into_iter()
2297        .zip(right_projection)
2298        .enumerate()
2299    {
2300        if !are_setop_compatible(left, right) {
2301            errors.push(type_issue(
2302                strict,
2303                validation_codes::E_SETOP_TYPE_MISMATCH,
2304                validation_codes::W_SETOP_IMPLICIT_COERCION,
2305                format!(
2306                    "{} column {} has incompatible types: {} vs {}",
2307                    op_name,
2308                    idx + 1,
2309                    type_family_name(left),
2310                    type_family_name(right)
2311                ),
2312            ));
2313        }
2314    }
2315}
2316
2317fn check_insert_assignments(
2318    stmt: &Expression,
2319    insert: &Insert,
2320    schema_map: &HashMap<String, TableSchemaEntry>,
2321    strict: bool,
2322    errors: &mut Vec<ValidationError>,
2323) {
2324    let Some((target_table_key, table_schema)) =
2325        resolve_table_schema_entry(&insert.table, schema_map)
2326    else {
2327        return;
2328    };
2329
2330    let mut target_columns = Vec::new();
2331    if insert.columns.is_empty() {
2332        target_columns.extend(table_schema.column_order.iter().cloned());
2333    } else {
2334        for column in &insert.columns {
2335            let col_name = lower(&column.name);
2336            if table_schema.columns.contains_key(&col_name) {
2337                target_columns.push(col_name);
2338            } else {
2339                errors.push(if strict {
2340                    ValidationError::error(
2341                        format!(
2342                            "Unknown column '{}' in table '{}'",
2343                            column.name, target_table_key
2344                        ),
2345                        validation_codes::E_UNKNOWN_COLUMN,
2346                    )
2347                } else {
2348                    ValidationError::warning(
2349                        format!(
2350                            "Unknown column '{}' in table '{}'",
2351                            column.name, target_table_key
2352                        ),
2353                        validation_codes::E_UNKNOWN_COLUMN,
2354                    )
2355                });
2356            }
2357        }
2358    }
2359
2360    if target_columns.is_empty() {
2361        return;
2362    }
2363
2364    let context = collect_type_check_context(stmt, schema_map);
2365
2366    if !insert.default_values {
2367        for (row_idx, row) in insert.values.iter().enumerate() {
2368            if row.len() != target_columns.len() {
2369                errors.push(type_issue(
2370                    strict,
2371                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2372                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2373                    format!(
2374                        "INSERT row {} has {} values but target has {} columns",
2375                        row_idx + 1,
2376                        row.len(),
2377                        target_columns.len()
2378                    ),
2379                ));
2380                continue;
2381            }
2382
2383            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2384                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2385                    continue;
2386                };
2387                let source_family = infer_expression_type_family(value, schema_map, &context);
2388                if !are_assignment_compatible(target_family, source_family) {
2389                    errors.push(type_issue(
2390                        strict,
2391                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2392                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2393                        format!(
2394                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2395                            target_table_key,
2396                            target_column,
2397                            type_family_name(target_family),
2398                            type_family_name(source_family)
2399                        ),
2400                    ));
2401                }
2402            }
2403        }
2404    }
2405
2406    if let Some(query) = &insert.query {
2407        // DuckDB BY NAME maps source columns by name, not position.
2408        if insert.by_name {
2409            return;
2410        }
2411
2412        let Some(source_projection) = projection_families(query, schema_map) else {
2413            return;
2414        };
2415
2416        if source_projection.len() != target_columns.len() {
2417            errors.push(type_issue(
2418                strict,
2419                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2420                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2421                format!(
2422                    "INSERT source query has {} columns but target has {} columns",
2423                    source_projection.len(),
2424                    target_columns.len()
2425                ),
2426            ));
2427            return;
2428        }
2429
2430        for (source_family, target_column) in
2431            source_projection.into_iter().zip(target_columns.iter())
2432        {
2433            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2434                continue;
2435            };
2436            if !are_assignment_compatible(target_family, source_family) {
2437                errors.push(type_issue(
2438                    strict,
2439                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2440                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2441                    format!(
2442                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2443                        target_table_key,
2444                        target_column,
2445                        type_family_name(target_family),
2446                        type_family_name(source_family)
2447                    ),
2448                ));
2449            }
2450        }
2451    }
2452}
2453
2454fn check_update_assignments(
2455    stmt: &Expression,
2456    update: &Update,
2457    schema_map: &HashMap<String, TableSchemaEntry>,
2458    strict: bool,
2459    errors: &mut Vec<ValidationError>,
2460) {
2461    let Some((target_table_key, table_schema)) =
2462        resolve_table_schema_entry(&update.table, schema_map)
2463    else {
2464        return;
2465    };
2466
2467    let context = collect_type_check_context(stmt, schema_map);
2468
2469    for (column, value) in &update.set {
2470        let col_name = lower(&column.name);
2471        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2472            errors.push(if strict {
2473                ValidationError::error(
2474                    format!(
2475                        "Unknown column '{}' in table '{}'",
2476                        column.name, target_table_key
2477                    ),
2478                    validation_codes::E_UNKNOWN_COLUMN,
2479                )
2480            } else {
2481                ValidationError::warning(
2482                    format!(
2483                        "Unknown column '{}' in table '{}'",
2484                        column.name, target_table_key
2485                    ),
2486                    validation_codes::E_UNKNOWN_COLUMN,
2487                )
2488            });
2489            continue;
2490        };
2491
2492        let source_family = infer_expression_type_family(value, schema_map, &context);
2493        if !are_assignment_compatible(target_family, source_family) {
2494            errors.push(type_issue(
2495                strict,
2496                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2497                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2498                format!(
2499                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2500                    target_table_key,
2501                    col_name,
2502                    type_family_name(target_family),
2503                    type_family_name(source_family)
2504                ),
2505            ));
2506        }
2507    }
2508}
2509
2510fn check_types(
2511    stmt: &Expression,
2512    dialect: DialectType,
2513    schema_map: &HashMap<String, TableSchemaEntry>,
2514    function_catalog: Option<&dyn FunctionCatalog>,
2515    strict: bool,
2516) -> Vec<ValidationError> {
2517    let mut errors = Vec::new();
2518    let context = collect_type_check_context(stmt, schema_map);
2519
2520    for node in stmt.dfs() {
2521        match node {
2522            Expression::Insert(insert) => {
2523                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2524            }
2525            Expression::Update(update) => {
2526                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2527            }
2528            Expression::Union(union) => {
2529                check_set_operation_compatibility(
2530                    "UNION",
2531                    &union.left,
2532                    &union.right,
2533                    schema_map,
2534                    strict,
2535                    &mut errors,
2536                );
2537            }
2538            Expression::Intersect(intersect) => {
2539                check_set_operation_compatibility(
2540                    "INTERSECT",
2541                    &intersect.left,
2542                    &intersect.right,
2543                    schema_map,
2544                    strict,
2545                    &mut errors,
2546                );
2547            }
2548            Expression::Except(except) => {
2549                check_set_operation_compatibility(
2550                    "EXCEPT",
2551                    &except.left,
2552                    &except.right,
2553                    schema_map,
2554                    strict,
2555                    &mut errors,
2556                );
2557            }
2558            Expression::Select(select) => {
2559                if let Some(prewhere) = &select.prewhere {
2560                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2561                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2562                        errors.push(type_issue(
2563                            strict,
2564                            validation_codes::E_INVALID_PREDICATE_TYPE,
2565                            validation_codes::W_PREDICATE_NULLABILITY,
2566                            format!(
2567                                "PREWHERE clause expects a boolean predicate, found {}",
2568                                type_family_name(family)
2569                            ),
2570                        ));
2571                    }
2572                }
2573
2574                if let Some(where_clause) = &select.where_clause {
2575                    let family =
2576                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2577                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578                        errors.push(type_issue(
2579                            strict,
2580                            validation_codes::E_INVALID_PREDICATE_TYPE,
2581                            validation_codes::W_PREDICATE_NULLABILITY,
2582                            format!(
2583                                "WHERE clause expects a boolean predicate, found {}",
2584                                type_family_name(family)
2585                            ),
2586                        ));
2587                    }
2588                }
2589
2590                if let Some(having_clause) = &select.having {
2591                    let family =
2592                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2593                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594                        errors.push(type_issue(
2595                            strict,
2596                            validation_codes::E_INVALID_PREDICATE_TYPE,
2597                            validation_codes::W_PREDICATE_NULLABILITY,
2598                            format!(
2599                                "HAVING clause expects a boolean predicate, found {}",
2600                                type_family_name(family)
2601                            ),
2602                        ));
2603                    }
2604                }
2605
2606                for join in &select.joins {
2607                    if let Some(on) = &join.on {
2608                        let family = infer_expression_type_family(on, schema_map, &context);
2609                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2610                            errors.push(type_issue(
2611                                strict,
2612                                validation_codes::E_INVALID_PREDICATE_TYPE,
2613                                validation_codes::W_PREDICATE_NULLABILITY,
2614                                format!(
2615                                    "JOIN ON expects a boolean predicate, found {}",
2616                                    type_family_name(family)
2617                                ),
2618                            ));
2619                        }
2620                    }
2621                    if let Some(match_condition) = &join.match_condition {
2622                        let family =
2623                            infer_expression_type_family(match_condition, schema_map, &context);
2624                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2625                            errors.push(type_issue(
2626                                strict,
2627                                validation_codes::E_INVALID_PREDICATE_TYPE,
2628                                validation_codes::W_PREDICATE_NULLABILITY,
2629                                format!(
2630                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2631                                    type_family_name(family)
2632                                ),
2633                            ));
2634                        }
2635                    }
2636                }
2637            }
2638            Expression::Where(where_clause) => {
2639                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2640                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2641                    errors.push(type_issue(
2642                        strict,
2643                        validation_codes::E_INVALID_PREDICATE_TYPE,
2644                        validation_codes::W_PREDICATE_NULLABILITY,
2645                        format!(
2646                            "WHERE clause expects a boolean predicate, found {}",
2647                            type_family_name(family)
2648                        ),
2649                    ));
2650                }
2651            }
2652            Expression::Having(having_clause) => {
2653                let family =
2654                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2655                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2656                    errors.push(type_issue(
2657                        strict,
2658                        validation_codes::E_INVALID_PREDICATE_TYPE,
2659                        validation_codes::W_PREDICATE_NULLABILITY,
2660                        format!(
2661                            "HAVING clause expects a boolean predicate, found {}",
2662                            type_family_name(family)
2663                        ),
2664                    ));
2665                }
2666            }
2667            Expression::And(op) | Expression::Or(op) => {
2668                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2669                    let family = infer_expression_type_family(expr, schema_map, &context);
2670                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2671                        errors.push(type_issue(
2672                            strict,
2673                            validation_codes::E_INVALID_PREDICATE_TYPE,
2674                            validation_codes::W_PREDICATE_NULLABILITY,
2675                            format!(
2676                                "Logical {} operand expects boolean, found {}",
2677                                side,
2678                                type_family_name(family)
2679                            ),
2680                        ));
2681                    }
2682                }
2683            }
2684            Expression::Not(unary) => {
2685                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2686                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2687                    errors.push(type_issue(
2688                        strict,
2689                        validation_codes::E_INVALID_PREDICATE_TYPE,
2690                        validation_codes::W_PREDICATE_NULLABILITY,
2691                        format!("NOT expects boolean, found {}", type_family_name(family)),
2692                    ));
2693                }
2694            }
2695            Expression::Eq(op)
2696            | Expression::Neq(op)
2697            | Expression::Lt(op)
2698            | Expression::Lte(op)
2699            | Expression::Gt(op)
2700            | Expression::Gte(op) => {
2701                let left = infer_expression_type_family(&op.left, schema_map, &context);
2702                let right = infer_expression_type_family(&op.right, schema_map, &context);
2703                if !are_comparable(left, right) {
2704                    errors.push(type_issue(
2705                        strict,
2706                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2707                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2708                        format!(
2709                            "Incompatible comparison between {} and {}",
2710                            type_family_name(left),
2711                            type_family_name(right)
2712                        ),
2713                    ));
2714                }
2715            }
2716            Expression::Like(op) | Expression::ILike(op) => {
2717                let left = infer_expression_type_family(&op.left, schema_map, &context);
2718                let right = infer_expression_type_family(&op.right, schema_map, &context);
2719                if left != TypeFamily::Unknown
2720                    && right != TypeFamily::Unknown
2721                    && (!is_string_like(left) || !is_string_like(right))
2722                {
2723                    errors.push(type_issue(
2724                        strict,
2725                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2726                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2727                        format!(
2728                            "LIKE/ILIKE expects string operands, found {} and {}",
2729                            type_family_name(left),
2730                            type_family_name(right)
2731                        ),
2732                    ));
2733                }
2734            }
2735            Expression::Between(between) => {
2736                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2737                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2738                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2739
2740                if !are_comparable(this_family, low_family)
2741                    || !are_comparable(this_family, high_family)
2742                {
2743                    errors.push(type_issue(
2744                        strict,
2745                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2746                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2747                        format!(
2748                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2749                            type_family_name(this_family),
2750                            type_family_name(low_family),
2751                            type_family_name(high_family)
2752                        ),
2753                    ));
2754                }
2755            }
2756            Expression::In(in_expr) => {
2757                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2758                for value in &in_expr.expressions {
2759                    let item_family = infer_expression_type_family(value, schema_map, &context);
2760                    if !are_comparable(this_family, item_family) {
2761                        errors.push(type_issue(
2762                            strict,
2763                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2764                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2765                            format!(
2766                                "IN item type {} is incompatible with {}",
2767                                type_family_name(item_family),
2768                                type_family_name(this_family)
2769                            ),
2770                        ));
2771                        break;
2772                    }
2773                }
2774            }
2775            Expression::Add(op)
2776            | Expression::Sub(op)
2777            | Expression::Mul(op)
2778            | Expression::Div(op)
2779            | Expression::Mod(op) => {
2780                let left = infer_expression_type_family(&op.left, schema_map, &context);
2781                let right = infer_expression_type_family(&op.right, schema_map, &context);
2782
2783                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2784                    continue;
2785                }
2786
2787                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2788                    && ((left.is_temporal() && right.is_numeric())
2789                        || (right.is_temporal() && left.is_numeric())
2790                        || (matches!(node, Expression::Sub(_))
2791                            && left.is_temporal()
2792                            && right.is_temporal()));
2793
2794                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2795                    errors.push(type_issue(
2796                        strict,
2797                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2798                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2799                        format!(
2800                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2801                            type_family_name(left),
2802                            type_family_name(right)
2803                        ),
2804                    ));
2805                }
2806            }
2807            Expression::Function(function) => {
2808                check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2809                check_generic_function(function, schema_map, &context, strict, &mut errors);
2810            }
2811            Expression::Upper(func)
2812            | Expression::Lower(func)
2813            | Expression::LTrim(func)
2814            | Expression::RTrim(func)
2815            | Expression::Reverse(func) => {
2816                let family = infer_expression_type_family(&func.this, schema_map, &context);
2817                check_function_argument(
2818                    &mut errors,
2819                    strict,
2820                    "string_function",
2821                    0,
2822                    family,
2823                    "a string argument",
2824                    is_string_like(family),
2825                );
2826            }
2827            Expression::Length(func) => {
2828                let family = infer_expression_type_family(&func.this, schema_map, &context);
2829                check_function_argument(
2830                    &mut errors,
2831                    strict,
2832                    "length",
2833                    0,
2834                    family,
2835                    "a string or binary argument",
2836                    is_string_or_binary(family),
2837                );
2838            }
2839            Expression::Trim(func) => {
2840                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2841                check_function_argument(
2842                    &mut errors,
2843                    strict,
2844                    "trim",
2845                    0,
2846                    this_family,
2847                    "a string argument",
2848                    is_string_like(this_family),
2849                );
2850                if let Some(chars) = &func.characters {
2851                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2852                    check_function_argument(
2853                        &mut errors,
2854                        strict,
2855                        "trim",
2856                        1,
2857                        chars_family,
2858                        "a string argument",
2859                        is_string_like(chars_family),
2860                    );
2861                }
2862            }
2863            Expression::Substring(func) => {
2864                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2865                check_function_argument(
2866                    &mut errors,
2867                    strict,
2868                    "substring",
2869                    0,
2870                    this_family,
2871                    "a string argument",
2872                    is_string_like(this_family),
2873                );
2874
2875                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2876                check_function_argument(
2877                    &mut errors,
2878                    strict,
2879                    "substring",
2880                    1,
2881                    start_family,
2882                    "a numeric argument",
2883                    start_family.is_numeric(),
2884                );
2885                if let Some(length) = &func.length {
2886                    let length_family = infer_expression_type_family(length, schema_map, &context);
2887                    check_function_argument(
2888                        &mut errors,
2889                        strict,
2890                        "substring",
2891                        2,
2892                        length_family,
2893                        "a numeric argument",
2894                        length_family.is_numeric(),
2895                    );
2896                }
2897            }
2898            Expression::Replace(func) => {
2899                for (arg, idx) in [
2900                    (&func.this, 0_usize),
2901                    (&func.old, 1_usize),
2902                    (&func.new, 2_usize),
2903                ] {
2904                    let family = infer_expression_type_family(arg, schema_map, &context);
2905                    check_function_argument(
2906                        &mut errors,
2907                        strict,
2908                        "replace",
2909                        idx,
2910                        family,
2911                        "a string argument",
2912                        is_string_like(family),
2913                    );
2914                }
2915            }
2916            Expression::Left(func) | Expression::Right(func) => {
2917                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2918                check_function_argument(
2919                    &mut errors,
2920                    strict,
2921                    "left_right",
2922                    0,
2923                    this_family,
2924                    "a string argument",
2925                    is_string_like(this_family),
2926                );
2927                let length_family =
2928                    infer_expression_type_family(&func.length, schema_map, &context);
2929                check_function_argument(
2930                    &mut errors,
2931                    strict,
2932                    "left_right",
2933                    1,
2934                    length_family,
2935                    "a numeric argument",
2936                    length_family.is_numeric(),
2937                );
2938            }
2939            Expression::Repeat(func) => {
2940                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2941                check_function_argument(
2942                    &mut errors,
2943                    strict,
2944                    "repeat",
2945                    0,
2946                    this_family,
2947                    "a string argument",
2948                    is_string_like(this_family),
2949                );
2950                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2951                check_function_argument(
2952                    &mut errors,
2953                    strict,
2954                    "repeat",
2955                    1,
2956                    times_family,
2957                    "a numeric argument",
2958                    times_family.is_numeric(),
2959                );
2960            }
2961            Expression::Lpad(func) | Expression::Rpad(func) => {
2962                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2963                check_function_argument(
2964                    &mut errors,
2965                    strict,
2966                    "pad",
2967                    0,
2968                    this_family,
2969                    "a string argument",
2970                    is_string_like(this_family),
2971                );
2972                let length_family =
2973                    infer_expression_type_family(&func.length, schema_map, &context);
2974                check_function_argument(
2975                    &mut errors,
2976                    strict,
2977                    "pad",
2978                    1,
2979                    length_family,
2980                    "a numeric argument",
2981                    length_family.is_numeric(),
2982                );
2983                if let Some(fill) = &func.fill {
2984                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
2985                    check_function_argument(
2986                        &mut errors,
2987                        strict,
2988                        "pad",
2989                        2,
2990                        fill_family,
2991                        "a string argument",
2992                        is_string_like(fill_family),
2993                    );
2994                }
2995            }
2996            Expression::Abs(func)
2997            | Expression::Sqrt(func)
2998            | Expression::Cbrt(func)
2999            | Expression::Ln(func)
3000            | Expression::Exp(func) => {
3001                let family = infer_expression_type_family(&func.this, schema_map, &context);
3002                check_function_argument(
3003                    &mut errors,
3004                    strict,
3005                    "numeric_function",
3006                    0,
3007                    family,
3008                    "a numeric argument",
3009                    family.is_numeric(),
3010                );
3011            }
3012            Expression::Round(func) => {
3013                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3014                check_function_argument(
3015                    &mut errors,
3016                    strict,
3017                    "round",
3018                    0,
3019                    this_family,
3020                    "a numeric argument",
3021                    this_family.is_numeric(),
3022                );
3023                if let Some(decimals) = &func.decimals {
3024                    let decimals_family =
3025                        infer_expression_type_family(decimals, schema_map, &context);
3026                    check_function_argument(
3027                        &mut errors,
3028                        strict,
3029                        "round",
3030                        1,
3031                        decimals_family,
3032                        "a numeric argument",
3033                        decimals_family.is_numeric(),
3034                    );
3035                }
3036            }
3037            Expression::Floor(func) => {
3038                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3039                check_function_argument(
3040                    &mut errors,
3041                    strict,
3042                    "floor",
3043                    0,
3044                    this_family,
3045                    "a numeric argument",
3046                    this_family.is_numeric(),
3047                );
3048                if let Some(scale) = &func.scale {
3049                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
3050                    check_function_argument(
3051                        &mut errors,
3052                        strict,
3053                        "floor",
3054                        1,
3055                        scale_family,
3056                        "a numeric argument",
3057                        scale_family.is_numeric(),
3058                    );
3059                }
3060            }
3061            Expression::Ceil(func) => {
3062                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3063                check_function_argument(
3064                    &mut errors,
3065                    strict,
3066                    "ceil",
3067                    0,
3068                    this_family,
3069                    "a numeric argument",
3070                    this_family.is_numeric(),
3071                );
3072                if let Some(decimals) = &func.decimals {
3073                    let decimals_family =
3074                        infer_expression_type_family(decimals, schema_map, &context);
3075                    check_function_argument(
3076                        &mut errors,
3077                        strict,
3078                        "ceil",
3079                        1,
3080                        decimals_family,
3081                        "a numeric argument",
3082                        decimals_family.is_numeric(),
3083                    );
3084                }
3085            }
3086            Expression::Power(func) => {
3087                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3088                check_function_argument(
3089                    &mut errors,
3090                    strict,
3091                    "power",
3092                    0,
3093                    left_family,
3094                    "a numeric argument",
3095                    left_family.is_numeric(),
3096                );
3097                let right_family =
3098                    infer_expression_type_family(&func.expression, schema_map, &context);
3099                check_function_argument(
3100                    &mut errors,
3101                    strict,
3102                    "power",
3103                    1,
3104                    right_family,
3105                    "a numeric argument",
3106                    right_family.is_numeric(),
3107                );
3108            }
3109            Expression::Log(func) => {
3110                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3111                check_function_argument(
3112                    &mut errors,
3113                    strict,
3114                    "log",
3115                    0,
3116                    this_family,
3117                    "a numeric argument",
3118                    this_family.is_numeric(),
3119                );
3120                if let Some(base) = &func.base {
3121                    let base_family = infer_expression_type_family(base, schema_map, &context);
3122                    check_function_argument(
3123                        &mut errors,
3124                        strict,
3125                        "log",
3126                        1,
3127                        base_family,
3128                        "a numeric argument",
3129                        base_family.is_numeric(),
3130                    );
3131                }
3132            }
3133            _ => {}
3134        }
3135    }
3136
3137    errors
3138}
3139
3140fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3141    let mut errors = Vec::new();
3142
3143    let Expression::Select(select) = stmt else {
3144        return errors;
3145    };
3146    let select_expr = Expression::Select(select.clone());
3147
3148    // W001: SELECT * is discouraged
3149    if !select_expr
3150        .find_all(|e| matches!(e, Expression::Star(_)))
3151        .is_empty()
3152    {
3153        errors.push(ValidationError::warning(
3154            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3155            validation_codes::W_SELECT_STAR,
3156        ));
3157    }
3158
3159    // W002: aggregate + non-aggregate columns without GROUP BY
3160    let aggregate_count = get_aggregate_functions(&select_expr).len();
3161    if aggregate_count > 0 && select.group_by.is_none() {
3162        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3163            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3164                && get_aggregate_functions(expr).is_empty()
3165        });
3166
3167        if has_non_aggregate_column {
3168            errors.push(ValidationError::warning(
3169                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3170                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3171            ));
3172        }
3173    }
3174
3175    // W003: DISTINCT with ORDER BY
3176    if select.distinct && select.order_by.is_some() {
3177        errors.push(ValidationError::warning(
3178            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3179            validation_codes::W_DISTINCT_ORDER_BY,
3180        ));
3181    }
3182
3183    // W004: LIMIT without ORDER BY
3184    if select.limit.is_some() && select.order_by.is_none() {
3185        errors.push(ValidationError::warning(
3186            "LIMIT without ORDER BY produces non-deterministic results",
3187            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3188        ));
3189    }
3190
3191    errors
3192}
3193
3194fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3195    scope
3196        .sources
3197        .get_key_value(name)
3198        .map(|(k, _)| k.clone())
3199        .or_else(|| {
3200            scope
3201                .sources
3202                .keys()
3203                .find(|source| source.eq_ignore_ascii_case(name))
3204                .cloned()
3205        })
3206}
3207
3208fn source_has_column(columns: &[String], column_name: &str) -> bool {
3209    columns
3210        .iter()
3211        .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3212}
3213
3214fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3215    scope
3216        .sources
3217        .get(source_name)
3218        .map(|source| match &source.expression {
3219            Expression::Table(table) => lower(&table_ref_display_name(table)),
3220            _ => lower(source_name),
3221        })
3222        .unwrap_or_else(|| lower(source_name))
3223}
3224
3225fn validate_select_columns_with_schema(
3226    select: &crate::expressions::Select,
3227    schema_map: &HashMap<String, TableSchemaEntry>,
3228    resolver_schema: &MappingSchema,
3229    strict: bool,
3230) -> Vec<ValidationError> {
3231    let mut errors = Vec::new();
3232    let select_expr = Expression::Select(Box::new(select.clone()));
3233    let scope = build_scope(&select_expr);
3234    let mut resolver = Resolver::new(&scope, resolver_schema, true);
3235    let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3236
3237    for node in walk_in_scope(&select_expr, false) {
3238        let Expression::Column(column) = node else {
3239            continue;
3240        };
3241
3242        let col_name = lower(&column.name.name);
3243        if col_name.is_empty() {
3244            continue;
3245        }
3246
3247        if let Some(table) = &column.table {
3248            let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3249                // The table qualifier is not a declared alias or source in this scope
3250                errors.push(if strict {
3251                    ValidationError::error(
3252                        format!(
3253                            "Unknown table or alias '{}' referenced by column '{}'",
3254                            table.name, col_name
3255                        ),
3256                        validation_codes::E_UNRESOLVED_REFERENCE,
3257                    )
3258                } else {
3259                    ValidationError::warning(
3260                        format!(
3261                            "Unknown table or alias '{}' referenced by column '{}'",
3262                            table.name, col_name
3263                        ),
3264                        validation_codes::E_UNRESOLVED_REFERENCE,
3265                    )
3266                });
3267                continue;
3268            };
3269
3270            if let Ok(columns) = resolver.get_source_columns(&source_name) {
3271                if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3272                    let table_name = source_display_name(&scope, &source_name);
3273                    errors.push(if strict {
3274                        ValidationError::error(
3275                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3276                            validation_codes::E_UNKNOWN_COLUMN,
3277                        )
3278                    } else {
3279                        ValidationError::warning(
3280                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3281                            validation_codes::E_UNKNOWN_COLUMN,
3282                        )
3283                    });
3284                }
3285            }
3286            continue;
3287        }
3288
3289        let matching_sources: Vec<String> = source_names
3290            .iter()
3291            .filter_map(|source_name| {
3292                resolver
3293                    .get_source_columns(source_name)
3294                    .ok()
3295                    .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3296                    .map(|_| source_name.clone())
3297            })
3298            .collect();
3299
3300        if !matching_sources.is_empty() {
3301            continue;
3302        }
3303
3304        let known_sources: Vec<String> = source_names
3305            .iter()
3306            .filter_map(|source_name| {
3307                resolver
3308                    .get_source_columns(source_name)
3309                    .ok()
3310                    .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3311                    .map(|_| source_name.clone())
3312            })
3313            .collect();
3314
3315        if known_sources.len() == 1 {
3316            let table_name = source_display_name(&scope, &known_sources[0]);
3317            errors.push(if strict {
3318                ValidationError::error(
3319                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3320                    validation_codes::E_UNKNOWN_COLUMN,
3321                )
3322            } else {
3323                ValidationError::warning(
3324                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3325                    validation_codes::E_UNKNOWN_COLUMN,
3326                )
3327            });
3328        } else if known_sources.len() > 1 {
3329            errors.push(if strict {
3330                ValidationError::error(
3331                    format!(
3332                        "Unknown column '{}' (not found in any referenced table)",
3333                        col_name
3334                    ),
3335                    validation_codes::E_UNKNOWN_COLUMN,
3336                )
3337            } else {
3338                ValidationError::warning(
3339                    format!(
3340                        "Unknown column '{}' (not found in any referenced table)",
3341                        col_name
3342                    ),
3343                    validation_codes::E_UNKNOWN_COLUMN,
3344                )
3345            });
3346        } else if !schema_map.is_empty() {
3347            let found = schema_map
3348                .values()
3349                .any(|table_schema| table_schema.columns.contains_key(&col_name));
3350            if !found {
3351                errors.push(if strict {
3352                    ValidationError::error(
3353                        format!("Unknown column '{}'", col_name),
3354                        validation_codes::E_UNKNOWN_COLUMN,
3355                    )
3356                } else {
3357                    ValidationError::warning(
3358                        format!("Unknown column '{}'", col_name),
3359                        validation_codes::E_UNKNOWN_COLUMN,
3360                    )
3361                });
3362            }
3363        }
3364    }
3365
3366    errors
3367}
3368
3369fn validate_statement_with_schema(
3370    stmt: &Expression,
3371    schema_map: &HashMap<String, TableSchemaEntry>,
3372    resolver_schema: &MappingSchema,
3373    strict: bool,
3374) -> Vec<ValidationError> {
3375    let mut errors = Vec::new();
3376    let cte_aliases = collect_cte_aliases(stmt);
3377    let mut seen_tables: HashSet<String> = HashSet::new();
3378
3379    // Table validation (E200)
3380    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3381        let Expression::Table(table) = node else {
3382            continue;
3383        };
3384
3385        if cte_aliases.contains(&lower(&table.name.name)) {
3386            continue;
3387        }
3388
3389        let resolved_key = table_ref_candidates(table)
3390            .into_iter()
3391            .find(|k| schema_map.contains_key(k));
3392        let table_key = resolved_key
3393            .clone()
3394            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3395
3396        if !seen_tables.insert(table_key) {
3397            continue;
3398        }
3399
3400        if resolved_key.is_none() {
3401            errors.push(if strict {
3402                ValidationError::error(
3403                    format!("Unknown table '{}'", table_ref_display_name(table)),
3404                    validation_codes::E_UNKNOWN_TABLE,
3405                )
3406            } else {
3407                ValidationError::warning(
3408                    format!("Unknown table '{}'", table_ref_display_name(table)),
3409                    validation_codes::E_UNKNOWN_TABLE,
3410                )
3411            });
3412        }
3413    }
3414
3415    for node in stmt.dfs() {
3416        let Expression::Select(select) = node else {
3417            continue;
3418        };
3419        errors.extend(validate_select_columns_with_schema(
3420            select,
3421            schema_map,
3422            resolver_schema,
3423            strict,
3424        ));
3425    }
3426
3427    errors
3428}
3429
3430/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3431pub fn validate_with_schema(
3432    sql: &str,
3433    dialect: DialectType,
3434    schema: &ValidationSchema,
3435    options: &SchemaValidationOptions,
3436) -> ValidationResult {
3437    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3438
3439    // Syntax validation first.
3440    let syntax_result = crate::validate_with_options(
3441        sql,
3442        dialect,
3443        &crate::ValidationOptions {
3444            strict_syntax: options.strict_syntax,
3445        },
3446    );
3447    if !syntax_result.valid {
3448        return syntax_result;
3449    }
3450
3451    let d = Dialect::get(dialect);
3452    let statements = match d.parse(sql) {
3453        Ok(exprs) => exprs,
3454        Err(e) => {
3455            return ValidationResult::with_errors(vec![ValidationError::error(
3456                e.to_string(),
3457                validation_codes::E_PARSE_OR_OPTIONS,
3458            )]);
3459        }
3460    };
3461
3462    let schema_map = build_schema_map(schema);
3463    let resolver_schema = build_resolver_schema(schema);
3464    let mut all_errors = syntax_result.errors;
3465    let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3466        default_embedded_function_catalog()
3467    } else {
3468        None
3469    };
3470    let effective_function_catalog = options
3471        .function_catalog
3472        .as_deref()
3473        .or_else(|| embedded_function_catalog.as_deref());
3474    let declared_relationships = if options.check_references {
3475        build_declared_relationships(schema, &schema_map)
3476    } else {
3477        Vec::new()
3478    };
3479
3480    if options.check_references {
3481        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3482    }
3483
3484    for stmt in &statements {
3485        if options.semantic {
3486            all_errors.extend(check_semantics(stmt));
3487        }
3488        all_errors.extend(validate_statement_with_schema(
3489            stmt,
3490            &schema_map,
3491            &resolver_schema,
3492            strict,
3493        ));
3494        if options.check_types {
3495            all_errors.extend(check_types(
3496                stmt,
3497                dialect,
3498                &schema_map,
3499                effective_function_catalog,
3500                strict,
3501            ));
3502        }
3503        if options.check_references {
3504            all_errors.extend(check_query_reference_quality(
3505                stmt,
3506                &schema_map,
3507                &resolver_schema,
3508                strict,
3509                &declared_relationships,
3510            ));
3511        }
3512    }
3513
3514    ValidationResult::with_errors(all_errors)
3515}
3516
3517#[cfg(test)]
3518mod tests;