1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15 feature = "function-catalog-clickhouse",
16 feature = "function-catalog-duckdb",
17 feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20 FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21 HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34 feature = "function-catalog-clickhouse",
35 feature = "function-catalog-duckdb",
36 feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43 pub name: String,
45 #[serde(default, rename = "type")]
47 pub data_type: String,
48 #[serde(default)]
50 pub nullable: Option<bool>,
51 #[serde(default, rename = "primaryKey")]
53 pub primary_key: bool,
54 #[serde(default)]
56 pub unique: bool,
57 #[serde(default)]
59 pub references: Option<SchemaColumnReference>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65 pub table: String,
67 pub column: String,
69 #[serde(default)]
71 pub schema: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77 #[serde(default)]
79 pub name: Option<String>,
80 pub columns: Vec<String>,
82 pub references: SchemaTableReference,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89 pub table: String,
91 pub columns: Vec<String>,
93 #[serde(default)]
95 pub schema: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101 pub name: String,
103 #[serde(default)]
105 pub schema: Option<String>,
106 pub columns: Vec<SchemaColumn>,
108 #[serde(default)]
110 pub aliases: Vec<String>,
111 #[serde(default, rename = "primaryKey")]
113 pub primary_key: Vec<String>,
114 #[serde(default, rename = "uniqueKeys")]
116 pub unique_keys: Vec<Vec<String>>,
117 #[serde(default, rename = "foreignKeys")]
119 pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125 pub tables: Vec<SchemaTable>,
127 #[serde(default)]
129 pub strict: Option<bool>,
130}
131
132#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135 #[serde(default)]
137 pub check_types: bool,
138 #[serde(default)]
140 pub check_references: bool,
141 #[serde(default)]
143 pub strict: Option<bool>,
144 #[serde(default)]
146 pub semantic: bool,
147 #[serde(default)]
149 pub strict_syntax: bool,
150 #[serde(skip, default)]
152 pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156 feature = "function-catalog-clickhouse",
157 feature = "function-catalog-duckdb",
158 feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(
161 case: polyglot_sql_function_catalogs::FunctionNameCase,
162) -> CoreFunctionNameCase {
163 match case {
164 polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
165 CoreFunctionNameCase::Insensitive
166 }
167 polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => {
168 CoreFunctionNameCase::Sensitive
169 }
170 }
171}
172
173#[cfg(any(
174 feature = "function-catalog-clickhouse",
175 feature = "function-catalog-duckdb",
176 feature = "function-catalog-all-dialects"
177))]
178fn to_core_signatures(
179 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
180) -> Vec<CoreFunctionSignature> {
181 signatures
182 .into_iter()
183 .map(|signature| CoreFunctionSignature {
184 min_arity: signature.min_arity,
185 max_arity: signature.max_arity,
186 })
187 .collect()
188}
189
190#[cfg(any(
191 feature = "function-catalog-clickhouse",
192 feature = "function-catalog-duckdb",
193 feature = "function-catalog-all-dialects"
194))]
195struct EmbeddedCatalogSink<'a> {
196 catalog: &'a mut HashMapFunctionCatalog,
197 dialect_cache: HashMap<&'static str, Option<DialectType>>,
198}
199
200#[cfg(any(
201 feature = "function-catalog-clickhouse",
202 feature = "function-catalog-duckdb",
203 feature = "function-catalog-all-dialects"
204))]
205impl<'a> EmbeddedCatalogSink<'a> {
206 fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
207 if let Some(cached) = self.dialect_cache.get(dialect) {
208 return *cached;
209 }
210 let parsed = dialect.parse::<DialectType>().ok();
211 self.dialect_cache.insert(dialect, parsed);
212 parsed
213 }
214}
215
216#[cfg(any(
217 feature = "function-catalog-clickhouse",
218 feature = "function-catalog-duckdb",
219 feature = "function-catalog-all-dialects"
220))]
221impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
222 fn set_dialect_name_case(
223 &mut self,
224 dialect: &'static str,
225 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
226 ) {
227 if let Some(core_dialect) = self.resolve_dialect(dialect) {
228 self.catalog
229 .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
230 }
231 }
232
233 fn set_function_name_case(
234 &mut self,
235 dialect: &'static str,
236 function_name: &str,
237 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
238 ) {
239 if let Some(core_dialect) = self.resolve_dialect(dialect) {
240 self.catalog.set_function_name_case(
241 core_dialect,
242 function_name,
243 to_core_name_case(name_case),
244 );
245 }
246 }
247
248 fn register(
249 &mut self,
250 dialect: &'static str,
251 function_name: &str,
252 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
253 ) {
254 if let Some(core_dialect) = self.resolve_dialect(dialect) {
255 self.catalog
256 .register(core_dialect, function_name, to_core_signatures(signatures));
257 }
258 }
259}
260
261#[cfg(any(
262 feature = "function-catalog-clickhouse",
263 feature = "function-catalog-duckdb",
264 feature = "function-catalog-all-dialects"
265))]
266fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
267 static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
268 let mut catalog = HashMapFunctionCatalog::default();
269 let mut sink = EmbeddedCatalogSink {
270 catalog: &mut catalog,
271 dialect_cache: HashMap::new(),
272 };
273 polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
274 Arc::new(catalog)
275 });
276
277 EMBEDDED.clone()
278}
279
280#[cfg(any(
281 feature = "function-catalog-clickhouse",
282 feature = "function-catalog-duckdb",
283 feature = "function-catalog-all-dialects"
284))]
285fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
286 Some(embedded_function_catalog_arc())
287}
288
289#[cfg(not(any(
290 feature = "function-catalog-clickhouse",
291 feature = "function-catalog-duckdb",
292 feature = "function-catalog-all-dialects"
293)))]
294fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
295 None
296}
297
298pub mod validation_codes {
300 pub const E_PARSE_OR_OPTIONS: &str = "E000";
302 pub const E_UNKNOWN_TABLE: &str = "E200";
303 pub const E_UNKNOWN_COLUMN: &str = "E201";
304 pub const E_UNKNOWN_FUNCTION: &str = "E202";
305 pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
306
307 pub const W_SELECT_STAR: &str = "W001";
308 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
309 pub const W_DISTINCT_ORDER_BY: &str = "W003";
310 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
311
312 pub const E_TYPE_MISMATCH: &str = "E210";
314 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
315 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
316 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
317 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
318 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
319 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
320 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
321 pub const E_INVALID_CAST: &str = "E218";
322 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
323
324 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
325 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
326 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
327 pub const W_LOSSY_CAST: &str = "W213";
328 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
329 pub const W_PREDICATE_NULLABILITY: &str = "W215";
330 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
331 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
332 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
333 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
334
335 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
337 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
338 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
339 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
340 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
341
342 pub const W_CARTESIAN_JOIN: &str = "W220";
343 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
344 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
345}
346
347#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
349#[serde(rename_all = "snake_case")]
350pub enum TypeFamily {
351 Unknown,
352 Boolean,
353 Integer,
354 Numeric,
355 String,
356 Binary,
357 Date,
358 Time,
359 Timestamp,
360 Interval,
361 Json,
362 Uuid,
363 Array,
364 Map,
365 Struct,
366}
367
368impl TypeFamily {
369 pub fn is_numeric(self) -> bool {
370 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
371 }
372
373 pub fn is_temporal(self) -> bool {
374 matches!(
375 self,
376 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
377 )
378 }
379}
380
381#[derive(Debug, Clone)]
382struct TableSchemaEntry {
383 columns: HashMap<String, TypeFamily>,
384 column_order: Vec<String>,
385}
386
387fn lower(s: &str) -> String {
388 s.to_lowercase()
389}
390
391fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
392 let open = data_type.find('(')?;
393 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
394 return None;
395 }
396 let base = data_type[..open].trim();
397 let inner = data_type[open + 1..data_type.len() - 1].trim();
398 Some((base, inner))
399}
400
401pub fn canonical_type_family(data_type: &str) -> TypeFamily {
403 let trimmed = data_type
404 .trim()
405 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
406 if trimmed.is_empty() {
407 return TypeFamily::Unknown;
408 }
409
410 let normalized = trimmed
412 .split_whitespace()
413 .collect::<Vec<_>>()
414 .join(" ")
415 .to_lowercase();
416
417 if let Some((base, inner)) = split_type_args(&normalized) {
419 match base {
420 "nullable" | "lowcardinality" => return canonical_type_family(inner),
421 "array" | "list" => return TypeFamily::Array,
422 "map" => return TypeFamily::Map,
423 "struct" | "row" | "record" => return TypeFamily::Struct,
424 _ => {}
425 }
426 }
427
428 if normalized.starts_with("array<") || normalized.starts_with("list<") {
429 return TypeFamily::Array;
430 }
431 if normalized.starts_with("map<") {
432 return TypeFamily::Map;
433 }
434 if normalized.starts_with("struct<")
435 || normalized.starts_with("row<")
436 || normalized.starts_with("record<")
437 || normalized.starts_with("object<")
438 {
439 return TypeFamily::Struct;
440 }
441
442 if normalized.ends_with("[]") {
443 return TypeFamily::Array;
444 }
445
446 let mut base = normalized
448 .split('(')
449 .next()
450 .unwrap_or("")
451 .trim()
452 .to_string();
453 if base.is_empty() {
454 return TypeFamily::Unknown;
455 }
456
457 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
458 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
459
460 match base.as_str() {
461 "bool" | "boolean" => TypeFamily::Boolean,
462 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
463 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
464 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
465 TypeFamily::Integer
466 }
467 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
468 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
469 TypeFamily::Numeric
470 }
471 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
472 | "string" | "clob" => TypeFamily::String,
473 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
474 "date" => TypeFamily::Date,
475 "time" => TypeFamily::Time,
476 "timestamp"
477 | "timestamptz"
478 | "datetime"
479 | "datetime2"
480 | "smalldatetime"
481 | "timestamp with time zone"
482 | "timestamp without time zone" => TypeFamily::Timestamp,
483 "interval" => TypeFamily::Interval,
484 "json" | "jsonb" | "variant" => TypeFamily::Json,
485 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
486 "array" | "list" => TypeFamily::Array,
487 "map" => TypeFamily::Map,
488 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
489 _ => TypeFamily::Unknown,
490 }
491}
492
493fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
494 let mut map = HashMap::new();
495
496 for table in &schema.tables {
497 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
498 let columns: HashMap<String, TypeFamily> = table
499 .columns
500 .iter()
501 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
502 .collect();
503 let entry = TableSchemaEntry {
504 columns,
505 column_order,
506 };
507
508 let simple_name = lower(&table.name);
509 map.insert(simple_name, entry.clone());
510
511 if let Some(table_schema) = &table.schema {
512 map.insert(
513 format!("{}.{}", lower(table_schema), lower(&table.name)),
514 entry.clone(),
515 );
516 }
517
518 for alias in &table.aliases {
519 map.insert(lower(alias), entry.clone());
520 }
521 }
522
523 map
524}
525
526fn type_family_to_data_type(family: TypeFamily) -> DataType {
527 match family {
528 TypeFamily::Unknown => DataType::Unknown,
529 TypeFamily::Boolean => DataType::Boolean,
530 TypeFamily::Integer => DataType::Int {
531 length: None,
532 integer_spelling: false,
533 },
534 TypeFamily::Numeric => DataType::Double {
535 precision: None,
536 scale: None,
537 },
538 TypeFamily::String => DataType::VarChar {
539 length: None,
540 parenthesized_length: false,
541 },
542 TypeFamily::Binary => DataType::VarBinary { length: None },
543 TypeFamily::Date => DataType::Date,
544 TypeFamily::Time => DataType::Time {
545 precision: None,
546 timezone: false,
547 },
548 TypeFamily::Timestamp => DataType::Timestamp {
549 precision: None,
550 timezone: false,
551 },
552 TypeFamily::Interval => DataType::Interval {
553 unit: None,
554 to: None,
555 },
556 TypeFamily::Json => DataType::Json,
557 TypeFamily::Uuid => DataType::Uuid,
558 TypeFamily::Array => DataType::Array {
559 element_type: Box::new(DataType::Unknown),
560 dimension: None,
561 },
562 TypeFamily::Map => DataType::Map {
563 key_type: Box::new(DataType::Unknown),
564 value_type: Box::new(DataType::Unknown),
565 },
566 TypeFamily::Struct => DataType::Struct {
567 fields: Vec::new(),
568 nested: false,
569 },
570 }
571}
572
573fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
574 let mut mapping = MappingSchema::new();
575
576 for table in &schema.tables {
577 let columns: Vec<(String, DataType)> = table
578 .columns
579 .iter()
580 .map(|column| {
581 (
582 lower(&column.name),
583 type_family_to_data_type(canonical_type_family(&column.data_type)),
584 )
585 })
586 .collect();
587
588 let mut table_names = Vec::new();
589 table_names.push(lower(&table.name));
590 if let Some(table_schema) = &table.schema {
591 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
592 }
593 for alias in &table.aliases {
594 table_names.push(lower(alias));
595 }
596
597 let mut dedup = HashSet::new();
598 for table_name in table_names {
599 if dedup.insert(table_name.clone()) {
600 let _ = mapping.add_table(&table_name, &columns, None);
601 }
602 }
603 }
604
605 mapping
606}
607
608pub fn mapping_schema_from_validation_schema(schema: &ValidationSchema) -> MappingSchema {
614 build_resolver_schema(schema)
615}
616
617fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
618 let mut aliases = HashSet::new();
619
620 for node in expr.dfs() {
621 match node {
622 Expression::Select(select) => {
623 if let Some(with) = &select.with {
624 for cte in &with.ctes {
625 aliases.insert(lower(&cte.alias.name));
626 }
627 }
628 }
629 Expression::Insert(insert) => {
630 if let Some(with) = &insert.with {
631 for cte in &with.ctes {
632 aliases.insert(lower(&cte.alias.name));
633 }
634 }
635 }
636 Expression::Update(update) => {
637 if let Some(with) = &update.with {
638 for cte in &with.ctes {
639 aliases.insert(lower(&cte.alias.name));
640 }
641 }
642 }
643 Expression::Delete(delete) => {
644 if let Some(with) = &delete.with {
645 for cte in &with.ctes {
646 aliases.insert(lower(&cte.alias.name));
647 }
648 }
649 }
650 Expression::Union(union) => {
651 if let Some(with) = &union.with {
652 for cte in &with.ctes {
653 aliases.insert(lower(&cte.alias.name));
654 }
655 }
656 }
657 Expression::Intersect(intersect) => {
658 if let Some(with) = &intersect.with {
659 for cte in &with.ctes {
660 aliases.insert(lower(&cte.alias.name));
661 }
662 }
663 }
664 Expression::Except(except) => {
665 if let Some(with) = &except.with {
666 for cte in &with.ctes {
667 aliases.insert(lower(&cte.alias.name));
668 }
669 }
670 }
671 Expression::Merge(merge) => {
672 if let Some(with_) = &merge.with_ {
673 if let Expression::With(with_clause) = with_.as_ref() {
674 for cte in &with_clause.ctes {
675 aliases.insert(lower(&cte.alias.name));
676 }
677 }
678 }
679 }
680 _ => {}
681 }
682 }
683
684 aliases
685}
686
687fn table_ref_candidates(table: &TableRef) -> Vec<String> {
688 let name = lower(&table.name.name);
689 let schema = table.schema.as_ref().map(|s| lower(&s.name));
690 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
691
692 let mut candidates = Vec::new();
693 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
694 candidates.push(format!("{}.{}.{}", catalog, schema, name));
695 }
696 if let Some(schema) = &schema {
697 candidates.push(format!("{}.{}", schema, name));
698 }
699 candidates.push(name);
700 candidates
701}
702
703fn table_ref_display_name(table: &TableRef) -> String {
704 let mut parts = Vec::new();
705 if let Some(catalog) = &table.catalog {
706 parts.push(catalog.name.clone());
707 }
708 if let Some(schema) = &table.schema {
709 parts.push(schema.name.clone());
710 }
711 parts.push(table.name.name.clone());
712 parts.join(".")
713}
714
715#[derive(Debug, Default, Clone)]
716struct TypeCheckContext {
717 referenced_tables: HashSet<String>,
718 table_aliases: HashMap<String, String>,
719}
720
721fn type_family_name(family: TypeFamily) -> &'static str {
722 match family {
723 TypeFamily::Unknown => "unknown",
724 TypeFamily::Boolean => "boolean",
725 TypeFamily::Integer => "integer",
726 TypeFamily::Numeric => "numeric",
727 TypeFamily::String => "string",
728 TypeFamily::Binary => "binary",
729 TypeFamily::Date => "date",
730 TypeFamily::Time => "time",
731 TypeFamily::Timestamp => "timestamp",
732 TypeFamily::Interval => "interval",
733 TypeFamily::Json => "json",
734 TypeFamily::Uuid => "uuid",
735 TypeFamily::Array => "array",
736 TypeFamily::Map => "map",
737 TypeFamily::Struct => "struct",
738 }
739}
740
741fn is_string_like(family: TypeFamily) -> bool {
742 matches!(family, TypeFamily::String)
743}
744
745fn is_string_or_binary(family: TypeFamily) -> bool {
746 matches!(family, TypeFamily::String | TypeFamily::Binary)
747}
748
749fn type_issue(
750 strict: bool,
751 error_code: &str,
752 warning_code: &str,
753 message: impl Into<String>,
754) -> ValidationError {
755 if strict {
756 ValidationError::error(message.into(), error_code)
757 } else {
758 ValidationError::warning(message.into(), warning_code)
759 }
760}
761
762fn data_type_family(data_type: &DataType) -> TypeFamily {
763 match data_type {
764 DataType::Boolean => TypeFamily::Boolean,
765 DataType::TinyInt { .. }
766 | DataType::SmallInt { .. }
767 | DataType::Int { .. }
768 | DataType::BigInt { .. } => TypeFamily::Integer,
769 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
770 TypeFamily::Numeric
771 }
772 DataType::Char { .. }
773 | DataType::VarChar { .. }
774 | DataType::String { .. }
775 | DataType::Text
776 | DataType::TextWithLength { .. }
777 | DataType::CharacterSet { .. } => TypeFamily::String,
778 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
779 DataType::Date => TypeFamily::Date,
780 DataType::Time { .. } => TypeFamily::Time,
781 DataType::Timestamp { .. } => TypeFamily::Timestamp,
782 DataType::Interval { .. } => TypeFamily::Interval,
783 DataType::Json | DataType::JsonB => TypeFamily::Json,
784 DataType::Uuid => TypeFamily::Uuid,
785 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
786 DataType::Map { .. } => TypeFamily::Map,
787 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
788 TypeFamily::Struct
789 }
790 DataType::Nullable { inner } => data_type_family(inner),
791 DataType::Custom { name } => canonical_type_family(name),
792 DataType::Unknown => TypeFamily::Unknown,
793 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
794 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
795 DataType::Vector { .. } => TypeFamily::Array,
796 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
797 }
798}
799
800fn collect_type_check_context(
801 stmt: &Expression,
802 schema_map: &HashMap<String, TableSchemaEntry>,
803) -> TypeCheckContext {
804 fn add_table_to_context(
805 table: &TableRef,
806 schema_map: &HashMap<String, TableSchemaEntry>,
807 context: &mut TypeCheckContext,
808 ) {
809 let resolved_key = table_ref_candidates(table)
810 .into_iter()
811 .find(|k| schema_map.contains_key(k));
812
813 let Some(table_key) = resolved_key else {
814 return;
815 };
816
817 context.referenced_tables.insert(table_key.clone());
818 context
819 .table_aliases
820 .insert(lower(&table.name.name), table_key.clone());
821 if let Some(alias) = &table.alias {
822 context
823 .table_aliases
824 .insert(lower(&alias.name), table_key.clone());
825 }
826 }
827
828 let mut context = TypeCheckContext::default();
829 let cte_aliases = collect_cte_aliases(stmt);
830
831 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
832 let Expression::Table(table) = node else {
833 continue;
834 };
835
836 if cte_aliases.contains(&lower(&table.name.name)) {
837 continue;
838 }
839
840 add_table_to_context(table, schema_map, &mut context);
841 }
842
843 match stmt {
846 Expression::Insert(insert) => {
847 add_table_to_context(&insert.table, schema_map, &mut context);
848 }
849 Expression::Update(update) => {
850 add_table_to_context(&update.table, schema_map, &mut context);
851 for table in &update.extra_tables {
852 add_table_to_context(table, schema_map, &mut context);
853 }
854 }
855 Expression::Delete(delete) => {
856 add_table_to_context(&delete.table, schema_map, &mut context);
857 for table in &delete.using {
858 add_table_to_context(table, schema_map, &mut context);
859 }
860 for table in &delete.tables {
861 add_table_to_context(table, schema_map, &mut context);
862 }
863 }
864 _ => {}
865 }
866
867 context
868}
869
870fn resolve_table_schema_entry<'a>(
871 table: &TableRef,
872 schema_map: &'a HashMap<String, TableSchemaEntry>,
873) -> Option<(String, &'a TableSchemaEntry)> {
874 let key = table_ref_candidates(table)
875 .into_iter()
876 .find(|k| schema_map.contains_key(k))?;
877 let entry = schema_map.get(&key)?;
878 Some((key, entry))
879}
880
881fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
882 if strict {
883 ValidationError::error(
884 message.into(),
885 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
886 )
887 } else {
888 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
889 }
890}
891
892fn reference_table_candidates(
893 table_name: &str,
894 explicit_schema: Option<&str>,
895 source_schema: Option<&str>,
896) -> Vec<String> {
897 let mut candidates = Vec::new();
898 let raw = lower(table_name);
899
900 if let Some(schema) = explicit_schema {
901 candidates.push(format!("{}.{}", lower(schema), raw));
902 }
903
904 if raw.contains('.') {
905 candidates.push(raw.clone());
906 if let Some(last) = raw.rsplit('.').next() {
907 candidates.push(last.to_string());
908 }
909 } else {
910 if let Some(schema) = source_schema {
911 candidates.push(format!("{}.{}", lower(schema), raw));
912 }
913 candidates.push(raw);
914 }
915
916 let mut dedup = HashSet::new();
917 candidates
918 .into_iter()
919 .filter(|c| dedup.insert(c.clone()))
920 .collect()
921}
922
923fn resolve_reference_table_key(
924 table_name: &str,
925 explicit_schema: Option<&str>,
926 source_schema: Option<&str>,
927 schema_map: &HashMap<String, TableSchemaEntry>,
928) -> Option<String> {
929 reference_table_candidates(table_name, explicit_schema, source_schema)
930 .into_iter()
931 .find(|candidate| schema_map.contains_key(candidate))
932}
933
934fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
935 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
936 return true;
937 }
938 if source == target {
939 return true;
940 }
941 if source.is_numeric() && target.is_numeric() {
942 return true;
943 }
944 if source.is_temporal() && target.is_temporal() {
945 return true;
946 }
947 false
948}
949
950fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
951 let mut hints = HashSet::new();
952 for column in &table.columns {
953 if column.primary_key || column.unique {
954 hints.insert(lower(&column.name));
955 }
956 }
957 for key_col in &table.primary_key {
958 hints.insert(lower(key_col));
959 }
960 for group in &table.unique_keys {
961 if group.len() == 1 {
962 if let Some(col) = group.first() {
963 hints.insert(lower(col));
964 }
965 }
966 }
967 hints
968}
969
970fn check_reference_integrity(
971 schema: &ValidationSchema,
972 schema_map: &HashMap<String, TableSchemaEntry>,
973 strict: bool,
974) -> Vec<ValidationError> {
975 let mut errors = Vec::new();
976
977 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
978 for table in &schema.tables {
979 let simple = lower(&table.name);
980 key_hints_lookup.insert(simple, table_key_hints(table));
981 if let Some(schema_name) = &table.schema {
982 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
983 key_hints_lookup.insert(qualified, table_key_hints(table));
984 }
985 }
986
987 for table in &schema.tables {
988 let source_table_display = if let Some(schema_name) = &table.schema {
989 format!("{}.{}", schema_name, table.name)
990 } else {
991 table.name.clone()
992 };
993 let source_schema = table.schema.as_deref();
994 let source_columns: HashMap<String, TypeFamily> = table
995 .columns
996 .iter()
997 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
998 .collect();
999
1000 for source_col in &table.columns {
1001 let Some(reference) = &source_col.references else {
1002 continue;
1003 };
1004 let source_type = canonical_type_family(&source_col.data_type);
1005
1006 let Some(target_key) = resolve_reference_table_key(
1007 &reference.table,
1008 reference.schema.as_deref(),
1009 source_schema,
1010 schema_map,
1011 ) else {
1012 errors.push(reference_issue(
1013 strict,
1014 format!(
1015 "Foreign key reference '{}.{}' points to unknown table '{}'",
1016 source_table_display, source_col.name, reference.table
1017 ),
1018 ));
1019 continue;
1020 };
1021
1022 let target_column = lower(&reference.column);
1023 let Some(target_entry) = schema_map.get(&target_key) else {
1024 errors.push(reference_issue(
1025 strict,
1026 format!(
1027 "Foreign key reference '{}.{}' points to unknown table '{}'",
1028 source_table_display, source_col.name, reference.table
1029 ),
1030 ));
1031 continue;
1032 };
1033
1034 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1035 errors.push(reference_issue(
1036 strict,
1037 format!(
1038 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1039 source_table_display, source_col.name, target_key, reference.column
1040 ),
1041 ));
1042 continue;
1043 };
1044
1045 if !key_types_compatible(source_type, target_type) {
1046 errors.push(reference_issue(
1047 strict,
1048 format!(
1049 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1050 source_table_display,
1051 source_col.name,
1052 target_key,
1053 reference.column,
1054 type_family_name(source_type),
1055 type_family_name(target_type)
1056 ),
1057 ));
1058 }
1059
1060 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1061 if !target_key_hints.contains(&target_column) {
1062 errors.push(ValidationError::warning(
1063 format!(
1064 "Referenced column '{}.{}' is not marked as primary/unique key",
1065 target_key, reference.column
1066 ),
1067 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1068 ));
1069 }
1070 }
1071 }
1072
1073 for foreign_key in &table.foreign_keys {
1074 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1075 errors.push(reference_issue(
1076 strict,
1077 format!(
1078 "Table-level foreign key on '{}' has empty source or target column list",
1079 source_table_display
1080 ),
1081 ));
1082 continue;
1083 }
1084 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1085 errors.push(reference_issue(
1086 strict,
1087 format!(
1088 "Table-level foreign key on '{}' has {} source columns but {} target columns",
1089 source_table_display,
1090 foreign_key.columns.len(),
1091 foreign_key.references.columns.len()
1092 ),
1093 ));
1094 continue;
1095 }
1096
1097 let Some(target_key) = resolve_reference_table_key(
1098 &foreign_key.references.table,
1099 foreign_key.references.schema.as_deref(),
1100 source_schema,
1101 schema_map,
1102 ) else {
1103 errors.push(reference_issue(
1104 strict,
1105 format!(
1106 "Table-level foreign key on '{}' points to unknown table '{}'",
1107 source_table_display, foreign_key.references.table
1108 ),
1109 ));
1110 continue;
1111 };
1112
1113 let Some(target_entry) = schema_map.get(&target_key) else {
1114 errors.push(reference_issue(
1115 strict,
1116 format!(
1117 "Table-level foreign key on '{}' points to unknown table '{}'",
1118 source_table_display, foreign_key.references.table
1119 ),
1120 ));
1121 continue;
1122 };
1123
1124 for (source_col, target_col) in foreign_key
1125 .columns
1126 .iter()
1127 .zip(foreign_key.references.columns.iter())
1128 {
1129 let source_col_name = lower(source_col);
1130 let target_col_name = lower(target_col);
1131
1132 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1133 errors.push(reference_issue(
1134 strict,
1135 format!(
1136 "Table-level foreign key on '{}' references unknown source column '{}'",
1137 source_table_display, source_col
1138 ),
1139 ));
1140 continue;
1141 };
1142
1143 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1144 errors.push(reference_issue(
1145 strict,
1146 format!(
1147 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1148 source_table_display, target_key, target_col
1149 ),
1150 ));
1151 continue;
1152 };
1153
1154 if !key_types_compatible(source_type, target_type) {
1155 errors.push(reference_issue(
1156 strict,
1157 format!(
1158 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1159 source_table_display,
1160 source_col,
1161 target_key,
1162 target_col,
1163 type_family_name(source_type),
1164 type_family_name(target_type)
1165 ),
1166 ));
1167 }
1168
1169 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1170 if !target_key_hints.contains(&target_col_name) {
1171 errors.push(ValidationError::warning(
1172 format!(
1173 "Referenced column '{}.{}' is not marked as primary/unique key",
1174 target_key, target_col
1175 ),
1176 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1177 ));
1178 }
1179 }
1180 }
1181 }
1182 }
1183
1184 errors
1185}
1186
1187fn resolve_unqualified_column_type(
1188 column_name: &str,
1189 schema_map: &HashMap<String, TableSchemaEntry>,
1190 context: &TypeCheckContext,
1191) -> TypeFamily {
1192 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1193 context.referenced_tables.iter().collect()
1194 } else {
1195 schema_map.keys().collect()
1196 };
1197
1198 let mut families = HashSet::new();
1199 for table_name in candidate_tables {
1200 if let Some(table_schema) = schema_map.get(table_name) {
1201 if let Some(family) = table_schema.columns.get(column_name) {
1202 families.insert(*family);
1203 }
1204 }
1205 }
1206
1207 if families.len() == 1 {
1208 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1209 } else {
1210 TypeFamily::Unknown
1211 }
1212}
1213
1214fn resolve_column_type(
1215 column: &Column,
1216 schema_map: &HashMap<String, TableSchemaEntry>,
1217 context: &TypeCheckContext,
1218) -> TypeFamily {
1219 let column_name = lower(&column.name.name);
1220 if column_name.is_empty() {
1221 return TypeFamily::Unknown;
1222 }
1223
1224 if let Some(table) = &column.table {
1225 let mut table_key = lower(&table.name);
1226 if let Some(mapped) = context.table_aliases.get(&table_key) {
1227 table_key = mapped.clone();
1228 }
1229
1230 return schema_map
1231 .get(&table_key)
1232 .and_then(|t| t.columns.get(&column_name))
1233 .copied()
1234 .unwrap_or(TypeFamily::Unknown);
1235 }
1236
1237 resolve_unqualified_column_type(&column_name, schema_map, context)
1238}
1239
1240struct TypeInferenceSchema<'a> {
1241 schema_map: &'a HashMap<String, TableSchemaEntry>,
1242 context: &'a TypeCheckContext,
1243}
1244
1245impl TypeInferenceSchema<'_> {
1246 fn resolve_table_key(&self, table: &str) -> Option<String> {
1247 let mut table_key = lower(table);
1248 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1249 table_key = mapped.clone();
1250 }
1251 if self.schema_map.contains_key(&table_key) {
1252 Some(table_key)
1253 } else {
1254 None
1255 }
1256 }
1257}
1258
1259impl SqlSchema for TypeInferenceSchema<'_> {
1260 fn dialect(&self) -> Option<DialectType> {
1261 None
1262 }
1263
1264 fn add_table(
1265 &mut self,
1266 _table: &str,
1267 _columns: &[(String, DataType)],
1268 _dialect: Option<DialectType>,
1269 ) -> SchemaResult<()> {
1270 Err(SchemaError::InvalidStructure(
1271 "Type inference schema is read-only".to_string(),
1272 ))
1273 }
1274
1275 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1276 let table_key = self
1277 .resolve_table_key(table)
1278 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1279 let entry = self
1280 .schema_map
1281 .get(&table_key)
1282 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1283 Ok(entry.column_order.clone())
1284 }
1285
1286 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1287 let col_name = lower(column);
1288 if table.is_empty() {
1289 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1290 return if family == TypeFamily::Unknown {
1291 Err(SchemaError::ColumnNotFound {
1292 table: "<unqualified>".to_string(),
1293 column: column.to_string(),
1294 })
1295 } else {
1296 Ok(type_family_to_data_type(family))
1297 };
1298 }
1299
1300 let table_key = self
1301 .resolve_table_key(table)
1302 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1303 let entry = self
1304 .schema_map
1305 .get(&table_key)
1306 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1307 let family =
1308 entry
1309 .columns
1310 .get(&col_name)
1311 .copied()
1312 .ok_or_else(|| SchemaError::ColumnNotFound {
1313 table: table.to_string(),
1314 column: column.to_string(),
1315 })?;
1316 Ok(type_family_to_data_type(family))
1317 }
1318
1319 fn has_column(&self, table: &str, column: &str) -> bool {
1320 self.get_column_type(table, column).is_ok()
1321 }
1322
1323 fn supported_table_args(&self) -> &[&str] {
1324 TABLE_PARTS
1325 }
1326
1327 fn is_empty(&self) -> bool {
1328 self.schema_map.is_empty()
1329 }
1330
1331 fn depth(&self) -> usize {
1332 1
1333 }
1334
1335 fn find_tables_for_column(&self, column: &str) -> Vec<String> {
1336 let col_name = column.to_lowercase();
1337 self.schema_map
1338 .iter()
1339 .filter(|(_, entry)| {
1340 entry
1341 .column_order
1342 .iter()
1343 .any(|c| c.to_lowercase() == col_name)
1344 })
1345 .map(|(table, _)| table.clone())
1346 .collect()
1347 }
1348}
1349
1350fn infer_expression_type_family(
1351 expr: &Expression,
1352 schema_map: &HashMap<String, TableSchemaEntry>,
1353 context: &TypeCheckContext,
1354) -> TypeFamily {
1355 let inference_schema = TypeInferenceSchema {
1356 schema_map,
1357 context,
1358 };
1359 let mut expr_clone = expr.clone();
1360 annotate_types(&mut expr_clone, Some(&inference_schema), None);
1361 if let Some(data_type) = expr_clone.inferred_type() {
1362 let family = data_type_family(&data_type);
1363 if family != TypeFamily::Unknown {
1364 return family;
1365 }
1366 }
1367
1368 infer_expression_type_family_fallback(expr, schema_map, context)
1369}
1370
1371fn infer_expression_type_family_fallback(
1372 expr: &Expression,
1373 schema_map: &HashMap<String, TableSchemaEntry>,
1374 context: &TypeCheckContext,
1375) -> TypeFamily {
1376 match expr {
1377 Expression::Literal(literal) => match literal.as_ref() {
1378 crate::expressions::Literal::Number(value) => {
1379 if value.contains('.') || value.contains('e') || value.contains('E') {
1380 TypeFamily::Numeric
1381 } else {
1382 TypeFamily::Integer
1383 }
1384 }
1385 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1386 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1387 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1388 crate::expressions::Literal::Timestamp(_)
1389 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1390 crate::expressions::Literal::HexString(_)
1391 | crate::expressions::Literal::BitString(_)
1392 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1393 _ => TypeFamily::String,
1394 },
1395 Expression::Boolean(_) => TypeFamily::Boolean,
1396 Expression::Null(_) => TypeFamily::Unknown,
1397 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1398 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1399 data_type_family(&cast.to)
1400 }
1401 Expression::Alias(alias) => {
1402 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1403 }
1404 Expression::Neg(unary) => {
1405 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1406 }
1407 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1408 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1409 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1410 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1411 TypeFamily::Unknown
1412 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1413 TypeFamily::Integer
1414 } else if left.is_numeric() && right.is_numeric() {
1415 TypeFamily::Numeric
1416 } else if left.is_temporal() || right.is_temporal() {
1417 left
1418 } else {
1419 TypeFamily::Unknown
1420 }
1421 }
1422 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1423 Expression::Concat(_) => TypeFamily::String,
1424 Expression::Eq(_)
1425 | Expression::Neq(_)
1426 | Expression::Lt(_)
1427 | Expression::Lte(_)
1428 | Expression::Gt(_)
1429 | Expression::Gte(_)
1430 | Expression::Like(_)
1431 | Expression::ILike(_)
1432 | Expression::And(_)
1433 | Expression::Or(_)
1434 | Expression::Not(_)
1435 | Expression::Between(_)
1436 | Expression::In(_)
1437 | Expression::IsNull(_)
1438 | Expression::IsTrue(_)
1439 | Expression::IsFalse(_)
1440 | Expression::Is(_) => TypeFamily::Boolean,
1441 Expression::Length(_) => TypeFamily::Integer,
1442 Expression::Upper(_)
1443 | Expression::Lower(_)
1444 | Expression::Trim(_)
1445 | Expression::LTrim(_)
1446 | Expression::RTrim(_)
1447 | Expression::Replace(_)
1448 | Expression::Substring(_)
1449 | Expression::Left(_)
1450 | Expression::Right(_)
1451 | Expression::Repeat(_)
1452 | Expression::Lpad(_)
1453 | Expression::Rpad(_)
1454 | Expression::ConcatWs(_) => TypeFamily::String,
1455 Expression::Abs(_)
1456 | Expression::Round(_)
1457 | Expression::Floor(_)
1458 | Expression::Ceil(_)
1459 | Expression::Power(_)
1460 | Expression::Sqrt(_)
1461 | Expression::Cbrt(_)
1462 | Expression::Ln(_)
1463 | Expression::Log(_)
1464 | Expression::Exp(_) => TypeFamily::Numeric,
1465 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1466 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1467 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1468 Expression::CurrentDate(_) => TypeFamily::Date,
1469 Expression::CurrentTime(_) => TypeFamily::Time,
1470 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1471 TypeFamily::Timestamp
1472 }
1473 Expression::Interval(_) => TypeFamily::Interval,
1474 _ => TypeFamily::Unknown,
1475 }
1476}
1477
1478fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1479 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1480 return true;
1481 }
1482 if left == right {
1483 return true;
1484 }
1485 if left.is_numeric() && right.is_numeric() {
1486 return true;
1487 }
1488 if left.is_temporal() && right.is_temporal() {
1489 return true;
1490 }
1491 false
1492}
1493
1494fn check_function_argument(
1495 errors: &mut Vec<ValidationError>,
1496 strict: bool,
1497 function_name: &str,
1498 arg_index: usize,
1499 family: TypeFamily,
1500 expected: &str,
1501 valid: bool,
1502) {
1503 if family == TypeFamily::Unknown || valid {
1504 return;
1505 }
1506
1507 errors.push(type_issue(
1508 strict,
1509 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1510 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1511 format!(
1512 "Function '{}' argument {} expects {}, found {}",
1513 function_name,
1514 arg_index + 1,
1515 expected,
1516 type_family_name(family)
1517 ),
1518 ));
1519}
1520
1521fn function_dispatch_name(name: &str) -> String {
1522 let upper = name
1523 .rsplit('.')
1524 .next()
1525 .unwrap_or(name)
1526 .trim()
1527 .to_uppercase();
1528 lower(canonical_typed_function_name_upper(&upper))
1529}
1530
1531fn function_base_name(name: &str) -> &str {
1532 name.rsplit('.').next().unwrap_or(name).trim()
1533}
1534
1535fn check_generic_function(
1536 function: &Function,
1537 schema_map: &HashMap<String, TableSchemaEntry>,
1538 context: &TypeCheckContext,
1539 strict: bool,
1540 errors: &mut Vec<ValidationError>,
1541) {
1542 let name = function_dispatch_name(&function.name);
1543
1544 let arg_family = |index: usize| -> Option<TypeFamily> {
1545 function
1546 .args
1547 .get(index)
1548 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1549 };
1550
1551 match name.as_str() {
1552 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1553 if let Some(family) = arg_family(0) {
1554 check_function_argument(
1555 errors,
1556 strict,
1557 &name,
1558 0,
1559 family,
1560 "a numeric argument",
1561 family.is_numeric(),
1562 );
1563 }
1564 }
1565 "round" | "floor" | "ceil" | "ceiling" => {
1566 if let Some(family) = arg_family(0) {
1567 check_function_argument(
1568 errors,
1569 strict,
1570 &name,
1571 0,
1572 family,
1573 "a numeric argument",
1574 family.is_numeric(),
1575 );
1576 }
1577 if let Some(family) = arg_family(1) {
1578 check_function_argument(
1579 errors,
1580 strict,
1581 &name,
1582 1,
1583 family,
1584 "a numeric argument",
1585 family.is_numeric(),
1586 );
1587 }
1588 }
1589 "power" | "pow" => {
1590 for i in [0_usize, 1_usize] {
1591 if let Some(family) = arg_family(i) {
1592 check_function_argument(
1593 errors,
1594 strict,
1595 &name,
1596 i,
1597 family,
1598 "a numeric argument",
1599 family.is_numeric(),
1600 );
1601 }
1602 }
1603 }
1604 "length" | "char_length" | "character_length" => {
1605 if let Some(family) = arg_family(0) {
1606 check_function_argument(
1607 errors,
1608 strict,
1609 &name,
1610 0,
1611 family,
1612 "a string or binary argument",
1613 is_string_or_binary(family),
1614 );
1615 }
1616 }
1617 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1618 if let Some(family) = arg_family(0) {
1619 check_function_argument(
1620 errors,
1621 strict,
1622 &name,
1623 0,
1624 family,
1625 "a string argument",
1626 is_string_like(family),
1627 );
1628 }
1629 }
1630 "substring" | "substr" => {
1631 if let Some(family) = arg_family(0) {
1632 check_function_argument(
1633 errors,
1634 strict,
1635 &name,
1636 0,
1637 family,
1638 "a string argument",
1639 is_string_like(family),
1640 );
1641 }
1642 if let Some(family) = arg_family(1) {
1643 check_function_argument(
1644 errors,
1645 strict,
1646 &name,
1647 1,
1648 family,
1649 "a numeric argument",
1650 family.is_numeric(),
1651 );
1652 }
1653 if let Some(family) = arg_family(2) {
1654 check_function_argument(
1655 errors,
1656 strict,
1657 &name,
1658 2,
1659 family,
1660 "a numeric argument",
1661 family.is_numeric(),
1662 );
1663 }
1664 }
1665 "replace" => {
1666 for i in [0_usize, 1_usize, 2_usize] {
1667 if let Some(family) = arg_family(i) {
1668 check_function_argument(
1669 errors,
1670 strict,
1671 &name,
1672 i,
1673 family,
1674 "a string argument",
1675 is_string_like(family),
1676 );
1677 }
1678 }
1679 }
1680 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1681 if let Some(family) = arg_family(0) {
1682 check_function_argument(
1683 errors,
1684 strict,
1685 &name,
1686 0,
1687 family,
1688 "a string argument",
1689 is_string_like(family),
1690 );
1691 }
1692 if let Some(family) = arg_family(1) {
1693 check_function_argument(
1694 errors,
1695 strict,
1696 &name,
1697 1,
1698 family,
1699 "a numeric argument",
1700 family.is_numeric(),
1701 );
1702 }
1703 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1704 if let Some(family) = arg_family(2) {
1705 check_function_argument(
1706 errors,
1707 strict,
1708 &name,
1709 2,
1710 family,
1711 "a string argument",
1712 is_string_like(family),
1713 );
1714 }
1715 }
1716 }
1717 _ => {}
1718 }
1719}
1720
1721fn check_function_catalog(
1722 function: &Function,
1723 dialect: DialectType,
1724 function_catalog: Option<&dyn FunctionCatalog>,
1725 strict: bool,
1726 errors: &mut Vec<ValidationError>,
1727) {
1728 let Some(catalog) = function_catalog else {
1729 return;
1730 };
1731
1732 let raw_name = function_base_name(&function.name);
1733 let normalized_name = function_dispatch_name(&function.name);
1734 let arity = function.args.len();
1735 let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1736 errors.push(if strict {
1737 ValidationError::error(
1738 format!(
1739 "Unknown function '{}' for dialect {:?}",
1740 function.name, dialect
1741 ),
1742 validation_codes::E_UNKNOWN_FUNCTION,
1743 )
1744 } else {
1745 ValidationError::warning(
1746 format!(
1747 "Unknown function '{}' for dialect {:?}",
1748 function.name, dialect
1749 ),
1750 validation_codes::E_UNKNOWN_FUNCTION,
1751 )
1752 });
1753 return;
1754 };
1755
1756 if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1757 return;
1758 }
1759
1760 let expected = signatures
1761 .iter()
1762 .map(|sig| sig.describe_arity())
1763 .collect::<Vec<_>>()
1764 .join(", ");
1765 errors.push(if strict {
1766 ValidationError::error(
1767 format!(
1768 "Invalid arity for function '{}': got {}, expected {}",
1769 function.name, arity, expected
1770 ),
1771 validation_codes::E_INVALID_FUNCTION_ARITY,
1772 )
1773 } else {
1774 ValidationError::warning(
1775 format!(
1776 "Invalid arity for function '{}': got {}, expected {}",
1777 function.name, arity, expected
1778 ),
1779 validation_codes::E_INVALID_FUNCTION_ARITY,
1780 )
1781 });
1782}
1783
1784#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1785struct DeclaredRelationship {
1786 source_table: String,
1787 source_column: String,
1788 target_table: String,
1789 target_column: String,
1790}
1791
1792fn build_declared_relationships(
1793 schema: &ValidationSchema,
1794 schema_map: &HashMap<String, TableSchemaEntry>,
1795) -> Vec<DeclaredRelationship> {
1796 let mut relationships = HashSet::new();
1797
1798 for table in &schema.tables {
1799 let Some(source_key) =
1800 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1801 else {
1802 continue;
1803 };
1804
1805 for column in &table.columns {
1806 let Some(reference) = &column.references else {
1807 continue;
1808 };
1809 let Some(target_key) = resolve_reference_table_key(
1810 &reference.table,
1811 reference.schema.as_deref(),
1812 table.schema.as_deref(),
1813 schema_map,
1814 ) else {
1815 continue;
1816 };
1817
1818 relationships.insert(DeclaredRelationship {
1819 source_table: source_key.clone(),
1820 source_column: lower(&column.name),
1821 target_table: target_key,
1822 target_column: lower(&reference.column),
1823 });
1824 }
1825
1826 for foreign_key in &table.foreign_keys {
1827 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1828 continue;
1829 }
1830 let Some(target_key) = resolve_reference_table_key(
1831 &foreign_key.references.table,
1832 foreign_key.references.schema.as_deref(),
1833 table.schema.as_deref(),
1834 schema_map,
1835 ) else {
1836 continue;
1837 };
1838
1839 for (source_col, target_col) in foreign_key
1840 .columns
1841 .iter()
1842 .zip(foreign_key.references.columns.iter())
1843 {
1844 relationships.insert(DeclaredRelationship {
1845 source_table: source_key.clone(),
1846 source_column: lower(source_col),
1847 target_table: target_key.clone(),
1848 target_column: lower(target_col),
1849 });
1850 }
1851 }
1852 }
1853
1854 relationships.into_iter().collect()
1855}
1856
1857fn resolve_column_binding(
1858 column: &Column,
1859 schema_map: &HashMap<String, TableSchemaEntry>,
1860 context: &TypeCheckContext,
1861 resolver: &mut Resolver<'_>,
1862) -> Option<(String, String)> {
1863 let column_name = lower(&column.name.name);
1864 if column_name.is_empty() {
1865 return None;
1866 }
1867
1868 if let Some(table) = &column.table {
1869 let mut table_key = lower(&table.name);
1870 if let Some(mapped) = context.table_aliases.get(&table_key) {
1871 table_key = mapped.clone();
1872 }
1873 if schema_map.contains_key(&table_key) {
1874 return Some((table_key, column_name));
1875 }
1876 return None;
1877 }
1878
1879 if let Some(resolved_source) = resolver.get_table(&column_name) {
1880 let mut table_key = lower(&resolved_source);
1881 if let Some(mapped) = context.table_aliases.get(&table_key) {
1882 table_key = mapped.clone();
1883 }
1884 if schema_map.contains_key(&table_key) {
1885 return Some((table_key, column_name));
1886 }
1887 }
1888
1889 let candidates: Vec<String> = context
1890 .referenced_tables
1891 .iter()
1892 .filter_map(|table_name| {
1893 schema_map
1894 .get(table_name)
1895 .filter(|entry| entry.columns.contains_key(&column_name))
1896 .map(|_| table_name.clone())
1897 })
1898 .collect();
1899 if candidates.len() == 1 {
1900 return Some((candidates[0].clone(), column_name));
1901 }
1902 None
1903}
1904
1905fn extract_join_equality_pairs(
1906 expr: &Expression,
1907 schema_map: &HashMap<String, TableSchemaEntry>,
1908 context: &TypeCheckContext,
1909 resolver: &mut Resolver<'_>,
1910 pairs: &mut Vec<((String, String), (String, String))>,
1911) {
1912 match expr {
1913 Expression::And(op) => {
1914 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1915 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1916 }
1917 Expression::Paren(paren) => {
1918 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1919 }
1920 Expression::Eq(op) => {
1921 let (Expression::Column(left_col), Expression::Column(right_col)) =
1922 (&op.left, &op.right)
1923 else {
1924 return;
1925 };
1926 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1927 return;
1928 };
1929 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1930 else {
1931 return;
1932 };
1933 pairs.push((left, right));
1934 }
1935 _ => {}
1936 }
1937}
1938
1939fn relationship_matches_pair(
1940 relationship: &DeclaredRelationship,
1941 left_table: &str,
1942 left_column: &str,
1943 right_table: &str,
1944 right_column: &str,
1945) -> bool {
1946 (relationship.source_table == left_table
1947 && relationship.source_column == left_column
1948 && relationship.target_table == right_table
1949 && relationship.target_column == right_column)
1950 || (relationship.source_table == right_table
1951 && relationship.source_column == right_column
1952 && relationship.target_table == left_table
1953 && relationship.target_column == left_column)
1954}
1955
1956fn resolved_table_key_from_expr(
1957 expr: &Expression,
1958 schema_map: &HashMap<String, TableSchemaEntry>,
1959) -> Option<String> {
1960 match expr {
1961 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1962 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1963 _ => None,
1964 }
1965}
1966
1967fn select_from_table_keys(
1968 select: &crate::expressions::Select,
1969 schema_map: &HashMap<String, TableSchemaEntry>,
1970) -> HashSet<String> {
1971 let mut keys = HashSet::new();
1972 if let Some(from_clause) = &select.from {
1973 for expr in &from_clause.expressions {
1974 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1975 keys.insert(key);
1976 }
1977 }
1978 }
1979 keys
1980}
1981
1982fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1983 matches!(
1984 kind,
1985 JoinKind::Natural
1986 | JoinKind::NaturalLeft
1987 | JoinKind::NaturalRight
1988 | JoinKind::NaturalFull
1989 | JoinKind::CrossApply
1990 | JoinKind::OuterApply
1991 | JoinKind::AsOf
1992 | JoinKind::AsOfLeft
1993 | JoinKind::AsOfRight
1994 | JoinKind::Lateral
1995 | JoinKind::LeftLateral
1996 )
1997}
1998
1999fn check_query_reference_quality(
2000 stmt: &Expression,
2001 schema_map: &HashMap<String, TableSchemaEntry>,
2002 resolver_schema: &MappingSchema,
2003 strict: bool,
2004 relationships: &[DeclaredRelationship],
2005) -> Vec<ValidationError> {
2006 let mut errors = Vec::new();
2007
2008 for node in stmt.dfs() {
2009 let Expression::Select(select) = node else {
2010 continue;
2011 };
2012
2013 let select_expr = Expression::Select(select.clone());
2014 let context = collect_type_check_context(&select_expr, schema_map);
2015 let scope = build_scope(&select_expr);
2016 let mut resolver = Resolver::new(&scope, resolver_schema, true);
2017
2018 if context.referenced_tables.len() > 1 {
2019 let using_columns: HashSet<String> = select
2020 .joins
2021 .iter()
2022 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
2023 .collect();
2024
2025 let mut seen = HashSet::new();
2026 for column_expr in select_expr
2027 .find_all(|e| matches!(e, Expression::Column(col) if col.table.is_none()))
2028 {
2029 let Expression::Column(column) = column_expr else {
2030 continue;
2031 };
2032
2033 let col_name = lower(&column.name.name);
2034 if col_name.is_empty()
2035 || using_columns.contains(&col_name)
2036 || !seen.insert(col_name.clone())
2037 {
2038 continue;
2039 }
2040
2041 if resolver.is_ambiguous(&col_name) {
2042 let source_count = resolver.sources_for_column(&col_name).len();
2043 errors.push(if strict {
2044 ValidationError::error(
2045 format!(
2046 "Ambiguous unqualified column '{}' found in {} referenced tables",
2047 col_name, source_count
2048 ),
2049 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2050 )
2051 } else {
2052 ValidationError::warning(
2053 format!(
2054 "Ambiguous unqualified column '{}' found in {} referenced tables",
2055 col_name, source_count
2056 ),
2057 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2058 )
2059 });
2060 }
2061 }
2062 }
2063
2064 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2065
2066 for join in &select.joins {
2067 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2068 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2069 let cartesian_like_kind = matches!(
2070 join.kind,
2071 JoinKind::Cross
2072 | JoinKind::Implicit
2073 | JoinKind::Array
2074 | JoinKind::LeftArray
2075 | JoinKind::Paste
2076 );
2077
2078 if right_table_key.is_some()
2079 && (cartesian_like_kind
2080 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2081 {
2082 errors.push(ValidationError::warning(
2083 "Potential cartesian join: JOIN without ON/USING condition",
2084 validation_codes::W_CARTESIAN_JOIN,
2085 ));
2086 }
2087
2088 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2089 if join.using.is_empty() {
2090 let mut eq_pairs = Vec::new();
2091 extract_join_equality_pairs(
2092 on_expr,
2093 schema_map,
2094 &context,
2095 &mut resolver,
2096 &mut eq_pairs,
2097 );
2098
2099 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2100 .iter()
2101 .filter(|rel| {
2102 cumulative_left_tables.contains(&rel.source_table)
2103 && rel.target_table == right_key
2104 || (cumulative_left_tables.contains(&rel.target_table)
2105 && rel.source_table == right_key)
2106 })
2107 .collect();
2108
2109 if !relevant_relationships.is_empty() {
2110 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2111 relevant_relationships
2112 .iter()
2113 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2114 });
2115 if !uses_declared_fk {
2116 errors.push(ValidationError::warning(
2117 "JOIN predicate does not use declared foreign-key relationship columns",
2118 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2119 ));
2120 }
2121 }
2122 }
2123 }
2124
2125 if let Some(right_key) = right_table_key {
2126 cumulative_left_tables.insert(right_key);
2127 }
2128 }
2129 }
2130
2131 errors
2132}
2133
2134fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2135 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2136 return true;
2137 }
2138 if left == right {
2139 return true;
2140 }
2141 if left.is_numeric() && right.is_numeric() {
2142 return true;
2143 }
2144 if left.is_temporal() && right.is_temporal() {
2145 return true;
2146 }
2147 false
2148}
2149
2150fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2151 if left == TypeFamily::Unknown {
2152 return right;
2153 }
2154 if right == TypeFamily::Unknown {
2155 return left;
2156 }
2157 if left == right {
2158 return left;
2159 }
2160 if left.is_numeric() && right.is_numeric() {
2161 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2162 return TypeFamily::Numeric;
2163 }
2164 return TypeFamily::Integer;
2165 }
2166 if left.is_temporal() && right.is_temporal() {
2167 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2168 return TypeFamily::Timestamp;
2169 }
2170 if left == TypeFamily::Date || right == TypeFamily::Date {
2171 return TypeFamily::Date;
2172 }
2173 return TypeFamily::Time;
2174 }
2175 TypeFamily::Unknown
2176}
2177
2178fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2179 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2180 return true;
2181 }
2182 if target == source {
2183 return true;
2184 }
2185
2186 match target {
2187 TypeFamily::Boolean => source == TypeFamily::Boolean,
2188 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2189 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2190 source.is_temporal()
2191 }
2192 TypeFamily::String => true,
2193 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2194 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2195 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2196 TypeFamily::Array => source == TypeFamily::Array,
2197 TypeFamily::Map => source == TypeFamily::Map,
2198 TypeFamily::Struct => source == TypeFamily::Struct,
2199 TypeFamily::Unknown => true,
2200 }
2201}
2202
2203fn projection_families(
2204 query_expr: &Expression,
2205 schema_map: &HashMap<String, TableSchemaEntry>,
2206) -> Option<Vec<TypeFamily>> {
2207 match query_expr {
2208 Expression::Select(select) => {
2209 if select
2210 .expressions
2211 .iter()
2212 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2213 {
2214 return None;
2215 }
2216 let select_expr = Expression::Select(select.clone());
2217 let context = collect_type_check_context(&select_expr, schema_map);
2218 Some(
2219 select
2220 .expressions
2221 .iter()
2222 .map(|e| infer_expression_type_family(e, schema_map, &context))
2223 .collect(),
2224 )
2225 }
2226 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2227 Expression::Union(union) => {
2228 let left = projection_families(&union.left, schema_map)?;
2229 let right = projection_families(&union.right, schema_map)?;
2230 if left.len() != right.len() {
2231 return None;
2232 }
2233 Some(
2234 left.into_iter()
2235 .zip(right)
2236 .map(|(l, r)| merged_setop_family(l, r))
2237 .collect(),
2238 )
2239 }
2240 Expression::Intersect(intersect) => {
2241 let left = projection_families(&intersect.left, schema_map)?;
2242 let right = projection_families(&intersect.right, schema_map)?;
2243 if left.len() != right.len() {
2244 return None;
2245 }
2246 Some(
2247 left.into_iter()
2248 .zip(right)
2249 .map(|(l, r)| merged_setop_family(l, r))
2250 .collect(),
2251 )
2252 }
2253 Expression::Except(except) => {
2254 let left = projection_families(&except.left, schema_map)?;
2255 let right = projection_families(&except.right, schema_map)?;
2256 if left.len() != right.len() {
2257 return None;
2258 }
2259 Some(
2260 left.into_iter()
2261 .zip(right)
2262 .map(|(l, r)| merged_setop_family(l, r))
2263 .collect(),
2264 )
2265 }
2266 Expression::Values(values) => {
2267 let first_row = values.expressions.first()?;
2268 let context = TypeCheckContext::default();
2269 Some(
2270 first_row
2271 .expressions
2272 .iter()
2273 .map(|e| infer_expression_type_family(e, schema_map, &context))
2274 .collect(),
2275 )
2276 }
2277 _ => None,
2278 }
2279}
2280
2281fn check_set_operation_compatibility(
2282 op_name: &str,
2283 left_expr: &Expression,
2284 right_expr: &Expression,
2285 schema_map: &HashMap<String, TableSchemaEntry>,
2286 strict: bool,
2287 errors: &mut Vec<ValidationError>,
2288) {
2289 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2290 return;
2291 };
2292 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2293 return;
2294 };
2295
2296 if left_projection.len() != right_projection.len() {
2297 errors.push(type_issue(
2298 strict,
2299 validation_codes::E_SETOP_ARITY_MISMATCH,
2300 validation_codes::W_SETOP_IMPLICIT_COERCION,
2301 format!(
2302 "{} operands return different column counts: left {}, right {}",
2303 op_name,
2304 left_projection.len(),
2305 right_projection.len()
2306 ),
2307 ));
2308 return;
2309 }
2310
2311 for (idx, (left, right)) in left_projection
2312 .into_iter()
2313 .zip(right_projection)
2314 .enumerate()
2315 {
2316 if !are_setop_compatible(left, right) {
2317 errors.push(type_issue(
2318 strict,
2319 validation_codes::E_SETOP_TYPE_MISMATCH,
2320 validation_codes::W_SETOP_IMPLICIT_COERCION,
2321 format!(
2322 "{} column {} has incompatible types: {} vs {}",
2323 op_name,
2324 idx + 1,
2325 type_family_name(left),
2326 type_family_name(right)
2327 ),
2328 ));
2329 }
2330 }
2331}
2332
2333fn check_insert_assignments(
2334 stmt: &Expression,
2335 insert: &Insert,
2336 schema_map: &HashMap<String, TableSchemaEntry>,
2337 strict: bool,
2338 errors: &mut Vec<ValidationError>,
2339) {
2340 let Some((target_table_key, table_schema)) =
2341 resolve_table_schema_entry(&insert.table, schema_map)
2342 else {
2343 return;
2344 };
2345
2346 let mut target_columns = Vec::new();
2347 if insert.columns.is_empty() {
2348 target_columns.extend(table_schema.column_order.iter().cloned());
2349 } else {
2350 for column in &insert.columns {
2351 let col_name = lower(&column.name);
2352 if table_schema.columns.contains_key(&col_name) {
2353 target_columns.push(col_name);
2354 } else {
2355 errors.push(if strict {
2356 ValidationError::error(
2357 format!(
2358 "Unknown column '{}' in table '{}'",
2359 column.name, target_table_key
2360 ),
2361 validation_codes::E_UNKNOWN_COLUMN,
2362 )
2363 } else {
2364 ValidationError::warning(
2365 format!(
2366 "Unknown column '{}' in table '{}'",
2367 column.name, target_table_key
2368 ),
2369 validation_codes::E_UNKNOWN_COLUMN,
2370 )
2371 });
2372 }
2373 }
2374 }
2375
2376 if target_columns.is_empty() {
2377 return;
2378 }
2379
2380 let context = collect_type_check_context(stmt, schema_map);
2381
2382 if !insert.default_values {
2383 for (row_idx, row) in insert.values.iter().enumerate() {
2384 if row.len() != target_columns.len() {
2385 errors.push(type_issue(
2386 strict,
2387 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2388 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2389 format!(
2390 "INSERT row {} has {} values but target has {} columns",
2391 row_idx + 1,
2392 row.len(),
2393 target_columns.len()
2394 ),
2395 ));
2396 continue;
2397 }
2398
2399 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2400 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2401 continue;
2402 };
2403 let source_family = infer_expression_type_family(value, schema_map, &context);
2404 if !are_assignment_compatible(target_family, source_family) {
2405 errors.push(type_issue(
2406 strict,
2407 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2408 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2409 format!(
2410 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2411 target_table_key,
2412 target_column,
2413 type_family_name(target_family),
2414 type_family_name(source_family)
2415 ),
2416 ));
2417 }
2418 }
2419 }
2420 }
2421
2422 if let Some(query) = &insert.query {
2423 if insert.by_name {
2425 return;
2426 }
2427
2428 let Some(source_projection) = projection_families(query, schema_map) else {
2429 return;
2430 };
2431
2432 if source_projection.len() != target_columns.len() {
2433 errors.push(type_issue(
2434 strict,
2435 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2436 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2437 format!(
2438 "INSERT source query has {} columns but target has {} columns",
2439 source_projection.len(),
2440 target_columns.len()
2441 ),
2442 ));
2443 return;
2444 }
2445
2446 for (source_family, target_column) in
2447 source_projection.into_iter().zip(target_columns.iter())
2448 {
2449 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2450 continue;
2451 };
2452 if !are_assignment_compatible(target_family, source_family) {
2453 errors.push(type_issue(
2454 strict,
2455 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2456 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2457 format!(
2458 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2459 target_table_key,
2460 target_column,
2461 type_family_name(target_family),
2462 type_family_name(source_family)
2463 ),
2464 ));
2465 }
2466 }
2467 }
2468}
2469
2470fn check_update_assignments(
2471 stmt: &Expression,
2472 update: &Update,
2473 schema_map: &HashMap<String, TableSchemaEntry>,
2474 strict: bool,
2475 errors: &mut Vec<ValidationError>,
2476) {
2477 let Some((target_table_key, table_schema)) =
2478 resolve_table_schema_entry(&update.table, schema_map)
2479 else {
2480 return;
2481 };
2482
2483 let context = collect_type_check_context(stmt, schema_map);
2484
2485 for (column, value) in &update.set {
2486 let col_name = lower(&column.name);
2487 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2488 errors.push(if strict {
2489 ValidationError::error(
2490 format!(
2491 "Unknown column '{}' in table '{}'",
2492 column.name, target_table_key
2493 ),
2494 validation_codes::E_UNKNOWN_COLUMN,
2495 )
2496 } else {
2497 ValidationError::warning(
2498 format!(
2499 "Unknown column '{}' in table '{}'",
2500 column.name, target_table_key
2501 ),
2502 validation_codes::E_UNKNOWN_COLUMN,
2503 )
2504 });
2505 continue;
2506 };
2507
2508 let source_family = infer_expression_type_family(value, schema_map, &context);
2509 if !are_assignment_compatible(target_family, source_family) {
2510 errors.push(type_issue(
2511 strict,
2512 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2513 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2514 format!(
2515 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2516 target_table_key,
2517 col_name,
2518 type_family_name(target_family),
2519 type_family_name(source_family)
2520 ),
2521 ));
2522 }
2523 }
2524}
2525
2526fn check_types(
2527 stmt: &Expression,
2528 dialect: DialectType,
2529 schema_map: &HashMap<String, TableSchemaEntry>,
2530 function_catalog: Option<&dyn FunctionCatalog>,
2531 strict: bool,
2532) -> Vec<ValidationError> {
2533 let mut errors = Vec::new();
2534 let context = collect_type_check_context(stmt, schema_map);
2535
2536 for node in stmt.dfs() {
2537 match node {
2538 Expression::Insert(insert) => {
2539 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2540 }
2541 Expression::Update(update) => {
2542 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2543 }
2544 Expression::Union(union) => {
2545 check_set_operation_compatibility(
2546 "UNION",
2547 &union.left,
2548 &union.right,
2549 schema_map,
2550 strict,
2551 &mut errors,
2552 );
2553 }
2554 Expression::Intersect(intersect) => {
2555 check_set_operation_compatibility(
2556 "INTERSECT",
2557 &intersect.left,
2558 &intersect.right,
2559 schema_map,
2560 strict,
2561 &mut errors,
2562 );
2563 }
2564 Expression::Except(except) => {
2565 check_set_operation_compatibility(
2566 "EXCEPT",
2567 &except.left,
2568 &except.right,
2569 schema_map,
2570 strict,
2571 &mut errors,
2572 );
2573 }
2574 Expression::Select(select) => {
2575 if let Some(prewhere) = &select.prewhere {
2576 let family = infer_expression_type_family(prewhere, schema_map, &context);
2577 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578 errors.push(type_issue(
2579 strict,
2580 validation_codes::E_INVALID_PREDICATE_TYPE,
2581 validation_codes::W_PREDICATE_NULLABILITY,
2582 format!(
2583 "PREWHERE clause expects a boolean predicate, found {}",
2584 type_family_name(family)
2585 ),
2586 ));
2587 }
2588 }
2589
2590 if let Some(where_clause) = &select.where_clause {
2591 let family =
2592 infer_expression_type_family(&where_clause.this, schema_map, &context);
2593 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594 errors.push(type_issue(
2595 strict,
2596 validation_codes::E_INVALID_PREDICATE_TYPE,
2597 validation_codes::W_PREDICATE_NULLABILITY,
2598 format!(
2599 "WHERE clause expects a boolean predicate, found {}",
2600 type_family_name(family)
2601 ),
2602 ));
2603 }
2604 }
2605
2606 if let Some(having_clause) = &select.having {
2607 let family =
2608 infer_expression_type_family(&having_clause.this, schema_map, &context);
2609 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2610 errors.push(type_issue(
2611 strict,
2612 validation_codes::E_INVALID_PREDICATE_TYPE,
2613 validation_codes::W_PREDICATE_NULLABILITY,
2614 format!(
2615 "HAVING clause expects a boolean predicate, found {}",
2616 type_family_name(family)
2617 ),
2618 ));
2619 }
2620 }
2621
2622 for join in &select.joins {
2623 if let Some(on) = &join.on {
2624 let family = infer_expression_type_family(on, schema_map, &context);
2625 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2626 errors.push(type_issue(
2627 strict,
2628 validation_codes::E_INVALID_PREDICATE_TYPE,
2629 validation_codes::W_PREDICATE_NULLABILITY,
2630 format!(
2631 "JOIN ON expects a boolean predicate, found {}",
2632 type_family_name(family)
2633 ),
2634 ));
2635 }
2636 }
2637 if let Some(match_condition) = &join.match_condition {
2638 let family =
2639 infer_expression_type_family(match_condition, schema_map, &context);
2640 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2641 errors.push(type_issue(
2642 strict,
2643 validation_codes::E_INVALID_PREDICATE_TYPE,
2644 validation_codes::W_PREDICATE_NULLABILITY,
2645 format!(
2646 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2647 type_family_name(family)
2648 ),
2649 ));
2650 }
2651 }
2652 }
2653 }
2654 Expression::Where(where_clause) => {
2655 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2656 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2657 errors.push(type_issue(
2658 strict,
2659 validation_codes::E_INVALID_PREDICATE_TYPE,
2660 validation_codes::W_PREDICATE_NULLABILITY,
2661 format!(
2662 "WHERE clause expects a boolean predicate, found {}",
2663 type_family_name(family)
2664 ),
2665 ));
2666 }
2667 }
2668 Expression::Having(having_clause) => {
2669 let family =
2670 infer_expression_type_family(&having_clause.this, schema_map, &context);
2671 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2672 errors.push(type_issue(
2673 strict,
2674 validation_codes::E_INVALID_PREDICATE_TYPE,
2675 validation_codes::W_PREDICATE_NULLABILITY,
2676 format!(
2677 "HAVING clause expects a boolean predicate, found {}",
2678 type_family_name(family)
2679 ),
2680 ));
2681 }
2682 }
2683 Expression::And(op) | Expression::Or(op) => {
2684 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2685 let family = infer_expression_type_family(expr, schema_map, &context);
2686 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2687 errors.push(type_issue(
2688 strict,
2689 validation_codes::E_INVALID_PREDICATE_TYPE,
2690 validation_codes::W_PREDICATE_NULLABILITY,
2691 format!(
2692 "Logical {} operand expects boolean, found {}",
2693 side,
2694 type_family_name(family)
2695 ),
2696 ));
2697 }
2698 }
2699 }
2700 Expression::Not(unary) => {
2701 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2702 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2703 errors.push(type_issue(
2704 strict,
2705 validation_codes::E_INVALID_PREDICATE_TYPE,
2706 validation_codes::W_PREDICATE_NULLABILITY,
2707 format!("NOT expects boolean, found {}", type_family_name(family)),
2708 ));
2709 }
2710 }
2711 Expression::Eq(op)
2712 | Expression::Neq(op)
2713 | Expression::Lt(op)
2714 | Expression::Lte(op)
2715 | Expression::Gt(op)
2716 | Expression::Gte(op) => {
2717 let left = infer_expression_type_family(&op.left, schema_map, &context);
2718 let right = infer_expression_type_family(&op.right, schema_map, &context);
2719 if !are_comparable(left, right) {
2720 errors.push(type_issue(
2721 strict,
2722 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2723 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2724 format!(
2725 "Incompatible comparison between {} and {}",
2726 type_family_name(left),
2727 type_family_name(right)
2728 ),
2729 ));
2730 }
2731 }
2732 Expression::Like(op) | Expression::ILike(op) => {
2733 let left = infer_expression_type_family(&op.left, schema_map, &context);
2734 let right = infer_expression_type_family(&op.right, schema_map, &context);
2735 if left != TypeFamily::Unknown
2736 && right != TypeFamily::Unknown
2737 && (!is_string_like(left) || !is_string_like(right))
2738 {
2739 errors.push(type_issue(
2740 strict,
2741 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2742 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2743 format!(
2744 "LIKE/ILIKE expects string operands, found {} and {}",
2745 type_family_name(left),
2746 type_family_name(right)
2747 ),
2748 ));
2749 }
2750 }
2751 Expression::Between(between) => {
2752 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2753 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2754 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2755
2756 if !are_comparable(this_family, low_family)
2757 || !are_comparable(this_family, high_family)
2758 {
2759 errors.push(type_issue(
2760 strict,
2761 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2762 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2763 format!(
2764 "BETWEEN bounds are incompatible with {} (found {} and {})",
2765 type_family_name(this_family),
2766 type_family_name(low_family),
2767 type_family_name(high_family)
2768 ),
2769 ));
2770 }
2771 }
2772 Expression::In(in_expr) => {
2773 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2774 for value in &in_expr.expressions {
2775 let item_family = infer_expression_type_family(value, schema_map, &context);
2776 if !are_comparable(this_family, item_family) {
2777 errors.push(type_issue(
2778 strict,
2779 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2780 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2781 format!(
2782 "IN item type {} is incompatible with {}",
2783 type_family_name(item_family),
2784 type_family_name(this_family)
2785 ),
2786 ));
2787 break;
2788 }
2789 }
2790 }
2791 Expression::Add(op)
2792 | Expression::Sub(op)
2793 | Expression::Mul(op)
2794 | Expression::Div(op)
2795 | Expression::Mod(op) => {
2796 let left = infer_expression_type_family(&op.left, schema_map, &context);
2797 let right = infer_expression_type_family(&op.right, schema_map, &context);
2798
2799 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2800 continue;
2801 }
2802
2803 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2804 && ((left.is_temporal() && right.is_numeric())
2805 || (right.is_temporal() && left.is_numeric())
2806 || (matches!(node, Expression::Sub(_))
2807 && left.is_temporal()
2808 && right.is_temporal()));
2809
2810 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2811 errors.push(type_issue(
2812 strict,
2813 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2814 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2815 format!(
2816 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2817 type_family_name(left),
2818 type_family_name(right)
2819 ),
2820 ));
2821 }
2822 }
2823 Expression::Function(function) => {
2824 check_function_catalog(function, dialect, function_catalog, strict, &mut errors);
2825 check_generic_function(function, schema_map, &context, strict, &mut errors);
2826 }
2827 Expression::Upper(func)
2828 | Expression::Lower(func)
2829 | Expression::LTrim(func)
2830 | Expression::RTrim(func)
2831 | Expression::Reverse(func) => {
2832 let family = infer_expression_type_family(&func.this, schema_map, &context);
2833 check_function_argument(
2834 &mut errors,
2835 strict,
2836 "string_function",
2837 0,
2838 family,
2839 "a string argument",
2840 is_string_like(family),
2841 );
2842 }
2843 Expression::Length(func) => {
2844 let family = infer_expression_type_family(&func.this, schema_map, &context);
2845 check_function_argument(
2846 &mut errors,
2847 strict,
2848 "length",
2849 0,
2850 family,
2851 "a string or binary argument",
2852 is_string_or_binary(family),
2853 );
2854 }
2855 Expression::Trim(func) => {
2856 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2857 check_function_argument(
2858 &mut errors,
2859 strict,
2860 "trim",
2861 0,
2862 this_family,
2863 "a string argument",
2864 is_string_like(this_family),
2865 );
2866 if let Some(chars) = &func.characters {
2867 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2868 check_function_argument(
2869 &mut errors,
2870 strict,
2871 "trim",
2872 1,
2873 chars_family,
2874 "a string argument",
2875 is_string_like(chars_family),
2876 );
2877 }
2878 }
2879 Expression::Substring(func) => {
2880 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2881 check_function_argument(
2882 &mut errors,
2883 strict,
2884 "substring",
2885 0,
2886 this_family,
2887 "a string argument",
2888 is_string_like(this_family),
2889 );
2890
2891 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2892 check_function_argument(
2893 &mut errors,
2894 strict,
2895 "substring",
2896 1,
2897 start_family,
2898 "a numeric argument",
2899 start_family.is_numeric(),
2900 );
2901 if let Some(length) = &func.length {
2902 let length_family = infer_expression_type_family(length, schema_map, &context);
2903 check_function_argument(
2904 &mut errors,
2905 strict,
2906 "substring",
2907 2,
2908 length_family,
2909 "a numeric argument",
2910 length_family.is_numeric(),
2911 );
2912 }
2913 }
2914 Expression::Replace(func) => {
2915 for (arg, idx) in [
2916 (&func.this, 0_usize),
2917 (&func.old, 1_usize),
2918 (&func.new, 2_usize),
2919 ] {
2920 let family = infer_expression_type_family(arg, schema_map, &context);
2921 check_function_argument(
2922 &mut errors,
2923 strict,
2924 "replace",
2925 idx,
2926 family,
2927 "a string argument",
2928 is_string_like(family),
2929 );
2930 }
2931 }
2932 Expression::Left(func) | Expression::Right(func) => {
2933 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2934 check_function_argument(
2935 &mut errors,
2936 strict,
2937 "left_right",
2938 0,
2939 this_family,
2940 "a string argument",
2941 is_string_like(this_family),
2942 );
2943 let length_family =
2944 infer_expression_type_family(&func.length, schema_map, &context);
2945 check_function_argument(
2946 &mut errors,
2947 strict,
2948 "left_right",
2949 1,
2950 length_family,
2951 "a numeric argument",
2952 length_family.is_numeric(),
2953 );
2954 }
2955 Expression::Repeat(func) => {
2956 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2957 check_function_argument(
2958 &mut errors,
2959 strict,
2960 "repeat",
2961 0,
2962 this_family,
2963 "a string argument",
2964 is_string_like(this_family),
2965 );
2966 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2967 check_function_argument(
2968 &mut errors,
2969 strict,
2970 "repeat",
2971 1,
2972 times_family,
2973 "a numeric argument",
2974 times_family.is_numeric(),
2975 );
2976 }
2977 Expression::Lpad(func) | Expression::Rpad(func) => {
2978 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2979 check_function_argument(
2980 &mut errors,
2981 strict,
2982 "pad",
2983 0,
2984 this_family,
2985 "a string argument",
2986 is_string_like(this_family),
2987 );
2988 let length_family =
2989 infer_expression_type_family(&func.length, schema_map, &context);
2990 check_function_argument(
2991 &mut errors,
2992 strict,
2993 "pad",
2994 1,
2995 length_family,
2996 "a numeric argument",
2997 length_family.is_numeric(),
2998 );
2999 if let Some(fill) = &func.fill {
3000 let fill_family = infer_expression_type_family(fill, schema_map, &context);
3001 check_function_argument(
3002 &mut errors,
3003 strict,
3004 "pad",
3005 2,
3006 fill_family,
3007 "a string argument",
3008 is_string_like(fill_family),
3009 );
3010 }
3011 }
3012 Expression::Abs(func)
3013 | Expression::Sqrt(func)
3014 | Expression::Cbrt(func)
3015 | Expression::Ln(func)
3016 | Expression::Exp(func) => {
3017 let family = infer_expression_type_family(&func.this, schema_map, &context);
3018 check_function_argument(
3019 &mut errors,
3020 strict,
3021 "numeric_function",
3022 0,
3023 family,
3024 "a numeric argument",
3025 family.is_numeric(),
3026 );
3027 }
3028 Expression::Round(func) => {
3029 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3030 check_function_argument(
3031 &mut errors,
3032 strict,
3033 "round",
3034 0,
3035 this_family,
3036 "a numeric argument",
3037 this_family.is_numeric(),
3038 );
3039 if let Some(decimals) = &func.decimals {
3040 let decimals_family =
3041 infer_expression_type_family(decimals, schema_map, &context);
3042 check_function_argument(
3043 &mut errors,
3044 strict,
3045 "round",
3046 1,
3047 decimals_family,
3048 "a numeric argument",
3049 decimals_family.is_numeric(),
3050 );
3051 }
3052 }
3053 Expression::Floor(func) => {
3054 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3055 check_function_argument(
3056 &mut errors,
3057 strict,
3058 "floor",
3059 0,
3060 this_family,
3061 "a numeric argument",
3062 this_family.is_numeric(),
3063 );
3064 if let Some(scale) = &func.scale {
3065 let scale_family = infer_expression_type_family(scale, schema_map, &context);
3066 check_function_argument(
3067 &mut errors,
3068 strict,
3069 "floor",
3070 1,
3071 scale_family,
3072 "a numeric argument",
3073 scale_family.is_numeric(),
3074 );
3075 }
3076 }
3077 Expression::Ceil(func) => {
3078 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3079 check_function_argument(
3080 &mut errors,
3081 strict,
3082 "ceil",
3083 0,
3084 this_family,
3085 "a numeric argument",
3086 this_family.is_numeric(),
3087 );
3088 if let Some(decimals) = &func.decimals {
3089 let decimals_family =
3090 infer_expression_type_family(decimals, schema_map, &context);
3091 check_function_argument(
3092 &mut errors,
3093 strict,
3094 "ceil",
3095 1,
3096 decimals_family,
3097 "a numeric argument",
3098 decimals_family.is_numeric(),
3099 );
3100 }
3101 }
3102 Expression::Power(func) => {
3103 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3104 check_function_argument(
3105 &mut errors,
3106 strict,
3107 "power",
3108 0,
3109 left_family,
3110 "a numeric argument",
3111 left_family.is_numeric(),
3112 );
3113 let right_family =
3114 infer_expression_type_family(&func.expression, schema_map, &context);
3115 check_function_argument(
3116 &mut errors,
3117 strict,
3118 "power",
3119 1,
3120 right_family,
3121 "a numeric argument",
3122 right_family.is_numeric(),
3123 );
3124 }
3125 Expression::Log(func) => {
3126 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3127 check_function_argument(
3128 &mut errors,
3129 strict,
3130 "log",
3131 0,
3132 this_family,
3133 "a numeric argument",
3134 this_family.is_numeric(),
3135 );
3136 if let Some(base) = &func.base {
3137 let base_family = infer_expression_type_family(base, schema_map, &context);
3138 check_function_argument(
3139 &mut errors,
3140 strict,
3141 "log",
3142 1,
3143 base_family,
3144 "a numeric argument",
3145 base_family.is_numeric(),
3146 );
3147 }
3148 }
3149 _ => {}
3150 }
3151 }
3152
3153 errors
3154}
3155
3156fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3157 let mut errors = Vec::new();
3158
3159 let Expression::Select(select) = stmt else {
3160 return errors;
3161 };
3162 let select_expr = Expression::Select(select.clone());
3163
3164 if !select_expr
3166 .find_all(|e| matches!(e, Expression::Star(_)))
3167 .is_empty()
3168 {
3169 errors.push(ValidationError::warning(
3170 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3171 validation_codes::W_SELECT_STAR,
3172 ));
3173 }
3174
3175 let aggregate_count = get_aggregate_functions(&select_expr).len();
3177 if aggregate_count > 0 && select.group_by.is_none() {
3178 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3179 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3180 && get_aggregate_functions(expr).is_empty()
3181 });
3182
3183 if has_non_aggregate_column {
3184 errors.push(ValidationError::warning(
3185 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3186 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3187 ));
3188 }
3189 }
3190
3191 if select.distinct && select.order_by.is_some() {
3193 errors.push(ValidationError::warning(
3194 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3195 validation_codes::W_DISTINCT_ORDER_BY,
3196 ));
3197 }
3198
3199 if select.limit.is_some() && select.order_by.is_none() {
3201 errors.push(ValidationError::warning(
3202 "LIMIT without ORDER BY produces non-deterministic results",
3203 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3204 ));
3205 }
3206
3207 errors
3208}
3209
3210fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3211 scope
3212 .sources
3213 .get_key_value(name)
3214 .map(|(k, _)| k.clone())
3215 .or_else(|| {
3216 scope
3217 .sources
3218 .keys()
3219 .find(|source| source.eq_ignore_ascii_case(name))
3220 .cloned()
3221 })
3222}
3223
3224fn source_has_column(columns: &[String], column_name: &str) -> bool {
3225 columns
3226 .iter()
3227 .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3228}
3229
3230fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3231 scope
3232 .sources
3233 .get(source_name)
3234 .map(|source| match &source.expression {
3235 Expression::Table(table) => lower(&table_ref_display_name(table)),
3236 _ => lower(source_name),
3237 })
3238 .unwrap_or_else(|| lower(source_name))
3239}
3240
3241fn validate_select_columns_with_schema(
3242 select: &crate::expressions::Select,
3243 schema_map: &HashMap<String, TableSchemaEntry>,
3244 resolver_schema: &MappingSchema,
3245 strict: bool,
3246) -> Vec<ValidationError> {
3247 let mut errors = Vec::new();
3248 let select_expr = Expression::Select(Box::new(select.clone()));
3249 let scope = build_scope(&select_expr);
3250 let mut resolver = Resolver::new(&scope, resolver_schema, true);
3251 let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3252
3253 for node in walk_in_scope(&select_expr, false) {
3254 let Expression::Column(column) = node else {
3255 continue;
3256 };
3257
3258 let col_name = lower(&column.name.name);
3259 if col_name.is_empty() {
3260 continue;
3261 }
3262
3263 if let Some(table) = &column.table {
3264 let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3265 errors.push(if strict {
3267 ValidationError::error(
3268 format!(
3269 "Unknown table or alias '{}' referenced by column '{}'",
3270 table.name, col_name
3271 ),
3272 validation_codes::E_UNRESOLVED_REFERENCE,
3273 )
3274 } else {
3275 ValidationError::warning(
3276 format!(
3277 "Unknown table or alias '{}' referenced by column '{}'",
3278 table.name, col_name
3279 ),
3280 validation_codes::E_UNRESOLVED_REFERENCE,
3281 )
3282 });
3283 continue;
3284 };
3285
3286 if let Ok(columns) = resolver.get_source_columns(&source_name) {
3287 if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3288 let table_name = source_display_name(&scope, &source_name);
3289 errors.push(if strict {
3290 ValidationError::error(
3291 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3292 validation_codes::E_UNKNOWN_COLUMN,
3293 )
3294 } else {
3295 ValidationError::warning(
3296 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3297 validation_codes::E_UNKNOWN_COLUMN,
3298 )
3299 });
3300 }
3301 }
3302 continue;
3303 }
3304
3305 let matching_sources: Vec<String> = source_names
3306 .iter()
3307 .filter_map(|source_name| {
3308 resolver
3309 .get_source_columns(source_name)
3310 .ok()
3311 .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3312 .map(|_| source_name.clone())
3313 })
3314 .collect();
3315
3316 if !matching_sources.is_empty() {
3317 continue;
3318 }
3319
3320 let known_sources: Vec<String> = source_names
3321 .iter()
3322 .filter_map(|source_name| {
3323 resolver
3324 .get_source_columns(source_name)
3325 .ok()
3326 .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3327 .map(|_| source_name.clone())
3328 })
3329 .collect();
3330
3331 if known_sources.len() == 1 {
3332 let table_name = source_display_name(&scope, &known_sources[0]);
3333 errors.push(if strict {
3334 ValidationError::error(
3335 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3336 validation_codes::E_UNKNOWN_COLUMN,
3337 )
3338 } else {
3339 ValidationError::warning(
3340 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3341 validation_codes::E_UNKNOWN_COLUMN,
3342 )
3343 });
3344 } else if known_sources.len() > 1 {
3345 errors.push(if strict {
3346 ValidationError::error(
3347 format!(
3348 "Unknown column '{}' (not found in any referenced table)",
3349 col_name
3350 ),
3351 validation_codes::E_UNKNOWN_COLUMN,
3352 )
3353 } else {
3354 ValidationError::warning(
3355 format!(
3356 "Unknown column '{}' (not found in any referenced table)",
3357 col_name
3358 ),
3359 validation_codes::E_UNKNOWN_COLUMN,
3360 )
3361 });
3362 } else if !schema_map.is_empty() {
3363 let found = schema_map
3364 .values()
3365 .any(|table_schema| table_schema.columns.contains_key(&col_name));
3366 if !found {
3367 errors.push(if strict {
3368 ValidationError::error(
3369 format!("Unknown column '{}'", col_name),
3370 validation_codes::E_UNKNOWN_COLUMN,
3371 )
3372 } else {
3373 ValidationError::warning(
3374 format!("Unknown column '{}'", col_name),
3375 validation_codes::E_UNKNOWN_COLUMN,
3376 )
3377 });
3378 }
3379 }
3380 }
3381
3382 errors
3383}
3384
3385fn validate_statement_with_schema(
3386 stmt: &Expression,
3387 schema_map: &HashMap<String, TableSchemaEntry>,
3388 resolver_schema: &MappingSchema,
3389 strict: bool,
3390) -> Vec<ValidationError> {
3391 let mut errors = Vec::new();
3392 let cte_aliases = collect_cte_aliases(stmt);
3393 let mut seen_tables: HashSet<String> = HashSet::new();
3394
3395 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3397 let Expression::Table(table) = node else {
3398 continue;
3399 };
3400
3401 if cte_aliases.contains(&lower(&table.name.name)) {
3402 continue;
3403 }
3404
3405 let resolved_key = table_ref_candidates(table)
3406 .into_iter()
3407 .find(|k| schema_map.contains_key(k));
3408 let table_key = resolved_key
3409 .clone()
3410 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3411
3412 if !seen_tables.insert(table_key) {
3413 continue;
3414 }
3415
3416 if resolved_key.is_none() {
3417 errors.push(if strict {
3418 ValidationError::error(
3419 format!("Unknown table '{}'", table_ref_display_name(table)),
3420 validation_codes::E_UNKNOWN_TABLE,
3421 )
3422 } else {
3423 ValidationError::warning(
3424 format!("Unknown table '{}'", table_ref_display_name(table)),
3425 validation_codes::E_UNKNOWN_TABLE,
3426 )
3427 });
3428 }
3429 }
3430
3431 for node in stmt.dfs() {
3432 let Expression::Select(select) = node else {
3433 continue;
3434 };
3435 errors.extend(validate_select_columns_with_schema(
3436 select,
3437 schema_map,
3438 resolver_schema,
3439 strict,
3440 ));
3441 }
3442
3443 errors
3444}
3445
3446pub fn validate_with_schema(
3448 sql: &str,
3449 dialect: DialectType,
3450 schema: &ValidationSchema,
3451 options: &SchemaValidationOptions,
3452) -> ValidationResult {
3453 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3454
3455 let syntax_result = crate::validate_with_options(
3457 sql,
3458 dialect,
3459 &crate::ValidationOptions {
3460 strict_syntax: options.strict_syntax,
3461 },
3462 );
3463 if !syntax_result.valid {
3464 return syntax_result;
3465 }
3466
3467 let d = Dialect::get(dialect);
3468 let statements = match d.parse(sql) {
3469 Ok(exprs) => exprs,
3470 Err(e) => {
3471 return ValidationResult::with_errors(vec![ValidationError::error(
3472 e.to_string(),
3473 validation_codes::E_PARSE_OR_OPTIONS,
3474 )]);
3475 }
3476 };
3477
3478 let schema_map = build_schema_map(schema);
3479 let resolver_schema = build_resolver_schema(schema);
3480 let mut all_errors = syntax_result.errors;
3481 let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3482 default_embedded_function_catalog()
3483 } else {
3484 None
3485 };
3486 let effective_function_catalog = options
3487 .function_catalog
3488 .as_deref()
3489 .or_else(|| embedded_function_catalog.as_deref());
3490 let declared_relationships = if options.check_references {
3491 build_declared_relationships(schema, &schema_map)
3492 } else {
3493 Vec::new()
3494 };
3495
3496 if options.check_references {
3497 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3498 }
3499
3500 for stmt in &statements {
3501 if options.semantic {
3502 all_errors.extend(check_semantics(stmt));
3503 }
3504 all_errors.extend(validate_statement_with_schema(
3505 stmt,
3506 &schema_map,
3507 &resolver_schema,
3508 strict,
3509 ));
3510 if options.check_types {
3511 all_errors.extend(check_types(
3512 stmt,
3513 dialect,
3514 &schema_map,
3515 effective_function_catalog,
3516 strict,
3517 ));
3518 }
3519 if options.check_references {
3520 all_errors.extend(check_query_reference_quality(
3521 stmt,
3522 &schema_map,
3523 &resolver_schema,
3524 strict,
3525 &declared_relationships,
3526 ));
3527 }
3528 }
3529
3530 ValidationResult::with_errors(all_errors)
3531}
3532
3533#[cfg(test)]
3534mod tests;