Skip to main content

polyglot_sql/
validation.rs

1//! Schema-aware and semantic SQL validation.
2//!
3//! This module extends syntax validation with:
4//! - schema checks (unknown tables/columns)
5//! - optional semantic warnings (SELECT *, LIMIT without ORDER BY, etc.)
6
7use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11    Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15    feature = "function-catalog-clickhouse",
16    feature = "function-catalog-duckdb",
17    feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20    FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21    HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34    feature = "function-catalog-clickhouse",
35    feature = "function-catalog-duckdb",
36    feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40/// Column definition used for schema-aware validation.
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43    /// Column name.
44    pub name: String,
45    /// Optional column data type (currently informational).
46    #[serde(default, rename = "type")]
47    pub data_type: String,
48    /// Whether the column allows NULL values.
49    #[serde(default)]
50    pub nullable: Option<bool>,
51    /// Whether this column is part of a primary key.
52    #[serde(default, rename = "primaryKey")]
53    pub primary_key: bool,
54    /// Whether this column has a uniqueness constraint.
55    #[serde(default)]
56    pub unique: bool,
57    /// Optional column-level foreign key reference.
58    #[serde(default)]
59    pub references: Option<SchemaColumnReference>,
60}
61
62/// Column-level foreign key reference metadata.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65    /// Referenced table name.
66    pub table: String,
67    /// Referenced column name.
68    pub column: String,
69    /// Optional schema/namespace of referenced table.
70    #[serde(default)]
71    pub schema: Option<String>,
72}
73
74/// Table-level foreign key reference metadata.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77    /// Optional FK name.
78    #[serde(default)]
79    pub name: Option<String>,
80    /// Source columns in the current table.
81    pub columns: Vec<String>,
82    /// Referenced target table + columns.
83    pub references: SchemaTableReference,
84}
85
86/// Target of a table-level foreign key.
87#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89    /// Referenced table name.
90    pub table: String,
91    /// Referenced target columns.
92    pub columns: Vec<String>,
93    /// Optional schema/namespace of referenced table.
94    #[serde(default)]
95    pub schema: Option<String>,
96}
97
98/// Table definition used for schema-aware validation.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101    /// Table name.
102    pub name: String,
103    /// Optional schema/namespace name.
104    #[serde(default)]
105    pub schema: Option<String>,
106    /// Column definitions.
107    pub columns: Vec<SchemaColumn>,
108    /// Optional aliases that should resolve to this table.
109    #[serde(default)]
110    pub aliases: Vec<String>,
111    /// Optional primary key column list.
112    #[serde(default, rename = "primaryKey")]
113    pub primary_key: Vec<String>,
114    /// Optional unique key groups.
115    #[serde(default, rename = "uniqueKeys")]
116    pub unique_keys: Vec<Vec<String>>,
117    /// Optional table-level foreign keys.
118    #[serde(default, rename = "foreignKeys")]
119    pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122/// Schema payload used for schema-aware validation.
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125    /// Known tables.
126    pub tables: Vec<SchemaTable>,
127    /// Default strict mode for unknown identifiers.
128    #[serde(default)]
129    pub strict: Option<bool>,
130}
131
132/// Options for schema-aware validation.
133#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135    /// Enables type compatibility checks for expressions, DML assignments, and set operations.
136    #[serde(default)]
137    pub check_types: bool,
138    /// Enables FK/reference integrity checks and query-level reference quality checks.
139    #[serde(default)]
140    pub check_references: bool,
141    /// If true/false, overrides schema.strict.
142    #[serde(default)]
143    pub strict: Option<bool>,
144    /// Enables semantic warnings (W001..W004).
145    #[serde(default)]
146    pub semantic: bool,
147    /// Enables strict syntax checks (e.g. rejects trailing commas before clause boundaries).
148    #[serde(default)]
149    pub strict_syntax: bool,
150    /// Optional external function catalog plugin for dialect-specific function validation.
151    #[serde(skip, default)]
152    pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156    feature = "function-catalog-clickhouse",
157    feature = "function-catalog-duckdb",
158    feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161    case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163    match case {
164        polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165            CoreFunctionNameCase::Insensitive
166        }
167        polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168            CoreFunctionNameCase::Sensitive
169        }
170    }
171}
172
173#[cfg(any(
174    feature = "function-catalog-clickhouse",
175    feature = "function-catalog-duckdb",
176    feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179    signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181    signatures
182        .into_iter()
183        .map(|signature| CoreFunctionSignature {
184            min_arity: signature.min_arity,
185            max_arity: signature.max_arity,
186        })
187        .collect()
188}
189
190#[cfg(any(
191    feature = "function-catalog-clickhouse",
192    feature = "function-catalog-duckdb",
193    feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196    catalog: &'a mut HashMapFunctionCatalog,
197    dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201    feature = "function-catalog-clickhouse",
202    feature = "function-catalog-duckdb",
203    feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206    fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207        if let Some(cached) = self.dialect_cache.get(dialect) {
208            return *cached;
209        }
210        let parsed = dialect.parse::<DialectType>().ok();
211        self.dialect_cache.insert(dialect, parsed);
212        parsed
213    }
214}
215
216#[cfg(any(
217    feature = "function-catalog-clickhouse",
218    feature = "function-catalog-duckdb",
219    feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222    fn set_dialect_name_case(
223        &mut self,
224        dialect: &'static str,
225        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226    ) {
227        if let Some(core_dialect) = self.resolve_dialect(dialect) {
228            self.catalog
229                .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230        }
231    }
232
233    fn set_function_name_case(
234        &mut self,
235        dialect: &'static str,
236        function_name: &str,
237        name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238    ) {
239        if let Some(core_dialect) = self.resolve_dialect(dialect) {
240            self.catalog.set_function_name_case(
241                core_dialect,
242                function_name,
243                to_core_name_case(name_case),
244            );
245        }
246    }
247
248    fn register(
249        &mut self,
250        dialect: &'static str,
251        function_name: &str,
252        signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253    ) {
254        if let Some(core_dialect) = self.resolve_dialect(dialect) {
255            self.catalog
256                .register(core_dialect, function_name, to_core_signatures(signatures));
257        }
258    }
259}
260
261#[cfg(any(
262    feature = "function-catalog-clickhouse",
263    feature = "function-catalog-duckdb",
264    feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267    static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268        let mut catalog = HashMapFunctionCatalog::default();
269        let mut sink = EmbeddedCatalogSink {
270            catalog: &mut catalog,
271            dialect_cache: HashMap::new(),
272        };
273        polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274        Arc::new(catalog)
275    });
276
277    EMBEDDED.clone()
278}
279
280#[cfg(any(
281    feature = "function-catalog-clickhouse",
282    feature = "function-catalog-duckdb",
283    feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286    Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290    feature = "function-catalog-clickhouse",
291    feature = "function-catalog-duckdb",
292    feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295    None
296}
297
298/// Validation error/warning codes used by schema-aware validation.
299pub mod validation_codes {
300    // Existing schema and semantic checks.
301    pub const E_PARSE_OR_OPTIONS: &str = "E000";
302    pub const E_UNKNOWN_TABLE: &str = "E200";
303    pub const E_UNKNOWN_COLUMN: &str = "E201";
304    pub const E_UNKNOWN_FUNCTION: &str = "E202";
305    pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307    pub const W_SELECT_STAR: &str = "W001";
308    pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309    pub const W_DISTINCT_ORDER_BY: &str = "W003";
310    pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312    // Phase 2 (type checks): E210-E219, W210-W219.
313    pub const E_TYPE_MISMATCH: &str = "E210";
314    pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315    pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316    pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317    pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318    pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319    pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320    pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321    pub const E_INVALID_CAST: &str = "E218";
322    pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324    pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325    pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326    pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327    pub const W_LOSSY_CAST: &str = "W213";
328    pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329    pub const W_PREDICATE_NULLABILITY: &str = "W215";
330    pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331    pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332    pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333    pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335    // Phase 2 (reference checks): E220-E229, W220-W229.
336    pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337    pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338    pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339    pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340    pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342    pub const W_CARTESIAN_JOIN: &str = "W220";
343    pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344    pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347/// Canonical type family used by schema/type checks.
348#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351    Unknown,
352    Boolean,
353    Integer,
354    Numeric,
355    String,
356    Binary,
357    Date,
358    Time,
359    Timestamp,
360    Interval,
361    Json,
362    Uuid,
363    Array,
364    Map,
365    Struct,
366}
367
368impl TypeFamily {
369    pub fn is_numeric(self) -> bool {
370        matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371    }
372
373    pub fn is_temporal(self) -> bool {
374        matches!(
375            self,
376            TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377        )
378    }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383    columns: HashMap<String, TypeFamily>,
384    column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388    s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392    let open = data_type.find('(')?;
393    if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394        return None;
395    }
396    let base = data_type[..open].trim();
397    let inner = data_type[open + 1..data_type.len() - 1].trim();
398    Some((base, inner))
399}
400
401/// Canonicalize a schema type string into a stable `TypeFamily`.
402pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403    let trimmed = data_type
404        .trim()
405        .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406    if trimmed.is_empty() {
407        return TypeFamily::Unknown;
408    }
409
410    // Normalize whitespace and lowercase for matching.
411    let normalized = trimmed
412        .split_whitespace()
413        .collect::<Vec<_>>()
414        .join(" ")
415        .to_lowercase();
416
417    // Strip common wrappers first.
418    if let Some((base, inner)) = split_type_args(&normalized) {
419        match base {
420            "nullable" | "lowcardinality" => return canonical_type_family(inner),
421            "array" | "list" => return TypeFamily::Array,
422            "map" => return TypeFamily::Map,
423            "struct" | "row" | "record" => return TypeFamily::Struct,
424            _ => {}
425        }
426    }
427
428    if normalized.starts_with("array<") || normalized.starts_with("list<") {
429        return TypeFamily::Array;
430    }
431    if normalized.starts_with("map<") {
432        return TypeFamily::Map;
433    }
434    if normalized.starts_with("struct<")
435        || normalized.starts_with("row<")
436        || normalized.starts_with("record<")
437        || normalized.starts_with("object<")
438    {
439        return TypeFamily::Struct;
440    }
441
442    if normalized.ends_with("[]") {
443        return TypeFamily::Array;
444    }
445
446    // Remove parameter list if present, e.g. VARCHAR(255), DECIMAL(10,2).
447    let mut base = normalized
448        .split('(')
449        .next()
450        .unwrap_or("")
451        .trim()
452        .to_string();
453    if base.is_empty() {
454        return TypeFamily::Unknown;
455    }
456
457    base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458    base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460    match base.as_str() {
461        "bool" | "boolean" => TypeFamily::Boolean,
462        "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463        | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464        | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465            TypeFamily::Integer
466        }
467        "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468        | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469            TypeFamily::Numeric
470        }
471        "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472        | "string" | "clob" => TypeFamily::String,
473        "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474        "date" => TypeFamily::Date,
475        "time" => TypeFamily::Time,
476        "timestamp"
477        | "timestamptz"
478        | "datetime"
479        | "datetime2"
480        | "smalldatetime"
481        | "timestamp with time zone"
482        | "timestamp without time zone" => TypeFamily::Timestamp,
483        "interval" => TypeFamily::Interval,
484        "json" | "jsonb" | "variant" => TypeFamily::Json,
485        "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486        "array" | "list" => TypeFamily::Array,
487        "map" => TypeFamily::Map,
488        "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489        _ => TypeFamily::Unknown,
490    }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494    let mut map = HashMap::new();
495
496    for table in &schema.tables {
497        let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498        let columns: HashMap<String, TypeFamily> = table
499            .columns
500            .iter()
501            .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502            .collect();
503        let entry = TableSchemaEntry {
504            columns,
505            column_order,
506        };
507
508        let simple_name = lower(&table.name);
509        map.insert(simple_name, entry.clone());
510
511        if let Some(table_schema) = &table.schema {
512            map.insert(
513                format!("{}.{}", lower(table_schema), lower(&table.name)),
514                entry.clone(),
515            );
516        }
517
518        for alias in &table.aliases {
519            map.insert(lower(alias), entry.clone());
520        }
521    }
522
523    map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527    match family {
528        TypeFamily::Unknown => DataType::Unknown,
529        TypeFamily::Boolean => DataType::Boolean,
530        TypeFamily::Integer => DataType::Int {
531            length: None,
532            integer_spelling: false,
533        },
534        TypeFamily::Numeric => DataType::Double {
535            precision: None,
536            scale: None,
537        },
538        TypeFamily::String => DataType::VarChar {
539            length: None,
540            parenthesized_length: false,
541        },
542        TypeFamily::Binary => DataType::VarBinary { length: None },
543        TypeFamily::Date => DataType::Date,
544        TypeFamily::Time => DataType::Time {
545            precision: None,
546            timezone: false,
547        },
548        TypeFamily::Timestamp => DataType::Timestamp {
549            precision: None,
550            timezone: false,
551        },
552        TypeFamily::Interval => DataType::Interval {
553            unit: None,
554            to: None,
555        },
556        TypeFamily::Json => DataType::Json,
557        TypeFamily::Uuid => DataType::Uuid,
558        TypeFamily::Array => DataType::Array {
559            element_type: Box::new(DataType::Unknown),
560            dimension: None,
561        },
562        TypeFamily::Map => DataType::Map {
563            key_type: Box::new(DataType::Unknown),
564            value_type: Box::new(DataType::Unknown),
565        },
566        TypeFamily::Struct => DataType::Struct {
567            fields: Vec::new(),
568            nested: false,
569        },
570    }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574    let mut mapping = MappingSchema::new();
575
576    for table in &schema.tables {
577        let columns: Vec<(String, DataType)> = table
578            .columns
579            .iter()
580            .map(|column| {
581                (
582                    lower(&column.name),
583                    type_family_to_data_type(canonical_type_family(&column.data_type)),
584                )
585            })
586            .collect();
587
588        let mut table_names = Vec::new();
589        table_names.push(lower(&table.name));
590        if let Some(table_schema) = &table.schema {
591            table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592        }
593        for alias in &table.aliases {
594            table_names.push(lower(alias));
595        }
596
597        let mut dedup = HashSet::new();
598        for table_name in table_names {
599            if dedup.insert(table_name.clone()) {
600                let _ = mapping.add_table(&table_name, &columns, None);
601            }
602        }
603    }
604
605    mapping
606}
607
608/// Build a `MappingSchema` from a `ValidationSchema` payload.
609///
610/// This is useful for APIs that already accept `ValidationSchema`-shaped input
611/// (for example JSON wrappers) and need to run schema-aware lineage or other
612/// resolver-based analysis.
613pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614    build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618    let mut aliases = HashSet::new();
619
620    for node in expr.dfs() {
621        match node {
622            Expression::Select(select) => {
623                if let Some(with) = &select.with {
624                    for cte in &with.ctes {
625                        aliases.insert(lower(&cte.alias.name));
626                    }
627                }
628            }
629            Expression::Insert(insert) => {
630                if let Some(with) = &insert.with {
631                    for cte in &with.ctes {
632                        aliases.insert(lower(&cte.alias.name));
633                    }
634                }
635            }
636            Expression::Update(update) => {
637                if let Some(with) = &update.with {
638                    for cte in &with.ctes {
639                        aliases.insert(lower(&cte.alias.name));
640                    }
641                }
642            }
643            Expression::Delete(delete) => {
644                if let Some(with) = &delete.with {
645                    for cte in &with.ctes {
646                        aliases.insert(lower(&cte.alias.name));
647                    }
648                }
649            }
650            Expression::Union(union) => {
651                if let Some(with) = &union.with {
652                    for cte in &with.ctes {
653                        aliases.insert(lower(&cte.alias.name));
654                    }
655                }
656            }
657            Expression::Intersect(intersect) => {
658                if let Some(with) = &intersect.with {
659                    for cte in &with.ctes {
660                        aliases.insert(lower(&cte.alias.name));
661                    }
662                }
663            }
664            Expression::Except(except) => {
665                if let Some(with) = &except.with {
666                    for cte in &with.ctes {
667                        aliases.insert(lower(&cte.alias.name));
668                    }
669                }
670            }
671            Expression::Merge(merge) => {
672                if let Some(with_) = &merge.with_ {
673                    if let Expression::With(with_clause) = with_.as_ref() {
674                        for cte in &with_clause.ctes {
675                            aliases.insert(lower(&cte.alias.name));
676                        }
677                    }
678                }
679            }
680            _ => {}
681        }
682    }
683
684    aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688    let name = lower(&table.name.name);
689    let schema = table.schema.as_ref().map(|s| lower(&s.name));
690    let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692    let mut candidates = Vec::new();
693    if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694        candidates.push(format!("{}.{}.{}", catalog, schema, name));
695    }
696    if let Some(schema) = &schema {
697        candidates.push(format!("{}.{}", schema, name));
698    }
699    candidates.push(name);
700    candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704    let mut parts = Vec::new();
705    if let Some(catalog) = &table.catalog {
706        parts.push(catalog.name.clone());
707    }
708    if let Some(schema) = &table.schema {
709        parts.push(schema.name.clone());
710    }
711    parts.push(table.name.name.clone());
712    parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717    referenced_tables: HashSet<String>,
718    table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722    match family {
723        TypeFamily::Unknown => "unknown",
724        TypeFamily::Boolean => "boolean",
725        TypeFamily::Integer => "integer",
726        TypeFamily::Numeric => "numeric",
727        TypeFamily::String => "string",
728        TypeFamily::Binary => "binary",
729        TypeFamily::Date => "date",
730        TypeFamily::Time => "time",
731        TypeFamily::Timestamp => "timestamp",
732        TypeFamily::Interval => "interval",
733        TypeFamily::Json => "json",
734        TypeFamily::Uuid => "uuid",
735        TypeFamily::Array => "array",
736        TypeFamily::Map => "map",
737        TypeFamily::Struct => "struct",
738    }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742    matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746    matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750    strict: bool,
751    error_code: &str,
752    warning_code: &str,
753    message: impl Into<String>,
754) -> ValidationError {
755    if strict {
756        ValidationError::error(message.into(), error_code)
757    } else {
758        ValidationError::warning(message.into(), warning_code)
759    }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763    match data_type {
764        DataType::Boolean => TypeFamily::Boolean,
765        DataType::TinyInt { .. }
766        | DataType::SmallInt { .. }
767        | DataType::Int { .. }
768        | DataType::BigInt { .. } => TypeFamily::Integer,
769        DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770            TypeFamily::Numeric
771        }
772        DataType::Char { .. }
773        | DataType::VarChar { .. }
774        | DataType::String { .. }
775        | DataType::Text
776        | DataType::TextWithLength { .. }
777        | DataType::CharacterSet { .. } => TypeFamily::String,
778        DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779        DataType::Date => TypeFamily::Date,
780        DataType::Time { .. } => TypeFamily::Time,
781        DataType::Timestamp { .. } => TypeFamily::Timestamp,
782        DataType::Interval { .. } => TypeFamily::Interval,
783        DataType::Json | DataType::JsonB => TypeFamily::Json,
784        DataType::Uuid => TypeFamily::Uuid,
785        DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786        DataType::Map { .. } => TypeFamily::Map,
787        DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788            TypeFamily::Struct
789        }
790        DataType::Nullable { inner } => data_type_family(inner),
791        DataType::Custom { name } => canonical_type_family(name),
792        DataType::Unknown => TypeFamily::Unknown,
793        DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794        DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795        DataType::Vector { .. } => TypeFamily::Array,
796        DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797    }
798}
799
800fn collect_type_check_context(
801    stmt: &Expression,
802    schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804    fn add_table_to_context(
805        table: &TableRef,
806        schema_map: &HashMap<String, TableSchemaEntry>,
807        context: &mut TypeCheckContext,
808    ) {
809        let resolved_key = table_ref_candidates(table)
810            .into_iter()
811            .find(|k| schema_map.contains_key(k));
812
813        let Some(table_key) = resolved_key else {
814            return;
815        };
816
817        context.referenced_tables.insert(table_key.clone());
818        context
819            .table_aliases
820            .insert(lower(&table.name.name), table_key.clone());
821        if let Some(alias) = &table.alias {
822            context
823                .table_aliases
824                .insert(lower(&alias.name), table_key.clone());
825        }
826    }
827
828    let mut context = TypeCheckContext::default();
829    let cte_aliases = collect_cte_aliases(stmt);
830
831    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832        let Expression::Table(table) = node else {
833            continue;
834        };
835
836        if cte_aliases.contains(&lower(&table.name.name)) {
837            continue;
838        }
839
840        add_table_to_context(table, schema_map, &mut context);
841    }
842
843    // Seed DML target tables explicitly because they are struct fields and may
844    // not appear as standalone Expression::Table nodes in traversal output.
845    match stmt {
846        Expression::Insert(insert) => {
847            add_table_to_context(&insert.table, schema_map, &mut context);
848        }
849        Expression::Update(update) => {
850            add_table_to_context(&update.table, schema_map, &mut context);
851            for table in &update.extra_tables {
852                add_table_to_context(table, schema_map, &mut context);
853            }
854        }
855        Expression::Delete(delete) => {
856            add_table_to_context(&delete.table, schema_map, &mut context);
857            for table in &delete.using {
858                add_table_to_context(table, schema_map, &mut context);
859            }
860            for table in &delete.tables {
861                add_table_to_context(table, schema_map, &mut context);
862            }
863        }
864        _ => {}
865    }
866
867    context
868}
869
870fn resolve_table_schema_entry<'a>(
871    table: &TableRef,
872    schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874    let key = table_ref_candidates(table)
875        .into_iter()
876        .find(|k| schema_map.contains_key(k))?;
877    let entry = schema_map.get(&key)?;
878    Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882    if strict {
883        ValidationError::error(
884            message.into(),
885            validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886        )
887    } else {
888        ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889    }
890}
891
892fn reference_table_candidates(
893    table_name: &str,
894    explicit_schema: Option<&str>,
895    source_schema: Option<&str>,
896) -> Vec<String> {
897    let mut candidates = Vec::new();
898    let raw = lower(table_name);
899
900    if let Some(schema) = explicit_schema {
901        candidates.push(format!("{}.{}", lower(schema), raw));
902    }
903
904    if raw.contains('.') {
905        candidates.push(raw.clone());
906        if let Some(last) = raw.rsplit('.').next() {
907            candidates.push(last.to_string());
908        }
909    } else {
910        if let Some(schema) = source_schema {
911            candidates.push(format!("{}.{}", lower(schema), raw));
912        }
913        candidates.push(raw);
914    }
915
916    let mut dedup = HashSet::new();
917    candidates
918        .into_iter()
919        .filter(|c| dedup.insert(c.clone()))
920        .collect()
921}
922
923fn resolve_reference_table_key(
924    table_name: &str,
925    explicit_schema: Option<&str>,
926    source_schema: Option<&str>,
927    schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929    reference_table_candidates(table_name, explicit_schema, source_schema)
930        .into_iter()
931        .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935    if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936        return true;
937    }
938    if source == target {
939        return true;
940    }
941    if source.is_numeric() && target.is_numeric() {
942        return true;
943    }
944    if source.is_temporal() && target.is_temporal() {
945        return true;
946    }
947    false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951    let mut hints = HashSet::new();
952    for column in &table.columns {
953        if column.primary_key || column.unique {
954            hints.insert(lower(&column.name));
955        }
956    }
957    for key_col in &table.primary_key {
958        hints.insert(lower(key_col));
959    }
960    for group in &table.unique_keys {
961        if group.len() == 1 {
962            if let Some(col) = group.first() {
963                hints.insert(lower(col));
964            }
965        }
966    }
967    hints
968}
969
970fn check_reference_integrity(
971    schema: &ValidationSchema,
972    schema_map: &HashMap<String, TableSchemaEntry>,
973    strict: bool,
974) -> Vec<ValidationError> {
975    let mut errors = Vec::new();
976
977    let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978    for table in &schema.tables {
979        let simple = lower(&table.name);
980        key_hints_lookup.insert(simple, table_key_hints(table));
981        if let Some(schema_name) = &table.schema {
982            let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983            key_hints_lookup.insert(qualified, table_key_hints(table));
984        }
985    }
986
987    for table in &schema.tables {
988        let source_table_display = if let Some(schema_name) = &table.schema {
989            format!("{}.{}", schema_name, table.name)
990        } else {
991            table.name.clone()
992        };
993        let source_schema = table.schema.as_deref();
994        let source_columns: HashMap<String, TypeFamily> = table
995            .columns
996            .iter()
997            .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998            .collect();
999
1000        for source_col in &table.columns {
1001            let Some(reference) = &source_col.references else {
1002                continue;
1003            };
1004            let source_type = canonical_type_family(&source_col.data_type);
1005
1006            let Some(target_key) = resolve_reference_table_key(
1007                &reference.table,
1008                reference.schema.as_deref(),
1009                source_schema,
1010                schema_map,
1011            ) else {
1012                errors.push(reference_issue(
1013                    strict,
1014                    format!(
1015                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1016                        source_table_display, source_col.name, reference.table
1017                    ),
1018                ));
1019                continue;
1020            };
1021
1022            let target_column = lower(&reference.column);
1023            let Some(target_entry) = schema_map.get(&target_key) else {
1024                errors.push(reference_issue(
1025                    strict,
1026                    format!(
1027                        "Foreign key reference '{}.{}' points to unknown table '{}'",
1028                        source_table_display, source_col.name, reference.table
1029                    ),
1030                ));
1031                continue;
1032            };
1033
1034            let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035                errors.push(reference_issue(
1036                    strict,
1037                    format!(
1038                        "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039                        source_table_display, source_col.name, target_key, reference.column
1040                    ),
1041                ));
1042                continue;
1043            };
1044
1045            if !key_types_compatible(source_type, target_type) {
1046                errors.push(reference_issue(
1047                    strict,
1048                    format!(
1049                        "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050                        source_table_display,
1051                        source_col.name,
1052                        target_key,
1053                        reference.column,
1054                        type_family_name(source_type),
1055                        type_family_name(target_type)
1056                    ),
1057                ));
1058            }
1059
1060            if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061                if !target_key_hints.contains(&target_column) {
1062                    errors.push(ValidationError::warning(
1063                        format!(
1064                            "Referenced column '{}.{}' is not marked as primary/unique key",
1065                            target_key, reference.column
1066                        ),
1067                        validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068                    ));
1069                }
1070            }
1071        }
1072
1073        for foreign_key in &table.foreign_keys {
1074            if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075                errors.push(reference_issue(
1076                    strict,
1077                    format!(
1078                        "Table-level foreign key on '{}' has empty source or target column list",
1079                        source_table_display
1080                    ),
1081                ));
1082                continue;
1083            }
1084            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085                errors.push(reference_issue(
1086                    strict,
1087                    format!(
1088                        "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089                        source_table_display,
1090                        foreign_key.columns.len(),
1091                        foreign_key.references.columns.len()
1092                    ),
1093                ));
1094                continue;
1095            }
1096
1097            let Some(target_key) = resolve_reference_table_key(
1098                &foreign_key.references.table,
1099                foreign_key.references.schema.as_deref(),
1100                source_schema,
1101                schema_map,
1102            ) else {
1103                errors.push(reference_issue(
1104                    strict,
1105                    format!(
1106                        "Table-level foreign key on '{}' points to unknown table '{}'",
1107                        source_table_display, foreign_key.references.table
1108                    ),
1109                ));
1110                continue;
1111            };
1112
1113            let Some(target_entry) = schema_map.get(&target_key) else {
1114                errors.push(reference_issue(
1115                    strict,
1116                    format!(
1117                        "Table-level foreign key on '{}' points to unknown table '{}'",
1118                        source_table_display, foreign_key.references.table
1119                    ),
1120                ));
1121                continue;
1122            };
1123
1124            for (source_col, target_col) in foreign_key
1125                .columns
1126                .iter()
1127                .zip(foreign_key.references.columns.iter())
1128            {
1129                let source_col_name = lower(source_col);
1130                let target_col_name = lower(target_col);
1131
1132                let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133                    errors.push(reference_issue(
1134                        strict,
1135                        format!(
1136                            "Table-level foreign key on '{}' references unknown source column '{}'",
1137                            source_table_display, source_col
1138                        ),
1139                    ));
1140                    continue;
1141                };
1142
1143                let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144                    errors.push(reference_issue(
1145                        strict,
1146                        format!(
1147                            "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148                            source_table_display, target_key, target_col
1149                        ),
1150                    ));
1151                    continue;
1152                };
1153
1154                if !key_types_compatible(source_type, target_type) {
1155                    errors.push(reference_issue(
1156                        strict,
1157                        format!(
1158                            "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159                            source_table_display,
1160                            source_col,
1161                            target_key,
1162                            target_col,
1163                            type_family_name(source_type),
1164                            type_family_name(target_type)
1165                        ),
1166                    ));
1167                }
1168
1169                if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170                    if !target_key_hints.contains(&target_col_name) {
1171                        errors.push(ValidationError::warning(
1172                            format!(
1173                                "Referenced column '{}.{}' is not marked as primary/unique key",
1174                                target_key, target_col
1175                            ),
1176                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177                        ));
1178                    }
1179                }
1180            }
1181        }
1182    }
1183
1184    errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188    column_name: &str,
1189    schema_map: &HashMap<String, TableSchemaEntry>,
1190    context: &TypeCheckContext,
1191) -> TypeFamily {
1192    let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193        context.referenced_tables.iter().collect()
1194    } else {
1195        schema_map.keys().collect()
1196    };
1197
1198    let mut families = HashSet::new();
1199    for table_name in candidate_tables {
1200        if let Some(table_schema) = schema_map.get(table_name) {
1201            if let Some(family) = table_schema.columns.get(column_name) {
1202                families.insert(*family);
1203            }
1204        }
1205    }
1206
1207    if families.len() == 1 {
1208        *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209    } else {
1210        TypeFamily::Unknown
1211    }
1212}
1213
1214fn resolve_column_type(
1215    column: &Column,
1216    schema_map: &HashMap<String, TableSchemaEntry>,
1217    context: &TypeCheckContext,
1218) -> TypeFamily {
1219    let column_name = lower(&column.name.name);
1220    if column_name.is_empty() {
1221        return TypeFamily::Unknown;
1222    }
1223
1224    if let Some(table) = &column.table {
1225        let mut table_key = lower(&table.name);
1226        if let Some(mapped) = context.table_aliases.get(&table_key) {
1227            table_key = mapped.clone();
1228        }
1229
1230        return schema_map
1231            .get(&table_key)
1232            .and_then(|t| t.columns.get(&column_name))
1233            .copied()
1234            .unwrap_or(TypeFamily::Unknown);
1235    }
1236
1237    resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241    schema_map: &'a HashMap<String, TableSchemaEntry>,
1242    context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246    fn resolve_table_key(&self, table: &str) -> Option<String> {
1247        let mut table_key = lower(table);
1248        if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249            table_key = mapped.clone();
1250        }
1251        if self.schema_map.contains_key(&table_key) {
1252            Some(table_key)
1253        } else {
1254            None
1255        }
1256    }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260    fn dialect(&self) -> Option<DialectType> {
1261        None
1262    }
1263
1264    fn add_table(
1265        &mut self,
1266        _table: &str,
1267        _columns: &[(String, DataType)],
1268        _dialect: Option<DialectType>,
1269    ) -> SchemaResult<()> {
1270        Err(SchemaError::InvalidStructure(
1271            "Type inference schema is read-only".to_string(),
1272        ))
1273    }
1274
1275    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276        let table_key = self
1277            .resolve_table_key(table)
1278            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279        let entry = self
1280            .schema_map
1281            .get(&table_key)
1282            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283        Ok(entry.column_order.clone())
1284    }
1285
1286    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287        let col_name = lower(column);
1288        if table.is_empty() {
1289            let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290            return if family == TypeFamily::Unknown {
1291                Err(SchemaError::ColumnNotFound {
1292                    table: "<unqualified>".to_string(),
1293                    column: column.to_string(),
1294                })
1295            } else {
1296                Ok(type_family_to_data_type(family))
1297            };
1298        }
1299
1300        let table_key = self
1301            .resolve_table_key(table)
1302            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303        let entry = self
1304            .schema_map
1305            .get(&table_key)
1306            .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307        let family =
1308            entry
1309                .columns
1310                .get(&col_name)
1311                .copied()
1312                .ok_or_else(|| SchemaError::ColumnNotFound {
1313                    table: table.to_string(),
1314                    column: column.to_string(),
1315                })?;
1316        Ok(type_family_to_data_type(family))
1317    }
1318
1319    fn has_column(&self, table: &str, column: &str) -> bool {
1320        self.get_column_type(table, column).is_ok()
1321    }
1322
1323    fn supported_table_args(&self) -> &[&str] {
1324        TABLE_PARTS
1325    }
1326
1327    fn is_empty(&self) -> bool {
1328        self.schema_map.is_empty()
1329    }
1330
1331    fn depth(&self) -> usize {
1332        1
1333    }
1334}
1335
1336fn infer_expression_type_family(
1337    expr: &Expression,
1338    schema_map: &HashMap<String, TableSchemaEntry>,
1339    context: &TypeCheckContext,
1340) -> TypeFamily {
1341    let inference_schema = TypeInferenceSchema {
1342        schema_map,
1343        context,
1344    };
1345    let mut expr_clone = expr.clone();
1346    annotate_types(&mut expr_clone, Some(&inference_schema), None);
1347    if let Some(data_type) = expr_clone.inferred_type() {
1348        let family = data_type_family(&data_type);
1349        if family != TypeFamily::Unknown {
1350            return family;
1351        }
1352    }
1353
1354    infer_expression_type_family_fallback(expr, schema_map, context)
1355}
1356
1357fn infer_expression_type_family_fallback(
1358    expr: &Expression,
1359    schema_map: &HashMap<String, TableSchemaEntry>,
1360    context: &TypeCheckContext,
1361) -> TypeFamily {
1362    match expr {
1363        Expression::Literal(literal) => match literal {
1364            crate::expressions::Literal::Number(value) => {
1365                if value.contains('.') || value.contains('e') || value.contains('E') {
1366                    TypeFamily::Numeric
1367                } else {
1368                    TypeFamily::Integer
1369                }
1370            }
1371            crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1372            crate::expressions::Literal::Date(_) => TypeFamily::Date,
1373            crate::expressions::Literal::Time(_) => TypeFamily::Time,
1374            crate::expressions::Literal::Timestamp(_)
1375            | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1376            crate::expressions::Literal::HexString(_)
1377            | crate::expressions::Literal::BitString(_)
1378            | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1379            _ => TypeFamily::String,
1380        },
1381        Expression::Boolean(_) => TypeFamily::Boolean,
1382        Expression::Null(_) => TypeFamily::Unknown,
1383        Expression::Column(column) => resolve_column_type(column, schema_map, context),
1384        Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1385            data_type_family(&cast.to)
1386        }
1387        Expression::Alias(alias) => {
1388            infer_expression_type_family_fallback(&alias.this, schema_map, context)
1389        }
1390        Expression::Neg(unary) => {
1391            infer_expression_type_family_fallback(&unary.this, schema_map, context)
1392        }
1393        Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1394            let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1395            let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1396            if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1397                TypeFamily::Unknown
1398            } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1399                TypeFamily::Integer
1400            } else if left.is_numeric() && right.is_numeric() {
1401                TypeFamily::Numeric
1402            } else if left.is_temporal() || right.is_temporal() {
1403                left
1404            } else {
1405                TypeFamily::Unknown
1406            }
1407        }
1408        Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1409        Expression::Concat(_) => TypeFamily::String,
1410        Expression::Eq(_)
1411        | Expression::Neq(_)
1412        | Expression::Lt(_)
1413        | Expression::Lte(_)
1414        | Expression::Gt(_)
1415        | Expression::Gte(_)
1416        | Expression::Like(_)
1417        | Expression::ILike(_)
1418        | Expression::And(_)
1419        | Expression::Or(_)
1420        | Expression::Not(_)
1421        | Expression::Between(_)
1422        | Expression::In(_)
1423        | Expression::IsNull(_)
1424        | Expression::IsTrue(_)
1425        | Expression::IsFalse(_)
1426        | Expression::Is(_) => TypeFamily::Boolean,
1427        Expression::Length(_) => TypeFamily::Integer,
1428        Expression::Upper(_)
1429        | Expression::Lower(_)
1430        | Expression::Trim(_)
1431        | Expression::LTrim(_)
1432        | Expression::RTrim(_)
1433        | Expression::Replace(_)
1434        | Expression::Substring(_)
1435        | Expression::Left(_)
1436        | Expression::Right(_)
1437        | Expression::Repeat(_)
1438        | Expression::Lpad(_)
1439        | Expression::Rpad(_)
1440        | Expression::ConcatWs(_) => TypeFamily::String,
1441        Expression::Abs(_)
1442        | Expression::Round(_)
1443        | Expression::Floor(_)
1444        | Expression::Ceil(_)
1445        | Expression::Power(_)
1446        | Expression::Sqrt(_)
1447        | Expression::Cbrt(_)
1448        | Expression::Ln(_)
1449        | Expression::Log(_)
1450        | Expression::Exp(_) => TypeFamily::Numeric,
1451        Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1452        Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1453        Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1454        Expression::CurrentDate(_) => TypeFamily::Date,
1455        Expression::CurrentTime(_) => TypeFamily::Time,
1456        Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1457            TypeFamily::Timestamp
1458        }
1459        Expression::Interval(_) => TypeFamily::Interval,
1460        _ => TypeFamily::Unknown,
1461    }
1462}
1463
1464fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1465    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1466        return true;
1467    }
1468    if left == right {
1469        return true;
1470    }
1471    if left.is_numeric() && right.is_numeric() {
1472        return true;
1473    }
1474    if left.is_temporal() && right.is_temporal() {
1475        return true;
1476    }
1477    false
1478}
1479
1480fn check_function_argument(
1481    errors: &mut Vec<ValidationError>,
1482    strict: bool,
1483    function_name: &str,
1484    arg_index: usize,
1485    family: TypeFamily,
1486    expected: &str,
1487    valid: bool,
1488) {
1489    if family == TypeFamily::Unknown || valid {
1490        return;
1491    }
1492
1493    errors.push(type_issue(
1494        strict,
1495        validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1496        validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1497        format!(
1498            "Function '{}' argument {} expects {}, found {}",
1499            function_name,
1500            arg_index + 1,
1501            expected,
1502            type_family_name(family)
1503        ),
1504    ));
1505}
1506
1507fn function_dispatch_name(name: &str) -> String {
1508    let upper = name
1509        .rsplit('.')
1510        .next()
1511        .unwrap_or(name)
1512        .trim()
1513        .to_uppercase();
1514    lower(canonical_typed_function_name_upper(&upper))
1515}
1516
1517fn function_base_name(name: &str) -> &str {
1518    name.rsplit('.').next().unwrap_or(name).trim()
1519}
1520
1521fn check_generic_function(
1522    function: &Function,
1523    schema_map: &HashMap<String, TableSchemaEntry>,
1524    context: &TypeCheckContext,
1525    strict: bool,
1526    errors: &mut Vec<ValidationError>,
1527) {
1528    let name = function_dispatch_name(&function.name);
1529
1530    let arg_family = |index: usize| -> Option<TypeFamily> {
1531        function
1532            .args
1533            .get(index)
1534            .map(|arg| infer_expression_type_family(arg, schema_map, context))
1535    };
1536
1537    match name.as_str() {
1538        "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1539            if let Some(family) = arg_family(0) {
1540                check_function_argument(
1541                    errors,
1542                    strict,
1543                    &name,
1544                    0,
1545                    family,
1546                    "a numeric argument",
1547                    family.is_numeric(),
1548                );
1549            }
1550        }
1551        "round" | "floor" | "ceil" | "ceiling" => {
1552            if let Some(family) = arg_family(0) {
1553                check_function_argument(
1554                    errors,
1555                    strict,
1556                    &name,
1557                    0,
1558                    family,
1559                    "a numeric argument",
1560                    family.is_numeric(),
1561                );
1562            }
1563            if let Some(family) = arg_family(1) {
1564                check_function_argument(
1565                    errors,
1566                    strict,
1567                    &name,
1568                    1,
1569                    family,
1570                    "a numeric argument",
1571                    family.is_numeric(),
1572                );
1573            }
1574        }
1575        "power" | "pow" => {
1576            for i in [0_usize, 1_usize] {
1577                if let Some(family) = arg_family(i) {
1578                    check_function_argument(
1579                        errors,
1580                        strict,
1581                        &name,
1582                        i,
1583                        family,
1584                        "a numeric argument",
1585                        family.is_numeric(),
1586                    );
1587                }
1588            }
1589        }
1590        "length" | "char_length" | "character_length" => {
1591            if let Some(family) = arg_family(0) {
1592                check_function_argument(
1593                    errors,
1594                    strict,
1595                    &name,
1596                    0,
1597                    family,
1598                    "a string or binary argument",
1599                    is_string_or_binary(family),
1600                );
1601            }
1602        }
1603        "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1604            if let Some(family) = arg_family(0) {
1605                check_function_argument(
1606                    errors,
1607                    strict,
1608                    &name,
1609                    0,
1610                    family,
1611                    "a string argument",
1612                    is_string_like(family),
1613                );
1614            }
1615        }
1616        "substring" | "substr" => {
1617            if let Some(family) = arg_family(0) {
1618                check_function_argument(
1619                    errors,
1620                    strict,
1621                    &name,
1622                    0,
1623                    family,
1624                    "a string argument",
1625                    is_string_like(family),
1626                );
1627            }
1628            if let Some(family) = arg_family(1) {
1629                check_function_argument(
1630                    errors,
1631                    strict,
1632                    &name,
1633                    1,
1634                    family,
1635                    "a numeric argument",
1636                    family.is_numeric(),
1637                );
1638            }
1639            if let Some(family) = arg_family(2) {
1640                check_function_argument(
1641                    errors,
1642                    strict,
1643                    &name,
1644                    2,
1645                    family,
1646                    "a numeric argument",
1647                    family.is_numeric(),
1648                );
1649            }
1650        }
1651        "replace" => {
1652            for i in [0_usize, 1_usize, 2_usize] {
1653                if let Some(family) = arg_family(i) {
1654                    check_function_argument(
1655                        errors,
1656                        strict,
1657                        &name,
1658                        i,
1659                        family,
1660                        "a string argument",
1661                        is_string_like(family),
1662                    );
1663                }
1664            }
1665        }
1666        "left" | "right" | "repeat" | "lpad" | "rpad" => {
1667            if let Some(family) = arg_family(0) {
1668                check_function_argument(
1669                    errors,
1670                    strict,
1671                    &name,
1672                    0,
1673                    family,
1674                    "a string argument",
1675                    is_string_like(family),
1676                );
1677            }
1678            if let Some(family) = arg_family(1) {
1679                check_function_argument(
1680                    errors,
1681                    strict,
1682                    &name,
1683                    1,
1684                    family,
1685                    "a numeric argument",
1686                    family.is_numeric(),
1687                );
1688            }
1689            if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1690                if let Some(family) = arg_family(2) {
1691                    check_function_argument(
1692                        errors,
1693                        strict,
1694                        &name,
1695                        2,
1696                        family,
1697                        "a string argument",
1698                        is_string_like(family),
1699                    );
1700                }
1701            }
1702        }
1703        _ => {}
1704    }
1705}
1706
1707fn check_function_catalog(
1708    function: &Function,
1709    dialect: DialectType,
1710    function_catalog: Option<&dyn FunctionCatalog>,
1711    strict: bool,
1712    errors: &mut Vec<ValidationError>,
1713) {
1714    let Some(catalog) = function_catalog else {
1715        return;
1716    };
1717
1718    let raw_name = function_base_name(&function.name);
1719    let normalized_name = function_dispatch_name(&function.name);
1720    let arity = function.args.len();
1721    let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1722        errors.push(if strict {
1723            ValidationError::error(
1724                format!(
1725                    "Unknown function '{}' for dialect {:?}",
1726                    function.name, dialect
1727                ),
1728                validation_codes::E_UNKNOWN_FUNCTION,
1729            )
1730        } else {
1731            ValidationError::warning(
1732                format!(
1733                    "Unknown function '{}' for dialect {:?}",
1734                    function.name, dialect
1735                ),
1736                validation_codes::E_UNKNOWN_FUNCTION,
1737            )
1738        });
1739        return;
1740    };
1741
1742    if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1743        return;
1744    }
1745
1746    let expected = signatures
1747        .iter()
1748        .map(|sig| sig.describe_arity())
1749        .collect::<Vec<_>>()
1750        .join(", ");
1751    errors.push(if strict {
1752        ValidationError::error(
1753            format!(
1754                "Invalid arity for function '{}': got {}, expected {}",
1755                function.name, arity, expected
1756            ),
1757            validation_codes::E_INVALID_FUNCTION_ARITY,
1758        )
1759    } else {
1760        ValidationError::warning(
1761            format!(
1762                "Invalid arity for function '{}': got {}, expected {}",
1763                function.name, arity, expected
1764            ),
1765            validation_codes::E_INVALID_FUNCTION_ARITY,
1766        )
1767    });
1768}
1769
1770#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1771struct DeclaredRelationship {
1772    source_table: String,
1773    source_column: String,
1774    target_table: String,
1775    target_column: String,
1776}
1777
1778fn build_declared_relationships(
1779    schema: &ValidationSchema,
1780    schema_map: &HashMap<String, TableSchemaEntry>,
1781) -> Vec<DeclaredRelationship> {
1782    let mut relationships = HashSet::new();
1783
1784    for table in &schema.tables {
1785        let Some(source_key) =
1786            resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1787        else {
1788            continue;
1789        };
1790
1791        for column in &table.columns {
1792            let Some(reference) = &column.references else {
1793                continue;
1794            };
1795            let Some(target_key) = resolve_reference_table_key(
1796                &reference.table,
1797                reference.schema.as_deref(),
1798                table.schema.as_deref(),
1799                schema_map,
1800            ) else {
1801                continue;
1802            };
1803
1804            relationships.insert(DeclaredRelationship {
1805                source_table: source_key.clone(),
1806                source_column: lower(&column.name),
1807                target_table: target_key,
1808                target_column: lower(&reference.column),
1809            });
1810        }
1811
1812        for foreign_key in &table.foreign_keys {
1813            if foreign_key.columns.len() != foreign_key.references.columns.len() {
1814                continue;
1815            }
1816            let Some(target_key) = resolve_reference_table_key(
1817                &foreign_key.references.table,
1818                foreign_key.references.schema.as_deref(),
1819                table.schema.as_deref(),
1820                schema_map,
1821            ) else {
1822                continue;
1823            };
1824
1825            for (source_col, target_col) in foreign_key
1826                .columns
1827                .iter()
1828                .zip(foreign_key.references.columns.iter())
1829            {
1830                relationships.insert(DeclaredRelationship {
1831                    source_table: source_key.clone(),
1832                    source_column: lower(source_col),
1833                    target_table: target_key.clone(),
1834                    target_column: lower(target_col),
1835                });
1836            }
1837        }
1838    }
1839
1840    relationships.into_iter().collect()
1841}
1842
1843fn resolve_column_binding(
1844    column: &Column,
1845    schema_map: &HashMap<String, TableSchemaEntry>,
1846    context: &TypeCheckContext,
1847    resolver: &mut Resolver<'_>,
1848) -> Option<(String, String)> {
1849    let column_name = lower(&column.name.name);
1850    if column_name.is_empty() {
1851        return None;
1852    }
1853
1854    if let Some(table) = &column.table {
1855        let mut table_key = lower(&table.name);
1856        if let Some(mapped) = context.table_aliases.get(&table_key) {
1857            table_key = mapped.clone();
1858        }
1859        if schema_map.contains_key(&table_key) {
1860            return Some((table_key, column_name));
1861        }
1862        return None;
1863    }
1864
1865    if let Some(resolved_source) = resolver.get_table(&column_name) {
1866        let mut table_key = lower(&resolved_source);
1867        if let Some(mapped) = context.table_aliases.get(&table_key) {
1868            table_key = mapped.clone();
1869        }
1870        if schema_map.contains_key(&table_key) {
1871            return Some((table_key, column_name));
1872        }
1873    }
1874
1875    let candidates: Vec<String> = context
1876        .referenced_tables
1877        .iter()
1878        .filter_map(|table_name| {
1879            schema_map
1880                .get(table_name)
1881                .filter(|entry| entry.columns.contains_key(&column_name))
1882                .map(|_| table_name.clone())
1883        })
1884        .collect();
1885    if candidates.len() == 1 {
1886        return Some((candidates[0].clone(), column_name));
1887    }
1888    None
1889}
1890
1891fn extract_join_equality_pairs(
1892    expr: &Expression,
1893    schema_map: &HashMap<String, TableSchemaEntry>,
1894    context: &TypeCheckContext,
1895    resolver: &mut Resolver<'_>,
1896    pairs: &mut Vec<((String, String), (String, String))>,
1897) {
1898    match expr {
1899        Expression::And(op) => {
1900            extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1901            extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1902        }
1903        Expression::Paren(paren) => {
1904            extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1905        }
1906        Expression::Eq(op) => {
1907            let (Expression::Column(left_col), Expression::Column(right_col)) =
1908                (&op.left, &op.right)
1909            else {
1910                return;
1911            };
1912            let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1913                return;
1914            };
1915            let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1916            else {
1917                return;
1918            };
1919            pairs.push((left, right));
1920        }
1921        _ => {}
1922    }
1923}
1924
1925fn relationship_matches_pair(
1926    relationship: &DeclaredRelationship,
1927    left_table: &str,
1928    left_column: &str,
1929    right_table: &str,
1930    right_column: &str,
1931) -> bool {
1932    (relationship.source_table == left_table
1933        && relationship.source_column == left_column
1934        && relationship.target_table == right_table
1935        && relationship.target_column == right_column)
1936        || (relationship.source_table == right_table
1937            && relationship.source_column == right_column
1938            && relationship.target_table == left_table
1939            && relationship.target_column == left_column)
1940}
1941
1942fn resolved_table_key_from_expr(
1943    expr: &Expression,
1944    schema_map: &HashMap<String, TableSchemaEntry>,
1945) -> Option<String> {
1946    match expr {
1947        Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1948        Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1949        _ => None,
1950    }
1951}
1952
1953fn select_from_table_keys(
1954    select: &crate::expressions::Select,
1955    schema_map: &HashMap<String, TableSchemaEntry>,
1956) -> HashSet<String> {
1957    let mut keys = HashSet::new();
1958    if let Some(from_clause) = &select.from {
1959        for expr in &from_clause.expressions {
1960            if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1961                keys.insert(key);
1962            }
1963        }
1964    }
1965    keys
1966}
1967
1968fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1969    matches!(
1970        kind,
1971        JoinKind::Natural
1972            | JoinKind::NaturalLeft
1973            | JoinKind::NaturalRight
1974            | JoinKind::NaturalFull
1975            | JoinKind::CrossApply
1976            | JoinKind::OuterApply
1977            | JoinKind::AsOf
1978            | JoinKind::AsOfLeft
1979            | JoinKind::AsOfRight
1980            | JoinKind::Lateral
1981            | JoinKind::LeftLateral
1982    )
1983}
1984
1985fn check_query_reference_quality(
1986    stmt: &Expression,
1987    schema_map: &HashMap<String, TableSchemaEntry>,
1988    resolver_schema: &MappingSchema,
1989    strict: bool,
1990    relationships: &[DeclaredRelationship],
1991) -> Vec<ValidationError> {
1992    let mut errors = Vec::new();
1993
1994    for node in stmt.dfs() {
1995        let Expression::Select(select) = node else {
1996            continue;
1997        };
1998
1999        let select_expr = Expression::Select(select.clone());
2000        let context = collect_type_check_context(&select_expr, schema_map);
2001        let scope = build_scope(&select_expr);
2002        let mut resolver = Resolver::new(&scope, resolver_schema, true);
2003
2004        if context.referenced_tables.len() > 1 {
2005            let using_columns: HashSet<String> = select
2006                .joins
2007                .iter()
2008                .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2009                .collect();
2010
2011            let mut seen = HashSet::new();
2012            for column_expr in select_expr
2013                .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2014            {
2015                let Expression::Column(column) = column_expr else {
2016                    continue;
2017                };
2018
2019                let col_name = lower(&column.name.name);
2020                if col_name.is_empty()
2021                    || using_columns.contains(&col_name)
2022                    || !seen.insert(col_name.clone())
2023                {
2024                    continue;
2025                }
2026
2027                if resolver.is_ambiguous(&col_name) {
2028                    let source_count = resolver.sources_for_column(&col_name).len();
2029                    errors.push(if strict {
2030                        ValidationError::error(
2031                            format!(
2032                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2033                                col_name, source_count
2034                            ),
2035                            validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2036                        )
2037                    } else {
2038                        ValidationError::warning(
2039                            format!(
2040                                "Ambiguous unqualified column '{}' found in {} referenced tables",
2041                                col_name, source_count
2042                            ),
2043                            validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2044                        )
2045                    });
2046                }
2047            }
2048        }
2049
2050        let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2051
2052        for join in &select.joins {
2053            let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2054            let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2055            let cartesian_like_kind = matches!(
2056                join.kind,
2057                JoinKind::Cross
2058                    | JoinKind::Implicit
2059                    | JoinKind::Array
2060                    | JoinKind::LeftArray
2061                    | JoinKind::Paste
2062            );
2063
2064            if right_table_key.is_some()
2065                && (cartesian_like_kind
2066                    || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2067            {
2068                errors.push(ValidationError::warning(
2069                    "Potential cartesian join: JOIN without ON/USING condition",
2070                    validation_codes::W_CARTESIAN_JOIN,
2071                ));
2072            }
2073
2074            if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2075                if join.using.is_empty() {
2076                    let mut eq_pairs = Vec::new();
2077                    extract_join_equality_pairs(
2078                        on_expr,
2079                        schema_map,
2080                        &context,
2081                        &mut resolver,
2082                        &mut eq_pairs,
2083                    );
2084
2085                    let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2086                        .iter()
2087                        .filter(|rel| {
2088                            cumulative_left_tables.contains(&rel.source_table)
2089                                && rel.target_table == right_key
2090                                || (cumulative_left_tables.contains(&rel.target_table)
2091                                    && rel.source_table == right_key)
2092                        })
2093                        .collect();
2094
2095                    if !relevant_relationships.is_empty() {
2096                        let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2097                            relevant_relationships
2098                                .iter()
2099                                .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2100                        });
2101                        if !uses_declared_fk {
2102                            errors.push(ValidationError::warning(
2103                                "JOIN predicate does not use declared foreign-key relationship columns",
2104                                validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2105                            ));
2106                        }
2107                    }
2108                }
2109            }
2110
2111            if let Some(right_key) = right_table_key {
2112                cumulative_left_tables.insert(right_key);
2113            }
2114        }
2115    }
2116
2117    errors
2118}
2119
2120fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2121    if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2122        return true;
2123    }
2124    if left == right {
2125        return true;
2126    }
2127    if left.is_numeric() && right.is_numeric() {
2128        return true;
2129    }
2130    if left.is_temporal() && right.is_temporal() {
2131        return true;
2132    }
2133    false
2134}
2135
2136fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2137    if left == TypeFamily::Unknown {
2138        return right;
2139    }
2140    if right == TypeFamily::Unknown {
2141        return left;
2142    }
2143    if left == right {
2144        return left;
2145    }
2146    if left.is_numeric() && right.is_numeric() {
2147        if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2148            return TypeFamily::Numeric;
2149        }
2150        return TypeFamily::Integer;
2151    }
2152    if left.is_temporal() && right.is_temporal() {
2153        if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2154            return TypeFamily::Timestamp;
2155        }
2156        if left == TypeFamily::Date || right == TypeFamily::Date {
2157            return TypeFamily::Date;
2158        }
2159        return TypeFamily::Time;
2160    }
2161    TypeFamily::Unknown
2162}
2163
2164fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2165    if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2166        return true;
2167    }
2168    if target == source {
2169        return true;
2170    }
2171
2172    match target {
2173        TypeFamily::Boolean => source == TypeFamily::Boolean,
2174        TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2175        TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2176            source.is_temporal()
2177        }
2178        TypeFamily::String => true,
2179        TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2180        TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2181        TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2182        TypeFamily::Array => source == TypeFamily::Array,
2183        TypeFamily::Map => source == TypeFamily::Map,
2184        TypeFamily::Struct => source == TypeFamily::Struct,
2185        TypeFamily::Unknown => true,
2186    }
2187}
2188
2189fn projection_families(
2190    query_expr: &Expression,
2191    schema_map: &HashMap<String, TableSchemaEntry>,
2192) -> Option<Vec<TypeFamily>> {
2193    match query_expr {
2194        Expression::Select(select) => {
2195            if select
2196                .expressions
2197                .iter()
2198                .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2199            {
2200                return None;
2201            }
2202            let select_expr = Expression::Select(select.clone());
2203            let context = collect_type_check_context(&select_expr, schema_map);
2204            Some(
2205                select
2206                    .expressions
2207                    .iter()
2208                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2209                    .collect(),
2210            )
2211        }
2212        Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2213        Expression::Union(union) => {
2214            let left = projection_families(&union.left, schema_map)?;
2215            let right = projection_families(&union.right, schema_map)?;
2216            if left.len() != right.len() {
2217                return None;
2218            }
2219            Some(
2220                left.into_iter()
2221                    .zip(right)
2222                    .map(|(l, r)| merged_setop_family(l, r))
2223                    .collect(),
2224            )
2225        }
2226        Expression::Intersect(intersect) => {
2227            let left = projection_families(&intersect.left, schema_map)?;
2228            let right = projection_families(&intersect.right, schema_map)?;
2229            if left.len() != right.len() {
2230                return None;
2231            }
2232            Some(
2233                left.into_iter()
2234                    .zip(right)
2235                    .map(|(l, r)| merged_setop_family(l, r))
2236                    .collect(),
2237            )
2238        }
2239        Expression::Except(except) => {
2240            let left = projection_families(&except.left, schema_map)?;
2241            let right = projection_families(&except.right, schema_map)?;
2242            if left.len() != right.len() {
2243                return None;
2244            }
2245            Some(
2246                left.into_iter()
2247                    .zip(right)
2248                    .map(|(l, r)| merged_setop_family(l, r))
2249                    .collect(),
2250            )
2251        }
2252        Expression::Values(values) => {
2253            let first_row = values.expressions.first()?;
2254            let context = TypeCheckContext::default();
2255            Some(
2256                first_row
2257                    .expressions
2258                    .iter()
2259                    .map(|e| infer_expression_type_family(e, schema_map, &context))
2260                    .collect(),
2261            )
2262        }
2263        _ => None,
2264    }
2265}
2266
2267fn check_set_operation_compatibility(
2268    op_name: &str,
2269    left_expr: &Expression,
2270    right_expr: &Expression,
2271    schema_map: &HashMap<String, TableSchemaEntry>,
2272    strict: bool,
2273    errors: &mut Vec<ValidationError>,
2274) {
2275    let Some(left_projection) = projection_families(left_expr, schema_map) else {
2276        return;
2277    };
2278    let Some(right_projection) = projection_families(right_expr, schema_map) else {
2279        return;
2280    };
2281
2282    if left_projection.len() != right_projection.len() {
2283        errors.push(type_issue(
2284            strict,
2285            validation_codes::E_SETOP_ARITY_MISMATCH,
2286            validation_codes::W_SETOP_IMPLICIT_COERCION,
2287            format!(
2288                "{} operands return different column counts: left {}, right {}",
2289                op_name,
2290                left_projection.len(),
2291                right_projection.len()
2292            ),
2293        ));
2294        return;
2295    }
2296
2297    for (idx, (left, right)) in left_projection
2298        .into_iter()
2299        .zip(right_projection)
2300        .enumerate()
2301    {
2302        if !are_setop_compatible(left, right) {
2303            errors.push(type_issue(
2304                strict,
2305                validation_codes::E_SETOP_TYPE_MISMATCH,
2306                validation_codes::W_SETOP_IMPLICIT_COERCION,
2307                format!(
2308                    "{} column {} has incompatible types: {} vs {}",
2309                    op_name,
2310                    idx + 1,
2311                    type_family_name(left),
2312                    type_family_name(right)
2313                ),
2314            ));
2315        }
2316    }
2317}
2318
2319fn check_insert_assignments(
2320    stmt: &Expression,
2321    insert: &Insert,
2322    schema_map: &HashMap<String, TableSchemaEntry>,
2323    strict: bool,
2324    errors: &mut Vec<ValidationError>,
2325) {
2326    let Some((target_table_key, table_schema)) =
2327        resolve_table_schema_entry(&insert.table, schema_map)
2328    else {
2329        return;
2330    };
2331
2332    let mut target_columns = Vec::new();
2333    if insert.columns.is_empty() {
2334        target_columns.extend(table_schema.column_order.iter().cloned());
2335    } else {
2336        for column in &insert.columns {
2337            let col_name = lower(&column.name);
2338            if table_schema.columns.contains_key(&col_name) {
2339                target_columns.push(col_name);
2340            } else {
2341                errors.push(if strict {
2342                    ValidationError::error(
2343                        format!(
2344                            "Unknown column '{}' in table '{}'",
2345                            column.name, target_table_key
2346                        ),
2347                        validation_codes::E_UNKNOWN_COLUMN,
2348                    )
2349                } else {
2350                    ValidationError::warning(
2351                        format!(
2352                            "Unknown column '{}' in table '{}'",
2353                            column.name, target_table_key
2354                        ),
2355                        validation_codes::E_UNKNOWN_COLUMN,
2356                    )
2357                });
2358            }
2359        }
2360    }
2361
2362    if target_columns.is_empty() {
2363        return;
2364    }
2365
2366    let context = collect_type_check_context(stmt, schema_map);
2367
2368    if !insert.default_values {
2369        for (row_idx, row) in insert.values.iter().enumerate() {
2370            if row.len() != target_columns.len() {
2371                errors.push(type_issue(
2372                    strict,
2373                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2374                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2375                    format!(
2376                        "INSERT row {} has {} values but target has {} columns",
2377                        row_idx + 1,
2378                        row.len(),
2379                        target_columns.len()
2380                    ),
2381                ));
2382                continue;
2383            }
2384
2385            for (value, target_column) in row.iter().zip(target_columns.iter()) {
2386                let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2387                    continue;
2388                };
2389                let source_family = infer_expression_type_family(value, schema_map, &context);
2390                if !are_assignment_compatible(target_family, source_family) {
2391                    errors.push(type_issue(
2392                        strict,
2393                        validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2394                        validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2395                        format!(
2396                            "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2397                            target_table_key,
2398                            target_column,
2399                            type_family_name(target_family),
2400                            type_family_name(source_family)
2401                        ),
2402                    ));
2403                }
2404            }
2405        }
2406    }
2407
2408    if let Some(query) = &insert.query {
2409        // DuckDB BY NAME maps source columns by name, not position.
2410        if insert.by_name {
2411            return;
2412        }
2413
2414        let Some(source_projection) = projection_families(query, schema_map) else {
2415            return;
2416        };
2417
2418        if source_projection.len() != target_columns.len() {
2419            errors.push(type_issue(
2420                strict,
2421                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2422                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2423                format!(
2424                    "INSERT source query has {} columns but target has {} columns",
2425                    source_projection.len(),
2426                    target_columns.len()
2427                ),
2428            ));
2429            return;
2430        }
2431
2432        for (source_family, target_column) in
2433            source_projection.into_iter().zip(target_columns.iter())
2434        {
2435            let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2436                continue;
2437            };
2438            if !are_assignment_compatible(target_family, source_family) {
2439                errors.push(type_issue(
2440                    strict,
2441                    validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2442                    validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2443                    format!(
2444                        "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2445                        target_table_key,
2446                        target_column,
2447                        type_family_name(target_family),
2448                        type_family_name(source_family)
2449                    ),
2450                ));
2451            }
2452        }
2453    }
2454}
2455
2456fn check_update_assignments(
2457    stmt: &Expression,
2458    update: &Update,
2459    schema_map: &HashMap<String, TableSchemaEntry>,
2460    strict: bool,
2461    errors: &mut Vec<ValidationError>,
2462) {
2463    let Some((target_table_key, table_schema)) =
2464        resolve_table_schema_entry(&update.table, schema_map)
2465    else {
2466        return;
2467    };
2468
2469    let context = collect_type_check_context(stmt, schema_map);
2470
2471    for (column, value) in &update.set {
2472        let col_name = lower(&column.name);
2473        let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2474            errors.push(if strict {
2475                ValidationError::error(
2476                    format!(
2477                        "Unknown column '{}' in table '{}'",
2478                        column.name, target_table_key
2479                    ),
2480                    validation_codes::E_UNKNOWN_COLUMN,
2481                )
2482            } else {
2483                ValidationError::warning(
2484                    format!(
2485                        "Unknown column '{}' in table '{}'",
2486                        column.name, target_table_key
2487                    ),
2488                    validation_codes::E_UNKNOWN_COLUMN,
2489                )
2490            });
2491            continue;
2492        };
2493
2494        let source_family = infer_expression_type_family(value, schema_map, &context);
2495        if !are_assignment_compatible(target_family, source_family) {
2496            errors.push(type_issue(
2497                strict,
2498                validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2499                validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2500                format!(
2501                    "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2502                    target_table_key,
2503                    col_name,
2504                    type_family_name(target_family),
2505                    type_family_name(source_family)
2506                ),
2507            ));
2508        }
2509    }
2510}
2511
2512fn check_types(
2513    stmt: &Expression,
2514    dialect: DialectType,
2515    schema_map: &HashMap<String, TableSchemaEntry>,
2516    function_catalog: Option<&dyn FunctionCatalog>,
2517    strict: bool,
2518) -> Vec<ValidationError> {
2519    let mut errors = Vec::new();
2520    let context = collect_type_check_context(stmt, schema_map);
2521
2522    for node in stmt.dfs() {
2523        match node {
2524            Expression::Insert(insert) => {
2525                check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2526            }
2527            Expression::Update(update) => {
2528                check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2529            }
2530            Expression::Union(union) => {
2531                check_set_operation_compatibility(
2532                    "UNION",
2533                    &union.left,
2534                    &union.right,
2535                    schema_map,
2536                    strict,
2537                    &mut errors,
2538                );
2539            }
2540            Expression::Intersect(intersect) => {
2541                check_set_operation_compatibility(
2542                    "INTERSECT",
2543                    &intersect.left,
2544                    &intersect.right,
2545                    schema_map,
2546                    strict,
2547                    &mut errors,
2548                );
2549            }
2550            Expression::Except(except) => {
2551                check_set_operation_compatibility(
2552                    "EXCEPT",
2553                    &except.left,
2554                    &except.right,
2555                    schema_map,
2556                    strict,
2557                    &mut errors,
2558                );
2559            }
2560            Expression::Select(select) => {
2561                if let Some(prewhere) = &select.prewhere {
2562                    let family = infer_expression_type_family(prewhere, schema_map, &context);
2563                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2564                        errors.push(type_issue(
2565                            strict,
2566                            validation_codes::E_INVALID_PREDICATE_TYPE,
2567                            validation_codes::W_PREDICATE_NULLABILITY,
2568                            format!(
2569                                "PREWHERE clause expects a boolean predicate, found {}",
2570                                type_family_name(family)
2571                            ),
2572                        ));
2573                    }
2574                }
2575
2576                if let Some(where_clause) = &select.where_clause {
2577                    let family =
2578                        infer_expression_type_family(&where_clause.this, schema_map, &context);
2579                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2580                        errors.push(type_issue(
2581                            strict,
2582                            validation_codes::E_INVALID_PREDICATE_TYPE,
2583                            validation_codes::W_PREDICATE_NULLABILITY,
2584                            format!(
2585                                "WHERE clause expects a boolean predicate, found {}",
2586                                type_family_name(family)
2587                            ),
2588                        ));
2589                    }
2590                }
2591
2592                if let Some(having_clause) = &select.having {
2593                    let family =
2594                        infer_expression_type_family(&having_clause.this, schema_map, &context);
2595                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2596                        errors.push(type_issue(
2597                            strict,
2598                            validation_codes::E_INVALID_PREDICATE_TYPE,
2599                            validation_codes::W_PREDICATE_NULLABILITY,
2600                            format!(
2601                                "HAVING clause expects a boolean predicate, found {}",
2602                                type_family_name(family)
2603                            ),
2604                        ));
2605                    }
2606                }
2607
2608                for join in &select.joins {
2609                    if let Some(on) = &join.on {
2610                        let family = infer_expression_type_family(on, schema_map, &context);
2611                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2612                            errors.push(type_issue(
2613                                strict,
2614                                validation_codes::E_INVALID_PREDICATE_TYPE,
2615                                validation_codes::W_PREDICATE_NULLABILITY,
2616                                format!(
2617                                    "JOIN ON expects a boolean predicate, found {}",
2618                                    type_family_name(family)
2619                                ),
2620                            ));
2621                        }
2622                    }
2623                    if let Some(match_condition) = &join.match_condition {
2624                        let family =
2625                            infer_expression_type_family(match_condition, schema_map, &context);
2626                        if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2627                            errors.push(type_issue(
2628                                strict,
2629                                validation_codes::E_INVALID_PREDICATE_TYPE,
2630                                validation_codes::W_PREDICATE_NULLABILITY,
2631                                format!(
2632                                    "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2633                                    type_family_name(family)
2634                                ),
2635                            ));
2636                        }
2637                    }
2638                }
2639            }
2640            Expression::Where(where_clause) => {
2641                let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2642                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2643                    errors.push(type_issue(
2644                        strict,
2645                        validation_codes::E_INVALID_PREDICATE_TYPE,
2646                        validation_codes::W_PREDICATE_NULLABILITY,
2647                        format!(
2648                            "WHERE clause expects a boolean predicate, found {}",
2649                            type_family_name(family)
2650                        ),
2651                    ));
2652                }
2653            }
2654            Expression::Having(having_clause) => {
2655                let family =
2656                    infer_expression_type_family(&having_clause.this, schema_map, &context);
2657                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2658                    errors.push(type_issue(
2659                        strict,
2660                        validation_codes::E_INVALID_PREDICATE_TYPE,
2661                        validation_codes::W_PREDICATE_NULLABILITY,
2662                        format!(
2663                            "HAVING clause expects a boolean predicate, found {}",
2664                            type_family_name(family)
2665                        ),
2666                    ));
2667                }
2668            }
2669            Expression::And(op) | Expression::Or(op) => {
2670                for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2671                    let family = infer_expression_type_family(expr, schema_map, &context);
2672                    if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2673                        errors.push(type_issue(
2674                            strict,
2675                            validation_codes::E_INVALID_PREDICATE_TYPE,
2676                            validation_codes::W_PREDICATE_NULLABILITY,
2677                            format!(
2678                                "Logical {} operand expects boolean, found {}",
2679                                side,
2680                                type_family_name(family)
2681                            ),
2682                        ));
2683                    }
2684                }
2685            }
2686            Expression::Not(unary) => {
2687                let family = infer_expression_type_family(&unary.this, schema_map, &context);
2688                if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2689                    errors.push(type_issue(
2690                        strict,
2691                        validation_codes::E_INVALID_PREDICATE_TYPE,
2692                        validation_codes::W_PREDICATE_NULLABILITY,
2693                        format!("NOT expects boolean, found {}", type_family_name(family)),
2694                    ));
2695                }
2696            }
2697            Expression::Eq(op)
2698            | Expression::Neq(op)
2699            | Expression::Lt(op)
2700            | Expression::Lte(op)
2701            | Expression::Gt(op)
2702            | Expression::Gte(op) => {
2703                let left = infer_expression_type_family(&op.left, schema_map, &context);
2704                let right = infer_expression_type_family(&op.right, schema_map, &context);
2705                if !are_comparable(left, right) {
2706                    errors.push(type_issue(
2707                        strict,
2708                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2709                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2710                        format!(
2711                            "Incompatible comparison between {} and {}",
2712                            type_family_name(left),
2713                            type_family_name(right)
2714                        ),
2715                    ));
2716                }
2717            }
2718            Expression::Like(op) | Expression::ILike(op) => {
2719                let left = infer_expression_type_family(&op.left, schema_map, &context);
2720                let right = infer_expression_type_family(&op.right, schema_map, &context);
2721                if left != TypeFamily::Unknown
2722                    && right != TypeFamily::Unknown
2723                    && (!is_string_like(left) || !is_string_like(right))
2724                {
2725                    errors.push(type_issue(
2726                        strict,
2727                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2728                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2729                        format!(
2730                            "LIKE/ILIKE expects string operands, found {} and {}",
2731                            type_family_name(left),
2732                            type_family_name(right)
2733                        ),
2734                    ));
2735                }
2736            }
2737            Expression::Between(between) => {
2738                let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2739                let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2740                let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2741
2742                if !are_comparable(this_family, low_family)
2743                    || !are_comparable(this_family, high_family)
2744                {
2745                    errors.push(type_issue(
2746                        strict,
2747                        validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2748                        validation_codes::W_IMPLICIT_CAST_COMPARISON,
2749                        format!(
2750                            "BETWEEN bounds are incompatible with {} (found {} and {})",
2751                            type_family_name(this_family),
2752                            type_family_name(low_family),
2753                            type_family_name(high_family)
2754                        ),
2755                    ));
2756                }
2757            }
2758            Expression::In(in_expr) => {
2759                let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2760                for value in &in_expr.expressions {
2761                    let item_family = infer_expression_type_family(value, schema_map, &context);
2762                    if !are_comparable(this_family, item_family) {
2763                        errors.push(type_issue(
2764                            strict,
2765                            validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2766                            validation_codes::W_IMPLICIT_CAST_COMPARISON,
2767                            format!(
2768                                "IN item type {} is incompatible with {}",
2769                                type_family_name(item_family),
2770                                type_family_name(this_family)
2771                            ),
2772                        ));
2773                        break;
2774                    }
2775                }
2776            }
2777            Expression::Add(op)
2778            | Expression::Sub(op)
2779            | Expression::Mul(op)
2780            | Expression::Div(op)
2781            | Expression::Mod(op) => {
2782                let left = infer_expression_type_family(&op.left, schema_map, &context);
2783                let right = infer_expression_type_family(&op.right, schema_map, &context);
2784
2785                if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2786                    continue;
2787                }
2788
2789                let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2790                    && ((left.is_temporal() && right.is_numeric())
2791                        || (right.is_temporal() && left.is_numeric())
2792                        || (matches!(node, Expression::Sub(_))
2793                            && left.is_temporal()
2794                            && right.is_temporal()));
2795
2796                if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2797                    errors.push(type_issue(
2798                        strict,
2799                        validation_codes::E_INVALID_ARITHMETIC_TYPE,
2800                        validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2801                        format!(
2802                            "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2803                            type_family_name(left),
2804                            type_family_name(right)
2805                        ),
2806                    ));
2807                }
2808            }
2809            Expression::Function(function) => {
2810                check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2811                check_generic_function(function, schema_map, &context, strict, &mut errors);
2812            }
2813            Expression::Upper(func)
2814            | Expression::Lower(func)
2815            | Expression::LTrim(func)
2816            | Expression::RTrim(func)
2817            | Expression::Reverse(func) => {
2818                let family = infer_expression_type_family(&func.this, schema_map, &context);
2819                check_function_argument(
2820                    &mut errors,
2821                    strict,
2822                    "string_function",
2823                    0,
2824                    family,
2825                    "a string argument",
2826                    is_string_like(family),
2827                );
2828            }
2829            Expression::Length(func) => {
2830                let family = infer_expression_type_family(&func.this, schema_map, &context);
2831                check_function_argument(
2832                    &mut errors,
2833                    strict,
2834                    "length",
2835                    0,
2836                    family,
2837                    "a string or binary argument",
2838                    is_string_or_binary(family),
2839                );
2840            }
2841            Expression::Trim(func) => {
2842                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2843                check_function_argument(
2844                    &mut errors,
2845                    strict,
2846                    "trim",
2847                    0,
2848                    this_family,
2849                    "a string argument",
2850                    is_string_like(this_family),
2851                );
2852                if let Some(chars) = &func.characters {
2853                    let chars_family = infer_expression_type_family(chars, schema_map, &context);
2854                    check_function_argument(
2855                        &mut errors,
2856                        strict,
2857                        "trim",
2858                        1,
2859                        chars_family,
2860                        "a string argument",
2861                        is_string_like(chars_family),
2862                    );
2863                }
2864            }
2865            Expression::Substring(func) => {
2866                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2867                check_function_argument(
2868                    &mut errors,
2869                    strict,
2870                    "substring",
2871                    0,
2872                    this_family,
2873                    "a string argument",
2874                    is_string_like(this_family),
2875                );
2876
2877                let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2878                check_function_argument(
2879                    &mut errors,
2880                    strict,
2881                    "substring",
2882                    1,
2883                    start_family,
2884                    "a numeric argument",
2885                    start_family.is_numeric(),
2886                );
2887                if let Some(length) = &func.length {
2888                    let length_family = infer_expression_type_family(length, schema_map, &context);
2889                    check_function_argument(
2890                        &mut errors,
2891                        strict,
2892                        "substring",
2893                        2,
2894                        length_family,
2895                        "a numeric argument",
2896                        length_family.is_numeric(),
2897                    );
2898                }
2899            }
2900            Expression::Replace(func) => {
2901                for (arg, idx) in [
2902                    (&func.this, 0_usize),
2903                    (&func.old, 1_usize),
2904                    (&func.new, 2_usize),
2905                ] {
2906                    let family = infer_expression_type_family(arg, schema_map, &context);
2907                    check_function_argument(
2908                        &mut errors,
2909                        strict,
2910                        "replace",
2911                        idx,
2912                        family,
2913                        "a string argument",
2914                        is_string_like(family),
2915                    );
2916                }
2917            }
2918            Expression::Left(func) | Expression::Right(func) => {
2919                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2920                check_function_argument(
2921                    &mut errors,
2922                    strict,
2923                    "left_right",
2924                    0,
2925                    this_family,
2926                    "a string argument",
2927                    is_string_like(this_family),
2928                );
2929                let length_family =
2930                    infer_expression_type_family(&func.length, schema_map, &context);
2931                check_function_argument(
2932                    &mut errors,
2933                    strict,
2934                    "left_right",
2935                    1,
2936                    length_family,
2937                    "a numeric argument",
2938                    length_family.is_numeric(),
2939                );
2940            }
2941            Expression::Repeat(func) => {
2942                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2943                check_function_argument(
2944                    &mut errors,
2945                    strict,
2946                    "repeat",
2947                    0,
2948                    this_family,
2949                    "a string argument",
2950                    is_string_like(this_family),
2951                );
2952                let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2953                check_function_argument(
2954                    &mut errors,
2955                    strict,
2956                    "repeat",
2957                    1,
2958                    times_family,
2959                    "a numeric argument",
2960                    times_family.is_numeric(),
2961                );
2962            }
2963            Expression::Lpad(func) | Expression::Rpad(func) => {
2964                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2965                check_function_argument(
2966                    &mut errors,
2967                    strict,
2968                    "pad",
2969                    0,
2970                    this_family,
2971                    "a string argument",
2972                    is_string_like(this_family),
2973                );
2974                let length_family =
2975                    infer_expression_type_family(&func.length, schema_map, &context);
2976                check_function_argument(
2977                    &mut errors,
2978                    strict,
2979                    "pad",
2980                    1,
2981                    length_family,
2982                    "a numeric argument",
2983                    length_family.is_numeric(),
2984                );
2985                if let Some(fill) = &func.fill {
2986                    let fill_family = infer_expression_type_family(fill, schema_map, &context);
2987                    check_function_argument(
2988                        &mut errors,
2989                        strict,
2990                        "pad",
2991                        2,
2992                        fill_family,
2993                        "a string argument",
2994                        is_string_like(fill_family),
2995                    );
2996                }
2997            }
2998            Expression::Abs(func)
2999            | Expression::Sqrt(func)
3000            | Expression::Cbrt(func)
3001            | Expression::Ln(func)
3002            | Expression::Exp(func) => {
3003                let family = infer_expression_type_family(&func.this, schema_map, &context);
3004                check_function_argument(
3005                    &mut errors,
3006                    strict,
3007                    "numeric_function",
3008                    0,
3009                    family,
3010                    "a numeric argument",
3011                    family.is_numeric(),
3012                );
3013            }
3014            Expression::Round(func) => {
3015                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3016                check_function_argument(
3017                    &mut errors,
3018                    strict,
3019                    "round",
3020                    0,
3021                    this_family,
3022                    "a numeric argument",
3023                    this_family.is_numeric(),
3024                );
3025                if let Some(decimals) = &func.decimals {
3026                    let decimals_family =
3027                        infer_expression_type_family(decimals, schema_map, &context);
3028                    check_function_argument(
3029                        &mut errors,
3030                        strict,
3031                        "round",
3032                        1,
3033                        decimals_family,
3034                        "a numeric argument",
3035                        decimals_family.is_numeric(),
3036                    );
3037                }
3038            }
3039            Expression::Floor(func) => {
3040                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3041                check_function_argument(
3042                    &mut errors,
3043                    strict,
3044                    "floor",
3045                    0,
3046                    this_family,
3047                    "a numeric argument",
3048                    this_family.is_numeric(),
3049                );
3050                if let Some(scale) = &func.scale {
3051                    let scale_family = infer_expression_type_family(scale, schema_map, &context);
3052                    check_function_argument(
3053                        &mut errors,
3054                        strict,
3055                        "floor",
3056                        1,
3057                        scale_family,
3058                        "a numeric argument",
3059                        scale_family.is_numeric(),
3060                    );
3061                }
3062            }
3063            Expression::Ceil(func) => {
3064                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3065                check_function_argument(
3066                    &mut errors,
3067                    strict,
3068                    "ceil",
3069                    0,
3070                    this_family,
3071                    "a numeric argument",
3072                    this_family.is_numeric(),
3073                );
3074                if let Some(decimals) = &func.decimals {
3075                    let decimals_family =
3076                        infer_expression_type_family(decimals, schema_map, &context);
3077                    check_function_argument(
3078                        &mut errors,
3079                        strict,
3080                        "ceil",
3081                        1,
3082                        decimals_family,
3083                        "a numeric argument",
3084                        decimals_family.is_numeric(),
3085                    );
3086                }
3087            }
3088            Expression::Power(func) => {
3089                let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3090                check_function_argument(
3091                    &mut errors,
3092                    strict,
3093                    "power",
3094                    0,
3095                    left_family,
3096                    "a numeric argument",
3097                    left_family.is_numeric(),
3098                );
3099                let right_family =
3100                    infer_expression_type_family(&func.expression, schema_map, &context);
3101                check_function_argument(
3102                    &mut errors,
3103                    strict,
3104                    "power",
3105                    1,
3106                    right_family,
3107                    "a numeric argument",
3108                    right_family.is_numeric(),
3109                );
3110            }
3111            Expression::Log(func) => {
3112                let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3113                check_function_argument(
3114                    &mut errors,
3115                    strict,
3116                    "log",
3117                    0,
3118                    this_family,
3119                    "a numeric argument",
3120                    this_family.is_numeric(),
3121                );
3122                if let Some(base) = &func.base {
3123                    let base_family = infer_expression_type_family(base, schema_map, &context);
3124                    check_function_argument(
3125                        &mut errors,
3126                        strict,
3127                        "log",
3128                        1,
3129                        base_family,
3130                        "a numeric argument",
3131                        base_family.is_numeric(),
3132                    );
3133                }
3134            }
3135            _ => {}
3136        }
3137    }
3138
3139    errors
3140}
3141
3142fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3143    let mut errors = Vec::new();
3144
3145    let Expression::Select(select) = stmt else {
3146        return errors;
3147    };
3148    let select_expr = Expression::Select(select.clone());
3149
3150    // W001: SELECT * is discouraged
3151    if !select_expr
3152        .find_all(|e| matches!(e, Expression::Star(_)))
3153        .is_empty()
3154    {
3155        errors.push(ValidationError::warning(
3156            "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3157            validation_codes::W_SELECT_STAR,
3158        ));
3159    }
3160
3161    // W002: aggregate + non-aggregate columns without GROUP BY
3162    let aggregate_count = get_aggregate_functions(&select_expr).len();
3163    if aggregate_count > 0 && select.group_by.is_none() {
3164        let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3165            matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3166                && get_aggregate_functions(expr).is_empty()
3167        });
3168
3169        if has_non_aggregate_column {
3170            errors.push(ValidationError::warning(
3171                "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3172                validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3173            ));
3174        }
3175    }
3176
3177    // W003: DISTINCT with ORDER BY
3178    if select.distinct && select.order_by.is_some() {
3179        errors.push(ValidationError::warning(
3180            "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3181            validation_codes::W_DISTINCT_ORDER_BY,
3182        ));
3183    }
3184
3185    // W004: LIMIT without ORDER BY
3186    if select.limit.is_some() && select.order_by.is_none() {
3187        errors.push(ValidationError::warning(
3188            "LIMIT without ORDER BY produces non-deterministic results",
3189            validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3190        ));
3191    }
3192
3193    errors
3194}
3195
3196fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3197    scope
3198        .sources
3199        .get_key_value(name)
3200        .map(|(k, _)| k.clone())
3201        .or_else(|| {
3202            scope
3203                .sources
3204                .keys()
3205                .find(|source| source.eq_ignore_ascii_case(name))
3206                .cloned()
3207        })
3208}
3209
3210fn source_has_column(columns: &[String], column_name: &str) -> bool {
3211    columns
3212        .iter()
3213        .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3214}
3215
3216fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3217    scope
3218        .sources
3219        .get(source_name)
3220        .map(|source| match &source.expression {
3221            Expression::Table(table) => lower(&table_ref_display_name(table)),
3222            _ => lower(source_name),
3223        })
3224        .unwrap_or_else(|| lower(source_name))
3225}
3226
3227fn validate_select_columns_with_schema(
3228    select: &crate::expressions::Select,
3229    schema_map: &HashMap<String, TableSchemaEntry>,
3230    resolver_schema: &MappingSchema,
3231    strict: bool,
3232) -> Vec<ValidationError> {
3233    let mut errors = Vec::new();
3234    let select_expr = Expression::Select(Box::new(select.clone()));
3235    let scope = build_scope(&select_expr);
3236    let mut resolver = Resolver::new(&scope, resolver_schema, true);
3237    let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3238
3239    for node in walk_in_scope(&select_expr, false) {
3240        let Expression::Column(column) = node else {
3241            continue;
3242        };
3243
3244        let col_name = lower(&column.name.name);
3245        if col_name.is_empty() {
3246            continue;
3247        }
3248
3249        if let Some(table) = &column.table {
3250            let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3251                // The table qualifier is not a declared alias or source in this scope
3252                errors.push(if strict {
3253                    ValidationError::error(
3254                        format!(
3255                            "Unknown table or alias '{}' referenced by column '{}'",
3256                            table.name, col_name
3257                        ),
3258                        validation_codes::E_UNRESOLVED_REFERENCE,
3259                    )
3260                } else {
3261                    ValidationError::warning(
3262                        format!(
3263                            "Unknown table or alias '{}' referenced by column '{}'",
3264                            table.name, col_name
3265                        ),
3266                        validation_codes::E_UNRESOLVED_REFERENCE,
3267                    )
3268                });
3269                continue;
3270            };
3271
3272            if let Ok(columns) = resolver.get_source_columns(&source_name) {
3273                if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3274                    let table_name = source_display_name(&scope, &source_name);
3275                    errors.push(if strict {
3276                        ValidationError::error(
3277                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3278                            validation_codes::E_UNKNOWN_COLUMN,
3279                        )
3280                    } else {
3281                        ValidationError::warning(
3282                            format!("Unknown column '{}' in table '{}'", col_name, table_name),
3283                            validation_codes::E_UNKNOWN_COLUMN,
3284                        )
3285                    });
3286                }
3287            }
3288            continue;
3289        }
3290
3291        let matching_sources: Vec<String> = source_names
3292            .iter()
3293            .filter_map(|source_name| {
3294                resolver
3295                    .get_source_columns(source_name)
3296                    .ok()
3297                    .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3298                    .map(|_| source_name.clone())
3299            })
3300            .collect();
3301
3302        if !matching_sources.is_empty() {
3303            continue;
3304        }
3305
3306        let known_sources: Vec<String> = source_names
3307            .iter()
3308            .filter_map(|source_name| {
3309                resolver
3310                    .get_source_columns(source_name)
3311                    .ok()
3312                    .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3313                    .map(|_| source_name.clone())
3314            })
3315            .collect();
3316
3317        if known_sources.len() == 1 {
3318            let table_name = source_display_name(&scope, &known_sources[0]);
3319            errors.push(if strict {
3320                ValidationError::error(
3321                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3322                    validation_codes::E_UNKNOWN_COLUMN,
3323                )
3324            } else {
3325                ValidationError::warning(
3326                    format!("Unknown column '{}' in table '{}'", col_name, table_name),
3327                    validation_codes::E_UNKNOWN_COLUMN,
3328                )
3329            });
3330        } else if known_sources.len() > 1 {
3331            errors.push(if strict {
3332                ValidationError::error(
3333                    format!(
3334                        "Unknown column '{}' (not found in any referenced table)",
3335                        col_name
3336                    ),
3337                    validation_codes::E_UNKNOWN_COLUMN,
3338                )
3339            } else {
3340                ValidationError::warning(
3341                    format!(
3342                        "Unknown column '{}' (not found in any referenced table)",
3343                        col_name
3344                    ),
3345                    validation_codes::E_UNKNOWN_COLUMN,
3346                )
3347            });
3348        } else if !schema_map.is_empty() {
3349            let found = schema_map
3350                .values()
3351                .any(|table_schema| table_schema.columns.contains_key(&col_name));
3352            if !found {
3353                errors.push(if strict {
3354                    ValidationError::error(
3355                        format!("Unknown column '{}'", col_name),
3356                        validation_codes::E_UNKNOWN_COLUMN,
3357                    )
3358                } else {
3359                    ValidationError::warning(
3360                        format!("Unknown column '{}'", col_name),
3361                        validation_codes::E_UNKNOWN_COLUMN,
3362                    )
3363                });
3364            }
3365        }
3366    }
3367
3368    errors
3369}
3370
3371fn validate_statement_with_schema(
3372    stmt: &Expression,
3373    schema_map: &HashMap<String, TableSchemaEntry>,
3374    resolver_schema: &MappingSchema,
3375    strict: bool,
3376) -> Vec<ValidationError> {
3377    let mut errors = Vec::new();
3378    let cte_aliases = collect_cte_aliases(stmt);
3379    let mut seen_tables: HashSet<String> = HashSet::new();
3380
3381    // Table validation (E200)
3382    for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3383        let Expression::Table(table) = node else {
3384            continue;
3385        };
3386
3387        if cte_aliases.contains(&lower(&table.name.name)) {
3388            continue;
3389        }
3390
3391        let resolved_key = table_ref_candidates(table)
3392            .into_iter()
3393            .find(|k| schema_map.contains_key(k));
3394        let table_key = resolved_key
3395            .clone()
3396            .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3397
3398        if !seen_tables.insert(table_key) {
3399            continue;
3400        }
3401
3402        if resolved_key.is_none() {
3403            errors.push(if strict {
3404                ValidationError::error(
3405                    format!("Unknown table '{}'", table_ref_display_name(table)),
3406                    validation_codes::E_UNKNOWN_TABLE,
3407                )
3408            } else {
3409                ValidationError::warning(
3410                    format!("Unknown table '{}'", table_ref_display_name(table)),
3411                    validation_codes::E_UNKNOWN_TABLE,
3412                )
3413            });
3414        }
3415    }
3416
3417    for node in stmt.dfs() {
3418        let Expression::Select(select) = node else {
3419            continue;
3420        };
3421        errors.extend(validate_select_columns_with_schema(
3422            select,
3423            schema_map,
3424            resolver_schema,
3425            strict,
3426        ));
3427    }
3428
3429    errors
3430}
3431
3432/// Validate SQL using syntax + schema-aware checks, with optional semantic warnings.
3433pub fn validate_with_schema(
3434    sql: &str,
3435    dialect: DialectType,
3436    schema: &ValidationSchema,
3437    options: &SchemaValidationOptions,
3438) -> ValidationResult {
3439    let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3440
3441    // Syntax validation first.
3442    let syntax_result = crate::validate_with_options(
3443        sql,
3444        dialect,
3445        &crate::ValidationOptions {
3446            strict_syntax: options.strict_syntax,
3447        },
3448    );
3449    if !syntax_result.valid {
3450        return syntax_result;
3451    }
3452
3453    let d = Dialect::get(dialect);
3454    let statements = match d.parse(sql) {
3455        Ok(exprs) => exprs,
3456        Err(e) => {
3457            return ValidationResult::with_errors(vec![ValidationError::error(
3458                e.to_string(),
3459                validation_codes::E_PARSE_OR_OPTIONS,
3460            )]);
3461        }
3462    };
3463
3464    let schema_map = build_schema_map(schema);
3465    let resolver_schema = build_resolver_schema(schema);
3466    let mut all_errors = syntax_result.errors;
3467    let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3468        default_embedded_function_catalog()
3469    } else {
3470        None
3471    };
3472    let effective_function_catalog = options
3473        .function_catalog
3474        .as_deref()
3475        .or_else(|| embedded_function_catalog.as_deref());
3476    let declared_relationships = if options.check_references {
3477        build_declared_relationships(schema, &schema_map)
3478    } else {
3479        Vec::new()
3480    };
3481
3482    if options.check_references {
3483        all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3484    }
3485
3486    for stmt in &statements {
3487        if options.semantic {
3488            all_errors.extend(check_semantics(stmt));
3489        }
3490        all_errors.extend(validate_statement_with_schema(
3491            stmt,
3492            &schema_map,
3493            &resolver_schema,
3494            strict,
3495        ));
3496        if options.check_types {
3497            all_errors.extend(check_types(
3498                stmt,
3499                dialect,
3500                &schema_map,
3501                effective_function_catalog,
3502                strict,
3503            ));
3504        }
3505        if options.check_references {
3506            all_errors.extend(check_query_reference_quality(
3507                stmt,
3508                &schema_map,
3509                &resolver_schema,
3510                strict,
3511                &declared_relationships,
3512            ));
3513        }
3514    }
3515
3516    ValidationResult::with_errors(all_errors)
3517}
3518
3519#[cfg(test)]
3520mod tests;