1use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19 #[error("Table not found: {0}")]
20 TableNotFound(String),
21
22 #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23 AmbiguousTable { table: String, matches: String },
24
25 #[error("Column not found: {column} in table {table}")]
26 ColumnNotFound { table: String, column: String },
27
28 #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29 DepthMismatch { expected: usize, actual: usize },
30
31 #[error("Invalid schema structure: {0}")]
32 InvalidStructure(String),
33}
34
35pub type SchemaResult<T> = Result<T, SchemaError>;
37
38pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41pub trait Schema {
43 fn dialect(&self) -> Option<DialectType>;
45
46 fn add_table(
48 &mut self,
49 table: &str,
50 columns: &[(String, DataType)],
51 dialect: Option<DialectType>,
52 ) -> SchemaResult<()>;
53
54 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60 fn has_column(&self, table: &str, column: &str) -> bool;
62
63 fn supported_table_args(&self) -> &[&str];
65
66 fn is_empty(&self) -> bool;
68
69 fn depth(&self) -> usize;
71}
72
73#[derive(Debug, Clone)]
75pub struct ColumnInfo {
76 pub data_type: DataType,
77 pub visible: bool,
78}
79
80impl ColumnInfo {
81 pub fn new(data_type: DataType) -> Self {
82 Self {
83 data_type,
84 visible: true,
85 }
86 }
87
88 pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
89 Self { data_type, visible }
90 }
91}
92
93#[derive(Debug, Clone)]
100pub struct MappingSchema {
101 mapping: HashMap<String, SchemaNode>,
103 mapping_trie: Trie<()>,
105 dialect: Option<DialectType>,
107 normalize: bool,
109 visible: HashMap<String, HashSet<String>>,
111 cached_depth: usize,
113}
114
115#[derive(Debug, Clone)]
117pub enum SchemaNode {
118 Namespace(HashMap<String, SchemaNode>),
120 Table(HashMap<String, ColumnInfo>),
122}
123
124impl Default for MappingSchema {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl MappingSchema {
131 pub fn new() -> Self {
133 Self {
134 mapping: HashMap::new(),
135 mapping_trie: Trie::new(),
136 dialect: None,
137 normalize: true,
138 visible: HashMap::new(),
139 cached_depth: 0,
140 }
141 }
142
143 pub fn with_dialect(dialect: DialectType) -> Self {
145 Self {
146 dialect: Some(dialect),
147 ..Self::new()
148 }
149 }
150
151 pub fn without_normalization(mut self) -> Self {
153 self.normalize = false;
154 self
155 }
156
157 pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
159 let key = self.normalize_name(table, true);
160 let cols: HashSet<String> = columns
161 .iter()
162 .map(|c| self.normalize_name(c, false))
163 .collect();
164 self.visible.insert(key, cols);
165 }
166
167 fn normalize_name(&self, name: &str, is_table: bool) -> String {
169 if !self.normalize {
170 return name.to_string();
171 }
172
173 match self.dialect {
176 Some(DialectType::BigQuery) if is_table => {
177 name.to_string()
179 }
180 Some(DialectType::Snowflake) => {
181 name.to_uppercase()
183 }
184 _ => {
185 name.to_lowercase()
187 }
188 }
189 }
190
191 fn parse_table_parts(&self, table: &str) -> Vec<String> {
193 table
194 .split('.')
195 .map(|s| self.normalize_name(s.trim(), true))
196 .collect()
197 }
198
199 fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
201 let parts = self.parse_table_parts(table);
202
203 let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
205 let key: String = reversed_parts.join(".");
206
207 let (result, _) = self.mapping_trie.in_trie(&key);
208
209 match result {
210 TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
211 TrieResult::Prefix => {
212 Err(SchemaError::AmbiguousTable {
214 table: table.to_string(),
215 matches: "multiple matches".to_string(),
216 })
217 }
218 TrieResult::Exists => {
219 self.navigate_to_table(&parts)
221 }
222 }
223 }
224
225 fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
227 let mut current = &self.mapping;
228
229 for (i, part) in parts.iter().enumerate() {
230 match current.get(part) {
231 Some(SchemaNode::Namespace(inner)) => {
232 current = inner;
233 }
234 Some(SchemaNode::Table(cols)) => {
235 if i == parts.len() - 1 {
236 return Ok(cols);
237 } else {
238 return Err(SchemaError::InvalidStructure(format!(
239 "Found table at {} but expected more levels",
240 parts[..=i].join(".")
241 )));
242 }
243 }
244 None => {
245 return Err(SchemaError::TableNotFound(parts.join(".")));
246 }
247 }
248 }
249
250 Err(SchemaError::TableNotFound(parts.join(".")))
252 }
253
254 fn add_table_internal(
256 &mut self,
257 parts: &[String],
258 columns: HashMap<String, ColumnInfo>,
259 ) -> SchemaResult<()> {
260 if parts.is_empty() {
261 return Err(SchemaError::InvalidStructure(
262 "Table name cannot be empty".to_string(),
263 ));
264 }
265
266 let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
268 self.mapping_trie.insert(&trie_key, ());
269
270 let mut current = &mut self.mapping;
272
273 for (i, part) in parts.iter().enumerate() {
274 let is_last = i == parts.len() - 1;
275
276 if is_last {
277 current.insert(part.clone(), SchemaNode::Table(columns));
279 return Ok(());
280 } else {
281 let entry = current
283 .entry(part.clone())
284 .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
285
286 match entry {
287 SchemaNode::Namespace(inner) => {
288 current = inner;
289 }
290 SchemaNode::Table(_) => {
291 return Err(SchemaError::InvalidStructure(format!(
292 "Expected namespace at {} but found table",
293 parts[..=i].join(".")
294 )));
295 }
296 }
297 }
298 }
299
300 Ok(())
301 }
302
303 fn update_depth(&mut self) {
305 self.cached_depth = self.calculate_depth(&self.mapping);
306 }
307
308 fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
309 if mapping.is_empty() {
310 return 0;
311 }
312
313 let mut max_depth = 1;
314 for node in mapping.values() {
315 match node {
316 SchemaNode::Namespace(inner) => {
317 let d = 1 + self.calculate_depth(inner);
318 if d > max_depth {
319 max_depth = d;
320 }
321 }
322 SchemaNode::Table(_) => {
323 }
325 }
326 }
327 max_depth
328 }
329}
330
331impl Schema for MappingSchema {
332 fn dialect(&self) -> Option<DialectType> {
333 self.dialect
334 }
335
336 fn add_table(
337 &mut self,
338 table: &str,
339 columns: &[(String, DataType)],
340 _dialect: Option<DialectType>,
341 ) -> SchemaResult<()> {
342 let parts = self.parse_table_parts(table);
343
344 let cols: HashMap<String, ColumnInfo> = columns
345 .iter()
346 .map(|(name, dtype)| {
347 let normalized_name = self.normalize_name(name, false);
348 (normalized_name, ColumnInfo::new(dtype.clone()))
349 })
350 .collect();
351
352 self.add_table_internal(&parts, cols)?;
353 self.update_depth();
354 Ok(())
355 }
356
357 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
358 let cols = self.find_table(table)?;
359 let table_key = self.normalize_name(table, true);
360
361 if let Some(visible_cols) = self.visible.get(&table_key) {
363 Ok(cols
364 .keys()
365 .filter(|k| visible_cols.contains(*k))
366 .cloned()
367 .collect())
368 } else {
369 Ok(cols.keys().cloned().collect())
370 }
371 }
372
373 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
374 let cols = self.find_table(table)?;
375 let normalized_col = self.normalize_name(column, false);
376
377 cols.get(&normalized_col)
378 .map(|info| info.data_type.clone())
379 .ok_or_else(|| SchemaError::ColumnNotFound {
380 table: table.to_string(),
381 column: column.to_string(),
382 })
383 }
384
385 fn has_column(&self, table: &str, column: &str) -> bool {
386 self.get_column_type(table, column).is_ok()
387 }
388
389 fn supported_table_args(&self) -> &[&str] {
390 let depth = self.depth();
391 if depth == 0 {
392 &[]
393 } else if depth <= 3 {
394 &TABLE_PARTS[..depth]
395 } else {
396 TABLE_PARTS
397 }
398 }
399
400 fn is_empty(&self) -> bool {
401 self.mapping.is_empty()
402 }
403
404 fn depth(&self) -> usize {
405 self.cached_depth
406 }
407}
408
409pub fn normalize_name(
411 name: &str,
412 dialect: Option<DialectType>,
413 is_table: bool,
414 normalize: bool,
415) -> String {
416 if !normalize {
417 return name.to_string();
418 }
419
420 match dialect {
421 Some(DialectType::BigQuery) if is_table => name.to_string(),
422 Some(DialectType::Snowflake) => name.to_uppercase(),
423 _ => name.to_lowercase(),
424 }
425}
426
427pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
429 schema.unwrap_or_default()
430}
431
432pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
448 let mut schema = MappingSchema::new();
449
450 for (table_name, columns) in tables {
451 let cols: Vec<(String, DataType)> = columns
452 .iter()
453 .map(|(name, dtype)| (name.to_string(), dtype.clone()))
454 .collect();
455
456 schema.add_table(table_name, &cols, None).ok();
457 }
458
459 schema
460}
461
462pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
464 let mut paths = Vec::new();
465 flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
466 paths
467}
468
469fn flatten_schema_paths_recursive(
470 mapping: &HashMap<String, SchemaNode>,
471 prefix: Vec<String>,
472 paths: &mut Vec<Vec<String>>,
473) {
474 for (key, node) in mapping {
475 let mut path = prefix.clone();
476 path.push(key.clone());
477
478 match node {
479 SchemaNode::Namespace(inner) => {
480 flatten_schema_paths_recursive(inner, path, paths);
481 }
482 SchemaNode::Table(_) => {
483 paths.push(path);
484 }
485 }
486 }
487}
488
489pub fn nested_set<V: Clone>(
491 map: &mut HashMap<String, HashMap<String, V>>,
492 keys: &[String],
493 value: V,
494) {
495 if keys.is_empty() {
496 return;
497 }
498
499 if keys.len() == 1 {
500 return;
502 }
503
504 let outer_key = &keys[0];
505 let inner_key = &keys[1];
506
507 map.entry(outer_key.clone())
508 .or_insert_with(HashMap::new)
509 .insert(inner_key.clone(), value);
510}
511
512pub fn nested_get<'a, V>(
514 map: &'a HashMap<String, HashMap<String, V>>,
515 keys: &[String],
516) -> Option<&'a V> {
517 if keys.len() != 2 {
518 return None;
519 }
520
521 map.get(&keys[0])?.get(&keys[1])
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_empty_schema() {
530 let schema = MappingSchema::new();
531 assert!(schema.is_empty());
532 assert_eq!(schema.depth(), 0);
533 }
534
535 #[test]
536 fn test_add_table() {
537 let mut schema = MappingSchema::new();
538 let columns = vec![
539 (
540 "id".to_string(),
541 DataType::Int {
542 length: None,
543 integer_spelling: false,
544 },
545 ),
546 (
547 "name".to_string(),
548 DataType::VarChar {
549 length: Some(255),
550 parenthesized_length: false,
551 },
552 ),
553 ];
554
555 schema.add_table("users", &columns, None).unwrap();
556
557 assert!(!schema.is_empty());
558 assert_eq!(schema.depth(), 1);
559 assert!(schema.has_column("users", "id"));
560 assert!(schema.has_column("users", "name"));
561 assert!(!schema.has_column("users", "email"));
562 }
563
564 #[test]
565 fn test_qualified_table_names() {
566 let mut schema = MappingSchema::new();
567 let columns = vec![(
568 "id".to_string(),
569 DataType::Int {
570 length: None,
571 integer_spelling: false,
572 },
573 )];
574
575 schema.add_table("mydb.users", &columns, None).unwrap();
576
577 assert!(schema.has_column("mydb.users", "id"));
578 assert_eq!(schema.depth(), 2);
579 }
580
581 #[test]
582 fn test_catalog_db_table() {
583 let mut schema = MappingSchema::new();
584 let columns = vec![(
585 "id".to_string(),
586 DataType::Int {
587 length: None,
588 integer_spelling: false,
589 },
590 )];
591
592 schema
593 .add_table("catalog.mydb.users", &columns, None)
594 .unwrap();
595
596 assert!(schema.has_column("catalog.mydb.users", "id"));
597 assert_eq!(schema.depth(), 3);
598 }
599
600 #[test]
601 fn test_get_column_type() {
602 let mut schema = MappingSchema::new();
603 let columns = vec![
604 (
605 "id".to_string(),
606 DataType::Int {
607 length: None,
608 integer_spelling: false,
609 },
610 ),
611 (
612 "name".to_string(),
613 DataType::VarChar {
614 length: Some(255),
615 parenthesized_length: false,
616 },
617 ),
618 ];
619
620 schema.add_table("users", &columns, None).unwrap();
621
622 let id_type = schema.get_column_type("users", "id").unwrap();
623 assert!(matches!(id_type, DataType::Int { .. }));
624
625 let name_type = schema.get_column_type("users", "name").unwrap();
626 assert!(matches!(
627 name_type,
628 DataType::VarChar {
629 length: Some(255),
630 parenthesized_length: false
631 }
632 ));
633 }
634
635 #[test]
636 fn test_column_names() {
637 let mut schema = MappingSchema::new();
638 let columns = vec![
639 (
640 "id".to_string(),
641 DataType::Int {
642 length: None,
643 integer_spelling: false,
644 },
645 ),
646 (
647 "name".to_string(),
648 DataType::VarChar {
649 length: None,
650 parenthesized_length: false,
651 },
652 ),
653 ];
654
655 schema.add_table("users", &columns, None).unwrap();
656
657 let names = schema.column_names("users").unwrap();
658 assert_eq!(names.len(), 2);
659 assert!(names.contains(&"id".to_string()));
660 assert!(names.contains(&"name".to_string()));
661 }
662
663 #[test]
664 fn test_table_not_found() {
665 let schema = MappingSchema::new();
666 let result = schema.column_names("nonexistent");
667 assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
668 }
669
670 #[test]
671 fn test_column_not_found() {
672 let mut schema = MappingSchema::new();
673 let columns = vec![(
674 "id".to_string(),
675 DataType::Int {
676 length: None,
677 integer_spelling: false,
678 },
679 )];
680 schema.add_table("users", &columns, None).unwrap();
681
682 let result = schema.get_column_type("users", "nonexistent");
683 assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
684 }
685
686 #[test]
687 fn test_normalize_name_default() {
688 let name = normalize_name("MyTable", None, true, true);
689 assert_eq!(name, "mytable");
690 }
691
692 #[test]
693 fn test_normalize_name_snowflake() {
694 let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
695 assert_eq!(name, "MYTABLE");
696 }
697
698 #[test]
699 fn test_normalize_disabled() {
700 let name = normalize_name("MyTable", None, true, false);
701 assert_eq!(name, "MyTable");
702 }
703
704 #[test]
705 fn test_from_simple_map() {
706 let schema = from_simple_map(&[
707 (
708 "users",
709 &[
710 (
711 "id",
712 DataType::Int {
713 length: None,
714 integer_spelling: false,
715 },
716 ),
717 (
718 "name",
719 DataType::VarChar {
720 length: None,
721 parenthesized_length: false,
722 },
723 ),
724 ],
725 ),
726 (
727 "orders",
728 &[
729 (
730 "id",
731 DataType::Int {
732 length: None,
733 integer_spelling: false,
734 },
735 ),
736 (
737 "user_id",
738 DataType::Int {
739 length: None,
740 integer_spelling: false,
741 },
742 ),
743 ],
744 ),
745 ]);
746
747 assert!(schema.has_column("users", "id"));
748 assert!(schema.has_column("users", "name"));
749 assert!(schema.has_column("orders", "id"));
750 assert!(schema.has_column("orders", "user_id"));
751 }
752
753 #[test]
754 fn test_flatten_schema_paths() {
755 let mut schema = MappingSchema::new();
756 schema
757 .add_table(
758 "db1.table1",
759 &[(
760 "id".to_string(),
761 DataType::Int {
762 length: None,
763 integer_spelling: false,
764 },
765 )],
766 None,
767 )
768 .unwrap();
769 schema
770 .add_table(
771 "db1.table2",
772 &[(
773 "id".to_string(),
774 DataType::Int {
775 length: None,
776 integer_spelling: false,
777 },
778 )],
779 None,
780 )
781 .unwrap();
782 schema
783 .add_table(
784 "db2.table1",
785 &[(
786 "id".to_string(),
787 DataType::Int {
788 length: None,
789 integer_spelling: false,
790 },
791 )],
792 None,
793 )
794 .unwrap();
795
796 let paths = flatten_schema_paths(&schema);
797 assert_eq!(paths.len(), 3);
798 }
799
800 #[test]
801 fn test_visible_columns() {
802 let mut schema = MappingSchema::new();
803 let columns = vec![
804 (
805 "id".to_string(),
806 DataType::Int {
807 length: None,
808 integer_spelling: false,
809 },
810 ),
811 (
812 "name".to_string(),
813 DataType::VarChar {
814 length: None,
815 parenthesized_length: false,
816 },
817 ),
818 (
819 "password".to_string(),
820 DataType::VarChar {
821 length: None,
822 parenthesized_length: false,
823 },
824 ),
825 ];
826 schema.add_table("users", &columns, None).unwrap();
827 schema.set_visible_columns("users", &["id", "name"]);
828
829 let names = schema.column_names("users").unwrap();
830 assert_eq!(names.len(), 2);
831 assert!(names.contains(&"id".to_string()));
832 assert!(names.contains(&"name".to_string()));
833 assert!(!names.contains(&"password".to_string()));
834 }
835}