1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15 feature = "function-catalog-clickhouse",
16 feature = "function-catalog-duckdb",
17 feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20 FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21 HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34 feature = "function-catalog-clickhouse",
35 feature = "function-catalog-duckdb",
36 feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43 pub name: String,
45 #[serde(default, rename = "type")]
47 pub data_type: String,
48 #[serde(default)]
50 pub nullable: Option<bool>,
51 #[serde(default, rename = "primaryKey")]
53 pub primary_key: bool,
54 #[serde(default)]
56 pub unique: bool,
57 #[serde(default)]
59 pub references: Option<SchemaColumnReference>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65 pub table: String,
67 pub column: String,
69 #[serde(default)]
71 pub schema: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77 #[serde(default)]
79 pub name: Option<String>,
80 pub columns: Vec<String>,
82 pub references: SchemaTableReference,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89 pub table: String,
91 pub columns: Vec<String>,
93 #[serde(default)]
95 pub schema: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101 pub name: String,
103 #[serde(default)]
105 pub schema: Option<String>,
106 pub columns: Vec<SchemaColumn>,
108 #[serde(default)]
110 pub aliases: Vec<String>,
111 #[serde(default, rename = "primaryKey")]
113 pub primary_key: Vec<String>,
114 #[serde(default, rename = "uniqueKeys")]
116 pub unique_keys: Vec<Vec<String>>,
117 #[serde(default, rename = "foreignKeys")]
119 pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125 pub tables: Vec<SchemaTable>,
127 #[serde(default)]
129 pub strict: Option<bool>,
130}
131
132#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135 #[serde(default)]
137 pub check_types: bool,
138 #[serde(default)]
140 pub check_references: bool,
141 #[serde(default)]
143 pub strict: Option<bool>,
144 #[serde(default)]
146 pub semantic: bool,
147 #[serde(default)]
149 pub strict_syntax: bool,
150 #[serde(skip, default)]
152 pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156 feature = "function-catalog-clickhouse",
157 feature = "function-catalog-duckdb",
158 feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161 case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163 match case {
164 polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165 CoreFunctionNameCase::Insensitive
166 }
167 polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168 CoreFunctionNameCase::Sensitive
169 }
170 }
171}
172
173#[cfg(any(
174 feature = "function-catalog-clickhouse",
175 feature = "function-catalog-duckdb",
176 feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181 signatures
182 .into_iter()
183 .map(|signature| CoreFunctionSignature {
184 min_arity: signature.min_arity,
185 max_arity: signature.max_arity,
186 })
187 .collect()
188}
189
190#[cfg(any(
191 feature = "function-catalog-clickhouse",
192 feature = "function-catalog-duckdb",
193 feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196 catalog: &'a mut HashMapFunctionCatalog,
197 dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201 feature = "function-catalog-clickhouse",
202 feature = "function-catalog-duckdb",
203 feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206 fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207 if let Some(cached) = self.dialect_cache.get(dialect) {
208 return *cached;
209 }
210 let parsed = dialect.parse::<DialectType>().ok();
211 self.dialect_cache.insert(dialect, parsed);
212 parsed
213 }
214}
215
216#[cfg(any(
217 feature = "function-catalog-clickhouse",
218 feature = "function-catalog-duckdb",
219 feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222 fn set_dialect_name_case(
223 &mut self,
224 dialect: &'static str,
225 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226 ) {
227 if let Some(core_dialect) = self.resolve_dialect(dialect) {
228 self.catalog
229 .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230 }
231 }
232
233 fn set_function_name_case(
234 &mut self,
235 dialect: &'static str,
236 function_name: &str,
237 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238 ) {
239 if let Some(core_dialect) = self.resolve_dialect(dialect) {
240 self.catalog.set_function_name_case(
241 core_dialect,
242 function_name,
243 to_core_name_case(name_case),
244 );
245 }
246 }
247
248 fn register(
249 &mut self,
250 dialect: &'static str,
251 function_name: &str,
252 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253 ) {
254 if let Some(core_dialect) = self.resolve_dialect(dialect) {
255 self.catalog
256 .register(core_dialect, function_name, to_core_signatures(signatures));
257 }
258 }
259}
260
261#[cfg(any(
262 feature = "function-catalog-clickhouse",
263 feature = "function-catalog-duckdb",
264 feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267 static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268 let mut catalog = HashMapFunctionCatalog::default();
269 let mut sink = EmbeddedCatalogSink {
270 catalog: &mut catalog,
271 dialect_cache: HashMap::new(),
272 };
273 polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274 Arc::new(catalog)
275 });
276
277 EMBEDDED.clone()
278}
279
280#[cfg(any(
281 feature = "function-catalog-clickhouse",
282 feature = "function-catalog-duckdb",
283 feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286 Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290 feature = "function-catalog-clickhouse",
291 feature = "function-catalog-duckdb",
292 feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295 None
296}
297
298pub mod validation_codes {
300 pub const E_PARSE_OR_OPTIONS: &str = "E000";
302 pub const E_UNKNOWN_TABLE: &str = "E200";
303 pub const E_UNKNOWN_COLUMN: &str = "E201";
304 pub const E_UNKNOWN_FUNCTION: &str = "E202";
305 pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307 pub const W_SELECT_STAR: &str = "W001";
308 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309 pub const W_DISTINCT_ORDER_BY: &str = "W003";
310 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312 pub const E_TYPE_MISMATCH: &str = "E210";
314 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321 pub const E_INVALID_CAST: &str = "E218";
322 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327 pub const W_LOSSY_CAST: &str = "W213";
328 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329 pub const W_PREDICATE_NULLABILITY: &str = "W215";
330 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342 pub const W_CARTESIAN_JOIN: &str = "W220";
343 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351 Unknown,
352 Boolean,
353 Integer,
354 Numeric,
355 String,
356 Binary,
357 Date,
358 Time,
359 Timestamp,
360 Interval,
361 Json,
362 Uuid,
363 Array,
364 Map,
365 Struct,
366}
367
368impl TypeFamily {
369 pub fn is_numeric(self) -> bool {
370 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371 }
372
373 pub fn is_temporal(self) -> bool {
374 matches!(
375 self,
376 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377 )
378 }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383 columns: HashMap<String, TypeFamily>,
384 column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388 s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392 let open = data_type.find('(')?;
393 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394 return None;
395 }
396 let base = data_type[..open].trim();
397 let inner = data_type[open + 1..data_type.len() - 1].trim();
398 Some((base, inner))
399}
400
401pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403 let trimmed = data_type
404 .trim()
405 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406 if trimmed.is_empty() {
407 return TypeFamily::Unknown;
408 }
409
410 let normalized = trimmed
412 .split_whitespace()
413 .collect::<Vec<_>>()
414 .join(" ")
415 .to_lowercase();
416
417 if let Some((base, inner)) = split_type_args(&normalized) {
419 match base {
420 "nullable" | "lowcardinality" => return canonical_type_family(inner),
421 "array" | "list" => return TypeFamily::Array,
422 "map" => return TypeFamily::Map,
423 "struct" | "row" | "record" => return TypeFamily::Struct,
424 _ => {}
425 }
426 }
427
428 if normalized.starts_with("array<") || normalized.starts_with("list<") {
429 return TypeFamily::Array;
430 }
431 if normalized.starts_with("map<") {
432 return TypeFamily::Map;
433 }
434 if normalized.starts_with("struct<")
435 || normalized.starts_with("row<")
436 || normalized.starts_with("record<")
437 || normalized.starts_with("object<")
438 {
439 return TypeFamily::Struct;
440 }
441
442 if normalized.ends_with("[]") {
443 return TypeFamily::Array;
444 }
445
446 let mut base = normalized
448 .split('(')
449 .next()
450 .unwrap_or("")
451 .trim()
452 .to_string();
453 if base.is_empty() {
454 return TypeFamily::Unknown;
455 }
456
457 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460 match base.as_str() {
461 "bool" | "boolean" => TypeFamily::Boolean,
462 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465 TypeFamily::Integer
466 }
467 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469 TypeFamily::Numeric
470 }
471 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472 | "string" | "clob" => TypeFamily::String,
473 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474 "date" => TypeFamily::Date,
475 "time" => TypeFamily::Time,
476 "timestamp"
477 | "timestamptz"
478 | "datetime"
479 | "datetime2"
480 | "smalldatetime"
481 | "timestamp with time zone"
482 | "timestamp without time zone" => TypeFamily::Timestamp,
483 "interval" => TypeFamily::Interval,
484 "json" | "jsonb" | "variant" => TypeFamily::Json,
485 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486 "array" | "list" => TypeFamily::Array,
487 "map" => TypeFamily::Map,
488 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489 _ => TypeFamily::Unknown,
490 }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494 let mut map = HashMap::new();
495
496 for table in &schema.tables {
497 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498 let columns: HashMap<String, TypeFamily> = table
499 .columns
500 .iter()
501 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502 .collect();
503 let entry = TableSchemaEntry {
504 columns,
505 column_order,
506 };
507
508 let simple_name = lower(&table.name);
509 map.insert(simple_name, entry.clone());
510
511 if let Some(table_schema) = &table.schema {
512 map.insert(
513 format!("{}.{}", lower(table_schema), lower(&table.name)),
514 entry.clone(),
515 );
516 }
517
518 for alias in &table.aliases {
519 map.insert(lower(alias), entry.clone());
520 }
521 }
522
523 map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527 match family {
528 TypeFamily::Unknown => DataType::Unknown,
529 TypeFamily::Boolean => DataType::Boolean,
530 TypeFamily::Integer => DataType::Int {
531 length: None,
532 integer_spelling: false,
533 },
534 TypeFamily::Numeric => DataType::Double {
535 precision: None,
536 scale: None,
537 },
538 TypeFamily::String => DataType::VarChar {
539 length: None,
540 parenthesized_length: false,
541 },
542 TypeFamily::Binary => DataType::VarBinary { length: None },
543 TypeFamily::Date => DataType::Date,
544 TypeFamily::Time => DataType::Time {
545 precision: None,
546 timezone: false,
547 },
548 TypeFamily::Timestamp => DataType::Timestamp {
549 precision: None,
550 timezone: false,
551 },
552 TypeFamily::Interval => DataType::Interval {
553 unit: None,
554 to: None,
555 },
556 TypeFamily::Json => DataType::Json,
557 TypeFamily::Uuid => DataType::Uuid,
558 TypeFamily::Array => DataType::Array {
559 element_type: Box::new(DataType::Unknown),
560 dimension: None,
561 },
562 TypeFamily::Map => DataType::Map {
563 key_type: Box::new(DataType::Unknown),
564 value_type: Box::new(DataType::Unknown),
565 },
566 TypeFamily::Struct => DataType::Struct {
567 fields: Vec::new(),
568 nested: false,
569 },
570 }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574 let mut mapping = MappingSchema::new();
575
576 for table in &schema.tables {
577 let columns: Vec<(String, DataType)> = table
578 .columns
579 .iter()
580 .map(|column| {
581 (
582 lower(&column.name),
583 type_family_to_data_type(canonical_type_family(&column.data_type)),
584 )
585 })
586 .collect();
587
588 let mut table_names = Vec::new();
589 table_names.push(lower(&table.name));
590 if let Some(table_schema) = &table.schema {
591 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592 }
593 for alias in &table.aliases {
594 table_names.push(lower(alias));
595 }
596
597 let mut dedup = HashSet::new();
598 for table_name in table_names {
599 if dedup.insert(table_name.clone()) {
600 let _ = mapping.add_table(&table_name, &columns, None);
601 }
602 }
603 }
604
605 mapping
606}
607
608fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
609 let mut aliases = HashSet::new();
610
611 for node in expr.dfs() {
612 match node {
613 Expression::Select(select) => {
614 if let Some(with) = &select.with {
615 for cte in &with.ctes {
616 aliases.insert(lower(&cte.alias.name));
617 }
618 }
619 }
620 Expression::Insert(insert) => {
621 if let Some(with) = &insert.with {
622 for cte in &with.ctes {
623 aliases.insert(lower(&cte.alias.name));
624 }
625 }
626 }
627 Expression::Update(update) => {
628 if let Some(with) = &update.with {
629 for cte in &with.ctes {
630 aliases.insert(lower(&cte.alias.name));
631 }
632 }
633 }
634 Expression::Delete(delete) => {
635 if let Some(with) = &delete.with {
636 for cte in &with.ctes {
637 aliases.insert(lower(&cte.alias.name));
638 }
639 }
640 }
641 Expression::Union(union) => {
642 if let Some(with) = &union.with {
643 for cte in &with.ctes {
644 aliases.insert(lower(&cte.alias.name));
645 }
646 }
647 }
648 Expression::Intersect(intersect) => {
649 if let Some(with) = &intersect.with {
650 for cte in &with.ctes {
651 aliases.insert(lower(&cte.alias.name));
652 }
653 }
654 }
655 Expression::Except(except) => {
656 if let Some(with) = &except.with {
657 for cte in &with.ctes {
658 aliases.insert(lower(&cte.alias.name));
659 }
660 }
661 }
662 Expression::Merge(merge) => {
663 if let Some(with_) = &merge.with_ {
664 if let Expression::With(with_clause) = with_.as_ref() {
665 for cte in &with_clause.ctes {
666 aliases.insert(lower(&cte.alias.name));
667 }
668 }
669 }
670 }
671 _ => {}
672 }
673 }
674
675 aliases
676}
677
678fn table_ref_candidates(table: &TableRef) -> Vec<String> {
679 let name = lower(&table.name.name);
680 let schema = table.schema.as_ref().map(|s| lower(&s.name));
681 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
682
683 let mut candidates = Vec::new();
684 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
685 candidates.push(format!("{}.{}.{}", catalog, schema, name));
686 }
687 if let Some(schema) = &schema {
688 candidates.push(format!("{}.{}", schema, name));
689 }
690 candidates.push(name);
691 candidates
692}
693
694fn table_ref_display_name(table: &TableRef) -> String {
695 let mut parts = Vec::new();
696 if let Some(catalog) = &table.catalog {
697 parts.push(catalog.name.clone());
698 }
699 if let Some(schema) = &table.schema {
700 parts.push(schema.name.clone());
701 }
702 parts.push(table.name.name.clone());
703 parts.join(".")
704}
705
706#[derive(Debug, Default, Clone)]
707struct TypeCheckContext {
708 referenced_tables: HashSet<String>,
709 table_aliases: HashMap<String, String>,
710}
711
712fn type_family_name(family: TypeFamily) -> &'static str {
713 match family {
714 TypeFamily::Unknown => "unknown",
715 TypeFamily::Boolean => "boolean",
716 TypeFamily::Integer => "integer",
717 TypeFamily::Numeric => "numeric",
718 TypeFamily::String => "string",
719 TypeFamily::Binary => "binary",
720 TypeFamily::Date => "date",
721 TypeFamily::Time => "time",
722 TypeFamily::Timestamp => "timestamp",
723 TypeFamily::Interval => "interval",
724 TypeFamily::Json => "json",
725 TypeFamily::Uuid => "uuid",
726 TypeFamily::Array => "array",
727 TypeFamily::Map => "map",
728 TypeFamily::Struct => "struct",
729 }
730}
731
732fn is_string_like(family: TypeFamily) -> bool {
733 matches!(family, TypeFamily::String)
734}
735
736fn is_string_or_binary(family: TypeFamily) -> bool {
737 matches!(family, TypeFamily::String | TypeFamily::Binary)
738}
739
740fn type_issue(
741 strict: bool,
742 error_code: &str,
743 warning_code: &str,
744 message: impl Into<String>,
745) -> ValidationError {
746 if strict {
747 ValidationError::error(message.into(), error_code)
748 } else {
749 ValidationError::warning(message.into(), warning_code)
750 }
751}
752
753fn data_type_family(data_type: &DataType) -> TypeFamily {
754 match data_type {
755 DataType::Boolean => TypeFamily::Boolean,
756 DataType::TinyInt { .. }
757 | DataType::SmallInt { .. }
758 | DataType::Int { .. }
759 | DataType::BigInt { .. } => TypeFamily::Integer,
760 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
761 TypeFamily::Numeric
762 }
763 DataType::Char { .. }
764 | DataType::VarChar { .. }
765 | DataType::String { .. }
766 | DataType::Text
767 | DataType::TextWithLength { .. }
768 | DataType::CharacterSet { .. } => TypeFamily::String,
769 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
770 DataType::Date => TypeFamily::Date,
771 DataType::Time { .. } => TypeFamily::Time,
772 DataType::Timestamp { .. } => TypeFamily::Timestamp,
773 DataType::Interval { .. } => TypeFamily::Interval,
774 DataType::Json | DataType::JsonB => TypeFamily::Json,
775 DataType::Uuid => TypeFamily::Uuid,
776 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
777 DataType::Map { .. } => TypeFamily::Map,
778 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
779 TypeFamily::Struct
780 }
781 DataType::Nullable { inner } => data_type_family(inner),
782 DataType::Custom { name } => canonical_type_family(name),
783 DataType::Unknown => TypeFamily::Unknown,
784 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
785 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
786 DataType::Vector { .. } => TypeFamily::Array,
787 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
788 }
789}
790
791fn collect_type_check_context(
792 stmt: &Expression,
793 schema_map: &HashMap<String, TableSchemaEntry>,
794) -> TypeCheckContext {
795 fn add_table_to_context(
796 table: &TableRef,
797 schema_map: &HashMap<String, TableSchemaEntry>,
798 context: &mut TypeCheckContext,
799 ) {
800 let resolved_key = table_ref_candidates(table)
801 .into_iter()
802 .find(|k| schema_map.contains_key(k));
803
804 let Some(table_key) = resolved_key else {
805 return;
806 };
807
808 context.referenced_tables.insert(table_key.clone());
809 context
810 .table_aliases
811 .insert(lower(&table.name.name), table_key.clone());
812 if let Some(alias) = &table.alias {
813 context
814 .table_aliases
815 .insert(lower(&alias.name), table_key.clone());
816 }
817 }
818
819 let mut context = TypeCheckContext::default();
820 let cte_aliases = collect_cte_aliases(stmt);
821
822 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
823 let Expression::Table(table) = node else {
824 continue;
825 };
826
827 if cte_aliases.contains(&lower(&table.name.name)) {
828 continue;
829 }
830
831 add_table_to_context(table, schema_map, &mut context);
832 }
833
834 match stmt {
837 Expression::Insert(insert) => {
838 add_table_to_context(&insert.table, schema_map, &mut context);
839 }
840 Expression::Update(update) => {
841 add_table_to_context(&update.table, schema_map, &mut context);
842 for table in &update.extra_tables {
843 add_table_to_context(table, schema_map, &mut context);
844 }
845 }
846 Expression::Delete(delete) => {
847 add_table_to_context(&delete.table, schema_map, &mut context);
848 for table in &delete.using {
849 add_table_to_context(table, schema_map, &mut context);
850 }
851 for table in &delete.tables {
852 add_table_to_context(table, schema_map, &mut context);
853 }
854 }
855 _ => {}
856 }
857
858 context
859}
860
861fn resolve_table_schema_entry<'a>(
862 table: &TableRef,
863 schema_map: &'a HashMap<String, TableSchemaEntry>,
864) -> Option<(String, &'a TableSchemaEntry)> {
865 let key = table_ref_candidates(table)
866 .into_iter()
867 .find(|k| schema_map.contains_key(k))?;
868 let entry = schema_map.get(&key)?;
869 Some((key, entry))
870}
871
872fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
873 if strict {
874 ValidationError::error(
875 message.into(),
876 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
877 )
878 } else {
879 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
880 }
881}
882
883fn reference_table_candidates(
884 table_name: &str,
885 explicit_schema: Option<&str>,
886 source_schema: Option<&str>,
887) -> Vec<String> {
888 let mut candidates = Vec::new();
889 let raw = lower(table_name);
890
891 if let Some(schema) = explicit_schema {
892 candidates.push(format!("{}.{}", lower(schema), raw));
893 }
894
895 if raw.contains('.') {
896 candidates.push(raw.clone());
897 if let Some(last) = raw.rsplit('.').next() {
898 candidates.push(last.to_string());
899 }
900 } else {
901 if let Some(schema) = source_schema {
902 candidates.push(format!("{}.{}", lower(schema), raw));
903 }
904 candidates.push(raw);
905 }
906
907 let mut dedup = HashSet::new();
908 candidates
909 .into_iter()
910 .filter(|c| dedup.insert(c.clone()))
911 .collect()
912}
913
914fn resolve_reference_table_key(
915 table_name: &str,
916 explicit_schema: Option<&str>,
917 source_schema: Option<&str>,
918 schema_map: &HashMap<String, TableSchemaEntry>,
919) -> Option<String> {
920 reference_table_candidates(table_name, explicit_schema, source_schema)
921 .into_iter()
922 .find(|candidate| schema_map.contains_key(candidate))
923}
924
925fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
926 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
927 return true;
928 }
929 if source == target {
930 return true;
931 }
932 if source.is_numeric() && target.is_numeric() {
933 return true;
934 }
935 if source.is_temporal() && target.is_temporal() {
936 return true;
937 }
938 false
939}
940
941fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
942 let mut hints = HashSet::new();
943 for column in &table.columns {
944 if column.primary_key || column.unique {
945 hints.insert(lower(&column.name));
946 }
947 }
948 for key_col in &table.primary_key {
949 hints.insert(lower(key_col));
950 }
951 for group in &table.unique_keys {
952 if group.len() == 1 {
953 if let Some(col) = group.first() {
954 hints.insert(lower(col));
955 }
956 }
957 }
958 hints
959}
960
961fn check_reference_integrity(
962 schema: &ValidationSchema,
963 schema_map: &HashMap<String, TableSchemaEntry>,
964 strict: bool,
965) -> Vec<ValidationError> {
966 let mut errors = Vec::new();
967
968 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
969 for table in &schema.tables {
970 let simple = lower(&table.name);
971 key_hints_lookup.insert(simple, table_key_hints(table));
972 if let Some(schema_name) = &table.schema {
973 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
974 key_hints_lookup.insert(qualified, table_key_hints(table));
975 }
976 }
977
978 for table in &schema.tables {
979 let source_table_display = if let Some(schema_name) = &table.schema {
980 format!("{}.{}", schema_name, table.name)
981 } else {
982 table.name.clone()
983 };
984 let source_schema = table.schema.as_deref();
985 let source_columns: HashMap<String, TypeFamily> = table
986 .columns
987 .iter()
988 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
989 .collect();
990
991 for source_col in &table.columns {
992 let Some(reference) = &source_col.references else {
993 continue;
994 };
995 let source_type = canonical_type_family(&source_col.data_type);
996
997 let Some(target_key) = resolve_reference_table_key(
998 &reference.table,
999 reference.schema.as_deref(),
1000 source_schema,
1001 schema_map,
1002 ) else {
1003 errors.push(reference_issue(
1004 strict,
1005 format!(
1006 "Foreign key reference '{}.{}' points to unknown table '{}'",
1007 source_table_display, source_col.name, reference.table
1008 ),
1009 ));
1010 continue;
1011 };
1012
1013 let target_column = lower(&reference.column);
1014 let Some(target_entry) = schema_map.get(&target_key) else {
1015 errors.push(reference_issue(
1016 strict,
1017 format!(
1018 "Foreign key reference '{}.{}' points to unknown table '{}'",
1019 source_table_display, source_col.name, reference.table
1020 ),
1021 ));
1022 continue;
1023 };
1024
1025 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1026 errors.push(reference_issue(
1027 strict,
1028 format!(
1029 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1030 source_table_display, source_col.name, target_key, reference.column
1031 ),
1032 ));
1033 continue;
1034 };
1035
1036 if !key_types_compatible(source_type, target_type) {
1037 errors.push(reference_issue(
1038 strict,
1039 format!(
1040 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1041 source_table_display,
1042 source_col.name,
1043 target_key,
1044 reference.column,
1045 type_family_name(source_type),
1046 type_family_name(target_type)
1047 ),
1048 ));
1049 }
1050
1051 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1052 if !target_key_hints.contains(&target_column) {
1053 errors.push(ValidationError::warning(
1054 format!(
1055 "Referenced column '{}.{}' is not marked as primary/unique key",
1056 target_key, reference.column
1057 ),
1058 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1059 ));
1060 }
1061 }
1062 }
1063
1064 for foreign_key in &table.foreign_keys {
1065 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1066 errors.push(reference_issue(
1067 strict,
1068 format!(
1069 "Table-level foreign key on '{}' has empty source or target column list",
1070 source_table_display
1071 ),
1072 ));
1073 continue;
1074 }
1075 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1076 errors.push(reference_issue(
1077 strict,
1078 format!(
1079 "Table-level foreign key on '{}' has {} source columns but {} target columns",
1080 source_table_display,
1081 foreign_key.columns.len(),
1082 foreign_key.references.columns.len()
1083 ),
1084 ));
1085 continue;
1086 }
1087
1088 let Some(target_key) = resolve_reference_table_key(
1089 &foreign_key.references.table,
1090 foreign_key.references.schema.as_deref(),
1091 source_schema,
1092 schema_map,
1093 ) else {
1094 errors.push(reference_issue(
1095 strict,
1096 format!(
1097 "Table-level foreign key on '{}' points to unknown table '{}'",
1098 source_table_display, foreign_key.references.table
1099 ),
1100 ));
1101 continue;
1102 };
1103
1104 let Some(target_entry) = schema_map.get(&target_key) else {
1105 errors.push(reference_issue(
1106 strict,
1107 format!(
1108 "Table-level foreign key on '{}' points to unknown table '{}'",
1109 source_table_display, foreign_key.references.table
1110 ),
1111 ));
1112 continue;
1113 };
1114
1115 for (source_col, target_col) in foreign_key
1116 .columns
1117 .iter()
1118 .zip(foreign_key.references.columns.iter())
1119 {
1120 let source_col_name = lower(source_col);
1121 let target_col_name = lower(target_col);
1122
1123 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1124 errors.push(reference_issue(
1125 strict,
1126 format!(
1127 "Table-level foreign key on '{}' references unknown source column '{}'",
1128 source_table_display, source_col
1129 ),
1130 ));
1131 continue;
1132 };
1133
1134 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1135 errors.push(reference_issue(
1136 strict,
1137 format!(
1138 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1139 source_table_display, target_key, target_col
1140 ),
1141 ));
1142 continue;
1143 };
1144
1145 if !key_types_compatible(source_type, target_type) {
1146 errors.push(reference_issue(
1147 strict,
1148 format!(
1149 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1150 source_table_display,
1151 source_col,
1152 target_key,
1153 target_col,
1154 type_family_name(source_type),
1155 type_family_name(target_type)
1156 ),
1157 ));
1158 }
1159
1160 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1161 if !target_key_hints.contains(&target_col_name) {
1162 errors.push(ValidationError::warning(
1163 format!(
1164 "Referenced column '{}.{}' is not marked as primary/unique key",
1165 target_key, target_col
1166 ),
1167 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1168 ));
1169 }
1170 }
1171 }
1172 }
1173 }
1174
1175 errors
1176}
1177
1178fn resolve_unqualified_column_type(
1179 column_name: &str,
1180 schema_map: &HashMap<String, TableSchemaEntry>,
1181 context: &TypeCheckContext,
1182) -> TypeFamily {
1183 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1184 context.referenced_tables.iter().collect()
1185 } else {
1186 schema_map.keys().collect()
1187 };
1188
1189 let mut families = HashSet::new();
1190 for table_name in candidate_tables {
1191 if let Some(table_schema) = schema_map.get(table_name) {
1192 if let Some(family) = table_schema.columns.get(column_name) {
1193 families.insert(*family);
1194 }
1195 }
1196 }
1197
1198 if families.len() == 1 {
1199 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1200 } else {
1201 TypeFamily::Unknown
1202 }
1203}
1204
1205fn resolve_column_type(
1206 column: &Column,
1207 schema_map: &HashMap<String, TableSchemaEntry>,
1208 context: &TypeCheckContext,
1209) -> TypeFamily {
1210 let column_name = lower(&column.name.name);
1211 if column_name.is_empty() {
1212 return TypeFamily::Unknown;
1213 }
1214
1215 if let Some(table) = &column.table {
1216 let mut table_key = lower(&table.name);
1217 if let Some(mapped) = context.table_aliases.get(&table_key) {
1218 table_key = mapped.clone();
1219 }
1220
1221 return schema_map
1222 .get(&table_key)
1223 .and_then(|t| t.columns.get(&column_name))
1224 .copied()
1225 .unwrap_or(TypeFamily::Unknown);
1226 }
1227
1228 resolve_unqualified_column_type(&column_name, schema_map, context)
1229}
1230
1231struct TypeInferenceSchema<'a> {
1232 schema_map: &'a HashMap<String, TableSchemaEntry>,
1233 context: &'a TypeCheckContext,
1234}
1235
1236impl TypeInferenceSchema<'_> {
1237 fn resolve_table_key(&self, table: &str) -> Option<String> {
1238 let mut table_key = lower(table);
1239 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1240 table_key = mapped.clone();
1241 }
1242 if self.schema_map.contains_key(&table_key) {
1243 Some(table_key)
1244 } else {
1245 None
1246 }
1247 }
1248}
1249
1250impl SqlSchema for TypeInferenceSchema<'_> {
1251 fn dialect(&self) -> Option<DialectType> {
1252 None
1253 }
1254
1255 fn add_table(
1256 &mut self,
1257 _table: &str,
1258 _columns: &[(String, DataType)],
1259 _dialect: Option<DialectType>,
1260 ) -> SchemaResult<()> {
1261 Err(SchemaError::InvalidStructure(
1262 "Type inference schema is read-only".to_string(),
1263 ))
1264 }
1265
1266 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1267 let table_key = self
1268 .resolve_table_key(table)
1269 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1270 let entry = self
1271 .schema_map
1272 .get(&table_key)
1273 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1274 Ok(entry.column_order.clone())
1275 }
1276
1277 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1278 let col_name = lower(column);
1279 if table.is_empty() {
1280 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1281 return if family == TypeFamily::Unknown {
1282 Err(SchemaError::ColumnNotFound {
1283 table: "<unqualified>".to_string(),
1284 column: column.to_string(),
1285 })
1286 } else {
1287 Ok(type_family_to_data_type(family))
1288 };
1289 }
1290
1291 let table_key = self
1292 .resolve_table_key(table)
1293 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1294 let entry = self
1295 .schema_map
1296 .get(&table_key)
1297 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1298 let family =
1299 entry
1300 .columns
1301 .get(&col_name)
1302 .copied()
1303 .ok_or_else(|| SchemaError::ColumnNotFound {
1304 table: table.to_string(),
1305 column: column.to_string(),
1306 })?;
1307 Ok(type_family_to_data_type(family))
1308 }
1309
1310 fn has_column(&self, table: &str, column: &str) -> bool {
1311 self.get_column_type(table, column).is_ok()
1312 }
1313
1314 fn supported_table_args(&self) -> &[&str] {
1315 TABLE_PARTS
1316 }
1317
1318 fn is_empty(&self) -> bool {
1319 self.schema_map.is_empty()
1320 }
1321
1322 fn depth(&self) -> usize {
1323 1
1324 }
1325}
1326
1327fn infer_expression_type_family(
1328 expr: &Expression,
1329 schema_map: &HashMap<String, TableSchemaEntry>,
1330 context: &TypeCheckContext,
1331) -> TypeFamily {
1332 let inference_schema = TypeInferenceSchema {
1333 schema_map,
1334 context,
1335 };
1336 if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1337 let family = data_type_family(&data_type);
1338 if family != TypeFamily::Unknown {
1339 return family;
1340 }
1341 }
1342
1343 infer_expression_type_family_fallback(expr, schema_map, context)
1344}
1345
1346fn infer_expression_type_family_fallback(
1347 expr: &Expression,
1348 schema_map: &HashMap<String, TableSchemaEntry>,
1349 context: &TypeCheckContext,
1350) -> TypeFamily {
1351 match expr {
1352 Expression::Literal(literal) => match literal {
1353 crate::expressions::Literal::Number(value) => {
1354 if value.contains('.') || value.contains('e') || value.contains('E') {
1355 TypeFamily::Numeric
1356 } else {
1357 TypeFamily::Integer
1358 }
1359 }
1360 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1361 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1362 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1363 crate::expressions::Literal::Timestamp(_)
1364 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1365 crate::expressions::Literal::HexString(_)
1366 | crate::expressions::Literal::BitString(_)
1367 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1368 _ => TypeFamily::String,
1369 },
1370 Expression::Boolean(_) => TypeFamily::Boolean,
1371 Expression::Null(_) => TypeFamily::Unknown,
1372 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1373 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1374 data_type_family(&cast.to)
1375 }
1376 Expression::Alias(alias) => {
1377 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1378 }
1379 Expression::Neg(unary) => {
1380 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1381 }
1382 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1383 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1384 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1385 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1386 TypeFamily::Unknown
1387 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1388 TypeFamily::Integer
1389 } else if left.is_numeric() && right.is_numeric() {
1390 TypeFamily::Numeric
1391 } else if left.is_temporal() || right.is_temporal() {
1392 left
1393 } else {
1394 TypeFamily::Unknown
1395 }
1396 }
1397 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1398 Expression::Concat(_) => TypeFamily::String,
1399 Expression::Eq(_)
1400 | Expression::Neq(_)
1401 | Expression::Lt(_)
1402 | Expression::Lte(_)
1403 | Expression::Gt(_)
1404 | Expression::Gte(_)
1405 | Expression::Like(_)
1406 | Expression::ILike(_)
1407 | Expression::And(_)
1408 | Expression::Or(_)
1409 | Expression::Not(_)
1410 | Expression::Between(_)
1411 | Expression::In(_)
1412 | Expression::IsNull(_)
1413 | Expression::IsTrue(_)
1414 | Expression::IsFalse(_)
1415 | Expression::Is(_) => TypeFamily::Boolean,
1416 Expression::Length(_) => TypeFamily::Integer,
1417 Expression::Upper(_)
1418 | Expression::Lower(_)
1419 | Expression::Trim(_)
1420 | Expression::LTrim(_)
1421 | Expression::RTrim(_)
1422 | Expression::Replace(_)
1423 | Expression::Substring(_)
1424 | Expression::Left(_)
1425 | Expression::Right(_)
1426 | Expression::Repeat(_)
1427 | Expression::Lpad(_)
1428 | Expression::Rpad(_)
1429 | Expression::ConcatWs(_) => TypeFamily::String,
1430 Expression::Abs(_)
1431 | Expression::Round(_)
1432 | Expression::Floor(_)
1433 | Expression::Ceil(_)
1434 | Expression::Power(_)
1435 | Expression::Sqrt(_)
1436 | Expression::Cbrt(_)
1437 | Expression::Ln(_)
1438 | Expression::Log(_)
1439 | Expression::Exp(_) => TypeFamily::Numeric,
1440 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1441 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1442 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1443 Expression::CurrentDate(_) => TypeFamily::Date,
1444 Expression::CurrentTime(_) => TypeFamily::Time,
1445 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1446 TypeFamily::Timestamp
1447 }
1448 Expression::Interval(_) => TypeFamily::Interval,
1449 _ => TypeFamily::Unknown,
1450 }
1451}
1452
1453fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1454 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1455 return true;
1456 }
1457 if left == right {
1458 return true;
1459 }
1460 if left.is_numeric() && right.is_numeric() {
1461 return true;
1462 }
1463 if left.is_temporal() && right.is_temporal() {
1464 return true;
1465 }
1466 false
1467}
1468
1469fn check_function_argument(
1470 errors: &mut Vec<ValidationError>,
1471 strict: bool,
1472 function_name: &str,
1473 arg_index: usize,
1474 family: TypeFamily,
1475 expected: &str,
1476 valid: bool,
1477) {
1478 if family == TypeFamily::Unknown || valid {
1479 return;
1480 }
1481
1482 errors.push(type_issue(
1483 strict,
1484 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1485 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1486 format!(
1487 "Function '{}' argument {} expects {}, found {}",
1488 function_name,
1489 arg_index + 1,
1490 expected,
1491 type_family_name(family)
1492 ),
1493 ));
1494}
1495
1496fn function_dispatch_name(name: &str) -> String {
1497 let upper = name
1498 .rsplit('.')
1499 .next()
1500 .unwrap_or(name)
1501 .trim()
1502 .to_uppercase();
1503 lower(canonical_typed_function_name_upper(&upper))
1504}
1505
1506fn function_base_name(name: &str) -> &str {
1507 name.rsplit('.').next().unwrap_or(name).trim()
1508}
1509
1510fn check_generic_function(
1511 function: &Function,
1512 schema_map: &HashMap<String, TableSchemaEntry>,
1513 context: &TypeCheckContext,
1514 strict: bool,
1515 errors: &mut Vec<ValidationError>,
1516) {
1517 let name = function_dispatch_name(&function.name);
1518
1519 let arg_family = |index: usize| -> Option<TypeFamily> {
1520 function
1521 .args
1522 .get(index)
1523 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1524 };
1525
1526 match name.as_str() {
1527 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1528 if let Some(family) = arg_family(0) {
1529 check_function_argument(
1530 errors,
1531 strict,
1532 &name,
1533 0,
1534 family,
1535 "a numeric argument",
1536 family.is_numeric(),
1537 );
1538 }
1539 }
1540 "round" | "floor" | "ceil" | "ceiling" => {
1541 if let Some(family) = arg_family(0) {
1542 check_function_argument(
1543 errors,
1544 strict,
1545 &name,
1546 0,
1547 family,
1548 "a numeric argument",
1549 family.is_numeric(),
1550 );
1551 }
1552 if let Some(family) = arg_family(1) {
1553 check_function_argument(
1554 errors,
1555 strict,
1556 &name,
1557 1,
1558 family,
1559 "a numeric argument",
1560 family.is_numeric(),
1561 );
1562 }
1563 }
1564 "power" | "pow" => {
1565 for i in [0_usize, 1_usize] {
1566 if let Some(family) = arg_family(i) {
1567 check_function_argument(
1568 errors,
1569 strict,
1570 &name,
1571 i,
1572 family,
1573 "a numeric argument",
1574 family.is_numeric(),
1575 );
1576 }
1577 }
1578 }
1579 "length" | "char_length" | "character_length" => {
1580 if let Some(family) = arg_family(0) {
1581 check_function_argument(
1582 errors,
1583 strict,
1584 &name,
1585 0,
1586 family,
1587 "a string or binary argument",
1588 is_string_or_binary(family),
1589 );
1590 }
1591 }
1592 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1593 if let Some(family) = arg_family(0) {
1594 check_function_argument(
1595 errors,
1596 strict,
1597 &name,
1598 0,
1599 family,
1600 "a string argument",
1601 is_string_like(family),
1602 );
1603 }
1604 }
1605 "substring" | "substr" => {
1606 if let Some(family) = arg_family(0) {
1607 check_function_argument(
1608 errors,
1609 strict,
1610 &name,
1611 0,
1612 family,
1613 "a string argument",
1614 is_string_like(family),
1615 );
1616 }
1617 if let Some(family) = arg_family(1) {
1618 check_function_argument(
1619 errors,
1620 strict,
1621 &name,
1622 1,
1623 family,
1624 "a numeric argument",
1625 family.is_numeric(),
1626 );
1627 }
1628 if let Some(family) = arg_family(2) {
1629 check_function_argument(
1630 errors,
1631 strict,
1632 &name,
1633 2,
1634 family,
1635 "a numeric argument",
1636 family.is_numeric(),
1637 );
1638 }
1639 }
1640 "replace" => {
1641 for i in [0_usize, 1_usize, 2_usize] {
1642 if let Some(family) = arg_family(i) {
1643 check_function_argument(
1644 errors,
1645 strict,
1646 &name,
1647 i,
1648 family,
1649 "a string argument",
1650 is_string_like(family),
1651 );
1652 }
1653 }
1654 }
1655 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1656 if let Some(family) = arg_family(0) {
1657 check_function_argument(
1658 errors,
1659 strict,
1660 &name,
1661 0,
1662 family,
1663 "a string argument",
1664 is_string_like(family),
1665 );
1666 }
1667 if let Some(family) = arg_family(1) {
1668 check_function_argument(
1669 errors,
1670 strict,
1671 &name,
1672 1,
1673 family,
1674 "a numeric argument",
1675 family.is_numeric(),
1676 );
1677 }
1678 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1679 if let Some(family) = arg_family(2) {
1680 check_function_argument(
1681 errors,
1682 strict,
1683 &name,
1684 2,
1685 family,
1686 "a string argument",
1687 is_string_like(family),
1688 );
1689 }
1690 }
1691 }
1692 _ => {}
1693 }
1694}
1695
1696fn check_function_catalog(
1697 function: &Function,
1698 dialect: DialectType,
1699 function_catalog: Option<&dyn FunctionCatalog>,
1700 strict: bool,
1701 errors: &mut Vec<ValidationError>,
1702) {
1703 let Some(catalog) = function_catalog else {
1704 return;
1705 };
1706
1707 let raw_name = function_base_name(&function.name);
1708 let normalized_name = function_dispatch_name(&function.name);
1709 let arity = function.args.len();
1710 let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1711 errors.push(if strict {
1712 ValidationError::error(
1713 format!(
1714 "Unknown function '{}' for dialect {:?}",
1715 function.name, dialect
1716 ),
1717 validation_codes::E_UNKNOWN_FUNCTION,
1718 )
1719 } else {
1720 ValidationError::warning(
1721 format!(
1722 "Unknown function '{}' for dialect {:?}",
1723 function.name, dialect
1724 ),
1725 validation_codes::E_UNKNOWN_FUNCTION,
1726 )
1727 });
1728 return;
1729 };
1730
1731 if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1732 return;
1733 }
1734
1735 let expected = signatures
1736 .iter()
1737 .map(|sig| sig.describe_arity())
1738 .collect::<Vec<_>>()
1739 .join(", ");
1740 errors.push(if strict {
1741 ValidationError::error(
1742 format!(
1743 "Invalid arity for function '{}': got {}, expected {}",
1744 function.name, arity, expected
1745 ),
1746 validation_codes::E_INVALID_FUNCTION_ARITY,
1747 )
1748 } else {
1749 ValidationError::warning(
1750 format!(
1751 "Invalid arity for function '{}': got {}, expected {}",
1752 function.name, arity, expected
1753 ),
1754 validation_codes::E_INVALID_FUNCTION_ARITY,
1755 )
1756 });
1757}
1758
1759#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1760struct DeclaredRelationship {
1761 source_table: String,
1762 source_column: String,
1763 target_table: String,
1764 target_column: String,
1765}
1766
1767fn build_declared_relationships(
1768 schema: &ValidationSchema,
1769 schema_map: &HashMap<String, TableSchemaEntry>,
1770) -> Vec<DeclaredRelationship> {
1771 let mut relationships = HashSet::new();
1772
1773 for table in &schema.tables {
1774 let Some(source_key) =
1775 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1776 else {
1777 continue;
1778 };
1779
1780 for column in &table.columns {
1781 let Some(reference) = &column.references else {
1782 continue;
1783 };
1784 let Some(target_key) = resolve_reference_table_key(
1785 &reference.table,
1786 reference.schema.as_deref(),
1787 table.schema.as_deref(),
1788 schema_map,
1789 ) else {
1790 continue;
1791 };
1792
1793 relationships.insert(DeclaredRelationship {
1794 source_table: source_key.clone(),
1795 source_column: lower(&column.name),
1796 target_table: target_key,
1797 target_column: lower(&reference.column),
1798 });
1799 }
1800
1801 for foreign_key in &table.foreign_keys {
1802 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1803 continue;
1804 }
1805 let Some(target_key) = resolve_reference_table_key(
1806 &foreign_key.references.table,
1807 foreign_key.references.schema.as_deref(),
1808 table.schema.as_deref(),
1809 schema_map,
1810 ) else {
1811 continue;
1812 };
1813
1814 for (source_col, target_col) in foreign_key
1815 .columns
1816 .iter()
1817 .zip(foreign_key.references.columns.iter())
1818 {
1819 relationships.insert(DeclaredRelationship {
1820 source_table: source_key.clone(),
1821 source_column: lower(source_col),
1822 target_table: target_key.clone(),
1823 target_column: lower(target_col),
1824 });
1825 }
1826 }
1827 }
1828
1829 relationships.into_iter().collect()
1830}
1831
1832fn resolve_column_binding(
1833 column: &Column,
1834 schema_map: &HashMap<String, TableSchemaEntry>,
1835 context: &TypeCheckContext,
1836 resolver: &mut Resolver<'_>,
1837) -> Option<(String, String)> {
1838 let column_name = lower(&column.name.name);
1839 if column_name.is_empty() {
1840 return None;
1841 }
1842
1843 if let Some(table) = &column.table {
1844 let mut table_key = lower(&table.name);
1845 if let Some(mapped) = context.table_aliases.get(&table_key) {
1846 table_key = mapped.clone();
1847 }
1848 if schema_map.contains_key(&table_key) {
1849 return Some((table_key, column_name));
1850 }
1851 return None;
1852 }
1853
1854 if let Some(resolved_source) = resolver.get_table(&column_name) {
1855 let mut table_key = lower(&resolved_source);
1856 if let Some(mapped) = context.table_aliases.get(&table_key) {
1857 table_key = mapped.clone();
1858 }
1859 if schema_map.contains_key(&table_key) {
1860 return Some((table_key, column_name));
1861 }
1862 }
1863
1864 let candidates: Vec<String> = context
1865 .referenced_tables
1866 .iter()
1867 .filter_map(|table_name| {
1868 schema_map
1869 .get(table_name)
1870 .filter(|entry| entry.columns.contains_key(&column_name))
1871 .map(|_| table_name.clone())
1872 })
1873 .collect();
1874 if candidates.len() == 1 {
1875 return Some((candidates[0].clone(), column_name));
1876 }
1877 None
1878}
1879
1880fn extract_join_equality_pairs(
1881 expr: &Expression,
1882 schema_map: &HashMap<String, TableSchemaEntry>,
1883 context: &TypeCheckContext,
1884 resolver: &mut Resolver<'_>,
1885 pairs: &mut Vec<((String, String), (String, String))>,
1886) {
1887 match expr {
1888 Expression::And(op) => {
1889 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1890 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1891 }
1892 Expression::Paren(paren) => {
1893 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1894 }
1895 Expression::Eq(op) => {
1896 let (Expression::Column(left_col), Expression::Column(right_col)) =
1897 (&op.left, &op.right)
1898 else {
1899 return;
1900 };
1901 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1902 return;
1903 };
1904 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1905 else {
1906 return;
1907 };
1908 pairs.push((left, right));
1909 }
1910 _ => {}
1911 }
1912}
1913
1914fn relationship_matches_pair(
1915 relationship: &DeclaredRelationship,
1916 left_table: &str,
1917 left_column: &str,
1918 right_table: &str,
1919 right_column: &str,
1920) -> bool {
1921 (relationship.source_table == left_table
1922 && relationship.source_column == left_column
1923 && relationship.target_table == right_table
1924 && relationship.target_column == right_column)
1925 || (relationship.source_table == right_table
1926 && relationship.source_column == right_column
1927 && relationship.target_table == left_table
1928 && relationship.target_column == left_column)
1929}
1930
1931fn resolved_table_key_from_expr(
1932 expr: &Expression,
1933 schema_map: &HashMap<String, TableSchemaEntry>,
1934) -> Option<String> {
1935 match expr {
1936 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1937 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1938 _ => None,
1939 }
1940}
1941
1942fn select_from_table_keys(
1943 select: &crate::expressions::Select,
1944 schema_map: &HashMap<String, TableSchemaEntry>,
1945) -> HashSet<String> {
1946 let mut keys = HashSet::new();
1947 if let Some(from_clause) = &select.from {
1948 for expr in &from_clause.expressions {
1949 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1950 keys.insert(key);
1951 }
1952 }
1953 }
1954 keys
1955}
1956
1957fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1958 matches!(
1959 kind,
1960 JoinKind::Natural
1961 | JoinKind::NaturalLeft
1962 | JoinKind::NaturalRight
1963 | JoinKind::NaturalFull
1964 | JoinKind::CrossApply
1965 | JoinKind::OuterApply
1966 | JoinKind::AsOf
1967 | JoinKind::AsOfLeft
1968 | JoinKind::AsOfRight
1969 | JoinKind::Lateral
1970 | JoinKind::LeftLateral
1971 )
1972}
1973
1974fn check_query_reference_quality(
1975 stmt: &Expression,
1976 schema_map: &HashMap<String, TableSchemaEntry>,
1977 resolver_schema: &MappingSchema,
1978 strict: bool,
1979 relationships: &[DeclaredRelationship],
1980) -> Vec<ValidationError> {
1981 let mut errors = Vec::new();
1982
1983 for node in stmt.dfs() {
1984 let Expression::Select(select) = node else {
1985 continue;
1986 };
1987
1988 let select_expr = Expression::Select(select.clone());
1989 let context = collect_type_check_context(&select_expr, schema_map);
1990 let scope = build_scope(&select_expr);
1991 let mut resolver = Resolver::new(&scope, resolver_schema, true);
1992
1993 if context.referenced_tables.len() > 1 {
1994 let using_columns: HashSet<String> = select
1995 .joins
1996 .iter()
1997 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1998 .collect();
1999
2000 let mut seen = HashSet::new();
2001 for column_expr in select_expr
2002 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
2003 {
2004 let Expression::Column(column) = column_expr else {
2005 continue;
2006 };
2007
2008 let col_name = lower(&column.name.name);
2009 if col_name.is_empty()
2010 || using_columns.contains(&col_name)
2011 || !seen.insert(col_name.clone())
2012 {
2013 continue;
2014 }
2015
2016 if resolver.is_ambiguous(&col_name) {
2017 let source_count = resolver.sources_for_column(&col_name).len();
2018 errors.push(if strict {
2019 ValidationError::error(
2020 format!(
2021 "Ambiguous unqualified column '{}' found in {} referenced tables",
2022 col_name, source_count
2023 ),
2024 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2025 )
2026 } else {
2027 ValidationError::warning(
2028 format!(
2029 "Ambiguous unqualified column '{}' found in {} referenced tables",
2030 col_name, source_count
2031 ),
2032 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2033 )
2034 });
2035 }
2036 }
2037 }
2038
2039 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2040
2041 for join in &select.joins {
2042 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2043 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2044 let cartesian_like_kind = matches!(
2045 join.kind,
2046 JoinKind::Cross
2047 | JoinKind::Implicit
2048 | JoinKind::Array
2049 | JoinKind::LeftArray
2050 | JoinKind::Paste
2051 );
2052
2053 if right_table_key.is_some()
2054 && (cartesian_like_kind
2055 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2056 {
2057 errors.push(ValidationError::warning(
2058 "Potential cartesian join: JOIN without ON/USING condition",
2059 validation_codes::W_CARTESIAN_JOIN,
2060 ));
2061 }
2062
2063 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2064 if join.using.is_empty() {
2065 let mut eq_pairs = Vec::new();
2066 extract_join_equality_pairs(
2067 on_expr,
2068 schema_map,
2069 &context,
2070 &mut resolver,
2071 &mut eq_pairs,
2072 );
2073
2074 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2075 .iter()
2076 .filter(|rel| {
2077 cumulative_left_tables.contains(&rel.source_table)
2078 && rel.target_table == right_key
2079 || (cumulative_left_tables.contains(&rel.target_table)
2080 && rel.source_table == right_key)
2081 })
2082 .collect();
2083
2084 if !relevant_relationships.is_empty() {
2085 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2086 relevant_relationships
2087 .iter()
2088 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2089 });
2090 if !uses_declared_fk {
2091 errors.push(ValidationError::warning(
2092 "JOIN predicate does not use declared foreign-key relationship columns",
2093 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2094 ));
2095 }
2096 }
2097 }
2098 }
2099
2100 if let Some(right_key) = right_table_key {
2101 cumulative_left_tables.insert(right_key);
2102 }
2103 }
2104 }
2105
2106 errors
2107}
2108
2109fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2110 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2111 return true;
2112 }
2113 if left == right {
2114 return true;
2115 }
2116 if left.is_numeric() && right.is_numeric() {
2117 return true;
2118 }
2119 if left.is_temporal() && right.is_temporal() {
2120 return true;
2121 }
2122 false
2123}
2124
2125fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2126 if left == TypeFamily::Unknown {
2127 return right;
2128 }
2129 if right == TypeFamily::Unknown {
2130 return left;
2131 }
2132 if left == right {
2133 return left;
2134 }
2135 if left.is_numeric() && right.is_numeric() {
2136 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2137 return TypeFamily::Numeric;
2138 }
2139 return TypeFamily::Integer;
2140 }
2141 if left.is_temporal() && right.is_temporal() {
2142 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2143 return TypeFamily::Timestamp;
2144 }
2145 if left == TypeFamily::Date || right == TypeFamily::Date {
2146 return TypeFamily::Date;
2147 }
2148 return TypeFamily::Time;
2149 }
2150 TypeFamily::Unknown
2151}
2152
2153fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2154 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2155 return true;
2156 }
2157 if target == source {
2158 return true;
2159 }
2160
2161 match target {
2162 TypeFamily::Boolean => source == TypeFamily::Boolean,
2163 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2164 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2165 source.is_temporal()
2166 }
2167 TypeFamily::String => true,
2168 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2169 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2170 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2171 TypeFamily::Array => source == TypeFamily::Array,
2172 TypeFamily::Map => source == TypeFamily::Map,
2173 TypeFamily::Struct => source == TypeFamily::Struct,
2174 TypeFamily::Unknown => true,
2175 }
2176}
2177
2178fn projection_families(
2179 query_expr: &Expression,
2180 schema_map: &HashMap<String, TableSchemaEntry>,
2181) -> Option<Vec<TypeFamily>> {
2182 match query_expr {
2183 Expression::Select(select) => {
2184 if select
2185 .expressions
2186 .iter()
2187 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2188 {
2189 return None;
2190 }
2191 let select_expr = Expression::Select(select.clone());
2192 let context = collect_type_check_context(&select_expr, schema_map);
2193 Some(
2194 select
2195 .expressions
2196 .iter()
2197 .map(|e| infer_expression_type_family(e, schema_map, &context))
2198 .collect(),
2199 )
2200 }
2201 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2202 Expression::Union(union) => {
2203 let left = projection_families(&union.left, schema_map)?;
2204 let right = projection_families(&union.right, schema_map)?;
2205 if left.len() != right.len() {
2206 return None;
2207 }
2208 Some(
2209 left.into_iter()
2210 .zip(right)
2211 .map(|(l, r)| merged_setop_family(l, r))
2212 .collect(),
2213 )
2214 }
2215 Expression::Intersect(intersect) => {
2216 let left = projection_families(&intersect.left, schema_map)?;
2217 let right = projection_families(&intersect.right, schema_map)?;
2218 if left.len() != right.len() {
2219 return None;
2220 }
2221 Some(
2222 left.into_iter()
2223 .zip(right)
2224 .map(|(l, r)| merged_setop_family(l, r))
2225 .collect(),
2226 )
2227 }
2228 Expression::Except(except) => {
2229 let left = projection_families(&except.left, schema_map)?;
2230 let right = projection_families(&except.right, schema_map)?;
2231 if left.len() != right.len() {
2232 return None;
2233 }
2234 Some(
2235 left.into_iter()
2236 .zip(right)
2237 .map(|(l, r)| merged_setop_family(l, r))
2238 .collect(),
2239 )
2240 }
2241 Expression::Values(values) => {
2242 let first_row = values.expressions.first()?;
2243 let context = TypeCheckContext::default();
2244 Some(
2245 first_row
2246 .expressions
2247 .iter()
2248 .map(|e| infer_expression_type_family(e, schema_map, &context))
2249 .collect(),
2250 )
2251 }
2252 _ => None,
2253 }
2254}
2255
2256fn check_set_operation_compatibility(
2257 op_name: &str,
2258 left_expr: &Expression,
2259 right_expr: &Expression,
2260 schema_map: &HashMap<String, TableSchemaEntry>,
2261 strict: bool,
2262 errors: &mut Vec<ValidationError>,
2263) {
2264 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2265 return;
2266 };
2267 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2268 return;
2269 };
2270
2271 if left_projection.len() != right_projection.len() {
2272 errors.push(type_issue(
2273 strict,
2274 validation_codes::E_SETOP_ARITY_MISMATCH,
2275 validation_codes::W_SETOP_IMPLICIT_COERCION,
2276 format!(
2277 "{} operands return different column counts: left {}, right {}",
2278 op_name,
2279 left_projection.len(),
2280 right_projection.len()
2281 ),
2282 ));
2283 return;
2284 }
2285
2286 for (idx, (left, right)) in left_projection
2287 .into_iter()
2288 .zip(right_projection)
2289 .enumerate()
2290 {
2291 if !are_setop_compatible(left, right) {
2292 errors.push(type_issue(
2293 strict,
2294 validation_codes::E_SETOP_TYPE_MISMATCH,
2295 validation_codes::W_SETOP_IMPLICIT_COERCION,
2296 format!(
2297 "{} column {} has incompatible types: {} vs {}",
2298 op_name,
2299 idx + 1,
2300 type_family_name(left),
2301 type_family_name(right)
2302 ),
2303 ));
2304 }
2305 }
2306}
2307
2308fn check_insert_assignments(
2309 stmt: &Expression,
2310 insert: &Insert,
2311 schema_map: &HashMap<String, TableSchemaEntry>,
2312 strict: bool,
2313 errors: &mut Vec<ValidationError>,
2314) {
2315 let Some((target_table_key, table_schema)) =
2316 resolve_table_schema_entry(&insert.table, schema_map)
2317 else {
2318 return;
2319 };
2320
2321 let mut target_columns = Vec::new();
2322 if insert.columns.is_empty() {
2323 target_columns.extend(table_schema.column_order.iter().cloned());
2324 } else {
2325 for column in &insert.columns {
2326 let col_name = lower(&column.name);
2327 if table_schema.columns.contains_key(&col_name) {
2328 target_columns.push(col_name);
2329 } else {
2330 errors.push(if strict {
2331 ValidationError::error(
2332 format!(
2333 "Unknown column '{}' in table '{}'",
2334 column.name, target_table_key
2335 ),
2336 validation_codes::E_UNKNOWN_COLUMN,
2337 )
2338 } else {
2339 ValidationError::warning(
2340 format!(
2341 "Unknown column '{}' in table '{}'",
2342 column.name, target_table_key
2343 ),
2344 validation_codes::E_UNKNOWN_COLUMN,
2345 )
2346 });
2347 }
2348 }
2349 }
2350
2351 if target_columns.is_empty() {
2352 return;
2353 }
2354
2355 let context = collect_type_check_context(stmt, schema_map);
2356
2357 if !insert.default_values {
2358 for (row_idx, row) in insert.values.iter().enumerate() {
2359 if row.len() != target_columns.len() {
2360 errors.push(type_issue(
2361 strict,
2362 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2363 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2364 format!(
2365 "INSERT row {} has {} values but target has {} columns",
2366 row_idx + 1,
2367 row.len(),
2368 target_columns.len()
2369 ),
2370 ));
2371 continue;
2372 }
2373
2374 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2375 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2376 continue;
2377 };
2378 let source_family = infer_expression_type_family(value, schema_map, &context);
2379 if !are_assignment_compatible(target_family, source_family) {
2380 errors.push(type_issue(
2381 strict,
2382 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2383 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2384 format!(
2385 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2386 target_table_key,
2387 target_column,
2388 type_family_name(target_family),
2389 type_family_name(source_family)
2390 ),
2391 ));
2392 }
2393 }
2394 }
2395 }
2396
2397 if let Some(query) = &insert.query {
2398 if insert.by_name {
2400 return;
2401 }
2402
2403 let Some(source_projection) = projection_families(query, schema_map) else {
2404 return;
2405 };
2406
2407 if source_projection.len() != target_columns.len() {
2408 errors.push(type_issue(
2409 strict,
2410 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2411 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2412 format!(
2413 "INSERT source query has {} columns but target has {} columns",
2414 source_projection.len(),
2415 target_columns.len()
2416 ),
2417 ));
2418 return;
2419 }
2420
2421 for (source_family, target_column) in
2422 source_projection.into_iter().zip(target_columns.iter())
2423 {
2424 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2425 continue;
2426 };
2427 if !are_assignment_compatible(target_family, source_family) {
2428 errors.push(type_issue(
2429 strict,
2430 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2431 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2432 format!(
2433 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2434 target_table_key,
2435 target_column,
2436 type_family_name(target_family),
2437 type_family_name(source_family)
2438 ),
2439 ));
2440 }
2441 }
2442 }
2443}
2444
2445fn check_update_assignments(
2446 stmt: &Expression,
2447 update: &Update,
2448 schema_map: &HashMap<String, TableSchemaEntry>,
2449 strict: bool,
2450 errors: &mut Vec<ValidationError>,
2451) {
2452 let Some((target_table_key, table_schema)) =
2453 resolve_table_schema_entry(&update.table, schema_map)
2454 else {
2455 return;
2456 };
2457
2458 let context = collect_type_check_context(stmt, schema_map);
2459
2460 for (column, value) in &update.set {
2461 let col_name = lower(&column.name);
2462 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2463 errors.push(if strict {
2464 ValidationError::error(
2465 format!(
2466 "Unknown column '{}' in table '{}'",
2467 column.name, target_table_key
2468 ),
2469 validation_codes::E_UNKNOWN_COLUMN,
2470 )
2471 } else {
2472 ValidationError::warning(
2473 format!(
2474 "Unknown column '{}' in table '{}'",
2475 column.name, target_table_key
2476 ),
2477 validation_codes::E_UNKNOWN_COLUMN,
2478 )
2479 });
2480 continue;
2481 };
2482
2483 let source_family = infer_expression_type_family(value, schema_map, &context);
2484 if !are_assignment_compatible(target_family, source_family) {
2485 errors.push(type_issue(
2486 strict,
2487 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2488 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2489 format!(
2490 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2491 target_table_key,
2492 col_name,
2493 type_family_name(target_family),
2494 type_family_name(source_family)
2495 ),
2496 ));
2497 }
2498 }
2499}
2500
2501fn check_types(
2502 stmt: &Expression,
2503 dialect: DialectType,
2504 schema_map: &HashMap<String, TableSchemaEntry>,
2505 function_catalog: Option<&dyn FunctionCatalog>,
2506 strict: bool,
2507) -> Vec<ValidationError> {
2508 let mut errors = Vec::new();
2509 let context = collect_type_check_context(stmt, schema_map);
2510
2511 for node in stmt.dfs() {
2512 match node {
2513 Expression::Insert(insert) => {
2514 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2515 }
2516 Expression::Update(update) => {
2517 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2518 }
2519 Expression::Union(union) => {
2520 check_set_operation_compatibility(
2521 "UNION",
2522 &union.left,
2523 &union.right,
2524 schema_map,
2525 strict,
2526 &mut errors,
2527 );
2528 }
2529 Expression::Intersect(intersect) => {
2530 check_set_operation_compatibility(
2531 "INTERSECT",
2532 &intersect.left,
2533 &intersect.right,
2534 schema_map,
2535 strict,
2536 &mut errors,
2537 );
2538 }
2539 Expression::Except(except) => {
2540 check_set_operation_compatibility(
2541 "EXCEPT",
2542 &except.left,
2543 &except.right,
2544 schema_map,
2545 strict,
2546 &mut errors,
2547 );
2548 }
2549 Expression::Select(select) => {
2550 if let Some(prewhere) = &select.prewhere {
2551 let family = infer_expression_type_family(prewhere, schema_map, &context);
2552 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2553 errors.push(type_issue(
2554 strict,
2555 validation_codes::E_INVALID_PREDICATE_TYPE,
2556 validation_codes::W_PREDICATE_NULLABILITY,
2557 format!(
2558 "PREWHERE clause expects a boolean predicate, found {}",
2559 type_family_name(family)
2560 ),
2561 ));
2562 }
2563 }
2564
2565 if let Some(where_clause) = &select.where_clause {
2566 let family =
2567 infer_expression_type_family(&where_clause.this, schema_map, &context);
2568 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2569 errors.push(type_issue(
2570 strict,
2571 validation_codes::E_INVALID_PREDICATE_TYPE,
2572 validation_codes::W_PREDICATE_NULLABILITY,
2573 format!(
2574 "WHERE clause expects a boolean predicate, found {}",
2575 type_family_name(family)
2576 ),
2577 ));
2578 }
2579 }
2580
2581 if let Some(having_clause) = &select.having {
2582 let family =
2583 infer_expression_type_family(&having_clause.this, schema_map, &context);
2584 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2585 errors.push(type_issue(
2586 strict,
2587 validation_codes::E_INVALID_PREDICATE_TYPE,
2588 validation_codes::W_PREDICATE_NULLABILITY,
2589 format!(
2590 "HAVING clause expects a boolean predicate, found {}",
2591 type_family_name(family)
2592 ),
2593 ));
2594 }
2595 }
2596
2597 for join in &select.joins {
2598 if let Some(on) = &join.on {
2599 let family = infer_expression_type_family(on, schema_map, &context);
2600 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2601 errors.push(type_issue(
2602 strict,
2603 validation_codes::E_INVALID_PREDICATE_TYPE,
2604 validation_codes::W_PREDICATE_NULLABILITY,
2605 format!(
2606 "JOIN ON expects a boolean predicate, found {}",
2607 type_family_name(family)
2608 ),
2609 ));
2610 }
2611 }
2612 if let Some(match_condition) = &join.match_condition {
2613 let family =
2614 infer_expression_type_family(match_condition, schema_map, &context);
2615 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2616 errors.push(type_issue(
2617 strict,
2618 validation_codes::E_INVALID_PREDICATE_TYPE,
2619 validation_codes::W_PREDICATE_NULLABILITY,
2620 format!(
2621 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2622 type_family_name(family)
2623 ),
2624 ));
2625 }
2626 }
2627 }
2628 }
2629 Expression::Where(where_clause) => {
2630 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2631 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2632 errors.push(type_issue(
2633 strict,
2634 validation_codes::E_INVALID_PREDICATE_TYPE,
2635 validation_codes::W_PREDICATE_NULLABILITY,
2636 format!(
2637 "WHERE clause expects a boolean predicate, found {}",
2638 type_family_name(family)
2639 ),
2640 ));
2641 }
2642 }
2643 Expression::Having(having_clause) => {
2644 let family =
2645 infer_expression_type_family(&having_clause.this, schema_map, &context);
2646 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2647 errors.push(type_issue(
2648 strict,
2649 validation_codes::E_INVALID_PREDICATE_TYPE,
2650 validation_codes::W_PREDICATE_NULLABILITY,
2651 format!(
2652 "HAVING clause expects a boolean predicate, found {}",
2653 type_family_name(family)
2654 ),
2655 ));
2656 }
2657 }
2658 Expression::And(op) | Expression::Or(op) => {
2659 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2660 let family = infer_expression_type_family(expr, schema_map, &context);
2661 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2662 errors.push(type_issue(
2663 strict,
2664 validation_codes::E_INVALID_PREDICATE_TYPE,
2665 validation_codes::W_PREDICATE_NULLABILITY,
2666 format!(
2667 "Logical {} operand expects boolean, found {}",
2668 side,
2669 type_family_name(family)
2670 ),
2671 ));
2672 }
2673 }
2674 }
2675 Expression::Not(unary) => {
2676 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2677 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2678 errors.push(type_issue(
2679 strict,
2680 validation_codes::E_INVALID_PREDICATE_TYPE,
2681 validation_codes::W_PREDICATE_NULLABILITY,
2682 format!("NOT expects boolean, found {}", type_family_name(family)),
2683 ));
2684 }
2685 }
2686 Expression::Eq(op)
2687 | Expression::Neq(op)
2688 | Expression::Lt(op)
2689 | Expression::Lte(op)
2690 | Expression::Gt(op)
2691 | Expression::Gte(op) => {
2692 let left = infer_expression_type_family(&op.left, schema_map, &context);
2693 let right = infer_expression_type_family(&op.right, schema_map, &context);
2694 if !are_comparable(left, right) {
2695 errors.push(type_issue(
2696 strict,
2697 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2698 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2699 format!(
2700 "Incompatible comparison between {} and {}",
2701 type_family_name(left),
2702 type_family_name(right)
2703 ),
2704 ));
2705 }
2706 }
2707 Expression::Like(op) | Expression::ILike(op) => {
2708 let left = infer_expression_type_family(&op.left, schema_map, &context);
2709 let right = infer_expression_type_family(&op.right, schema_map, &context);
2710 if left != TypeFamily::Unknown
2711 && right != TypeFamily::Unknown
2712 && (!is_string_like(left) || !is_string_like(right))
2713 {
2714 errors.push(type_issue(
2715 strict,
2716 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2717 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2718 format!(
2719 "LIKE/ILIKE expects string operands, found {} and {}",
2720 type_family_name(left),
2721 type_family_name(right)
2722 ),
2723 ));
2724 }
2725 }
2726 Expression::Between(between) => {
2727 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2728 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2729 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2730
2731 if !are_comparable(this_family, low_family)
2732 || !are_comparable(this_family, high_family)
2733 {
2734 errors.push(type_issue(
2735 strict,
2736 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2737 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2738 format!(
2739 "BETWEEN bounds are incompatible with {} (found {} and {})",
2740 type_family_name(this_family),
2741 type_family_name(low_family),
2742 type_family_name(high_family)
2743 ),
2744 ));
2745 }
2746 }
2747 Expression::In(in_expr) => {
2748 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2749 for value in &in_expr.expressions {
2750 let item_family = infer_expression_type_family(value, schema_map, &context);
2751 if !are_comparable(this_family, item_family) {
2752 errors.push(type_issue(
2753 strict,
2754 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2755 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2756 format!(
2757 "IN item type {} is incompatible with {}",
2758 type_family_name(item_family),
2759 type_family_name(this_family)
2760 ),
2761 ));
2762 break;
2763 }
2764 }
2765 }
2766 Expression::Add(op)
2767 | Expression::Sub(op)
2768 | Expression::Mul(op)
2769 | Expression::Div(op)
2770 | Expression::Mod(op) => {
2771 let left = infer_expression_type_family(&op.left, schema_map, &context);
2772 let right = infer_expression_type_family(&op.right, schema_map, &context);
2773
2774 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2775 continue;
2776 }
2777
2778 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2779 && ((left.is_temporal() && right.is_numeric())
2780 || (right.is_temporal() && left.is_numeric())
2781 || (matches!(node, Expression::Sub(_))
2782 && left.is_temporal()
2783 && right.is_temporal()));
2784
2785 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2786 errors.push(type_issue(
2787 strict,
2788 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2789 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2790 format!(
2791 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2792 type_family_name(left),
2793 type_family_name(right)
2794 ),
2795 ));
2796 }
2797 }
2798 Expression::Function(function) => {
2799 check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2800 check_generic_function(function, schema_map, &context, strict, &mut errors);
2801 }
2802 Expression::Upper(func)
2803 | Expression::Lower(func)
2804 | Expression::LTrim(func)
2805 | Expression::RTrim(func)
2806 | Expression::Reverse(func) => {
2807 let family = infer_expression_type_family(&func.this, schema_map, &context);
2808 check_function_argument(
2809 &mut errors,
2810 strict,
2811 "string_function",
2812 0,
2813 family,
2814 "a string argument",
2815 is_string_like(family),
2816 );
2817 }
2818 Expression::Length(func) => {
2819 let family = infer_expression_type_family(&func.this, schema_map, &context);
2820 check_function_argument(
2821 &mut errors,
2822 strict,
2823 "length",
2824 0,
2825 family,
2826 "a string or binary argument",
2827 is_string_or_binary(family),
2828 );
2829 }
2830 Expression::Trim(func) => {
2831 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2832 check_function_argument(
2833 &mut errors,
2834 strict,
2835 "trim",
2836 0,
2837 this_family,
2838 "a string argument",
2839 is_string_like(this_family),
2840 );
2841 if let Some(chars) = &func.characters {
2842 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2843 check_function_argument(
2844 &mut errors,
2845 strict,
2846 "trim",
2847 1,
2848 chars_family,
2849 "a string argument",
2850 is_string_like(chars_family),
2851 );
2852 }
2853 }
2854 Expression::Substring(func) => {
2855 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2856 check_function_argument(
2857 &mut errors,
2858 strict,
2859 "substring",
2860 0,
2861 this_family,
2862 "a string argument",
2863 is_string_like(this_family),
2864 );
2865
2866 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2867 check_function_argument(
2868 &mut errors,
2869 strict,
2870 "substring",
2871 1,
2872 start_family,
2873 "a numeric argument",
2874 start_family.is_numeric(),
2875 );
2876 if let Some(length) = &func.length {
2877 let length_family = infer_expression_type_family(length, schema_map, &context);
2878 check_function_argument(
2879 &mut errors,
2880 strict,
2881 "substring",
2882 2,
2883 length_family,
2884 "a numeric argument",
2885 length_family.is_numeric(),
2886 );
2887 }
2888 }
2889 Expression::Replace(func) => {
2890 for (arg, idx) in [
2891 (&func.this, 0_usize),
2892 (&func.old, 1_usize),
2893 (&func.new, 2_usize),
2894 ] {
2895 let family = infer_expression_type_family(arg, schema_map, &context);
2896 check_function_argument(
2897 &mut errors,
2898 strict,
2899 "replace",
2900 idx,
2901 family,
2902 "a string argument",
2903 is_string_like(family),
2904 );
2905 }
2906 }
2907 Expression::Left(func) | Expression::Right(func) => {
2908 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2909 check_function_argument(
2910 &mut errors,
2911 strict,
2912 "left_right",
2913 0,
2914 this_family,
2915 "a string argument",
2916 is_string_like(this_family),
2917 );
2918 let length_family =
2919 infer_expression_type_family(&func.length, schema_map, &context);
2920 check_function_argument(
2921 &mut errors,
2922 strict,
2923 "left_right",
2924 1,
2925 length_family,
2926 "a numeric argument",
2927 length_family.is_numeric(),
2928 );
2929 }
2930 Expression::Repeat(func) => {
2931 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2932 check_function_argument(
2933 &mut errors,
2934 strict,
2935 "repeat",
2936 0,
2937 this_family,
2938 "a string argument",
2939 is_string_like(this_family),
2940 );
2941 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2942 check_function_argument(
2943 &mut errors,
2944 strict,
2945 "repeat",
2946 1,
2947 times_family,
2948 "a numeric argument",
2949 times_family.is_numeric(),
2950 );
2951 }
2952 Expression::Lpad(func) | Expression::Rpad(func) => {
2953 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2954 check_function_argument(
2955 &mut errors,
2956 strict,
2957 "pad",
2958 0,
2959 this_family,
2960 "a string argument",
2961 is_string_like(this_family),
2962 );
2963 let length_family =
2964 infer_expression_type_family(&func.length, schema_map, &context);
2965 check_function_argument(
2966 &mut errors,
2967 strict,
2968 "pad",
2969 1,
2970 length_family,
2971 "a numeric argument",
2972 length_family.is_numeric(),
2973 );
2974 if let Some(fill) = &func.fill {
2975 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2976 check_function_argument(
2977 &mut errors,
2978 strict,
2979 "pad",
2980 2,
2981 fill_family,
2982 "a string argument",
2983 is_string_like(fill_family),
2984 );
2985 }
2986 }
2987 Expression::Abs(func)
2988 | Expression::Sqrt(func)
2989 | Expression::Cbrt(func)
2990 | Expression::Ln(func)
2991 | Expression::Exp(func) => {
2992 let family = infer_expression_type_family(&func.this, schema_map, &context);
2993 check_function_argument(
2994 &mut errors,
2995 strict,
2996 "numeric_function",
2997 0,
2998 family,
2999 "a numeric argument",
3000 family.is_numeric(),
3001 );
3002 }
3003 Expression::Round(func) => {
3004 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3005 check_function_argument(
3006 &mut errors,
3007 strict,
3008 "round",
3009 0,
3010 this_family,
3011 "a numeric argument",
3012 this_family.is_numeric(),
3013 );
3014 if let Some(decimals) = &func.decimals {
3015 let decimals_family =
3016 infer_expression_type_family(decimals, schema_map, &context);
3017 check_function_argument(
3018 &mut errors,
3019 strict,
3020 "round",
3021 1,
3022 decimals_family,
3023 "a numeric argument",
3024 decimals_family.is_numeric(),
3025 );
3026 }
3027 }
3028 Expression::Floor(func) => {
3029 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3030 check_function_argument(
3031 &mut errors,
3032 strict,
3033 "floor",
3034 0,
3035 this_family,
3036 "a numeric argument",
3037 this_family.is_numeric(),
3038 );
3039 if let Some(scale) = &func.scale {
3040 let scale_family = infer_expression_type_family(scale, schema_map, &context);
3041 check_function_argument(
3042 &mut errors,
3043 strict,
3044 "floor",
3045 1,
3046 scale_family,
3047 "a numeric argument",
3048 scale_family.is_numeric(),
3049 );
3050 }
3051 }
3052 Expression::Ceil(func) => {
3053 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3054 check_function_argument(
3055 &mut errors,
3056 strict,
3057 "ceil",
3058 0,
3059 this_family,
3060 "a numeric argument",
3061 this_family.is_numeric(),
3062 );
3063 if let Some(decimals) = &func.decimals {
3064 let decimals_family =
3065 infer_expression_type_family(decimals, schema_map, &context);
3066 check_function_argument(
3067 &mut errors,
3068 strict,
3069 "ceil",
3070 1,
3071 decimals_family,
3072 "a numeric argument",
3073 decimals_family.is_numeric(),
3074 );
3075 }
3076 }
3077 Expression::Power(func) => {
3078 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3079 check_function_argument(
3080 &mut errors,
3081 strict,
3082 "power",
3083 0,
3084 left_family,
3085 "a numeric argument",
3086 left_family.is_numeric(),
3087 );
3088 let right_family =
3089 infer_expression_type_family(&func.expression, schema_map, &context);
3090 check_function_argument(
3091 &mut errors,
3092 strict,
3093 "power",
3094 1,
3095 right_family,
3096 "a numeric argument",
3097 right_family.is_numeric(),
3098 );
3099 }
3100 Expression::Log(func) => {
3101 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3102 check_function_argument(
3103 &mut errors,
3104 strict,
3105 "log",
3106 0,
3107 this_family,
3108 "a numeric argument",
3109 this_family.is_numeric(),
3110 );
3111 if let Some(base) = &func.base {
3112 let base_family = infer_expression_type_family(base, schema_map, &context);
3113 check_function_argument(
3114 &mut errors,
3115 strict,
3116 "log",
3117 1,
3118 base_family,
3119 "a numeric argument",
3120 base_family.is_numeric(),
3121 );
3122 }
3123 }
3124 _ => {}
3125 }
3126 }
3127
3128 errors
3129}
3130
3131fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3132 let mut errors = Vec::new();
3133
3134 let Expression::Select(select) = stmt else {
3135 return errors;
3136 };
3137 let select_expr = Expression::Select(select.clone());
3138
3139 if !select_expr
3141 .find_all(|e| matches!(e, Expression::Star(_)))
3142 .is_empty()
3143 {
3144 errors.push(ValidationError::warning(
3145 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3146 validation_codes::W_SELECT_STAR,
3147 ));
3148 }
3149
3150 let aggregate_count = get_aggregate_functions(&select_expr).len();
3152 if aggregate_count > 0 && select.group_by.is_none() {
3153 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3154 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3155 && get_aggregate_functions(expr).is_empty()
3156 });
3157
3158 if has_non_aggregate_column {
3159 errors.push(ValidationError::warning(
3160 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3161 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3162 ));
3163 }
3164 }
3165
3166 if select.distinct && select.order_by.is_some() {
3168 errors.push(ValidationError::warning(
3169 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3170 validation_codes::W_DISTINCT_ORDER_BY,
3171 ));
3172 }
3173
3174 if select.limit.is_some() && select.order_by.is_none() {
3176 errors.push(ValidationError::warning(
3177 "LIMIT without ORDER BY produces non-deterministic results",
3178 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3179 ));
3180 }
3181
3182 errors
3183}
3184
3185fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3186 scope
3187 .sources
3188 .get_key_value(name)
3189 .map(|(k, _)| k.clone())
3190 .or_else(|| {
3191 scope
3192 .sources
3193 .keys()
3194 .find(|source| source.eq_ignore_ascii_case(name))
3195 .cloned()
3196 })
3197}
3198
3199fn source_has_column(columns: &[String], column_name: &str) -> bool {
3200 columns
3201 .iter()
3202 .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3203}
3204
3205fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3206 scope
3207 .sources
3208 .get(source_name)
3209 .map(|source| match &source.expression {
3210 Expression::Table(table) => lower(&table_ref_display_name(table)),
3211 _ => lower(source_name),
3212 })
3213 .unwrap_or_else(|| lower(source_name))
3214}
3215
3216fn validate_select_columns_with_schema(
3217 select: &crate::expressions::Select,
3218 schema_map: &HashMap<String, TableSchemaEntry>,
3219 resolver_schema: &MappingSchema,
3220 strict: bool,
3221) -> Vec<ValidationError> {
3222 let mut errors = Vec::new();
3223 let select_expr = Expression::Select(Box::new(select.clone()));
3224 let scope = build_scope(&select_expr);
3225 let mut resolver = Resolver::new(&scope, resolver_schema, true);
3226 let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3227
3228 for node in walk_in_scope(&select_expr, false) {
3229 let Expression::Column(column) = node else {
3230 continue;
3231 };
3232
3233 let col_name = lower(&column.name.name);
3234 if col_name.is_empty() {
3235 continue;
3236 }
3237
3238 if let Some(table) = &column.table {
3239 let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3240 errors.push(if strict {
3242 ValidationError::error(
3243 format!(
3244 "Unknown table or alias '{}' referenced by column '{}'",
3245 table.name, col_name
3246 ),
3247 validation_codes::E_UNRESOLVED_REFERENCE,
3248 )
3249 } else {
3250 ValidationError::warning(
3251 format!(
3252 "Unknown table or alias '{}' referenced by column '{}'",
3253 table.name, col_name
3254 ),
3255 validation_codes::E_UNRESOLVED_REFERENCE,
3256 )
3257 });
3258 continue;
3259 };
3260
3261 if let Ok(columns) = resolver.get_source_columns(&source_name) {
3262 if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3263 let table_name = source_display_name(&scope, &source_name);
3264 errors.push(if strict {
3265 ValidationError::error(
3266 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3267 validation_codes::E_UNKNOWN_COLUMN,
3268 )
3269 } else {
3270 ValidationError::warning(
3271 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3272 validation_codes::E_UNKNOWN_COLUMN,
3273 )
3274 });
3275 }
3276 }
3277 continue;
3278 }
3279
3280 let matching_sources: Vec<String> = source_names
3281 .iter()
3282 .filter_map(|source_name| {
3283 resolver
3284 .get_source_columns(source_name)
3285 .ok()
3286 .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3287 .map(|_| source_name.clone())
3288 })
3289 .collect();
3290
3291 if !matching_sources.is_empty() {
3292 continue;
3293 }
3294
3295 let known_sources: Vec<String> = source_names
3296 .iter()
3297 .filter_map(|source_name| {
3298 resolver
3299 .get_source_columns(source_name)
3300 .ok()
3301 .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3302 .map(|_| source_name.clone())
3303 })
3304 .collect();
3305
3306 if known_sources.len() == 1 {
3307 let table_name = source_display_name(&scope, &known_sources[0]);
3308 errors.push(if strict {
3309 ValidationError::error(
3310 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3311 validation_codes::E_UNKNOWN_COLUMN,
3312 )
3313 } else {
3314 ValidationError::warning(
3315 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3316 validation_codes::E_UNKNOWN_COLUMN,
3317 )
3318 });
3319 } else if known_sources.len() > 1 {
3320 errors.push(if strict {
3321 ValidationError::error(
3322 format!(
3323 "Unknown column '{}' (not found in any referenced table)",
3324 col_name
3325 ),
3326 validation_codes::E_UNKNOWN_COLUMN,
3327 )
3328 } else {
3329 ValidationError::warning(
3330 format!(
3331 "Unknown column '{}' (not found in any referenced table)",
3332 col_name
3333 ),
3334 validation_codes::E_UNKNOWN_COLUMN,
3335 )
3336 });
3337 } else if !schema_map.is_empty() {
3338 let found = schema_map
3339 .values()
3340 .any(|table_schema| table_schema.columns.contains_key(&col_name));
3341 if !found {
3342 errors.push(if strict {
3343 ValidationError::error(
3344 format!("Unknown column '{}'", col_name),
3345 validation_codes::E_UNKNOWN_COLUMN,
3346 )
3347 } else {
3348 ValidationError::warning(
3349 format!("Unknown column '{}'", col_name),
3350 validation_codes::E_UNKNOWN_COLUMN,
3351 )
3352 });
3353 }
3354 }
3355 }
3356
3357 errors
3358}
3359
3360fn validate_statement_with_schema(
3361 stmt: &Expression,
3362 schema_map: &HashMap<String, TableSchemaEntry>,
3363 resolver_schema: &MappingSchema,
3364 strict: bool,
3365) -> Vec<ValidationError> {
3366 let mut errors = Vec::new();
3367 let cte_aliases = collect_cte_aliases(stmt);
3368 let mut seen_tables: HashSet<String> = HashSet::new();
3369
3370 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3372 let Expression::Table(table) = node else {
3373 continue;
3374 };
3375
3376 if cte_aliases.contains(&lower(&table.name.name)) {
3377 continue;
3378 }
3379
3380 let resolved_key = table_ref_candidates(table)
3381 .into_iter()
3382 .find(|k| schema_map.contains_key(k));
3383 let table_key = resolved_key
3384 .clone()
3385 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3386
3387 if !seen_tables.insert(table_key) {
3388 continue;
3389 }
3390
3391 if resolved_key.is_none() {
3392 errors.push(if strict {
3393 ValidationError::error(
3394 format!("Unknown table '{}'", table_ref_display_name(table)),
3395 validation_codes::E_UNKNOWN_TABLE,
3396 )
3397 } else {
3398 ValidationError::warning(
3399 format!("Unknown table '{}'", table_ref_display_name(table)),
3400 validation_codes::E_UNKNOWN_TABLE,
3401 )
3402 });
3403 }
3404 }
3405
3406 for node in stmt.dfs() {
3407 let Expression::Select(select) = node else {
3408 continue;
3409 };
3410 errors.extend(validate_select_columns_with_schema(
3411 select,
3412 schema_map,
3413 resolver_schema,
3414 strict,
3415 ));
3416 }
3417
3418 errors
3419}
3420
3421pub fn validate_with_schema(
3423 sql: &str,
3424 dialect: DialectType,
3425 schema: &ValidationSchema,
3426 options: &SchemaValidationOptions,
3427) -> ValidationResult {
3428 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3429
3430 let syntax_result = crate::validate_with_options(
3432 sql,
3433 dialect,
3434 &crate::ValidationOptions {
3435 strict_syntax: options.strict_syntax,
3436 },
3437 );
3438 if !syntax_result.valid {
3439 return syntax_result;
3440 }
3441
3442 let d = Dialect::get(dialect);
3443 let statements = match d.parse(sql) {
3444 Ok(exprs) => exprs,
3445 Err(e) => {
3446 return ValidationResult::with_errors(vec![ValidationError::error(
3447 e.to_string(),
3448 validation_codes::E_PARSE_OR_OPTIONS,
3449 )]);
3450 }
3451 };
3452
3453 let schema_map = build_schema_map(schema);
3454 let resolver_schema = build_resolver_schema(schema);
3455 let mut all_errors = syntax_result.errors;
3456 let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3457 default_embedded_function_catalog()
3458 } else {
3459 None
3460 };
3461 let effective_function_catalog = options
3462 .function_catalog
3463 .as_deref()
3464 .or_else(|| embedded_function_catalog.as_deref());
3465 let declared_relationships = if options.check_references {
3466 build_declared_relationships(schema, &schema_map)
3467 } else {
3468 Vec::new()
3469 };
3470
3471 if options.check_references {
3472 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3473 }
3474
3475 for stmt in &statements {
3476 if options.semantic {
3477 all_errors.extend(check_semantics(stmt));
3478 }
3479 all_errors.extend(validate_statement_with_schema(
3480 stmt,
3481 &schema_map,
3482 &resolver_schema,
3483 strict,
3484 ));
3485 if options.check_types {
3486 all_errors.extend(check_types(
3487 stmt,
3488 dialect,
3489 &schema_map,
3490 effective_function_catalog,
3491 strict,
3492 ));
3493 }
3494 if options.check_references {
3495 all_errors.extend(check_query_reference_quality(
3496 stmt,
3497 &schema_map,
3498 &resolver_schema,
3499 strict,
3500 &declared_relationships,
3501 ));
3502 }
3503 }
3504
3505 ValidationResult::with_errors(all_errors)
3506}
3507
3508#[cfg(test)]
3509mod tests;