1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15 feature = "function-catalog-clickhouse",
16 feature = "function-catalog-duckdb",
17 feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20 FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21 HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34 feature = "function-catalog-clickhouse",
35 feature = "function-catalog-duckdb",
36 feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43 pub name: String,
45 #[serde(default, rename = "type")]
47 pub data_type: String,
48 #[serde(default)]
50 pub nullable: Option<bool>,
51 #[serde(default, rename = "primaryKey")]
53 pub primary_key: bool,
54 #[serde(default)]
56 pub unique: bool,
57 #[serde(default)]
59 pub references: Option<SchemaColumnReference>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65 pub table: String,
67 pub column: String,
69 #[serde(default)]
71 pub schema: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77 #[serde(default)]
79 pub name: Option<String>,
80 pub columns: Vec<String>,
82 pub references: SchemaTableReference,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89 pub table: String,
91 pub columns: Vec<String>,
93 #[serde(default)]
95 pub schema: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101 pub name: String,
103 #[serde(default)]
105 pub schema: Option<String>,
106 pub columns: Vec<SchemaColumn>,
108 #[serde(default)]
110 pub aliases: Vec<String>,
111 #[serde(default, rename = "primaryKey")]
113 pub primary_key: Vec<String>,
114 #[serde(default, rename = "uniqueKeys")]
116 pub unique_keys: Vec<Vec<String>>,
117 #[serde(default, rename = "foreignKeys")]
119 pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125 pub tables: Vec<SchemaTable>,
127 #[serde(default)]
129 pub strict: Option<bool>,
130}
131
132#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135 #[serde(default)]
137 pub check_types: bool,
138 #[serde(default)]
140 pub check_references: bool,
141 #[serde(default)]
143 pub strict: Option<bool>,
144 #[serde(default)]
146 pub semantic: bool,
147 #[serde(default)]
149 pub strict_syntax: bool,
150 #[serde(skip, default)]
152 pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156 feature = "function-catalog-clickhouse",
157 feature = "function-catalog-duckdb",
158 feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161 case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163 match case {
164 polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165 CoreFunctionNameCase::Insensitive
166 }
167 polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168 CoreFunctionNameCase::Sensitive
169 }
170 }
171}
172
173#[cfg(any(
174 feature = "function-catalog-clickhouse",
175 feature = "function-catalog-duckdb",
176 feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181 signatures
182 .into_iter()
183 .map(|signature| CoreFunctionSignature {
184 min_arity: signature.min_arity,
185 max_arity: signature.max_arity,
186 })
187 .collect()
188}
189
190#[cfg(any(
191 feature = "function-catalog-clickhouse",
192 feature = "function-catalog-duckdb",
193 feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196 catalog: &'a mut HashMapFunctionCatalog,
197 dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201 feature = "function-catalog-clickhouse",
202 feature = "function-catalog-duckdb",
203 feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206 fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207 if let Some(cached) = self.dialect_cache.get(dialect) {
208 return *cached;
209 }
210 let parsed = dialect.parse::<DialectType>().ok();
211 self.dialect_cache.insert(dialect, parsed);
212 parsed
213 }
214}
215
216#[cfg(any(
217 feature = "function-catalog-clickhouse",
218 feature = "function-catalog-duckdb",
219 feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222 fn set_dialect_name_case(
223 &mut self,
224 dialect: &'static str,
225 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226 ) {
227 if let Some(core_dialect) = self.resolve_dialect(dialect) {
228 self.catalog
229 .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230 }
231 }
232
233 fn set_function_name_case(
234 &mut self,
235 dialect: &'static str,
236 function_name: &str,
237 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238 ) {
239 if let Some(core_dialect) = self.resolve_dialect(dialect) {
240 self.catalog.set_function_name_case(
241 core_dialect,
242 function_name,
243 to_core_name_case(name_case),
244 );
245 }
246 }
247
248 fn register(
249 &mut self,
250 dialect: &'static str,
251 function_name: &str,
252 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253 ) {
254 if let Some(core_dialect) = self.resolve_dialect(dialect) {
255 self.catalog
256 .register(core_dialect, function_name, to_core_signatures(signatures));
257 }
258 }
259}
260
261#[cfg(any(
262 feature = "function-catalog-clickhouse",
263 feature = "function-catalog-duckdb",
264 feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267 static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268 let mut catalog = HashMapFunctionCatalog::default();
269 let mut sink = EmbeddedCatalogSink {
270 catalog: &mut catalog,
271 dialect_cache: HashMap::new(),
272 };
273 polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274 Arc::new(catalog)
275 });
276
277 EMBEDDED.clone()
278}
279
280#[cfg(any(
281 feature = "function-catalog-clickhouse",
282 feature = "function-catalog-duckdb",
283 feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286 Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290 feature = "function-catalog-clickhouse",
291 feature = "function-catalog-duckdb",
292 feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295 None
296}
297
298pub mod validation_codes {
300 pub const E_PARSE_OR_OPTIONS: &str = "E000";
302 pub const E_UNKNOWN_TABLE: &str = "E200";
303 pub const E_UNKNOWN_COLUMN: &str = "E201";
304 pub const E_UNKNOWN_FUNCTION: &str = "E202";
305 pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307 pub const W_SELECT_STAR: &str = "W001";
308 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309 pub const W_DISTINCT_ORDER_BY: &str = "W003";
310 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312 pub const E_TYPE_MISMATCH: &str = "E210";
314 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321 pub const E_INVALID_CAST: &str = "E218";
322 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327 pub const W_LOSSY_CAST: &str = "W213";
328 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329 pub const W_PREDICATE_NULLABILITY: &str = "W215";
330 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342 pub const W_CARTESIAN_JOIN: &str = "W220";
343 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351 Unknown,
352 Boolean,
353 Integer,
354 Numeric,
355 String,
356 Binary,
357 Date,
358 Time,
359 Timestamp,
360 Interval,
361 Json,
362 Uuid,
363 Array,
364 Map,
365 Struct,
366}
367
368impl TypeFamily {
369 pub fn is_numeric(self) -> bool {
370 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371 }
372
373 pub fn is_temporal(self) -> bool {
374 matches!(
375 self,
376 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377 )
378 }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383 columns: HashMap<String, TypeFamily>,
384 column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388 s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392 let open = data_type.find('(')?;
393 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394 return None;
395 }
396 let base = data_type[..open].trim();
397 let inner = data_type[open + 1..data_type.len() - 1].trim();
398 Some((base, inner))
399}
400
401pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403 let trimmed = data_type
404 .trim()
405 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406 if trimmed.is_empty() {
407 return TypeFamily::Unknown;
408 }
409
410 let normalized = trimmed
412 .split_whitespace()
413 .collect::<Vec<_>>()
414 .join(" ")
415 .to_lowercase();
416
417 if let Some((base, inner)) = split_type_args(&normalized) {
419 match base {
420 "nullable" | "lowcardinality" => return canonical_type_family(inner),
421 "array" | "list" => return TypeFamily::Array,
422 "map" => return TypeFamily::Map,
423 "struct" | "row" | "record" => return TypeFamily::Struct,
424 _ => {}
425 }
426 }
427
428 if normalized.starts_with("array<") || normalized.starts_with("list<") {
429 return TypeFamily::Array;
430 }
431 if normalized.starts_with("map<") {
432 return TypeFamily::Map;
433 }
434 if normalized.starts_with("struct<")
435 || normalized.starts_with("row<")
436 || normalized.starts_with("record<")
437 || normalized.starts_with("object<")
438 {
439 return TypeFamily::Struct;
440 }
441
442 if normalized.ends_with("[]") {
443 return TypeFamily::Array;
444 }
445
446 let mut base = normalized
448 .split('(')
449 .next()
450 .unwrap_or("")
451 .trim()
452 .to_string();
453 if base.is_empty() {
454 return TypeFamily::Unknown;
455 }
456
457 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460 match base.as_str() {
461 "bool" | "boolean" => TypeFamily::Boolean,
462 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465 TypeFamily::Integer
466 }
467 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469 TypeFamily::Numeric
470 }
471 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472 | "string" | "clob" => TypeFamily::String,
473 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474 "date" => TypeFamily::Date,
475 "time" => TypeFamily::Time,
476 "timestamp"
477 | "timestamptz"
478 | "datetime"
479 | "datetime2"
480 | "smalldatetime"
481 | "timestamp with time zone"
482 | "timestamp without time zone" => TypeFamily::Timestamp,
483 "interval" => TypeFamily::Interval,
484 "json" | "jsonb" | "variant" => TypeFamily::Json,
485 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486 "array" | "list" => TypeFamily::Array,
487 "map" => TypeFamily::Map,
488 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489 _ => TypeFamily::Unknown,
490 }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494 let mut map = HashMap::new();
495
496 for table in &schema.tables {
497 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498 let columns: HashMap<String, TypeFamily> = table
499 .columns
500 .iter()
501 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502 .collect();
503 let entry = TableSchemaEntry {
504 columns,
505 column_order,
506 };
507
508 let simple_name = lower(&table.name);
509 map.insert(simple_name, entry.clone());
510
511 if let Some(table_schema) = &table.schema {
512 map.insert(
513 format!("{}.{}", lower(table_schema), lower(&table.name)),
514 entry.clone(),
515 );
516 }
517
518 for alias in &table.aliases {
519 map.insert(lower(alias), entry.clone());
520 }
521 }
522
523 map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527 match family {
528 TypeFamily::Unknown => DataType::Unknown,
529 TypeFamily::Boolean => DataType::Boolean,
530 TypeFamily::Integer => DataType::Int {
531 length: None,
532 integer_spelling: false,
533 },
534 TypeFamily::Numeric => DataType::Double {
535 precision: None,
536 scale: None,
537 },
538 TypeFamily::String => DataType::VarChar {
539 length: None,
540 parenthesized_length: false,
541 },
542 TypeFamily::Binary => DataType::VarBinary { length: None },
543 TypeFamily::Date => DataType::Date,
544 TypeFamily::Time => DataType::Time {
545 precision: None,
546 timezone: false,
547 },
548 TypeFamily::Timestamp => DataType::Timestamp {
549 precision: None,
550 timezone: false,
551 },
552 TypeFamily::Interval => DataType::Interval {
553 unit: None,
554 to: None,
555 },
556 TypeFamily::Json => DataType::Json,
557 TypeFamily::Uuid => DataType::Uuid,
558 TypeFamily::Array => DataType::Array {
559 element_type: Box::new(DataType::Unknown),
560 dimension: None,
561 },
562 TypeFamily::Map => DataType::Map {
563 key_type: Box::new(DataType::Unknown),
564 value_type: Box::new(DataType::Unknown),
565 },
566 TypeFamily::Struct => DataType::Struct {
567 fields: Vec::new(),
568 nested: false,
569 },
570 }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574 let mut mapping = MappingSchema::new();
575
576 for table in &schema.tables {
577 let columns: Vec<(String, DataType)> = table
578 .columns
579 .iter()
580 .map(|column| {
581 (
582 lower(&column.name),
583 type_family_to_data_type(canonical_type_family(&column.data_type)),
584 )
585 })
586 .collect();
587
588 let mut table_names = Vec::new();
589 table_names.push(lower(&table.name));
590 if let Some(table_schema) = &table.schema {
591 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592 }
593 for alias in &table.aliases {
594 table_names.push(lower(alias));
595 }
596
597 let mut dedup = HashSet::new();
598 for table_name in table_names {
599 if dedup.insert(table_name.clone()) {
600 let _ = mapping.add_table(&table_name, &columns, None);
601 }
602 }
603 }
604
605 mapping
606}
607
608pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614 build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618 let mut aliases = HashSet::new();
619
620 for node in expr.dfs() {
621 match node {
622 Expression::Select(select) => {
623 if let Some(with) = &select.with {
624 for cte in &with.ctes {
625 aliases.insert(lower(&cte.alias.name));
626 }
627 }
628 }
629 Expression::Insert(insert) => {
630 if let Some(with) = &insert.with {
631 for cte in &with.ctes {
632 aliases.insert(lower(&cte.alias.name));
633 }
634 }
635 }
636 Expression::Update(update) => {
637 if let Some(with) = &update.with {
638 for cte in &with.ctes {
639 aliases.insert(lower(&cte.alias.name));
640 }
641 }
642 }
643 Expression::Delete(delete) => {
644 if let Some(with) = &delete.with {
645 for cte in &with.ctes {
646 aliases.insert(lower(&cte.alias.name));
647 }
648 }
649 }
650 Expression::Union(union) => {
651 if let Some(with) = &union.with {
652 for cte in &with.ctes {
653 aliases.insert(lower(&cte.alias.name));
654 }
655 }
656 }
657 Expression::Intersect(intersect) => {
658 if let Some(with) = &intersect.with {
659 for cte in &with.ctes {
660 aliases.insert(lower(&cte.alias.name));
661 }
662 }
663 }
664 Expression::Except(except) => {
665 if let Some(with) = &except.with {
666 for cte in &with.ctes {
667 aliases.insert(lower(&cte.alias.name));
668 }
669 }
670 }
671 Expression::Merge(merge) => {
672 if let Some(with_) = &merge.with_ {
673 if let Expression::With(with_clause) = with_.as_ref() {
674 for cte in &with_clause.ctes {
675 aliases.insert(lower(&cte.alias.name));
676 }
677 }
678 }
679 }
680 _ => {}
681 }
682 }
683
684 aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688 let name = lower(&table.name.name);
689 let schema = table.schema.as_ref().map(|s| lower(&s.name));
690 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692 let mut candidates = Vec::new();
693 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694 candidates.push(format!("{}.{}.{}", catalog, schema, name));
695 }
696 if let Some(schema) = &schema {
697 candidates.push(format!("{}.{}", schema, name));
698 }
699 candidates.push(name);
700 candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704 let mut parts = Vec::new();
705 if let Some(catalog) = &table.catalog {
706 parts.push(catalog.name.clone());
707 }
708 if let Some(schema) = &table.schema {
709 parts.push(schema.name.clone());
710 }
711 parts.push(table.name.name.clone());
712 parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717 referenced_tables: HashSet<String>,
718 table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722 match family {
723 TypeFamily::Unknown => "unknown",
724 TypeFamily::Boolean => "boolean",
725 TypeFamily::Integer => "integer",
726 TypeFamily::Numeric => "numeric",
727 TypeFamily::String => "string",
728 TypeFamily::Binary => "binary",
729 TypeFamily::Date => "date",
730 TypeFamily::Time => "time",
731 TypeFamily::Timestamp => "timestamp",
732 TypeFamily::Interval => "interval",
733 TypeFamily::Json => "json",
734 TypeFamily::Uuid => "uuid",
735 TypeFamily::Array => "array",
736 TypeFamily::Map => "map",
737 TypeFamily::Struct => "struct",
738 }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742 matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746 matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750 strict: bool,
751 error_code: &str,
752 warning_code: &str,
753 message: impl Into<String>,
754) -> ValidationError {
755 if strict {
756 ValidationError::error(message.into(), error_code)
757 } else {
758 ValidationError::warning(message.into(), warning_code)
759 }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763 match data_type {
764 DataType::Boolean => TypeFamily::Boolean,
765 DataType::TinyInt { .. }
766 | DataType::SmallInt { .. }
767 | DataType::Int { .. }
768 | DataType::BigInt { .. } => TypeFamily::Integer,
769 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770 TypeFamily::Numeric
771 }
772 DataType::Char { .. }
773 | DataType::VarChar { .. }
774 | DataType::String { .. }
775 | DataType::Text
776 | DataType::TextWithLength { .. }
777 | DataType::CharacterSet { .. } => TypeFamily::String,
778 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779 DataType::Date => TypeFamily::Date,
780 DataType::Time { .. } => TypeFamily::Time,
781 DataType::Timestamp { .. } => TypeFamily::Timestamp,
782 DataType::Interval { .. } => TypeFamily::Interval,
783 DataType::Json | DataType::JsonB => TypeFamily::Json,
784 DataType::Uuid => TypeFamily::Uuid,
785 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786 DataType::Map { .. } => TypeFamily::Map,
787 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788 TypeFamily::Struct
789 }
790 DataType::Nullable { inner } => data_type_family(inner),
791 DataType::Custom { name } => canonical_type_family(name),
792 DataType::Unknown => TypeFamily::Unknown,
793 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795 DataType::Vector { .. } => TypeFamily::Array,
796 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797 }
798}
799
800fn collect_type_check_context(
801 stmt: &Expression,
802 schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804 fn add_table_to_context(
805 table: &TableRef,
806 schema_map: &HashMap<String, TableSchemaEntry>,
807 context: &mut TypeCheckContext,
808 ) {
809 let resolved_key = table_ref_candidates(table)
810 .into_iter()
811 .find(|k| schema_map.contains_key(k));
812
813 let Some(table_key) = resolved_key else {
814 return;
815 };
816
817 context.referenced_tables.insert(table_key.clone());
818 context
819 .table_aliases
820 .insert(lower(&table.name.name), table_key.clone());
821 if let Some(alias) = &table.alias {
822 context
823 .table_aliases
824 .insert(lower(&alias.name), table_key.clone());
825 }
826 }
827
828 let mut context = TypeCheckContext::default();
829 let cte_aliases = collect_cte_aliases(stmt);
830
831 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832 let Expression::Table(table) = node else {
833 continue;
834 };
835
836 if cte_aliases.contains(&lower(&table.name.name)) {
837 continue;
838 }
839
840 add_table_to_context(table, schema_map, &mut context);
841 }
842
843 match stmt {
846 Expression::Insert(insert) => {
847 add_table_to_context(&insert.table, schema_map, &mut context);
848 }
849 Expression::Update(update) => {
850 add_table_to_context(&update.table, schema_map, &mut context);
851 for table in &update.extra_tables {
852 add_table_to_context(table, schema_map, &mut context);
853 }
854 }
855 Expression::Delete(delete) => {
856 add_table_to_context(&delete.table, schema_map, &mut context);
857 for table in &delete.using {
858 add_table_to_context(table, schema_map, &mut context);
859 }
860 for table in &delete.tables {
861 add_table_to_context(table, schema_map, &mut context);
862 }
863 }
864 _ => {}
865 }
866
867 context
868}
869
870fn resolve_table_schema_entry<'a>(
871 table: &TableRef,
872 schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874 let key = table_ref_candidates(table)
875 .into_iter()
876 .find(|k| schema_map.contains_key(k))?;
877 let entry = schema_map.get(&key)?;
878 Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882 if strict {
883 ValidationError::error(
884 message.into(),
885 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886 )
887 } else {
888 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889 }
890}
891
892fn reference_table_candidates(
893 table_name: &str,
894 explicit_schema: Option<&str>,
895 source_schema: Option<&str>,
896) -> Vec<String> {
897 let mut candidates = Vec::new();
898 let raw = lower(table_name);
899
900 if let Some(schema) = explicit_schema {
901 candidates.push(format!("{}.{}", lower(schema), raw));
902 }
903
904 if raw.contains('.') {
905 candidates.push(raw.clone());
906 if let Some(last) = raw.rsplit('.').next() {
907 candidates.push(last.to_string());
908 }
909 } else {
910 if let Some(schema) = source_schema {
911 candidates.push(format!("{}.{}", lower(schema), raw));
912 }
913 candidates.push(raw);
914 }
915
916 let mut dedup = HashSet::new();
917 candidates
918 .into_iter()
919 .filter(|c| dedup.insert(c.clone()))
920 .collect()
921}
922
923fn resolve_reference_table_key(
924 table_name: &str,
925 explicit_schema: Option<&str>,
926 source_schema: Option<&str>,
927 schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929 reference_table_candidates(table_name, explicit_schema, source_schema)
930 .into_iter()
931 .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936 return true;
937 }
938 if source == target {
939 return true;
940 }
941 if source.is_numeric() && target.is_numeric() {
942 return true;
943 }
944 if source.is_temporal() && target.is_temporal() {
945 return true;
946 }
947 false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951 let mut hints = HashSet::new();
952 for column in &table.columns {
953 if column.primary_key || column.unique {
954 hints.insert(lower(&column.name));
955 }
956 }
957 for key_col in &table.primary_key {
958 hints.insert(lower(key_col));
959 }
960 for group in &table.unique_keys {
961 if group.len() == 1 {
962 if let Some(col) = group.first() {
963 hints.insert(lower(col));
964 }
965 }
966 }
967 hints
968}
969
970fn check_reference_integrity(
971 schema: &ValidationSchema,
972 schema_map: &HashMap<String, TableSchemaEntry>,
973 strict: bool,
974) -> Vec<ValidationError> {
975 let mut errors = Vec::new();
976
977 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978 for table in &schema.tables {
979 let simple = lower(&table.name);
980 key_hints_lookup.insert(simple, table_key_hints(table));
981 if let Some(schema_name) = &table.schema {
982 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983 key_hints_lookup.insert(qualified, table_key_hints(table));
984 }
985 }
986
987 for table in &schema.tables {
988 let source_table_display = if let Some(schema_name) = &table.schema {
989 format!("{}.{}", schema_name, table.name)
990 } else {
991 table.name.clone()
992 };
993 let source_schema = table.schema.as_deref();
994 let source_columns: HashMap<String, TypeFamily> = table
995 .columns
996 .iter()
997 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998 .collect();
999
1000 for source_col in &table.columns {
1001 let Some(reference) = &source_col.references else {
1002 continue;
1003 };
1004 let source_type = canonical_type_family(&source_col.data_type);
1005
1006 let Some(target_key) = resolve_reference_table_key(
1007 &reference.table,
1008 reference.schema.as_deref(),
1009 source_schema,
1010 schema_map,
1011 ) else {
1012 errors.push(reference_issue(
1013 strict,
1014 format!(
1015 "Foreign key reference '{}.{}' points to unknown table '{}'",
1016 source_table_display, source_col.name, reference.table
1017 ),
1018 ));
1019 continue;
1020 };
1021
1022 let target_column = lower(&reference.column);
1023 let Some(target_entry) = schema_map.get(&target_key) else {
1024 errors.push(reference_issue(
1025 strict,
1026 format!(
1027 "Foreign key reference '{}.{}' points to unknown table '{}'",
1028 source_table_display, source_col.name, reference.table
1029 ),
1030 ));
1031 continue;
1032 };
1033
1034 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035 errors.push(reference_issue(
1036 strict,
1037 format!(
1038 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039 source_table_display, source_col.name, target_key, reference.column
1040 ),
1041 ));
1042 continue;
1043 };
1044
1045 if !key_types_compatible(source_type, target_type) {
1046 errors.push(reference_issue(
1047 strict,
1048 format!(
1049 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050 source_table_display,
1051 source_col.name,
1052 target_key,
1053 reference.column,
1054 type_family_name(source_type),
1055 type_family_name(target_type)
1056 ),
1057 ));
1058 }
1059
1060 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061 if !target_key_hints.contains(&target_column) {
1062 errors.push(ValidationError::warning(
1063 format!(
1064 "Referenced column '{}.{}' is not marked as primary/unique key",
1065 target_key, reference.column
1066 ),
1067 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068 ));
1069 }
1070 }
1071 }
1072
1073 for foreign_key in &table.foreign_keys {
1074 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075 errors.push(reference_issue(
1076 strict,
1077 format!(
1078 "Table-level foreign key on '{}' has empty source or target column list",
1079 source_table_display
1080 ),
1081 ));
1082 continue;
1083 }
1084 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085 errors.push(reference_issue(
1086 strict,
1087 format!(
1088 "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089 source_table_display,
1090 foreign_key.columns.len(),
1091 foreign_key.references.columns.len()
1092 ),
1093 ));
1094 continue;
1095 }
1096
1097 let Some(target_key) = resolve_reference_table_key(
1098 &foreign_key.references.table,
1099 foreign_key.references.schema.as_deref(),
1100 source_schema,
1101 schema_map,
1102 ) else {
1103 errors.push(reference_issue(
1104 strict,
1105 format!(
1106 "Table-level foreign key on '{}' points to unknown table '{}'",
1107 source_table_display, foreign_key.references.table
1108 ),
1109 ));
1110 continue;
1111 };
1112
1113 let Some(target_entry) = schema_map.get(&target_key) else {
1114 errors.push(reference_issue(
1115 strict,
1116 format!(
1117 "Table-level foreign key on '{}' points to unknown table '{}'",
1118 source_table_display, foreign_key.references.table
1119 ),
1120 ));
1121 continue;
1122 };
1123
1124 for (source_col, target_col) in foreign_key
1125 .columns
1126 .iter()
1127 .zip(foreign_key.references.columns.iter())
1128 {
1129 let source_col_name = lower(source_col);
1130 let target_col_name = lower(target_col);
1131
1132 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133 errors.push(reference_issue(
1134 strict,
1135 format!(
1136 "Table-level foreign key on '{}' references unknown source column '{}'",
1137 source_table_display, source_col
1138 ),
1139 ));
1140 continue;
1141 };
1142
1143 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144 errors.push(reference_issue(
1145 strict,
1146 format!(
1147 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148 source_table_display, target_key, target_col
1149 ),
1150 ));
1151 continue;
1152 };
1153
1154 if !key_types_compatible(source_type, target_type) {
1155 errors.push(reference_issue(
1156 strict,
1157 format!(
1158 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159 source_table_display,
1160 source_col,
1161 target_key,
1162 target_col,
1163 type_family_name(source_type),
1164 type_family_name(target_type)
1165 ),
1166 ));
1167 }
1168
1169 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170 if !target_key_hints.contains(&target_col_name) {
1171 errors.push(ValidationError::warning(
1172 format!(
1173 "Referenced column '{}.{}' is not marked as primary/unique key",
1174 target_key, target_col
1175 ),
1176 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177 ));
1178 }
1179 }
1180 }
1181 }
1182 }
1183
1184 errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188 column_name: &str,
1189 schema_map: &HashMap<String, TableSchemaEntry>,
1190 context: &TypeCheckContext,
1191) -> TypeFamily {
1192 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193 context.referenced_tables.iter().collect()
1194 } else {
1195 schema_map.keys().collect()
1196 };
1197
1198 let mut families = HashSet::new();
1199 for table_name in candidate_tables {
1200 if let Some(table_schema) = schema_map.get(table_name) {
1201 if let Some(family) = table_schema.columns.get(column_name) {
1202 families.insert(*family);
1203 }
1204 }
1205 }
1206
1207 if families.len() == 1 {
1208 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209 } else {
1210 TypeFamily::Unknown
1211 }
1212}
1213
1214fn resolve_column_type(
1215 column: &Column,
1216 schema_map: &HashMap<String, TableSchemaEntry>,
1217 context: &TypeCheckContext,
1218) -> TypeFamily {
1219 let column_name = lower(&column.name.name);
1220 if column_name.is_empty() {
1221 return TypeFamily::Unknown;
1222 }
1223
1224 if let Some(table) = &column.table {
1225 let mut table_key = lower(&table.name);
1226 if let Some(mapped) = context.table_aliases.get(&table_key) {
1227 table_key = mapped.clone();
1228 }
1229
1230 return schema_map
1231 .get(&table_key)
1232 .and_then(|t| t.columns.get(&column_name))
1233 .copied()
1234 .unwrap_or(TypeFamily::Unknown);
1235 }
1236
1237 resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241 schema_map: &'a HashMap<String, TableSchemaEntry>,
1242 context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246 fn resolve_table_key(&self, table: &str) -> Option<String> {
1247 let mut table_key = lower(table);
1248 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249 table_key = mapped.clone();
1250 }
1251 if self.schema_map.contains_key(&table_key) {
1252 Some(table_key)
1253 } else {
1254 None
1255 }
1256 }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260 fn dialect(&self) -> Option<DialectType> {
1261 None
1262 }
1263
1264 fn add_table(
1265 &mut self,
1266 _table: &str,
1267 _columns: &[(String, DataType)],
1268 _dialect: Option<DialectType>,
1269 ) -> SchemaResult<()> {
1270 Err(SchemaError::InvalidStructure(
1271 "Type inference schema is read-only".to_string(),
1272 ))
1273 }
1274
1275 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276 let table_key = self
1277 .resolve_table_key(table)
1278 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279 let entry = self
1280 .schema_map
1281 .get(&table_key)
1282 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283 Ok(entry.column_order.clone())
1284 }
1285
1286 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287 let col_name = lower(column);
1288 if table.is_empty() {
1289 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290 return if family == TypeFamily::Unknown {
1291 Err(SchemaError::ColumnNotFound {
1292 table: "<unqualified>".to_string(),
1293 column: column.to_string(),
1294 })
1295 } else {
1296 Ok(type_family_to_data_type(family))
1297 };
1298 }
1299
1300 let table_key = self
1301 .resolve_table_key(table)
1302 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303 let entry = self
1304 .schema_map
1305 .get(&table_key)
1306 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307 let family =
1308 entry
1309 .columns
1310 .get(&col_name)
1311 .copied()
1312 .ok_or_else(|| SchemaError::ColumnNotFound {
1313 table: table.to_string(),
1314 column: column.to_string(),
1315 })?;
1316 Ok(type_family_to_data_type(family))
1317 }
1318
1319 fn has_column(&self, table: &str, column: &str) -> bool {
1320 self.get_column_type(table, column).is_ok()
1321 }
1322
1323 fn supported_table_args(&self) -> &[&str] {
1324 TABLE_PARTS
1325 }
1326
1327 fn is_empty(&self) -> bool {
1328 self.schema_map.is_empty()
1329 }
1330
1331 fn depth(&self) -> usize {
1332 1
1333 }
1334}
1335
1336fn infer_expression_type_family(
1337 expr: &Expression,
1338 schema_map: &HashMap<String, TableSchemaEntry>,
1339 context: &TypeCheckContext,
1340) -> TypeFamily {
1341 let inference_schema = TypeInferenceSchema {
1342 schema_map,
1343 context,
1344 };
1345 if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1346 let family = data_type_family(&data_type);
1347 if family != TypeFamily::Unknown {
1348 return family;
1349 }
1350 }
1351
1352 infer_expression_type_family_fallback(expr, schema_map, context)
1353}
1354
1355fn infer_expression_type_family_fallback(
1356 expr: &Expression,
1357 schema_map: &HashMap<String, TableSchemaEntry>,
1358 context: &TypeCheckContext,
1359) -> TypeFamily {
1360 match expr {
1361 Expression::Literal(literal) => match literal {
1362 crate::expressions::Literal::Number(value) => {
1363 if value.contains('.') || value.contains('e') || value.contains('E') {
1364 TypeFamily::Numeric
1365 } else {
1366 TypeFamily::Integer
1367 }
1368 }
1369 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1370 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1371 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1372 crate::expressions::Literal::Timestamp(_)
1373 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1374 crate::expressions::Literal::HexString(_)
1375 | crate::expressions::Literal::BitString(_)
1376 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1377 _ => TypeFamily::String,
1378 },
1379 Expression::Boolean(_) => TypeFamily::Boolean,
1380 Expression::Null(_) => TypeFamily::Unknown,
1381 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1382 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1383 data_type_family(&cast.to)
1384 }
1385 Expression::Alias(alias) => {
1386 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1387 }
1388 Expression::Neg(unary) => {
1389 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1390 }
1391 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1392 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1393 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1394 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1395 TypeFamily::Unknown
1396 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1397 TypeFamily::Integer
1398 } else if left.is_numeric() && right.is_numeric() {
1399 TypeFamily::Numeric
1400 } else if left.is_temporal() || right.is_temporal() {
1401 left
1402 } else {
1403 TypeFamily::Unknown
1404 }
1405 }
1406 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1407 Expression::Concat(_) => TypeFamily::String,
1408 Expression::Eq(_)
1409 | Expression::Neq(_)
1410 | Expression::Lt(_)
1411 | Expression::Lte(_)
1412 | Expression::Gt(_)
1413 | Expression::Gte(_)
1414 | Expression::Like(_)
1415 | Expression::ILike(_)
1416 | Expression::And(_)
1417 | Expression::Or(_)
1418 | Expression::Not(_)
1419 | Expression::Between(_)
1420 | Expression::In(_)
1421 | Expression::IsNull(_)
1422 | Expression::IsTrue(_)
1423 | Expression::IsFalse(_)
1424 | Expression::Is(_) => TypeFamily::Boolean,
1425 Expression::Length(_) => TypeFamily::Integer,
1426 Expression::Upper(_)
1427 | Expression::Lower(_)
1428 | Expression::Trim(_)
1429 | Expression::LTrim(_)
1430 | Expression::RTrim(_)
1431 | Expression::Replace(_)
1432 | Expression::Substring(_)
1433 | Expression::Left(_)
1434 | Expression::Right(_)
1435 | Expression::Repeat(_)
1436 | Expression::Lpad(_)
1437 | Expression::Rpad(_)
1438 | Expression::ConcatWs(_) => TypeFamily::String,
1439 Expression::Abs(_)
1440 | Expression::Round(_)
1441 | Expression::Floor(_)
1442 | Expression::Ceil(_)
1443 | Expression::Power(_)
1444 | Expression::Sqrt(_)
1445 | Expression::Cbrt(_)
1446 | Expression::Ln(_)
1447 | Expression::Log(_)
1448 | Expression::Exp(_) => TypeFamily::Numeric,
1449 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1450 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1451 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1452 Expression::CurrentDate(_) => TypeFamily::Date,
1453 Expression::CurrentTime(_) => TypeFamily::Time,
1454 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1455 TypeFamily::Timestamp
1456 }
1457 Expression::Interval(_) => TypeFamily::Interval,
1458 _ => TypeFamily::Unknown,
1459 }
1460}
1461
1462fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1463 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1464 return true;
1465 }
1466 if left == right {
1467 return true;
1468 }
1469 if left.is_numeric() && right.is_numeric() {
1470 return true;
1471 }
1472 if left.is_temporal() && right.is_temporal() {
1473 return true;
1474 }
1475 false
1476}
1477
1478fn check_function_argument(
1479 errors: &mut Vec<ValidationError>,
1480 strict: bool,
1481 function_name: &str,
1482 arg_index: usize,
1483 family: TypeFamily,
1484 expected: &str,
1485 valid: bool,
1486) {
1487 if family == TypeFamily::Unknown || valid {
1488 return;
1489 }
1490
1491 errors.push(type_issue(
1492 strict,
1493 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1494 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1495 format!(
1496 "Function '{}' argument {} expects {}, found {}",
1497 function_name,
1498 arg_index + 1,
1499 expected,
1500 type_family_name(family)
1501 ),
1502 ));
1503}
1504
1505fn function_dispatch_name(name: &str) -> String {
1506 let upper = name
1507 .rsplit('.')
1508 .next()
1509 .unwrap_or(name)
1510 .trim()
1511 .to_uppercase();
1512 lower(canonical_typed_function_name_upper(&upper))
1513}
1514
1515fn function_base_name(name: &str) -> &str {
1516 name.rsplit('.').next().unwrap_or(name).trim()
1517}
1518
1519fn check_generic_function(
1520 function: &Function,
1521 schema_map: &HashMap<String, TableSchemaEntry>,
1522 context: &TypeCheckContext,
1523 strict: bool,
1524 errors: &mut Vec<ValidationError>,
1525) {
1526 let name = function_dispatch_name(&function.name);
1527
1528 let arg_family = |index: usize| -> Option<TypeFamily> {
1529 function
1530 .args
1531 .get(index)
1532 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1533 };
1534
1535 match name.as_str() {
1536 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1537 if let Some(family) = arg_family(0) {
1538 check_function_argument(
1539 errors,
1540 strict,
1541 &name,
1542 0,
1543 family,
1544 "a numeric argument",
1545 family.is_numeric(),
1546 );
1547 }
1548 }
1549 "round" | "floor" | "ceil" | "ceiling" => {
1550 if let Some(family) = arg_family(0) {
1551 check_function_argument(
1552 errors,
1553 strict,
1554 &name,
1555 0,
1556 family,
1557 "a numeric argument",
1558 family.is_numeric(),
1559 );
1560 }
1561 if let Some(family) = arg_family(1) {
1562 check_function_argument(
1563 errors,
1564 strict,
1565 &name,
1566 1,
1567 family,
1568 "a numeric argument",
1569 family.is_numeric(),
1570 );
1571 }
1572 }
1573 "power" | "pow" => {
1574 for i in [0_usize, 1_usize] {
1575 if let Some(family) = arg_family(i) {
1576 check_function_argument(
1577 errors,
1578 strict,
1579 &name,
1580 i,
1581 family,
1582 "a numeric argument",
1583 family.is_numeric(),
1584 );
1585 }
1586 }
1587 }
1588 "length" | "char_length" | "character_length" => {
1589 if let Some(family) = arg_family(0) {
1590 check_function_argument(
1591 errors,
1592 strict,
1593 &name,
1594 0,
1595 family,
1596 "a string or binary argument",
1597 is_string_or_binary(family),
1598 );
1599 }
1600 }
1601 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1602 if let Some(family) = arg_family(0) {
1603 check_function_argument(
1604 errors,
1605 strict,
1606 &name,
1607 0,
1608 family,
1609 "a string argument",
1610 is_string_like(family),
1611 );
1612 }
1613 }
1614 "substring" | "substr" => {
1615 if let Some(family) = arg_family(0) {
1616 check_function_argument(
1617 errors,
1618 strict,
1619 &name,
1620 0,
1621 family,
1622 "a string argument",
1623 is_string_like(family),
1624 );
1625 }
1626 if let Some(family) = arg_family(1) {
1627 check_function_argument(
1628 errors,
1629 strict,
1630 &name,
1631 1,
1632 family,
1633 "a numeric argument",
1634 family.is_numeric(),
1635 );
1636 }
1637 if let Some(family) = arg_family(2) {
1638 check_function_argument(
1639 errors,
1640 strict,
1641 &name,
1642 2,
1643 family,
1644 "a numeric argument",
1645 family.is_numeric(),
1646 );
1647 }
1648 }
1649 "replace" => {
1650 for i in [0_usize, 1_usize, 2_usize] {
1651 if let Some(family) = arg_family(i) {
1652 check_function_argument(
1653 errors,
1654 strict,
1655 &name,
1656 i,
1657 family,
1658 "a string argument",
1659 is_string_like(family),
1660 );
1661 }
1662 }
1663 }
1664 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1665 if let Some(family) = arg_family(0) {
1666 check_function_argument(
1667 errors,
1668 strict,
1669 &name,
1670 0,
1671 family,
1672 "a string argument",
1673 is_string_like(family),
1674 );
1675 }
1676 if let Some(family) = arg_family(1) {
1677 check_function_argument(
1678 errors,
1679 strict,
1680 &name,
1681 1,
1682 family,
1683 "a numeric argument",
1684 family.is_numeric(),
1685 );
1686 }
1687 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1688 if let Some(family) = arg_family(2) {
1689 check_function_argument(
1690 errors,
1691 strict,
1692 &name,
1693 2,
1694 family,
1695 "a string argument",
1696 is_string_like(family),
1697 );
1698 }
1699 }
1700 }
1701 _ => {}
1702 }
1703}
1704
1705fn check_function_catalog(
1706 function: &Function,
1707 dialect: DialectType,
1708 function_catalog: Option<&dyn FunctionCatalog>,
1709 strict: bool,
1710 errors: &mut Vec<ValidationError>,
1711) {
1712 let Some(catalog) = function_catalog else {
1713 return;
1714 };
1715
1716 let raw_name = function_base_name(&function.name);
1717 let normalized_name = function_dispatch_name(&function.name);
1718 let arity = function.args.len();
1719 let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1720 errors.push(if strict {
1721 ValidationError::error(
1722 format!(
1723 "Unknown function '{}' for dialect {:?}",
1724 function.name, dialect
1725 ),
1726 validation_codes::E_UNKNOWN_FUNCTION,
1727 )
1728 } else {
1729 ValidationError::warning(
1730 format!(
1731 "Unknown function '{}' for dialect {:?}",
1732 function.name, dialect
1733 ),
1734 validation_codes::E_UNKNOWN_FUNCTION,
1735 )
1736 });
1737 return;
1738 };
1739
1740 if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1741 return;
1742 }
1743
1744 let expected = signatures
1745 .iter()
1746 .map(|sig| sig.describe_arity())
1747 .collect::<Vec<_>>()
1748 .join(", ");
1749 errors.push(if strict {
1750 ValidationError::error(
1751 format!(
1752 "Invalid arity for function '{}': got {}, expected {}",
1753 function.name, arity, expected
1754 ),
1755 validation_codes::E_INVALID_FUNCTION_ARITY,
1756 )
1757 } else {
1758 ValidationError::warning(
1759 format!(
1760 "Invalid arity for function '{}': got {}, expected {}",
1761 function.name, arity, expected
1762 ),
1763 validation_codes::E_INVALID_FUNCTION_ARITY,
1764 )
1765 });
1766}
1767
1768#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1769struct DeclaredRelationship {
1770 source_table: String,
1771 source_column: String,
1772 target_table: String,
1773 target_column: String,
1774}
1775
1776fn build_declared_relationships(
1777 schema: &ValidationSchema,
1778 schema_map: &HashMap<String, TableSchemaEntry>,
1779) -> Vec<DeclaredRelationship> {
1780 let mut relationships = HashSet::new();
1781
1782 for table in &schema.tables {
1783 let Some(source_key) =
1784 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1785 else {
1786 continue;
1787 };
1788
1789 for column in &table.columns {
1790 let Some(reference) = &column.references else {
1791 continue;
1792 };
1793 let Some(target_key) = resolve_reference_table_key(
1794 &reference.table,
1795 reference.schema.as_deref(),
1796 table.schema.as_deref(),
1797 schema_map,
1798 ) else {
1799 continue;
1800 };
1801
1802 relationships.insert(DeclaredRelationship {
1803 source_table: source_key.clone(),
1804 source_column: lower(&column.name),
1805 target_table: target_key,
1806 target_column: lower(&reference.column),
1807 });
1808 }
1809
1810 for foreign_key in &table.foreign_keys {
1811 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1812 continue;
1813 }
1814 let Some(target_key) = resolve_reference_table_key(
1815 &foreign_key.references.table,
1816 foreign_key.references.schema.as_deref(),
1817 table.schema.as_deref(),
1818 schema_map,
1819 ) else {
1820 continue;
1821 };
1822
1823 for (source_col, target_col) in foreign_key
1824 .columns
1825 .iter()
1826 .zip(foreign_key.references.columns.iter())
1827 {
1828 relationships.insert(DeclaredRelationship {
1829 source_table: source_key.clone(),
1830 source_column: lower(source_col),
1831 target_table: target_key.clone(),
1832 target_column: lower(target_col),
1833 });
1834 }
1835 }
1836 }
1837
1838 relationships.into_iter().collect()
1839}
1840
1841fn resolve_column_binding(
1842 column: &Column,
1843 schema_map: &HashMap<String, TableSchemaEntry>,
1844 context: &TypeCheckContext,
1845 resolver: &mut Resolver<'_>,
1846) -> Option<(String, String)> {
1847 let column_name = lower(&column.name.name);
1848 if column_name.is_empty() {
1849 return None;
1850 }
1851
1852 if let Some(table) = &column.table {
1853 let mut table_key = lower(&table.name);
1854 if let Some(mapped) = context.table_aliases.get(&table_key) {
1855 table_key = mapped.clone();
1856 }
1857 if schema_map.contains_key(&table_key) {
1858 return Some((table_key, column_name));
1859 }
1860 return None;
1861 }
1862
1863 if let Some(resolved_source) = resolver.get_table(&column_name) {
1864 let mut table_key = lower(&resolved_source);
1865 if let Some(mapped) = context.table_aliases.get(&table_key) {
1866 table_key = mapped.clone();
1867 }
1868 if schema_map.contains_key(&table_key) {
1869 return Some((table_key, column_name));
1870 }
1871 }
1872
1873 let candidates: Vec<String> = context
1874 .referenced_tables
1875 .iter()
1876 .filter_map(|table_name| {
1877 schema_map
1878 .get(table_name)
1879 .filter(|entry| entry.columns.contains_key(&column_name))
1880 .map(|_| table_name.clone())
1881 })
1882 .collect();
1883 if candidates.len() == 1 {
1884 return Some((candidates[0].clone(), column_name));
1885 }
1886 None
1887}
1888
1889fn extract_join_equality_pairs(
1890 expr: &Expression,
1891 schema_map: &HashMap<String, TableSchemaEntry>,
1892 context: &TypeCheckContext,
1893 resolver: &mut Resolver<'_>,
1894 pairs: &mut Vec<((String, String), (String, String))>,
1895) {
1896 match expr {
1897 Expression::And(op) => {
1898 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1899 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1900 }
1901 Expression::Paren(paren) => {
1902 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1903 }
1904 Expression::Eq(op) => {
1905 let (Expression::Column(left_col), Expression::Column(right_col)) =
1906 (&op.left, &op.right)
1907 else {
1908 return;
1909 };
1910 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1911 return;
1912 };
1913 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1914 else {
1915 return;
1916 };
1917 pairs.push((left, right));
1918 }
1919 _ => {}
1920 }
1921}
1922
1923fn relationship_matches_pair(
1924 relationship: &DeclaredRelationship,
1925 left_table: &str,
1926 left_column: &str,
1927 right_table: &str,
1928 right_column: &str,
1929) -> bool {
1930 (relationship.source_table == left_table
1931 && relationship.source_column == left_column
1932 && relationship.target_table == right_table
1933 && relationship.target_column == right_column)
1934 || (relationship.source_table == right_table
1935 && relationship.source_column == right_column
1936 && relationship.target_table == left_table
1937 && relationship.target_column == left_column)
1938}
1939
1940fn resolved_table_key_from_expr(
1941 expr: &Expression,
1942 schema_map: &HashMap<String, TableSchemaEntry>,
1943) -> Option<String> {
1944 match expr {
1945 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1946 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1947 _ => None,
1948 }
1949}
1950
1951fn select_from_table_keys(
1952 select: &crate::expressions::Select,
1953 schema_map: &HashMap<String, TableSchemaEntry>,
1954) -> HashSet<String> {
1955 let mut keys = HashSet::new();
1956 if let Some(from_clause) = &select.from {
1957 for expr in &from_clause.expressions {
1958 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1959 keys.insert(key);
1960 }
1961 }
1962 }
1963 keys
1964}
1965
1966fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1967 matches!(
1968 kind,
1969 JoinKind::Natural
1970 | JoinKind::NaturalLeft
1971 | JoinKind::NaturalRight
1972 | JoinKind::NaturalFull
1973 | JoinKind::CrossApply
1974 | JoinKind::OuterApply
1975 | JoinKind::AsOf
1976 | JoinKind::AsOfLeft
1977 | JoinKind::AsOfRight
1978 | JoinKind::Lateral
1979 | JoinKind::LeftLateral
1980 )
1981}
1982
1983fn check_query_reference_quality(
1984 stmt: &Expression,
1985 schema_map: &HashMap<String, TableSchemaEntry>,
1986 resolver_schema: &MappingSchema,
1987 strict: bool,
1988 relationships: &[DeclaredRelationship],
1989) -> Vec<ValidationError> {
1990 let mut errors = Vec::new();
1991
1992 for node in stmt.dfs() {
1993 let Expression::Select(select) = node else {
1994 continue;
1995 };
1996
1997 let select_expr = Expression::Select(select.clone());
1998 let context = collect_type_check_context(&select_expr, schema_map);
1999 let scope = build_scope(&select_expr);
2000 let mut resolver = Resolver::new(&scope, resolver_schema, true);
2001
2002 if context.referenced_tables.len() > 1 {
2003 let using_columns: HashSet<String> = select
2004 .joins
2005 .iter()
2006 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2007 .collect();
2008
2009 let mut seen = HashSet::new();
2010 for column_expr in select_expr
2011 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2012 {
2013 let Expression::Column(column) = column_expr else {
2014 continue;
2015 };
2016
2017 let col_name = lower(&column.name.name);
2018 if col_name.is_empty()
2019 || using_columns.contains(&col_name)
2020 || !seen.insert(col_name.clone())
2021 {
2022 continue;
2023 }
2024
2025 if resolver.is_ambiguous(&col_name) {
2026 let source_count = resolver.sources_for_column(&col_name).len();
2027 errors.push(if strict {
2028 ValidationError::error(
2029 format!(
2030 "Ambiguous unqualified column '{}' found in {} referenced tables",
2031 col_name, source_count
2032 ),
2033 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2034 )
2035 } else {
2036 ValidationError::warning(
2037 format!(
2038 "Ambiguous unqualified column '{}' found in {} referenced tables",
2039 col_name, source_count
2040 ),
2041 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2042 )
2043 });
2044 }
2045 }
2046 }
2047
2048 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2049
2050 for join in &select.joins {
2051 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2052 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2053 let cartesian_like_kind = matches!(
2054 join.kind,
2055 JoinKind::Cross
2056 | JoinKind::Implicit
2057 | JoinKind::Array
2058 | JoinKind::LeftArray
2059 | JoinKind::Paste
2060 );
2061
2062 if right_table_key.is_some()
2063 && (cartesian_like_kind
2064 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2065 {
2066 errors.push(ValidationError::warning(
2067 "Potential cartesian join: JOIN without ON/USING condition",
2068 validation_codes::W_CARTESIAN_JOIN,
2069 ));
2070 }
2071
2072 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2073 if join.using.is_empty() {
2074 let mut eq_pairs = Vec::new();
2075 extract_join_equality_pairs(
2076 on_expr,
2077 schema_map,
2078 &context,
2079 &mut resolver,
2080 &mut eq_pairs,
2081 );
2082
2083 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2084 .iter()
2085 .filter(|rel| {
2086 cumulative_left_tables.contains(&rel.source_table)
2087 && rel.target_table == right_key
2088 || (cumulative_left_tables.contains(&rel.target_table)
2089 && rel.source_table == right_key)
2090 })
2091 .collect();
2092
2093 if !relevant_relationships.is_empty() {
2094 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2095 relevant_relationships
2096 .iter()
2097 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2098 });
2099 if !uses_declared_fk {
2100 errors.push(ValidationError::warning(
2101 "JOIN predicate does not use declared foreign-key relationship columns",
2102 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2103 ));
2104 }
2105 }
2106 }
2107 }
2108
2109 if let Some(right_key) = right_table_key {
2110 cumulative_left_tables.insert(right_key);
2111 }
2112 }
2113 }
2114
2115 errors
2116}
2117
2118fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2119 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2120 return true;
2121 }
2122 if left == right {
2123 return true;
2124 }
2125 if left.is_numeric() && right.is_numeric() {
2126 return true;
2127 }
2128 if left.is_temporal() && right.is_temporal() {
2129 return true;
2130 }
2131 false
2132}
2133
2134fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2135 if left == TypeFamily::Unknown {
2136 return right;
2137 }
2138 if right == TypeFamily::Unknown {
2139 return left;
2140 }
2141 if left == right {
2142 return left;
2143 }
2144 if left.is_numeric() && right.is_numeric() {
2145 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2146 return TypeFamily::Numeric;
2147 }
2148 return TypeFamily::Integer;
2149 }
2150 if left.is_temporal() && right.is_temporal() {
2151 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2152 return TypeFamily::Timestamp;
2153 }
2154 if left == TypeFamily::Date || right == TypeFamily::Date {
2155 return TypeFamily::Date;
2156 }
2157 return TypeFamily::Time;
2158 }
2159 TypeFamily::Unknown
2160}
2161
2162fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2163 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2164 return true;
2165 }
2166 if target == source {
2167 return true;
2168 }
2169
2170 match target {
2171 TypeFamily::Boolean => source == TypeFamily::Boolean,
2172 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2173 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2174 source.is_temporal()
2175 }
2176 TypeFamily::String => true,
2177 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2178 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2179 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2180 TypeFamily::Array => source == TypeFamily::Array,
2181 TypeFamily::Map => source == TypeFamily::Map,
2182 TypeFamily::Struct => source == TypeFamily::Struct,
2183 TypeFamily::Unknown => true,
2184 }
2185}
2186
2187fn projection_families(
2188 query_expr: &Expression,
2189 schema_map: &HashMap<String, TableSchemaEntry>,
2190) -> Option<Vec<TypeFamily>> {
2191 match query_expr {
2192 Expression::Select(select) => {
2193 if select
2194 .expressions
2195 .iter()
2196 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2197 {
2198 return None;
2199 }
2200 let select_expr = Expression::Select(select.clone());
2201 let context = collect_type_check_context(&select_expr, schema_map);
2202 Some(
2203 select
2204 .expressions
2205 .iter()
2206 .map(|e| infer_expression_type_family(e, schema_map, &context))
2207 .collect(),
2208 )
2209 }
2210 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2211 Expression::Union(union) => {
2212 let left = projection_families(&union.left, schema_map)?;
2213 let right = projection_families(&union.right, schema_map)?;
2214 if left.len() != right.len() {
2215 return None;
2216 }
2217 Some(
2218 left.into_iter()
2219 .zip(right)
2220 .map(|(l, r)| merged_setop_family(l, r))
2221 .collect(),
2222 )
2223 }
2224 Expression::Intersect(intersect) => {
2225 let left = projection_families(&intersect.left, schema_map)?;
2226 let right = projection_families(&intersect.right, schema_map)?;
2227 if left.len() != right.len() {
2228 return None;
2229 }
2230 Some(
2231 left.into_iter()
2232 .zip(right)
2233 .map(|(l, r)| merged_setop_family(l, r))
2234 .collect(),
2235 )
2236 }
2237 Expression::Except(except) => {
2238 let left = projection_families(&except.left, schema_map)?;
2239 let right = projection_families(&except.right, schema_map)?;
2240 if left.len() != right.len() {
2241 return None;
2242 }
2243 Some(
2244 left.into_iter()
2245 .zip(right)
2246 .map(|(l, r)| merged_setop_family(l, r))
2247 .collect(),
2248 )
2249 }
2250 Expression::Values(values) => {
2251 let first_row = values.expressions.first()?;
2252 let context = TypeCheckContext::default();
2253 Some(
2254 first_row
2255 .expressions
2256 .iter()
2257 .map(|e| infer_expression_type_family(e, schema_map, &context))
2258 .collect(),
2259 )
2260 }
2261 _ => None,
2262 }
2263}
2264
2265fn check_set_operation_compatibility(
2266 op_name: &str,
2267 left_expr: &Expression,
2268 right_expr: &Expression,
2269 schema_map: &HashMap<String, TableSchemaEntry>,
2270 strict: bool,
2271 errors: &mut Vec<ValidationError>,
2272) {
2273 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2274 return;
2275 };
2276 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2277 return;
2278 };
2279
2280 if left_projection.len() != right_projection.len() {
2281 errors.push(type_issue(
2282 strict,
2283 validation_codes::E_SETOP_ARITY_MISMATCH,
2284 validation_codes::W_SETOP_IMPLICIT_COERCION,
2285 format!(
2286 "{} operands return different column counts: left {}, right {}",
2287 op_name,
2288 left_projection.len(),
2289 right_projection.len()
2290 ),
2291 ));
2292 return;
2293 }
2294
2295 for (idx, (left, right)) in left_projection
2296 .into_iter()
2297 .zip(right_projection)
2298 .enumerate()
2299 {
2300 if !are_setop_compatible(left, right) {
2301 errors.push(type_issue(
2302 strict,
2303 validation_codes::E_SETOP_TYPE_MISMATCH,
2304 validation_codes::W_SETOP_IMPLICIT_COERCION,
2305 format!(
2306 "{} column {} has incompatible types: {} vs {}",
2307 op_name,
2308 idx + 1,
2309 type_family_name(left),
2310 type_family_name(right)
2311 ),
2312 ));
2313 }
2314 }
2315}
2316
2317fn check_insert_assignments(
2318 stmt: &Expression,
2319 insert: &Insert,
2320 schema_map: &HashMap<String, TableSchemaEntry>,
2321 strict: bool,
2322 errors: &mut Vec<ValidationError>,
2323) {
2324 let Some((target_table_key, table_schema)) =
2325 resolve_table_schema_entry(&insert.table, schema_map)
2326 else {
2327 return;
2328 };
2329
2330 let mut target_columns = Vec::new();
2331 if insert.columns.is_empty() {
2332 target_columns.extend(table_schema.column_order.iter().cloned());
2333 } else {
2334 for column in &insert.columns {
2335 let col_name = lower(&column.name);
2336 if table_schema.columns.contains_key(&col_name) {
2337 target_columns.push(col_name);
2338 } else {
2339 errors.push(if strict {
2340 ValidationError::error(
2341 format!(
2342 "Unknown column '{}' in table '{}'",
2343 column.name, target_table_key
2344 ),
2345 validation_codes::E_UNKNOWN_COLUMN,
2346 )
2347 } else {
2348 ValidationError::warning(
2349 format!(
2350 "Unknown column '{}' in table '{}'",
2351 column.name, target_table_key
2352 ),
2353 validation_codes::E_UNKNOWN_COLUMN,
2354 )
2355 });
2356 }
2357 }
2358 }
2359
2360 if target_columns.is_empty() {
2361 return;
2362 }
2363
2364 let context = collect_type_check_context(stmt, schema_map);
2365
2366 if !insert.default_values {
2367 for (row_idx, row) in insert.values.iter().enumerate() {
2368 if row.len() != target_columns.len() {
2369 errors.push(type_issue(
2370 strict,
2371 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2372 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2373 format!(
2374 "INSERT row {} has {} values but target has {} columns",
2375 row_idx + 1,
2376 row.len(),
2377 target_columns.len()
2378 ),
2379 ));
2380 continue;
2381 }
2382
2383 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2384 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2385 continue;
2386 };
2387 let source_family = infer_expression_type_family(value, schema_map, &context);
2388 if !are_assignment_compatible(target_family, source_family) {
2389 errors.push(type_issue(
2390 strict,
2391 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2392 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2393 format!(
2394 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2395 target_table_key,
2396 target_column,
2397 type_family_name(target_family),
2398 type_family_name(source_family)
2399 ),
2400 ));
2401 }
2402 }
2403 }
2404 }
2405
2406 if let Some(query) = &insert.query {
2407 if insert.by_name {
2409 return;
2410 }
2411
2412 let Some(source_projection) = projection_families(query, schema_map) else {
2413 return;
2414 };
2415
2416 if source_projection.len() != target_columns.len() {
2417 errors.push(type_issue(
2418 strict,
2419 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2420 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2421 format!(
2422 "INSERT source query has {} columns but target has {} columns",
2423 source_projection.len(),
2424 target_columns.len()
2425 ),
2426 ));
2427 return;
2428 }
2429
2430 for (source_family, target_column) in
2431 source_projection.into_iter().zip(target_columns.iter())
2432 {
2433 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2434 continue;
2435 };
2436 if !are_assignment_compatible(target_family, source_family) {
2437 errors.push(type_issue(
2438 strict,
2439 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2440 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2441 format!(
2442 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2443 target_table_key,
2444 target_column,
2445 type_family_name(target_family),
2446 type_family_name(source_family)
2447 ),
2448 ));
2449 }
2450 }
2451 }
2452}
2453
2454fn check_update_assignments(
2455 stmt: &Expression,
2456 update: &Update,
2457 schema_map: &HashMap<String, TableSchemaEntry>,
2458 strict: bool,
2459 errors: &mut Vec<ValidationError>,
2460) {
2461 let Some((target_table_key, table_schema)) =
2462 resolve_table_schema_entry(&update.table, schema_map)
2463 else {
2464 return;
2465 };
2466
2467 let context = collect_type_check_context(stmt, schema_map);
2468
2469 for (column, value) in &update.set {
2470 let col_name = lower(&column.name);
2471 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2472 errors.push(if strict {
2473 ValidationError::error(
2474 format!(
2475 "Unknown column '{}' in table '{}'",
2476 column.name, target_table_key
2477 ),
2478 validation_codes::E_UNKNOWN_COLUMN,
2479 )
2480 } else {
2481 ValidationError::warning(
2482 format!(
2483 "Unknown column '{}' in table '{}'",
2484 column.name, target_table_key
2485 ),
2486 validation_codes::E_UNKNOWN_COLUMN,
2487 )
2488 });
2489 continue;
2490 };
2491
2492 let source_family = infer_expression_type_family(value, schema_map, &context);
2493 if !are_assignment_compatible(target_family, source_family) {
2494 errors.push(type_issue(
2495 strict,
2496 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2497 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2498 format!(
2499 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2500 target_table_key,
2501 col_name,
2502 type_family_name(target_family),
2503 type_family_name(source_family)
2504 ),
2505 ));
2506 }
2507 }
2508}
2509
2510fn check_types(
2511 stmt: &Expression,
2512 dialect: DialectType,
2513 schema_map: &HashMap<String, TableSchemaEntry>,
2514 function_catalog: Option<&dyn FunctionCatalog>,
2515 strict: bool,
2516) -> Vec<ValidationError> {
2517 let mut errors = Vec::new();
2518 let context = collect_type_check_context(stmt, schema_map);
2519
2520 for node in stmt.dfs() {
2521 match node {
2522 Expression::Insert(insert) => {
2523 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2524 }
2525 Expression::Update(update) => {
2526 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2527 }
2528 Expression::Union(union) => {
2529 check_set_operation_compatibility(
2530 "UNION",
2531 &union.left,
2532 &union.right,
2533 schema_map,
2534 strict,
2535 &mut errors,
2536 );
2537 }
2538 Expression::Intersect(intersect) => {
2539 check_set_operation_compatibility(
2540 "INTERSECT",
2541 &intersect.left,
2542 &intersect.right,
2543 schema_map,
2544 strict,
2545 &mut errors,
2546 );
2547 }
2548 Expression::Except(except) => {
2549 check_set_operation_compatibility(
2550 "EXCEPT",
2551 &except.left,
2552 &except.right,
2553 schema_map,
2554 strict,
2555 &mut errors,
2556 );
2557 }
2558 Expression::Select(select) => {
2559 if let Some(prewhere) = &select.prewhere {
2560 let family = infer_expression_type_family(prewhere, schema_map, &context);
2561 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2562 errors.push(type_issue(
2563 strict,
2564 validation_codes::E_INVALID_PREDICATE_TYPE,
2565 validation_codes::W_PREDICATE_NULLABILITY,
2566 format!(
2567 "PREWHERE clause expects a boolean predicate, found {}",
2568 type_family_name(family)
2569 ),
2570 ));
2571 }
2572 }
2573
2574 if let Some(where_clause) = &select.where_clause {
2575 let family =
2576 infer_expression_type_family(&where_clause.this, schema_map, &context);
2577 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578 errors.push(type_issue(
2579 strict,
2580 validation_codes::E_INVALID_PREDICATE_TYPE,
2581 validation_codes::W_PREDICATE_NULLABILITY,
2582 format!(
2583 "WHERE clause expects a boolean predicate, found {}",
2584 type_family_name(family)
2585 ),
2586 ));
2587 }
2588 }
2589
2590 if let Some(having_clause) = &select.having {
2591 let family =
2592 infer_expression_type_family(&having_clause.this, schema_map, &context);
2593 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594 errors.push(type_issue(
2595 strict,
2596 validation_codes::E_INVALID_PREDICATE_TYPE,
2597 validation_codes::W_PREDICATE_NULLABILITY,
2598 format!(
2599 "HAVING clause expects a boolean predicate, found {}",
2600 type_family_name(family)
2601 ),
2602 ));
2603 }
2604 }
2605
2606 for join in &select.joins {
2607 if let Some(on) = &join.on {
2608 let family = infer_expression_type_family(on, schema_map, &context);
2609 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2610 errors.push(type_issue(
2611 strict,
2612 validation_codes::E_INVALID_PREDICATE_TYPE,
2613 validation_codes::W_PREDICATE_NULLABILITY,
2614 format!(
2615 "JOIN ON expects a boolean predicate, found {}",
2616 type_family_name(family)
2617 ),
2618 ));
2619 }
2620 }
2621 if let Some(match_condition) = &join.match_condition {
2622 let family =
2623 infer_expression_type_family(match_condition, schema_map, &context);
2624 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2625 errors.push(type_issue(
2626 strict,
2627 validation_codes::E_INVALID_PREDICATE_TYPE,
2628 validation_codes::W_PREDICATE_NULLABILITY,
2629 format!(
2630 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2631 type_family_name(family)
2632 ),
2633 ));
2634 }
2635 }
2636 }
2637 }
2638 Expression::Where(where_clause) => {
2639 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2640 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2641 errors.push(type_issue(
2642 strict,
2643 validation_codes::E_INVALID_PREDICATE_TYPE,
2644 validation_codes::W_PREDICATE_NULLABILITY,
2645 format!(
2646 "WHERE clause expects a boolean predicate, found {}",
2647 type_family_name(family)
2648 ),
2649 ));
2650 }
2651 }
2652 Expression::Having(having_clause) => {
2653 let family =
2654 infer_expression_type_family(&having_clause.this, schema_map, &context);
2655 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2656 errors.push(type_issue(
2657 strict,
2658 validation_codes::E_INVALID_PREDICATE_TYPE,
2659 validation_codes::W_PREDICATE_NULLABILITY,
2660 format!(
2661 "HAVING clause expects a boolean predicate, found {}",
2662 type_family_name(family)
2663 ),
2664 ));
2665 }
2666 }
2667 Expression::And(op) | Expression::Or(op) => {
2668 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2669 let family = infer_expression_type_family(expr, schema_map, &context);
2670 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2671 errors.push(type_issue(
2672 strict,
2673 validation_codes::E_INVALID_PREDICATE_TYPE,
2674 validation_codes::W_PREDICATE_NULLABILITY,
2675 format!(
2676 "Logical {} operand expects boolean, found {}",
2677 side,
2678 type_family_name(family)
2679 ),
2680 ));
2681 }
2682 }
2683 }
2684 Expression::Not(unary) => {
2685 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2686 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2687 errors.push(type_issue(
2688 strict,
2689 validation_codes::E_INVALID_PREDICATE_TYPE,
2690 validation_codes::W_PREDICATE_NULLABILITY,
2691 format!("NOT expects boolean, found {}", type_family_name(family)),
2692 ));
2693 }
2694 }
2695 Expression::Eq(op)
2696 | Expression::Neq(op)
2697 | Expression::Lt(op)
2698 | Expression::Lte(op)
2699 | Expression::Gt(op)
2700 | Expression::Gte(op) => {
2701 let left = infer_expression_type_family(&op.left, schema_map, &context);
2702 let right = infer_expression_type_family(&op.right, schema_map, &context);
2703 if !are_comparable(left, right) {
2704 errors.push(type_issue(
2705 strict,
2706 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2707 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2708 format!(
2709 "Incompatible comparison between {} and {}",
2710 type_family_name(left),
2711 type_family_name(right)
2712 ),
2713 ));
2714 }
2715 }
2716 Expression::Like(op) | Expression::ILike(op) => {
2717 let left = infer_expression_type_family(&op.left, schema_map, &context);
2718 let right = infer_expression_type_family(&op.right, schema_map, &context);
2719 if left != TypeFamily::Unknown
2720 && right != TypeFamily::Unknown
2721 && (!is_string_like(left) || !is_string_like(right))
2722 {
2723 errors.push(type_issue(
2724 strict,
2725 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2726 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2727 format!(
2728 "LIKE/ILIKE expects string operands, found {} and {}",
2729 type_family_name(left),
2730 type_family_name(right)
2731 ),
2732 ));
2733 }
2734 }
2735 Expression::Between(between) => {
2736 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2737 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2738 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2739
2740 if !are_comparable(this_family, low_family)
2741 || !are_comparable(this_family, high_family)
2742 {
2743 errors.push(type_issue(
2744 strict,
2745 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2746 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2747 format!(
2748 "BETWEEN bounds are incompatible with {} (found {} and {})",
2749 type_family_name(this_family),
2750 type_family_name(low_family),
2751 type_family_name(high_family)
2752 ),
2753 ));
2754 }
2755 }
2756 Expression::In(in_expr) => {
2757 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2758 for value in &in_expr.expressions {
2759 let item_family = infer_expression_type_family(value, schema_map, &context);
2760 if !are_comparable(this_family, item_family) {
2761 errors.push(type_issue(
2762 strict,
2763 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2764 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2765 format!(
2766 "IN item type {} is incompatible with {}",
2767 type_family_name(item_family),
2768 type_family_name(this_family)
2769 ),
2770 ));
2771 break;
2772 }
2773 }
2774 }
2775 Expression::Add(op)
2776 | Expression::Sub(op)
2777 | Expression::Mul(op)
2778 | Expression::Div(op)
2779 | Expression::Mod(op) => {
2780 let left = infer_expression_type_family(&op.left, schema_map, &context);
2781 let right = infer_expression_type_family(&op.right, schema_map, &context);
2782
2783 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2784 continue;
2785 }
2786
2787 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2788 && ((left.is_temporal() && right.is_numeric())
2789 || (right.is_temporal() && left.is_numeric())
2790 || (matches!(node, Expression::Sub(_))
2791 && left.is_temporal()
2792 && right.is_temporal()));
2793
2794 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2795 errors.push(type_issue(
2796 strict,
2797 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2798 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2799 format!(
2800 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2801 type_family_name(left),
2802 type_family_name(right)
2803 ),
2804 ));
2805 }
2806 }
2807 Expression::Function(function) => {
2808 check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2809 check_generic_function(function, schema_map, &context, strict, &mut errors);
2810 }
2811 Expression::Upper(func)
2812 | Expression::Lower(func)
2813 | Expression::LTrim(func)
2814 | Expression::RTrim(func)
2815 | Expression::Reverse(func) => {
2816 let family = infer_expression_type_family(&func.this, schema_map, &context);
2817 check_function_argument(
2818 &mut errors,
2819 strict,
2820 "string_function",
2821 0,
2822 family,
2823 "a string argument",
2824 is_string_like(family),
2825 );
2826 }
2827 Expression::Length(func) => {
2828 let family = infer_expression_type_family(&func.this, schema_map, &context);
2829 check_function_argument(
2830 &mut errors,
2831 strict,
2832 "length",
2833 0,
2834 family,
2835 "a string or binary argument",
2836 is_string_or_binary(family),
2837 );
2838 }
2839 Expression::Trim(func) => {
2840 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2841 check_function_argument(
2842 &mut errors,
2843 strict,
2844 "trim",
2845 0,
2846 this_family,
2847 "a string argument",
2848 is_string_like(this_family),
2849 );
2850 if let Some(chars) = &func.characters {
2851 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2852 check_function_argument(
2853 &mut errors,
2854 strict,
2855 "trim",
2856 1,
2857 chars_family,
2858 "a string argument",
2859 is_string_like(chars_family),
2860 );
2861 }
2862 }
2863 Expression::Substring(func) => {
2864 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2865 check_function_argument(
2866 &mut errors,
2867 strict,
2868 "substring",
2869 0,
2870 this_family,
2871 "a string argument",
2872 is_string_like(this_family),
2873 );
2874
2875 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2876 check_function_argument(
2877 &mut errors,
2878 strict,
2879 "substring",
2880 1,
2881 start_family,
2882 "a numeric argument",
2883 start_family.is_numeric(),
2884 );
2885 if let Some(length) = &func.length {
2886 let length_family = infer_expression_type_family(length, schema_map, &context);
2887 check_function_argument(
2888 &mut errors,
2889 strict,
2890 "substring",
2891 2,
2892 length_family,
2893 "a numeric argument",
2894 length_family.is_numeric(),
2895 );
2896 }
2897 }
2898 Expression::Replace(func) => {
2899 for (arg, idx) in [
2900 (&func.this, 0_usize),
2901 (&func.old, 1_usize),
2902 (&func.new, 2_usize),
2903 ] {
2904 let family = infer_expression_type_family(arg, schema_map, &context);
2905 check_function_argument(
2906 &mut errors,
2907 strict,
2908 "replace",
2909 idx,
2910 family,
2911 "a string argument",
2912 is_string_like(family),
2913 );
2914 }
2915 }
2916 Expression::Left(func) | Expression::Right(func) => {
2917 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2918 check_function_argument(
2919 &mut errors,
2920 strict,
2921 "left_right",
2922 0,
2923 this_family,
2924 "a string argument",
2925 is_string_like(this_family),
2926 );
2927 let length_family =
2928 infer_expression_type_family(&func.length, schema_map, &context);
2929 check_function_argument(
2930 &mut errors,
2931 strict,
2932 "left_right",
2933 1,
2934 length_family,
2935 "a numeric argument",
2936 length_family.is_numeric(),
2937 );
2938 }
2939 Expression::Repeat(func) => {
2940 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2941 check_function_argument(
2942 &mut errors,
2943 strict,
2944 "repeat",
2945 0,
2946 this_family,
2947 "a string argument",
2948 is_string_like(this_family),
2949 );
2950 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2951 check_function_argument(
2952 &mut errors,
2953 strict,
2954 "repeat",
2955 1,
2956 times_family,
2957 "a numeric argument",
2958 times_family.is_numeric(),
2959 );
2960 }
2961 Expression::Lpad(func) | Expression::Rpad(func) => {
2962 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2963 check_function_argument(
2964 &mut errors,
2965 strict,
2966 "pad",
2967 0,
2968 this_family,
2969 "a string argument",
2970 is_string_like(this_family),
2971 );
2972 let length_family =
2973 infer_expression_type_family(&func.length, schema_map, &context);
2974 check_function_argument(
2975 &mut errors,
2976 strict,
2977 "pad",
2978 1,
2979 length_family,
2980 "a numeric argument",
2981 length_family.is_numeric(),
2982 );
2983 if let Some(fill) = &func.fill {
2984 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2985 check_function_argument(
2986 &mut errors,
2987 strict,
2988 "pad",
2989 2,
2990 fill_family,
2991 "a string argument",
2992 is_string_like(fill_family),
2993 );
2994 }
2995 }
2996 Expression::Abs(func)
2997 | Expression::Sqrt(func)
2998 | Expression::Cbrt(func)
2999 | Expression::Ln(func)
3000 | Expression::Exp(func) => {
3001 let family = infer_expression_type_family(&func.this, schema_map, &context);
3002 check_function_argument(
3003 &mut errors,
3004 strict,
3005 "numeric_function",
3006 0,
3007 family,
3008 "a numeric argument",
3009 family.is_numeric(),
3010 );
3011 }
3012 Expression::Round(func) => {
3013 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3014 check_function_argument(
3015 &mut errors,
3016 strict,
3017 "round",
3018 0,
3019 this_family,
3020 "a numeric argument",
3021 this_family.is_numeric(),
3022 );
3023 if let Some(decimals) = &func.decimals {
3024 let decimals_family =
3025 infer_expression_type_family(decimals, schema_map, &context);
3026 check_function_argument(
3027 &mut errors,
3028 strict,
3029 "round",
3030 1,
3031 decimals_family,
3032 "a numeric argument",
3033 decimals_family.is_numeric(),
3034 );
3035 }
3036 }
3037 Expression::Floor(func) => {
3038 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3039 check_function_argument(
3040 &mut errors,
3041 strict,
3042 "floor",
3043 0,
3044 this_family,
3045 "a numeric argument",
3046 this_family.is_numeric(),
3047 );
3048 if let Some(scale) = &func.scale {
3049 let scale_family = infer_expression_type_family(scale, schema_map, &context);
3050 check_function_argument(
3051 &mut errors,
3052 strict,
3053 "floor",
3054 1,
3055 scale_family,
3056 "a numeric argument",
3057 scale_family.is_numeric(),
3058 );
3059 }
3060 }
3061 Expression::Ceil(func) => {
3062 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3063 check_function_argument(
3064 &mut errors,
3065 strict,
3066 "ceil",
3067 0,
3068 this_family,
3069 "a numeric argument",
3070 this_family.is_numeric(),
3071 );
3072 if let Some(decimals) = &func.decimals {
3073 let decimals_family =
3074 infer_expression_type_family(decimals, schema_map, &context);
3075 check_function_argument(
3076 &mut errors,
3077 strict,
3078 "ceil",
3079 1,
3080 decimals_family,
3081 "a numeric argument",
3082 decimals_family.is_numeric(),
3083 );
3084 }
3085 }
3086 Expression::Power(func) => {
3087 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3088 check_function_argument(
3089 &mut errors,
3090 strict,
3091 "power",
3092 0,
3093 left_family,
3094 "a numeric argument",
3095 left_family.is_numeric(),
3096 );
3097 let right_family =
3098 infer_expression_type_family(&func.expression, schema_map, &context);
3099 check_function_argument(
3100 &mut errors,
3101 strict,
3102 "power",
3103 1,
3104 right_family,
3105 "a numeric argument",
3106 right_family.is_numeric(),
3107 );
3108 }
3109 Expression::Log(func) => {
3110 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3111 check_function_argument(
3112 &mut errors,
3113 strict,
3114 "log",
3115 0,
3116 this_family,
3117 "a numeric argument",
3118 this_family.is_numeric(),
3119 );
3120 if let Some(base) = &func.base {
3121 let base_family = infer_expression_type_family(base, schema_map, &context);
3122 check_function_argument(
3123 &mut errors,
3124 strict,
3125 "log",
3126 1,
3127 base_family,
3128 "a numeric argument",
3129 base_family.is_numeric(),
3130 );
3131 }
3132 }
3133 _ => {}
3134 }
3135 }
3136
3137 errors
3138}
3139
3140fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3141 let mut errors = Vec::new();
3142
3143 let Expression::Select(select) = stmt else {
3144 return errors;
3145 };
3146 let select_expr = Expression::Select(select.clone());
3147
3148 if !select_expr
3150 .find_all(|e| matches!(e, Expression::Star(_)))
3151 .is_empty()
3152 {
3153 errors.push(ValidationError::warning(
3154 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3155 validation_codes::W_SELECT_STAR,
3156 ));
3157 }
3158
3159 let aggregate_count = get_aggregate_functions(&select_expr).len();
3161 if aggregate_count > 0 && select.group_by.is_none() {
3162 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3163 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3164 && get_aggregate_functions(expr).is_empty()
3165 });
3166
3167 if has_non_aggregate_column {
3168 errors.push(ValidationError::warning(
3169 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3170 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3171 ));
3172 }
3173 }
3174
3175 if select.distinct && select.order_by.is_some() {
3177 errors.push(ValidationError::warning(
3178 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3179 validation_codes::W_DISTINCT_ORDER_BY,
3180 ));
3181 }
3182
3183 if select.limit.is_some() && select.order_by.is_none() {
3185 errors.push(ValidationError::warning(
3186 "LIMIT without ORDER BY produces non-deterministic results",
3187 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3188 ));
3189 }
3190
3191 errors
3192}
3193
3194fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3195 scope
3196 .sources
3197 .get_key_value(name)
3198 .map(|(k, _)| k.clone())
3199 .or_else(|| {
3200 scope
3201 .sources
3202 .keys()
3203 .find(|source| source.eq_ignore_ascii_case(name))
3204 .cloned()
3205 })
3206}
3207
3208fn source_has_column(columns: &[String], column_name: &str) -> bool {
3209 columns
3210 .iter()
3211 .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3212}
3213
3214fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3215 scope
3216 .sources
3217 .get(source_name)
3218 .map(|source| match &source.expression {
3219 Expression::Table(table) => lower(&table_ref_display_name(table)),
3220 _ => lower(source_name),
3221 })
3222 .unwrap_or_else(|| lower(source_name))
3223}
3224
3225fn validate_select_columns_with_schema(
3226 select: &crate::expressions::Select,
3227 schema_map: &HashMap<String, TableSchemaEntry>,
3228 resolver_schema: &MappingSchema,
3229 strict: bool,
3230) -> Vec<ValidationError> {
3231 let mut errors = Vec::new();
3232 let select_expr = Expression::Select(Box::new(select.clone()));
3233 let scope = build_scope(&select_expr);
3234 let mut resolver = Resolver::new(&scope, resolver_schema, true);
3235 let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3236
3237 for node in walk_in_scope(&select_expr, false) {
3238 let Expression::Column(column) = node else {
3239 continue;
3240 };
3241
3242 let col_name = lower(&column.name.name);
3243 if col_name.is_empty() {
3244 continue;
3245 }
3246
3247 if let Some(table) = &column.table {
3248 let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3249 errors.push(if strict {
3251 ValidationError::error(
3252 format!(
3253 "Unknown table or alias '{}' referenced by column '{}'",
3254 table.name, col_name
3255 ),
3256 validation_codes::E_UNRESOLVED_REFERENCE,
3257 )
3258 } else {
3259 ValidationError::warning(
3260 format!(
3261 "Unknown table or alias '{}' referenced by column '{}'",
3262 table.name, col_name
3263 ),
3264 validation_codes::E_UNRESOLVED_REFERENCE,
3265 )
3266 });
3267 continue;
3268 };
3269
3270 if let Ok(columns) = resolver.get_source_columns(&source_name) {
3271 if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3272 let table_name = source_display_name(&scope, &source_name);
3273 errors.push(if strict {
3274 ValidationError::error(
3275 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3276 validation_codes::E_UNKNOWN_COLUMN,
3277 )
3278 } else {
3279 ValidationError::warning(
3280 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3281 validation_codes::E_UNKNOWN_COLUMN,
3282 )
3283 });
3284 }
3285 }
3286 continue;
3287 }
3288
3289 let matching_sources: Vec<String> = source_names
3290 .iter()
3291 .filter_map(|source_name| {
3292 resolver
3293 .get_source_columns(source_name)
3294 .ok()
3295 .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3296 .map(|_| source_name.clone())
3297 })
3298 .collect();
3299
3300 if !matching_sources.is_empty() {
3301 continue;
3302 }
3303
3304 let known_sources: Vec<String> = source_names
3305 .iter()
3306 .filter_map(|source_name| {
3307 resolver
3308 .get_source_columns(source_name)
3309 .ok()
3310 .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3311 .map(|_| source_name.clone())
3312 })
3313 .collect();
3314
3315 if known_sources.len() == 1 {
3316 let table_name = source_display_name(&scope, &known_sources[0]);
3317 errors.push(if strict {
3318 ValidationError::error(
3319 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3320 validation_codes::E_UNKNOWN_COLUMN,
3321 )
3322 } else {
3323 ValidationError::warning(
3324 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3325 validation_codes::E_UNKNOWN_COLUMN,
3326 )
3327 });
3328 } else if known_sources.len() > 1 {
3329 errors.push(if strict {
3330 ValidationError::error(
3331 format!(
3332 "Unknown column '{}' (not found in any referenced table)",
3333 col_name
3334 ),
3335 validation_codes::E_UNKNOWN_COLUMN,
3336 )
3337 } else {
3338 ValidationError::warning(
3339 format!(
3340 "Unknown column '{}' (not found in any referenced table)",
3341 col_name
3342 ),
3343 validation_codes::E_UNKNOWN_COLUMN,
3344 )
3345 });
3346 } else if !schema_map.is_empty() {
3347 let found = schema_map
3348 .values()
3349 .any(|table_schema| table_schema.columns.contains_key(&col_name));
3350 if !found {
3351 errors.push(if strict {
3352 ValidationError::error(
3353 format!("Unknown column '{}'", col_name),
3354 validation_codes::E_UNKNOWN_COLUMN,
3355 )
3356 } else {
3357 ValidationError::warning(
3358 format!("Unknown column '{}'", col_name),
3359 validation_codes::E_UNKNOWN_COLUMN,
3360 )
3361 });
3362 }
3363 }
3364 }
3365
3366 errors
3367}
3368
3369fn validate_statement_with_schema(
3370 stmt: &Expression,
3371 schema_map: &HashMap<String, TableSchemaEntry>,
3372 resolver_schema: &MappingSchema,
3373 strict: bool,
3374) -> Vec<ValidationError> {
3375 let mut errors = Vec::new();
3376 let cte_aliases = collect_cte_aliases(stmt);
3377 let mut seen_tables: HashSet<String> = HashSet::new();
3378
3379 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3381 let Expression::Table(table) = node else {
3382 continue;
3383 };
3384
3385 if cte_aliases.contains(&lower(&table.name.name)) {
3386 continue;
3387 }
3388
3389 let resolved_key = table_ref_candidates(table)
3390 .into_iter()
3391 .find(|k| schema_map.contains_key(k));
3392 let table_key = resolved_key
3393 .clone()
3394 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3395
3396 if !seen_tables.insert(table_key) {
3397 continue;
3398 }
3399
3400 if resolved_key.is_none() {
3401 errors.push(if strict {
3402 ValidationError::error(
3403 format!("Unknown table '{}'", table_ref_display_name(table)),
3404 validation_codes::E_UNKNOWN_TABLE,
3405 )
3406 } else {
3407 ValidationError::warning(
3408 format!("Unknown table '{}'", table_ref_display_name(table)),
3409 validation_codes::E_UNKNOWN_TABLE,
3410 )
3411 });
3412 }
3413 }
3414
3415 for node in stmt.dfs() {
3416 let Expression::Select(select) = node else {
3417 continue;
3418 };
3419 errors.extend(validate_select_columns_with_schema(
3420 select,
3421 schema_map,
3422 resolver_schema,
3423 strict,
3424 ));
3425 }
3426
3427 errors
3428}
3429
3430pub fn validate_with_schema(
3432 sql: &str,
3433 dialect: DialectType,
3434 schema: &ValidationSchema,
3435 options: &SchemaValidationOptions,
3436) -> ValidationResult {
3437 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3438
3439 let syntax_result = crate::validate_with_options(
3441 sql,
3442 dialect,
3443 &crate::ValidationOptions {
3444 strict_syntax: options.strict_syntax,
3445 },
3446 );
3447 if !syntax_result.valid {
3448 return syntax_result;
3449 }
3450
3451 let d = Dialect::get(dialect);
3452 let statements = match d.parse(sql) {
3453 Ok(exprs) => exprs,
3454 Err(e) => {
3455 return ValidationResult::with_errors(vec![ValidationError::error(
3456 e.to_string(),
3457 validation_codes::E_PARSE_OR_OPTIONS,
3458 )]);
3459 }
3460 };
3461
3462 let schema_map = build_schema_map(schema);
3463 let resolver_schema = build_resolver_schema(schema);
3464 let mut all_errors = syntax_result.errors;
3465 let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3466 default_embedded_function_catalog()
3467 } else {
3468 None
3469 };
3470 let effective_function_catalog = options
3471 .function_catalog
3472 .as_deref()
3473 .or_else(|| embedded_function_catalog.as_deref());
3474 let declared_relationships = if options.check_references {
3475 build_declared_relationships(schema, &schema_map)
3476 } else {
3477 Vec::new()
3478 };
3479
3480 if options.check_references {
3481 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3482 }
3483
3484 for stmt in &statements {
3485 if options.semantic {
3486 all_errors.extend(check_semantics(stmt));
3487 }
3488 all_errors.extend(validate_statement_with_schema(
3489 stmt,
3490 &schema_map,
3491 &resolver_schema,
3492 strict,
3493 ));
3494 if options.check_types {
3495 all_errors.extend(check_types(
3496 stmt,
3497 dialect,
3498 &schema_map,
3499 effective_function_catalog,
3500 strict,
3501 ));
3502 }
3503 if options.check_references {
3504 all_errors.extend(check_query_reference_quality(
3505 stmt,
3506 &schema_map,
3507 &resolver_schema,
3508 strict,
3509 &declared_relationships,
3510 ));
3511 }
3512 }
3513
3514 ValidationResult::with_errors(all_errors)
3515}
3516
3517#[cfg(test)]
3518mod tests;