Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15    feature = "function-catalog-clickhouse",
16    feature = "function-catalog-duckdb",
17    feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20    FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21    HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34    feature = "function-catalog-clickhouse",
35    feature = "function-catalog-duckdb",
36    feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40/// Column definition used for schema-aware validation.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43    /// Column name.
44    pub name: String,
45    /// Optional column data type (currently informational).
46    #[serde(default, rename = "type")]
47    pub data_type: String,
48    /// Whether the column allows NULL values.
49    #[serde(default)]
50    pub nullable: Option<bool>,
51    /// Whether this column is part of a primary key.
52    #[serde(default, rename = "primaryKey")]
53    pub primary_key: bool,
54    /// Whether this column has a uniqueness constraint.
55    #[serde(default)]
56    pub unique: bool,
57    /// Optional column-level foreign key reference.
58    #[serde(default)]
59    pub references: Option<SchemaColumnReference>,
60}
61
62/// Column-level foreign key reference metadata.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65    /// Referenced table name.
66    pub table: String,
67    /// Referenced column name.
68    pub column: String,
69    /// Optional schema/namespace of referenced table.
70    #[serde(default)]
71    pub schema: Option<String>,
72}
73
74/// Table-level foreign key reference metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77    /// Optional FK name.
78    #[serde(default)]
79    pub name: Option<String>,
80    /// Source columns in the current table.
81    pub columns: Vec<String>,
82    /// Referenced target table + columns.
83    pub references: SchemaTableReference,
84}
85
86/// Target of a table-level foreign key.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89    /// Referenced table name.
90    pub table: String,
91    /// Referenced target columns.
92    pub columns: Vec<String>,
93    /// Optional schema/namespace of referenced table.
94    #[serde(default)]
95    pub schema: Option<String>,
96}
97
98/// Table definition used for schema-aware validation.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101    /// Table name.
102    pub name: String,
103    /// Optional schema/namespace name.
104    #[serde(default)]
105    pub schema: Option<String>,
106    /// Column definitions.
107    pub columns: Vec<SchemaColumn>,
108    /// Optional aliases that should resolve to this table.
109    #[serde(default)]
110    pub aliases: Vec<String>,
111    /// Optional primary key column list.
112    #[serde(default, rename = "primaryKey")]
113    pub primary_key: Vec<String>,
114    /// Optional unique key groups.
115    #[serde(default, rename = "uniqueKeys")]
116    pub unique_keys: Vec<Vec<String>>,
117    /// Optional table-level foreign keys.
118    #[serde(default, rename = "foreignKeys")]
119    pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122/// Schema payload used for schema-aware validation.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125    /// Known tables.
126    pub tables: Vec<SchemaTable>,
127    /// Default strict mode for unknown identifiers.
128    #[serde(default)]
129    pub strict: Option<bool>,
130}
131
132/// Options for schema-aware validation.
133#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
136    #[serde(default)]
137    pub check_types: bool,
138    /// Enables FK/reference integrity checks and query-level reference quality checks.
139    #[serde(default)]
140    pub check_references: bool,
141    /// If true/false, overrides schema.strict.
142    #[serde(default)]
143    pub strict: Option<bool>,
144    /// Enables semantic warnings (W001..W004).
145    #[serde(default)]
146    pub semantic: bool,
147    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
148    #[serde(default)]
149    pub strict_syntax: bool,
150    /// Optional external function catalog plugin for dialect-specific function validation.
151    #[serde(skip, default)]
152    pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156    feature = "function-catalog-clickhouse",
157    feature = "function-catalog-duckdb",
158    feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161    case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163    match case {
164        polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165            CoreFunctionNameCase::Insensitive
166        }
167        polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168            CoreFunctionNameCase::Sensitive
169        }
170    }
171}
172
173#[cfg(any(
174    feature = "function-catalog-clickhouse",
175    feature = "function-catalog-duckdb",
176    feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179    signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181    signatures
182        .into_iter()
183        .map(|signature| CoreFunctionSignature {
184            min_arity: signature.min_arity,
185            max_arity: signature.max_arity,
186        })
187        .collect()
188}
189
190#[cfg(any(
191    feature = "function-catalog-clickhouse",
192    feature = "function-catalog-duckdb",
193    feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196    catalog: &'a mut HashMapFunctionCatalog,
197    dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201    feature = "function-catalog-clickhouse",
202    feature = "function-catalog-duckdb",
203    feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206    fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207        if let Some(cached) = self.dialect_cache.get(dialect) {
208            return *cached;
209        }
210        let parsed = dialect.parse::<DialectType>().ok();
211        self.dialect_cache.insert(dialect, parsed);
212        parsed
213    }
214}
215
216#[cfg(any(
217    feature = "function-catalog-clickhouse",
218    feature = "function-catalog-duckdb",
219    feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222    fn set_dialect_name_case(
223        &mut self,
224        dialect: &'static str,
225        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226    ) {
227        if let Some(core_dialect) = self.resolve_dialect(dialect) {
228            self.catalog
229                .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230        }
231    }
232
233    fn set_function_name_case(
234        &mut self,
235        dialect: &'static str,
236        function_name: &str,
237        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238    ) {
239        if let Some(core_dialect) = self.resolve_dialect(dialect) {
240            self.catalog.set_function_name_case(
241                core_dialect,
242                function_name,
243                to_core_name_case(name_case),
244            );
245        }
246    }
247
248    fn register(
249        &mut self,
250        dialect: &'static str,
251        function_name: &str,
252        signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253    ) {
254        if let Some(core_dialect) = self.resolve_dialect(dialect) {
255            self.catalog
256                .register(core_dialect, function_name, to_core_signatures(signatures));
257        }
258    }
259}
260
261#[cfg(any(
262    feature = "function-catalog-clickhouse",
263    feature = "function-catalog-duckdb",
264    feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267    static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268        let mut catalog = HashMapFunctionCatalog::default();
269        let mut sink = EmbeddedCatalogSink {
270            catalog: &mut catalog,
271            dialect_cache: HashMap::new(),
272        };
273        polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274        Arc::new(catalog)
275    });
276
277    EMBEDDED.clone()
278}
279
280#[cfg(any(
281    feature = "function-catalog-clickhouse",
282    feature = "function-catalog-duckdb",
283    feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286    Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290    feature = "function-catalog-clickhouse",
291    feature = "function-catalog-duckdb",
292    feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295    None
296}
297
298/// Validation error/warning codes used by schema-aware validation.
299pub mod validation_codes {
300    // Existing schema and semantic checks.
301    pub const E_PARSE_OR_OPTIONS: &str = "E000";
302    pub const E_UNKNOWN_TABLE: &str = "E200";
303    pub const E_UNKNOWN_COLUMN: &str = "E201";
304    pub const E_UNKNOWN_FUNCTION: &str = "E202";
305    pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307    pub const W_SELECT_STAR: &str = "W001";
308    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309    pub const W_DISTINCT_ORDER_BY: &str = "W003";
310    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312    // Phase 2 (type checks): E210-E219, W210-W219.
313    pub const E_TYPE_MISMATCH: &str = "E210";
314    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321    pub const E_INVALID_CAST: &str = "E218";
322    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327    pub const W_LOSSY_CAST: &str = "W213";
328    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329    pub const W_PREDICATE_NULLABILITY: &str = "W215";
330    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335    // Phase 2 (reference checks): E220-E229, W220-W229.
336    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342    pub const W_CARTESIAN_JOIN: &str = "W220";
343    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347/// Canonical type family used by schema/type checks.
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351    Unknown,
352    Boolean,
353    Integer,
354    Numeric,
355    String,
356    Binary,
357    Date,
358    Time,
359    Timestamp,
360    Interval,
361    Json,
362    Uuid,
363    Array,
364    Map,
365    Struct,
366}
367
368impl TypeFamily {
369    pub fn is_numeric(self) -> bool {
370        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371    }
372
373    pub fn is_temporal(self) -> bool {
374        matches!(
375            self,
376            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377        )
378    }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383    columns: HashMap<String, TypeFamily>,
384    column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388    s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392    let open = data_type.find('(')?;
393    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394        return None;
395    }
396    let base = data_type[..open].trim();
397    let inner = data_type[open + 1..data_type.len() - 1].trim();
398    Some((base, inner))
399}
400
401/// Canonicalize a schema type string into a stable `TypeFamily`.
402pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403    let trimmed = data_type
404        .trim()
405        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406    if trimmed.is_empty() {
407        return TypeFamily::Unknown;
408    }
409
410    // Normalize whitespace and lowercase for matching.
411    let normalized = trimmed
412        .split_whitespace()
413        .collect::<Vec<_>>()
414        .join(" ")
415        .to_lowercase();
416
417    // Strip common wrappers first.
418    if let Some((base, inner)) = split_type_args(&normalized) {
419        match base {
420            "nullable" | "lowcardinality" => return canonical_type_family(inner),
421            "array" | "list" => return TypeFamily::Array,
422            "map" => return TypeFamily::Map,
423            "struct" | "row" | "record" => return TypeFamily::Struct,
424            _ => {}
425        }
426    }
427
428    if normalized.starts_with("array<") || normalized.starts_with("list<") {
429        return TypeFamily::Array;
430    }
431    if normalized.starts_with("map<") {
432        return TypeFamily::Map;
433    }
434    if normalized.starts_with("struct<")
435        || normalized.starts_with("row<")
436        || normalized.starts_with("record<")
437        || normalized.starts_with("object<")
438    {
439        return TypeFamily::Struct;
440    }
441
442    if normalized.ends_with("[]") {
443        return TypeFamily::Array;
444    }
445
446    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
447    let mut base = normalized
448        .split('(')
449        .next()
450        .unwrap_or("")
451        .trim()
452        .to_string();
453    if base.is_empty() {
454        return TypeFamily::Unknown;
455    }
456
457    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460    match base.as_str() {
461        "bool" | "boolean" => TypeFamily::Boolean,
462        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465            TypeFamily::Integer
466        }
467        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469            TypeFamily::Numeric
470        }
471        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472        | "string" | "clob" => TypeFamily::String,
473        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474        "date" => TypeFamily::Date,
475        "time" => TypeFamily::Time,
476        "timestamp"
477        | "timestamptz"
478        | "datetime"
479        | "datetime2"
480        | "smalldatetime"
481        | "timestamp with time zone"
482        | "timestamp without time zone" => TypeFamily::Timestamp,
483        "interval" => TypeFamily::Interval,
484        "json" | "jsonb" | "variant" => TypeFamily::Json,
485        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486        "array" | "list" => TypeFamily::Array,
487        "map" => TypeFamily::Map,
488        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489        _ => TypeFamily::Unknown,
490    }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494    let mut map = HashMap::new();
495
496    for table in &schema.tables {
497        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498        let columns: HashMap<String, TypeFamily> = table
499            .columns
500            .iter()
501            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502            .collect();
503        let entry = TableSchemaEntry {
504            columns,
505            column_order,
506        };
507
508        let simple_name = lower(&table.name);
509        map.insert(simple_name, entry.clone());
510
511        if let Some(table_schema) = &table.schema {
512            map.insert(
513                format!("{}.{}", lower(table_schema), lower(&table.name)),
514                entry.clone(),
515            );
516        }
517
518        for alias in &table.aliases {
519            map.insert(lower(alias), entry.clone());
520        }
521    }
522
523    map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527    match family {
528        TypeFamily::Unknown => DataType::Unknown,
529        TypeFamily::Boolean => DataType::Boolean,
530        TypeFamily::Integer => DataType::Int {
531            length: None,
532            integer_spelling: false,
533        },
534        TypeFamily::Numeric => DataType::Double {
535            precision: None,
536            scale: None,
537        },
538        TypeFamily::String => DataType::VarChar {
539            length: None,
540            parenthesized_length: false,
541        },
542        TypeFamily::Binary => DataType::VarBinary { length: None },
543        TypeFamily::Date => DataType::Date,
544        TypeFamily::Time => DataType::Time {
545            precision: None,
546            timezone: false,
547        },
548        TypeFamily::Timestamp => DataType::Timestamp {
549            precision: None,
550            timezone: false,
551        },
552        TypeFamily::Interval => DataType::Interval {
553            unit: None,
554            to: None,
555        },
556        TypeFamily::Json => DataType::Json,
557        TypeFamily::Uuid => DataType::Uuid,
558        TypeFamily::Array => DataType::Array {
559            element_type: Box::new(DataType::Unknown),
560            dimension: None,
561        },
562        TypeFamily::Map => DataType::Map {
563            key_type: Box::new(DataType::Unknown),
564            value_type: Box::new(DataType::Unknown),
565        },
566        TypeFamily::Struct => DataType::Struct {
567            fields: Vec::new(),
568            nested: false,
569        },
570    }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574    let mut mapping = MappingSchema::new();
575
576    for table in &schema.tables {
577        let columns: Vec<(String, DataType)> = table
578            .columns
579            .iter()
580            .map(|column| {
581                (
582                    lower(&column.name),
583                    type_family_to_data_type(canonical_type_family(&column.data_type)),
584                )
585            })
586            .collect();
587
588        let mut table_names = Vec::new();
589        table_names.push(lower(&table.name));
590        if let Some(table_schema) = &table.schema {
591            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592        }
593        for alias in &table.aliases {
594            table_names.push(lower(alias));
595        }
596
597        let mut dedup = HashSet::new();
598        for table_name in table_names {
599            if dedup.insert(table_name.clone()) {
600                let _ = mapping.add_table(&table_name, &columns, None);
601            }
602        }
603    }
604
605    mapping
606}
607
608/// Build a `MappingSchema` from a `ValidationSchema` payload.
609///
610/// This is useful for APIs that already accept `ValidationSchema`-shaped input
611/// (for example JSON wrappers) and need to run schema-aware lineage or other
612/// resolver-based analysis.
613pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614    build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618    let mut aliases = HashSet::new();
619
620    for node in expr.dfs() {
621        match node {
622            Expression::Select(select) => {
623                if let Some(with) = &select.with {
624                    for cte in &with.ctes {
625                        aliases.insert(lower(&cte.alias.name));
626                    }
627                }
628            }
629            Expression::Insert(insert) => {
630                if let Some(with) = &insert.with {
631                    for cte in &with.ctes {
632                        aliases.insert(lower(&cte.alias.name));
633                    }
634                }
635            }
636            Expression::Update(update) => {
637                if let Some(with) = &update.with {
638                    for cte in &with.ctes {
639                        aliases.insert(lower(&cte.alias.name));
640                    }
641                }
642            }
643            Expression::Delete(delete) => {
644                if let Some(with) = &delete.with {
645                    for cte in &with.ctes {
646                        aliases.insert(lower(&cte.alias.name));
647                    }
648                }
649            }
650            Expression::Union(union) => {
651                if let Some(with) = &union.with {
652                    for cte in &with.ctes {
653                        aliases.insert(lower(&cte.alias.name));
654                    }
655                }
656            }
657            Expression::Intersect(intersect) => {
658                if let Some(with) = &intersect.with {
659                    for cte in &with.ctes {
660                        aliases.insert(lower(&cte.alias.name));
661                    }
662                }
663            }
664            Expression::Except(except) => {
665                if let Some(with) = &except.with {
666                    for cte in &with.ctes {
667                        aliases.insert(lower(&cte.alias.name));
668                    }
669                }
670            }
671            Expression::Merge(merge) => {
672                if let Some(with_) = &merge.with_ {
673                    if let Expression::With(with_clause) = with_.as_ref() {
674                        for cte in &with_clause.ctes {
675                            aliases.insert(lower(&cte.alias.name));
676                        }
677                    }
678                }
679            }
680            _ => {}
681        }
682    }
683
684    aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688    let name = lower(&table.name.name);
689    let schema = table.schema.as_ref().map(|s| lower(&s.name));
690    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692    let mut candidates = Vec::new();
693    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694        candidates.push(format!("{}.{}.{}", catalog, schema, name));
695    }
696    if let Some(schema) = &schema {
697        candidates.push(format!("{}.{}", schema, name));
698    }
699    candidates.push(name);
700    candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704    let mut parts = Vec::new();
705    if let Some(catalog) = &table.catalog {
706        parts.push(catalog.name.clone());
707    }
708    if let Some(schema) = &table.schema {
709        parts.push(schema.name.clone());
710    }
711    parts.push(table.name.name.clone());
712    parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717    referenced_tables: HashSet<String>,
718    table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722    match family {
723        TypeFamily::Unknown => "unknown",
724        TypeFamily::Boolean => "boolean",
725        TypeFamily::Integer => "integer",
726        TypeFamily::Numeric => "numeric",
727        TypeFamily::String => "string",
728        TypeFamily::Binary => "binary",
729        TypeFamily::Date => "date",
730        TypeFamily::Time => "time",
731        TypeFamily::Timestamp => "timestamp",
732        TypeFamily::Interval => "interval",
733        TypeFamily::Json => "json",
734        TypeFamily::Uuid => "uuid",
735        TypeFamily::Array => "array",
736        TypeFamily::Map => "map",
737        TypeFamily::Struct => "struct",
738    }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742    matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746    matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750    strict: bool,
751    error_code: &str,
752    warning_code: &str,
753    message: impl Into<String>,
754) -> ValidationError {
755    if strict {
756        ValidationError::error(message.into(), error_code)
757    } else {
758        ValidationError::warning(message.into(), warning_code)
759    }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763    match data_type {
764        DataType::Boolean => TypeFamily::Boolean,
765        DataType::TinyInt { .. }
766        | DataType::SmallInt { .. }
767        | DataType::Int { .. }
768        | DataType::BigInt { .. } => TypeFamily::Integer,
769        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770            TypeFamily::Numeric
771        }
772        DataType::Char { .. }
773        | DataType::VarChar { .. }
774        | DataType::String { .. }
775        | DataType::Text
776        | DataType::TextWithLength { .. }
777        | DataType::CharacterSet { .. } => TypeFamily::String,
778        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779        DataType::Date => TypeFamily::Date,
780        DataType::Time { .. } => TypeFamily::Time,
781        DataType::Timestamp { .. } => TypeFamily::Timestamp,
782        DataType::Interval { .. } => TypeFamily::Interval,
783        DataType::Json | DataType::JsonB => TypeFamily::Json,
784        DataType::Uuid => TypeFamily::Uuid,
785        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786        DataType::Map { .. } => TypeFamily::Map,
787        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788            TypeFamily::Struct
789        }
790        DataType::Nullable { inner } => data_type_family(inner),
791        DataType::Custom { name } => canonical_type_family(name),
792        DataType::Unknown => TypeFamily::Unknown,
793        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795        DataType::Vector { .. } => TypeFamily::Array,
796        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797    }
798}
799
800fn collect_type_check_context(
801    stmt: &Expression,
802    schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804    fn add_table_to_context(
805        table: &TableRef,
806        schema_map: &HashMap<String, TableSchemaEntry>,
807        context: &mut TypeCheckContext,
808    ) {
809        let resolved_key = table_ref_candidates(table)
810            .into_iter()
811            .find(|k| schema_map.contains_key(k));
812
813        let Some(table_key) = resolved_key else {
814            return;
815        };
816
817        context.referenced_tables.insert(table_key.clone());
818        context
819            .table_aliases
820            .insert(lower(&table.name.name), table_key.clone());
821        if let Some(alias) = &table.alias {
822            context
823                .table_aliases
824                .insert(lower(&alias.name), table_key.clone());
825        }
826    }
827
828    let mut context = TypeCheckContext::default();
829    let cte_aliases = collect_cte_aliases(stmt);
830
831    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832        let Expression::Table(table) = node else {
833            continue;
834        };
835
836        if cte_aliases.contains(&lower(&table.name.name)) {
837            continue;
838        }
839
840        add_table_to_context(table, schema_map, &mut context);
841    }
842
843    // Seed DML target tables explicitly because they are struct fields and may
844    // not appear as standalone Expression::Table nodes in traversal output.
845    match stmt {
846        Expression::Insert(insert) => {
847            add_table_to_context(&insert.table, schema_map, &mut context);
848        }
849        Expression::Update(update) => {
850            add_table_to_context(&update.table, schema_map, &mut context);
851            for table in &update.extra_tables {
852                add_table_to_context(table, schema_map, &mut context);
853            }
854        }
855        Expression::Delete(delete) => {
856            add_table_to_context(&delete.table, schema_map, &mut context);
857            for table in &delete.using {
858                add_table_to_context(table, schema_map, &mut context);
859            }
860            for table in &delete.tables {
861                add_table_to_context(table, schema_map, &mut context);
862            }
863        }
864        _ => {}
865    }
866
867    context
868}
869
870fn resolve_table_schema_entry<'a>(
871    table: &TableRef,
872    schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874    let key = table_ref_candidates(table)
875        .into_iter()
876        .find(|k| schema_map.contains_key(k))?;
877    let entry = schema_map.get(&key)?;
878    Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882    if strict {
883        ValidationError::error(
884            message.into(),
885            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886        )
887    } else {
888        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889    }
890}
891
892fn reference_table_candidates(
893    table_name: &str,
894    explicit_schema: Option<&str>,
895    source_schema: Option<&str>,
896) -> Vec<String> {
897    let mut candidates = Vec::new();
898    let raw = lower(table_name);
899
900    if let Some(schema) = explicit_schema {
901        candidates.push(format!("{}.{}", lower(schema), raw));
902    }
903
904    if raw.contains('.') {
905        candidates.push(raw.clone());
906        if let Some(last) = raw.rsplit('.').next() {
907            candidates.push(last.to_string());
908        }
909    } else {
910        if let Some(schema) = source_schema {
911            candidates.push(format!("{}.{}", lower(schema), raw));
912        }
913        candidates.push(raw);
914    }
915
916    let mut dedup = HashSet::new();
917    candidates
918        .into_iter()
919        .filter(|c| dedup.insert(c.clone()))
920        .collect()
921}
922
923fn resolve_reference_table_key(
924    table_name: &str,
925    explicit_schema: Option<&str>,
926    source_schema: Option<&str>,
927    schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929    reference_table_candidates(table_name, explicit_schema, source_schema)
930        .into_iter()
931        .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936        return true;
937    }
938    if source == target {
939        return true;
940    }
941    if source.is_numeric() && target.is_numeric() {
942        return true;
943    }
944    if source.is_temporal() && target.is_temporal() {
945        return true;
946    }
947    false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951    let mut hints = HashSet::new();
952    for column in &table.columns {
953        if column.primary_key || column.unique {
954            hints.insert(lower(&column.name));
955        }
956    }
957    for key_col in &table.primary_key {
958        hints.insert(lower(key_col));
959    }
960    for group in &table.unique_keys {
961        if group.len() == 1 {
962            if let Some(col) = group.first() {
963                hints.insert(lower(col));
964            }
965        }
966    }
967    hints
968}
969
970fn check_reference_integrity(
971    schema: &ValidationSchema,
972    schema_map: &HashMap<String, TableSchemaEntry>,
973    strict: bool,
974) -> Vec<ValidationError> {
975    let mut errors = Vec::new();
976
977    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978    for table in &schema.tables {
979        let simple = lower(&table.name);
980        key_hints_lookup.insert(simple, table_key_hints(table));
981        if let Some(schema_name) = &table.schema {
982            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983            key_hints_lookup.insert(qualified, table_key_hints(table));
984        }
985    }
986
987    for table in &schema.tables {
988        let source_table_display = if let Some(schema_name) = &table.schema {
989            format!("{}.{}", schema_name, table.name)
990        } else {
991            table.name.clone()
992        };
993        let source_schema = table.schema.as_deref();
994        let source_columns: HashMap<String, TypeFamily> = table
995            .columns
996            .iter()
997            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998            .collect();
999
1000        for source_col in &table.columns {
1001            let Some(reference) = &source_col.references else {
1002                continue;
1003            };
1004            let source_type = canonical_type_family(&source_col.data_type);
1005
1006            let Some(target_key) = resolve_reference_table_key(
1007                &reference.table,
1008                reference.schema.as_deref(),
1009                source_schema,
1010                schema_map,
1011            ) else {
1012                errors.push(reference_issue(
1013                    strict,
1014                    format!(
1015                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1016                        source_table_display, source_col.name, reference.table
1017                    ),
1018                ));
1019                continue;
1020            };
1021
1022            let target_column = lower(&reference.column);
1023            let Some(target_entry) = schema_map.get(&target_key) else {
1024                errors.push(reference_issue(
1025                    strict,
1026                    format!(
1027                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1028                        source_table_display, source_col.name, reference.table
1029                    ),
1030                ));
1031                continue;
1032            };
1033
1034            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035                errors.push(reference_issue(
1036                    strict,
1037                    format!(
1038                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039                        source_table_display, source_col.name, target_key, reference.column
1040                    ),
1041                ));
1042                continue;
1043            };
1044
1045            if !key_types_compatible(source_type, target_type) {
1046                errors.push(reference_issue(
1047                    strict,
1048                    format!(
1049                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050                        source_table_display,
1051                        source_col.name,
1052                        target_key,
1053                        reference.column,
1054                        type_family_name(source_type),
1055                        type_family_name(target_type)
1056                    ),
1057                ));
1058            }
1059
1060            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061                if !target_key_hints.contains(&target_column) {
1062                    errors.push(ValidationError::warning(
1063                        format!(
1064                            "Referenced column '{}.{}' is not marked as primary/unique key",
1065                            target_key, reference.column
1066                        ),
1067                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068                    ));
1069                }
1070            }
1071        }
1072
1073        for foreign_key in &table.foreign_keys {
1074            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075                errors.push(reference_issue(
1076                    strict,
1077                    format!(
1078                        "Table-level foreign key on '{}' has empty source or target column list",
1079                        source_table_display
1080                    ),
1081                ));
1082                continue;
1083            }
1084            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085                errors.push(reference_issue(
1086                    strict,
1087                    format!(
1088                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089                        source_table_display,
1090                        foreign_key.columns.len(),
1091                        foreign_key.references.columns.len()
1092                    ),
1093                ));
1094                continue;
1095            }
1096
1097            let Some(target_key) = resolve_reference_table_key(
1098                &foreign_key.references.table,
1099                foreign_key.references.schema.as_deref(),
1100                source_schema,
1101                schema_map,
1102            ) else {
1103                errors.push(reference_issue(
1104                    strict,
1105                    format!(
1106                        "Table-level foreign key on '{}' points to unknown table '{}'",
1107                        source_table_display, foreign_key.references.table
1108                    ),
1109                ));
1110                continue;
1111            };
1112
1113            let Some(target_entry) = schema_map.get(&target_key) else {
1114                errors.push(reference_issue(
1115                    strict,
1116                    format!(
1117                        "Table-level foreign key on '{}' points to unknown table '{}'",
1118                        source_table_display, foreign_key.references.table
1119                    ),
1120                ));
1121                continue;
1122            };
1123
1124            for (source_col, target_col) in foreign_key
1125                .columns
1126                .iter()
1127                .zip(foreign_key.references.columns.iter())
1128            {
1129                let source_col_name = lower(source_col);
1130                let target_col_name = lower(target_col);
1131
1132                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133                    errors.push(reference_issue(
1134                        strict,
1135                        format!(
1136                            "Table-level foreign key on '{}' references unknown source column '{}'",
1137                            source_table_display, source_col
1138                        ),
1139                    ));
1140                    continue;
1141                };
1142
1143                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144                    errors.push(reference_issue(
1145                        strict,
1146                        format!(
1147                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148                            source_table_display, target_key, target_col
1149                        ),
1150                    ));
1151                    continue;
1152                };
1153
1154                if !key_types_compatible(source_type, target_type) {
1155                    errors.push(reference_issue(
1156                        strict,
1157                        format!(
1158                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159                            source_table_display,
1160                            source_col,
1161                            target_key,
1162                            target_col,
1163                            type_family_name(source_type),
1164                            type_family_name(target_type)
1165                        ),
1166                    ));
1167                }
1168
1169                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170                    if !target_key_hints.contains(&target_col_name) {
1171                        errors.push(ValidationError::warning(
1172                            format!(
1173                                "Referenced column '{}.{}' is not marked as primary/unique key",
1174                                target_key, target_col
1175                            ),
1176                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177                        ));
1178                    }
1179                }
1180            }
1181        }
1182    }
1183
1184    errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188    column_name: &str,
1189    schema_map: &HashMap<String, TableSchemaEntry>,
1190    context: &TypeCheckContext,
1191) -> TypeFamily {
1192    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193        context.referenced_tables.iter().collect()
1194    } else {
1195        schema_map.keys().collect()
1196    };
1197
1198    let mut families = HashSet::new();
1199    for table_name in candidate_tables {
1200        if let Some(table_schema) = schema_map.get(table_name) {
1201            if let Some(family) = table_schema.columns.get(column_name) {
1202                families.insert(*family);
1203            }
1204        }
1205    }
1206
1207    if families.len() == 1 {
1208        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209    } else {
1210        TypeFamily::Unknown
1211    }
1212}
1213
1214fn resolve_column_type(
1215    column: &Column,
1216    schema_map: &HashMap<String, TableSchemaEntry>,
1217    context: &TypeCheckContext,
1218) -> TypeFamily {
1219    let column_name = lower(&column.name.name);
1220    if column_name.is_empty() {
1221        return TypeFamily::Unknown;
1222    }
1223
1224    if let Some(table) = &column.table {
1225        let mut table_key = lower(&table.name);
1226        if let Some(mapped) = context.table_aliases.get(&table_key) {
1227            table_key = mapped.clone();
1228        }
1229
1230        return schema_map
1231            .get(&table_key)
1232            .and_then(|t| t.columns.get(&column_name))
1233            .copied()
1234            .unwrap_or(TypeFamily::Unknown);
1235    }
1236
1237    resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241    schema_map: &'a HashMap<String, TableSchemaEntry>,
1242    context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246    fn resolve_table_key(&self, table: &str) -> Option<String> {
1247        let mut table_key = lower(table);
1248        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249            table_key = mapped.clone();
1250        }
1251        if self.schema_map.contains_key(&table_key) {
1252            Some(table_key)
1253        } else {
1254            None
1255        }
1256    }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260    fn dialect(&self) -> Option<DialectType> {
1261        None
1262    }
1263
1264    fn add_table(
1265        &mut self,
1266        _table: &str,
1267        _columns: &[(String, DataType)],
1268        _dialect: Option<DialectType>,
1269    ) -> SchemaResult<()> {
1270        Err(SchemaError::InvalidStructure(
1271            "Type inference schema is read-only".to_string(),
1272        ))
1273    }
1274
1275    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276        let table_key = self
1277            .resolve_table_key(table)
1278            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279        let entry = self
1280            .schema_map
1281            .get(&table_key)
1282            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283        Ok(entry.column_order.clone())
1284    }
1285
1286    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287        let col_name = lower(column);
1288        if table.is_empty() {
1289            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290            return if family == TypeFamily::Unknown {
1291                Err(SchemaError::ColumnNotFound {
1292                    table: "<unqualified>".to_string(),
1293                    column: column.to_string(),
1294                })
1295            } else {
1296                Ok(type_family_to_data_type(family))
1297            };
1298        }
1299
1300        let table_key = self
1301            .resolve_table_key(table)
1302            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303        let entry = self
1304            .schema_map
1305            .get(&table_key)
1306            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307        let family =
1308            entry
1309                .columns
1310                .get(&col_name)
1311                .copied()
1312                .ok_or_else(|| SchemaError::ColumnNotFound {
1313                    table: table.to_string(),
1314                    column: column.to_string(),
1315                })?;
1316        Ok(type_family_to_data_type(family))
1317    }
1318
1319    fn has_column(&self, table: &str, column: &str) -> bool {
1320        self.get_column_type(table, column).is_ok()
1321    }
1322
1323    fn supported_table_args(&self) -> &[&str] {
1324        TABLE_PARTS
1325    }
1326
1327    fn is_empty(&self) -> bool {
1328        self.schema_map.is_empty()
1329    }
1330
1331    fn depth(&self) -> usize {
1332        1
1333    }
1334
1335    fn find_tables_for_column(&self, column: &str) -> Vec<String> {
1336        let col_name = column.to_lowercase();
1337        self.schema_map
1338            .iter()
1339            .filter(|(_, entry)| {
1340                entry
1341                    .column_order
1342                    .iter()
1343                    .any(|c| c.to_lowercase() == col_name)
1344            })
1345            .map(|(table, _)| table.clone())
1346            .collect()
1347    }
1348}
1349
1350fn infer_expression_type_family(
1351    expr: &Expression,
1352    schema_map: &HashMap<String, TableSchemaEntry>,
1353    context: &TypeCheckContext,
1354) -> TypeFamily {
1355    let inference_schema = TypeInferenceSchema {
1356        schema_map,
1357        context,
1358    };
1359    let mut expr_clone = expr.clone();
1360    annotate_types(&mut expr_clone, Some(&inference_schema), None);
1361    if let Some(data_type) = expr_clone.inferred_type() {
1362        let family = data_type_family(&data_type);
1363        if family != TypeFamily::Unknown {
1364            return family;
1365        }
1366    }
1367
1368    infer_expression_type_family_fallback(expr, schema_map, context)
1369}
1370
1371fn infer_expression_type_family_fallback(
1372    expr: &Expression,
1373    schema_map: &HashMap<String, TableSchemaEntry>,
1374    context: &TypeCheckContext,
1375) -> TypeFamily {
1376    match expr {
1377        Expression::Literal(literal) => match literal.as_ref() {
1378            crate::expressions::Literal::Number(value) => {
1379                if value.contains('.') || value.contains('e') || value.contains('E') {
1380                    TypeFamily::Numeric
1381                } else {
1382                    TypeFamily::Integer
1383                }
1384            }
1385            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1386            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1387            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1388            crate::expressions::Literal::Timestamp(_)
1389            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1390            crate::expressions::Literal::HexString(_)
1391            | crate::expressions::Literal::BitString(_)
1392            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1393            _ => TypeFamily::String,
1394        },
1395        Expression::Boolean(_) => TypeFamily::Boolean,
1396        Expression::Null(_) => TypeFamily::Unknown,
1397        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1398        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1399            data_type_family(&cast.to)
1400        }
1401        Expression::Alias(alias) => {
1402            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1403        }
1404        Expression::Neg(unary) => {
1405            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1406        }
1407        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1408            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1409            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1410            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1411                TypeFamily::Unknown
1412            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1413                TypeFamily::Integer
1414            } else if left.is_numeric() && right.is_numeric() {
1415                TypeFamily::Numeric
1416            } else if left.is_temporal() || right.is_temporal() {
1417                left
1418            } else {
1419                TypeFamily::Unknown
1420            }
1421        }
1422        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1423        Expression::Concat(_) => TypeFamily::String,
1424        Expression::Eq(_)
1425        | Expression::Neq(_)
1426        | Expression::Lt(_)
1427        | Expression::Lte(_)
1428        | Expression::Gt(_)
1429        | Expression::Gte(_)
1430        | Expression::Like(_)
1431        | Expression::ILike(_)
1432        | Expression::And(_)
1433        | Expression::Or(_)
1434        | Expression::Not(_)
1435        | Expression::Between(_)
1436        | Expression::In(_)
1437        | Expression::IsNull(_)
1438        | Expression::IsTrue(_)
1439        | Expression::IsFalse(_)
1440        | Expression::Is(_) => TypeFamily::Boolean,
1441        Expression::Length(_) => TypeFamily::Integer,
1442        Expression::Upper(_)
1443        | Expression::Lower(_)
1444        | Expression::Trim(_)
1445        | Expression::LTrim(_)
1446        | Expression::RTrim(_)
1447        | Expression::Replace(_)
1448        | Expression::Substring(_)
1449        | Expression::Left(_)
1450        | Expression::Right(_)
1451        | Expression::Repeat(_)
1452        | Expression::Lpad(_)
1453        | Expression::Rpad(_)
1454        | Expression::ConcatWs(_) => TypeFamily::String,
1455        Expression::Abs(_)
1456        | Expression::Round(_)
1457        | Expression::Floor(_)
1458        | Expression::Ceil(_)
1459        | Expression::Power(_)
1460        | Expression::Sqrt(_)
1461        | Expression::Cbrt(_)
1462        | Expression::Ln(_)
1463        | Expression::Log(_)
1464        | Expression::Exp(_) => TypeFamily::Numeric,
1465        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1466        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1467        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1468        Expression::CurrentDate(_) => TypeFamily::Date,
1469        Expression::CurrentTime(_) => TypeFamily::Time,
1470        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1471            TypeFamily::Timestamp
1472        }
1473        Expression::Interval(_) => TypeFamily::Interval,
1474        _ => TypeFamily::Unknown,
1475    }
1476}
1477
1478fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1479    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1480        return true;
1481    }
1482    if left == right {
1483        return true;
1484    }
1485    if left.is_numeric() && right.is_numeric() {
1486        return true;
1487    }
1488    if left.is_temporal() && right.is_temporal() {
1489        return true;
1490    }
1491    false
1492}
1493
1494fn check_function_argument(
1495    errors: &mut Vec<ValidationError>,
1496    strict: bool,
1497    function_name: &str,
1498    arg_index: usize,
1499    family: TypeFamily,
1500    expected: &str,
1501    valid: bool,
1502) {
1503    if family == TypeFamily::Unknown || valid {
1504        return;
1505    }
1506
1507    errors.push(type_issue(
1508        strict,
1509        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1510        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1511        format!(
1512            "Function '{}' argument {} expects {}, found {}",
1513            function_name,
1514            arg_index + 1,
1515            expected,
1516            type_family_name(family)
1517        ),
1518    ));
1519}
1520
1521fn function_dispatch_name(name: &str) -> String {
1522    let upper = name
1523        .rsplit('.')
1524        .next()
1525        .unwrap_or(name)
1526        .trim()
1527        .to_uppercase();
1528    lower(canonical_typed_function_name_upper(&upper))
1529}
1530
1531fn function_base_name(name: &str) -> &str {
1532    name.rsplit('.').next().unwrap_or(name).trim()
1533}
1534
1535fn check_generic_function(
1536    function: &Function,
1537    schema_map: &HashMap<String, TableSchemaEntry>,
1538    context: &TypeCheckContext,
1539    strict: bool,
1540    errors: &mut Vec<ValidationError>,
1541) {
1542    let name = function_dispatch_name(&function.name);
1543
1544    let arg_family = |index: usize| -> Option<TypeFamily> {
1545        function
1546            .args
1547            .get(index)
1548            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1549    };
1550
1551    match name.as_str() {
1552        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1553            if let Some(family) = arg_family(0) {
1554                check_function_argument(
1555                    errors,
1556                    strict,
1557                    &name,
1558                    0,
1559                    family,
1560                    "a numeric argument",
1561                    family.is_numeric(),
1562                );
1563            }
1564        }
1565        "round" | "floor" | "ceil" | "ceiling" => {
1566            if let Some(family) = arg_family(0) {
1567                check_function_argument(
1568                    errors,
1569                    strict,
1570                    &name,
1571                    0,
1572                    family,
1573                    "a numeric argument",
1574                    family.is_numeric(),
1575                );
1576            }
1577            if let Some(family) = arg_family(1) {
1578                check_function_argument(
1579                    errors,
1580                    strict,
1581                    &name,
1582                    1,
1583                    family,
1584                    "a numeric argument",
1585                    family.is_numeric(),
1586                );
1587            }
1588        }
1589        "power" | "pow" => {
1590            for i in [0_usize, 1_usize] {
1591                if let Some(family) = arg_family(i) {
1592                    check_function_argument(
1593                        errors,
1594                        strict,
1595                        &name,
1596                        i,
1597                        family,
1598                        "a numeric argument",
1599                        family.is_numeric(),
1600                    );
1601                }
1602            }
1603        }
1604        "length" | "char_length" | "character_length" => {
1605            if let Some(family) = arg_family(0) {
1606                check_function_argument(
1607                    errors,
1608                    strict,
1609                    &name,
1610                    0,
1611                    family,
1612                    "a string or binary argument",
1613                    is_string_or_binary(family),
1614                );
1615            }
1616        }
1617        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1618            if let Some(family) = arg_family(0) {
1619                check_function_argument(
1620                    errors,
1621                    strict,
1622                    &name,
1623                    0,
1624                    family,
1625                    "a string argument",
1626                    is_string_like(family),
1627                );
1628            }
1629        }
1630        "substring" | "substr" => {
1631            if let Some(family) = arg_family(0) {
1632                check_function_argument(
1633                    errors,
1634                    strict,
1635                    &name,
1636                    0,
1637                    family,
1638                    "a string argument",
1639                    is_string_like(family),
1640                );
1641            }
1642            if let Some(family) = arg_family(1) {
1643                check_function_argument(
1644                    errors,
1645                    strict,
1646                    &name,
1647                    1,
1648                    family,
1649                    "a numeric argument",
1650                    family.is_numeric(),
1651                );
1652            }
1653            if let Some(family) = arg_family(2) {
1654                check_function_argument(
1655                    errors,
1656                    strict,
1657                    &name,
1658                    2,
1659                    family,
1660                    "a numeric argument",
1661                    family.is_numeric(),
1662                );
1663            }
1664        }
1665        "replace" => {
1666            for i in [0_usize, 1_usize, 2_usize] {
1667                if let Some(family) = arg_family(i) {
1668                    check_function_argument(
1669                        errors,
1670                        strict,
1671                        &name,
1672                        i,
1673                        family,
1674                        "a string argument",
1675                        is_string_like(family),
1676                    );
1677                }
1678            }
1679        }
1680        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1681            if let Some(family) = arg_family(0) {
1682                check_function_argument(
1683                    errors,
1684                    strict,
1685                    &name,
1686                    0,
1687                    family,
1688                    "a string argument",
1689                    is_string_like(family),
1690                );
1691            }
1692            if let Some(family) = arg_family(1) {
1693                check_function_argument(
1694                    errors,
1695                    strict,
1696                    &name,
1697                    1,
1698                    family,
1699                    "a numeric argument",
1700                    family.is_numeric(),
1701                );
1702            }
1703            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1704                if let Some(family) = arg_family(2) {
1705                    check_function_argument(
1706                        errors,
1707                        strict,
1708                        &name,
1709                        2,
1710                        family,
1711                        "a string argument",
1712                        is_string_like(family),
1713                    );
1714                }
1715            }
1716        }
1717        _ => {}
1718    }
1719}
1720
1721fn check_function_catalog(
1722    function: &Function,
1723    dialect: DialectType,
1724    function_catalog: Option<&dyn FunctionCatalog>,
1725    strict: bool,
1726    errors: &mut Vec<ValidationError>,
1727) {
1728    let Some(catalog) = function_catalog else {
1729        return;
1730    };
1731
1732    let raw_name = function_base_name(&function.name);
1733    let normalized_name = function_dispatch_name(&function.name);
1734    let arity = function.args.len();
1735    let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1736        errors.push(if strict {
1737            ValidationError::error(
1738                format!(
1739                    "Unknown function '{}' for dialect {:?}",
1740                    function.name, dialect
1741                ),
1742                validation_codes::E_UNKNOWN_FUNCTION,
1743            )
1744        } else {
1745            ValidationError::warning(
1746                format!(
1747                    "Unknown function '{}' for dialect {:?}",
1748                    function.name, dialect
1749                ),
1750                validation_codes::E_UNKNOWN_FUNCTION,
1751            )
1752        });
1753        return;
1754    };
1755
1756    if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1757        return;
1758    }
1759
1760    let expected = signatures
1761        .iter()
1762        .map(|sig| sig.describe_arity())
1763        .collect::<Vec<_>>()
1764        .join(", ");
1765    errors.push(if strict {
1766        ValidationError::error(
1767            format!(
1768                "Invalid arity for function '{}': got {}, expected {}",
1769                function.name, arity, expected
1770            ),
1771            validation_codes::E_INVALID_FUNCTION_ARITY,
1772        )
1773    } else {
1774        ValidationError::warning(
1775            format!(
1776                "Invalid arity for function '{}': got {}, expected {}",
1777                function.name, arity, expected
1778            ),
1779            validation_codes::E_INVALID_FUNCTION_ARITY,
1780        )
1781    });
1782}
1783
1784#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1785struct DeclaredRelationship {
1786    source_table: String,
1787    source_column: String,
1788    target_table: String,
1789    target_column: String,
1790}
1791
1792fn build_declared_relationships(
1793    schema: &ValidationSchema,
1794    schema_map: &HashMap<String, TableSchemaEntry>,
1795) -> Vec<DeclaredRelationship> {
1796    let mut relationships = HashSet::new();
1797
1798    for table in &schema.tables {
1799        let Some(source_key) =
1800            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1801        else {
1802            continue;
1803        };
1804
1805        for column in &table.columns {
1806            let Some(reference) = &column.references else {
1807                continue;
1808            };
1809            let Some(target_key) = resolve_reference_table_key(
1810                &reference.table,
1811                reference.schema.as_deref(),
1812                table.schema.as_deref(),
1813                schema_map,
1814            ) else {
1815                continue;
1816            };
1817
1818            relationships.insert(DeclaredRelationship {
1819                source_table: source_key.clone(),
1820                source_column: lower(&column.name),
1821                target_table: target_key,
1822                target_column: lower(&reference.column),
1823            });
1824        }
1825
1826        for foreign_key in &table.foreign_keys {
1827            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1828                continue;
1829            }
1830            let Some(target_key) = resolve_reference_table_key(
1831                &foreign_key.references.table,
1832                foreign_key.references.schema.as_deref(),
1833                table.schema.as_deref(),
1834                schema_map,
1835            ) else {
1836                continue;
1837            };
1838
1839            for (source_col, target_col) in foreign_key
1840                .columns
1841                .iter()
1842                .zip(foreign_key.references.columns.iter())
1843            {
1844                relationships.insert(DeclaredRelationship {
1845                    source_table: source_key.clone(),
1846                    source_column: lower(source_col),
1847                    target_table: target_key.clone(),
1848                    target_column: lower(target_col),
1849                });
1850            }
1851        }
1852    }
1853
1854    relationships.into_iter().collect()
1855}
1856
1857fn resolve_column_binding(
1858    column: &Column,
1859    schema_map: &HashMap<String, TableSchemaEntry>,
1860    context: &TypeCheckContext,
1861    resolver: &mut Resolver<'_>,
1862) -> Option<(String, String)> {
1863    let column_name = lower(&column.name.name);
1864    if column_name.is_empty() {
1865        return None;
1866    }
1867
1868    if let Some(table) = &column.table {
1869        let mut table_key = lower(&table.name);
1870        if let Some(mapped) = context.table_aliases.get(&table_key) {
1871            table_key = mapped.clone();
1872        }
1873        if schema_map.contains_key(&table_key) {
1874            return Some((table_key, column_name));
1875        }
1876        return None;
1877    }
1878
1879    if let Some(resolved_source) = resolver.get_table(&column_name) {
1880        let mut table_key = lower(&resolved_source);
1881        if let Some(mapped) = context.table_aliases.get(&table_key) {
1882            table_key = mapped.clone();
1883        }
1884        if schema_map.contains_key(&table_key) {
1885            return Some((table_key, column_name));
1886        }
1887    }
1888
1889    let candidates: Vec<String> = context
1890        .referenced_tables
1891        .iter()
1892        .filter_map(|table_name| {
1893            schema_map
1894                .get(table_name)
1895                .filter(|entry| entry.columns.contains_key(&column_name))
1896                .map(|_| table_name.clone())
1897        })
1898        .collect();
1899    if candidates.len() == 1 {
1900        return Some((candidates[0].clone(), column_name));
1901    }
1902    None
1903}
1904
1905fn extract_join_equality_pairs(
1906    expr: &Expression,
1907    schema_map: &HashMap<String, TableSchemaEntry>,
1908    context: &TypeCheckContext,
1909    resolver: &mut Resolver<'_>,
1910    pairs: &mut Vec<((String, String), (String, String))>,
1911) {
1912    match expr {
1913        Expression::And(op) => {
1914            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1915            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1916        }
1917        Expression::Paren(paren) => {
1918            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1919        }
1920        Expression::Eq(op) => {
1921            let (Expression::Column(left_col), Expression::Column(right_col)) =
1922                (&op.left, &op.right)
1923            else {
1924                return;
1925            };
1926            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1927                return;
1928            };
1929            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1930            else {
1931                return;
1932            };
1933            pairs.push((left, right));
1934        }
1935        _ => {}
1936    }
1937}
1938
1939fn relationship_matches_pair(
1940    relationship: &DeclaredRelationship,
1941    left_table: &str,
1942    left_column: &str,
1943    right_table: &str,
1944    right_column: &str,
1945) -> bool {
1946    (relationship.source_table == left_table
1947        && relationship.source_column == left_column
1948        && relationship.target_table == right_table
1949        && relationship.target_column == right_column)
1950        || (relationship.source_table == right_table
1951            && relationship.source_column == right_column
1952            && relationship.target_table == left_table
1953            && relationship.target_column == left_column)
1954}
1955
1956fn resolved_table_key_from_expr(
1957    expr: &Expression,
1958    schema_map: &HashMap<String, TableSchemaEntry>,
1959) -> Option<String> {
1960    match expr {
1961        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1962        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1963        _ => None,
1964    }
1965}
1966
1967fn select_from_table_keys(
1968    select: &crate::expressions::Select,
1969    schema_map: &HashMap<String, TableSchemaEntry>,
1970) -> HashSet<String> {
1971    let mut keys = HashSet::new();
1972    if let Some(from_clause) = &select.from {
1973        for expr in &from_clause.expressions {
1974            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1975                keys.insert(key);
1976            }
1977        }
1978    }
1979    keys
1980}
1981
1982fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1983    matches!(
1984        kind,
1985        JoinKind::Natural
1986            | JoinKind::NaturalLeft
1987            | JoinKind::NaturalRight
1988            | JoinKind::NaturalFull
1989            | JoinKind::CrossApply
1990            | JoinKind::OuterApply
1991            | JoinKind::AsOf
1992            | JoinKind::AsOfLeft
1993            | JoinKind::AsOfRight
1994            | JoinKind::Lateral
1995            | JoinKind::LeftLateral
1996    )
1997}
1998
1999fn check_query_reference_quality(
2000    stmt: &Expression,
2001    schema_map: &HashMap<String, TableSchemaEntry>,
2002    resolver_schema: &MappingSchema,
2003    strict: bool,
2004    relationships: &[DeclaredRelationship],
2005) -> Vec<ValidationError> {
2006    let mut errors = Vec::new();
2007
2008    for node in stmt.dfs() {
2009        let Expression::Select(select) = node else {
2010            continue;
2011        };
2012
2013        let select_expr = Expression::Select(select.clone());
2014        let context = collect_type_check_context(&select_expr, schema_map);
2015        let scope = build_scope(&select_expr);
2016        let mut resolver = Resolver::new(&scope, resolver_schema, true);
2017
2018        if context.referenced_tables.len() > 1 {
2019            let using_columns: HashSet<String> = select
2020                .joins
2021                .iter()
2022                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2023                .collect();
2024
2025            let mut seen = HashSet::new();
2026            for column_expr in select_expr
2027                .find_all(|e| matches!(e, Expression::Column(col) if col.table.is_none()))
2028            {
2029                let Expression::Column(column) = column_expr else {
2030                    continue;
2031                };
2032
2033                let col_name = lower(&column.name.name);
2034                if col_name.is_empty()
2035                    || using_columns.contains(&col_name)
2036                    || !seen.insert(col_name.clone())
2037                {
2038                    continue;
2039                }
2040
2041                if resolver.is_ambiguous(&col_name) {
2042                    let source_count = resolver.sources_for_column(&col_name).len();
2043                    errors.push(if strict {
2044                        ValidationError::error(
2045                            format!(
2046                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2047                                col_name, source_count
2048                            ),
2049                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2050                        )
2051                    } else {
2052                        ValidationError::warning(
2053                            format!(
2054                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2055                                col_name, source_count
2056                            ),
2057                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2058                        )
2059                    });
2060                }
2061            }
2062        }
2063
2064        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2065
2066        for join in &select.joins {
2067            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2068            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2069            let cartesian_like_kind = matches!(
2070                join.kind,
2071                JoinKind::Cross
2072                    | JoinKind::Implicit
2073                    | JoinKind::Array
2074                    | JoinKind::LeftArray
2075                    | JoinKind::Paste
2076            );
2077
2078            if right_table_key.is_some()
2079                && (cartesian_like_kind
2080                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2081            {
2082                errors.push(ValidationError::warning(
2083                    "Potential cartesian join: JOIN without ON/USING condition",
2084                    validation_codes::W_CARTESIAN_JOIN,
2085                ));
2086            }
2087
2088            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2089                if join.using.is_empty() {
2090                    let mut eq_pairs = Vec::new();
2091                    extract_join_equality_pairs(
2092                        on_expr,
2093                        schema_map,
2094                        &context,
2095                        &mut resolver,
2096                        &mut eq_pairs,
2097                    );
2098
2099                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2100                        .iter()
2101                        .filter(|rel| {
2102                            cumulative_left_tables.contains(&rel.source_table)
2103                                && rel.target_table == right_key
2104                                || (cumulative_left_tables.contains(&rel.target_table)
2105                                    && rel.source_table == right_key)
2106                        })
2107                        .collect();
2108
2109                    if !relevant_relationships.is_empty() {
2110                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2111                            relevant_relationships
2112                                .iter()
2113                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2114                        });
2115                        if !uses_declared_fk {
2116                            errors.push(ValidationError::warning(
2117                                "JOIN predicate does not use declared foreign-key relationship columns",
2118                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2119                            ));
2120                        }
2121                    }
2122                }
2123            }
2124
2125            if let Some(right_key) = right_table_key {
2126                cumulative_left_tables.insert(right_key);
2127            }
2128        }
2129    }
2130
2131    errors
2132}
2133
2134fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2135    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2136        return true;
2137    }
2138    if left == right {
2139        return true;
2140    }
2141    if left.is_numeric() && right.is_numeric() {
2142        return true;
2143    }
2144    if left.is_temporal() && right.is_temporal() {
2145        return true;
2146    }
2147    false
2148}
2149
2150fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2151    if left == TypeFamily::Unknown {
2152        return right;
2153    }
2154    if right == TypeFamily::Unknown {
2155        return left;
2156    }
2157    if left == right {
2158        return left;
2159    }
2160    if left.is_numeric() && right.is_numeric() {
2161        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2162            return TypeFamily::Numeric;
2163        }
2164        return TypeFamily::Integer;
2165    }
2166    if left.is_temporal() && right.is_temporal() {
2167        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2168            return TypeFamily::Timestamp;
2169        }
2170        if left == TypeFamily::Date || right == TypeFamily::Date {
2171            return TypeFamily::Date;
2172        }
2173        return TypeFamily::Time;
2174    }
2175    TypeFamily::Unknown
2176}
2177
2178fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2179    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2180        return true;
2181    }
2182    if target == source {
2183        return true;
2184    }
2185
2186    match target {
2187        TypeFamily::Boolean => source == TypeFamily::Boolean,
2188        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2189        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2190            source.is_temporal()
2191        }
2192        TypeFamily::String => true,
2193        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2194        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2195        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2196        TypeFamily::Array => source == TypeFamily::Array,
2197        TypeFamily::Map => source == TypeFamily::Map,
2198        TypeFamily::Struct => source == TypeFamily::Struct,
2199        TypeFamily::Unknown => true,
2200    }
2201}
2202
2203fn projection_families(
2204    query_expr: &Expression,
2205    schema_map: &HashMap<String, TableSchemaEntry>,
2206) -> Option<Vec<TypeFamily>> {
2207    match query_expr {
2208        Expression::Select(select) => {
2209            if select
2210                .expressions
2211                .iter()
2212                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2213            {
2214                return None;
2215            }
2216            let select_expr = Expression::Select(select.clone());
2217            let context = collect_type_check_context(&select_expr, schema_map);
2218            Some(
2219                select
2220                    .expressions
2221                    .iter()
2222                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2223                    .collect(),
2224            )
2225        }
2226        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2227        Expression::Union(union) => {
2228            let left = projection_families(&union.left, schema_map)?;
2229            let right = projection_families(&union.right, schema_map)?;
2230            if left.len() != right.len() {
2231                return None;
2232            }
2233            Some(
2234                left.into_iter()
2235                    .zip(right)
2236                    .map(|(l, r)| merged_setop_family(l, r))
2237                    .collect(),
2238            )
2239        }
2240        Expression::Intersect(intersect) => {
2241            let left = projection_families(&intersect.left, schema_map)?;
2242            let right = projection_families(&intersect.right, schema_map)?;
2243            if left.len() != right.len() {
2244                return None;
2245            }
2246            Some(
2247                left.into_iter()
2248                    .zip(right)
2249                    .map(|(l, r)| merged_setop_family(l, r))
2250                    .collect(),
2251            )
2252        }
2253        Expression::Except(except) => {
2254            let left = projection_families(&except.left, schema_map)?;
2255            let right = projection_families(&except.right, schema_map)?;
2256            if left.len() != right.len() {
2257                return None;
2258            }
2259            Some(
2260                left.into_iter()
2261                    .zip(right)
2262                    .map(|(l, r)| merged_setop_family(l, r))
2263                    .collect(),
2264            )
2265        }
2266        Expression::Values(values) => {
2267            let first_row = values.expressions.first()?;
2268            let context = TypeCheckContext::default();
2269            Some(
2270                first_row
2271                    .expressions
2272                    .iter()
2273                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2274                    .collect(),
2275            )
2276        }
2277        _ => None,
2278    }
2279}
2280
2281fn check_set_operation_compatibility(
2282    op_name: &str,
2283    left_expr: &Expression,
2284    right_expr: &Expression,
2285    schema_map: &HashMap<String, TableSchemaEntry>,
2286    strict: bool,
2287    errors: &mut Vec<ValidationError>,
2288) {
2289    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2290        return;
2291    };
2292    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2293        return;
2294    };
2295
2296    if left_projection.len() != right_projection.len() {
2297        errors.push(type_issue(
2298            strict,
2299            validation_codes::E_SETOP_ARITY_MISMATCH,
2300            validation_codes::W_SETOP_IMPLICIT_COERCION,
2301            format!(
2302                "{} operands return different column counts: left {}, right {}",
2303                op_name,
2304                left_projection.len(),
2305                right_projection.len()
2306            ),
2307        ));
2308        return;
2309    }
2310
2311    for (idx, (left, right)) in left_projection
2312        .into_iter()
2313        .zip(right_projection)
2314        .enumerate()
2315    {
2316        if !are_setop_compatible(left, right) {
2317            errors.push(type_issue(
2318                strict,
2319                validation_codes::E_SETOP_TYPE_MISMATCH,
2320                validation_codes::W_SETOP_IMPLICIT_COERCION,
2321                format!(
2322                    "{} column {} has incompatible types: {} vs {}",
2323                    op_name,
2324                    idx + 1,
2325                    type_family_name(left),
2326                    type_family_name(right)
2327                ),
2328            ));
2329        }
2330    }
2331}
2332
2333fn check_insert_assignments(
2334    stmt: &Expression,
2335    insert: &Insert,
2336    schema_map: &HashMap<String, TableSchemaEntry>,
2337    strict: bool,
2338    errors: &mut Vec<ValidationError>,
2339) {
2340    let Some((target_table_key, table_schema)) =
2341        resolve_table_schema_entry(&insert.table, schema_map)
2342    else {
2343        return;
2344    };
2345
2346    let mut target_columns = Vec::new();
2347    if insert.columns.is_empty() {
2348        target_columns.extend(table_schema.column_order.iter().cloned());
2349    } else {
2350        for column in &insert.columns {
2351            let col_name = lower(&column.name);
2352            if table_schema.columns.contains_key(&col_name) {
2353                target_columns.push(col_name);
2354            } else {
2355                errors.push(if strict {
2356                    ValidationError::error(
2357                        format!(
2358                            "Unknown column '{}' in table '{}'",
2359                            column.name, target_table_key
2360                        ),
2361                        validation_codes::E_UNKNOWN_COLUMN,
2362                    )
2363                } else {
2364                    ValidationError::warning(
2365                        format!(
2366                            "Unknown column '{}' in table '{}'",
2367                            column.name, target_table_key
2368                        ),
2369                        validation_codes::E_UNKNOWN_COLUMN,
2370                    )
2371                });
2372            }
2373        }
2374    }
2375
2376    if target_columns.is_empty() {
2377        return;
2378    }
2379
2380    let context = collect_type_check_context(stmt, schema_map);
2381
2382    if !insert.default_values {
2383        for (row_idx, row) in insert.values.iter().enumerate() {
2384            if row.len() != target_columns.len() {
2385                errors.push(type_issue(
2386                    strict,
2387                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2388                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2389                    format!(
2390                        "INSERT row {} has {} values but target has {} columns",
2391                        row_idx + 1,
2392                        row.len(),
2393                        target_columns.len()
2394                    ),
2395                ));
2396                continue;
2397            }
2398
2399            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2400                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2401                    continue;
2402                };
2403                let source_family = infer_expression_type_family(value, schema_map, &context);
2404                if !are_assignment_compatible(target_family, source_family) {
2405                    errors.push(type_issue(
2406                        strict,
2407                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2408                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2409                        format!(
2410                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2411                            target_table_key,
2412                            target_column,
2413                            type_family_name(target_family),
2414                            type_family_name(source_family)
2415                        ),
2416                    ));
2417                }
2418            }
2419        }
2420    }
2421
2422    if let Some(query) = &insert.query {
2423        // DuckDB BY NAME maps source columns by name, not position.
2424        if insert.by_name {
2425            return;
2426        }
2427
2428        let Some(source_projection) = projection_families(query, schema_map) else {
2429            return;
2430        };
2431
2432        if source_projection.len() != target_columns.len() {
2433            errors.push(type_issue(
2434                strict,
2435                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2436                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2437                format!(
2438                    "INSERT source query has {} columns but target has {} columns",
2439                    source_projection.len(),
2440                    target_columns.len()
2441                ),
2442            ));
2443            return;
2444        }
2445
2446        for (source_family, target_column) in
2447            source_projection.into_iter().zip(target_columns.iter())
2448        {
2449            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2450                continue;
2451            };
2452            if !are_assignment_compatible(target_family, source_family) {
2453                errors.push(type_issue(
2454                    strict,
2455                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2456                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2457                    format!(
2458                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2459                        target_table_key,
2460                        target_column,
2461                        type_family_name(target_family),
2462                        type_family_name(source_family)
2463                    ),
2464                ));
2465            }
2466        }
2467    }
2468}
2469
2470fn check_update_assignments(
2471    stmt: &Expression,
2472    update: &Update,
2473    schema_map: &HashMap<String, TableSchemaEntry>,
2474    strict: bool,
2475    errors: &mut Vec<ValidationError>,
2476) {
2477    let Some((target_table_key, table_schema)) =
2478        resolve_table_schema_entry(&update.table, schema_map)
2479    else {
2480        return;
2481    };
2482
2483    let context = collect_type_check_context(stmt, schema_map);
2484
2485    for (column, value) in &update.set {
2486        let col_name = lower(&column.name);
2487        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2488            errors.push(if strict {
2489                ValidationError::error(
2490                    format!(
2491                        "Unknown column '{}' in table '{}'",
2492                        column.name, target_table_key
2493                    ),
2494                    validation_codes::E_UNKNOWN_COLUMN,
2495                )
2496            } else {
2497                ValidationError::warning(
2498                    format!(
2499                        "Unknown column '{}' in table '{}'",
2500                        column.name, target_table_key
2501                    ),
2502                    validation_codes::E_UNKNOWN_COLUMN,
2503                )
2504            });
2505            continue;
2506        };
2507
2508        let source_family = infer_expression_type_family(value, schema_map, &context);
2509        if !are_assignment_compatible(target_family, source_family) {
2510            errors.push(type_issue(
2511                strict,
2512                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2513                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2514                format!(
2515                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2516                    target_table_key,
2517                    col_name,
2518                    type_family_name(target_family),
2519                    type_family_name(source_family)
2520                ),
2521            ));
2522        }
2523    }
2524}
2525
2526fn check_types(
2527    stmt: &Expression,
2528    dialect: DialectType,
2529    schema_map: &HashMap<String, TableSchemaEntry>,
2530    function_catalog: Option<&dyn FunctionCatalog>,
2531    strict: bool,
2532) -> Vec<ValidationError> {
2533    let mut errors = Vec::new();
2534    let context = collect_type_check_context(stmt, schema_map);
2535
2536    for node in stmt.dfs() {
2537        match node {
2538            Expression::Insert(insert) => {
2539                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2540            }
2541            Expression::Update(update) => {
2542                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2543            }
2544            Expression::Union(union) => {
2545                check_set_operation_compatibility(
2546                    "UNION",
2547                    &union.left,
2548                    &union.right,
2549                    schema_map,
2550                    strict,
2551                    &mut errors,
2552                );
2553            }
2554            Expression::Intersect(intersect) => {
2555                check_set_operation_compatibility(
2556                    "INTERSECT",
2557                    &intersect.left,
2558                    &intersect.right,
2559                    schema_map,
2560                    strict,
2561                    &mut errors,
2562                );
2563            }
2564            Expression::Except(except) => {
2565                check_set_operation_compatibility(
2566                    "EXCEPT",
2567                    &except.left,
2568                    &except.right,
2569                    schema_map,
2570                    strict,
2571                    &mut errors,
2572                );
2573            }
2574            Expression::Select(select) => {
2575                if let Some(prewhere) = &select.prewhere {
2576                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2577                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578                        errors.push(type_issue(
2579                            strict,
2580                            validation_codes::E_INVALID_PREDICATE_TYPE,
2581                            validation_codes::W_PREDICATE_NULLABILITY,
2582                            format!(
2583                                "PREWHERE clause expects a boolean predicate, found {}",
2584                                type_family_name(family)
2585                            ),
2586                        ));
2587                    }
2588                }
2589
2590                if let Some(where_clause) = &select.where_clause {
2591                    let family =
2592                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2593                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594                        errors.push(type_issue(
2595                            strict,
2596                            validation_codes::E_INVALID_PREDICATE_TYPE,
2597                            validation_codes::W_PREDICATE_NULLABILITY,
2598                            format!(
2599                                "WHERE clause expects a boolean predicate, found {}",
2600                                type_family_name(family)
2601                            ),
2602                        ));
2603                    }
2604                }
2605
2606                if let Some(having_clause) = &select.having {
2607                    let family =
2608                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2609                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2610                        errors.push(type_issue(
2611                            strict,
2612                            validation_codes::E_INVALID_PREDICATE_TYPE,
2613                            validation_codes::W_PREDICATE_NULLABILITY,
2614                            format!(
2615                                "HAVING clause expects a boolean predicate, found {}",
2616                                type_family_name(family)
2617                            ),
2618                        ));
2619                    }
2620                }
2621
2622                for join in &select.joins {
2623                    if let Some(on) = &join.on {
2624                        let family = infer_expression_type_family(on, schema_map, &context);
2625                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2626                            errors.push(type_issue(
2627                                strict,
2628                                validation_codes::E_INVALID_PREDICATE_TYPE,
2629                                validation_codes::W_PREDICATE_NULLABILITY,
2630                                format!(
2631                                    "JOIN ON expects a boolean predicate, found {}",
2632                                    type_family_name(family)
2633                                ),
2634                            ));
2635                        }
2636                    }
2637                    if let Some(match_condition) = &join.match_condition {
2638                        let family =
2639                            infer_expression_type_family(match_condition, schema_map, &context);
2640                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2641                            errors.push(type_issue(
2642                                strict,
2643                                validation_codes::E_INVALID_PREDICATE_TYPE,
2644                                validation_codes::W_PREDICATE_NULLABILITY,
2645                                format!(
2646                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2647                                    type_family_name(family)
2648                                ),
2649                            ));
2650                        }
2651                    }
2652                }
2653            }
2654            Expression::Where(where_clause) => {
2655                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2656                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2657                    errors.push(type_issue(
2658                        strict,
2659                        validation_codes::E_INVALID_PREDICATE_TYPE,
2660                        validation_codes::W_PREDICATE_NULLABILITY,
2661                        format!(
2662                            "WHERE clause expects a boolean predicate, found {}",
2663                            type_family_name(family)
2664                        ),
2665                    ));
2666                }
2667            }
2668            Expression::Having(having_clause) => {
2669                let family =
2670                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2671                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2672                    errors.push(type_issue(
2673                        strict,
2674                        validation_codes::E_INVALID_PREDICATE_TYPE,
2675                        validation_codes::W_PREDICATE_NULLABILITY,
2676                        format!(
2677                            "HAVING clause expects a boolean predicate, found {}",
2678                            type_family_name(family)
2679                        ),
2680                    ));
2681                }
2682            }
2683            Expression::And(op) | Expression::Or(op) => {
2684                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2685                    let family = infer_expression_type_family(expr, schema_map, &context);
2686                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2687                        errors.push(type_issue(
2688                            strict,
2689                            validation_codes::E_INVALID_PREDICATE_TYPE,
2690                            validation_codes::W_PREDICATE_NULLABILITY,
2691                            format!(
2692                                "Logical {} operand expects boolean, found {}",
2693                                side,
2694                                type_family_name(family)
2695                            ),
2696                        ));
2697                    }
2698                }
2699            }
2700            Expression::Not(unary) => {
2701                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2702                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2703                    errors.push(type_issue(
2704                        strict,
2705                        validation_codes::E_INVALID_PREDICATE_TYPE,
2706                        validation_codes::W_PREDICATE_NULLABILITY,
2707                        format!("NOT expects boolean, found {}", type_family_name(family)),
2708                    ));
2709                }
2710            }
2711            Expression::Eq(op)
2712            | Expression::Neq(op)
2713            | Expression::Lt(op)
2714            | Expression::Lte(op)
2715            | Expression::Gt(op)
2716            | Expression::Gte(op) => {
2717                let left = infer_expression_type_family(&op.left, schema_map, &context);
2718                let right = infer_expression_type_family(&op.right, schema_map, &context);
2719                if !are_comparable(left, right) {
2720                    errors.push(type_issue(
2721                        strict,
2722                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2723                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2724                        format!(
2725                            "Incompatible comparison between {} and {}",
2726                            type_family_name(left),
2727                            type_family_name(right)
2728                        ),
2729                    ));
2730                }
2731            }
2732            Expression::Like(op) | Expression::ILike(op) => {
2733                let left = infer_expression_type_family(&op.left, schema_map, &context);
2734                let right = infer_expression_type_family(&op.right, schema_map, &context);
2735                if left != TypeFamily::Unknown
2736                    && right != TypeFamily::Unknown
2737                    && (!is_string_like(left) || !is_string_like(right))
2738                {
2739                    errors.push(type_issue(
2740                        strict,
2741                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2742                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2743                        format!(
2744                            "LIKE/ILIKE expects string operands, found {} and {}",
2745                            type_family_name(left),
2746                            type_family_name(right)
2747                        ),
2748                    ));
2749                }
2750            }
2751            Expression::Between(between) => {
2752                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2753                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2754                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2755
2756                if !are_comparable(this_family, low_family)
2757                    || !are_comparable(this_family, high_family)
2758                {
2759                    errors.push(type_issue(
2760                        strict,
2761                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2762                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2763                        format!(
2764                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2765                            type_family_name(this_family),
2766                            type_family_name(low_family),
2767                            type_family_name(high_family)
2768                        ),
2769                    ));
2770                }
2771            }
2772            Expression::In(in_expr) => {
2773                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2774                for value in &in_expr.expressions {
2775                    let item_family = infer_expression_type_family(value, schema_map, &context);
2776                    if !are_comparable(this_family, item_family) {
2777                        errors.push(type_issue(
2778                            strict,
2779                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2780                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2781                            format!(
2782                                "IN item type {} is incompatible with {}",
2783                                type_family_name(item_family),
2784                                type_family_name(this_family)
2785                            ),
2786                        ));
2787                        break;
2788                    }
2789                }
2790            }
2791            Expression::Add(op)
2792            | Expression::Sub(op)
2793            | Expression::Mul(op)
2794            | Expression::Div(op)
2795            | Expression::Mod(op) => {
2796                let left = infer_expression_type_family(&op.left, schema_map, &context);
2797                let right = infer_expression_type_family(&op.right, schema_map, &context);
2798
2799                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2800                    continue;
2801                }
2802
2803                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2804                    && ((left.is_temporal() && right.is_numeric())
2805                        || (right.is_temporal() && left.is_numeric())
2806                        || (matches!(node, Expression::Sub(_))
2807                            && left.is_temporal()
2808                            && right.is_temporal()));
2809
2810                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2811                    errors.push(type_issue(
2812                        strict,
2813                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2814                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2815                        format!(
2816                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2817                            type_family_name(left),
2818                            type_family_name(right)
2819                        ),
2820                    ));
2821                }
2822            }
2823            Expression::Function(function) => {
2824                check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2825                check_generic_function(function, schema_map, &context, strict, &mut errors);
2826            }
2827            Expression::Upper(func)
2828            | Expression::Lower(func)
2829            | Expression::LTrim(func)
2830            | Expression::RTrim(func)
2831            | Expression::Reverse(func) => {
2832                let family = infer_expression_type_family(&func.this, schema_map, &context);
2833                check_function_argument(
2834                    &mut errors,
2835                    strict,
2836                    "string_function",
2837                    0,
2838                    family,
2839                    "a string argument",
2840                    is_string_like(family),
2841                );
2842            }
2843            Expression::Length(func) => {
2844                let family = infer_expression_type_family(&func.this, schema_map, &context);
2845                check_function_argument(
2846                    &mut errors,
2847                    strict,
2848                    "length",
2849                    0,
2850                    family,
2851                    "a string or binary argument",
2852                    is_string_or_binary(family),
2853                );
2854            }
2855            Expression::Trim(func) => {
2856                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2857                check_function_argument(
2858                    &mut errors,
2859                    strict,
2860                    "trim",
2861                    0,
2862                    this_family,
2863                    "a string argument",
2864                    is_string_like(this_family),
2865                );
2866                if let Some(chars) = &func.characters {
2867                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2868                    check_function_argument(
2869                        &mut errors,
2870                        strict,
2871                        "trim",
2872                        1,
2873                        chars_family,
2874                        "a string argument",
2875                        is_string_like(chars_family),
2876                    );
2877                }
2878            }
2879            Expression::Substring(func) => {
2880                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2881                check_function_argument(
2882                    &mut errors,
2883                    strict,
2884                    "substring",
2885                    0,
2886                    this_family,
2887                    "a string argument",
2888                    is_string_like(this_family),
2889                );
2890
2891                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2892                check_function_argument(
2893                    &mut errors,
2894                    strict,
2895                    "substring",
2896                    1,
2897                    start_family,
2898                    "a numeric argument",
2899                    start_family.is_numeric(),
2900                );
2901                if let Some(length) = &func.length {
2902                    let length_family = infer_expression_type_family(length, schema_map, &context);
2903                    check_function_argument(
2904                        &mut errors,
2905                        strict,
2906                        "substring",
2907                        2,
2908                        length_family,
2909                        "a numeric argument",
2910                        length_family.is_numeric(),
2911                    );
2912                }
2913            }
2914            Expression::Replace(func) => {
2915                for (arg, idx) in [
2916                    (&func.this, 0_usize),
2917                    (&func.old, 1_usize),
2918                    (&func.new, 2_usize),
2919                ] {
2920                    let family = infer_expression_type_family(arg, schema_map, &context);
2921                    check_function_argument(
2922                        &mut errors,
2923                        strict,
2924                        "replace",
2925                        idx,
2926                        family,
2927                        "a string argument",
2928                        is_string_like(family),
2929                    );
2930                }
2931            }
2932            Expression::Left(func) | Expression::Right(func) => {
2933                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2934                check_function_argument(
2935                    &mut errors,
2936                    strict,
2937                    "left_right",
2938                    0,
2939                    this_family,
2940                    "a string argument",
2941                    is_string_like(this_family),
2942                );
2943                let length_family =
2944                    infer_expression_type_family(&func.length, schema_map, &context);
2945                check_function_argument(
2946                    &mut errors,
2947                    strict,
2948                    "left_right",
2949                    1,
2950                    length_family,
2951                    "a numeric argument",
2952                    length_family.is_numeric(),
2953                );
2954            }
2955            Expression::Repeat(func) => {
2956                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2957                check_function_argument(
2958                    &mut errors,
2959                    strict,
2960                    "repeat",
2961                    0,
2962                    this_family,
2963                    "a string argument",
2964                    is_string_like(this_family),
2965                );
2966                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2967                check_function_argument(
2968                    &mut errors,
2969                    strict,
2970                    "repeat",
2971                    1,
2972                    times_family,
2973                    "a numeric argument",
2974                    times_family.is_numeric(),
2975                );
2976            }
2977            Expression::Lpad(func) | Expression::Rpad(func) => {
2978                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2979                check_function_argument(
2980                    &mut errors,
2981                    strict,
2982                    "pad",
2983                    0,
2984                    this_family,
2985                    "a string argument",
2986                    is_string_like(this_family),
2987                );
2988                let length_family =
2989                    infer_expression_type_family(&func.length, schema_map, &context);
2990                check_function_argument(
2991                    &mut errors,
2992                    strict,
2993                    "pad",
2994                    1,
2995                    length_family,
2996                    "a numeric argument",
2997                    length_family.is_numeric(),
2998                );
2999                if let Some(fill) = &func.fill {
3000                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
3001                    check_function_argument(
3002                        &mut errors,
3003                        strict,
3004                        "pad",
3005                        2,
3006                        fill_family,
3007                        "a string argument",
3008                        is_string_like(fill_family),
3009                    );
3010                }
3011            }
3012            Expression::Abs(func)
3013            | Expression::Sqrt(func)
3014            | Expression::Cbrt(func)
3015            | Expression::Ln(func)
3016            | Expression::Exp(func) => {
3017                let family = infer_expression_type_family(&func.this, schema_map, &context);
3018                check_function_argument(
3019                    &mut errors,
3020                    strict,
3021                    "numeric_function",
3022                    0,
3023                    family,
3024                    "a numeric argument",
3025                    family.is_numeric(),
3026                );
3027            }
3028            Expression::Round(func) => {
3029                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3030                check_function_argument(
3031                    &mut errors,
3032                    strict,
3033                    "round",
3034                    0,
3035                    this_family,
3036                    "a numeric argument",
3037                    this_family.is_numeric(),
3038                );
3039                if let Some(decimals) = &func.decimals {
3040                    let decimals_family =
3041                        infer_expression_type_family(decimals, schema_map, &context);
3042                    check_function_argument(
3043                        &mut errors,
3044                        strict,
3045                        "round",
3046                        1,
3047                        decimals_family,
3048                        "a numeric argument",
3049                        decimals_family.is_numeric(),
3050                    );
3051                }
3052            }
3053            Expression::Floor(func) => {
3054                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3055                check_function_argument(
3056                    &mut errors,
3057                    strict,
3058                    "floor",
3059                    0,
3060                    this_family,
3061                    "a numeric argument",
3062                    this_family.is_numeric(),
3063                );
3064                if let Some(scale) = &func.scale {
3065                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
3066                    check_function_argument(
3067                        &mut errors,
3068                        strict,
3069                        "floor",
3070                        1,
3071                        scale_family,
3072                        "a numeric argument",
3073                        scale_family.is_numeric(),
3074                    );
3075                }
3076            }
3077            Expression::Ceil(func) => {
3078                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3079                check_function_argument(
3080                    &mut errors,
3081                    strict,
3082                    "ceil",
3083                    0,
3084                    this_family,
3085                    "a numeric argument",
3086                    this_family.is_numeric(),
3087                );
3088                if let Some(decimals) = &func.decimals {
3089                    let decimals_family =
3090                        infer_expression_type_family(decimals, schema_map, &context);
3091                    check_function_argument(
3092                        &mut errors,
3093                        strict,
3094                        "ceil",
3095                        1,
3096                        decimals_family,
3097                        "a numeric argument",
3098                        decimals_family.is_numeric(),
3099                    );
3100                }
3101            }
3102            Expression::Power(func) => {
3103                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3104                check_function_argument(
3105                    &mut errors,
3106                    strict,
3107                    "power",
3108                    0,
3109                    left_family,
3110                    "a numeric argument",
3111                    left_family.is_numeric(),
3112                );
3113                let right_family =
3114                    infer_expression_type_family(&func.expression, schema_map, &context);
3115                check_function_argument(
3116                    &mut errors,
3117                    strict,
3118                    "power",
3119                    1,
3120                    right_family,
3121                    "a numeric argument",
3122                    right_family.is_numeric(),
3123                );
3124            }
3125            Expression::Log(func) => {
3126                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3127                check_function_argument(
3128                    &mut errors,
3129                    strict,
3130                    "log",
3131                    0,
3132                    this_family,
3133                    "a numeric argument",
3134                    this_family.is_numeric(),
3135                );
3136                if let Some(base) = &func.base {
3137                    let base_family = infer_expression_type_family(base, schema_map, &context);
3138                    check_function_argument(
3139                        &mut errors,
3140                        strict,
3141                        "log",
3142                        1,
3143                        base_family,
3144                        "a numeric argument",
3145                        base_family.is_numeric(),
3146                    );
3147                }
3148            }
3149            _ => {}
3150        }
3151    }
3152
3153    errors
3154}
3155
3156fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3157    let mut errors = Vec::new();
3158
3159    let Expression::Select(select) = stmt else {
3160        return errors;
3161    };
3162    let select_expr = Expression::Select(select.clone());
3163
3164    // W001: SELECT * is discouraged
3165    if !select_expr
3166        .find_all(|e| matches!(e, Expression::Star(_)))
3167        .is_empty()
3168    {
3169        errors.push(ValidationError::warning(
3170            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3171            validation_codes::W_SELECT_STAR,
3172        ));
3173    }
3174
3175    // W002: aggregate + non-aggregate columns without GROUP BY
3176    let aggregate_count = get_aggregate_functions(&select_expr).len();
3177    if aggregate_count > 0 && select.group_by.is_none() {
3178        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3179            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3180                && get_aggregate_functions(expr).is_empty()
3181        });
3182
3183        if has_non_aggregate_column {
3184            errors.push(ValidationError::warning(
3185                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3186                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3187            ));
3188        }
3189    }
3190
3191    // W003: DISTINCT with ORDER BY
3192    if select.distinct && select.order_by.is_some() {
3193        errors.push(ValidationError::warning(
3194            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3195            validation_codes::W_DISTINCT_ORDER_BY,
3196        ));
3197    }
3198
3199    // W004: LIMIT without ORDER BY
3200    if select.limit.is_some() && select.order_by.is_none() {
3201        errors.push(ValidationError::warning(
3202            "LIMIT without ORDER BY produces non-deterministic results",
3203            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3204        ));
3205    }
3206
3207    errors
3208}
3209
3210fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3211    scope
3212        .sources
3213        .get_key_value(name)
3214        .map(|(k, _)| k.clone())
3215        .or_else(|| {
3216            scope
3217                .sources
3218                .keys()
3219                .find(|source| source.eq_ignore_ascii_case(name))
3220                .cloned()
3221        })
3222}
3223
3224fn source_has_column(columns: &[String], column_name: &str) -> bool {
3225    columns
3226        .iter()
3227        .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3228}
3229
3230fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3231    scope
3232        .sources
3233        .get(source_name)
3234        .map(|source| match &source.expression {
3235            Expression::Table(table) => lower(&table_ref_display_name(table)),
3236            _ => lower(source_name),
3237        })
3238        .unwrap_or_else(|| lower(source_name))
3239}
3240
3241fn validate_select_columns_with_schema(
3242    select: &crate::expressions::Select,
3243    schema_map: &HashMap<String, TableSchemaEntry>,
3244    resolver_schema: &MappingSchema,
3245    strict: bool,
3246) -> Vec<ValidationError> {
3247    let mut errors = Vec::new();
3248    let select_expr = Expression::Select(Box::new(select.clone()));
3249    let scope = build_scope(&select_expr);
3250    let mut resolver = Resolver::new(&scope, resolver_schema, true);
3251    let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3252
3253    for node in walk_in_scope(&select_expr, false) {
3254        let Expression::Column(column) = node else {
3255            continue;
3256        };
3257
3258        let col_name = lower(&column.name.name);
3259        if col_name.is_empty() {
3260            continue;
3261        }
3262
3263        if let Some(table) = &column.table {
3264            let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3265                // The table qualifier is not a declared alias or source in this scope
3266                errors.push(if strict {
3267                    ValidationError::error(
3268                        format!(
3269                            "Unknown table or alias '{}' referenced by column '{}'",
3270                            table.name, col_name
3271                        ),
3272                        validation_codes::E_UNRESOLVED_REFERENCE,
3273                    )
3274                } else {
3275                    ValidationError::warning(
3276                        format!(
3277                            "Unknown table or alias '{}' referenced by column '{}'",
3278                            table.name, col_name
3279                        ),
3280                        validation_codes::E_UNRESOLVED_REFERENCE,
3281                    )
3282                });
3283                continue;
3284            };
3285
3286            if let Ok(columns) = resolver.get_source_columns(&source_name) {
3287                if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3288                    let table_name = source_display_name(&scope, &source_name);
3289                    errors.push(if strict {
3290                        ValidationError::error(
3291                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3292                            validation_codes::E_UNKNOWN_COLUMN,
3293                        )
3294                    } else {
3295                        ValidationError::warning(
3296                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3297                            validation_codes::E_UNKNOWN_COLUMN,
3298                        )
3299                    });
3300                }
3301            }
3302            continue;
3303        }
3304
3305        let matching_sources: Vec<String> = source_names
3306            .iter()
3307            .filter_map(|source_name| {
3308                resolver
3309                    .get_source_columns(source_name)
3310                    .ok()
3311                    .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3312                    .map(|_| source_name.clone())
3313            })
3314            .collect();
3315
3316        if !matching_sources.is_empty() {
3317            continue;
3318        }
3319
3320        let known_sources: Vec<String> = source_names
3321            .iter()
3322            .filter_map(|source_name| {
3323                resolver
3324                    .get_source_columns(source_name)
3325                    .ok()
3326                    .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3327                    .map(|_| source_name.clone())
3328            })
3329            .collect();
3330
3331        if known_sources.len() == 1 {
3332            let table_name = source_display_name(&scope, &known_sources[0]);
3333            errors.push(if strict {
3334                ValidationError::error(
3335                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3336                    validation_codes::E_UNKNOWN_COLUMN,
3337                )
3338            } else {
3339                ValidationError::warning(
3340                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3341                    validation_codes::E_UNKNOWN_COLUMN,
3342                )
3343            });
3344        } else if known_sources.len() > 1 {
3345            errors.push(if strict {
3346                ValidationError::error(
3347                    format!(
3348                        "Unknown column '{}' (not found in any referenced table)",
3349                        col_name
3350                    ),
3351                    validation_codes::E_UNKNOWN_COLUMN,
3352                )
3353            } else {
3354                ValidationError::warning(
3355                    format!(
3356                        "Unknown column '{}' (not found in any referenced table)",
3357                        col_name
3358                    ),
3359                    validation_codes::E_UNKNOWN_COLUMN,
3360                )
3361            });
3362        } else if !schema_map.is_empty() {
3363            let found = schema_map
3364                .values()
3365                .any(|table_schema| table_schema.columns.contains_key(&col_name));
3366            if !found {
3367                errors.push(if strict {
3368                    ValidationError::error(
3369                        format!("Unknown column '{}'", col_name),
3370                        validation_codes::E_UNKNOWN_COLUMN,
3371                    )
3372                } else {
3373                    ValidationError::warning(
3374                        format!("Unknown column '{}'", col_name),
3375                        validation_codes::E_UNKNOWN_COLUMN,
3376                    )
3377                });
3378            }
3379        }
3380    }
3381
3382    errors
3383}
3384
3385fn validate_statement_with_schema(
3386    stmt: &Expression,
3387    schema_map: &HashMap<String, TableSchemaEntry>,
3388    resolver_schema: &MappingSchema,
3389    strict: bool,
3390) -> Vec<ValidationError> {
3391    let mut errors = Vec::new();
3392    let cte_aliases = collect_cte_aliases(stmt);
3393    let mut seen_tables: HashSet<String> = HashSet::new();
3394
3395    // Table validation (E200)
3396    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3397        let Expression::Table(table) = node else {
3398            continue;
3399        };
3400
3401        if cte_aliases.contains(&lower(&table.name.name)) {
3402            continue;
3403        }
3404
3405        let resolved_key = table_ref_candidates(table)
3406            .into_iter()
3407            .find(|k| schema_map.contains_key(k));
3408        let table_key = resolved_key
3409            .clone()
3410            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3411
3412        if !seen_tables.insert(table_key) {
3413            continue;
3414        }
3415
3416        if resolved_key.is_none() {
3417            errors.push(if strict {
3418                ValidationError::error(
3419                    format!("Unknown table '{}'", table_ref_display_name(table)),
3420                    validation_codes::E_UNKNOWN_TABLE,
3421                )
3422            } else {
3423                ValidationError::warning(
3424                    format!("Unknown table '{}'", table_ref_display_name(table)),
3425                    validation_codes::E_UNKNOWN_TABLE,
3426                )
3427            });
3428        }
3429    }
3430
3431    for node in stmt.dfs() {
3432        let Expression::Select(select) = node else {
3433            continue;
3434        };
3435        errors.extend(validate_select_columns_with_schema(
3436            select,
3437            schema_map,
3438            resolver_schema,
3439            strict,
3440        ));
3441    }
3442
3443    errors
3444}
3445
3446/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3447pub fn validate_with_schema(
3448    sql: &str,
3449    dialect: DialectType,
3450    schema: &ValidationSchema,
3451    options: &SchemaValidationOptions,
3452) -> ValidationResult {
3453    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3454
3455    // Syntax validation first.
3456    let syntax_result = crate::validate_with_options(
3457        sql,
3458        dialect,
3459        &crate::ValidationOptions {
3460            strict_syntax: options.strict_syntax,
3461        },
3462    );
3463    if !syntax_result.valid {
3464        return syntax_result;
3465    }
3466
3467    let d = Dialect::get(dialect);
3468    let statements = match d.parse(sql) {
3469        Ok(exprs) => exprs,
3470        Err(e) => {
3471            return ValidationResult::with_errors(vec![ValidationError::error(
3472                e.to_string(),
3473                validation_codes::E_PARSE_OR_OPTIONS,
3474            )]);
3475        }
3476    };
3477
3478    let schema_map = build_schema_map(schema);
3479    let resolver_schema = build_resolver_schema(schema);
3480    let mut all_errors = syntax_result.errors;
3481    let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3482        default_embedded_function_catalog()
3483    } else {
3484        None
3485    };
3486    let effective_function_catalog = options
3487        .function_catalog
3488        .as_deref()
3489        .or_else(|| embedded_function_catalog.as_deref());
3490    let declared_relationships = if options.check_references {
3491        build_declared_relationships(schema, &schema_map)
3492    } else {
3493        Vec::new()
3494    };
3495
3496    if options.check_references {
3497        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3498    }
3499
3500    for stmt in &statements {
3501        if options.semantic {
3502            all_errors.extend(check_semantics(stmt));
3503        }
3504        all_errors.extend(validate_statement_with_schema(
3505            stmt,
3506            &schema_map,
3507            &resolver_schema,
3508            strict,
3509        ));
3510        if options.check_types {
3511            all_errors.extend(check_types(
3512                stmt,
3513                dialect,
3514                &schema_map,
3515                effective_function_catalog,
3516                strict,
3517            ));
3518        }
3519        if options.check_references {
3520            all_errors.extend(check_query_reference_quality(
3521                stmt,
3522                &schema_map,
3523                &resolver_schema,
3524                strict,
3525                &declared_relationships,
3526            ));
3527        }
3528    }
3529
3530    ValidationResult::with_errors(all_errors)
3531}
3532
3533#[cfg(test)]
3534mod tests;