1use crate::ast_transforms::get_aggregate_functions;
8use crate::dialects::{Dialect, DialectType};
9use crate::error::{ValidationError, ValidationResult};
10use crate::expressions::{
11 Column, DataType, Expression, Function, Insert, JoinKind, TableRef, Update,
12};
13use crate::function_catalog::FunctionCatalog;
14#[cfg(any(
15 feature = "function-catalog-clickhouse",
16 feature = "function-catalog-duckdb",
17 feature = "function-catalog-all-dialects"
18))]
19use crate::function_catalog::{
20 FunctionNameCase as CoreFunctionNameCase, FunctionSignature as CoreFunctionSignature,
21 HashMapFunctionCatalog,
22};
23use crate::function_registry::canonical_typed_function_name_upper;
24use crate::optimizer::annotate_types;
25use crate::resolver::Resolver;
26use crate::schema::{MappingSchema, Schema as SqlSchema, SchemaError, SchemaResult, TABLE_PARTS};
27use crate::scope::{build_scope, walk_in_scope};
28use crate::traversal::ExpressionWalk;
29use serde::{Deserialize, Serialize};
30use std::collections::{HashMap, HashSet};
31use std::sync::Arc;
32
33#[cfg(any(
34 feature = "function-catalog-clickhouse",
35 feature = "function-catalog-duckdb",
36 feature = "function-catalog-all-dialects"
37))]
38use std::sync::LazyLock;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SchemaColumn {
43 pub name: String,
45 #[serde(default, rename = "type")]
47 pub data_type: String,
48 #[serde(default)]
50 pub nullable: Option<bool>,
51 #[serde(default, rename = "primaryKey")]
53 pub primary_key: bool,
54 #[serde(default)]
56 pub unique: bool,
57 #[serde(default)]
59 pub references: Option<SchemaColumnReference>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct SchemaColumnReference {
65 pub table: String,
67 pub column: String,
69 #[serde(default)]
71 pub schema: Option<String>,
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct SchemaForeignKey {
77 #[serde(default)]
79 pub name: Option<String>,
80 pub columns: Vec<String>,
82 pub references: SchemaTableReference,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub struct SchemaTableReference {
89 pub table: String,
91 pub columns: Vec<String>,
93 #[serde(default)]
95 pub schema: Option<String>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct SchemaTable {
101 pub name: String,
103 #[serde(default)]
105 pub schema: Option<String>,
106 pub columns: Vec<SchemaColumn>,
108 #[serde(default)]
110 pub aliases: Vec<String>,
111 #[serde(default, rename = "primaryKey")]
113 pub primary_key: Vec<String>,
114 #[serde(default, rename = "uniqueKeys")]
116 pub unique_keys: Vec<Vec<String>>,
117 #[serde(default, rename = "foreignKeys")]
119 pub foreign_keys: Vec<SchemaForeignKey>,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct ValidationSchema {
125 pub tables: Vec<SchemaTable>,
127 #[serde(default)]
129 pub strict: Option<bool>,
130}
131
132#[derive(Clone, Serialize, Deserialize, Default)]
134pub struct SchemaValidationOptions {
135 #[serde(default)]
137 pub check_types: bool,
138 #[serde(default)]
140 pub check_references: bool,
141 #[serde(default)]
143 pub strict: Option<bool>,
144 #[serde(default)]
146 pub semantic: bool,
147 #[serde(default)]
149 pub strict_syntax: bool,
150 #[serde(skip, default)]
152 pub function_catalog: Option<Arc<dyn FunctionCatalog>>,
153}
154
155#[cfg(any(
156 feature = "function-catalog-clickhouse",
157 feature = "function-catalog-duckdb",
158 feature = "function-catalog-all-dialects"
159))]
160fn to_core_name_case(case: polyglot_sql_function_catalogs::FunctionNameCase) -> CoreFunctionNameCase {
161 match case {
162 polyglot_sql_function_catalogs::FunctionNameCase::Insensitive => {
163 CoreFunctionNameCase::Insensitive
164 }
165 polyglot_sql_function_catalogs::FunctionNameCase::Sensitive => CoreFunctionNameCase::Sensitive,
166 }
167}
168
169#[cfg(any(
170 feature = "function-catalog-clickhouse",
171 feature = "function-catalog-duckdb",
172 feature = "function-catalog-all-dialects"
173))]
174fn to_core_signatures(
175 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
176) -> Vec<CoreFunctionSignature> {
177 signatures
178 .into_iter()
179 .map(|signature| CoreFunctionSignature {
180 min_arity: signature.min_arity,
181 max_arity: signature.max_arity,
182 })
183 .collect()
184}
185
186#[cfg(any(
187 feature = "function-catalog-clickhouse",
188 feature = "function-catalog-duckdb",
189 feature = "function-catalog-all-dialects"
190))]
191struct EmbeddedCatalogSink<'a> {
192 catalog: &'a mut HashMapFunctionCatalog,
193 dialect_cache: HashMap<&'static str, Option<DialectType>>,
194}
195
196#[cfg(any(
197 feature = "function-catalog-clickhouse",
198 feature = "function-catalog-duckdb",
199 feature = "function-catalog-all-dialects"
200))]
201impl<'a> EmbeddedCatalogSink<'a> {
202 fn resolve_dialect(&mut self, dialect: &'static str) -> Option<DialectType> {
203 if let Some(cached) = self.dialect_cache.get(dialect) {
204 return *cached;
205 }
206 let parsed = dialect.parse::<DialectType>().ok();
207 self.dialect_cache.insert(dialect, parsed);
208 parsed
209 }
210}
211
212#[cfg(any(
213 feature = "function-catalog-clickhouse",
214 feature = "function-catalog-duckdb",
215 feature = "function-catalog-all-dialects"
216))]
217impl<'a> polyglot_sql_function_catalogs::CatalogSink for EmbeddedCatalogSink<'a> {
218 fn set_dialect_name_case(
219 &mut self,
220 dialect: &'static str,
221 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
222 ) {
223 if let Some(core_dialect) = self.resolve_dialect(dialect) {
224 self.catalog
225 .set_dialect_name_case(core_dialect, to_core_name_case(name_case));
226 }
227 }
228
229 fn set_function_name_case(
230 &mut self,
231 dialect: &'static str,
232 function_name: &str,
233 name_case: polyglot_sql_function_catalogs::FunctionNameCase,
234 ) {
235 if let Some(core_dialect) = self.resolve_dialect(dialect) {
236 self.catalog
237 .set_function_name_case(core_dialect, function_name, to_core_name_case(name_case));
238 }
239 }
240
241 fn register(
242 &mut self,
243 dialect: &'static str,
244 function_name: &str,
245 signatures: Vec<polyglot_sql_function_catalogs::FunctionSignature>,
246 ) {
247 if let Some(core_dialect) = self.resolve_dialect(dialect) {
248 self.catalog
249 .register(core_dialect, function_name, to_core_signatures(signatures));
250 }
251 }
252}
253
254#[cfg(any(
255 feature = "function-catalog-clickhouse",
256 feature = "function-catalog-duckdb",
257 feature = "function-catalog-all-dialects"
258))]
259fn embedded_function_catalog_arc() -> Arc<dyn FunctionCatalog> {
260 static EMBEDDED: LazyLock<Arc<HashMapFunctionCatalog>> = LazyLock::new(|| {
261 let mut catalog = HashMapFunctionCatalog::default();
262 let mut sink = EmbeddedCatalogSink {
263 catalog: &mut catalog,
264 dialect_cache: HashMap::new(),
265 };
266 polyglot_sql_function_catalogs::register_enabled_catalogs(&mut sink);
267 Arc::new(catalog)
268 });
269
270 EMBEDDED.clone()
271}
272
273#[cfg(any(
274 feature = "function-catalog-clickhouse",
275 feature = "function-catalog-duckdb",
276 feature = "function-catalog-all-dialects"
277))]
278fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
279 Some(embedded_function_catalog_arc())
280}
281
282#[cfg(not(any(
283 feature = "function-catalog-clickhouse",
284 feature = "function-catalog-duckdb",
285 feature = "function-catalog-all-dialects"
286)))]
287fn default_embedded_function_catalog() -> Option<Arc<dyn FunctionCatalog>> {
288 None
289}
290
291pub mod validation_codes {
293 pub const E_PARSE_OR_OPTIONS: &str = "E000";
295 pub const E_UNKNOWN_TABLE: &str = "E200";
296 pub const E_UNKNOWN_COLUMN: &str = "E201";
297 pub const E_UNKNOWN_FUNCTION: &str = "E202";
298 pub const E_INVALID_FUNCTION_ARITY: &str = "E203";
299
300 pub const W_SELECT_STAR: &str = "W001";
301 pub const W_AGGREGATE_WITHOUT_GROUP_BY: &str = "W002";
302 pub const W_DISTINCT_ORDER_BY: &str = "W003";
303 pub const W_LIMIT_WITHOUT_ORDER_BY: &str = "W004";
304
305 pub const E_TYPE_MISMATCH: &str = "E210";
307 pub const E_INVALID_PREDICATE_TYPE: &str = "E211";
308 pub const E_INVALID_ARITHMETIC_TYPE: &str = "E212";
309 pub const E_INVALID_FUNCTION_ARGUMENT_TYPE: &str = "E213";
310 pub const E_INVALID_ASSIGNMENT_TYPE: &str = "E214";
311 pub const E_SETOP_TYPE_MISMATCH: &str = "E215";
312 pub const E_SETOP_ARITY_MISMATCH: &str = "E216";
313 pub const E_INCOMPATIBLE_COMPARISON_TYPES: &str = "E217";
314 pub const E_INVALID_CAST: &str = "E218";
315 pub const E_UNKNOWN_INFERRED_TYPE: &str = "E219";
316
317 pub const W_IMPLICIT_CAST_COMPARISON: &str = "W210";
318 pub const W_IMPLICIT_CAST_ARITHMETIC: &str = "W211";
319 pub const W_IMPLICIT_CAST_ASSIGNMENT: &str = "W212";
320 pub const W_LOSSY_CAST: &str = "W213";
321 pub const W_SETOP_IMPLICIT_COERCION: &str = "W214";
322 pub const W_PREDICATE_NULLABILITY: &str = "W215";
323 pub const W_FUNCTION_ARGUMENT_COERCION: &str = "W216";
324 pub const W_AGGREGATE_TYPE_COERCION: &str = "W217";
325 pub const W_POSSIBLE_OVERFLOW: &str = "W218";
326 pub const W_POSSIBLE_TRUNCATION: &str = "W219";
327
328 pub const E_INVALID_FOREIGN_KEY_REFERENCE: &str = "E220";
330 pub const E_AMBIGUOUS_COLUMN_REFERENCE: &str = "E221";
331 pub const E_UNRESOLVED_REFERENCE: &str = "E222";
332 pub const E_CTE_COLUMN_COUNT_MISMATCH: &str = "E223";
333 pub const E_MISSING_REFERENCE_TARGET: &str = "E224";
334
335 pub const W_CARTESIAN_JOIN: &str = "W220";
336 pub const W_JOIN_NOT_USING_DECLARED_REFERENCE: &str = "W221";
337 pub const W_WEAK_REFERENCE_INTEGRITY: &str = "W222";
338}
339
340#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
342#[serde(rename_all = "snake_case")]
343pub enum TypeFamily {
344 Unknown,
345 Boolean,
346 Integer,
347 Numeric,
348 String,
349 Binary,
350 Date,
351 Time,
352 Timestamp,
353 Interval,
354 Json,
355 Uuid,
356 Array,
357 Map,
358 Struct,
359}
360
361impl TypeFamily {
362 pub fn is_numeric(self) -> bool {
363 matches!(self, TypeFamily::Integer | TypeFamily::Numeric)
364 }
365
366 pub fn is_temporal(self) -> bool {
367 matches!(
368 self,
369 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval
370 )
371 }
372}
373
374#[derive(Debug, Clone)]
375struct TableSchemaEntry {
376 columns: HashMap<String, TypeFamily>,
377 column_order: Vec<String>,
378}
379
380fn lower(s: &str) -> String {
381 s.to_lowercase()
382}
383
384fn split_type_args(data_type: &str) -> Option<(&str, &str)> {
385 let open = data_type.find('(')?;
386 if !data_type.ends_with(')') || open + 1 >= data_type.len() {
387 return None;
388 }
389 let base = data_type[..open].trim();
390 let inner = data_type[open + 1..data_type.len() - 1].trim();
391 Some((base, inner))
392}
393
394pub fn canonical_type_family(data_type: &str) -> TypeFamily {
396 let trimmed = data_type
397 .trim()
398 .trim_matches(|c| c == '"' || c == '\'' || c == '`');
399 if trimmed.is_empty() {
400 return TypeFamily::Unknown;
401 }
402
403 let normalized = trimmed
405 .split_whitespace()
406 .collect::<Vec<_>>()
407 .join(" ")
408 .to_lowercase();
409
410 if let Some((base, inner)) = split_type_args(&normalized) {
412 match base {
413 "nullable" | "lowcardinality" => return canonical_type_family(inner),
414 "array" | "list" => return TypeFamily::Array,
415 "map" => return TypeFamily::Map,
416 "struct" | "row" | "record" => return TypeFamily::Struct,
417 _ => {}
418 }
419 }
420
421 if normalized.starts_with("array<") || normalized.starts_with("list<") {
422 return TypeFamily::Array;
423 }
424 if normalized.starts_with("map<") {
425 return TypeFamily::Map;
426 }
427 if normalized.starts_with("struct<")
428 || normalized.starts_with("row<")
429 || normalized.starts_with("record<")
430 || normalized.starts_with("object<")
431 {
432 return TypeFamily::Struct;
433 }
434
435 if normalized.ends_with("[]") {
436 return TypeFamily::Array;
437 }
438
439 let mut base = normalized
441 .split('(')
442 .next()
443 .unwrap_or("")
444 .trim()
445 .to_string();
446 if base.is_empty() {
447 return TypeFamily::Unknown;
448 }
449
450 base = base.strip_prefix("unsigned ").unwrap_or(&base).to_string();
451 base = base.strip_suffix(" unsigned").unwrap_or(&base).to_string();
452
453 match base.as_str() {
454 "bool" | "boolean" => TypeFamily::Boolean,
455 "tinyint" | "smallint" | "int2" | "int" | "integer" | "int4" | "int8" | "bigint"
456 | "serial" | "smallserial" | "bigserial" | "utinyint" | "usmallint" | "uinteger"
457 | "ubigint" | "uint8" | "uint16" | "uint32" | "uint64" | "int16" | "int32" | "int64" => {
458 TypeFamily::Integer
459 }
460 "numeric" | "decimal" | "dec" | "number" | "float" | "float4" | "float8" | "real"
461 | "double" | "double precision" | "bfloat16" | "float16" | "float32" | "float64" => {
462 TypeFamily::Numeric
463 }
464 "char" | "character" | "varchar" | "character varying" | "nchar" | "nvarchar" | "text"
465 | "string" | "clob" => TypeFamily::String,
466 "binary" | "varbinary" | "blob" | "bytea" | "bytes" => TypeFamily::Binary,
467 "date" => TypeFamily::Date,
468 "time" => TypeFamily::Time,
469 "timestamp"
470 | "timestamptz"
471 | "datetime"
472 | "datetime2"
473 | "smalldatetime"
474 | "timestamp with time zone"
475 | "timestamp without time zone" => TypeFamily::Timestamp,
476 "interval" => TypeFamily::Interval,
477 "json" | "jsonb" | "variant" => TypeFamily::Json,
478 "uuid" | "uniqueidentifier" => TypeFamily::Uuid,
479 "array" | "list" => TypeFamily::Array,
480 "map" => TypeFamily::Map,
481 "struct" | "row" | "record" | "object" => TypeFamily::Struct,
482 _ => TypeFamily::Unknown,
483 }
484}
485
486fn build_schema_map(schema: &ValidationSchema) -> HashMap<String, TableSchemaEntry> {
487 let mut map = HashMap::new();
488
489 for table in &schema.tables {
490 let column_order: Vec<String> = table.columns.iter().map(|c| lower(&c.name)).collect();
491 let columns: HashMap<String, TypeFamily> = table
492 .columns
493 .iter()
494 .map(|c| (lower(&c.name), canonical_type_family(&c.data_type)))
495 .collect();
496 let entry = TableSchemaEntry {
497 columns,
498 column_order,
499 };
500
501 let simple_name = lower(&table.name);
502 map.insert(simple_name, entry.clone());
503
504 if let Some(table_schema) = &table.schema {
505 map.insert(
506 format!("{}.{}", lower(table_schema), lower(&table.name)),
507 entry.clone(),
508 );
509 }
510
511 for alias in &table.aliases {
512 map.insert(lower(alias), entry.clone());
513 }
514 }
515
516 map
517}
518
519fn type_family_to_data_type(family: TypeFamily) -> DataType {
520 match family {
521 TypeFamily::Unknown => DataType::Unknown,
522 TypeFamily::Boolean => DataType::Boolean,
523 TypeFamily::Integer => DataType::Int {
524 length: None,
525 integer_spelling: false,
526 },
527 TypeFamily::Numeric => DataType::Double {
528 precision: None,
529 scale: None,
530 },
531 TypeFamily::String => DataType::VarChar {
532 length: None,
533 parenthesized_length: false,
534 },
535 TypeFamily::Binary => DataType::VarBinary { length: None },
536 TypeFamily::Date => DataType::Date,
537 TypeFamily::Time => DataType::Time {
538 precision: None,
539 timezone: false,
540 },
541 TypeFamily::Timestamp => DataType::Timestamp {
542 precision: None,
543 timezone: false,
544 },
545 TypeFamily::Interval => DataType::Interval {
546 unit: None,
547 to: None,
548 },
549 TypeFamily::Json => DataType::Json,
550 TypeFamily::Uuid => DataType::Uuid,
551 TypeFamily::Array => DataType::Array {
552 element_type: Box::new(DataType::Unknown),
553 dimension: None,
554 },
555 TypeFamily::Map => DataType::Map {
556 key_type: Box::new(DataType::Unknown),
557 value_type: Box::new(DataType::Unknown),
558 },
559 TypeFamily::Struct => DataType::Struct {
560 fields: Vec::new(),
561 nested: false,
562 },
563 }
564}
565
566fn build_resolver_schema(schema: &ValidationSchema) -> MappingSchema {
567 let mut mapping = MappingSchema::new();
568
569 for table in &schema.tables {
570 let columns: Vec<(String, DataType)> = table
571 .columns
572 .iter()
573 .map(|column| {
574 (
575 lower(&column.name),
576 type_family_to_data_type(canonical_type_family(&column.data_type)),
577 )
578 })
579 .collect();
580
581 let mut table_names = Vec::new();
582 table_names.push(lower(&table.name));
583 if let Some(table_schema) = &table.schema {
584 table_names.push(format!("{}.{}", lower(table_schema), lower(&table.name)));
585 }
586 for alias in &table.aliases {
587 table_names.push(lower(alias));
588 }
589
590 let mut dedup = HashSet::new();
591 for table_name in table_names {
592 if dedup.insert(table_name.clone()) {
593 let _ = mapping.add_table(&table_name, &columns, None);
594 }
595 }
596 }
597
598 mapping
599}
600
601fn collect_cte_aliases(expr: &Expression) -> HashSet<String> {
602 let mut aliases = HashSet::new();
603
604 for node in expr.dfs() {
605 match node {
606 Expression::Select(select) => {
607 if let Some(with) = &select.with {
608 for cte in &with.ctes {
609 aliases.insert(lower(&cte.alias.name));
610 }
611 }
612 }
613 Expression::Insert(insert) => {
614 if let Some(with) = &insert.with {
615 for cte in &with.ctes {
616 aliases.insert(lower(&cte.alias.name));
617 }
618 }
619 }
620 Expression::Update(update) => {
621 if let Some(with) = &update.with {
622 for cte in &with.ctes {
623 aliases.insert(lower(&cte.alias.name));
624 }
625 }
626 }
627 Expression::Delete(delete) => {
628 if let Some(with) = &delete.with {
629 for cte in &with.ctes {
630 aliases.insert(lower(&cte.alias.name));
631 }
632 }
633 }
634 Expression::Union(union) => {
635 if let Some(with) = &union.with {
636 for cte in &with.ctes {
637 aliases.insert(lower(&cte.alias.name));
638 }
639 }
640 }
641 Expression::Intersect(intersect) => {
642 if let Some(with) = &intersect.with {
643 for cte in &with.ctes {
644 aliases.insert(lower(&cte.alias.name));
645 }
646 }
647 }
648 Expression::Except(except) => {
649 if let Some(with) = &except.with {
650 for cte in &with.ctes {
651 aliases.insert(lower(&cte.alias.name));
652 }
653 }
654 }
655 Expression::Merge(merge) => {
656 if let Some(with_) = &merge.with_ {
657 if let Expression::With(with_clause) = with_.as_ref() {
658 for cte in &with_clause.ctes {
659 aliases.insert(lower(&cte.alias.name));
660 }
661 }
662 }
663 }
664 _ => {}
665 }
666 }
667
668 aliases
669}
670
671fn table_ref_candidates(table: &TableRef) -> Vec<String> {
672 let name = lower(&table.name.name);
673 let schema = table.schema.as_ref().map(|s| lower(&s.name));
674 let catalog = table.catalog.as_ref().map(|c| lower(&c.name));
675
676 let mut candidates = Vec::new();
677 if let (Some(catalog), Some(schema)) = (&catalog, &schema) {
678 candidates.push(format!("{}.{}.{}", catalog, schema, name));
679 }
680 if let Some(schema) = &schema {
681 candidates.push(format!("{}.{}", schema, name));
682 }
683 candidates.push(name);
684 candidates
685}
686
687fn table_ref_display_name(table: &TableRef) -> String {
688 let mut parts = Vec::new();
689 if let Some(catalog) = &table.catalog {
690 parts.push(catalog.name.clone());
691 }
692 if let Some(schema) = &table.schema {
693 parts.push(schema.name.clone());
694 }
695 parts.push(table.name.name.clone());
696 parts.join(".")
697}
698
699#[derive(Debug, Default, Clone)]
700struct TypeCheckContext {
701 referenced_tables: HashSet<String>,
702 table_aliases: HashMap<String, String>,
703}
704
705fn type_family_name(family: TypeFamily) -> &'static str {
706 match family {
707 TypeFamily::Unknown => "unknown",
708 TypeFamily::Boolean => "boolean",
709 TypeFamily::Integer => "integer",
710 TypeFamily::Numeric => "numeric",
711 TypeFamily::String => "string",
712 TypeFamily::Binary => "binary",
713 TypeFamily::Date => "date",
714 TypeFamily::Time => "time",
715 TypeFamily::Timestamp => "timestamp",
716 TypeFamily::Interval => "interval",
717 TypeFamily::Json => "json",
718 TypeFamily::Uuid => "uuid",
719 TypeFamily::Array => "array",
720 TypeFamily::Map => "map",
721 TypeFamily::Struct => "struct",
722 }
723}
724
725fn is_string_like(family: TypeFamily) -> bool {
726 matches!(family, TypeFamily::String)
727}
728
729fn is_string_or_binary(family: TypeFamily) -> bool {
730 matches!(family, TypeFamily::String | TypeFamily::Binary)
731}
732
733fn type_issue(
734 strict: bool,
735 error_code: &str,
736 warning_code: &str,
737 message: impl Into<String>,
738) -> ValidationError {
739 if strict {
740 ValidationError::error(message.into(), error_code)
741 } else {
742 ValidationError::warning(message.into(), warning_code)
743 }
744}
745
746fn data_type_family(data_type: &DataType) -> TypeFamily {
747 match data_type {
748 DataType::Boolean => TypeFamily::Boolean,
749 DataType::TinyInt { .. }
750 | DataType::SmallInt { .. }
751 | DataType::Int { .. }
752 | DataType::BigInt { .. } => TypeFamily::Integer,
753 DataType::Float { .. } | DataType::Double { .. } | DataType::Decimal { .. } => {
754 TypeFamily::Numeric
755 }
756 DataType::Char { .. }
757 | DataType::VarChar { .. }
758 | DataType::String { .. }
759 | DataType::Text
760 | DataType::TextWithLength { .. }
761 | DataType::CharacterSet { .. } => TypeFamily::String,
762 DataType::Binary { .. } | DataType::VarBinary { .. } | DataType::Blob => TypeFamily::Binary,
763 DataType::Date => TypeFamily::Date,
764 DataType::Time { .. } => TypeFamily::Time,
765 DataType::Timestamp { .. } => TypeFamily::Timestamp,
766 DataType::Interval { .. } => TypeFamily::Interval,
767 DataType::Json | DataType::JsonB => TypeFamily::Json,
768 DataType::Uuid => TypeFamily::Uuid,
769 DataType::Array { .. } | DataType::List { .. } => TypeFamily::Array,
770 DataType::Map { .. } => TypeFamily::Map,
771 DataType::Struct { .. } | DataType::Object { .. } | DataType::Union { .. } => {
772 TypeFamily::Struct
773 }
774 DataType::Nullable { inner } => data_type_family(inner),
775 DataType::Custom { name } => canonical_type_family(name),
776 DataType::Unknown => TypeFamily::Unknown,
777 DataType::Bit { .. } | DataType::VarBit { .. } => TypeFamily::Binary,
778 DataType::Enum { .. } | DataType::Set { .. } => TypeFamily::String,
779 DataType::Vector { .. } => TypeFamily::Array,
780 DataType::Geometry { .. } | DataType::Geography { .. } => TypeFamily::Struct,
781 }
782}
783
784fn collect_type_check_context(
785 stmt: &Expression,
786 schema_map: &HashMap<String, TableSchemaEntry>,
787) -> TypeCheckContext {
788 fn add_table_to_context(
789 table: &TableRef,
790 schema_map: &HashMap<String, TableSchemaEntry>,
791 context: &mut TypeCheckContext,
792 ) {
793 let resolved_key = table_ref_candidates(table)
794 .into_iter()
795 .find(|k| schema_map.contains_key(k));
796
797 let Some(table_key) = resolved_key else {
798 return;
799 };
800
801 context.referenced_tables.insert(table_key.clone());
802 context
803 .table_aliases
804 .insert(lower(&table.name.name), table_key.clone());
805 if let Some(alias) = &table.alias {
806 context
807 .table_aliases
808 .insert(lower(&alias.name), table_key.clone());
809 }
810 }
811
812 let mut context = TypeCheckContext::default();
813 let cte_aliases = collect_cte_aliases(stmt);
814
815 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
816 let Expression::Table(table) = node else {
817 continue;
818 };
819
820 if cte_aliases.contains(&lower(&table.name.name)) {
821 continue;
822 }
823
824 add_table_to_context(table, schema_map, &mut context);
825 }
826
827 match stmt {
830 Expression::Insert(insert) => {
831 add_table_to_context(&insert.table, schema_map, &mut context);
832 }
833 Expression::Update(update) => {
834 add_table_to_context(&update.table, schema_map, &mut context);
835 for table in &update.extra_tables {
836 add_table_to_context(table, schema_map, &mut context);
837 }
838 }
839 Expression::Delete(delete) => {
840 add_table_to_context(&delete.table, schema_map, &mut context);
841 for table in &delete.using {
842 add_table_to_context(table, schema_map, &mut context);
843 }
844 for table in &delete.tables {
845 add_table_to_context(table, schema_map, &mut context);
846 }
847 }
848 _ => {}
849 }
850
851 context
852}
853
854fn resolve_table_schema_entry<'a>(
855 table: &TableRef,
856 schema_map: &'a HashMap<String, TableSchemaEntry>,
857) -> Option<(String, &'a TableSchemaEntry)> {
858 let key = table_ref_candidates(table)
859 .into_iter()
860 .find(|k| schema_map.contains_key(k))?;
861 let entry = schema_map.get(&key)?;
862 Some((key, entry))
863}
864
865fn reference_issue(strict: bool, message: impl Into<String>) -> ValidationError {
866 if strict {
867 ValidationError::error(
868 message.into(),
869 validation_codes::E_INVALID_FOREIGN_KEY_REFERENCE,
870 )
871 } else {
872 ValidationError::warning(message.into(), validation_codes::W_WEAK_REFERENCE_INTEGRITY)
873 }
874}
875
876fn reference_table_candidates(
877 table_name: &str,
878 explicit_schema: Option<&str>,
879 source_schema: Option<&str>,
880) -> Vec<String> {
881 let mut candidates = Vec::new();
882 let raw = lower(table_name);
883
884 if let Some(schema) = explicit_schema {
885 candidates.push(format!("{}.{}", lower(schema), raw));
886 }
887
888 if raw.contains('.') {
889 candidates.push(raw.clone());
890 if let Some(last) = raw.rsplit('.').next() {
891 candidates.push(last.to_string());
892 }
893 } else {
894 if let Some(schema) = source_schema {
895 candidates.push(format!("{}.{}", lower(schema), raw));
896 }
897 candidates.push(raw);
898 }
899
900 let mut dedup = HashSet::new();
901 candidates
902 .into_iter()
903 .filter(|c| dedup.insert(c.clone()))
904 .collect()
905}
906
907fn resolve_reference_table_key(
908 table_name: &str,
909 explicit_schema: Option<&str>,
910 source_schema: Option<&str>,
911 schema_map: &HashMap<String, TableSchemaEntry>,
912) -> Option<String> {
913 reference_table_candidates(table_name, explicit_schema, source_schema)
914 .into_iter()
915 .find(|candidate| schema_map.contains_key(candidate))
916}
917
918fn key_types_compatible(source: TypeFamily, target: TypeFamily) -> bool {
919 if source == TypeFamily::Unknown || target == TypeFamily::Unknown {
920 return true;
921 }
922 if source == target {
923 return true;
924 }
925 if source.is_numeric() && target.is_numeric() {
926 return true;
927 }
928 if source.is_temporal() && target.is_temporal() {
929 return true;
930 }
931 false
932}
933
934fn table_key_hints(table: &SchemaTable) -> HashSet<String> {
935 let mut hints = HashSet::new();
936 for column in &table.columns {
937 if column.primary_key || column.unique {
938 hints.insert(lower(&column.name));
939 }
940 }
941 for key_col in &table.primary_key {
942 hints.insert(lower(key_col));
943 }
944 for group in &table.unique_keys {
945 if group.len() == 1 {
946 if let Some(col) = group.first() {
947 hints.insert(lower(col));
948 }
949 }
950 }
951 hints
952}
953
954fn check_reference_integrity(
955 schema: &ValidationSchema,
956 schema_map: &HashMap<String, TableSchemaEntry>,
957 strict: bool,
958) -> Vec<ValidationError> {
959 let mut errors = Vec::new();
960
961 let mut key_hints_lookup: HashMap<String, HashSet<String>> = HashMap::new();
962 for table in &schema.tables {
963 let simple = lower(&table.name);
964 key_hints_lookup.insert(simple, table_key_hints(table));
965 if let Some(schema_name) = &table.schema {
966 let qualified = format!("{}.{}", lower(schema_name), lower(&table.name));
967 key_hints_lookup.insert(qualified, table_key_hints(table));
968 }
969 }
970
971 for table in &schema.tables {
972 let source_table_display = if let Some(schema_name) = &table.schema {
973 format!("{}.{}", schema_name, table.name)
974 } else {
975 table.name.clone()
976 };
977 let source_schema = table.schema.as_deref();
978 let source_columns: HashMap<String, TypeFamily> = table
979 .columns
980 .iter()
981 .map(|col| (lower(&col.name), canonical_type_family(&col.data_type)))
982 .collect();
983
984 for source_col in &table.columns {
985 let Some(reference) = &source_col.references else {
986 continue;
987 };
988 let source_type = canonical_type_family(&source_col.data_type);
989
990 let Some(target_key) = resolve_reference_table_key(
991 &reference.table,
992 reference.schema.as_deref(),
993 source_schema,
994 schema_map,
995 ) else {
996 errors.push(reference_issue(
997 strict,
998 format!(
999 "Foreign key reference '{}.{}' points to unknown table '{}'",
1000 source_table_display, source_col.name, reference.table
1001 ),
1002 ));
1003 continue;
1004 };
1005
1006 let target_column = lower(&reference.column);
1007 let Some(target_entry) = schema_map.get(&target_key) else {
1008 errors.push(reference_issue(
1009 strict,
1010 format!(
1011 "Foreign key reference '{}.{}' points to unknown table '{}'",
1012 source_table_display, source_col.name, reference.table
1013 ),
1014 ));
1015 continue;
1016 };
1017
1018 let Some(target_type) = target_entry.columns.get(&target_column).copied() else {
1019 errors.push(reference_issue(
1020 strict,
1021 format!(
1022 "Foreign key reference '{}.{}' points to unknown column '{}.{}'",
1023 source_table_display, source_col.name, target_key, reference.column
1024 ),
1025 ));
1026 continue;
1027 };
1028
1029 if !key_types_compatible(source_type, target_type) {
1030 errors.push(reference_issue(
1031 strict,
1032 format!(
1033 "Foreign key type mismatch for '{}.{}' -> '{}.{}': {} vs {}",
1034 source_table_display,
1035 source_col.name,
1036 target_key,
1037 reference.column,
1038 type_family_name(source_type),
1039 type_family_name(target_type)
1040 ),
1041 ));
1042 }
1043
1044 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1045 if !target_key_hints.contains(&target_column) {
1046 errors.push(ValidationError::warning(
1047 format!(
1048 "Referenced column '{}.{}' is not marked as primary/unique key",
1049 target_key, reference.column
1050 ),
1051 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1052 ));
1053 }
1054 }
1055 }
1056
1057 for foreign_key in &table.foreign_keys {
1058 if foreign_key.columns.is_empty() || foreign_key.references.columns.is_empty() {
1059 errors.push(reference_issue(
1060 strict,
1061 format!(
1062 "Table-level foreign key on '{}' has empty source or target column list",
1063 source_table_display
1064 ),
1065 ));
1066 continue;
1067 }
1068 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1069 errors.push(reference_issue(
1070 strict,
1071 format!(
1072 "Table-level foreign key on '{}' has {} source columns but {} target columns",
1073 source_table_display,
1074 foreign_key.columns.len(),
1075 foreign_key.references.columns.len()
1076 ),
1077 ));
1078 continue;
1079 }
1080
1081 let Some(target_key) = resolve_reference_table_key(
1082 &foreign_key.references.table,
1083 foreign_key.references.schema.as_deref(),
1084 source_schema,
1085 schema_map,
1086 ) else {
1087 errors.push(reference_issue(
1088 strict,
1089 format!(
1090 "Table-level foreign key on '{}' points to unknown table '{}'",
1091 source_table_display, foreign_key.references.table
1092 ),
1093 ));
1094 continue;
1095 };
1096
1097 let Some(target_entry) = schema_map.get(&target_key) else {
1098 errors.push(reference_issue(
1099 strict,
1100 format!(
1101 "Table-level foreign key on '{}' points to unknown table '{}'",
1102 source_table_display, foreign_key.references.table
1103 ),
1104 ));
1105 continue;
1106 };
1107
1108 for (source_col, target_col) in foreign_key
1109 .columns
1110 .iter()
1111 .zip(foreign_key.references.columns.iter())
1112 {
1113 let source_col_name = lower(source_col);
1114 let target_col_name = lower(target_col);
1115
1116 let Some(source_type) = source_columns.get(&source_col_name).copied() else {
1117 errors.push(reference_issue(
1118 strict,
1119 format!(
1120 "Table-level foreign key on '{}' references unknown source column '{}'",
1121 source_table_display, source_col
1122 ),
1123 ));
1124 continue;
1125 };
1126
1127 let Some(target_type) = target_entry.columns.get(&target_col_name).copied() else {
1128 errors.push(reference_issue(
1129 strict,
1130 format!(
1131 "Table-level foreign key on '{}' references unknown target column '{}.{}'",
1132 source_table_display, target_key, target_col
1133 ),
1134 ));
1135 continue;
1136 };
1137
1138 if !key_types_compatible(source_type, target_type) {
1139 errors.push(reference_issue(
1140 strict,
1141 format!(
1142 "Table-level foreign key type mismatch '{}.{}' -> '{}.{}': {} vs {}",
1143 source_table_display,
1144 source_col,
1145 target_key,
1146 target_col,
1147 type_family_name(source_type),
1148 type_family_name(target_type)
1149 ),
1150 ));
1151 }
1152
1153 if let Some(target_key_hints) = key_hints_lookup.get(&target_key) {
1154 if !target_key_hints.contains(&target_col_name) {
1155 errors.push(ValidationError::warning(
1156 format!(
1157 "Referenced column '{}.{}' is not marked as primary/unique key",
1158 target_key, target_col
1159 ),
1160 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
1161 ));
1162 }
1163 }
1164 }
1165 }
1166 }
1167
1168 errors
1169}
1170
1171fn resolve_unqualified_column_type(
1172 column_name: &str,
1173 schema_map: &HashMap<String, TableSchemaEntry>,
1174 context: &TypeCheckContext,
1175) -> TypeFamily {
1176 let candidate_tables: Vec<&String> = if !context.referenced_tables.is_empty() {
1177 context.referenced_tables.iter().collect()
1178 } else {
1179 schema_map.keys().collect()
1180 };
1181
1182 let mut families = HashSet::new();
1183 for table_name in candidate_tables {
1184 if let Some(table_schema) = schema_map.get(table_name) {
1185 if let Some(family) = table_schema.columns.get(column_name) {
1186 families.insert(*family);
1187 }
1188 }
1189 }
1190
1191 if families.len() == 1 {
1192 *families.iter().next().unwrap_or(&TypeFamily::Unknown)
1193 } else {
1194 TypeFamily::Unknown
1195 }
1196}
1197
1198fn resolve_column_type(
1199 column: &Column,
1200 schema_map: &HashMap<String, TableSchemaEntry>,
1201 context: &TypeCheckContext,
1202) -> TypeFamily {
1203 let column_name = lower(&column.name.name);
1204 if column_name.is_empty() {
1205 return TypeFamily::Unknown;
1206 }
1207
1208 if let Some(table) = &column.table {
1209 let mut table_key = lower(&table.name);
1210 if let Some(mapped) = context.table_aliases.get(&table_key) {
1211 table_key = mapped.clone();
1212 }
1213
1214 return schema_map
1215 .get(&table_key)
1216 .and_then(|t| t.columns.get(&column_name))
1217 .copied()
1218 .unwrap_or(TypeFamily::Unknown);
1219 }
1220
1221 resolve_unqualified_column_type(&column_name, schema_map, context)
1222}
1223
1224struct TypeInferenceSchema<'a> {
1225 schema_map: &'a HashMap<String, TableSchemaEntry>,
1226 context: &'a TypeCheckContext,
1227}
1228
1229impl TypeInferenceSchema<'_> {
1230 fn resolve_table_key(&self, table: &str) -> Option<String> {
1231 let mut table_key = lower(table);
1232 if let Some(mapped) = self.context.table_aliases.get(&table_key) {
1233 table_key = mapped.clone();
1234 }
1235 if self.schema_map.contains_key(&table_key) {
1236 Some(table_key)
1237 } else {
1238 None
1239 }
1240 }
1241}
1242
1243impl SqlSchema for TypeInferenceSchema<'_> {
1244 fn dialect(&self) -> Option<DialectType> {
1245 None
1246 }
1247
1248 fn add_table(
1249 &mut self,
1250 _table: &str,
1251 _columns: &[(String, DataType)],
1252 _dialect: Option<DialectType>,
1253 ) -> SchemaResult<()> {
1254 Err(SchemaError::InvalidStructure(
1255 "Type inference schema is read-only".to_string(),
1256 ))
1257 }
1258
1259 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
1260 let table_key = self
1261 .resolve_table_key(table)
1262 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1263 let entry = self
1264 .schema_map
1265 .get(&table_key)
1266 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1267 Ok(entry.column_order.clone())
1268 }
1269
1270 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
1271 let col_name = lower(column);
1272 if table.is_empty() {
1273 let family = resolve_unqualified_column_type(&col_name, self.schema_map, self.context);
1274 return if family == TypeFamily::Unknown {
1275 Err(SchemaError::ColumnNotFound {
1276 table: "<unqualified>".to_string(),
1277 column: column.to_string(),
1278 })
1279 } else {
1280 Ok(type_family_to_data_type(family))
1281 };
1282 }
1283
1284 let table_key = self
1285 .resolve_table_key(table)
1286 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1287 let entry = self
1288 .schema_map
1289 .get(&table_key)
1290 .ok_or_else(|| SchemaError::TableNotFound(table.to_string()))?;
1291 let family =
1292 entry
1293 .columns
1294 .get(&col_name)
1295 .copied()
1296 .ok_or_else(|| SchemaError::ColumnNotFound {
1297 table: table.to_string(),
1298 column: column.to_string(),
1299 })?;
1300 Ok(type_family_to_data_type(family))
1301 }
1302
1303 fn has_column(&self, table: &str, column: &str) -> bool {
1304 self.get_column_type(table, column).is_ok()
1305 }
1306
1307 fn supported_table_args(&self) -> &[&str] {
1308 TABLE_PARTS
1309 }
1310
1311 fn is_empty(&self) -> bool {
1312 self.schema_map.is_empty()
1313 }
1314
1315 fn depth(&self) -> usize {
1316 1
1317 }
1318}
1319
1320fn infer_expression_type_family(
1321 expr: &Expression,
1322 schema_map: &HashMap<String, TableSchemaEntry>,
1323 context: &TypeCheckContext,
1324) -> TypeFamily {
1325 let inference_schema = TypeInferenceSchema {
1326 schema_map,
1327 context,
1328 };
1329 if let Some(data_type) = annotate_types(expr, Some(&inference_schema), None) {
1330 let family = data_type_family(&data_type);
1331 if family != TypeFamily::Unknown {
1332 return family;
1333 }
1334 }
1335
1336 infer_expression_type_family_fallback(expr, schema_map, context)
1337}
1338
1339fn infer_expression_type_family_fallback(
1340 expr: &Expression,
1341 schema_map: &HashMap<String, TableSchemaEntry>,
1342 context: &TypeCheckContext,
1343) -> TypeFamily {
1344 match expr {
1345 Expression::Literal(literal) => match literal {
1346 crate::expressions::Literal::Number(value) => {
1347 if value.contains('.') || value.contains('e') || value.contains('E') {
1348 TypeFamily::Numeric
1349 } else {
1350 TypeFamily::Integer
1351 }
1352 }
1353 crate::expressions::Literal::HexNumber(_) => TypeFamily::Integer,
1354 crate::expressions::Literal::Date(_) => TypeFamily::Date,
1355 crate::expressions::Literal::Time(_) => TypeFamily::Time,
1356 crate::expressions::Literal::Timestamp(_)
1357 | crate::expressions::Literal::Datetime(_) => TypeFamily::Timestamp,
1358 crate::expressions::Literal::HexString(_)
1359 | crate::expressions::Literal::BitString(_)
1360 | crate::expressions::Literal::ByteString(_) => TypeFamily::Binary,
1361 _ => TypeFamily::String,
1362 },
1363 Expression::Boolean(_) => TypeFamily::Boolean,
1364 Expression::Null(_) => TypeFamily::Unknown,
1365 Expression::Column(column) => resolve_column_type(column, schema_map, context),
1366 Expression::Cast(cast) | Expression::TryCast(cast) | Expression::SafeCast(cast) => {
1367 data_type_family(&cast.to)
1368 }
1369 Expression::Alias(alias) => {
1370 infer_expression_type_family_fallback(&alias.this, schema_map, context)
1371 }
1372 Expression::Neg(unary) => {
1373 infer_expression_type_family_fallback(&unary.this, schema_map, context)
1374 }
1375 Expression::Add(op) | Expression::Sub(op) | Expression::Mul(op) => {
1376 let left = infer_expression_type_family_fallback(&op.left, schema_map, context);
1377 let right = infer_expression_type_family_fallback(&op.right, schema_map, context);
1378 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1379 TypeFamily::Unknown
1380 } else if left == TypeFamily::Integer && right == TypeFamily::Integer {
1381 TypeFamily::Integer
1382 } else if left.is_numeric() && right.is_numeric() {
1383 TypeFamily::Numeric
1384 } else if left.is_temporal() || right.is_temporal() {
1385 left
1386 } else {
1387 TypeFamily::Unknown
1388 }
1389 }
1390 Expression::Div(_) | Expression::Mod(_) => TypeFamily::Numeric,
1391 Expression::Concat(_) => TypeFamily::String,
1392 Expression::Eq(_)
1393 | Expression::Neq(_)
1394 | Expression::Lt(_)
1395 | Expression::Lte(_)
1396 | Expression::Gt(_)
1397 | Expression::Gte(_)
1398 | Expression::Like(_)
1399 | Expression::ILike(_)
1400 | Expression::And(_)
1401 | Expression::Or(_)
1402 | Expression::Not(_)
1403 | Expression::Between(_)
1404 | Expression::In(_)
1405 | Expression::IsNull(_)
1406 | Expression::IsTrue(_)
1407 | Expression::IsFalse(_)
1408 | Expression::Is(_) => TypeFamily::Boolean,
1409 Expression::Length(_) => TypeFamily::Integer,
1410 Expression::Upper(_)
1411 | Expression::Lower(_)
1412 | Expression::Trim(_)
1413 | Expression::LTrim(_)
1414 | Expression::RTrim(_)
1415 | Expression::Replace(_)
1416 | Expression::Substring(_)
1417 | Expression::Left(_)
1418 | Expression::Right(_)
1419 | Expression::Repeat(_)
1420 | Expression::Lpad(_)
1421 | Expression::Rpad(_)
1422 | Expression::ConcatWs(_) => TypeFamily::String,
1423 Expression::Abs(_)
1424 | Expression::Round(_)
1425 | Expression::Floor(_)
1426 | Expression::Ceil(_)
1427 | Expression::Power(_)
1428 | Expression::Sqrt(_)
1429 | Expression::Cbrt(_)
1430 | Expression::Ln(_)
1431 | Expression::Log(_)
1432 | Expression::Exp(_) => TypeFamily::Numeric,
1433 Expression::DateAdd(_) | Expression::DateSub(_) | Expression::ToDate(_) => TypeFamily::Date,
1434 Expression::ToTimestamp(_) => TypeFamily::Timestamp,
1435 Expression::DateDiff(_) | Expression::Extract(_) => TypeFamily::Integer,
1436 Expression::CurrentDate(_) => TypeFamily::Date,
1437 Expression::CurrentTime(_) => TypeFamily::Time,
1438 Expression::CurrentTimestamp(_) | Expression::CurrentTimestampLTZ(_) => {
1439 TypeFamily::Timestamp
1440 }
1441 Expression::Interval(_) => TypeFamily::Interval,
1442 _ => TypeFamily::Unknown,
1443 }
1444}
1445
1446fn are_comparable(left: TypeFamily, right: TypeFamily) -> bool {
1447 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
1448 return true;
1449 }
1450 if left == right {
1451 return true;
1452 }
1453 if left.is_numeric() && right.is_numeric() {
1454 return true;
1455 }
1456 if left.is_temporal() && right.is_temporal() {
1457 return true;
1458 }
1459 false
1460}
1461
1462fn check_function_argument(
1463 errors: &mut Vec<ValidationError>,
1464 strict: bool,
1465 function_name: &str,
1466 arg_index: usize,
1467 family: TypeFamily,
1468 expected: &str,
1469 valid: bool,
1470) {
1471 if family == TypeFamily::Unknown || valid {
1472 return;
1473 }
1474
1475 errors.push(type_issue(
1476 strict,
1477 validation_codes::E_INVALID_FUNCTION_ARGUMENT_TYPE,
1478 validation_codes::W_FUNCTION_ARGUMENT_COERCION,
1479 format!(
1480 "Function '{}' argument {} expects {}, found {}",
1481 function_name,
1482 arg_index + 1,
1483 expected,
1484 type_family_name(family)
1485 ),
1486 ));
1487}
1488
1489fn function_dispatch_name(name: &str) -> String {
1490 let upper = name
1491 .rsplit('.')
1492 .next()
1493 .unwrap_or(name)
1494 .trim()
1495 .to_uppercase();
1496 lower(canonical_typed_function_name_upper(&upper))
1497}
1498
1499fn function_base_name(name: &str) -> &str {
1500 name.rsplit('.').next().unwrap_or(name).trim()
1501}
1502
1503fn check_generic_function(
1504 function: &Function,
1505 schema_map: &HashMap<String, TableSchemaEntry>,
1506 context: &TypeCheckContext,
1507 strict: bool,
1508 errors: &mut Vec<ValidationError>,
1509) {
1510 let name = function_dispatch_name(&function.name);
1511
1512 let arg_family = |index: usize| -> Option<TypeFamily> {
1513 function
1514 .args
1515 .get(index)
1516 .map(|arg| infer_expression_type_family(arg, schema_map, context))
1517 };
1518
1519 match name.as_str() {
1520 "abs" | "sqrt" | "cbrt" | "ln" | "exp" => {
1521 if let Some(family) = arg_family(0) {
1522 check_function_argument(
1523 errors,
1524 strict,
1525 &name,
1526 0,
1527 family,
1528 "a numeric argument",
1529 family.is_numeric(),
1530 );
1531 }
1532 }
1533 "round" | "floor" | "ceil" | "ceiling" => {
1534 if let Some(family) = arg_family(0) {
1535 check_function_argument(
1536 errors,
1537 strict,
1538 &name,
1539 0,
1540 family,
1541 "a numeric argument",
1542 family.is_numeric(),
1543 );
1544 }
1545 if let Some(family) = arg_family(1) {
1546 check_function_argument(
1547 errors,
1548 strict,
1549 &name,
1550 1,
1551 family,
1552 "a numeric argument",
1553 family.is_numeric(),
1554 );
1555 }
1556 }
1557 "power" | "pow" => {
1558 for i in [0_usize, 1_usize] {
1559 if let Some(family) = arg_family(i) {
1560 check_function_argument(
1561 errors,
1562 strict,
1563 &name,
1564 i,
1565 family,
1566 "a numeric argument",
1567 family.is_numeric(),
1568 );
1569 }
1570 }
1571 }
1572 "length" | "char_length" | "character_length" => {
1573 if let Some(family) = arg_family(0) {
1574 check_function_argument(
1575 errors,
1576 strict,
1577 &name,
1578 0,
1579 family,
1580 "a string or binary argument",
1581 is_string_or_binary(family),
1582 );
1583 }
1584 }
1585 "upper" | "lower" | "trim" | "ltrim" | "rtrim" | "reverse" => {
1586 if let Some(family) = arg_family(0) {
1587 check_function_argument(
1588 errors,
1589 strict,
1590 &name,
1591 0,
1592 family,
1593 "a string argument",
1594 is_string_like(family),
1595 );
1596 }
1597 }
1598 "substring" | "substr" => {
1599 if let Some(family) = arg_family(0) {
1600 check_function_argument(
1601 errors,
1602 strict,
1603 &name,
1604 0,
1605 family,
1606 "a string argument",
1607 is_string_like(family),
1608 );
1609 }
1610 if let Some(family) = arg_family(1) {
1611 check_function_argument(
1612 errors,
1613 strict,
1614 &name,
1615 1,
1616 family,
1617 "a numeric argument",
1618 family.is_numeric(),
1619 );
1620 }
1621 if let Some(family) = arg_family(2) {
1622 check_function_argument(
1623 errors,
1624 strict,
1625 &name,
1626 2,
1627 family,
1628 "a numeric argument",
1629 family.is_numeric(),
1630 );
1631 }
1632 }
1633 "replace" => {
1634 for i in [0_usize, 1_usize, 2_usize] {
1635 if let Some(family) = arg_family(i) {
1636 check_function_argument(
1637 errors,
1638 strict,
1639 &name,
1640 i,
1641 family,
1642 "a string argument",
1643 is_string_like(family),
1644 );
1645 }
1646 }
1647 }
1648 "left" | "right" | "repeat" | "lpad" | "rpad" => {
1649 if let Some(family) = arg_family(0) {
1650 check_function_argument(
1651 errors,
1652 strict,
1653 &name,
1654 0,
1655 family,
1656 "a string argument",
1657 is_string_like(family),
1658 );
1659 }
1660 if let Some(family) = arg_family(1) {
1661 check_function_argument(
1662 errors,
1663 strict,
1664 &name,
1665 1,
1666 family,
1667 "a numeric argument",
1668 family.is_numeric(),
1669 );
1670 }
1671 if (name == "lpad" || name == "rpad") && function.args.len() > 2 {
1672 if let Some(family) = arg_family(2) {
1673 check_function_argument(
1674 errors,
1675 strict,
1676 &name,
1677 2,
1678 family,
1679 "a string argument",
1680 is_string_like(family),
1681 );
1682 }
1683 }
1684 }
1685 _ => {}
1686 }
1687}
1688
1689fn check_function_catalog(
1690 function: &Function,
1691 dialect: DialectType,
1692 function_catalog: Option<&dyn FunctionCatalog>,
1693 strict: bool,
1694 errors: &mut Vec<ValidationError>,
1695) {
1696 let Some(catalog) = function_catalog else {
1697 return;
1698 };
1699
1700 let raw_name = function_base_name(&function.name);
1701 let normalized_name = function_dispatch_name(&function.name);
1702 let arity = function.args.len();
1703 let Some(signatures) = catalog.lookup(dialect, raw_name, &normalized_name) else {
1704 errors.push(if strict {
1705 ValidationError::error(
1706 format!(
1707 "Unknown function '{}' for dialect {:?}",
1708 function.name, dialect
1709 ),
1710 validation_codes::E_UNKNOWN_FUNCTION,
1711 )
1712 } else {
1713 ValidationError::warning(
1714 format!(
1715 "Unknown function '{}' for dialect {:?}",
1716 function.name, dialect
1717 ),
1718 validation_codes::E_UNKNOWN_FUNCTION,
1719 )
1720 });
1721 return;
1722 };
1723
1724 if signatures.iter().any(|sig| sig.matches_arity(arity)) {
1725 return;
1726 }
1727
1728 let expected = signatures
1729 .iter()
1730 .map(|sig| sig.describe_arity())
1731 .collect::<Vec<_>>()
1732 .join(", ");
1733 errors.push(if strict {
1734 ValidationError::error(
1735 format!(
1736 "Invalid arity for function '{}': got {}, expected {}",
1737 function.name, arity, expected
1738 ),
1739 validation_codes::E_INVALID_FUNCTION_ARITY,
1740 )
1741 } else {
1742 ValidationError::warning(
1743 format!(
1744 "Invalid arity for function '{}': got {}, expected {}",
1745 function.name, arity, expected
1746 ),
1747 validation_codes::E_INVALID_FUNCTION_ARITY,
1748 )
1749 });
1750}
1751
1752#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1753struct DeclaredRelationship {
1754 source_table: String,
1755 source_column: String,
1756 target_table: String,
1757 target_column: String,
1758}
1759
1760fn build_declared_relationships(
1761 schema: &ValidationSchema,
1762 schema_map: &HashMap<String, TableSchemaEntry>,
1763) -> Vec<DeclaredRelationship> {
1764 let mut relationships = HashSet::new();
1765
1766 for table in &schema.tables {
1767 let Some(source_key) =
1768 resolve_reference_table_key(&table.name, table.schema.as_deref(), None, schema_map)
1769 else {
1770 continue;
1771 };
1772
1773 for column in &table.columns {
1774 let Some(reference) = &column.references else {
1775 continue;
1776 };
1777 let Some(target_key) = resolve_reference_table_key(
1778 &reference.table,
1779 reference.schema.as_deref(),
1780 table.schema.as_deref(),
1781 schema_map,
1782 ) else {
1783 continue;
1784 };
1785
1786 relationships.insert(DeclaredRelationship {
1787 source_table: source_key.clone(),
1788 source_column: lower(&column.name),
1789 target_table: target_key,
1790 target_column: lower(&reference.column),
1791 });
1792 }
1793
1794 for foreign_key in &table.foreign_keys {
1795 if foreign_key.columns.len() != foreign_key.references.columns.len() {
1796 continue;
1797 }
1798 let Some(target_key) = resolve_reference_table_key(
1799 &foreign_key.references.table,
1800 foreign_key.references.schema.as_deref(),
1801 table.schema.as_deref(),
1802 schema_map,
1803 ) else {
1804 continue;
1805 };
1806
1807 for (source_col, target_col) in foreign_key
1808 .columns
1809 .iter()
1810 .zip(foreign_key.references.columns.iter())
1811 {
1812 relationships.insert(DeclaredRelationship {
1813 source_table: source_key.clone(),
1814 source_column: lower(source_col),
1815 target_table: target_key.clone(),
1816 target_column: lower(target_col),
1817 });
1818 }
1819 }
1820 }
1821
1822 relationships.into_iter().collect()
1823}
1824
1825fn resolve_column_binding(
1826 column: &Column,
1827 schema_map: &HashMap<String, TableSchemaEntry>,
1828 context: &TypeCheckContext,
1829 resolver: &mut Resolver<'_>,
1830) -> Option<(String, String)> {
1831 let column_name = lower(&column.name.name);
1832 if column_name.is_empty() {
1833 return None;
1834 }
1835
1836 if let Some(table) = &column.table {
1837 let mut table_key = lower(&table.name);
1838 if let Some(mapped) = context.table_aliases.get(&table_key) {
1839 table_key = mapped.clone();
1840 }
1841 if schema_map.contains_key(&table_key) {
1842 return Some((table_key, column_name));
1843 }
1844 return None;
1845 }
1846
1847 if let Some(resolved_source) = resolver.get_table(&column_name) {
1848 let mut table_key = lower(&resolved_source);
1849 if let Some(mapped) = context.table_aliases.get(&table_key) {
1850 table_key = mapped.clone();
1851 }
1852 if schema_map.contains_key(&table_key) {
1853 return Some((table_key, column_name));
1854 }
1855 }
1856
1857 let candidates: Vec<String> = context
1858 .referenced_tables
1859 .iter()
1860 .filter_map(|table_name| {
1861 schema_map
1862 .get(table_name)
1863 .filter(|entry| entry.columns.contains_key(&column_name))
1864 .map(|_| table_name.clone())
1865 })
1866 .collect();
1867 if candidates.len() == 1 {
1868 return Some((candidates[0].clone(), column_name));
1869 }
1870 None
1871}
1872
1873fn extract_join_equality_pairs(
1874 expr: &Expression,
1875 schema_map: &HashMap<String, TableSchemaEntry>,
1876 context: &TypeCheckContext,
1877 resolver: &mut Resolver<'_>,
1878 pairs: &mut Vec<((String, String), (String, String))>,
1879) {
1880 match expr {
1881 Expression::And(op) => {
1882 extract_join_equality_pairs(&op.left, schema_map, context, resolver, pairs);
1883 extract_join_equality_pairs(&op.right, schema_map, context, resolver, pairs);
1884 }
1885 Expression::Paren(paren) => {
1886 extract_join_equality_pairs(&paren.this, schema_map, context, resolver, pairs);
1887 }
1888 Expression::Eq(op) => {
1889 let (Expression::Column(left_col), Expression::Column(right_col)) =
1890 (&op.left, &op.right)
1891 else {
1892 return;
1893 };
1894 let Some(left) = resolve_column_binding(left_col, schema_map, context, resolver) else {
1895 return;
1896 };
1897 let Some(right) = resolve_column_binding(right_col, schema_map, context, resolver)
1898 else {
1899 return;
1900 };
1901 pairs.push((left, right));
1902 }
1903 _ => {}
1904 }
1905}
1906
1907fn relationship_matches_pair(
1908 relationship: &DeclaredRelationship,
1909 left_table: &str,
1910 left_column: &str,
1911 right_table: &str,
1912 right_column: &str,
1913) -> bool {
1914 (relationship.source_table == left_table
1915 && relationship.source_column == left_column
1916 && relationship.target_table == right_table
1917 && relationship.target_column == right_column)
1918 || (relationship.source_table == right_table
1919 && relationship.source_column == right_column
1920 && relationship.target_table == left_table
1921 && relationship.target_column == left_column)
1922}
1923
1924fn resolved_table_key_from_expr(
1925 expr: &Expression,
1926 schema_map: &HashMap<String, TableSchemaEntry>,
1927) -> Option<String> {
1928 match expr {
1929 Expression::Table(table) => resolve_table_schema_entry(table, schema_map).map(|(k, _)| k),
1930 Expression::Alias(alias) => resolved_table_key_from_expr(&alias.this, schema_map),
1931 _ => None,
1932 }
1933}
1934
1935fn select_from_table_keys(
1936 select: &crate::expressions::Select,
1937 schema_map: &HashMap<String, TableSchemaEntry>,
1938) -> HashSet<String> {
1939 let mut keys = HashSet::new();
1940 if let Some(from_clause) = &select.from {
1941 for expr in &from_clause.expressions {
1942 if let Some(key) = resolved_table_key_from_expr(expr, schema_map) {
1943 keys.insert(key);
1944 }
1945 }
1946 }
1947 keys
1948}
1949
1950fn is_natural_or_implied_join(kind: JoinKind) -> bool {
1951 matches!(
1952 kind,
1953 JoinKind::Natural
1954 | JoinKind::NaturalLeft
1955 | JoinKind::NaturalRight
1956 | JoinKind::NaturalFull
1957 | JoinKind::CrossApply
1958 | JoinKind::OuterApply
1959 | JoinKind::AsOf
1960 | JoinKind::AsOfLeft
1961 | JoinKind::AsOfRight
1962 | JoinKind::Lateral
1963 | JoinKind::LeftLateral
1964 )
1965}
1966
1967fn check_query_reference_quality(
1968 stmt: &Expression,
1969 schema_map: &HashMap<String, TableSchemaEntry>,
1970 resolver_schema: &MappingSchema,
1971 strict: bool,
1972 relationships: &[DeclaredRelationship],
1973) -> Vec<ValidationError> {
1974 let mut errors = Vec::new();
1975
1976 for node in stmt.dfs() {
1977 let Expression::Select(select) = node else {
1978 continue;
1979 };
1980
1981 let select_expr = Expression::Select(select.clone());
1982 let context = collect_type_check_context(&select_expr, schema_map);
1983 let scope = build_scope(&select_expr);
1984 let mut resolver = Resolver::new(&scope, resolver_schema, true);
1985
1986 if context.referenced_tables.len() > 1 {
1987 let using_columns: HashSet<String> = select
1988 .joins
1989 .iter()
1990 .flat_map(|join| join.using.iter().map(|id| lower(&id.name)))
1991 .collect();
1992
1993 let mut seen = HashSet::new();
1994 for column_expr in select_expr
1995 .find_all(|e| matches!(e, Expression::Column(Column { table: None, .. })))
1996 {
1997 let Expression::Column(column) = column_expr else {
1998 continue;
1999 };
2000
2001 let col_name = lower(&column.name.name);
2002 if col_name.is_empty()
2003 || using_columns.contains(&col_name)
2004 || !seen.insert(col_name.clone())
2005 {
2006 continue;
2007 }
2008
2009 if resolver.is_ambiguous(&col_name) {
2010 let source_count = resolver.sources_for_column(&col_name).len();
2011 errors.push(if strict {
2012 ValidationError::error(
2013 format!(
2014 "Ambiguous unqualified column '{}' found in {} referenced tables",
2015 col_name, source_count
2016 ),
2017 validation_codes::E_AMBIGUOUS_COLUMN_REFERENCE,
2018 )
2019 } else {
2020 ValidationError::warning(
2021 format!(
2022 "Ambiguous unqualified column '{}' found in {} referenced tables",
2023 col_name, source_count
2024 ),
2025 validation_codes::W_WEAK_REFERENCE_INTEGRITY,
2026 )
2027 });
2028 }
2029 }
2030 }
2031
2032 let mut cumulative_left_tables = select_from_table_keys(select, schema_map);
2033
2034 for join in &select.joins {
2035 let right_table_key = resolved_table_key_from_expr(&join.this, schema_map);
2036 let has_explicit_condition = join.on.is_some() || !join.using.is_empty();
2037 let cartesian_like_kind = matches!(
2038 join.kind,
2039 JoinKind::Cross
2040 | JoinKind::Implicit
2041 | JoinKind::Array
2042 | JoinKind::LeftArray
2043 | JoinKind::Paste
2044 );
2045
2046 if right_table_key.is_some()
2047 && (cartesian_like_kind
2048 || (!has_explicit_condition && !is_natural_or_implied_join(join.kind)))
2049 {
2050 errors.push(ValidationError::warning(
2051 "Potential cartesian join: JOIN without ON/USING condition",
2052 validation_codes::W_CARTESIAN_JOIN,
2053 ));
2054 }
2055
2056 if let (Some(on_expr), Some(right_key)) = (&join.on, right_table_key.clone()) {
2057 if join.using.is_empty() {
2058 let mut eq_pairs = Vec::new();
2059 extract_join_equality_pairs(
2060 on_expr,
2061 schema_map,
2062 &context,
2063 &mut resolver,
2064 &mut eq_pairs,
2065 );
2066
2067 let relevant_relationships: Vec<&DeclaredRelationship> = relationships
2068 .iter()
2069 .filter(|rel| {
2070 cumulative_left_tables.contains(&rel.source_table)
2071 && rel.target_table == right_key
2072 || (cumulative_left_tables.contains(&rel.target_table)
2073 && rel.source_table == right_key)
2074 })
2075 .collect();
2076
2077 if !relevant_relationships.is_empty() {
2078 let uses_declared_fk = eq_pairs.iter().any(|((lt, lc), (rt, rc))| {
2079 relevant_relationships
2080 .iter()
2081 .any(|rel| relationship_matches_pair(rel, lt, lc, rt, rc))
2082 });
2083 if !uses_declared_fk {
2084 errors.push(ValidationError::warning(
2085 "JOIN predicate does not use declared foreign-key relationship columns",
2086 validation_codes::W_JOIN_NOT_USING_DECLARED_REFERENCE,
2087 ));
2088 }
2089 }
2090 }
2091 }
2092
2093 if let Some(right_key) = right_table_key {
2094 cumulative_left_tables.insert(right_key);
2095 }
2096 }
2097 }
2098
2099 errors
2100}
2101
2102fn are_setop_compatible(left: TypeFamily, right: TypeFamily) -> bool {
2103 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2104 return true;
2105 }
2106 if left == right {
2107 return true;
2108 }
2109 if left.is_numeric() && right.is_numeric() {
2110 return true;
2111 }
2112 if left.is_temporal() && right.is_temporal() {
2113 return true;
2114 }
2115 false
2116}
2117
2118fn merged_setop_family(left: TypeFamily, right: TypeFamily) -> TypeFamily {
2119 if left == TypeFamily::Unknown {
2120 return right;
2121 }
2122 if right == TypeFamily::Unknown {
2123 return left;
2124 }
2125 if left == right {
2126 return left;
2127 }
2128 if left.is_numeric() && right.is_numeric() {
2129 if left == TypeFamily::Numeric || right == TypeFamily::Numeric {
2130 return TypeFamily::Numeric;
2131 }
2132 return TypeFamily::Integer;
2133 }
2134 if left.is_temporal() && right.is_temporal() {
2135 if left == TypeFamily::Timestamp || right == TypeFamily::Timestamp {
2136 return TypeFamily::Timestamp;
2137 }
2138 if left == TypeFamily::Date || right == TypeFamily::Date {
2139 return TypeFamily::Date;
2140 }
2141 return TypeFamily::Time;
2142 }
2143 TypeFamily::Unknown
2144}
2145
2146fn are_assignment_compatible(target: TypeFamily, source: TypeFamily) -> bool {
2147 if target == TypeFamily::Unknown || source == TypeFamily::Unknown {
2148 return true;
2149 }
2150 if target == source {
2151 return true;
2152 }
2153
2154 match target {
2155 TypeFamily::Boolean => source == TypeFamily::Boolean,
2156 TypeFamily::Integer | TypeFamily::Numeric => source.is_numeric(),
2157 TypeFamily::Date | TypeFamily::Time | TypeFamily::Timestamp | TypeFamily::Interval => {
2158 source.is_temporal()
2159 }
2160 TypeFamily::String => true,
2161 TypeFamily::Binary => matches!(source, TypeFamily::Binary | TypeFamily::String),
2162 TypeFamily::Json => matches!(source, TypeFamily::Json | TypeFamily::String),
2163 TypeFamily::Uuid => matches!(source, TypeFamily::Uuid | TypeFamily::String),
2164 TypeFamily::Array => source == TypeFamily::Array,
2165 TypeFamily::Map => source == TypeFamily::Map,
2166 TypeFamily::Struct => source == TypeFamily::Struct,
2167 TypeFamily::Unknown => true,
2168 }
2169}
2170
2171fn projection_families(
2172 query_expr: &Expression,
2173 schema_map: &HashMap<String, TableSchemaEntry>,
2174) -> Option<Vec<TypeFamily>> {
2175 match query_expr {
2176 Expression::Select(select) => {
2177 if select
2178 .expressions
2179 .iter()
2180 .any(|e| matches!(e, Expression::Star(_) | Expression::BracedWildcard(_)))
2181 {
2182 return None;
2183 }
2184 let select_expr = Expression::Select(select.clone());
2185 let context = collect_type_check_context(&select_expr, schema_map);
2186 Some(
2187 select
2188 .expressions
2189 .iter()
2190 .map(|e| infer_expression_type_family(e, schema_map, &context))
2191 .collect(),
2192 )
2193 }
2194 Expression::Subquery(subquery) => projection_families(&subquery.this, schema_map),
2195 Expression::Union(union) => {
2196 let left = projection_families(&union.left, schema_map)?;
2197 let right = projection_families(&union.right, schema_map)?;
2198 if left.len() != right.len() {
2199 return None;
2200 }
2201 Some(
2202 left.into_iter()
2203 .zip(right)
2204 .map(|(l, r)| merged_setop_family(l, r))
2205 .collect(),
2206 )
2207 }
2208 Expression::Intersect(intersect) => {
2209 let left = projection_families(&intersect.left, schema_map)?;
2210 let right = projection_families(&intersect.right, schema_map)?;
2211 if left.len() != right.len() {
2212 return None;
2213 }
2214 Some(
2215 left.into_iter()
2216 .zip(right)
2217 .map(|(l, r)| merged_setop_family(l, r))
2218 .collect(),
2219 )
2220 }
2221 Expression::Except(except) => {
2222 let left = projection_families(&except.left, schema_map)?;
2223 let right = projection_families(&except.right, schema_map)?;
2224 if left.len() != right.len() {
2225 return None;
2226 }
2227 Some(
2228 left.into_iter()
2229 .zip(right)
2230 .map(|(l, r)| merged_setop_family(l, r))
2231 .collect(),
2232 )
2233 }
2234 Expression::Values(values) => {
2235 let first_row = values.expressions.first()?;
2236 let context = TypeCheckContext::default();
2237 Some(
2238 first_row
2239 .expressions
2240 .iter()
2241 .map(|e| infer_expression_type_family(e, schema_map, &context))
2242 .collect(),
2243 )
2244 }
2245 _ => None,
2246 }
2247}
2248
2249fn check_set_operation_compatibility(
2250 op_name: &str,
2251 left_expr: &Expression,
2252 right_expr: &Expression,
2253 schema_map: &HashMap<String, TableSchemaEntry>,
2254 strict: bool,
2255 errors: &mut Vec<ValidationError>,
2256) {
2257 let Some(left_projection) = projection_families(left_expr, schema_map) else {
2258 return;
2259 };
2260 let Some(right_projection) = projection_families(right_expr, schema_map) else {
2261 return;
2262 };
2263
2264 if left_projection.len() != right_projection.len() {
2265 errors.push(type_issue(
2266 strict,
2267 validation_codes::E_SETOP_ARITY_MISMATCH,
2268 validation_codes::W_SETOP_IMPLICIT_COERCION,
2269 format!(
2270 "{} operands return different column counts: left {}, right {}",
2271 op_name,
2272 left_projection.len(),
2273 right_projection.len()
2274 ),
2275 ));
2276 return;
2277 }
2278
2279 for (idx, (left, right)) in left_projection
2280 .into_iter()
2281 .zip(right_projection)
2282 .enumerate()
2283 {
2284 if !are_setop_compatible(left, right) {
2285 errors.push(type_issue(
2286 strict,
2287 validation_codes::E_SETOP_TYPE_MISMATCH,
2288 validation_codes::W_SETOP_IMPLICIT_COERCION,
2289 format!(
2290 "{} column {} has incompatible types: {} vs {}",
2291 op_name,
2292 idx + 1,
2293 type_family_name(left),
2294 type_family_name(right)
2295 ),
2296 ));
2297 }
2298 }
2299}
2300
2301fn check_insert_assignments(
2302 stmt: &Expression,
2303 insert: &Insert,
2304 schema_map: &HashMap<String, TableSchemaEntry>,
2305 strict: bool,
2306 errors: &mut Vec<ValidationError>,
2307) {
2308 let Some((target_table_key, table_schema)) =
2309 resolve_table_schema_entry(&insert.table, schema_map)
2310 else {
2311 return;
2312 };
2313
2314 let mut target_columns = Vec::new();
2315 if insert.columns.is_empty() {
2316 target_columns.extend(table_schema.column_order.iter().cloned());
2317 } else {
2318 for column in &insert.columns {
2319 let col_name = lower(&column.name);
2320 if table_schema.columns.contains_key(&col_name) {
2321 target_columns.push(col_name);
2322 } else {
2323 errors.push(if strict {
2324 ValidationError::error(
2325 format!(
2326 "Unknown column '{}' in table '{}'",
2327 column.name, target_table_key
2328 ),
2329 validation_codes::E_UNKNOWN_COLUMN,
2330 )
2331 } else {
2332 ValidationError::warning(
2333 format!(
2334 "Unknown column '{}' in table '{}'",
2335 column.name, target_table_key
2336 ),
2337 validation_codes::E_UNKNOWN_COLUMN,
2338 )
2339 });
2340 }
2341 }
2342 }
2343
2344 if target_columns.is_empty() {
2345 return;
2346 }
2347
2348 let context = collect_type_check_context(stmt, schema_map);
2349
2350 if !insert.default_values {
2351 for (row_idx, row) in insert.values.iter().enumerate() {
2352 if row.len() != target_columns.len() {
2353 errors.push(type_issue(
2354 strict,
2355 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2356 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2357 format!(
2358 "INSERT row {} has {} values but target has {} columns",
2359 row_idx + 1,
2360 row.len(),
2361 target_columns.len()
2362 ),
2363 ));
2364 continue;
2365 }
2366
2367 for (value, target_column) in row.iter().zip(target_columns.iter()) {
2368 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2369 continue;
2370 };
2371 let source_family = infer_expression_type_family(value, schema_map, &context);
2372 if !are_assignment_compatible(target_family, source_family) {
2373 errors.push(type_issue(
2374 strict,
2375 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2376 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2377 format!(
2378 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2379 target_table_key,
2380 target_column,
2381 type_family_name(target_family),
2382 type_family_name(source_family)
2383 ),
2384 ));
2385 }
2386 }
2387 }
2388 }
2389
2390 if let Some(query) = &insert.query {
2391 if insert.by_name {
2393 return;
2394 }
2395
2396 let Some(source_projection) = projection_families(query, schema_map) else {
2397 return;
2398 };
2399
2400 if source_projection.len() != target_columns.len() {
2401 errors.push(type_issue(
2402 strict,
2403 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2404 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2405 format!(
2406 "INSERT source query has {} columns but target has {} columns",
2407 source_projection.len(),
2408 target_columns.len()
2409 ),
2410 ));
2411 return;
2412 }
2413
2414 for (source_family, target_column) in
2415 source_projection.into_iter().zip(target_columns.iter())
2416 {
2417 let Some(target_family) = table_schema.columns.get(target_column).copied() else {
2418 continue;
2419 };
2420 if !are_assignment_compatible(target_family, source_family) {
2421 errors.push(type_issue(
2422 strict,
2423 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2424 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2425 format!(
2426 "INSERT assignment type mismatch for '{}.{}': expected {}, found {}",
2427 target_table_key,
2428 target_column,
2429 type_family_name(target_family),
2430 type_family_name(source_family)
2431 ),
2432 ));
2433 }
2434 }
2435 }
2436}
2437
2438fn check_update_assignments(
2439 stmt: &Expression,
2440 update: &Update,
2441 schema_map: &HashMap<String, TableSchemaEntry>,
2442 strict: bool,
2443 errors: &mut Vec<ValidationError>,
2444) {
2445 let Some((target_table_key, table_schema)) =
2446 resolve_table_schema_entry(&update.table, schema_map)
2447 else {
2448 return;
2449 };
2450
2451 let context = collect_type_check_context(stmt, schema_map);
2452
2453 for (column, value) in &update.set {
2454 let col_name = lower(&column.name);
2455 let Some(target_family) = table_schema.columns.get(&col_name).copied() else {
2456 errors.push(if strict {
2457 ValidationError::error(
2458 format!(
2459 "Unknown column '{}' in table '{}'",
2460 column.name, target_table_key
2461 ),
2462 validation_codes::E_UNKNOWN_COLUMN,
2463 )
2464 } else {
2465 ValidationError::warning(
2466 format!(
2467 "Unknown column '{}' in table '{}'",
2468 column.name, target_table_key
2469 ),
2470 validation_codes::E_UNKNOWN_COLUMN,
2471 )
2472 });
2473 continue;
2474 };
2475
2476 let source_family = infer_expression_type_family(value, schema_map, &context);
2477 if !are_assignment_compatible(target_family, source_family) {
2478 errors.push(type_issue(
2479 strict,
2480 validation_codes::E_INVALID_ASSIGNMENT_TYPE,
2481 validation_codes::W_IMPLICIT_CAST_ASSIGNMENT,
2482 format!(
2483 "UPDATE assignment type mismatch for '{}.{}': expected {}, found {}",
2484 target_table_key,
2485 col_name,
2486 type_family_name(target_family),
2487 type_family_name(source_family)
2488 ),
2489 ));
2490 }
2491 }
2492}
2493
2494fn check_types(
2495 stmt: &Expression,
2496 dialect: DialectType,
2497 schema_map: &HashMap<String, TableSchemaEntry>,
2498 function_catalog: Option<&dyn FunctionCatalog>,
2499 strict: bool,
2500) -> Vec<ValidationError> {
2501 let mut errors = Vec::new();
2502 let context = collect_type_check_context(stmt, schema_map);
2503
2504 for node in stmt.dfs() {
2505 match node {
2506 Expression::Insert(insert) => {
2507 check_insert_assignments(stmt, insert, schema_map, strict, &mut errors);
2508 }
2509 Expression::Update(update) => {
2510 check_update_assignments(stmt, update, schema_map, strict, &mut errors);
2511 }
2512 Expression::Union(union) => {
2513 check_set_operation_compatibility(
2514 "UNION",
2515 &union.left,
2516 &union.right,
2517 schema_map,
2518 strict,
2519 &mut errors,
2520 );
2521 }
2522 Expression::Intersect(intersect) => {
2523 check_set_operation_compatibility(
2524 "INTERSECT",
2525 &intersect.left,
2526 &intersect.right,
2527 schema_map,
2528 strict,
2529 &mut errors,
2530 );
2531 }
2532 Expression::Except(except) => {
2533 check_set_operation_compatibility(
2534 "EXCEPT",
2535 &except.left,
2536 &except.right,
2537 schema_map,
2538 strict,
2539 &mut errors,
2540 );
2541 }
2542 Expression::Select(select) => {
2543 if let Some(prewhere) = &select.prewhere {
2544 let family = infer_expression_type_family(prewhere, schema_map, &context);
2545 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2546 errors.push(type_issue(
2547 strict,
2548 validation_codes::E_INVALID_PREDICATE_TYPE,
2549 validation_codes::W_PREDICATE_NULLABILITY,
2550 format!(
2551 "PREWHERE clause expects a boolean predicate, found {}",
2552 type_family_name(family)
2553 ),
2554 ));
2555 }
2556 }
2557
2558 if let Some(where_clause) = &select.where_clause {
2559 let family =
2560 infer_expression_type_family(&where_clause.this, schema_map, &context);
2561 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2562 errors.push(type_issue(
2563 strict,
2564 validation_codes::E_INVALID_PREDICATE_TYPE,
2565 validation_codes::W_PREDICATE_NULLABILITY,
2566 format!(
2567 "WHERE clause expects a boolean predicate, found {}",
2568 type_family_name(family)
2569 ),
2570 ));
2571 }
2572 }
2573
2574 if let Some(having_clause) = &select.having {
2575 let family =
2576 infer_expression_type_family(&having_clause.this, schema_map, &context);
2577 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2578 errors.push(type_issue(
2579 strict,
2580 validation_codes::E_INVALID_PREDICATE_TYPE,
2581 validation_codes::W_PREDICATE_NULLABILITY,
2582 format!(
2583 "HAVING clause expects a boolean predicate, found {}",
2584 type_family_name(family)
2585 ),
2586 ));
2587 }
2588 }
2589
2590 for join in &select.joins {
2591 if let Some(on) = &join.on {
2592 let family = infer_expression_type_family(on, schema_map, &context);
2593 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2594 errors.push(type_issue(
2595 strict,
2596 validation_codes::E_INVALID_PREDICATE_TYPE,
2597 validation_codes::W_PREDICATE_NULLABILITY,
2598 format!(
2599 "JOIN ON expects a boolean predicate, found {}",
2600 type_family_name(family)
2601 ),
2602 ));
2603 }
2604 }
2605 if let Some(match_condition) = &join.match_condition {
2606 let family =
2607 infer_expression_type_family(match_condition, schema_map, &context);
2608 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2609 errors.push(type_issue(
2610 strict,
2611 validation_codes::E_INVALID_PREDICATE_TYPE,
2612 validation_codes::W_PREDICATE_NULLABILITY,
2613 format!(
2614 "JOIN MATCH_CONDITION expects a boolean predicate, found {}",
2615 type_family_name(family)
2616 ),
2617 ));
2618 }
2619 }
2620 }
2621 }
2622 Expression::Where(where_clause) => {
2623 let family = infer_expression_type_family(&where_clause.this, schema_map, &context);
2624 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2625 errors.push(type_issue(
2626 strict,
2627 validation_codes::E_INVALID_PREDICATE_TYPE,
2628 validation_codes::W_PREDICATE_NULLABILITY,
2629 format!(
2630 "WHERE clause expects a boolean predicate, found {}",
2631 type_family_name(family)
2632 ),
2633 ));
2634 }
2635 }
2636 Expression::Having(having_clause) => {
2637 let family =
2638 infer_expression_type_family(&having_clause.this, schema_map, &context);
2639 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2640 errors.push(type_issue(
2641 strict,
2642 validation_codes::E_INVALID_PREDICATE_TYPE,
2643 validation_codes::W_PREDICATE_NULLABILITY,
2644 format!(
2645 "HAVING clause expects a boolean predicate, found {}",
2646 type_family_name(family)
2647 ),
2648 ));
2649 }
2650 }
2651 Expression::And(op) | Expression::Or(op) => {
2652 for (side, expr) in [("left", &op.left), ("right", &op.right)] {
2653 let family = infer_expression_type_family(expr, schema_map, &context);
2654 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2655 errors.push(type_issue(
2656 strict,
2657 validation_codes::E_INVALID_PREDICATE_TYPE,
2658 validation_codes::W_PREDICATE_NULLABILITY,
2659 format!(
2660 "Logical {} operand expects boolean, found {}",
2661 side,
2662 type_family_name(family)
2663 ),
2664 ));
2665 }
2666 }
2667 }
2668 Expression::Not(unary) => {
2669 let family = infer_expression_type_family(&unary.this, schema_map, &context);
2670 if family != TypeFamily::Unknown && family != TypeFamily::Boolean {
2671 errors.push(type_issue(
2672 strict,
2673 validation_codes::E_INVALID_PREDICATE_TYPE,
2674 validation_codes::W_PREDICATE_NULLABILITY,
2675 format!("NOT expects boolean, found {}", type_family_name(family)),
2676 ));
2677 }
2678 }
2679 Expression::Eq(op)
2680 | Expression::Neq(op)
2681 | Expression::Lt(op)
2682 | Expression::Lte(op)
2683 | Expression::Gt(op)
2684 | Expression::Gte(op) => {
2685 let left = infer_expression_type_family(&op.left, schema_map, &context);
2686 let right = infer_expression_type_family(&op.right, schema_map, &context);
2687 if !are_comparable(left, right) {
2688 errors.push(type_issue(
2689 strict,
2690 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2691 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2692 format!(
2693 "Incompatible comparison between {} and {}",
2694 type_family_name(left),
2695 type_family_name(right)
2696 ),
2697 ));
2698 }
2699 }
2700 Expression::Like(op) | Expression::ILike(op) => {
2701 let left = infer_expression_type_family(&op.left, schema_map, &context);
2702 let right = infer_expression_type_family(&op.right, schema_map, &context);
2703 if left != TypeFamily::Unknown
2704 && right != TypeFamily::Unknown
2705 && (!is_string_like(left) || !is_string_like(right))
2706 {
2707 errors.push(type_issue(
2708 strict,
2709 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2710 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2711 format!(
2712 "LIKE/ILIKE expects string operands, found {} and {}",
2713 type_family_name(left),
2714 type_family_name(right)
2715 ),
2716 ));
2717 }
2718 }
2719 Expression::Between(between) => {
2720 let this_family = infer_expression_type_family(&between.this, schema_map, &context);
2721 let low_family = infer_expression_type_family(&between.low, schema_map, &context);
2722 let high_family = infer_expression_type_family(&between.high, schema_map, &context);
2723
2724 if !are_comparable(this_family, low_family)
2725 || !are_comparable(this_family, high_family)
2726 {
2727 errors.push(type_issue(
2728 strict,
2729 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2730 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2731 format!(
2732 "BETWEEN bounds are incompatible with {} (found {} and {})",
2733 type_family_name(this_family),
2734 type_family_name(low_family),
2735 type_family_name(high_family)
2736 ),
2737 ));
2738 }
2739 }
2740 Expression::In(in_expr) => {
2741 let this_family = infer_expression_type_family(&in_expr.this, schema_map, &context);
2742 for value in &in_expr.expressions {
2743 let item_family = infer_expression_type_family(value, schema_map, &context);
2744 if !are_comparable(this_family, item_family) {
2745 errors.push(type_issue(
2746 strict,
2747 validation_codes::E_INCOMPATIBLE_COMPARISON_TYPES,
2748 validation_codes::W_IMPLICIT_CAST_COMPARISON,
2749 format!(
2750 "IN item type {} is incompatible with {}",
2751 type_family_name(item_family),
2752 type_family_name(this_family)
2753 ),
2754 ));
2755 break;
2756 }
2757 }
2758 }
2759 Expression::Add(op)
2760 | Expression::Sub(op)
2761 | Expression::Mul(op)
2762 | Expression::Div(op)
2763 | Expression::Mod(op) => {
2764 let left = infer_expression_type_family(&op.left, schema_map, &context);
2765 let right = infer_expression_type_family(&op.right, schema_map, &context);
2766
2767 if left == TypeFamily::Unknown || right == TypeFamily::Unknown {
2768 continue;
2769 }
2770
2771 let temporal_ok = matches!(node, Expression::Add(_) | Expression::Sub(_))
2772 && ((left.is_temporal() && right.is_numeric())
2773 || (right.is_temporal() && left.is_numeric())
2774 || (matches!(node, Expression::Sub(_))
2775 && left.is_temporal()
2776 && right.is_temporal()));
2777
2778 if !(left.is_numeric() && right.is_numeric()) && !temporal_ok {
2779 errors.push(type_issue(
2780 strict,
2781 validation_codes::E_INVALID_ARITHMETIC_TYPE,
2782 validation_codes::W_IMPLICIT_CAST_ARITHMETIC,
2783 format!(
2784 "Arithmetic operation expects numeric-compatible operands, found {} and {}",
2785 type_family_name(left),
2786 type_family_name(right)
2787 ),
2788 ));
2789 }
2790 }
2791 Expression::Function(function) => {
2792 check_function_catalog(
2793 function,
2794 dialect,
2795 function_catalog,
2796 strict,
2797 &mut errors,
2798 );
2799 check_generic_function(function, schema_map, &context, strict, &mut errors);
2800 }
2801 Expression::Upper(func)
2802 | Expression::Lower(func)
2803 | Expression::LTrim(func)
2804 | Expression::RTrim(func)
2805 | Expression::Reverse(func) => {
2806 let family = infer_expression_type_family(&func.this, schema_map, &context);
2807 check_function_argument(
2808 &mut errors,
2809 strict,
2810 "string_function",
2811 0,
2812 family,
2813 "a string argument",
2814 is_string_like(family),
2815 );
2816 }
2817 Expression::Length(func) => {
2818 let family = infer_expression_type_family(&func.this, schema_map, &context);
2819 check_function_argument(
2820 &mut errors,
2821 strict,
2822 "length",
2823 0,
2824 family,
2825 "a string or binary argument",
2826 is_string_or_binary(family),
2827 );
2828 }
2829 Expression::Trim(func) => {
2830 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2831 check_function_argument(
2832 &mut errors,
2833 strict,
2834 "trim",
2835 0,
2836 this_family,
2837 "a string argument",
2838 is_string_like(this_family),
2839 );
2840 if let Some(chars) = &func.characters {
2841 let chars_family = infer_expression_type_family(chars, schema_map, &context);
2842 check_function_argument(
2843 &mut errors,
2844 strict,
2845 "trim",
2846 1,
2847 chars_family,
2848 "a string argument",
2849 is_string_like(chars_family),
2850 );
2851 }
2852 }
2853 Expression::Substring(func) => {
2854 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2855 check_function_argument(
2856 &mut errors,
2857 strict,
2858 "substring",
2859 0,
2860 this_family,
2861 "a string argument",
2862 is_string_like(this_family),
2863 );
2864
2865 let start_family = infer_expression_type_family(&func.start, schema_map, &context);
2866 check_function_argument(
2867 &mut errors,
2868 strict,
2869 "substring",
2870 1,
2871 start_family,
2872 "a numeric argument",
2873 start_family.is_numeric(),
2874 );
2875 if let Some(length) = &func.length {
2876 let length_family = infer_expression_type_family(length, schema_map, &context);
2877 check_function_argument(
2878 &mut errors,
2879 strict,
2880 "substring",
2881 2,
2882 length_family,
2883 "a numeric argument",
2884 length_family.is_numeric(),
2885 );
2886 }
2887 }
2888 Expression::Replace(func) => {
2889 for (arg, idx) in [
2890 (&func.this, 0_usize),
2891 (&func.old, 1_usize),
2892 (&func.new, 2_usize),
2893 ] {
2894 let family = infer_expression_type_family(arg, schema_map, &context);
2895 check_function_argument(
2896 &mut errors,
2897 strict,
2898 "replace",
2899 idx,
2900 family,
2901 "a string argument",
2902 is_string_like(family),
2903 );
2904 }
2905 }
2906 Expression::Left(func) | Expression::Right(func) => {
2907 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2908 check_function_argument(
2909 &mut errors,
2910 strict,
2911 "left_right",
2912 0,
2913 this_family,
2914 "a string argument",
2915 is_string_like(this_family),
2916 );
2917 let length_family =
2918 infer_expression_type_family(&func.length, schema_map, &context);
2919 check_function_argument(
2920 &mut errors,
2921 strict,
2922 "left_right",
2923 1,
2924 length_family,
2925 "a numeric argument",
2926 length_family.is_numeric(),
2927 );
2928 }
2929 Expression::Repeat(func) => {
2930 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2931 check_function_argument(
2932 &mut errors,
2933 strict,
2934 "repeat",
2935 0,
2936 this_family,
2937 "a string argument",
2938 is_string_like(this_family),
2939 );
2940 let times_family = infer_expression_type_family(&func.times, schema_map, &context);
2941 check_function_argument(
2942 &mut errors,
2943 strict,
2944 "repeat",
2945 1,
2946 times_family,
2947 "a numeric argument",
2948 times_family.is_numeric(),
2949 );
2950 }
2951 Expression::Lpad(func) | Expression::Rpad(func) => {
2952 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
2953 check_function_argument(
2954 &mut errors,
2955 strict,
2956 "pad",
2957 0,
2958 this_family,
2959 "a string argument",
2960 is_string_like(this_family),
2961 );
2962 let length_family =
2963 infer_expression_type_family(&func.length, schema_map, &context);
2964 check_function_argument(
2965 &mut errors,
2966 strict,
2967 "pad",
2968 1,
2969 length_family,
2970 "a numeric argument",
2971 length_family.is_numeric(),
2972 );
2973 if let Some(fill) = &func.fill {
2974 let fill_family = infer_expression_type_family(fill, schema_map, &context);
2975 check_function_argument(
2976 &mut errors,
2977 strict,
2978 "pad",
2979 2,
2980 fill_family,
2981 "a string argument",
2982 is_string_like(fill_family),
2983 );
2984 }
2985 }
2986 Expression::Abs(func)
2987 | Expression::Sqrt(func)
2988 | Expression::Cbrt(func)
2989 | Expression::Ln(func)
2990 | Expression::Exp(func) => {
2991 let family = infer_expression_type_family(&func.this, schema_map, &context);
2992 check_function_argument(
2993 &mut errors,
2994 strict,
2995 "numeric_function",
2996 0,
2997 family,
2998 "a numeric argument",
2999 family.is_numeric(),
3000 );
3001 }
3002 Expression::Round(func) => {
3003 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3004 check_function_argument(
3005 &mut errors,
3006 strict,
3007 "round",
3008 0,
3009 this_family,
3010 "a numeric argument",
3011 this_family.is_numeric(),
3012 );
3013 if let Some(decimals) = &func.decimals {
3014 let decimals_family =
3015 infer_expression_type_family(decimals, schema_map, &context);
3016 check_function_argument(
3017 &mut errors,
3018 strict,
3019 "round",
3020 1,
3021 decimals_family,
3022 "a numeric argument",
3023 decimals_family.is_numeric(),
3024 );
3025 }
3026 }
3027 Expression::Floor(func) => {
3028 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3029 check_function_argument(
3030 &mut errors,
3031 strict,
3032 "floor",
3033 0,
3034 this_family,
3035 "a numeric argument",
3036 this_family.is_numeric(),
3037 );
3038 if let Some(scale) = &func.scale {
3039 let scale_family = infer_expression_type_family(scale, schema_map, &context);
3040 check_function_argument(
3041 &mut errors,
3042 strict,
3043 "floor",
3044 1,
3045 scale_family,
3046 "a numeric argument",
3047 scale_family.is_numeric(),
3048 );
3049 }
3050 }
3051 Expression::Ceil(func) => {
3052 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3053 check_function_argument(
3054 &mut errors,
3055 strict,
3056 "ceil",
3057 0,
3058 this_family,
3059 "a numeric argument",
3060 this_family.is_numeric(),
3061 );
3062 if let Some(decimals) = &func.decimals {
3063 let decimals_family =
3064 infer_expression_type_family(decimals, schema_map, &context);
3065 check_function_argument(
3066 &mut errors,
3067 strict,
3068 "ceil",
3069 1,
3070 decimals_family,
3071 "a numeric argument",
3072 decimals_family.is_numeric(),
3073 );
3074 }
3075 }
3076 Expression::Power(func) => {
3077 let left_family = infer_expression_type_family(&func.this, schema_map, &context);
3078 check_function_argument(
3079 &mut errors,
3080 strict,
3081 "power",
3082 0,
3083 left_family,
3084 "a numeric argument",
3085 left_family.is_numeric(),
3086 );
3087 let right_family =
3088 infer_expression_type_family(&func.expression, schema_map, &context);
3089 check_function_argument(
3090 &mut errors,
3091 strict,
3092 "power",
3093 1,
3094 right_family,
3095 "a numeric argument",
3096 right_family.is_numeric(),
3097 );
3098 }
3099 Expression::Log(func) => {
3100 let this_family = infer_expression_type_family(&func.this, schema_map, &context);
3101 check_function_argument(
3102 &mut errors,
3103 strict,
3104 "log",
3105 0,
3106 this_family,
3107 "a numeric argument",
3108 this_family.is_numeric(),
3109 );
3110 if let Some(base) = &func.base {
3111 let base_family = infer_expression_type_family(base, schema_map, &context);
3112 check_function_argument(
3113 &mut errors,
3114 strict,
3115 "log",
3116 1,
3117 base_family,
3118 "a numeric argument",
3119 base_family.is_numeric(),
3120 );
3121 }
3122 }
3123 _ => {}
3124 }
3125 }
3126
3127 errors
3128}
3129
3130fn check_semantics(stmt: &Expression) -> Vec<ValidationError> {
3131 let mut errors = Vec::new();
3132
3133 let Expression::Select(select) = stmt else {
3134 return errors;
3135 };
3136 let select_expr = Expression::Select(select.clone());
3137
3138 if !select_expr
3140 .find_all(|e| matches!(e, Expression::Star(_)))
3141 .is_empty()
3142 {
3143 errors.push(ValidationError::warning(
3144 "SELECT * is discouraged; specify columns explicitly for better performance and maintainability",
3145 validation_codes::W_SELECT_STAR,
3146 ));
3147 }
3148
3149 let aggregate_count = get_aggregate_functions(&select_expr).len();
3151 if aggregate_count > 0 && select.group_by.is_none() {
3152 let has_non_aggregate_column = select.expressions.iter().any(|expr| {
3153 matches!(expr, Expression::Column(_) | Expression::Identifier(_))
3154 && get_aggregate_functions(expr).is_empty()
3155 });
3156
3157 if has_non_aggregate_column {
3158 errors.push(ValidationError::warning(
3159 "Mixing aggregate functions with non-aggregated columns without GROUP BY may cause errors in strict SQL mode",
3160 validation_codes::W_AGGREGATE_WITHOUT_GROUP_BY,
3161 ));
3162 }
3163 }
3164
3165 if select.distinct && select.order_by.is_some() {
3167 errors.push(ValidationError::warning(
3168 "DISTINCT with ORDER BY: ensure ORDER BY columns are in SELECT list",
3169 validation_codes::W_DISTINCT_ORDER_BY,
3170 ));
3171 }
3172
3173 if select.limit.is_some() && select.order_by.is_none() {
3175 errors.push(ValidationError::warning(
3176 "LIMIT without ORDER BY produces non-deterministic results",
3177 validation_codes::W_LIMIT_WITHOUT_ORDER_BY,
3178 ));
3179 }
3180
3181 errors
3182}
3183
3184fn resolve_scope_source_name(scope: &crate::scope::Scope, name: &str) -> Option<String> {
3185 scope
3186 .sources
3187 .get_key_value(name)
3188 .map(|(k, _)| k.clone())
3189 .or_else(|| {
3190 scope
3191 .sources
3192 .keys()
3193 .find(|source| source.eq_ignore_ascii_case(name))
3194 .cloned()
3195 })
3196}
3197
3198fn source_has_column(columns: &[String], column_name: &str) -> bool {
3199 columns
3200 .iter()
3201 .any(|c| c == "*" || c.eq_ignore_ascii_case(column_name))
3202}
3203
3204fn source_display_name(scope: &crate::scope::Scope, source_name: &str) -> String {
3205 scope
3206 .sources
3207 .get(source_name)
3208 .map(|source| match &source.expression {
3209 Expression::Table(table) => lower(&table_ref_display_name(table)),
3210 _ => lower(source_name),
3211 })
3212 .unwrap_or_else(|| lower(source_name))
3213}
3214
3215fn validate_select_columns_with_schema(
3216 select: &crate::expressions::Select,
3217 schema_map: &HashMap<String, TableSchemaEntry>,
3218 resolver_schema: &MappingSchema,
3219 strict: bool,
3220) -> Vec<ValidationError> {
3221 let mut errors = Vec::new();
3222 let select_expr = Expression::Select(Box::new(select.clone()));
3223 let scope = build_scope(&select_expr);
3224 let mut resolver = Resolver::new(&scope, resolver_schema, true);
3225 let source_names: Vec<String> = scope.sources.keys().cloned().collect();
3226
3227 for node in walk_in_scope(&select_expr, false) {
3228 let Expression::Column(column) = node else {
3229 continue;
3230 };
3231
3232 let col_name = lower(&column.name.name);
3233 if col_name.is_empty() {
3234 continue;
3235 }
3236
3237 if let Some(table) = &column.table {
3238 let Some(source_name) = resolve_scope_source_name(&scope, &table.name) else {
3239 errors.push(if strict {
3241 ValidationError::error(
3242 format!(
3243 "Unknown table or alias '{}' referenced by column '{}'",
3244 table.name, col_name
3245 ),
3246 validation_codes::E_UNRESOLVED_REFERENCE,
3247 )
3248 } else {
3249 ValidationError::warning(
3250 format!(
3251 "Unknown table or alias '{}' referenced by column '{}'",
3252 table.name, col_name
3253 ),
3254 validation_codes::E_UNRESOLVED_REFERENCE,
3255 )
3256 });
3257 continue;
3258 };
3259
3260 if let Ok(columns) = resolver.get_source_columns(&source_name) {
3261 if !columns.is_empty() && !source_has_column(&columns, &col_name) {
3262 let table_name = source_display_name(&scope, &source_name);
3263 errors.push(if strict {
3264 ValidationError::error(
3265 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3266 validation_codes::E_UNKNOWN_COLUMN,
3267 )
3268 } else {
3269 ValidationError::warning(
3270 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3271 validation_codes::E_UNKNOWN_COLUMN,
3272 )
3273 });
3274 }
3275 }
3276 continue;
3277 }
3278
3279 let matching_sources: Vec<String> = source_names
3280 .iter()
3281 .filter_map(|source_name| {
3282 resolver
3283 .get_source_columns(source_name)
3284 .ok()
3285 .filter(|columns| !columns.is_empty() && source_has_column(columns, &col_name))
3286 .map(|_| source_name.clone())
3287 })
3288 .collect();
3289
3290 if !matching_sources.is_empty() {
3291 continue;
3292 }
3293
3294 let known_sources: Vec<String> = source_names
3295 .iter()
3296 .filter_map(|source_name| {
3297 resolver
3298 .get_source_columns(source_name)
3299 .ok()
3300 .filter(|columns| !columns.is_empty() && !columns.iter().any(|c| c == "*"))
3301 .map(|_| source_name.clone())
3302 })
3303 .collect();
3304
3305 if known_sources.len() == 1 {
3306 let table_name = source_display_name(&scope, &known_sources[0]);
3307 errors.push(if strict {
3308 ValidationError::error(
3309 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3310 validation_codes::E_UNKNOWN_COLUMN,
3311 )
3312 } else {
3313 ValidationError::warning(
3314 format!("Unknown column '{}' in table '{}'", col_name, table_name),
3315 validation_codes::E_UNKNOWN_COLUMN,
3316 )
3317 });
3318 } else if known_sources.len() > 1 {
3319 errors.push(if strict {
3320 ValidationError::error(
3321 format!(
3322 "Unknown column '{}' (not found in any referenced table)",
3323 col_name
3324 ),
3325 validation_codes::E_UNKNOWN_COLUMN,
3326 )
3327 } else {
3328 ValidationError::warning(
3329 format!(
3330 "Unknown column '{}' (not found in any referenced table)",
3331 col_name
3332 ),
3333 validation_codes::E_UNKNOWN_COLUMN,
3334 )
3335 });
3336 } else if !schema_map.is_empty() {
3337 let found = schema_map
3338 .values()
3339 .any(|table_schema| table_schema.columns.contains_key(&col_name));
3340 if !found {
3341 errors.push(if strict {
3342 ValidationError::error(
3343 format!("Unknown column '{}'", col_name),
3344 validation_codes::E_UNKNOWN_COLUMN,
3345 )
3346 } else {
3347 ValidationError::warning(
3348 format!("Unknown column '{}'", col_name),
3349 validation_codes::E_UNKNOWN_COLUMN,
3350 )
3351 });
3352 }
3353 }
3354 }
3355
3356 errors
3357}
3358
3359fn validate_statement_with_schema(
3360 stmt: &Expression,
3361 schema_map: &HashMap<String, TableSchemaEntry>,
3362 resolver_schema: &MappingSchema,
3363 strict: bool,
3364) -> Vec<ValidationError> {
3365 let mut errors = Vec::new();
3366 let cte_aliases = collect_cte_aliases(stmt);
3367 let mut seen_tables: HashSet<String> = HashSet::new();
3368
3369 for node in stmt.find_all(|e| matches!(e, Expression::Table(_))) {
3371 let Expression::Table(table) = node else {
3372 continue;
3373 };
3374
3375 if cte_aliases.contains(&lower(&table.name.name)) {
3376 continue;
3377 }
3378
3379 let resolved_key = table_ref_candidates(table)
3380 .into_iter()
3381 .find(|k| schema_map.contains_key(k));
3382 let table_key = resolved_key
3383 .clone()
3384 .unwrap_or_else(|| lower(&table_ref_display_name(table)));
3385
3386 if !seen_tables.insert(table_key) {
3387 continue;
3388 }
3389
3390 if resolved_key.is_none() {
3391 errors.push(if strict {
3392 ValidationError::error(
3393 format!("Unknown table '{}'", table_ref_display_name(table)),
3394 validation_codes::E_UNKNOWN_TABLE,
3395 )
3396 } else {
3397 ValidationError::warning(
3398 format!("Unknown table '{}'", table_ref_display_name(table)),
3399 validation_codes::E_UNKNOWN_TABLE,
3400 )
3401 });
3402 }
3403 }
3404
3405 for node in stmt.dfs() {
3406 let Expression::Select(select) = node else {
3407 continue;
3408 };
3409 errors.extend(validate_select_columns_with_schema(
3410 select,
3411 schema_map,
3412 resolver_schema,
3413 strict,
3414 ));
3415 }
3416
3417 errors
3418}
3419
3420pub fn validate_with_schema(
3422 sql: &str,
3423 dialect: DialectType,
3424 schema: &ValidationSchema,
3425 options: &SchemaValidationOptions,
3426) -> ValidationResult {
3427 let strict = options.strict.unwrap_or(schema.strict.unwrap_or(true));
3428
3429 let syntax_result = crate::validate_with_options(
3431 sql,
3432 dialect,
3433 &crate::ValidationOptions {
3434 strict_syntax: options.strict_syntax,
3435 },
3436 );
3437 if !syntax_result.valid {
3438 return syntax_result;
3439 }
3440
3441 let d = Dialect::get(dialect);
3442 let statements = match d.parse(sql) {
3443 Ok(exprs) => exprs,
3444 Err(e) => {
3445 return ValidationResult::with_errors(vec![ValidationError::error(
3446 e.to_string(),
3447 validation_codes::E_PARSE_OR_OPTIONS,
3448 )]);
3449 }
3450 };
3451
3452 let schema_map = build_schema_map(schema);
3453 let resolver_schema = build_resolver_schema(schema);
3454 let mut all_errors = syntax_result.errors;
3455 let embedded_function_catalog = if options.check_types && options.function_catalog.is_none() {
3456 default_embedded_function_catalog()
3457 } else {
3458 None
3459 };
3460 let effective_function_catalog = options
3461 .function_catalog
3462 .as_deref()
3463 .or_else(|| embedded_function_catalog.as_deref());
3464 let declared_relationships = if options.check_references {
3465 build_declared_relationships(schema, &schema_map)
3466 } else {
3467 Vec::new()
3468 };
3469
3470 if options.check_references {
3471 all_errors.extend(check_reference_integrity(schema, &schema_map, strict));
3472 }
3473
3474 for stmt in &statements {
3475 if options.semantic {
3476 all_errors.extend(check_semantics(stmt));
3477 }
3478 all_errors.extend(validate_statement_with_schema(
3479 stmt,
3480 &schema_map,
3481 &resolver_schema,
3482 strict,
3483 ));
3484 if options.check_types {
3485 all_errors.extend(check_types(
3486 stmt,
3487 dialect,
3488 &schema_map,
3489 effective_function_catalog,
3490 strict,
3491 ));
3492 }
3493 if options.check_references {
3494 all_errors.extend(check_query_reference_quality(
3495 stmt,
3496 &schema_map,
3497 &resolver_schema,
3498 strict,
3499 &declared_relationships,
3500 ));
3501 }
3502 }
3503
3504 ValidationResult::with_errors(all_errors)
3505}
3506
3507#[cfg(test)]
3508mod tests;