1use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19 #[error("Table not found: {0}")]
20 TableNotFound(String),
21
22 #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23 AmbiguousTable { table: String, matches: String },
24
25 #[error("Column not found: {column} in table {table}")]
26 ColumnNotFound { table: String, column: String },
27
28 #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29 DepthMismatch { expected: usize, actual: usize },
30
31 #[error("Invalid schema structure: {0}")]
32 InvalidStructure(String),
33}
34
35pub type SchemaResult<T> = Result<T, SchemaError>;
37
38pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41pub trait Schema {
43 fn dialect(&self) -> Option<DialectType>;
45
46 fn add_table(
48 &mut self,
49 table: &str,
50 columns: &[(String, DataType)],
51 dialect: Option<DialectType>,
52 ) -> SchemaResult<()>;
53
54 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60 fn has_column(&self, table: &str, column: &str) -> bool;
62
63 fn supported_table_args(&self) -> &[&str];
65
66 fn is_empty(&self) -> bool;
68
69 fn depth(&self) -> usize;
71}
72
73#[derive(Debug, Clone)]
75pub struct ColumnInfo {
76 pub data_type: DataType,
77 pub visible: bool,
78}
79
80impl ColumnInfo {
81 pub fn new(data_type: DataType) -> Self {
82 Self {
83 data_type,
84 visible: true,
85 }
86 }
87
88 pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
89 Self { data_type, visible }
90 }
91}
92
93#[derive(Debug, Clone)]
100pub struct MappingSchema {
101 mapping: HashMap<String, SchemaNode>,
103 mapping_trie: Trie<()>,
105 dialect: Option<DialectType>,
107 normalize: bool,
109 visible: HashMap<String, HashSet<String>>,
111 cached_depth: usize,
113}
114
115#[derive(Debug, Clone)]
117pub enum SchemaNode {
118 Namespace(HashMap<String, SchemaNode>),
120 Table(HashMap<String, ColumnInfo>),
122}
123
124impl Default for MappingSchema {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl MappingSchema {
131 pub fn new() -> Self {
133 Self {
134 mapping: HashMap::new(),
135 mapping_trie: Trie::new(),
136 dialect: None,
137 normalize: true,
138 visible: HashMap::new(),
139 cached_depth: 0,
140 }
141 }
142
143 pub fn with_dialect(dialect: DialectType) -> Self {
145 Self {
146 dialect: Some(dialect),
147 ..Self::new()
148 }
149 }
150
151 pub fn without_normalization(mut self) -> Self {
153 self.normalize = false;
154 self
155 }
156
157 pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
159 let key = self.normalize_name(table, true);
160 let cols: HashSet<String> = columns
161 .iter()
162 .map(|c| self.normalize_name(c, false))
163 .collect();
164 self.visible.insert(key, cols);
165 }
166
167 fn normalize_name(&self, name: &str, is_table: bool) -> String {
169 if !self.normalize {
170 return name.to_string();
171 }
172
173 match self.dialect {
176 Some(DialectType::BigQuery) if is_table => {
177 name.to_string()
179 }
180 Some(DialectType::Snowflake) => {
181 name.to_uppercase()
183 }
184 _ => {
185 name.to_lowercase()
187 }
188 }
189 }
190
191 fn parse_table_parts(&self, table: &str) -> Vec<String> {
193 table
194 .split('.')
195 .map(|s| self.normalize_name(s.trim(), true))
196 .collect()
197 }
198
199 fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
201 let parts = self.parse_table_parts(table);
202
203 let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
205 let key: String = reversed_parts.join(".");
206
207 let (result, _) = self.mapping_trie.in_trie(&key);
208
209 match result {
210 TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
211 TrieResult::Prefix => {
212 Err(SchemaError::AmbiguousTable {
214 table: table.to_string(),
215 matches: "multiple matches".to_string(),
216 })
217 }
218 TrieResult::Exists => {
219 self.navigate_to_table(&parts)
221 }
222 }
223 }
224
225 fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
227 let mut current = &self.mapping;
228
229 for (i, part) in parts.iter().enumerate() {
230 match current.get(part) {
231 Some(SchemaNode::Namespace(inner)) => {
232 current = inner;
233 }
234 Some(SchemaNode::Table(cols)) => {
235 if i == parts.len() - 1 {
236 return Ok(cols);
237 } else {
238 return Err(SchemaError::InvalidStructure(format!(
239 "Found table at {} but expected more levels",
240 parts[..=i].join(".")
241 )));
242 }
243 }
244 None => {
245 return Err(SchemaError::TableNotFound(parts.join(".")));
246 }
247 }
248 }
249
250 Err(SchemaError::TableNotFound(parts.join(".")))
252 }
253
254 fn add_table_internal(
256 &mut self,
257 parts: &[String],
258 columns: HashMap<String, ColumnInfo>,
259 ) -> SchemaResult<()> {
260 if parts.is_empty() {
261 return Err(SchemaError::InvalidStructure(
262 "Table name cannot be empty".to_string(),
263 ));
264 }
265
266 let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
268 self.mapping_trie.insert(&trie_key, ());
269
270 let mut current = &mut self.mapping;
272
273 for (i, part) in parts.iter().enumerate() {
274 let is_last = i == parts.len() - 1;
275
276 if is_last {
277 current.insert(part.clone(), SchemaNode::Table(columns));
279 return Ok(());
280 } else {
281 let entry = current
283 .entry(part.clone())
284 .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
285
286 match entry {
287 SchemaNode::Namespace(inner) => {
288 current = inner;
289 }
290 SchemaNode::Table(_) => {
291 return Err(SchemaError::InvalidStructure(format!(
292 "Expected namespace at {} but found table",
293 parts[..=i].join(".")
294 )));
295 }
296 }
297 }
298 }
299
300 Ok(())
301 }
302
303 fn update_depth(&mut self) {
305 self.cached_depth = self.calculate_depth(&self.mapping);
306 }
307
308 fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
309 if mapping.is_empty() {
310 return 0;
311 }
312
313 let mut max_depth = 1;
314 for node in mapping.values() {
315 match node {
316 SchemaNode::Namespace(inner) => {
317 let d = 1 + self.calculate_depth(inner);
318 if d > max_depth {
319 max_depth = d;
320 }
321 }
322 SchemaNode::Table(_) => {
323 }
325 }
326 }
327 max_depth
328 }
329}
330
331impl Schema for MappingSchema {
332 fn dialect(&self) -> Option<DialectType> {
333 self.dialect
334 }
335
336 fn add_table(
337 &mut self,
338 table: &str,
339 columns: &[(String, DataType)],
340 _dialect: Option<DialectType>,
341 ) -> SchemaResult<()> {
342 let parts = self.parse_table_parts(table);
343
344 let cols: HashMap<String, ColumnInfo> = columns
345 .iter()
346 .map(|(name, dtype)| {
347 let normalized_name = self.normalize_name(name, false);
348 (normalized_name, ColumnInfo::new(dtype.clone()))
349 })
350 .collect();
351
352 self.add_table_internal(&parts, cols)?;
353 self.update_depth();
354 Ok(())
355 }
356
357 fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
358 let cols = self.find_table(table)?;
359 let table_key = self.normalize_name(table, true);
360
361 if let Some(visible_cols) = self.visible.get(&table_key) {
363 Ok(cols
364 .keys()
365 .filter(|k| visible_cols.contains(*k))
366 .cloned()
367 .collect())
368 } else {
369 Ok(cols.keys().cloned().collect())
370 }
371 }
372
373 fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
374 let cols = self.find_table(table)?;
375 let normalized_col = self.normalize_name(column, false);
376
377 cols.get(&normalized_col)
378 .map(|info| info.data_type.clone())
379 .ok_or_else(|| SchemaError::ColumnNotFound {
380 table: table.to_string(),
381 column: column.to_string(),
382 })
383 }
384
385 fn has_column(&self, table: &str, column: &str) -> bool {
386 self.get_column_type(table, column).is_ok()
387 }
388
389 fn supported_table_args(&self) -> &[&str] {
390 let depth = self.depth();
391 if depth == 0 {
392 &[]
393 } else if depth <= 3 {
394 &TABLE_PARTS[..depth]
395 } else {
396 TABLE_PARTS
397 }
398 }
399
400 fn is_empty(&self) -> bool {
401 self.mapping.is_empty()
402 }
403
404 fn depth(&self) -> usize {
405 self.cached_depth
406 }
407}
408
409pub fn normalize_name(
411 name: &str,
412 dialect: Option<DialectType>,
413 is_table: bool,
414 normalize: bool,
415) -> String {
416 if !normalize {
417 return name.to_string();
418 }
419
420 match dialect {
421 Some(DialectType::BigQuery) if is_table => name.to_string(),
422 Some(DialectType::Snowflake) => name.to_uppercase(),
423 _ => name.to_lowercase(),
424 }
425}
426
427pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
429 schema.unwrap_or_default()
430}
431
432pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
448 let mut schema = MappingSchema::new();
449
450 for (table_name, columns) in tables {
451 let cols: Vec<(String, DataType)> = columns
452 .iter()
453 .map(|(name, dtype)| (name.to_string(), dtype.clone()))
454 .collect();
455
456 schema.add_table(table_name, &cols, None).ok();
457 }
458
459 schema
460}
461
462pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
464 let mut paths = Vec::new();
465 flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
466 paths
467}
468
469fn flatten_schema_paths_recursive(
470 mapping: &HashMap<String, SchemaNode>,
471 prefix: Vec<String>,
472 paths: &mut Vec<Vec<String>>,
473) {
474 for (key, node) in mapping {
475 let mut path = prefix.clone();
476 path.push(key.clone());
477
478 match node {
479 SchemaNode::Namespace(inner) => {
480 flatten_schema_paths_recursive(inner, path, paths);
481 }
482 SchemaNode::Table(_) => {
483 paths.push(path);
484 }
485 }
486 }
487}
488
489pub fn nested_set<V: Clone>(
491 map: &mut HashMap<String, HashMap<String, V>>,
492 keys: &[String],
493 value: V,
494) {
495 if keys.is_empty() {
496 return;
497 }
498
499 if keys.len() == 1 {
500 return;
502 }
503
504 let outer_key = &keys[0];
505 let inner_key = &keys[1];
506
507 map.entry(outer_key.clone())
508 .or_insert_with(HashMap::new)
509 .insert(inner_key.clone(), value);
510}
511
512pub fn nested_get<'a, V>(
514 map: &'a HashMap<String, HashMap<String, V>>,
515 keys: &[String],
516) -> Option<&'a V> {
517 if keys.len() != 2 {
518 return None;
519 }
520
521 map.get(&keys[0])?.get(&keys[1])
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_empty_schema() {
530 let schema = MappingSchema::new();
531 assert!(schema.is_empty());
532 assert_eq!(schema.depth(), 0);
533 }
534
535 #[test]
536 fn test_add_table() {
537 let mut schema = MappingSchema::new();
538 let columns = vec![
539 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
540 (
541 "name".to_string(),
542 DataType::VarChar { length: Some(255), parenthesized_length: false },
543 ),
544 ];
545
546 schema.add_table("users", &columns, None).unwrap();
547
548 assert!(!schema.is_empty());
549 assert_eq!(schema.depth(), 1);
550 assert!(schema.has_column("users", "id"));
551 assert!(schema.has_column("users", "name"));
552 assert!(!schema.has_column("users", "email"));
553 }
554
555 #[test]
556 fn test_qualified_table_names() {
557 let mut schema = MappingSchema::new();
558 let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
559
560 schema.add_table("mydb.users", &columns, None).unwrap();
561
562 assert!(schema.has_column("mydb.users", "id"));
563 assert_eq!(schema.depth(), 2);
564 }
565
566 #[test]
567 fn test_catalog_db_table() {
568 let mut schema = MappingSchema::new();
569 let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
570
571 schema
572 .add_table("catalog.mydb.users", &columns, None)
573 .unwrap();
574
575 assert!(schema.has_column("catalog.mydb.users", "id"));
576 assert_eq!(schema.depth(), 3);
577 }
578
579 #[test]
580 fn test_get_column_type() {
581 let mut schema = MappingSchema::new();
582 let columns = vec![
583 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
584 (
585 "name".to_string(),
586 DataType::VarChar { length: Some(255), parenthesized_length: false },
587 ),
588 ];
589
590 schema.add_table("users", &columns, None).unwrap();
591
592 let id_type = schema.get_column_type("users", "id").unwrap();
593 assert!(matches!(id_type, DataType::Int { .. }));
594
595 let name_type = schema.get_column_type("users", "name").unwrap();
596 assert!(matches!(name_type, DataType::VarChar { length: Some(255), parenthesized_length: false }));
597 }
598
599 #[test]
600 fn test_column_names() {
601 let mut schema = MappingSchema::new();
602 let columns = vec![
603 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
604 ("name".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
605 ];
606
607 schema.add_table("users", &columns, None).unwrap();
608
609 let names = schema.column_names("users").unwrap();
610 assert_eq!(names.len(), 2);
611 assert!(names.contains(&"id".to_string()));
612 assert!(names.contains(&"name".to_string()));
613 }
614
615 #[test]
616 fn test_table_not_found() {
617 let schema = MappingSchema::new();
618 let result = schema.column_names("nonexistent");
619 assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
620 }
621
622 #[test]
623 fn test_column_not_found() {
624 let mut schema = MappingSchema::new();
625 let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
626 schema.add_table("users", &columns, None).unwrap();
627
628 let result = schema.get_column_type("users", "nonexistent");
629 assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
630 }
631
632 #[test]
633 fn test_normalize_name_default() {
634 let name = normalize_name("MyTable", None, true, true);
635 assert_eq!(name, "mytable");
636 }
637
638 #[test]
639 fn test_normalize_name_snowflake() {
640 let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
641 assert_eq!(name, "MYTABLE");
642 }
643
644 #[test]
645 fn test_normalize_disabled() {
646 let name = normalize_name("MyTable", None, true, false);
647 assert_eq!(name, "MyTable");
648 }
649
650 #[test]
651 fn test_from_simple_map() {
652 let schema = from_simple_map(&[
653 (
654 "users",
655 &[
656 ("id", DataType::Int { length: None, integer_spelling: false }),
657 ("name", DataType::VarChar { length: None, parenthesized_length: false }),
658 ],
659 ),
660 (
661 "orders",
662 &[
663 ("id", DataType::Int { length: None, integer_spelling: false }),
664 ("user_id", DataType::Int { length: None, integer_spelling: false }),
665 ],
666 ),
667 ]);
668
669 assert!(schema.has_column("users", "id"));
670 assert!(schema.has_column("users", "name"));
671 assert!(schema.has_column("orders", "id"));
672 assert!(schema.has_column("orders", "user_id"));
673 }
674
675 #[test]
676 fn test_flatten_schema_paths() {
677 let mut schema = MappingSchema::new();
678 schema
679 .add_table("db1.table1", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
680 .unwrap();
681 schema
682 .add_table("db1.table2", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
683 .unwrap();
684 schema
685 .add_table("db2.table1", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
686 .unwrap();
687
688 let paths = flatten_schema_paths(&schema);
689 assert_eq!(paths.len(), 3);
690 }
691
692 #[test]
693 fn test_visible_columns() {
694 let mut schema = MappingSchema::new();
695 let columns = vec![
696 ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
697 ("name".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
698 ("password".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
699 ];
700 schema.add_table("users", &columns, None).unwrap();
701 schema.set_visible_columns("users", &["id", "name"]);
702
703 let names = schema.column_names("users").unwrap();
704 assert_eq!(names.len(), 2);
705 assert!(names.contains(&"id".to_string()));
706 assert!(names.contains(&"name".to_string()));
707 assert!(!names.contains(&"password".to_string()));
708 }
709}