1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15 feature = "function-catalog-clickhouse",
16 feature = "function-catalog-duckdb",
17 feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20 FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21 HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34 feature = "function-catalog-clickhouse",
35 feature = "function-catalog-duckdb",
36 feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43 pub name: String,
45 #[serde(default, rename = "type")]
47 pub data_type: String,
48 #[serde(default)]
50 pub nullable: Option<bool>,
51 #[serde(default, rename = "primaryKey")]
53 pub primary_key: bool,
54 #[serde(default)]
56 pub unique: bool,
57 #[serde(default)]
59 pub references: Option<SchemaColumnReference>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65 pub table: String,
67 pub column: String,
69 #[serde(default)]
71 pub schema: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77 #[serde(default)]
79 pub name: Option<String>,
80 pub columns: Vec<String>,
82 pub references: SchemaTableReference,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89 pub table: String,
91 pub columns: Vec<String>,
93 #[serde(default)]
95 pub schema: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101 pub name: String,
103 #[serde(default)]
105 pub schema: Option<String>,
106 pub columns: Vec<SchemaColumn>,
108 #[serde(default)]
110 pub aliases: Vec<String>,
111 #[serde(default, rename = "primaryKey")]
113 pub primary_key: Vec<String>,
114 #[serde(default, rename = "uniqueKeys")]
116 pub unique_keys: Vec<Vec<String>>,
117 #[serde(default, rename = "foreignKeys")]
119 pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125 pub tables: Vec<SchemaTable>,
127 #[serde(default)]
129 pub strict: Option<bool>,
130}
131
132#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135 #[serde(default)]
137 pub check_types: bool,
138 #[serde(default)]
140 pub check_references: bool,
141 #[serde(default)]
143 pub strict: Option<bool>,
144 #[serde(default)]
146 pub semantic: bool,
147 #[serde(default)]
149 pub strict_syntax: bool,
150 #[serde(skip, default)]
152 pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156 feature = "function-catalog-clickhouse",
157 feature = "function-catalog-duckdb",
158 feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161 case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163 match case {
164 polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165 CoreFunctionNameCase::Insensitive
166 }
167 polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168 CoreFunctionNameCase::Sensitive
169 }
170 }
171}
172
173#[cfg(any(
174 feature = "function-catalog-clickhouse",
175 feature = "function-catalog-duckdb",
176 feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181 signatures
182 .into_iter()
183 .map(|signature| CoreFunctionSignature {
184 min_arity: signature.min_arity,
185 max_arity: signature.max_arity,
186 })
187 .collect()
188}
189
190#[cfg(any(
191 feature = "function-catalog-clickhouse",
192 feature = "function-catalog-duckdb",
193 feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196 catalog: &'a mut HashMapFunctionCatalog,
197 dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201 feature = "function-catalog-clickhouse",
202 feature = "function-catalog-duckdb",
203 feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206 fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207 if let Some(cached) = self.dialect_cache.get(dialect) {
208 return *cached;
209 }
210 let parsed = dialect.parse::<DialectType>().ok();
211 self.dialect_cache.insert(dialect, parsed);
212 parsed
213 }
214}
215
216#[cfg(any(
217 feature = "function-catalog-clickhouse",
218 feature = "function-catalog-duckdb",
219 feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222 fn set_dialect_name_case(
223 &mut self,
224 dialect: &'static str,
225 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226 ) {
227 if let Some(core_dialect) = self.resolve_dialect(dialect) {
228 self.catalog
229 .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230 }
231 }
232
233 fn set_function_name_case(
234 &mut self,
235 dialect: &'static str,
236 function_name: &str,
237 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238 ) {
239 if let Some(core_dialect) = self.resolve_dialect(dialect) {
240 self.catalog.set_function_name_case(
241 core_dialect,
242 function_name,
243 to_core_name_case(name_case),
244 );
245 }
246 }
247
248 fn register(
249 &mut self,
250 dialect: &'static str,
251 function_name: &str,
252 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253 ) {
254 if let Some(core_dialect) = self.resolve_dialect(dialect) {
255 self.catalog
256 .register(core_dialect, function_name, to_core_signatures(signatures));
257 }
258 }
259}
260
261#[cfg(any(
262 feature = "function-catalog-clickhouse",
263 feature = "function-catalog-duckdb",
264 feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267 static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268 let mut catalog = HashMapFunctionCatalog::default();
269 let mut sink = EmbeddedCatalogSink {
270 catalog: &mut catalog,
271 dialect_cache: HashMap::new(),
272 };
273 polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274 Arc::new(catalog)
275 });
276
277 EMBEDDED.clone()
278}
279
280#[cfg(any(
281 feature = "function-catalog-clickhouse",
282 feature = "function-catalog-duckdb",
283 feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286 Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290 feature = "function-catalog-clickhouse",
291 feature = "function-catalog-duckdb",
292 feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295 None
296}
297
298pub mod validation_codes {
300 pub const E_PARSE_OR_OPTIONS: &str = "E000";
302 pub const E_UNKNOWN_TABLE: &str = "E200";
303 pub const E_UNKNOWN_COLUMN: &str = "E201";
304 pub const E_UNKNOWN_FUNCTION: &str = "E202";
305 pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307 pub const W_SELECT_STAR: &str = "W001";
308 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309 pub const W_DISTINCT_ORDER_BY: &str = "W003";
310 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312 pub const E_TYPE_MISMATCH: &str = "E210";
314 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321 pub const E_INVALID_CAST: &str = "E218";
322 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327 pub const W_LOSSY_CAST: &str = "W213";
328 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329 pub const W_PREDICATE_NULLABILITY: &str = "W215";
330 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342 pub const W_CARTESIAN_JOIN: &str = "W220";
343 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351 Unknown,
352 Boolean,
353 Integer,
354 Numeric,
355 String,
356 Binary,
357 Date,
358 Time,
359 Timestamp,
360 Interval,
361 Json,
362 Uuid,
363 Array,
364 Map,
365 Struct,
366}
367
368impl TypeFamily {
369 pub fn is_numeric(self) -> bool {
370 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371 }
372
373 pub fn is_temporal(self) -> bool {
374 matches!(
375 self,
376 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377 )
378 }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383 columns: HashMap<String, TypeFamily>,
384 column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388 s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392 let open = data_type.find('(')?;
393 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394 return None;
395 }
396 let base = data_type[..open].trim();
397 let inner = data_type[open + 1..data_type.len() - 1].trim();
398 Some((base, inner))
399}
400
401pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403 let trimmed = data_type
404 .trim()
405 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406 if trimmed.is_empty() {
407 return TypeFamily::Unknown;
408 }
409
410 let normalized = trimmed
412 .split_whitespace()
413 .collect::<Vec<_>>()
414 .join(" ")
415 .to_lowercase();
416
417 if let Some((base, inner)) = split_type_args(&normalized) {
419 match base {
420 "nullable" | "lowcardinality" => return canonical_type_family(inner),
421 "array" | "list" => return TypeFamily::Array,
422 "map" => return TypeFamily::Map,
423 "struct" | "row" | "record" => return TypeFamily::Struct,
424 _ => {}
425 }
426 }
427
428 if normalized.starts_with("array<") || normalized.starts_with("list<") {
429 return TypeFamily::Array;
430 }
431 if normalized.starts_with("map<") {
432 return TypeFamily::Map;
433 }
434 if normalized.starts_with("struct<")
435 || normalized.starts_with("row<")
436 || normalized.starts_with("record<")
437 || normalized.starts_with("object<")
438 {
439 return TypeFamily::Struct;
440 }
441
442 if normalized.ends_with("[]") {
443 return TypeFamily::Array;
444 }
445
446 let mut base = normalized
448 .split('(')
449 .next()
450 .unwrap_or("")
451 .trim()
452 .to_string();
453 if base.is_empty() {
454 return TypeFamily::Unknown;
455 }
456
457 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460 match base.as_str() {
461 "bool" | "boolean" => TypeFamily::Boolean,
462 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465 TypeFamily::Integer
466 }
467 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469 TypeFamily::Numeric
470 }
471 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472 | "string" | "clob" => TypeFamily::String,
473 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474 "date" => TypeFamily::Date,
475 "time" => TypeFamily::Time,
476 "timestamp"
477 | "timestamptz"
478 | "datetime"
479 | "datetime2"
480 | "smalldatetime"
481 | "timestamp with time zone"
482 | "timestamp without time zone" => TypeFamily::Timestamp,
483 "interval" => TypeFamily::Interval,
484 "json" | "jsonb" | "variant" => TypeFamily::Json,
485 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486 "array" | "list" => TypeFamily::Array,
487 "map" => TypeFamily::Map,
488 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489 _ => TypeFamily::Unknown,
490 }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494 let mut map = HashMap::new();
495
496 for table in &schema.tables {
497 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498 let columns: HashMap<String, TypeFamily> = table
499 .columns
500 .iter()
501 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502 .collect();
503 let entry = TableSchemaEntry {
504 columns,
505 column_order,
506 };
507
508 let simple_name = lower(&table.name);
509 map.insert(simple_name, entry.clone());
510
511 if let Some(table_schema) = &table.schema {
512 map.insert(
513 format!("{}.{}", lower(table_schema), lower(&table.name)),
514 entry.clone(),
515 );
516 }
517
518 for alias in &table.aliases {
519 map.insert(lower(alias), entry.clone());
520 }
521 }
522
523 map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527 match family {
528 TypeFamily::Unknown => DataType::Unknown,
529 TypeFamily::Boolean => DataType::Boolean,
530 TypeFamily::Integer => DataType::Int {
531 length: None,
532 integer_spelling: false,
533 },
534 TypeFamily::Numeric => DataType::Double {
535 precision: None,
536 scale: None,
537 },
538 TypeFamily::String => DataType::VarChar {
539 length: None,
540 parenthesized_length: false,
541 },
542 TypeFamily::Binary => DataType::VarBinary { length: None },
543 TypeFamily::Date => DataType::Date,
544 TypeFamily::Time => DataType::Time {
545 precision: None,
546 timezone: false,
547 },
548 TypeFamily::Timestamp => DataType::Timestamp {
549 precision: None,
550 timezone: false,
551 },
552 TypeFamily::Interval => DataType::Interval {
553 unit: None,
554 to: None,
555 },
556 TypeFamily::Json => DataType::Json,
557 TypeFamily::Uuid => DataType::Uuid,
558 TypeFamily::Array => DataType::Array {
559 element_type: Box::new(DataType::Unknown),
560 dimension: None,
561 },
562 TypeFamily::Map => DataType::Map {
563 key_type: Box::new(DataType::Unknown),
564 value_type: Box::new(DataType::Unknown),
565 },
566 TypeFamily::Struct => DataType::Struct {
567 fields: Vec::new(),
568 nested: false,
569 },
570 }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574 let mut mapping = MappingSchema::new();
575
576 for table in &schema.tables {
577 let columns: Vec<(String, DataType)> = table
578 .columns
579 .iter()
580 .map(|column| {
581 (
582 lower(&column.name),
583 type_family_to_data_type(canonical_type_family(&column.data_type)),
584 )
585 })
586 .collect();
587
588 let mut table_names = Vec::new();
589 table_names.push(lower(&table.name));
590 if let Some(table_schema) = &table.schema {
591 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592 }
593 for alias in &table.aliases {
594 table_names.push(lower(alias));
595 }
596
597 let mut dedup = HashSet::new();
598 for table_name in table_names {
599 if dedup.insert(table_name.clone()) {
600 let _ = mapping.add_table(&table_name, &columns, None);
601 }
602 }
603 }
604
605 mapping
606}
607
608pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614 build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618 let mut aliases = HashSet::new();
619
620 for node in expr.dfs() {
621 match node {
622 Expression::Select(select) => {
623 if let Some(with) = &select.with {
624 for cte in &with.ctes {
625 aliases.insert(lower(&cte.alias.name));
626 }
627 }
628 }
629 Expression::Insert(insert) => {
630 if let Some(with) = &insert.with {
631 for cte in &with.ctes {
632 aliases.insert(lower(&cte.alias.name));
633 }
634 }
635 }
636 Expression::Update(update) => {
637 if let Some(with) = &update.with {
638 for cte in &with.ctes {
639 aliases.insert(lower(&cte.alias.name));
640 }
641 }
642 }
643 Expression::Delete(delete) => {
644 if let Some(with) = &delete.with {
645 for cte in &with.ctes {
646 aliases.insert(lower(&cte.alias.name));
647 }
648 }
649 }
650 Expression::Union(union) => {
651 if let Some(with) = &union.with {
652 for cte in &with.ctes {
653 aliases.insert(lower(&cte.alias.name));
654 }
655 }
656 }
657 Expression::Intersect(intersect) => {
658 if let Some(with) = &intersect.with {
659 for cte in &with.ctes {
660 aliases.insert(lower(&cte.alias.name));
661 }
662 }
663 }
664 Expression::Except(except) => {
665 if let Some(with) = &except.with {
666 for cte in &with.ctes {
667 aliases.insert(lower(&cte.alias.name));
668 }
669 }
670 }
671 Expression::Merge(merge) => {
672 if let Some(with_) = &merge.with_ {
673 if let Expression::With(with_clause) = with_.as_ref() {
674 for cte in &with_clause.ctes {
675 aliases.insert(lower(&cte.alias.name));
676 }
677 }
678 }
679 }
680 _ => {}
681 }
682 }
683
684 aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688 let name = lower(&table.name.name);
689 let schema = table.schema.as_ref().map(|s| lower(&s.name));
690 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692 let mut candidates = Vec::new();
693 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694 candidates.push(format!("{}.{}.{}", catalog, schema, name));
695 }
696 if let Some(schema) = &schema {
697 candidates.push(format!("{}.{}", schema, name));
698 }
699 candidates.push(name);
700 candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704 let mut parts = Vec::new();
705 if let Some(catalog) = &table.catalog {
706 parts.push(catalog.name.clone());
707 }
708 if let Some(schema) = &table.schema {
709 parts.push(schema.name.clone());
710 }
711 parts.push(table.name.name.clone());
712 parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717 referenced_tables: HashSet<String>,
718 table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722 match family {
723 TypeFamily::Unknown => "unknown",
724 TypeFamily::Boolean => "boolean",
725 TypeFamily::Integer => "integer",
726 TypeFamily::Numeric => "numeric",
727 TypeFamily::String => "string",
728 TypeFamily::Binary => "binary",
729 TypeFamily::Date => "date",
730 TypeFamily::Time => "time",
731 TypeFamily::Timestamp => "timestamp",
732 TypeFamily::Interval => "interval",
733 TypeFamily::Json => "json",
734 TypeFamily::Uuid => "uuid",
735 TypeFamily::Array => "array",
736 TypeFamily::Map => "map",
737 TypeFamily::Struct => "struct",
738 }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742 matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746 matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750 strict: bool,
751 error_code: &str,
752 warning_code: &str,
753 message: impl Into<String>,
754) -> ValidationError {
755 if strict {
756 ValidationError::error(message.into(), error_code)
757 } else {
758 ValidationError::warning(message.into(), warning_code)
759 }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763 match data_type {
764 DataType::Boolean => TypeFamily::Boolean,
765 DataType::TinyInt { .. }
766 | DataType::SmallInt { .. }
767 | DataType::Int { .. }
768 | DataType::BigInt { .. } => TypeFamily::Integer,
769 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770 TypeFamily::Numeric
771 }
772 DataType::Char { .. }
773 | DataType::VarChar { .. }
774 | DataType::String { .. }
775 | DataType::Text
776 | DataType::TextWithLength { .. }
777 | DataType::CharacterSet { .. } => TypeFamily::String,
778 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779 DataType::Date => TypeFamily::Date,
780 DataType::Time { .. } => TypeFamily::Time,
781 DataType::Timestamp { .. } => TypeFamily::Timestamp,
782 DataType::Interval { .. } => TypeFamily::Interval,
783 DataType::Json | DataType::JsonB => TypeFamily::Json,
784 DataType::Uuid => TypeFamily::Uuid,
785 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786 DataType::Map { .. } => TypeFamily::Map,
787 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788 TypeFamily::Struct
789 }
790 DataType::Nullable { inner } => data_type_family(inner),
791 DataType::Custom { name } => canonical_type_family(name),
792 DataType::Unknown => TypeFamily::Unknown,
793 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795 DataType::Vector { .. } => TypeFamily::Array,
796 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797 }
798}
799
800fn collect_type_check_context(
801 stmt: &Expression,
802 schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804 fn add_table_to_context(
805 table: &TableRef,
806 schema_map: &HashMap<String, TableSchemaEntry>,
807 context: &mut TypeCheckContext,
808 ) {
809 let resolved_key = table_ref_candidates(table)
810 .into_iter()
811 .find(|k| schema_map.contains_key(k));
812
813 let Some(table_key) = resolved_key else {
814 return;
815 };
816
817 context.referenced_tables.insert(table_key.clone());
818 context
819 .table_aliases
820 .insert(lower(&table.name.name), table_key.clone());
821 if let Some(alias) = &table.alias {
822 context
823 .table_aliases
824 .insert(lower(&alias.name), table_key.clone());
825 }
826 }
827
828 let mut context = TypeCheckContext::default();
829 let cte_aliases = collect_cte_aliases(stmt);
830
831 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832 let Expression::Table(table) = node else {
833 continue;
834 };
835
836 if cte_aliases.contains(&lower(&table.name.name)) {
837 continue;
838 }
839
840 add_table_to_context(table, schema_map, &mut context);
841 }
842
843 match stmt {
846 Expression::Insert(insert) => {
847 add_table_to_context(&insert.table, schema_map, &mut context);
848 }
849 Expression::Update(update) => {
850 add_table_to_context(&update.table, schema_map, &mut context);
851 for table in &update.extra_tables {
852 add_table_to_context(table, schema_map, &mut context);
853 }
854 }
855 Expression::Delete(delete) => {
856 add_table_to_context(&delete.table, schema_map, &mut context);
857 for table in &delete.using {
858 add_table_to_context(table, schema_map, &mut context);
859 }
860 for table in &delete.tables {
861 add_table_to_context(table, schema_map, &mut context);
862 }
863 }
864 _ => {}
865 }
866
867 context
868}
869
870fn resolve_table_schema_entry<'a>(
871 table: &TableRef,
872 schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874 let key = table_ref_candidates(table)
875 .into_iter()
876 .find(|k| schema_map.contains_key(k))?;
877 let entry = schema_map.get(&key)?;
878 Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882 if strict {
883 ValidationError::error(
884 message.into(),
885 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886 )
887 } else {
888 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889 }
890}
891
892fn reference_table_candidates(
893 table_name: &str,
894 explicit_schema: Option<&str>,
895 source_schema: Option<&str>,
896) -> Vec<String> {
897 let mut candidates = Vec::new();
898 let raw = lower(table_name);
899
900 if let Some(schema) = explicit_schema {
901 candidates.push(format!("{}.{}", lower(schema), raw));
902 }
903
904 if raw.contains('.') {
905 candidates.push(raw.clone());
906 if let Some(last) = raw.rsplit('.').next() {
907 candidates.push(last.to_string());
908 }
909 } else {
910 if let Some(schema) = source_schema {
911 candidates.push(format!("{}.{}", lower(schema), raw));
912 }
913 candidates.push(raw);
914 }
915
916 let mut dedup = HashSet::new();
917 candidates
918 .into_iter()
919 .filter(|c| dedup.insert(c.clone()))
920 .collect()
921}
922
923fn resolve_reference_table_key(
924 table_name: &str,
925 explicit_schema: Option<&str>,
926 source_schema: Option<&str>,
927 schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929 reference_table_candidates(table_name, explicit_schema, source_schema)
930 .into_iter()
931 .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936 return true;
937 }
938 if source == target {
939 return true;
940 }
941 if source.is_numeric() && target.is_numeric() {
942 return true;
943 }
944 if source.is_temporal() && target.is_temporal() {
945 return true;
946 }
947 false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951 let mut hints = HashSet::new();
952 for column in &table.columns {
953 if column.primary_key || column.unique {
954 hints.insert(lower(&column.name));
955 }
956 }
957 for key_col in &table.primary_key {
958 hints.insert(lower(key_col));
959 }
960 for group in &table.unique_keys {
961 if group.len() == 1 {
962 if let Some(col) = group.first() {
963 hints.insert(lower(col));
964 }
965 }
966 }
967 hints
968}
969
970fn check_reference_integrity(
971 schema: &ValidationSchema,
972 schema_map: &HashMap<String, TableSchemaEntry>,
973 strict: bool,
974) -> Vec<ValidationError> {
975 let mut errors = Vec::new();
976
977 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978 for table in &schema.tables {
979 let simple = lower(&table.name);
980 key_hints_lookup.insert(simple, table_key_hints(table));
981 if let Some(schema_name) = &table.schema {
982 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983 key_hints_lookup.insert(qualified, table_key_hints(table));
984 }
985 }
986
987 for table in &schema.tables {
988 let source_table_display = if let Some(schema_name) = &table.schema {
989 format!("{}.{}", schema_name, table.name)
990 } else {
991 table.name.clone()
992 };
993 let source_schema = table.schema.as_deref();
994 let source_columns: HashMap<String, TypeFamily> = table
995 .columns
996 .iter()
997 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998 .collect();
999
1000 for source_col in &table.columns {
1001 let Some(reference) = &source_col.references else {
1002 continue;
1003 };
1004 let source_type = canonical_type_family(&source_col.data_type);
1005
1006 let Some(target_key) = resolve_reference_table_key(
1007 &reference.table,
1008 reference.schema.as_deref(),
1009 source_schema,
1010 schema_map,
1011 ) else {
1012 errors.push(reference_issue(
1013 strict,
1014 format!(
1015 "Foreign key reference '{}.{}' points to unknown table '{}'",
1016 source_table_display, source_col.name, reference.table
1017 ),
1018 ));
1019 continue;
1020 };
1021
1022 let target_column = lower(&reference.column);
1023 let Some(target_entry) = schema_map.get(&target_key) else {
1024 errors.push(reference_issue(
1025 strict,
1026 format!(
1027 "Foreign key reference '{}.{}' points to unknown table '{}'",
1028 source_table_display, source_col.name, reference.table
1029 ),
1030 ));
1031 continue;
1032 };
1033
1034 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035 errors.push(reference_issue(
1036 strict,
1037 format!(
1038 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039 source_table_display, source_col.name, target_key, reference.column
1040 ),
1041 ));
1042 continue;
1043 };
1044
1045 if !key_types_compatible(source_type, target_type) {
1046 errors.push(reference_issue(
1047 strict,
1048 format!(
1049 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050 source_table_display,
1051 source_col.name,
1052 target_key,
1053 reference.column,
1054 type_family_name(source_type),
1055 type_family_name(target_type)
1056 ),
1057 ));
1058 }
1059
1060 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061 if !target_key_hints.contains(&target_column) {
1062 errors.push(ValidationError::warning(
1063 format!(
1064 "Referenced column '{}.{}' is not marked as primary/unique key",
1065 target_key, reference.column
1066 ),
1067 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068 ));
1069 }
1070 }
1071 }
1072
1073 for foreign_key in &table.foreign_keys {
1074 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075 errors.push(reference_issue(
1076 strict,
1077 format!(
1078 "Table-level foreign key on '{}' has empty source or target column list",
1079 source_table_display
1080 ),
1081 ));
1082 continue;
1083 }
1084 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085 errors.push(reference_issue(
1086 strict,
1087 format!(
1088 "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089 source_table_display,
1090 foreign_key.columns.len(),
1091 foreign_key.references.columns.len()
1092 ),
1093 ));
1094 continue;
1095 }
1096
1097 let Some(target_key) = resolve_reference_table_key(
1098 &foreign_key.references.table,
1099 foreign_key.references.schema.as_deref(),
1100 source_schema,
1101 schema_map,
1102 ) else {
1103 errors.push(reference_issue(
1104 strict,
1105 format!(
1106 "Table-level foreign key on '{}' points to unknown table '{}'",
1107 source_table_display, foreign_key.references.table
1108 ),
1109 ));
1110 continue;
1111 };
1112
1113 let Some(target_entry) = schema_map.get(&target_key) else {
1114 errors.push(reference_issue(
1115 strict,
1116 format!(
1117 "Table-level foreign key on '{}' points to unknown table '{}'",
1118 source_table_display, foreign_key.references.table
1119 ),
1120 ));
1121 continue;
1122 };
1123
1124 for (source_col, target_col) in foreign_key
1125 .columns
1126 .iter()
1127 .zip(foreign_key.references.columns.iter())
1128 {
1129 let source_col_name = lower(source_col);
1130 let target_col_name = lower(target_col);
1131
1132 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133 errors.push(reference_issue(
1134 strict,
1135 format!(
1136 "Table-level foreign key on '{}' references unknown source column '{}'",
1137 source_table_display, source_col
1138 ),
1139 ));
1140 continue;
1141 };
1142
1143 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144 errors.push(reference_issue(
1145 strict,
1146 format!(
1147 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148 source_table_display, target_key, target_col
1149 ),
1150 ));
1151 continue;
1152 };
1153
1154 if !key_types_compatible(source_type, target_type) {
1155 errors.push(reference_issue(
1156 strict,
1157 format!(
1158 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159 source_table_display,
1160 source_col,
1161 target_key,
1162 target_col,
1163 type_family_name(source_type),
1164 type_family_name(target_type)
1165 ),
1166 ));
1167 }
1168
1169 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170 if !target_key_hints.contains(&target_col_name) {
1171 errors.push(ValidationError::warning(
1172 format!(
1173 "Referenced column '{}.{}' is not marked as primary/unique key",
1174 target_key, target_col
1175 ),
1176 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177 ));
1178 }
1179 }
1180 }
1181 }
1182 }
1183
1184 errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188 column_name: &str,
1189 schema_map: &HashMap<String, TableSchemaEntry>,
1190 context: &TypeCheckContext,
1191) -> TypeFamily {
1192 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193 context.referenced_tables.iter().collect()
1194 } else {
1195 schema_map.keys().collect()
1196 };
1197
1198 let mut families = HashSet::new();
1199 for table_name in candidate_tables {
1200 if let Some(table_schema) = schema_map.get(table_name) {
1201 if let Some(family) = table_schema.columns.get(column_name) {
1202 families.insert(*family);
1203 }
1204 }
1205 }
1206
1207 if families.len() == 1 {
1208 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209 } else {
1210 TypeFamily::Unknown
1211 }
1212}
1213
1214fn resolve_column_type(
1215 column: &Column,
1216 schema_map: &HashMap<String, TableSchemaEntry>,
1217 context: &TypeCheckContext,
1218) -> TypeFamily {
1219 let column_name = lower(&column.name.name);
1220 if column_name.is_empty() {
1221 return TypeFamily::Unknown;
1222 }
1223
1224 if let Some(table) = &column.table {
1225 let mut table_key = lower(&table.name);
1226 if let Some(mapped) = context.table_aliases.get(&table_key) {
1227 table_key = mapped.clone();
1228 }
1229
1230 return schema_map
1231 .get(&table_key)
1232 .and_then(|t| t.columns.get(&column_name))
1233 .copied()
1234 .unwrap_or(TypeFamily::Unknown);
1235 }
1236
1237 resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241 schema_map: &'a HashMap<String, TableSchemaEntry>,
1242 context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246 fn resolve_table_key(&self, table: &str) -> Option<String> {
1247 let mut table_key = lower(table);
1248 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249 table_key = mapped.clone();
1250 }
1251 if self.schema_map.contains_key(&table_key) {
1252 Some(table_key)
1253 } else {
1254 None
1255 }
1256 }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260 fn dialect(&self) -> Option<DialectType> {
1261 None
1262 }
1263
1264 fn add_table(
1265 &mut self,
1266 _table: &str,
1267 _columns: &[(String, DataType)],
1268 _dialect: Option<DialectType>,
1269 ) -> SchemaResult<()> {
1270 Err(SchemaError::InvalidStructure(
1271 "Type inference schema is read-only".to_string(),
1272 ))
1273 }
1274
1275 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276 let table_key = self
1277 .resolve_table_key(table)
1278 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279 let entry = self
1280 .schema_map
1281 .get(&table_key)
1282 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283 Ok(entry.column_order.clone())
1284 }
1285
1286 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287 let col_name = lower(column);
1288 if table.is_empty() {
1289 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290 return if family == TypeFamily::Unknown {
1291 Err(SchemaError::ColumnNotFound {
1292 table: "<unqualified>".to_string(),
1293 column: column.to_string(),
1294 })
1295 } else {
1296 Ok(type_family_to_data_type(family))
1297 };
1298 }
1299
1300 let table_key = self
1301 .resolve_table_key(table)
1302 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303 let entry = self
1304 .schema_map
1305 .get(&table_key)
1306 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307 let family =
1308 entry
1309 .columns
1310 .get(&col_name)
1311 .copied()
1312 .ok_or_else(|| SchemaError::ColumnNotFound {
1313 table: table.to_string(),
1314 column: column.to_string(),
1315 })?;
1316 Ok(type_family_to_data_type(family))
1317 }
1318
1319 fn has_column(&self, table: &str, column: &str) -> bool {
1320 self.get_column_type(table, column).is_ok()
1321 }
1322
1323 fn supported_table_args(&self) -> &[&str] {
1324 TABLE_PARTS
1325 }
1326
1327 fn is_empty(&self) -> bool {
1328 self.schema_map.is_empty()
1329 }
1330
1331 fn depth(&self) -> usize {
1332 1
1333 }
1334}
1335
1336fn infer_expression_type_family(
1337 expr: &Expression,
1338 schema_map: &HashMap<String, TableSchemaEntry>,
1339 context: &TypeCheckContext,
1340) -> TypeFamily {
1341 let inference_schema = TypeInferenceSchema {
1342 schema_map,
1343 context,
1344 };
1345 let mut expr_clone = expr.clone();
1346 annotate_types(&mut expr_clone, Some(&inference_schema), None);
1347 if let Some(data_type) = expr_clone.inferred_type() {
1348 let family = data_type_family(&data_type);
1349 if family != TypeFamily::Unknown {
1350 return family;
1351 }
1352 }
1353
1354 infer_expression_type_family_fallback(expr, schema_map, context)
1355}
1356
1357fn infer_expression_type_family_fallback(
1358 expr: &Expression,
1359 schema_map: &HashMap<String, TableSchemaEntry>,
1360 context: &TypeCheckContext,
1361) -> TypeFamily {
1362 match expr {
1363 Expression::Literal(literal) => match literal {
1364 crate::expressions::Literal::Number(value) => {
1365 if value.contains('.') || value.contains('e') || value.contains('E') {
1366 TypeFamily::Numeric
1367 } else {
1368 TypeFamily::Integer
1369 }
1370 }
1371 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1372 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1373 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1374 crate::expressions::Literal::Timestamp(_)
1375 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1376 crate::expressions::Literal::HexString(_)
1377 | crate::expressions::Literal::BitString(_)
1378 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1379 _ => TypeFamily::String,
1380 },
1381 Expression::Boolean(_) => TypeFamily::Boolean,
1382 Expression::Null(_) => TypeFamily::Unknown,
1383 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1384 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1385 data_type_family(&cast.to)
1386 }
1387 Expression::Alias(alias) => {
1388 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1389 }
1390 Expression::Neg(unary) => {
1391 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1392 }
1393 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1394 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1395 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1396 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1397 TypeFamily::Unknown
1398 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1399 TypeFamily::Integer
1400 } else if left.is_numeric() && right.is_numeric() {
1401 TypeFamily::Numeric
1402 } else if left.is_temporal() || right.is_temporal() {
1403 left
1404 } else {
1405 TypeFamily::Unknown
1406 }
1407 }
1408 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1409 Expression::Concat(_) => TypeFamily::String,
1410 Expression::Eq(_)
1411 | Expression::Neq(_)
1412 | Expression::Lt(_)
1413 | Expression::Lte(_)
1414 | Expression::Gt(_)
1415 | Expression::Gte(_)
1416 | Expression::Like(_)
1417 | Expression::ILike(_)
1418 | Expression::And(_)
1419 | Expression::Or(_)
1420 | Expression::Not(_)
1421 | Expression::Between(_)
1422 | Expression::In(_)
1423 | Expression::IsNull(_)
1424 | Expression::IsTrue(_)
1425 | Expression::IsFalse(_)
1426 | Expression::Is(_) => TypeFamily::Boolean,
1427 Expression::Length(_) => TypeFamily::Integer,
1428 Expression::Upper(_)
1429 | Expression::Lower(_)
1430 | Expression::Trim(_)
1431 | Expression::LTrim(_)
1432 | Expression::RTrim(_)
1433 | Expression::Replace(_)
1434 | Expression::Substring(_)
1435 | Expression::Left(_)
1436 | Expression::Right(_)
1437 | Expression::Repeat(_)
1438 | Expression::Lpad(_)
1439 | Expression::Rpad(_)
1440 | Expression::ConcatWs(_) => TypeFamily::String,
1441 Expression::Abs(_)
1442 | Expression::Round(_)
1443 | Expression::Floor(_)
1444 | Expression::Ceil(_)
1445 | Expression::Power(_)
1446 | Expression::Sqrt(_)
1447 | Expression::Cbrt(_)
1448 | Expression::Ln(_)
1449 | Expression::Log(_)
1450 | Expression::Exp(_) => TypeFamily::Numeric,
1451 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1452 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1453 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1454 Expression::CurrentDate(_) => TypeFamily::Date,
1455 Expression::CurrentTime(_) => TypeFamily::Time,
1456 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1457 TypeFamily::Timestamp
1458 }
1459 Expression::Interval(_) => TypeFamily::Interval,
1460 _ => TypeFamily::Unknown,
1461 }
1462}
1463
1464fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1465 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1466 return true;
1467 }
1468 if left == right {
1469 return true;
1470 }
1471 if left.is_numeric() && right.is_numeric() {
1472 return true;
1473 }
1474 if left.is_temporal() && right.is_temporal() {
1475 return true;
1476 }
1477 false
1478}
1479
1480fn check_function_argument(
1481 errors: &mut Vec<ValidationError>,
1482 strict: bool,
1483 function_name: &str,
1484 arg_index: usize,
1485 family: TypeFamily,
1486 expected: &str,
1487 valid: bool,
1488) {
1489 if family == TypeFamily::Unknown || valid {
1490 return;
1491 }
1492
1493 errors.push(type_issue(
1494 strict,
1495 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1496 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1497 format!(
1498 "Function '{}' argument {} expects {}, found {}",
1499 function_name,
1500 arg_index + 1,
1501 expected,
1502 type_family_name(family)
1503 ),
1504 ));
1505}
1506
1507fn function_dispatch_name(name: &str) -> String {
1508 let upper = name
1509 .rsplit('.')
1510 .next()
1511 .unwrap_or(name)
1512 .trim()
1513 .to_uppercase();
1514 lower(canonical_typed_function_name_upper(&upper))
1515}
1516
1517fn function_base_name(name: &str) -> &str {
1518 name.rsplit('.').next().unwrap_or(name).trim()
1519}
1520
1521fn check_generic_function(
1522 function: &Function,
1523 schema_map: &HashMap<String, TableSchemaEntry>,
1524 context: &TypeCheckContext,
1525 strict: bool,
1526 errors: &mut Vec<ValidationError>,
1527) {
1528 let name = function_dispatch_name(&function.name);
1529
1530 let arg_family = |index: usize| -> Option<TypeFamily> {
1531 function
1532 .args
1533 .get(index)
1534 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1535 };
1536
1537 match name.as_str() {
1538 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1539 if let Some(family) = arg_family(0) {
1540 check_function_argument(
1541 errors,
1542 strict,
1543 &name,
1544 0,
1545 family,
1546 "a numeric argument",
1547 family.is_numeric(),
1548 );
1549 }
1550 }
1551 "round" | "floor" | "ceil" | "ceiling" => {
1552 if let Some(family) = arg_family(0) {
1553 check_function_argument(
1554 errors,
1555 strict,
1556 &name,
1557 0,
1558 family,
1559 "a numeric argument",
1560 family.is_numeric(),
1561 );
1562 }
1563 if let Some(family) = arg_family(1) {
1564 check_function_argument(
1565 errors,
1566 strict,
1567 &name,
1568 1,
1569 family,
1570 "a numeric argument",
1571 family.is_numeric(),
1572 );
1573 }
1574 }
1575 "power" | "pow" => {
1576 for i in [0_usize, 1_usize] {
1577 if let Some(family) = arg_family(i) {
1578 check_function_argument(
1579 errors,
1580 strict,
1581 &name,
1582 i,
1583 family,
1584 "a numeric argument",
1585 family.is_numeric(),
1586 );
1587 }
1588 }
1589 }
1590 "length" | "char_length" | "character_length" => {
1591 if let Some(family) = arg_family(0) {
1592 check_function_argument(
1593 errors,
1594 strict,
1595 &name,
1596 0,
1597 family,
1598 "a string or binary argument",
1599 is_string_or_binary(family),
1600 );
1601 }
1602 }
1603 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1604 if let Some(family) = arg_family(0) {
1605 check_function_argument(
1606 errors,
1607 strict,
1608 &name,
1609 0,
1610 family,
1611 "a string argument",
1612 is_string_like(family),
1613 );
1614 }
1615 }
1616 "substring" | "substr" => {
1617 if let Some(family) = arg_family(0) {
1618 check_function_argument(
1619 errors,
1620 strict,
1621 &name,
1622 0,
1623 family,
1624 "a string argument",
1625 is_string_like(family),
1626 );
1627 }
1628 if let Some(family) = arg_family(1) {
1629 check_function_argument(
1630 errors,
1631 strict,
1632 &name,
1633 1,
1634 family,
1635 "a numeric argument",
1636 family.is_numeric(),
1637 );
1638 }
1639 if let Some(family) = arg_family(2) {
1640 check_function_argument(
1641 errors,
1642 strict,
1643 &name,
1644 2,
1645 family,
1646 "a numeric argument",
1647 family.is_numeric(),
1648 );
1649 }
1650 }
1651 "replace" => {
1652 for i in [0_usize, 1_usize, 2_usize] {
1653 if let Some(family) = arg_family(i) {
1654 check_function_argument(
1655 errors,
1656 strict,
1657 &name,
1658 i,
1659 family,
1660 "a string argument",
1661 is_string_like(family),
1662 );
1663 }
1664 }
1665 }
1666 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1667 if let Some(family) = arg_family(0) {
1668 check_function_argument(
1669 errors,
1670 strict,
1671 &name,
1672 0,
1673 family,
1674 "a string argument",
1675 is_string_like(family),
1676 );
1677 }
1678 if let Some(family) = arg_family(1) {
1679 check_function_argument(
1680 errors,
1681 strict,
1682 &name,
1683 1,
1684 family,
1685 "a numeric argument",
1686 family.is_numeric(),
1687 );
1688 }
1689 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1690 if let Some(family) = arg_family(2) {
1691 check_function_argument(
1692 errors,
1693 strict,
1694 &name,
1695 2,
1696 family,
1697 "a string argument",
1698 is_string_like(family),
1699 );
1700 }
1701 }
1702 }
1703 _ => {}
1704 }
1705}
1706
1707fn check_function_catalog(
1708 function: &Function,
1709 dialect: DialectType,
1710 function_catalog: Option<&dyn FunctionCatalog>,
1711 strict: bool,
1712 errors: &mut Vec<ValidationError>,
1713) {
1714 let Some(catalog) = function_catalog else {
1715 return;
1716 };
1717
1718 let raw_name = function_base_name(&function.name);
1719 let normalized_name = function_dispatch_name(&function.name);
1720 let arity = function.args.len();
1721 let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1722 errors.push(if strict {
1723 ValidationError::error(
1724 format!(
1725 "Unknown function '{}' for dialect {:?}",
1726 function.name, dialect
1727 ),
1728 validation_codes::E_UNKNOWN_FUNCTION,
1729 )
1730 } else {
1731 ValidationError::warning(
1732 format!(
1733 "Unknown function '{}' for dialect {:?}",
1734 function.name, dialect
1735 ),
1736 validation_codes::E_UNKNOWN_FUNCTION,
1737 )
1738 });
1739 return;
1740 };
1741
1742 if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1743 return;
1744 }
1745
1746 let expected = signatures
1747 .iter()
1748 .map(|sig| sig.describe_arity())
1749 .collect::<Vec<_>>()
1750 .join(", ");
1751 errors.push(if strict {
1752 ValidationError::error(
1753 format!(
1754 "Invalid arity for function '{}': got {}, expected {}",
1755 function.name, arity, expected
1756 ),
1757 validation_codes::E_INVALID_FUNCTION_ARITY,
1758 )
1759 } else {
1760 ValidationError::warning(
1761 format!(
1762 "Invalid arity for function '{}': got {}, expected {}",
1763 function.name, arity, expected
1764 ),
1765 validation_codes::E_INVALID_FUNCTION_ARITY,
1766 )
1767 });
1768}
1769
1770#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1771struct DeclaredRelationship {
1772 source_table: String,
1773 source_column: String,
1774 target_table: String,
1775 target_column: String,
1776}
1777
1778fn build_declared_relationships(
1779 schema: &ValidationSchema,
1780 schema_map: &HashMap<String, TableSchemaEntry>,
1781) -> Vec<DeclaredRelationship> {
1782 let mut relationships = HashSet::new();
1783
1784 for table in &schema.tables {
1785 let Some(source_key) =
1786 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1787 else {
1788 continue;
1789 };
1790
1791 for column in &table.columns {
1792 let Some(reference) = &column.references else {
1793 continue;
1794 };
1795 let Some(target_key) = resolve_reference_table_key(
1796 &reference.table,
1797 reference.schema.as_deref(),
1798 table.schema.as_deref(),
1799 schema_map,
1800 ) else {
1801 continue;
1802 };
1803
1804 relationships.insert(DeclaredRelationship {
1805 source_table: source_key.clone(),
1806 source_column: lower(&column.name),
1807 target_table: target_key,
1808 target_column: lower(&reference.column),
1809 });
1810 }
1811
1812 for foreign_key in &table.foreign_keys {
1813 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1814 continue;
1815 }
1816 let Some(target_key) = resolve_reference_table_key(
1817 &foreign_key.references.table,
1818 foreign_key.references.schema.as_deref(),
1819 table.schema.as_deref(),
1820 schema_map,
1821 ) else {
1822 continue;
1823 };
1824
1825 for (source_col, target_col) in foreign_key
1826 .columns
1827 .iter()
1828 .zip(foreign_key.references.columns.iter())
1829 {
1830 relationships.insert(DeclaredRelationship {
1831 source_table: source_key.clone(),
1832 source_column: lower(source_col),
1833 target_table: target_key.clone(),
1834 target_column: lower(target_col),
1835 });
1836 }
1837 }
1838 }
1839
1840 relationships.into_iter().collect()
1841}
1842
1843fn resolve_column_binding(
1844 column: &Column,
1845 schema_map: &HashMap<String, TableSchemaEntry>,
1846 context: &TypeCheckContext,
1847 resolver: &mut Resolver<'_>,
1848) -> Option<(String, String)> {
1849 let column_name = lower(&column.name.name);
1850 if column_name.is_empty() {
1851 return None;
1852 }
1853
1854 if let Some(table) = &column.table {
1855 let mut table_key = lower(&table.name);
1856 if let Some(mapped) = context.table_aliases.get(&table_key) {
1857 table_key = mapped.clone();
1858 }
1859 if schema_map.contains_key(&table_key) {
1860 return Some((table_key, column_name));
1861 }
1862 return None;
1863 }
1864
1865 if let Some(resolved_source) = resolver.get_table(&column_name) {
1866 let mut table_key = lower(&resolved_source);
1867 if let Some(mapped) = context.table_aliases.get(&table_key) {
1868 table_key = mapped.clone();
1869 }
1870 if schema_map.contains_key(&table_key) {
1871 return Some((table_key, column_name));
1872 }
1873 }
1874
1875 let candidates: Vec<String> = context
1876 .referenced_tables
1877 .iter()
1878 .filter_map(|table_name| {
1879 schema_map
1880 .get(table_name)
1881 .filter(|entry| entry.columns.contains_key(&column_name))
1882 .map(|_| table_name.clone())
1883 })
1884 .collect();
1885 if candidates.len() == 1 {
1886 return Some((candidates[0].clone(), column_name));
1887 }
1888 None
1889}
1890
1891fn extract_join_equality_pairs(
1892 expr: &Expression,
1893 schema_map: &HashMap<String, TableSchemaEntry>,
1894 context: &TypeCheckContext,
1895 resolver: &mut Resolver<'_>,
1896 pairs: &mut Vec<((String, String), (String, String))>,
1897) {
1898 match expr {
1899 Expression::And(op) => {
1900 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1901 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1902 }
1903 Expression::Paren(paren) => {
1904 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1905 }
1906 Expression::Eq(op) => {
1907 let (Expression::Column(left_col), Expression::Column(right_col)) =
1908 (&op.left, &op.right)
1909 else {
1910 return;
1911 };
1912 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1913 return;
1914 };
1915 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1916 else {
1917 return;
1918 };
1919 pairs.push((left, right));
1920 }
1921 _ => {}
1922 }
1923}
1924
1925fn relationship_matches_pair(
1926 relationship: &DeclaredRelationship,
1927 left_table: &str,
1928 left_column: &str,
1929 right_table: &str,
1930 right_column: &str,
1931) -> bool {
1932 (relationship.source_table == left_table
1933 && relationship.source_column == left_column
1934 && relationship.target_table == right_table
1935 && relationship.target_column == right_column)
1936 || (relationship.source_table == right_table
1937 && relationship.source_column == right_column
1938 && relationship.target_table == left_table
1939 && relationship.target_column == left_column)
1940}
1941
1942fn resolved_table_key_from_expr(
1943 expr: &Expression,
1944 schema_map: &HashMap<String, TableSchemaEntry>,
1945) -> Option<String> {
1946 match expr {
1947 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1948 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1949 _ => None,
1950 }
1951}
1952
1953fn select_from_table_keys(
1954 select: &crate::expressions::Select,
1955 schema_map: &HashMap<String, TableSchemaEntry>,
1956) -> HashSet<String> {
1957 let mut keys = HashSet::new();
1958 if let Some(from_clause) = &select.from {
1959 for expr in &from_clause.expressions {
1960 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1961 keys.insert(key);
1962 }
1963 }
1964 }
1965 keys
1966}
1967
1968fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1969 matches!(
1970 kind,
1971 JoinKind::Natural
1972 | JoinKind::NaturalLeft
1973 | JoinKind::NaturalRight
1974 | JoinKind::NaturalFull
1975 | JoinKind::CrossApply
1976 | JoinKind::OuterApply
1977 | JoinKind::AsOf
1978 | JoinKind::AsOfLeft
1979 | JoinKind::AsOfRight
1980 | JoinKind::Lateral
1981 | JoinKind::LeftLateral
1982 )
1983}
1984
1985fn check_query_reference_quality(
1986 stmt: &Expression,
1987 schema_map: &HashMap<String, TableSchemaEntry>,
1988 resolver_schema: &MappingSchema,
1989 strict: bool,
1990 relationships: &[DeclaredRelationship],
1991) -> Vec<ValidationError> {
1992 let mut errors = Vec::new();
1993
1994 for node in stmt.dfs() {
1995 let Expression::Select(select) = node else {
1996 continue;
1997 };
1998
1999 let select_expr = Expression::Select(select.clone());
2000 let context = collect_type_check_context(&select_expr, schema_map);
2001 let scope = build_scope(&select_expr);
2002 let mut resolver = Resolver::new(&scope, resolver_schema, true);
2003
2004 if context.referenced_tables.len() > 1 {
2005 let using_columns: HashSet<String> = select
2006 .joins
2007 .iter()
2008 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2009 .collect();
2010
2011 let mut seen = HashSet::new();
2012 for column_expr in select_expr
2013 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2014 {
2015 let Expression::Column(column) = column_expr else {
2016 continue;
2017 };
2018
2019 let col_name = lower(&column.name.name);
2020 if col_name.is_empty()
2021 || using_columns.contains(&col_name)
2022 || !seen.insert(col_name.clone())
2023 {
2024 continue;
2025 }
2026
2027 if resolver.is_ambiguous(&col_name) {
2028 let source_count = resolver.sources_for_column(&col_name).len();
2029 errors.push(if strict {
2030 ValidationError::error(
2031 format!(
2032 "Ambiguous unqualified column '{}' found in {} referenced tables",
2033 col_name, source_count
2034 ),
2035 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2036 )
2037 } else {
2038 ValidationError::warning(
2039 format!(
2040 "Ambiguous unqualified column '{}' found in {} referenced tables",
2041 col_name, source_count
2042 ),
2043 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2044 )
2045 });
2046 }
2047 }
2048 }
2049
2050 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2051
2052 for join in &select.joins {
2053 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2054 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2055 let cartesian_like_kind = matches!(
2056 join.kind,
2057 JoinKind::Cross
2058 | JoinKind::Implicit
2059 | JoinKind::Array
2060 | JoinKind::LeftArray
2061 | JoinKind::Paste
2062 );
2063
2064 if right_table_key.is_some()
2065 && (cartesian_like_kind
2066 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2067 {
2068 errors.push(ValidationError::warning(
2069 "Potential cartesian join: JOIN without ON/USING condition",
2070 validation_codes::W_CARTESIAN_JOIN,
2071 ));
2072 }
2073
2074 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2075 if join.using.is_empty() {
2076 let mut eq_pairs = Vec::new();
2077 extract_join_equality_pairs(
2078 on_expr,
2079 schema_map,
2080 &context,
2081 &mut resolver,
2082 &mut eq_pairs,
2083 );
2084
2085 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2086 .iter()
2087 .filter(|rel| {
2088 cumulative_left_tables.contains(&rel.source_table)
2089 && rel.target_table == right_key
2090 || (cumulative_left_tables.contains(&rel.target_table)
2091 && rel.source_table == right_key)
2092 })
2093 .collect();
2094
2095 if !relevant_relationships.is_empty() {
2096 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2097 relevant_relationships
2098 .iter()
2099 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2100 });
2101 if !uses_declared_fk {
2102 errors.push(ValidationError::warning(
2103 "JOIN predicate does not use declared foreign-key relationship columns",
2104 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2105 ));
2106 }
2107 }
2108 }
2109 }
2110
2111 if let Some(right_key) = right_table_key {
2112 cumulative_left_tables.insert(right_key);
2113 }
2114 }
2115 }
2116
2117 errors
2118}
2119
2120fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2121 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2122 return true;
2123 }
2124 if left == right {
2125 return true;
2126 }
2127 if left.is_numeric() && right.is_numeric() {
2128 return true;
2129 }
2130 if left.is_temporal() && right.is_temporal() {
2131 return true;
2132 }
2133 false
2134}
2135
2136fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2137 if left == TypeFamily::Unknown {
2138 return right;
2139 }
2140 if right == TypeFamily::Unknown {
2141 return left;
2142 }
2143 if left == right {
2144 return left;
2145 }
2146 if left.is_numeric() && right.is_numeric() {
2147 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2148 return TypeFamily::Numeric;
2149 }
2150 return TypeFamily::Integer;
2151 }
2152 if left.is_temporal() && right.is_temporal() {
2153 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2154 return TypeFamily::Timestamp;
2155 }
2156 if left == TypeFamily::Date || right == TypeFamily::Date {
2157 return TypeFamily::Date;
2158 }
2159 return TypeFamily::Time;
2160 }
2161 TypeFamily::Unknown
2162}
2163
2164fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2165 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2166 return true;
2167 }
2168 if target == source {
2169 return true;
2170 }
2171
2172 match target {
2173 TypeFamily::Boolean => source == TypeFamily::Boolean,
2174 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2175 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2176 source.is_temporal()
2177 }
2178 TypeFamily::String => true,
2179 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2180 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2181 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2182 TypeFamily::Array => source == TypeFamily::Array,
2183 TypeFamily::Map => source == TypeFamily::Map,
2184 TypeFamily::Struct => source == TypeFamily::Struct,
2185 TypeFamily::Unknown => true,
2186 }
2187}
2188
2189fn projection_families(
2190 query_expr: &Expression,
2191 schema_map: &HashMap<String, TableSchemaEntry>,
2192) -> Option<Vec<TypeFamily>> {
2193 match query_expr {
2194 Expression::Select(select) => {
2195 if select
2196 .expressions
2197 .iter()
2198 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2199 {
2200 return None;
2201 }
2202 let select_expr = Expression::Select(select.clone());
2203 let context = collect_type_check_context(&select_expr, schema_map);
2204 Some(
2205 select
2206 .expressions
2207 .iter()
2208 .map(|e| infer_expression_type_family(e, schema_map, &context))
2209 .collect(),
2210 )
2211 }
2212 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2213 Expression::Union(union) => {
2214 let left = projection_families(&union.left, schema_map)?;
2215 let right = projection_families(&union.right, schema_map)?;
2216 if left.len() != right.len() {
2217 return None;
2218 }
2219 Some(
2220 left.into_iter()
2221 .zip(right)
2222 .map(|(l, r)| merged_setop_family(l, r))
2223 .collect(),
2224 )
2225 }
2226 Expression::Intersect(intersect) => {
2227 let left = projection_families(&intersect.left, schema_map)?;
2228 let right = projection_families(&intersect.right, schema_map)?;
2229 if left.len() != right.len() {
2230 return None;
2231 }
2232 Some(
2233 left.into_iter()
2234 .zip(right)
2235 .map(|(l, r)| merged_setop_family(l, r))
2236 .collect(),
2237 )
2238 }
2239 Expression::Except(except) => {
2240 let left = projection_families(&except.left, schema_map)?;
2241 let right = projection_families(&except.right, schema_map)?;
2242 if left.len() != right.len() {
2243 return None;
2244 }
2245 Some(
2246 left.into_iter()
2247 .zip(right)
2248 .map(|(l, r)| merged_setop_family(l, r))
2249 .collect(),
2250 )
2251 }
2252 Expression::Values(values) => {
2253 let first_row = values.expressions.first()?;
2254 let context = TypeCheckContext::default();
2255 Some(
2256 first_row
2257 .expressions
2258 .iter()
2259 .map(|e| infer_expression_type_family(e, schema_map, &context))
2260 .collect(),
2261 )
2262 }
2263 _ => None,
2264 }
2265}
2266
2267fn check_set_operation_compatibility(
2268 op_name: &str,
2269 left_expr: &Expression,
2270 right_expr: &Expression,
2271 schema_map: &HashMap<String, TableSchemaEntry>,
2272 strict: bool,
2273 errors: &mut Vec<ValidationError>,
2274) {
2275 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2276 return;
2277 };
2278 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2279 return;
2280 };
2281
2282 if left_projection.len() != right_projection.len() {
2283 errors.push(type_issue(
2284 strict,
2285 validation_codes::E_SETOP_ARITY_MISMATCH,
2286 validation_codes::W_SETOP_IMPLICIT_COERCION,
2287 format!(
2288 "{} operands return different column counts: left {}, right {}",
2289 op_name,
2290 left_projection.len(),
2291 right_projection.len()
2292 ),
2293 ));
2294 return;
2295 }
2296
2297 for (idx, (left, right)) in left_projection
2298 .into_iter()
2299 .zip(right_projection)
2300 .enumerate()
2301 {
2302 if !are_setop_compatible(left, right) {
2303 errors.push(type_issue(
2304 strict,
2305 validation_codes::E_SETOP_TYPE_MISMATCH,
2306 validation_codes::W_SETOP_IMPLICIT_COERCION,
2307 format!(
2308 "{} column {} has incompatible types: {} vs {}",
2309 op_name,
2310 idx + 1,
2311 type_family_name(left),
2312 type_family_name(right)
2313 ),
2314 ));
2315 }
2316 }
2317}
2318
2319fn check_insert_assignments(
2320 stmt: &Expression,
2321 insert: &Insert,
2322 schema_map: &HashMap<String, TableSchemaEntry>,
2323 strict: bool,
2324 errors: &mut Vec<ValidationError>,
2325) {
2326 let Some((target_table_key, table_schema)) =
2327 resolve_table_schema_entry(&insert.table, schema_map)
2328 else {
2329 return;
2330 };
2331
2332 let mut target_columns = Vec::new();
2333 if insert.columns.is_empty() {
2334 target_columns.extend(table_schema.column_order.iter().cloned());
2335 } else {
2336 for column in &insert.columns {
2337 let col_name = lower(&column.name);
2338 if table_schema.columns.contains_key(&col_name) {
2339 target_columns.push(col_name);
2340 } else {
2341 errors.push(if strict {
2342 ValidationError::error(
2343 format!(
2344 "Unknown column '{}' in table '{}'",
2345 column.name, target_table_key
2346 ),
2347 validation_codes::E_UNKNOWN_COLUMN,
2348 )
2349 } else {
2350 ValidationError::warning(
2351 format!(
2352 "Unknown column '{}' in table '{}'",
2353 column.name, target_table_key
2354 ),
2355 validation_codes::E_UNKNOWN_COLUMN,
2356 )
2357 });
2358 }
2359 }
2360 }
2361
2362 if target_columns.is_empty() {
2363 return;
2364 }
2365
2366 let context = collect_type_check_context(stmt, schema_map);
2367
2368 if !insert.default_values {
2369 for (row_idx, row) in insert.values.iter().enumerate() {
2370 if row.len() != target_columns.len() {
2371 errors.push(type_issue(
2372 strict,
2373 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2374 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2375 format!(
2376 "INSERT row {} has {} values but target has {} columns",
2377 row_idx + 1,
2378 row.len(),
2379 target_columns.len()
2380 ),
2381 ));
2382 continue;
2383 }
2384
2385 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2386 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2387 continue;
2388 };
2389 let source_family = infer_expression_type_family(value, schema_map, &context);
2390 if !are_assignment_compatible(target_family, source_family) {
2391 errors.push(type_issue(
2392 strict,
2393 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2394 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2395 format!(
2396 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2397 target_table_key,
2398 target_column,
2399 type_family_name(target_family),
2400 type_family_name(source_family)
2401 ),
2402 ));
2403 }
2404 }
2405 }
2406 }
2407
2408 if let Some(query) = &insert.query {
2409 if insert.by_name {
2411 return;
2412 }
2413
2414 let Some(source_projection) = projection_families(query, schema_map) else {
2415 return;
2416 };
2417
2418 if source_projection.len() != target_columns.len() {
2419 errors.push(type_issue(
2420 strict,
2421 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2422 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2423 format!(
2424 "INSERT source query has {} columns but target has {} columns",
2425 source_projection.len(),
2426 target_columns.len()
2427 ),
2428 ));
2429 return;
2430 }
2431
2432 for (source_family, target_column) in
2433 source_projection.into_iter().zip(target_columns.iter())
2434 {
2435 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2436 continue;
2437 };
2438 if !are_assignment_compatible(target_family, source_family) {
2439 errors.push(type_issue(
2440 strict,
2441 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2442 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2443 format!(
2444 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2445 target_table_key,
2446 target_column,
2447 type_family_name(target_family),
2448 type_family_name(source_family)
2449 ),
2450 ));
2451 }
2452 }
2453 }
2454}
2455
2456fn check_update_assignments(
2457 stmt: &Expression,
2458 update: &Update,
2459 schema_map: &HashMap<String, TableSchemaEntry>,
2460 strict: bool,
2461 errors: &mut Vec<ValidationError>,
2462) {
2463 let Some((target_table_key, table_schema)) =
2464 resolve_table_schema_entry(&update.table, schema_map)
2465 else {
2466 return;
2467 };
2468
2469 let context = collect_type_check_context(stmt, schema_map);
2470
2471 for (column, value) in &update.set {
2472 let col_name = lower(&column.name);
2473 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2474 errors.push(if strict {
2475 ValidationError::error(
2476 format!(
2477 "Unknown column '{}' in table '{}'",
2478 column.name, target_table_key
2479 ),
2480 validation_codes::E_UNKNOWN_COLUMN,
2481 )
2482 } else {
2483 ValidationError::warning(
2484 format!(
2485 "Unknown column '{}' in table '{}'",
2486 column.name, target_table_key
2487 ),
2488 validation_codes::E_UNKNOWN_COLUMN,
2489 )
2490 });
2491 continue;
2492 };
2493
2494 let source_family = infer_expression_type_family(value, schema_map, &context);
2495 if !are_assignment_compatible(target_family, source_family) {
2496 errors.push(type_issue(
2497 strict,
2498 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2499 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2500 format!(
2501 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2502 target_table_key,
2503 col_name,
2504 type_family_name(target_family),
2505 type_family_name(source_family)
2506 ),
2507 ));
2508 }
2509 }
2510}
2511
2512fn check_types(
2513 stmt: &Expression,
2514 dialect: DialectType,
2515 schema_map: &HashMap<String, TableSchemaEntry>,
2516 function_catalog: Option<&dyn FunctionCatalog>,
2517 strict: bool,
2518) -> Vec<ValidationError> {
2519 let mut errors = Vec::new();
2520 let context = collect_type_check_context(stmt, schema_map);
2521
2522 for node in stmt.dfs() {
2523 match node {
2524 Expression::Insert(insert) => {
2525 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2526 }
2527 Expression::Update(update) => {
2528 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2529 }
2530 Expression::Union(union) => {
2531 check_set_operation_compatibility(
2532 "UNION",
2533 &union.left,
2534 &union.right,
2535 schema_map,
2536 strict,
2537 &mut errors,
2538 );
2539 }
2540 Expression::Intersect(intersect) => {
2541 check_set_operation_compatibility(
2542 "INTERSECT",
2543 &intersect.left,
2544 &intersect.right,
2545 schema_map,
2546 strict,
2547 &mut errors,
2548 );
2549 }
2550 Expression::Except(except) => {
2551 check_set_operation_compatibility(
2552 "EXCEPT",
2553 &except.left,
2554 &except.right,
2555 schema_map,
2556 strict,
2557 &mut errors,
2558 );
2559 }
2560 Expression::Select(select) => {
2561 if let Some(prewhere) = &select.prewhere {
2562 let family = infer_expression_type_family(prewhere, schema_map, &context);
2563 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2564 errors.push(type_issue(
2565 strict,
2566 validation_codes::E_INVALID_PREDICATE_TYPE,
2567 validation_codes::W_PREDICATE_NULLABILITY,
2568 format!(
2569 "PREWHERE clause expects a boolean predicate, found {}",
2570 type_family_name(family)
2571 ),
2572 ));
2573 }
2574 }
2575
2576 if let Some(where_clause) = &select.where_clause {
2577 let family =
2578 infer_expression_type_family(&where_clause.this, schema_map, &context);
2579 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2580 errors.push(type_issue(
2581 strict,
2582 validation_codes::E_INVALID_PREDICATE_TYPE,
2583 validation_codes::W_PREDICATE_NULLABILITY,
2584 format!(
2585 "WHERE clause expects a boolean predicate, found {}",
2586 type_family_name(family)
2587 ),
2588 ));
2589 }
2590 }
2591
2592 if let Some(having_clause) = &select.having {
2593 let family =
2594 infer_expression_type_family(&having_clause.this, schema_map, &context);
2595 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2596 errors.push(type_issue(
2597 strict,
2598 validation_codes::E_INVALID_PREDICATE_TYPE,
2599 validation_codes::W_PREDICATE_NULLABILITY,
2600 format!(
2601 "HAVING clause expects a boolean predicate, found {}",
2602 type_family_name(family)
2603 ),
2604 ));
2605 }
2606 }
2607
2608 for join in &select.joins {
2609 if let Some(on) = &join.on {
2610 let family = infer_expression_type_family(on, schema_map, &context);
2611 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2612 errors.push(type_issue(
2613 strict,
2614 validation_codes::E_INVALID_PREDICATE_TYPE,
2615 validation_codes::W_PREDICATE_NULLABILITY,
2616 format!(
2617 "JOIN ON expects a boolean predicate, found {}",
2618 type_family_name(family)
2619 ),
2620 ));
2621 }
2622 }
2623 if let Some(match_condition) = &join.match_condition {
2624 let family =
2625 infer_expression_type_family(match_condition, schema_map, &context);
2626 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2627 errors.push(type_issue(
2628 strict,
2629 validation_codes::E_INVALID_PREDICATE_TYPE,
2630 validation_codes::W_PREDICATE_NULLABILITY,
2631 format!(
2632 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2633 type_family_name(family)
2634 ),
2635 ));
2636 }
2637 }
2638 }
2639 }
2640 Expression::Where(where_clause) => {
2641 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2642 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2643 errors.push(type_issue(
2644 strict,
2645 validation_codes::E_INVALID_PREDICATE_TYPE,
2646 validation_codes::W_PREDICATE_NULLABILITY,
2647 format!(
2648 "WHERE clause expects a boolean predicate, found {}",
2649 type_family_name(family)
2650 ),
2651 ));
2652 }
2653 }
2654 Expression::Having(having_clause) => {
2655 let family =
2656 infer_expression_type_family(&having_clause.this, schema_map, &context);
2657 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2658 errors.push(type_issue(
2659 strict,
2660 validation_codes::E_INVALID_PREDICATE_TYPE,
2661 validation_codes::W_PREDICATE_NULLABILITY,
2662 format!(
2663 "HAVING clause expects a boolean predicate, found {}",
2664 type_family_name(family)
2665 ),
2666 ));
2667 }
2668 }
2669 Expression::And(op) | Expression::Or(op) => {
2670 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2671 let family = infer_expression_type_family(expr, schema_map, &context);
2672 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2673 errors.push(type_issue(
2674 strict,
2675 validation_codes::E_INVALID_PREDICATE_TYPE,
2676 validation_codes::W_PREDICATE_NULLABILITY,
2677 format!(
2678 "Logical {} operand expects boolean, found {}",
2679 side,
2680 type_family_name(family)
2681 ),
2682 ));
2683 }
2684 }
2685 }
2686 Expression::Not(unary) => {
2687 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2688 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2689 errors.push(type_issue(
2690 strict,
2691 validation_codes::E_INVALID_PREDICATE_TYPE,
2692 validation_codes::W_PREDICATE_NULLABILITY,
2693 format!("NOT expects boolean, found {}", type_family_name(family)),
2694 ));
2695 }
2696 }
2697 Expression::Eq(op)
2698 | Expression::Neq(op)
2699 | Expression::Lt(op)
2700 | Expression::Lte(op)
2701 | Expression::Gt(op)
2702 | Expression::Gte(op) => {
2703 let left = infer_expression_type_family(&op.left, schema_map, &context);
2704 let right = infer_expression_type_family(&op.right, schema_map, &context);
2705 if !are_comparable(left, right) {
2706 errors.push(type_issue(
2707 strict,
2708 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2709 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2710 format!(
2711 "Incompatible comparison between {} and {}",
2712 type_family_name(left),
2713 type_family_name(right)
2714 ),
2715 ));
2716 }
2717 }
2718 Expression::Like(op) | Expression::ILike(op) => {
2719 let left = infer_expression_type_family(&op.left, schema_map, &context);
2720 let right = infer_expression_type_family(&op.right, schema_map, &context);
2721 if left != TypeFamily::Unknown
2722 && right != TypeFamily::Unknown
2723 && (!is_string_like(left) || !is_string_like(right))
2724 {
2725 errors.push(type_issue(
2726 strict,
2727 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2728 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2729 format!(
2730 "LIKE/ILIKE expects string operands, found {} and {}",
2731 type_family_name(left),
2732 type_family_name(right)
2733 ),
2734 ));
2735 }
2736 }
2737 Expression::Between(between) => {
2738 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2739 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2740 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2741
2742 if !are_comparable(this_family, low_family)
2743 || !are_comparable(this_family, high_family)
2744 {
2745 errors.push(type_issue(
2746 strict,
2747 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2748 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2749 format!(
2750 "BETWEEN bounds are incompatible with {} (found {} and {})",
2751 type_family_name(this_family),
2752 type_family_name(low_family),
2753 type_family_name(high_family)
2754 ),
2755 ));
2756 }
2757 }
2758 Expression::In(in_expr) => {
2759 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2760 for value in &in_expr.expressions {
2761 let item_family = infer_expression_type_family(value, schema_map, &context);
2762 if !are_comparable(this_family, item_family) {
2763 errors.push(type_issue(
2764 strict,
2765 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2766 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2767 format!(
2768 "IN item type {} is incompatible with {}",
2769 type_family_name(item_family),
2770 type_family_name(this_family)
2771 ),
2772 ));
2773 break;
2774 }
2775 }
2776 }
2777 Expression::Add(op)
2778 | Expression::Sub(op)
2779 | Expression::Mul(op)
2780 | Expression::Div(op)
2781 | Expression::Mod(op) => {
2782 let left = infer_expression_type_family(&op.left, schema_map, &context);
2783 let right = infer_expression_type_family(&op.right, schema_map, &context);
2784
2785 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2786 continue;
2787 }
2788
2789 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2790 && ((left.is_temporal() && right.is_numeric())
2791 || (right.is_temporal() && left.is_numeric())
2792 || (matches!(node, Expression::Sub(_))
2793 && left.is_temporal()
2794 && right.is_temporal()));
2795
2796 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2797 errors.push(type_issue(
2798 strict,
2799 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2800 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2801 format!(
2802 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2803 type_family_name(left),
2804 type_family_name(right)
2805 ),
2806 ));
2807 }
2808 }
2809 Expression::Function(function) => {
2810 check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2811 check_generic_function(function, schema_map, &context, strict, &mut errors);
2812 }
2813 Expression::Upper(func)
2814 | Expression::Lower(func)
2815 | Expression::LTrim(func)
2816 | Expression::RTrim(func)
2817 | Expression::Reverse(func) => {
2818 let family = infer_expression_type_family(&func.this, schema_map, &context);
2819 check_function_argument(
2820 &mut errors,
2821 strict,
2822 "string_function",
2823 0,
2824 family,
2825 "a string argument",
2826 is_string_like(family),
2827 );
2828 }
2829 Expression::Length(func) => {
2830 let family = infer_expression_type_family(&func.this, schema_map, &context);
2831 check_function_argument(
2832 &mut errors,
2833 strict,
2834 "length",
2835 0,
2836 family,
2837 "a string or binary argument",
2838 is_string_or_binary(family),
2839 );
2840 }
2841 Expression::Trim(func) => {
2842 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2843 check_function_argument(
2844 &mut errors,
2845 strict,
2846 "trim",
2847 0,
2848 this_family,
2849 "a string argument",
2850 is_string_like(this_family),
2851 );
2852 if let Some(chars) = &func.characters {
2853 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2854 check_function_argument(
2855 &mut errors,
2856 strict,
2857 "trim",
2858 1,
2859 chars_family,
2860 "a string argument",
2861 is_string_like(chars_family),
2862 );
2863 }
2864 }
2865 Expression::Substring(func) => {
2866 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2867 check_function_argument(
2868 &mut errors,
2869 strict,
2870 "substring",
2871 0,
2872 this_family,
2873 "a string argument",
2874 is_string_like(this_family),
2875 );
2876
2877 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2878 check_function_argument(
2879 &mut errors,
2880 strict,
2881 "substring",
2882 1,
2883 start_family,
2884 "a numeric argument",
2885 start_family.is_numeric(),
2886 );
2887 if let Some(length) = &func.length {
2888 let length_family = infer_expression_type_family(length, schema_map, &context);
2889 check_function_argument(
2890 &mut errors,
2891 strict,
2892 "substring",
2893 2,
2894 length_family,
2895 "a numeric argument",
2896 length_family.is_numeric(),
2897 );
2898 }
2899 }
2900 Expression::Replace(func) => {
2901 for (arg, idx) in [
2902 (&func.this, 0_usize),
2903 (&func.old, 1_usize),
2904 (&func.new, 2_usize),
2905 ] {
2906 let family = infer_expression_type_family(arg, schema_map, &context);
2907 check_function_argument(
2908 &mut errors,
2909 strict,
2910 "replace",
2911 idx,
2912 family,
2913 "a string argument",
2914 is_string_like(family),
2915 );
2916 }
2917 }
2918 Expression::Left(func) | Expression::Right(func) => {
2919 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2920 check_function_argument(
2921 &mut errors,
2922 strict,
2923 "left_right",
2924 0,
2925 this_family,
2926 "a string argument",
2927 is_string_like(this_family),
2928 );
2929 let length_family =
2930 infer_expression_type_family(&func.length, schema_map, &context);
2931 check_function_argument(
2932 &mut errors,
2933 strict,
2934 "left_right",
2935 1,
2936 length_family,
2937 "a numeric argument",
2938 length_family.is_numeric(),
2939 );
2940 }
2941 Expression::Repeat(func) => {
2942 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2943 check_function_argument(
2944 &mut errors,
2945 strict,
2946 "repeat",
2947 0,
2948 this_family,
2949 "a string argument",
2950 is_string_like(this_family),
2951 );
2952 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2953 check_function_argument(
2954 &mut errors,
2955 strict,
2956 "repeat",
2957 1,
2958 times_family,
2959 "a numeric argument",
2960 times_family.is_numeric(),
2961 );
2962 }
2963 Expression::Lpad(func) | Expression::Rpad(func) => {
2964 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2965 check_function_argument(
2966 &mut errors,
2967 strict,
2968 "pad",
2969 0,
2970 this_family,
2971 "a string argument",
2972 is_string_like(this_family),
2973 );
2974 let length_family =
2975 infer_expression_type_family(&func.length, schema_map, &context);
2976 check_function_argument(
2977 &mut errors,
2978 strict,
2979 "pad",
2980 1,
2981 length_family,
2982 "a numeric argument",
2983 length_family.is_numeric(),
2984 );
2985 if let Some(fill) = &func.fill {
2986 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2987 check_function_argument(
2988 &mut errors,
2989 strict,
2990 "pad",
2991 2,
2992 fill_family,
2993 "a string argument",
2994 is_string_like(fill_family),
2995 );
2996 }
2997 }
2998 Expression::Abs(func)
2999 | Expression::Sqrt(func)
3000 | Expression::Cbrt(func)
3001 | Expression::Ln(func)
3002 | Expression::Exp(func) => {
3003 let family = infer_expression_type_family(&func.this, schema_map, &context);
3004 check_function_argument(
3005 &mut errors,
3006 strict,
3007 "numeric_function",
3008 0,
3009 family,
3010 "a numeric argument",
3011 family.is_numeric(),
3012 );
3013 }
3014 Expression::Round(func) => {
3015 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3016 check_function_argument(
3017 &mut errors,
3018 strict,
3019 "round",
3020 0,
3021 this_family,
3022 "a numeric argument",
3023 this_family.is_numeric(),
3024 );
3025 if let Some(decimals) = &func.decimals {
3026 let decimals_family =
3027 infer_expression_type_family(decimals, schema_map, &context);
3028 check_function_argument(
3029 &mut errors,
3030 strict,
3031 "round",
3032 1,
3033 decimals_family,
3034 "a numeric argument",
3035 decimals_family.is_numeric(),
3036 );
3037 }
3038 }
3039 Expression::Floor(func) => {
3040 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3041 check_function_argument(
3042 &mut errors,
3043 strict,
3044 "floor",
3045 0,
3046 this_family,
3047 "a numeric argument",
3048 this_family.is_numeric(),
3049 );
3050 if let Some(scale) = &func.scale {
3051 let scale_family = infer_expression_type_family(scale, schema_map, &context);
3052 check_function_argument(
3053 &mut errors,
3054 strict,
3055 "floor",
3056 1,
3057 scale_family,
3058 "a numeric argument",
3059 scale_family.is_numeric(),
3060 );
3061 }
3062 }
3063 Expression::Ceil(func) => {
3064 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3065 check_function_argument(
3066 &mut errors,
3067 strict,
3068 "ceil",
3069 0,
3070 this_family,
3071 "a numeric argument",
3072 this_family.is_numeric(),
3073 );
3074 if let Some(decimals) = &func.decimals {
3075 let decimals_family =
3076 infer_expression_type_family(decimals, schema_map, &context);
3077 check_function_argument(
3078 &mut errors,
3079 strict,
3080 "ceil",
3081 1,
3082 decimals_family,
3083 "a numeric argument",
3084 decimals_family.is_numeric(),
3085 );
3086 }
3087 }
3088 Expression::Power(func) => {
3089 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3090 check_function_argument(
3091 &mut errors,
3092 strict,
3093 "power",
3094 0,
3095 left_family,
3096 "a numeric argument",
3097 left_family.is_numeric(),
3098 );
3099 let right_family =
3100 infer_expression_type_family(&func.expression, schema_map, &context);
3101 check_function_argument(
3102 &mut errors,
3103 strict,
3104 "power",
3105 1,
3106 right_family,
3107 "a numeric argument",
3108 right_family.is_numeric(),
3109 );
3110 }
3111 Expression::Log(func) => {
3112 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3113 check_function_argument(
3114 &mut errors,
3115 strict,
3116 "log",
3117 0,
3118 this_family,
3119 "a numeric argument",
3120 this_family.is_numeric(),
3121 );
3122 if let Some(base) = &func.base {
3123 let base_family = infer_expression_type_family(base, schema_map, &context);
3124 check_function_argument(
3125 &mut errors,
3126 strict,
3127 "log",
3128 1,
3129 base_family,
3130 "a numeric argument",
3131 base_family.is_numeric(),
3132 );
3133 }
3134 }
3135 _ => {}
3136 }
3137 }
3138
3139 errors
3140}
3141
3142fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3143 let mut errors = Vec::new();
3144
3145 let Expression::Select(select) = stmt else {
3146 return errors;
3147 };
3148 let select_expr = Expression::Select(select.clone());
3149
3150 if !select_expr
3152 .find_all(|e| matches!(e, Expression::Star(_)))
3153 .is_empty()
3154 {
3155 errors.push(ValidationError::warning(
3156 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3157 validation_codes::W_SELECT_STAR,
3158 ));
3159 }
3160
3161 let aggregate_count = get_aggregate_functions(&select_expr).len();
3163 if aggregate_count > 0 && select.group_by.is_none() {
3164 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3165 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3166 && get_aggregate_functions(expr).is_empty()
3167 });
3168
3169 if has_non_aggregate_column {
3170 errors.push(ValidationError::warning(
3171 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3172 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3173 ));
3174 }
3175 }
3176
3177 if select.distinct && select.order_by.is_some() {
3179 errors.push(ValidationError::warning(
3180 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3181 validation_codes::W_DISTINCT_ORDER_BY,
3182 ));
3183 }
3184
3185 if select.limit.is_some() && select.order_by.is_none() {
3187 errors.push(ValidationError::warning(
3188 "LIMIT without ORDER BY produces non-deterministic results",
3189 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3190 ));
3191 }
3192
3193 errors
3194}
3195
3196fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3197 scope
3198 .sources
3199 .get_key_value(name)
3200 .map(|(k, _)| k.clone())
3201 .or_else(|| {
3202 scope
3203 .sources
3204 .keys()
3205 .find(|source| source.eq_ignore_ascii_case(name))
3206 .cloned()
3207 })
3208}
3209
3210fn source_has_column(columns: &[String], column_name: &str) -> bool {
3211 columns
3212 .iter()
3213 .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3214}
3215
3216fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3217 scope
3218 .sources
3219 .get(source_name)
3220 .map(|source| match &source.expression {
3221 Expression::Table(table) => lower(&table_ref_display_name(table)),
3222 _ => lower(source_name),
3223 })
3224 .unwrap_or_else(|| lower(source_name))
3225}
3226
3227fn validate_select_columns_with_schema(
3228 select: &crate::expressions::Select,
3229 schema_map: &HashMap<String, TableSchemaEntry>,
3230 resolver_schema: &MappingSchema,
3231 strict: bool,
3232) -> Vec<ValidationError> {
3233 let mut errors = Vec::new();
3234 let select_expr = Expression::Select(Box::new(select.clone()));
3235 let scope = build_scope(&select_expr);
3236 let mut resolver = Resolver::new(&scope, resolver_schema, true);
3237 let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3238
3239 for node in walk_in_scope(&select_expr, false) {
3240 let Expression::Column(column) = node else {
3241 continue;
3242 };
3243
3244 let col_name = lower(&column.name.name);
3245 if col_name.is_empty() {
3246 continue;
3247 }
3248
3249 if let Some(table) = &column.table {
3250 let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3251 errors.push(if strict {
3253 ValidationError::error(
3254 format!(
3255 "Unknown table or alias '{}' referenced by column '{}'",
3256 table.name, col_name
3257 ),
3258 validation_codes::E_UNRESOLVED_REFERENCE,
3259 )
3260 } else {
3261 ValidationError::warning(
3262 format!(
3263 "Unknown table or alias '{}' referenced by column '{}'",
3264 table.name, col_name
3265 ),
3266 validation_codes::E_UNRESOLVED_REFERENCE,
3267 )
3268 });
3269 continue;
3270 };
3271
3272 if let Ok(columns) = resolver.get_source_columns(&source_name) {
3273 if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3274 let table_name = source_display_name(&scope, &source_name);
3275 errors.push(if strict {
3276 ValidationError::error(
3277 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3278 validation_codes::E_UNKNOWN_COLUMN,
3279 )
3280 } else {
3281 ValidationError::warning(
3282 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3283 validation_codes::E_UNKNOWN_COLUMN,
3284 )
3285 });
3286 }
3287 }
3288 continue;
3289 }
3290
3291 let matching_sources: Vec<String> = source_names
3292 .iter()
3293 .filter_map(|source_name| {
3294 resolver
3295 .get_source_columns(source_name)
3296 .ok()
3297 .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3298 .map(|_| source_name.clone())
3299 })
3300 .collect();
3301
3302 if !matching_sources.is_empty() {
3303 continue;
3304 }
3305
3306 let known_sources: Vec<String> = source_names
3307 .iter()
3308 .filter_map(|source_name| {
3309 resolver
3310 .get_source_columns(source_name)
3311 .ok()
3312 .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3313 .map(|_| source_name.clone())
3314 })
3315 .collect();
3316
3317 if known_sources.len() == 1 {
3318 let table_name = source_display_name(&scope, &known_sources[0]);
3319 errors.push(if strict {
3320 ValidationError::error(
3321 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3322 validation_codes::E_UNKNOWN_COLUMN,
3323 )
3324 } else {
3325 ValidationError::warning(
3326 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3327 validation_codes::E_UNKNOWN_COLUMN,
3328 )
3329 });
3330 } else if known_sources.len() > 1 {
3331 errors.push(if strict {
3332 ValidationError::error(
3333 format!(
3334 "Unknown column '{}' (not found in any referenced table)",
3335 col_name
3336 ),
3337 validation_codes::E_UNKNOWN_COLUMN,
3338 )
3339 } else {
3340 ValidationError::warning(
3341 format!(
3342 "Unknown column '{}' (not found in any referenced table)",
3343 col_name
3344 ),
3345 validation_codes::E_UNKNOWN_COLUMN,
3346 )
3347 });
3348 } else if !schema_map.is_empty() {
3349 let found = schema_map
3350 .values()
3351 .any(|table_schema| table_schema.columns.contains_key(&col_name));
3352 if !found {
3353 errors.push(if strict {
3354 ValidationError::error(
3355 format!("Unknown column '{}'", col_name),
3356 validation_codes::E_UNKNOWN_COLUMN,
3357 )
3358 } else {
3359 ValidationError::warning(
3360 format!("Unknown column '{}'", col_name),
3361 validation_codes::E_UNKNOWN_COLUMN,
3362 )
3363 });
3364 }
3365 }
3366 }
3367
3368 errors
3369}
3370
3371fn validate_statement_with_schema(
3372 stmt: &Expression,
3373 schema_map: &HashMap<String, TableSchemaEntry>,
3374 resolver_schema: &MappingSchema,
3375 strict: bool,
3376) -> Vec<ValidationError> {
3377 let mut errors = Vec::new();
3378 let cte_aliases = collect_cte_aliases(stmt);
3379 let mut seen_tables: HashSet<String> = HashSet::new();
3380
3381 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3383 let Expression::Table(table) = node else {
3384 continue;
3385 };
3386
3387 if cte_aliases.contains(&lower(&table.name.name)) {
3388 continue;
3389 }
3390
3391 let resolved_key = table_ref_candidates(table)
3392 .into_iter()
3393 .find(|k| schema_map.contains_key(k));
3394 let table_key = resolved_key
3395 .clone()
3396 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3397
3398 if !seen_tables.insert(table_key) {
3399 continue;
3400 }
3401
3402 if resolved_key.is_none() {
3403 errors.push(if strict {
3404 ValidationError::error(
3405 format!("Unknown table '{}'", table_ref_display_name(table)),
3406 validation_codes::E_UNKNOWN_TABLE,
3407 )
3408 } else {
3409 ValidationError::warning(
3410 format!("Unknown table '{}'", table_ref_display_name(table)),
3411 validation_codes::E_UNKNOWN_TABLE,
3412 )
3413 });
3414 }
3415 }
3416
3417 for node in stmt.dfs() {
3418 let Expression::Select(select) = node else {
3419 continue;
3420 };
3421 errors.extend(validate_select_columns_with_schema(
3422 select,
3423 schema_map,
3424 resolver_schema,
3425 strict,
3426 ));
3427 }
3428
3429 errors
3430}
3431
3432pub fn validate_with_schema(
3434 sql: &str,
3435 dialect: DialectType,
3436 schema: &ValidationSchema,
3437 options: &SchemaValidationOptions,
3438) -> ValidationResult {
3439 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3440
3441 let syntax_result = crate::validate_with_options(
3443 sql,
3444 dialect,
3445 &crate::ValidationOptions {
3446 strict_syntax: options.strict_syntax,
3447 },
3448 );
3449 if !syntax_result.valid {
3450 return syntax_result;
3451 }
3452
3453 let d = Dialect::get(dialect);
3454 let statements = match d.parse(sql) {
3455 Ok(exprs) => exprs,
3456 Err(e) => {
3457 return ValidationResult::with_errors(vec![ValidationError::error(
3458 e.to_string(),
3459 validation_codes::E_PARSE_OR_OPTIONS,
3460 )]);
3461 }
3462 };
3463
3464 let schema_map = build_schema_map(schema);
3465 let resolver_schema = build_resolver_schema(schema);
3466 let mut all_errors = syntax_result.errors;
3467 let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3468 default_embedded_function_catalog()
3469 } else {
3470 None
3471 };
3472 let effective_function_catalog = options
3473 .function_catalog
3474 .as_deref()
3475 .or_else(|| embedded_function_catalog.as_deref());
3476 let declared_relationships = if options.check_references {
3477 build_declared_relationships(schema, &schema_map)
3478 } else {
3479 Vec::new()
3480 };
3481
3482 if options.check_references {
3483 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3484 }
3485
3486 for stmt in &statements {
3487 if options.semantic {
3488 all_errors.extend(check_semantics(stmt));
3489 }
3490 all_errors.extend(validate_statement_with_schema(
3491 stmt,
3492 &schema_map,
3493 &resolver_schema,
3494 strict,
3495 ));
3496 if options.check_types {
3497 all_errors.extend(check_types(
3498 stmt,
3499 dialect,
3500 &schema_map,
3501 effective_function_catalog,
3502 strict,
3503 ));
3504 }
3505 if options.check_references {
3506 all_errors.extend(check_query_reference_quality(
3507 stmt,
3508 &schema_map,
3509 &resolver_schema,
3510 strict,
3511 &declared_relationships,
3512 ));
3513 }
3514 }
3515
3516 ValidationResult::with_errors(all_errors)
3517}
3518
3519#[cfg(test)]
3520mod tests;