1use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44 TableNotFound(String),
46 ColumnNotFound { table: String, column: String },
48 DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56 SchemaError::ColumnNotFound { table, column } => {
57 write!(f, "Column '{column}' not found in table '{table}'")
58 }
59 SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60 }
61 }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67 fn from(e: SchemaError) -> Self {
68 SqlglotError::Internal(e.to_string())
69 }
70}
71
72pub trait Schema {
77 fn add_table(
87 &mut self,
88 table_path: &[&str],
89 columns: Vec<(String, DataType)>,
90 ) -> Result<(), SchemaError>;
91
92 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99 fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError>;
105
106 fn has_column(&self, table_path: &[&str], column: &str) -> bool;
108
109 fn dialect(&self) -> Dialect;
111}
112
113#[derive(Debug, Clone, PartialEq)]
115struct ColumnInfo {
116 columns: Vec<(String, DataType)>,
118 index: HashMap<String, usize>,
120}
121
122impl ColumnInfo {
123 fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
124 let index = columns
125 .iter()
126 .enumerate()
127 .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
128 .collect();
129 Self { columns, index }
130 }
131
132 fn column_names(&self) -> Vec<String> {
133 self.columns.iter().map(|(n, _)| n.clone()).collect()
134 }
135
136 fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
137 let key = normalize_identifier(column, dialect);
138 self.index.get(&key).map(|&i| &self.columns[i].1)
139 }
140
141 fn has_column(&self, column: &str, dialect: Dialect) -> bool {
142 let key = normalize_identifier(column, dialect);
143 self.index.contains_key(&key)
144 }
145}
146
147#[derive(Debug, Clone)]
157pub struct MappingSchema {
158 dialect: Dialect,
159 tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
161 udf_types: HashMap<String, DataType>,
163}
164
165impl MappingSchema {
166 #[must_use]
168 pub fn new(dialect: Dialect) -> Self {
169 Self {
170 dialect,
171 tables: HashMap::new(),
172 udf_types: HashMap::new(),
173 }
174 }
175
176 pub fn replace_table(
178 &mut self,
179 table_path: &[&str],
180 columns: Vec<(String, DataType)>,
181 ) -> Result<(), SchemaError> {
182 let (catalog, database, table) = self.resolve_path(table_path)?;
183 let info = ColumnInfo::new(columns, self.dialect);
184 self.tables
185 .entry(catalog)
186 .or_default()
187 .entry(database)
188 .or_default()
189 .insert(table, info);
190 Ok(())
191 }
192
193 pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
195 let (catalog, database, table) = self.resolve_path(table_path)?;
196 let removed = self
197 .tables
198 .get_mut(&catalog)
199 .and_then(|dbs| dbs.get_mut(&database))
200 .map(|tbls| tbls.remove(&table).is_some())
201 .unwrap_or(false);
202 Ok(removed)
203 }
204
205 pub fn add_udf(&mut self, name: &str, return_type: DataType) {
207 let key = normalize_identifier(name, self.dialect);
208 self.udf_types.insert(key, return_type);
209 }
210
211 #[must_use]
213 pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
214 let key = normalize_identifier(name, self.dialect);
215 self.udf_types.get(&key)
216 }
217
218 #[must_use]
220 pub fn table_names(&self) -> Vec<(String, String, String)> {
221 let mut result = Vec::new();
222 for (catalog, dbs) in &self.tables {
223 for (database, tbls) in dbs {
224 for table in tbls.keys() {
225 result.push((catalog.clone(), database.clone(), table.clone()));
226 }
227 }
228 }
229 result
230 }
231
232 fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
235 let (catalog, database, table) = match self.resolve_path(table_path) {
236 Ok(parts) => parts,
237 Err(_) => return None,
238 };
239
240 if let Some(info) = self
242 .tables
243 .get(&catalog)
244 .and_then(|dbs| dbs.get(&database))
245 .and_then(|tbls| tbls.get(&table))
246 {
247 return Some(info);
248 }
249
250 if table_path.len() == 1 {
252 let norm_name = normalize_identifier(table_path[0], self.dialect);
253 for dbs in self.tables.values() {
254 for tbls in dbs.values() {
255 if let Some(info) = tbls.get(&norm_name) {
256 return Some(info);
257 }
258 }
259 }
260 }
261
262 if table_path.len() == 2 {
264 let norm_db = normalize_identifier(table_path[0], self.dialect);
265 let norm_tbl = normalize_identifier(table_path[1], self.dialect);
266 for dbs in self.tables.values() {
267 if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
268 return Some(info);
269 }
270 }
271 }
272
273 None
274 }
275
276 fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
279 match table_path.len() {
280 1 => Ok((
281 String::new(),
282 String::new(),
283 normalize_identifier(table_path[0], self.dialect),
284 )),
285 2 => Ok((
286 String::new(),
287 normalize_identifier(table_path[0], self.dialect),
288 normalize_identifier(table_path[1], self.dialect),
289 )),
290 3 => Ok((
291 normalize_identifier(table_path[0], self.dialect),
292 normalize_identifier(table_path[1], self.dialect),
293 normalize_identifier(table_path[2], self.dialect),
294 )),
295 _ => Err(SchemaError::TableNotFound(table_path.join("."))),
296 }
297 }
298
299 fn format_table_path(table_path: &[&str]) -> String {
300 table_path.join(".")
301 }
302}
303
304impl Schema for MappingSchema {
305 fn add_table(
306 &mut self,
307 table_path: &[&str],
308 columns: Vec<(String, DataType)>,
309 ) -> Result<(), SchemaError> {
310 let (catalog, database, table) = self.resolve_path(table_path)?;
311 let entry = self
312 .tables
313 .entry(catalog)
314 .or_default()
315 .entry(database)
316 .or_default();
317
318 if entry.contains_key(&table) {
319 return Err(SchemaError::DuplicateTable(Self::format_table_path(
320 table_path,
321 )));
322 }
323
324 let info = ColumnInfo::new(columns, self.dialect);
325 entry.insert(table, info);
326 Ok(())
327 }
328
329 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
330 self.find_table(table_path)
331 .map(|info| info.column_names())
332 .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
333 }
334
335 fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError> {
336 let table_str = Self::format_table_path(table_path);
337 let info = self
338 .find_table(table_path)
339 .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
340
341 info.get_type(column, self.dialect)
342 .cloned()
343 .ok_or(SchemaError::ColumnNotFound {
344 table: table_str,
345 column: column.to_string(),
346 })
347 }
348
349 fn has_column(&self, table_path: &[&str], column: &str) -> bool {
350 self.find_table(table_path)
351 .is_some_and(|info| info.has_column(column, self.dialect))
352 }
353
354 fn dialect(&self) -> Dialect {
355 self.dialect
356 }
357}
358
359#[must_use]
370pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
371 if is_case_sensitive_dialect(dialect) {
372 name.to_string()
373 } else {
374 name.to_lowercase()
375 }
376}
377
378#[must_use]
380pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
381 matches!(
382 dialect,
383 Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
384 )
385}
386
387pub fn ensure_schema(
414 tables: HashMap<String, HashMap<String, DataType>>,
415 dialect: Dialect,
416) -> MappingSchema {
417 let mut schema = MappingSchema::new(dialect);
418 for (table_name, columns) in tables {
419 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
420 let _ = schema.replace_table(&[&table_name], col_vec);
422 }
423 schema
424}
425
426pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
429
430pub fn ensure_schema_nested(catalog_map: CatalogMap, dialect: Dialect) -> MappingSchema {
433 let mut schema = MappingSchema::new(dialect);
434 for (catalog, databases) in catalog_map {
435 for (database, tables) in databases {
436 for (table, columns) in tables {
437 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
438 let _ = schema.replace_table(&[&catalog, &database, &table], col_vec);
439 }
440 }
441 }
442 schema
443}
444
445#[cfg(test)]
446mod tests {
447 use super::*;
448
449 #[test]
452 fn test_add_and_query_table() {
453 let mut schema = MappingSchema::new(Dialect::Ansi);
454 schema
455 .add_table(
456 &["users"],
457 vec![
458 ("id".to_string(), DataType::Int),
459 ("name".to_string(), DataType::Varchar(Some(255))),
460 ("email".to_string(), DataType::Text),
461 ],
462 )
463 .unwrap();
464
465 assert_eq!(
466 schema.column_names(&["users"]).unwrap(),
467 vec!["id", "name", "email"]
468 );
469 assert_eq!(
470 schema.get_column_type(&["users"], "id").unwrap(),
471 DataType::Int
472 );
473 assert_eq!(
474 schema.get_column_type(&["users"], "name").unwrap(),
475 DataType::Varchar(Some(255))
476 );
477 assert!(schema.has_column(&["users"], "id"));
478 assert!(schema.has_column(&["users"], "email"));
479 assert!(!schema.has_column(&["users"], "nonexistent"));
480 }
481
482 #[test]
483 fn test_duplicate_table_error() {
484 let mut schema = MappingSchema::new(Dialect::Ansi);
485 schema
486 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
487 .unwrap();
488
489 let err = schema
490 .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
491 .unwrap_err();
492 assert!(matches!(err, SchemaError::DuplicateTable(_)));
493 }
494
495 #[test]
496 fn test_replace_table() {
497 let mut schema = MappingSchema::new(Dialect::Ansi);
498 schema
499 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
500 .unwrap();
501
502 schema
503 .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
504 .unwrap();
505
506 assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
507 assert_eq!(schema.get_column_type(&["t"], "b").unwrap(), DataType::Text);
508 }
509
510 #[test]
511 fn test_remove_table() {
512 let mut schema = MappingSchema::new(Dialect::Ansi);
513 schema
514 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
515 .unwrap();
516
517 assert!(schema.remove_table(&["t"]).unwrap());
518 assert!(!schema.remove_table(&["t"]).unwrap());
519 assert!(schema.column_names(&["t"]).is_err());
520 }
521
522 #[test]
523 fn test_table_not_found() {
524 let schema = MappingSchema::new(Dialect::Ansi);
525 let err = schema.column_names(&["nonexistent"]).unwrap_err();
526 assert!(matches!(err, SchemaError::TableNotFound(_)));
527 }
528
529 #[test]
530 fn test_column_not_found() {
531 let mut schema = MappingSchema::new(Dialect::Ansi);
532 schema
533 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
534 .unwrap();
535
536 let err = schema.get_column_type(&["t"], "z").unwrap_err();
537 assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
538 }
539
540 #[test]
543 fn test_three_level_path() {
544 let mut schema = MappingSchema::new(Dialect::Ansi);
545 schema
546 .add_table(
547 &["my_catalog", "my_db", "orders"],
548 vec![
549 ("order_id".to_string(), DataType::BigInt),
550 (
551 "total".to_string(),
552 DataType::Decimal {
553 precision: Some(10),
554 scale: Some(2),
555 },
556 ),
557 ],
558 )
559 .unwrap();
560
561 assert_eq!(
562 schema
563 .column_names(&["my_catalog", "my_db", "orders"])
564 .unwrap(),
565 vec!["order_id", "total"]
566 );
567 assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
568 }
569
570 #[test]
571 fn test_two_level_path() {
572 let mut schema = MappingSchema::new(Dialect::Ansi);
573 schema
574 .add_table(
575 &["public", "users"],
576 vec![("id".to_string(), DataType::Int)],
577 )
578 .unwrap();
579
580 assert_eq!(
581 schema.column_names(&["public", "users"]).unwrap(),
582 vec!["id"]
583 );
584 }
585
586 #[test]
587 fn test_short_path_searches_all() {
588 let mut schema = MappingSchema::new(Dialect::Ansi);
589 schema
590 .add_table(
591 &["catalog", "db", "orders"],
592 vec![("id".to_string(), DataType::Int)],
593 )
594 .unwrap();
595
596 assert!(schema.has_column(&["orders"], "id"));
598 assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
599
600 assert!(schema.has_column(&["db", "orders"], "id"));
602 }
603
604 #[test]
607 fn test_case_insensitive_dialect() {
608 let mut schema = MappingSchema::new(Dialect::Postgres);
609 schema
610 .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
611 .unwrap();
612
613 assert!(schema.has_column(&["users"], "id"));
615 assert!(schema.has_column(&["USERS"], "ID"));
616 assert!(schema.has_column(&["Users"], "Id"));
617 assert_eq!(
618 schema.get_column_type(&["users"], "id").unwrap(),
619 DataType::Int
620 );
621 }
622
623 #[test]
624 fn test_case_sensitive_dialect() {
625 let mut schema = MappingSchema::new(Dialect::BigQuery);
626 schema
627 .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
628 .unwrap();
629
630 assert!(schema.has_column(&["Users"], "ID"));
632 assert!(!schema.has_column(&["users"], "ID"));
633 assert!(!schema.has_column(&["Users"], "id"));
634 }
635
636 #[test]
637 fn test_hive_case_sensitive() {
638 let mut schema = MappingSchema::new(Dialect::Hive);
639 schema
640 .add_table(&["MyTable"], vec![("Col1".to_string(), DataType::Text)])
641 .unwrap();
642
643 assert!(schema.has_column(&["MyTable"], "Col1"));
644 assert!(!schema.has_column(&["mytable"], "col1"));
645 }
646
647 #[test]
650 fn test_udf_return_type() {
651 let mut schema = MappingSchema::new(Dialect::Ansi);
652 schema.add_udf("my_custom_fn", DataType::Int);
653
654 assert_eq!(schema.get_udf_type("my_custom_fn").unwrap(), &DataType::Int);
655 assert_eq!(schema.get_udf_type("MY_CUSTOM_FN").unwrap(), &DataType::Int);
657 assert!(schema.get_udf_type("nonexistent").is_none());
658 }
659
660 #[test]
661 fn test_udf_case_sensitive() {
662 let mut schema = MappingSchema::new(Dialect::BigQuery);
663 schema.add_udf("myFunc", DataType::Boolean);
664
665 assert!(schema.get_udf_type("myFunc").is_some());
666 assert!(schema.get_udf_type("MYFUNC").is_none());
667 }
668
669 #[test]
672 fn test_ensure_schema() {
673 let mut tables = HashMap::new();
674 let mut cols = HashMap::new();
675 cols.insert("id".to_string(), DataType::Int);
676 cols.insert("name".to_string(), DataType::Text);
677 tables.insert("users".to_string(), cols);
678
679 let schema = ensure_schema(tables, Dialect::Postgres);
680 assert!(schema.has_column(&["users"], "id"));
681 assert!(schema.has_column(&["users"], "name"));
682 }
683
684 #[test]
685 fn test_ensure_schema_nested() {
686 let mut catalogs = HashMap::new();
687 let mut databases = HashMap::new();
688 let mut tables = HashMap::new();
689 let mut cols = HashMap::new();
690 cols.insert("order_id".to_string(), DataType::BigInt);
691 tables.insert("orders".to_string(), cols);
692 databases.insert("sales".to_string(), tables);
693 catalogs.insert("warehouse".to_string(), databases);
694
695 let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
696 assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
697 assert!(schema.has_column(&["orders"], "order_id"));
699 }
700
701 #[test]
704 fn test_table_names() {
705 let mut schema = MappingSchema::new(Dialect::Ansi);
706 schema
707 .add_table(&["cat", "db", "t1"], vec![("a".to_string(), DataType::Int)])
708 .unwrap();
709 schema
710 .add_table(&["cat", "db", "t2"], vec![("b".to_string(), DataType::Int)])
711 .unwrap();
712
713 let mut names = schema.table_names();
714 names.sort();
715 assert_eq!(names.len(), 2);
716 assert!(
717 names
718 .iter()
719 .any(|(c, d, t)| c == "cat" && d == "db" && t == "t1")
720 );
721 assert!(
722 names
723 .iter()
724 .any(|(c, d, t)| c == "cat" && d == "db" && t == "t2")
725 );
726 }
727
728 #[test]
731 fn test_invalid_path_too_many_parts() {
732 let mut schema = MappingSchema::new(Dialect::Ansi);
733 let err = schema
734 .add_table(
735 &["a", "b", "c", "d"],
736 vec![("x".to_string(), DataType::Int)],
737 )
738 .unwrap_err();
739 assert!(matches!(err, SchemaError::TableNotFound(_)));
740 }
741
742 #[test]
743 fn test_empty_schema_has_no_columns() {
744 let schema = MappingSchema::new(Dialect::Ansi);
745 assert!(!schema.has_column(&["any_table"], "any_col"));
746 }
747
748 #[test]
751 fn test_schema_error_display() {
752 let e = SchemaError::TableNotFound("users".to_string());
753 assert_eq!(e.to_string(), "Table not found: users");
754
755 let e = SchemaError::ColumnNotFound {
756 table: "users".to_string(),
757 column: "age".to_string(),
758 };
759 assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
760
761 let e = SchemaError::DuplicateTable("users".to_string());
762 assert_eq!(e.to_string(), "Table already exists: users");
763 }
764
765 #[test]
768 fn test_schema_error_into_sqlglot_error() {
769 let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
770 assert!(matches!(e, SqlglotError::Internal(_)));
771 }
772
773 #[test]
776 fn test_multiple_dialects_normalization() {
777 let mut pg = MappingSchema::new(Dialect::Postgres);
779 pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
780 .unwrap();
781 assert!(pg.has_column(&["t"], "c"));
782
783 let mut my = MappingSchema::new(Dialect::Mysql);
785 my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
786 .unwrap();
787 assert!(my.has_column(&["t"], "c"));
788
789 let mut sp = MappingSchema::new(Dialect::Spark);
791 sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
792 .unwrap();
793 assert!(!sp.has_column(&["t"], "c"));
794 assert!(sp.has_column(&["T"], "C"));
795 }
796
797 #[test]
800 fn test_complex_data_types() {
801 let mut schema = MappingSchema::new(Dialect::Ansi);
802 schema
803 .add_table(
804 &["complex_table"],
805 vec![
806 (
807 "tags".to_string(),
808 DataType::Array(Some(Box::new(DataType::Text))),
809 ),
810 ("metadata".to_string(), DataType::Json),
811 (
812 "coords".to_string(),
813 DataType::Struct(vec![
814 ("lat".to_string(), DataType::Double),
815 ("lng".to_string(), DataType::Double),
816 ]),
817 ),
818 (
819 "lookup".to_string(),
820 DataType::Map {
821 key: Box::new(DataType::Text),
822 value: Box::new(DataType::Int),
823 },
824 ),
825 ],
826 )
827 .unwrap();
828
829 assert_eq!(
830 schema.get_column_type(&["complex_table"], "tags").unwrap(),
831 DataType::Array(Some(Box::new(DataType::Text)))
832 );
833 assert_eq!(
834 schema
835 .get_column_type(&["complex_table"], "metadata")
836 .unwrap(),
837 DataType::Json
838 );
839 }
840
841 #[test]
844 fn test_schema_dialect() {
845 let schema = MappingSchema::new(Dialect::Snowflake);
846 assert_eq!(schema.dialect(), Dialect::Snowflake);
847 }
848}