1use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19 #[error("Table not found: {0}")]
20 TableNotFound(String),
21
22 #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23 AmbiguousTable { table: String, matches: String },
24
25 #[error("Column not found: {column} in table {table}")]
26 ColumnNotFound { table: String, column: String },
27
28 #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29 DepthMismatch { expected: usize, actual: usize },
30
31 #[error("Invalid schema structure: {0}")]
32 InvalidStructure(String),
33}
34
35pub type SchemaResult<T> = Result<T, SchemaError>;
37
38pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41pub trait Schema {
43 fn dialect(&self) -> Option<DialectType>;
45
46 fn add_table(
48 &mut self,
49 table: &str,
50 columns: &[(String, DataType)],
51 dialect: Option<DialectType>,
52 ) -> SchemaResult<()>;
53
54 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60 fn has_column(&self, table: &str, column: &str) -> bool;
62
63 fn supported_table_args(&self) -> &[&str];
65
66 fn is_empty(&self) -> bool;
68
69 fn depth(&self) -> usize;
71
72 fn find_tables_for_column(&self, column: &str) -> Vec<String>;
75}
76
77#[derive(Debug, Clone)]
79pub struct ColumnInfo {
80 pub data_type: DataType,
81 pub visible: bool,
82}
83
84impl ColumnInfo {
85 pub fn new(data_type: DataType) -> Self {
86 Self {
87 data_type,
88 visible: true,
89 }
90 }
91
92 pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
93 Self { data_type, visible }
94 }
95}
96
97#[derive(Debug, Clone)]
104pub struct MappingSchema {
105 mapping: HashMap<String, SchemaNode>,
107 mapping_trie: Trie<()>,
109 dialect: Option<DialectType>,
111 normalize: bool,
113 visible: HashMap<String, HashSet<String>>,
115 cached_depth: usize,
117}
118
119#[derive(Debug, Clone)]
121pub enum SchemaNode {
122 Namespace(HashMap<String, SchemaNode>),
124 Table(HashMap<String, ColumnInfo>),
126}
127
128impl Default for MappingSchema {
129 fn default() -> Self {
130 Self::new()
131 }
132}
133
134impl MappingSchema {
135 pub fn new() -> Self {
137 Self {
138 mapping: HashMap::new(),
139 mapping_trie: Trie::new(),
140 dialect: None,
141 normalize: true,
142 visible: HashMap::new(),
143 cached_depth: 0,
144 }
145 }
146
147 pub fn with_dialect(dialect: DialectType) -> Self {
149 Self {
150 dialect: Some(dialect),
151 ..Self::new()
152 }
153 }
154
155 pub fn without_normalization(mut self) -> Self {
157 self.normalize = false;
158 self
159 }
160
161 pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
163 let key = self.normalize_name(table, true);
164 let cols: HashSet<String> = columns
165 .iter()
166 .map(|c| self.normalize_name(c, false))
167 .collect();
168 self.visible.insert(key, cols);
169 }
170
171 fn normalize_name(&self, name: &str, is_table: bool) -> String {
173 if !self.normalize {
174 return name.to_string();
175 }
176
177 match self.dialect {
180 Some(DialectType::BigQuery) if is_table => {
181 name.to_string()
183 }
184 Some(DialectType::Snowflake) => {
185 name.to_uppercase()
187 }
188 _ => {
189 name.to_lowercase()
191 }
192 }
193 }
194
195 fn parse_table_parts(&self, table: &str) -> Vec<String> {
197 table
198 .split('.')
199 .map(|s| self.normalize_name(s.trim(), true))
200 .collect()
201 }
202
203 fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
205 let parts = self.parse_table_parts(table);
206
207 let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
209 let key: String = reversed_parts.join(".");
210
211 let (result, _) = self.mapping_trie.in_trie(&key);
212
213 match result {
214 TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
215 TrieResult::Prefix => {
216 Err(SchemaError::AmbiguousTable {
218 table: table.to_string(),
219 matches: "multiple matches".to_string(),
220 })
221 }
222 TrieResult::Exists => {
223 self.navigate_to_table(&parts)
225 }
226 }
227 }
228
229 fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
231 let mut current = &self.mapping;
232
233 for (i, part) in parts.iter().enumerate() {
234 match current.get(part) {
235 Some(SchemaNode::Namespace(inner)) => {
236 current = inner;
237 }
238 Some(SchemaNode::Table(cols)) => {
239 if i == parts.len() - 1 {
240 return Ok(cols);
241 } else {
242 return Err(SchemaError::InvalidStructure(format!(
243 "Found table at {} but expected more levels",
244 parts[..=i].join(".")
245 )));
246 }
247 }
248 None => {
249 return Err(SchemaError::TableNotFound(parts.join(".")));
250 }
251 }
252 }
253
254 Err(SchemaError::TableNotFound(parts.join(".")))
256 }
257
258 fn add_table_internal(
260 &mut self,
261 parts: &[String],
262 columns: HashMap<String, ColumnInfo>,
263 ) -> SchemaResult<()> {
264 if parts.is_empty() {
265 return Err(SchemaError::InvalidStructure(
266 "Table name cannot be empty".to_string(),
267 ));
268 }
269
270 let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
272 self.mapping_trie.insert(&trie_key, ());
273
274 let mut current = &mut self.mapping;
276
277 for (i, part) in parts.iter().enumerate() {
278 let is_last = i == parts.len() - 1;
279
280 if is_last {
281 current.insert(part.clone(), SchemaNode::Table(columns));
283 return Ok(());
284 } else {
285 let entry = current
287 .entry(part.clone())
288 .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
289
290 match entry {
291 SchemaNode::Namespace(inner) => {
292 current = inner;
293 }
294 SchemaNode::Table(_) => {
295 return Err(SchemaError::InvalidStructure(format!(
296 "Expected namespace at {} but found table",
297 parts[..=i].join(".")
298 )));
299 }
300 }
301 }
302 }
303
304 Ok(())
305 }
306
307 fn update_depth(&mut self) {
309 self.cached_depth = self.calculate_depth(&self.mapping);
310 }
311
312 fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
313 if mapping.is_empty() {
314 return 0;
315 }
316
317 let mut max_depth = 1;
318 for node in mapping.values() {
319 match node {
320 SchemaNode::Namespace(inner) => {
321 let d = 1 + self.calculate_depth(inner);
322 if d > max_depth {
323 max_depth = d;
324 }
325 }
326 SchemaNode::Table(_) => {
327 }
329 }
330 }
331 max_depth
332 }
333}
334
335impl Schema for MappingSchema {
336 fn dialect(&self) -> Option<DialectType> {
337 self.dialect
338 }
339
340 fn add_table(
341 &mut self,
342 table: &str,
343 columns: &[(String, DataType)],
344 _dialect: Option<DialectType>,
345 ) -> SchemaResult<()> {
346 let parts = self.parse_table_parts(table);
347
348 let cols: HashMap<String, ColumnInfo> = columns
349 .iter()
350 .map(|(name, dtype)| {
351 let normalized_name = self.normalize_name(name, false);
352 (normalized_name, ColumnInfo::new(dtype.clone()))
353 })
354 .collect();
355
356 self.add_table_internal(&parts, cols)?;
357 self.update_depth();
358 Ok(())
359 }
360
361 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
362 let cols = self.find_table(table)?;
363 let table_key = self.normalize_name(table, true);
364
365 if let Some(visible_cols) = self.visible.get(&table_key) {
367 Ok(cols
368 .keys()
369 .filter(|k| visible_cols.contains(*k))
370 .cloned()
371 .collect())
372 } else {
373 Ok(cols.keys().cloned().collect())
374 }
375 }
376
377 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
378 let cols = self.find_table(table)?;
379 let normalized_col = self.normalize_name(column, false);
380
381 cols.get(&normalized_col)
382 .map(|info| info.data_type.clone())
383 .ok_or_else(|| SchemaError::ColumnNotFound {
384 table: table.to_string(),
385 column: column.to_string(),
386 })
387 }
388
389 fn has_column(&self, table: &str, column: &str) -> bool {
390 self.get_column_type(table, column).is_ok()
391 }
392
393 fn supported_table_args(&self) -> &[&str] {
394 let depth = self.depth();
395 if depth == 0 {
396 &[]
397 } else if depth <= 3 {
398 &TABLE_PARTS[..depth]
399 } else {
400 TABLE_PARTS
401 }
402 }
403
404 fn is_empty(&self) -> bool {
405 self.mapping.is_empty()
406 }
407
408 fn depth(&self) -> usize {
409 self.cached_depth
410 }
411
412 fn find_tables_for_column(&self, column: &str) -> Vec<String> {
413 let normalized = normalize_name(column, self.dialect, false, self.normalize);
414 let mut result = Vec::new();
415 for table_name in self.mapping.keys() {
416 if self.has_column(table_name, &normalized) {
417 result.push(table_name.clone());
418 }
419 }
420 result
421 }
422}
423
424pub fn normalize_name(
426 name: &str,
427 dialect: Option<DialectType>,
428 is_table: bool,
429 normalize: bool,
430) -> String {
431 if !normalize {
432 return name.to_string();
433 }
434
435 match dialect {
436 Some(DialectType::BigQuery) if is_table => name.to_string(),
437 Some(DialectType::Snowflake) => name.to_uppercase(),
438 _ => name.to_lowercase(),
439 }
440}
441
442pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
444 schema.unwrap_or_default()
445}
446
447pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
463 let mut schema = MappingSchema::new();
464
465 for (table_name, columns) in tables {
466 let cols: Vec<(String, DataType)> = columns
467 .iter()
468 .map(|(name, dtype)| (name.to_string(), dtype.clone()))
469 .collect();
470
471 schema.add_table(table_name, &cols, None).ok();
472 }
473
474 schema
475}
476
477pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
479 let mut paths = Vec::new();
480 flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
481 paths
482}
483
484fn flatten_schema_paths_recursive(
485 mapping: &HashMap<String, SchemaNode>,
486 prefix: Vec<String>,
487 paths: &mut Vec<Vec<String>>,
488) {
489 for (key, node) in mapping {
490 let mut path = prefix.clone();
491 path.push(key.clone());
492
493 match node {
494 SchemaNode::Namespace(inner) => {
495 flatten_schema_paths_recursive(inner, path, paths);
496 }
497 SchemaNode::Table(_) => {
498 paths.push(path);
499 }
500 }
501 }
502}
503
504pub fn nested_set<V: Clone>(
506 map: &mut HashMap<String, HashMap<String, V>>,
507 keys: &[String],
508 value: V,
509) {
510 if keys.is_empty() {
511 return;
512 }
513
514 if keys.len() == 1 {
515 return;
517 }
518
519 let outer_key = &keys[0];
520 let inner_key = &keys[1];
521
522 map.entry(outer_key.clone())
523 .or_insert_with(HashMap::new)
524 .insert(inner_key.clone(), value);
525}
526
527pub fn nested_get<'a, V>(
529 map: &'a HashMap<String, HashMap<String, V>>,
530 keys: &[String],
531) -> Option<&'a V> {
532 if keys.len() != 2 {
533 return None;
534 }
535
536 map.get(&keys[0])?.get(&keys[1])
537}
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542
543 #[test]
544 fn test_empty_schema() {
545 let schema = MappingSchema::new();
546 assert!(schema.is_empty());
547 assert_eq!(schema.depth(), 0);
548 }
549
550 #[test]
551 fn test_add_table() {
552 let mut schema = MappingSchema::new();
553 let columns = vec![
554 (
555 "id".to_string(),
556 DataType::Int {
557 length: None,
558 integer_spelling: false,
559 },
560 ),
561 (
562 "name".to_string(),
563 DataType::VarChar {
564 length: Some(255),
565 parenthesized_length: false,
566 },
567 ),
568 ];
569
570 schema.add_table("users", &columns, None).unwrap();
571
572 assert!(!schema.is_empty());
573 assert_eq!(schema.depth(), 1);
574 assert!(schema.has_column("users", "id"));
575 assert!(schema.has_column("users", "name"));
576 assert!(!schema.has_column("users", "email"));
577 }
578
579 #[test]
580 fn test_qualified_table_names() {
581 let mut schema = MappingSchema::new();
582 let columns = vec![(
583 "id".to_string(),
584 DataType::Int {
585 length: None,
586 integer_spelling: false,
587 },
588 )];
589
590 schema.add_table("mydb.users", &columns, None).unwrap();
591
592 assert!(schema.has_column("mydb.users", "id"));
593 assert_eq!(schema.depth(), 2);
594 }
595
596 #[test]
597 fn test_catalog_db_table() {
598 let mut schema = MappingSchema::new();
599 let columns = vec![(
600 "id".to_string(),
601 DataType::Int {
602 length: None,
603 integer_spelling: false,
604 },
605 )];
606
607 schema
608 .add_table("catalog.mydb.users", &columns, None)
609 .unwrap();
610
611 assert!(schema.has_column("catalog.mydb.users", "id"));
612 assert_eq!(schema.depth(), 3);
613 }
614
615 #[test]
616 fn test_get_column_type() {
617 let mut schema = MappingSchema::new();
618 let columns = vec![
619 (
620 "id".to_string(),
621 DataType::Int {
622 length: None,
623 integer_spelling: false,
624 },
625 ),
626 (
627 "name".to_string(),
628 DataType::VarChar {
629 length: Some(255),
630 parenthesized_length: false,
631 },
632 ),
633 ];
634
635 schema.add_table("users", &columns, None).unwrap();
636
637 let id_type = schema.get_column_type("users", "id").unwrap();
638 assert!(matches!(id_type, DataType::Int { .. }));
639
640 let name_type = schema.get_column_type("users", "name").unwrap();
641 assert!(matches!(
642 name_type,
643 DataType::VarChar {
644 length: Some(255),
645 parenthesized_length: false
646 }
647 ));
648 }
649
650 #[test]
651 fn test_column_names() {
652 let mut schema = MappingSchema::new();
653 let columns = vec![
654 (
655 "id".to_string(),
656 DataType::Int {
657 length: None,
658 integer_spelling: false,
659 },
660 ),
661 (
662 "name".to_string(),
663 DataType::VarChar {
664 length: None,
665 parenthesized_length: false,
666 },
667 ),
668 ];
669
670 schema.add_table("users", &columns, None).unwrap();
671
672 let names = schema.column_names("users").unwrap();
673 assert_eq!(names.len(), 2);
674 assert!(names.contains(&"id".to_string()));
675 assert!(names.contains(&"name".to_string()));
676 }
677
678 #[test]
679 fn test_table_not_found() {
680 let schema = MappingSchema::new();
681 let result = schema.column_names("nonexistent");
682 assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
683 }
684
685 #[test]
686 fn test_column_not_found() {
687 let mut schema = MappingSchema::new();
688 let columns = vec![(
689 "id".to_string(),
690 DataType::Int {
691 length: None,
692 integer_spelling: false,
693 },
694 )];
695 schema.add_table("users", &columns, None).unwrap();
696
697 let result = schema.get_column_type("users", "nonexistent");
698 assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
699 }
700
701 #[test]
702 fn test_normalize_name_default() {
703 let name = normalize_name("MyTable", None, true, true);
704 assert_eq!(name, "mytable");
705 }
706
707 #[test]
708 fn test_normalize_name_snowflake() {
709 let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
710 assert_eq!(name, "MYTABLE");
711 }
712
713 #[test]
714 fn test_normalize_disabled() {
715 let name = normalize_name("MyTable", None, true, false);
716 assert_eq!(name, "MyTable");
717 }
718
719 #[test]
720 fn test_from_simple_map() {
721 let schema = from_simple_map(&[
722 (
723 "users",
724 &[
725 (
726 "id",
727 DataType::Int {
728 length: None,
729 integer_spelling: false,
730 },
731 ),
732 (
733 "name",
734 DataType::VarChar {
735 length: None,
736 parenthesized_length: false,
737 },
738 ),
739 ],
740 ),
741 (
742 "orders",
743 &[
744 (
745 "id",
746 DataType::Int {
747 length: None,
748 integer_spelling: false,
749 },
750 ),
751 (
752 "user_id",
753 DataType::Int {
754 length: None,
755 integer_spelling: false,
756 },
757 ),
758 ],
759 ),
760 ]);
761
762 assert!(schema.has_column("users", "id"));
763 assert!(schema.has_column("users", "name"));
764 assert!(schema.has_column("orders", "id"));
765 assert!(schema.has_column("orders", "user_id"));
766 }
767
768 #[test]
769 fn test_flatten_schema_paths() {
770 let mut schema = MappingSchema::new();
771 schema
772 .add_table(
773 "db1.table1",
774 &[(
775 "id".to_string(),
776 DataType::Int {
777 length: None,
778 integer_spelling: false,
779 },
780 )],
781 None,
782 )
783 .unwrap();
784 schema
785 .add_table(
786 "db1.table2",
787 &[(
788 "id".to_string(),
789 DataType::Int {
790 length: None,
791 integer_spelling: false,
792 },
793 )],
794 None,
795 )
796 .unwrap();
797 schema
798 .add_table(
799 "db2.table1",
800 &[(
801 "id".to_string(),
802 DataType::Int {
803 length: None,
804 integer_spelling: false,
805 },
806 )],
807 None,
808 )
809 .unwrap();
810
811 let paths = flatten_schema_paths(&schema);
812 assert_eq!(paths.len(), 3);
813 }
814
815 #[test]
816 fn test_visible_columns() {
817 let mut schema = MappingSchema::new();
818 let columns = vec![
819 (
820 "id".to_string(),
821 DataType::Int {
822 length: None,
823 integer_spelling: false,
824 },
825 ),
826 (
827 "name".to_string(),
828 DataType::VarChar {
829 length: None,
830 parenthesized_length: false,
831 },
832 ),
833 (
834 "password".to_string(),
835 DataType::VarChar {
836 length: None,
837 parenthesized_length: false,
838 },
839 ),
840 ];
841 schema.add_table("users", &columns, None).unwrap();
842 schema.set_visible_columns("users", &["id", "name"]);
843
844 let names = schema.column_names("users").unwrap();
845 assert_eq!(names.len(), 2);
846 assert!(names.contains(&"id".to_string()));
847 assert!(names.contains(&"name".to_string()));
848 assert!(!names.contains(&"password".to_string()));
849 }
850}