1use std::collections::HashMap;
36
37use crate::ast::DataType;
38use crate::dialects::Dialect;
39use crate::errors::SqlglotError;
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SchemaError {
44 TableNotFound(String),
46 ColumnNotFound { table: String, column: String },
48 DuplicateTable(String),
50}
51
52impl std::fmt::Display for SchemaError {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 SchemaError::TableNotFound(t) => write!(f, "Table not found: {t}"),
56 SchemaError::ColumnNotFound { table, column } => {
57 write!(f, "Column '{column}' not found in table '{table}'")
58 }
59 SchemaError::DuplicateTable(t) => write!(f, "Table already exists: {t}"),
60 }
61 }
62}
63
64impl std::error::Error for SchemaError {}
65
66impl From<SchemaError> for SqlglotError {
67 fn from(e: SchemaError) -> Self {
68 SqlglotError::Internal(e.to_string())
69 }
70}
71
72pub trait Schema {
77 fn add_table(
87 &mut self,
88 table_path: &[&str],
89 columns: Vec<(String, DataType)>,
90 ) -> Result<(), SchemaError>;
91
92 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError>;
98
99 fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError>;
105
106 fn has_column(&self, table_path: &[&str], column: &str) -> bool;
108
109 fn dialect(&self) -> Dialect;
111
112 fn get_udf_type(&self, _name: &str) -> Option<&DataType> {
117 None
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
123struct ColumnInfo {
124 columns: Vec<(String, DataType)>,
126 index: HashMap<String, usize>,
128}
129
130impl ColumnInfo {
131 fn new(columns: Vec<(String, DataType)>, dialect: Dialect) -> Self {
132 let index = columns
133 .iter()
134 .enumerate()
135 .map(|(i, (name, _))| (normalize_identifier(name, dialect), i))
136 .collect();
137 Self { columns, index }
138 }
139
140 fn column_names(&self) -> Vec<String> {
141 self.columns.iter().map(|(n, _)| n.clone()).collect()
142 }
143
144 fn get_type(&self, column: &str, dialect: Dialect) -> Option<&DataType> {
145 let key = normalize_identifier(column, dialect);
146 self.index.get(&key).map(|&i| &self.columns[i].1)
147 }
148
149 fn has_column(&self, column: &str, dialect: Dialect) -> bool {
150 let key = normalize_identifier(column, dialect);
151 self.index.contains_key(&key)
152 }
153}
154
155#[derive(Debug, Clone)]
165pub struct MappingSchema {
166 dialect: Dialect,
167 tables: HashMap<String, HashMap<String, HashMap<String, ColumnInfo>>>,
169 udf_types: HashMap<String, DataType>,
171}
172
173impl MappingSchema {
174 #[must_use]
176 pub fn new(dialect: Dialect) -> Self {
177 Self {
178 dialect,
179 tables: HashMap::new(),
180 udf_types: HashMap::new(),
181 }
182 }
183
184 pub fn replace_table(
186 &mut self,
187 table_path: &[&str],
188 columns: Vec<(String, DataType)>,
189 ) -> Result<(), SchemaError> {
190 let (catalog, database, table) = self.resolve_path(table_path)?;
191 let info = ColumnInfo::new(columns, self.dialect);
192 self.tables
193 .entry(catalog)
194 .or_default()
195 .entry(database)
196 .or_default()
197 .insert(table, info);
198 Ok(())
199 }
200
201 pub fn remove_table(&mut self, table_path: &[&str]) -> Result<bool, SchemaError> {
203 let (catalog, database, table) = self.resolve_path(table_path)?;
204 let removed = self
205 .tables
206 .get_mut(&catalog)
207 .and_then(|dbs| dbs.get_mut(&database))
208 .map(|tbls| tbls.remove(&table).is_some())
209 .unwrap_or(false);
210 Ok(removed)
211 }
212
213 pub fn add_udf(&mut self, name: &str, return_type: DataType) {
215 let key = normalize_identifier(name, self.dialect);
216 self.udf_types.insert(key, return_type);
217 }
218
219 #[must_use]
221 pub fn get_udf_type(&self, name: &str) -> Option<&DataType> {
222 let key = normalize_identifier(name, self.dialect);
223 self.udf_types.get(&key)
224 }
225
226 #[must_use]
228 pub fn table_names(&self) -> Vec<(String, String, String)> {
229 let mut result = Vec::new();
230 for (catalog, dbs) in &self.tables {
231 for (database, tbls) in dbs {
232 for table in tbls.keys() {
233 result.push((catalog.clone(), database.clone(), table.clone()));
234 }
235 }
236 }
237 result
238 }
239
240 fn find_table(&self, table_path: &[&str]) -> Option<&ColumnInfo> {
243 let (catalog, database, table) = match self.resolve_path(table_path) {
244 Ok(parts) => parts,
245 Err(_) => return None,
246 };
247
248 if let Some(info) = self
250 .tables
251 .get(&catalog)
252 .and_then(|dbs| dbs.get(&database))
253 .and_then(|tbls| tbls.get(&table))
254 {
255 return Some(info);
256 }
257
258 if table_path.len() == 1 {
260 let norm_name = normalize_identifier(table_path[0], self.dialect);
261 for dbs in self.tables.values() {
262 for tbls in dbs.values() {
263 if let Some(info) = tbls.get(&norm_name) {
264 return Some(info);
265 }
266 }
267 }
268 }
269
270 if table_path.len() == 2 {
272 let norm_db = normalize_identifier(table_path[0], self.dialect);
273 let norm_tbl = normalize_identifier(table_path[1], self.dialect);
274 for dbs in self.tables.values() {
275 if let Some(info) = dbs.get(&norm_db).and_then(|tbls| tbls.get(&norm_tbl)) {
276 return Some(info);
277 }
278 }
279 }
280
281 None
282 }
283
284 fn resolve_path(&self, table_path: &[&str]) -> Result<(String, String, String), SchemaError> {
287 match table_path.len() {
288 1 => Ok((
289 String::new(),
290 String::new(),
291 normalize_identifier(table_path[0], self.dialect),
292 )),
293 2 => Ok((
294 String::new(),
295 normalize_identifier(table_path[0], self.dialect),
296 normalize_identifier(table_path[1], self.dialect),
297 )),
298 3 => Ok((
299 normalize_identifier(table_path[0], self.dialect),
300 normalize_identifier(table_path[1], self.dialect),
301 normalize_identifier(table_path[2], self.dialect),
302 )),
303 _ => Err(SchemaError::TableNotFound(table_path.join("."))),
304 }
305 }
306
307 fn format_table_path(table_path: &[&str]) -> String {
308 table_path.join(".")
309 }
310}
311
312impl Schema for MappingSchema {
313 fn add_table(
314 &mut self,
315 table_path: &[&str],
316 columns: Vec<(String, DataType)>,
317 ) -> Result<(), SchemaError> {
318 let (catalog, database, table) = self.resolve_path(table_path)?;
319 let entry = self
320 .tables
321 .entry(catalog)
322 .or_default()
323 .entry(database)
324 .or_default();
325
326 if entry.contains_key(&table) {
327 return Err(SchemaError::DuplicateTable(Self::format_table_path(
328 table_path,
329 )));
330 }
331
332 let info = ColumnInfo::new(columns, self.dialect);
333 entry.insert(table, info);
334 Ok(())
335 }
336
337 fn column_names(&self, table_path: &[&str]) -> Result<Vec<String>, SchemaError> {
338 self.find_table(table_path)
339 .map(|info| info.column_names())
340 .ok_or_else(|| SchemaError::TableNotFound(Self::format_table_path(table_path)))
341 }
342
343 fn get_column_type(&self, table_path: &[&str], column: &str) -> Result<DataType, SchemaError> {
344 let table_str = Self::format_table_path(table_path);
345 let info = self
346 .find_table(table_path)
347 .ok_or_else(|| SchemaError::TableNotFound(table_str.clone()))?;
348
349 info.get_type(column, self.dialect)
350 .cloned()
351 .ok_or(SchemaError::ColumnNotFound {
352 table: table_str,
353 column: column.to_string(),
354 })
355 }
356
357 fn has_column(&self, table_path: &[&str], column: &str) -> bool {
358 self.find_table(table_path)
359 .is_some_and(|info| info.has_column(column, self.dialect))
360 }
361
362 fn dialect(&self) -> Dialect {
363 self.dialect
364 }
365
366 fn get_udf_type(&self, name: &str) -> Option<&DataType> {
367 let key = normalize_identifier(name, self.dialect);
368 self.udf_types.get(&key)
369 }
370}
371
372#[must_use]
383pub fn normalize_identifier(name: &str, dialect: Dialect) -> String {
384 if is_case_sensitive_dialect(dialect) {
385 name.to_string()
386 } else {
387 name.to_lowercase()
388 }
389}
390
391#[must_use]
393pub fn is_case_sensitive_dialect(dialect: Dialect) -> bool {
394 matches!(
395 dialect,
396 Dialect::BigQuery | Dialect::Hive | Dialect::Spark | Dialect::Databricks
397 )
398}
399
400pub fn ensure_schema(
427 tables: HashMap<String, HashMap<String, DataType>>,
428 dialect: Dialect,
429) -> MappingSchema {
430 let mut schema = MappingSchema::new(dialect);
431 for (table_name, columns) in tables {
432 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
433 let _ = schema.replace_table(&[&table_name], col_vec);
435 }
436 schema
437}
438
439pub type CatalogMap = HashMap<String, HashMap<String, HashMap<String, HashMap<String, DataType>>>>;
442
443pub fn ensure_schema_nested(catalog_map: CatalogMap, dialect: Dialect) -> MappingSchema {
446 let mut schema = MappingSchema::new(dialect);
447 for (catalog, databases) in catalog_map {
448 for (database, tables) in databases {
449 for (table, columns) in tables {
450 let col_vec: Vec<(String, DataType)> = columns.into_iter().collect();
451 let _ = schema.replace_table(&[&catalog, &database, &table], col_vec);
452 }
453 }
454 }
455 schema
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461
462 #[test]
465 fn test_add_and_query_table() {
466 let mut schema = MappingSchema::new(Dialect::Ansi);
467 schema
468 .add_table(
469 &["users"],
470 vec![
471 ("id".to_string(), DataType::Int),
472 ("name".to_string(), DataType::Varchar(Some(255))),
473 ("email".to_string(), DataType::Text),
474 ],
475 )
476 .unwrap();
477
478 assert_eq!(
479 schema.column_names(&["users"]).unwrap(),
480 vec!["id", "name", "email"]
481 );
482 assert_eq!(
483 schema.get_column_type(&["users"], "id").unwrap(),
484 DataType::Int
485 );
486 assert_eq!(
487 schema.get_column_type(&["users"], "name").unwrap(),
488 DataType::Varchar(Some(255))
489 );
490 assert!(schema.has_column(&["users"], "id"));
491 assert!(schema.has_column(&["users"], "email"));
492 assert!(!schema.has_column(&["users"], "nonexistent"));
493 }
494
495 #[test]
496 fn test_duplicate_table_error() {
497 let mut schema = MappingSchema::new(Dialect::Ansi);
498 schema
499 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
500 .unwrap();
501
502 let err = schema
503 .add_table(&["t"], vec![("b".to_string(), DataType::Text)])
504 .unwrap_err();
505 assert!(matches!(err, SchemaError::DuplicateTable(_)));
506 }
507
508 #[test]
509 fn test_replace_table() {
510 let mut schema = MappingSchema::new(Dialect::Ansi);
511 schema
512 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
513 .unwrap();
514
515 schema
516 .replace_table(&["t"], vec![("b".to_string(), DataType::Text)])
517 .unwrap();
518
519 assert_eq!(schema.column_names(&["t"]).unwrap(), vec!["b"]);
520 assert_eq!(schema.get_column_type(&["t"], "b").unwrap(), DataType::Text);
521 }
522
523 #[test]
524 fn test_remove_table() {
525 let mut schema = MappingSchema::new(Dialect::Ansi);
526 schema
527 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
528 .unwrap();
529
530 assert!(schema.remove_table(&["t"]).unwrap());
531 assert!(!schema.remove_table(&["t"]).unwrap());
532 assert!(schema.column_names(&["t"]).is_err());
533 }
534
535 #[test]
536 fn test_table_not_found() {
537 let schema = MappingSchema::new(Dialect::Ansi);
538 let err = schema.column_names(&["nonexistent"]).unwrap_err();
539 assert!(matches!(err, SchemaError::TableNotFound(_)));
540 }
541
542 #[test]
543 fn test_column_not_found() {
544 let mut schema = MappingSchema::new(Dialect::Ansi);
545 schema
546 .add_table(&["t"], vec![("a".to_string(), DataType::Int)])
547 .unwrap();
548
549 let err = schema.get_column_type(&["t"], "z").unwrap_err();
550 assert!(matches!(err, SchemaError::ColumnNotFound { .. }));
551 }
552
553 #[test]
556 fn test_three_level_path() {
557 let mut schema = MappingSchema::new(Dialect::Ansi);
558 schema
559 .add_table(
560 &["my_catalog", "my_db", "orders"],
561 vec![
562 ("order_id".to_string(), DataType::BigInt),
563 (
564 "total".to_string(),
565 DataType::Decimal {
566 precision: Some(10),
567 scale: Some(2),
568 },
569 ),
570 ],
571 )
572 .unwrap();
573
574 assert_eq!(
575 schema
576 .column_names(&["my_catalog", "my_db", "orders"])
577 .unwrap(),
578 vec!["order_id", "total"]
579 );
580 assert!(schema.has_column(&["my_catalog", "my_db", "orders"], "order_id"));
581 }
582
583 #[test]
584 fn test_two_level_path() {
585 let mut schema = MappingSchema::new(Dialect::Ansi);
586 schema
587 .add_table(
588 &["public", "users"],
589 vec![("id".to_string(), DataType::Int)],
590 )
591 .unwrap();
592
593 assert_eq!(
594 schema.column_names(&["public", "users"]).unwrap(),
595 vec!["id"]
596 );
597 }
598
599 #[test]
600 fn test_short_path_searches_all() {
601 let mut schema = MappingSchema::new(Dialect::Ansi);
602 schema
603 .add_table(
604 &["catalog", "db", "orders"],
605 vec![("id".to_string(), DataType::Int)],
606 )
607 .unwrap();
608
609 assert!(schema.has_column(&["orders"], "id"));
611 assert_eq!(schema.column_names(&["orders"]).unwrap(), vec!["id"]);
612
613 assert!(schema.has_column(&["db", "orders"], "id"));
615 }
616
617 #[test]
620 fn test_case_insensitive_dialect() {
621 let mut schema = MappingSchema::new(Dialect::Postgres);
622 schema
623 .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
624 .unwrap();
625
626 assert!(schema.has_column(&["users"], "id"));
628 assert!(schema.has_column(&["USERS"], "ID"));
629 assert!(schema.has_column(&["Users"], "Id"));
630 assert_eq!(
631 schema.get_column_type(&["users"], "id").unwrap(),
632 DataType::Int
633 );
634 }
635
636 #[test]
637 fn test_case_sensitive_dialect() {
638 let mut schema = MappingSchema::new(Dialect::BigQuery);
639 schema
640 .add_table(&["Users"], vec![("ID".to_string(), DataType::Int)])
641 .unwrap();
642
643 assert!(schema.has_column(&["Users"], "ID"));
645 assert!(!schema.has_column(&["users"], "ID"));
646 assert!(!schema.has_column(&["Users"], "id"));
647 }
648
649 #[test]
650 fn test_hive_case_sensitive() {
651 let mut schema = MappingSchema::new(Dialect::Hive);
652 schema
653 .add_table(&["MyTable"], vec![("Col1".to_string(), DataType::Text)])
654 .unwrap();
655
656 assert!(schema.has_column(&["MyTable"], "Col1"));
657 assert!(!schema.has_column(&["mytable"], "col1"));
658 }
659
660 #[test]
663 fn test_udf_return_type() {
664 let mut schema = MappingSchema::new(Dialect::Ansi);
665 schema.add_udf("my_custom_fn", DataType::Int);
666
667 assert_eq!(schema.get_udf_type("my_custom_fn").unwrap(), &DataType::Int);
668 assert_eq!(schema.get_udf_type("MY_CUSTOM_FN").unwrap(), &DataType::Int);
670 assert!(schema.get_udf_type("nonexistent").is_none());
671 }
672
673 #[test]
674 fn test_udf_case_sensitive() {
675 let mut schema = MappingSchema::new(Dialect::BigQuery);
676 schema.add_udf("myFunc", DataType::Boolean);
677
678 assert!(schema.get_udf_type("myFunc").is_some());
679 assert!(schema.get_udf_type("MYFUNC").is_none());
680 }
681
682 #[test]
685 fn test_ensure_schema() {
686 let mut tables = HashMap::new();
687 let mut cols = HashMap::new();
688 cols.insert("id".to_string(), DataType::Int);
689 cols.insert("name".to_string(), DataType::Text);
690 tables.insert("users".to_string(), cols);
691
692 let schema = ensure_schema(tables, Dialect::Postgres);
693 assert!(schema.has_column(&["users"], "id"));
694 assert!(schema.has_column(&["users"], "name"));
695 }
696
697 #[test]
698 fn test_ensure_schema_nested() {
699 let mut catalogs = HashMap::new();
700 let mut databases = HashMap::new();
701 let mut tables = HashMap::new();
702 let mut cols = HashMap::new();
703 cols.insert("order_id".to_string(), DataType::BigInt);
704 tables.insert("orders".to_string(), cols);
705 databases.insert("sales".to_string(), tables);
706 catalogs.insert("warehouse".to_string(), databases);
707
708 let schema = ensure_schema_nested(catalogs, Dialect::Ansi);
709 assert!(schema.has_column(&["warehouse", "sales", "orders"], "order_id"));
710 assert!(schema.has_column(&["orders"], "order_id"));
712 }
713
714 #[test]
717 fn test_table_names() {
718 let mut schema = MappingSchema::new(Dialect::Ansi);
719 schema
720 .add_table(&["cat", "db", "t1"], vec![("a".to_string(), DataType::Int)])
721 .unwrap();
722 schema
723 .add_table(&["cat", "db", "t2"], vec![("b".to_string(), DataType::Int)])
724 .unwrap();
725
726 let mut names = schema.table_names();
727 names.sort();
728 assert_eq!(names.len(), 2);
729 assert!(
730 names
731 .iter()
732 .any(|(c, d, t)| c == "cat" && d == "db" && t == "t1")
733 );
734 assert!(
735 names
736 .iter()
737 .any(|(c, d, t)| c == "cat" && d == "db" && t == "t2")
738 );
739 }
740
741 #[test]
744 fn test_invalid_path_too_many_parts() {
745 let mut schema = MappingSchema::new(Dialect::Ansi);
746 let err = schema
747 .add_table(
748 &["a", "b", "c", "d"],
749 vec![("x".to_string(), DataType::Int)],
750 )
751 .unwrap_err();
752 assert!(matches!(err, SchemaError::TableNotFound(_)));
753 }
754
755 #[test]
756 fn test_empty_schema_has_no_columns() {
757 let schema = MappingSchema::new(Dialect::Ansi);
758 assert!(!schema.has_column(&["any_table"], "any_col"));
759 }
760
761 #[test]
764 fn test_schema_error_display() {
765 let e = SchemaError::TableNotFound("users".to_string());
766 assert_eq!(e.to_string(), "Table not found: users");
767
768 let e = SchemaError::ColumnNotFound {
769 table: "users".to_string(),
770 column: "age".to_string(),
771 };
772 assert_eq!(e.to_string(), "Column 'age' not found in table 'users'");
773
774 let e = SchemaError::DuplicateTable("users".to_string());
775 assert_eq!(e.to_string(), "Table already exists: users");
776 }
777
778 #[test]
781 fn test_schema_error_into_sqlglot_error() {
782 let e: SqlglotError = SchemaError::TableNotFound("t".to_string()).into();
783 assert!(matches!(e, SqlglotError::Internal(_)));
784 }
785
786 #[test]
789 fn test_multiple_dialects_normalization() {
790 let mut pg = MappingSchema::new(Dialect::Postgres);
792 pg.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
793 .unwrap();
794 assert!(pg.has_column(&["t"], "c"));
795
796 let mut my = MappingSchema::new(Dialect::Mysql);
798 my.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
799 .unwrap();
800 assert!(my.has_column(&["t"], "c"));
801
802 let mut sp = MappingSchema::new(Dialect::Spark);
804 sp.add_table(&["T"], vec![("C".to_string(), DataType::Int)])
805 .unwrap();
806 assert!(!sp.has_column(&["t"], "c"));
807 assert!(sp.has_column(&["T"], "C"));
808 }
809
810 #[test]
813 fn test_complex_data_types() {
814 let mut schema = MappingSchema::new(Dialect::Ansi);
815 schema
816 .add_table(
817 &["complex_table"],
818 vec![
819 (
820 "tags".to_string(),
821 DataType::Array(Some(Box::new(DataType::Text))),
822 ),
823 ("metadata".to_string(), DataType::Json),
824 (
825 "coords".to_string(),
826 DataType::Struct(vec![
827 ("lat".to_string(), DataType::Double),
828 ("lng".to_string(), DataType::Double),
829 ]),
830 ),
831 (
832 "lookup".to_string(),
833 DataType::Map {
834 key: Box::new(DataType::Text),
835 value: Box::new(DataType::Int),
836 },
837 ),
838 ],
839 )
840 .unwrap();
841
842 assert_eq!(
843 schema.get_column_type(&["complex_table"], "tags").unwrap(),
844 DataType::Array(Some(Box::new(DataType::Text)))
845 );
846 assert_eq!(
847 schema
848 .get_column_type(&["complex_table"], "metadata")
849 .unwrap(),
850 DataType::Json
851 );
852 }
853
854 #[test]
857 fn test_schema_dialect() {
858 let schema = MappingSchema::new(Dialect::Snowflake);
859 assert_eq!(schema.dialect(), Dialect::Snowflake);
860 }
861}