Skip to main content

polyglot_sql/
schema.rs

1//! Schema management for SQL queries
2//!
3//! This module provides functionality for:
4//! - Representing database schemas (tables, columns, types)
5//! - Looking up column types for query optimization
6//! - Normalizing identifiers per dialect
7//!
8//! Based on the Python implementation in `sqlglot/schema.py`.
9
10use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16/// Errors that can occur during schema operations
17#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19    #[error("Table not found: {0}")]
20    TableNotFound(String),
21
22    #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23    AmbiguousTable { table: String, matches: String },
24
25    #[error("Column not found: {column} in table {table}")]
26    ColumnNotFound { table: String, column: String },
27
28    #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29    DepthMismatch { expected: usize, actual: usize },
30
31    #[error("Invalid schema structure: {0}")]
32    InvalidStructure(String),
33}
34
35/// Result type for schema operations
36pub type SchemaResult<T> = Result<T, SchemaError>;
37
38/// Supported table argument names
39pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41/// Abstract trait for database schemas
42pub trait Schema {
43    /// Get the dialect associated with this schema (if any)
44    fn dialect(&self) -> Option<DialectType>;
45
46    /// Add or update a table in the schema
47    fn add_table(
48        &mut self,
49        table: &str,
50        columns: &[(String, DataType)],
51        dialect: Option<DialectType>,
52    ) -> SchemaResult<()>;
53
54    /// Get column names for a table
55    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57    /// Get the type of a column in a table
58    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60    /// Check if a column exists in a table
61    fn has_column(&self, table: &str, column: &str) -> bool;
62
63    /// Get supported table argument levels
64    fn supported_table_args(&self) -> &[&str];
65
66    /// Check if the schema is empty
67    fn is_empty(&self) -> bool;
68
69    /// Get the nesting depth of the schema
70    fn depth(&self) -> usize;
71}
72
73/// A column with its type and visibility
74#[derive(Debug, Clone)]
75pub struct ColumnInfo {
76    pub data_type: DataType,
77    pub visible: bool,
78}
79
80impl ColumnInfo {
81    pub fn new(data_type: DataType) -> Self {
82        Self {
83            data_type,
84            visible: true,
85        }
86    }
87
88    pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
89        Self { data_type, visible }
90    }
91}
92
93/// A mapping-based schema implementation
94///
95/// Supports nested schemas with different levels:
96/// - Level 1: `{table: {col: type}}`
97/// - Level 2: `{db: {table: {col: type}}}`
98/// - Level 3: `{catalog: {db: {table: {col: type}}}}`
99#[derive(Debug, Clone)]
100pub struct MappingSchema {
101    /// The actual schema data
102    mapping: HashMap<String, SchemaNode>,
103    /// Trie for efficient table lookup
104    mapping_trie: Trie<()>,
105    /// The dialect for this schema
106    dialect: Option<DialectType>,
107    /// Whether to normalize identifiers
108    normalize: bool,
109    /// Visible columns per table
110    visible: HashMap<String, HashSet<String>>,
111    /// Cached depth
112    cached_depth: usize,
113}
114
115/// A node in the schema tree
116#[derive(Debug, Clone)]
117pub enum SchemaNode {
118    /// Intermediate node (database or catalog)
119    Namespace(HashMap<String, SchemaNode>),
120    /// Leaf node (table with columns)
121    Table(HashMap<String, ColumnInfo>),
122}
123
124impl Default for MappingSchema {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl MappingSchema {
131    /// Create a new empty schema
132    pub fn new() -> Self {
133        Self {
134            mapping: HashMap::new(),
135            mapping_trie: Trie::new(),
136            dialect: None,
137            normalize: true,
138            visible: HashMap::new(),
139            cached_depth: 0,
140        }
141    }
142
143    /// Create a schema with a specific dialect
144    pub fn with_dialect(dialect: DialectType) -> Self {
145        Self {
146            dialect: Some(dialect),
147            ..Self::new()
148        }
149    }
150
151    /// Create a schema with normalization disabled
152    pub fn without_normalization(mut self) -> Self {
153        self.normalize = false;
154        self
155    }
156
157    /// Set visibility for columns in a table
158    pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
159        let key = self.normalize_name(table, true);
160        let cols: HashSet<String> = columns
161            .iter()
162            .map(|c| self.normalize_name(c, false))
163            .collect();
164        self.visible.insert(key, cols);
165    }
166
167    /// Normalize an identifier name based on dialect
168    fn normalize_name(&self, name: &str, is_table: bool) -> String {
169        if !self.normalize {
170            return name.to_string();
171        }
172
173        // Default normalization: lowercase
174        // Different dialects may have different rules
175        match self.dialect {
176            Some(DialectType::BigQuery) if is_table => {
177                // BigQuery preserves case for tables
178                name.to_string()
179            }
180            Some(DialectType::Snowflake) => {
181                // Snowflake uppercases by default
182                name.to_uppercase()
183            }
184            _ => {
185                // Most dialects lowercase
186                name.to_lowercase()
187            }
188        }
189    }
190
191    /// Parse a qualified table name into parts
192    fn parse_table_parts(&self, table: &str) -> Vec<String> {
193        table
194            .split('.')
195            .map(|s| self.normalize_name(s.trim(), true))
196            .collect()
197    }
198
199    /// Get the column mapping for a table
200    fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
201        let parts = self.parse_table_parts(table);
202
203        // Use trie to find table
204        let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
205        let key: String = reversed_parts.join(".");
206
207        let (result, _) = self.mapping_trie.in_trie(&key);
208
209        match result {
210            TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
211            TrieResult::Prefix => {
212                // Ambiguous - multiple tables match
213                Err(SchemaError::AmbiguousTable {
214                    table: table.to_string(),
215                    matches: "multiple matches".to_string(),
216                })
217            }
218            TrieResult::Exists => {
219                // Navigate to the table
220                self.navigate_to_table(&parts)
221            }
222        }
223    }
224
225    /// Navigate the schema tree to find a table's columns
226    fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
227        let mut current = &self.mapping;
228
229        for (i, part) in parts.iter().enumerate() {
230            match current.get(part) {
231                Some(SchemaNode::Namespace(inner)) => {
232                    current = inner;
233                }
234                Some(SchemaNode::Table(cols)) => {
235                    if i == parts.len() - 1 {
236                        return Ok(cols);
237                    } else {
238                        return Err(SchemaError::InvalidStructure(format!(
239                            "Found table at {} but expected more levels",
240                            parts[..=i].join(".")
241                        )));
242                    }
243                }
244                None => {
245                    return Err(SchemaError::TableNotFound(parts.join(".")));
246                }
247            }
248        }
249
250        // We've exhausted parts but didn't find a table
251        Err(SchemaError::TableNotFound(parts.join(".")))
252    }
253
254    /// Add a table to the schema
255    fn add_table_internal(
256        &mut self,
257        parts: &[String],
258        columns: HashMap<String, ColumnInfo>,
259    ) -> SchemaResult<()> {
260        if parts.is_empty() {
261            return Err(SchemaError::InvalidStructure(
262                "Table name cannot be empty".to_string(),
263            ));
264        }
265
266        // Build trie key (reversed parts)
267        let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
268        self.mapping_trie.insert(&trie_key, ());
269
270        // Navigate/create path to table
271        let mut current = &mut self.mapping;
272
273        for (i, part) in parts.iter().enumerate() {
274            let is_last = i == parts.len() - 1;
275
276            if is_last {
277                // Insert table
278                current.insert(part.clone(), SchemaNode::Table(columns));
279                return Ok(());
280            } else {
281                // Navigate or create namespace
282                let entry = current
283                    .entry(part.clone())
284                    .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
285
286                match entry {
287                    SchemaNode::Namespace(inner) => {
288                        current = inner;
289                    }
290                    SchemaNode::Table(_) => {
291                        return Err(SchemaError::InvalidStructure(format!(
292                            "Expected namespace at {} but found table",
293                            parts[..=i].join(".")
294                        )));
295                    }
296                }
297            }
298        }
299
300        Ok(())
301    }
302
303    /// Update cached depth
304    fn update_depth(&mut self) {
305        self.cached_depth = self.calculate_depth(&self.mapping);
306    }
307
308    fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
309        if mapping.is_empty() {
310            return 0;
311        }
312
313        let mut max_depth = 1;
314        for node in mapping.values() {
315            match node {
316                SchemaNode::Namespace(inner) => {
317                    let d = 1 + self.calculate_depth(inner);
318                    if d > max_depth {
319                        max_depth = d;
320                    }
321                }
322                SchemaNode::Table(_) => {
323                    // Tables don't add to depth beyond their level
324                }
325            }
326        }
327        max_depth
328    }
329}
330
331impl Schema for MappingSchema {
332    fn dialect(&self) -> Option<DialectType> {
333        self.dialect
334    }
335
336    fn add_table(
337        &mut self,
338        table: &str,
339        columns: &[(String, DataType)],
340        _dialect: Option<DialectType>,
341    ) -> SchemaResult<()> {
342        let parts = self.parse_table_parts(table);
343
344        let cols: HashMap<String, ColumnInfo> = columns
345            .iter()
346            .map(|(name, dtype)| {
347                let normalized_name = self.normalize_name(name, false);
348                (normalized_name, ColumnInfo::new(dtype.clone()))
349            })
350            .collect();
351
352        self.add_table_internal(&parts, cols)?;
353        self.update_depth();
354        Ok(())
355    }
356
357    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
358        let cols = self.find_table(table)?;
359        let table_key = self.normalize_name(table, true);
360
361        // Check visibility
362        if let Some(visible_cols) = self.visible.get(&table_key) {
363            Ok(cols
364                .keys()
365                .filter(|k| visible_cols.contains(*k))
366                .cloned()
367                .collect())
368        } else {
369            Ok(cols.keys().cloned().collect())
370        }
371    }
372
373    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
374        let cols = self.find_table(table)?;
375        let normalized_col = self.normalize_name(column, false);
376
377        cols.get(&normalized_col)
378            .map(|info| info.data_type.clone())
379            .ok_or_else(|| SchemaError::ColumnNotFound {
380                table: table.to_string(),
381                column: column.to_string(),
382            })
383    }
384
385    fn has_column(&self, table: &str, column: &str) -> bool {
386        self.get_column_type(table, column).is_ok()
387    }
388
389    fn supported_table_args(&self) -> &[&str] {
390        let depth = self.depth();
391        if depth == 0 {
392            &[]
393        } else if depth <= 3 {
394            &TABLE_PARTS[..depth]
395        } else {
396            TABLE_PARTS
397        }
398    }
399
400    fn is_empty(&self) -> bool {
401        self.mapping.is_empty()
402    }
403
404    fn depth(&self) -> usize {
405        self.cached_depth
406    }
407}
408
409/// Normalize a table or column name according to dialect rules
410pub fn normalize_name(
411    name: &str,
412    dialect: Option<DialectType>,
413    is_table: bool,
414    normalize: bool,
415) -> String {
416    if !normalize {
417        return name.to_string();
418    }
419
420    match dialect {
421        Some(DialectType::BigQuery) if is_table => name.to_string(),
422        Some(DialectType::Snowflake) => name.to_uppercase(),
423        _ => name.to_lowercase(),
424    }
425}
426
427/// Ensure we have a schema instance
428pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
429    schema.unwrap_or_default()
430}
431
432/// Helper to build a schema from a simple map
433///
434/// # Example
435///
436/// ```
437/// use polyglot_sql::schema::{MappingSchema, Schema, from_simple_map};
438/// use polyglot_sql::expressions::DataType;
439///
440/// let schema = from_simple_map(&[
441///     ("users", &[("id", DataType::Int { length: None, integer_spelling: false }), ("name", DataType::VarChar { length: Some(255), parenthesized_length: false })]),
442///     ("orders", &[("id", DataType::Int { length: None, integer_spelling: false }), ("user_id", DataType::Int { length: None, integer_spelling: false })]),
443/// ]);
444///
445/// assert_eq!(schema.column_names("users").unwrap().len(), 2);
446/// ```
447pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
448    let mut schema = MappingSchema::new();
449
450    for (table_name, columns) in tables {
451        let cols: Vec<(String, DataType)> = columns
452            .iter()
453            .map(|(name, dtype)| (name.to_string(), dtype.clone()))
454            .collect();
455
456        schema.add_table(table_name, &cols, None).ok();
457    }
458
459    schema
460}
461
462/// Flatten a nested schema to get all table paths
463pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
464    let mut paths = Vec::new();
465    flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
466    paths
467}
468
469fn flatten_schema_paths_recursive(
470    mapping: &HashMap<String, SchemaNode>,
471    prefix: Vec<String>,
472    paths: &mut Vec<Vec<String>>,
473) {
474    for (key, node) in mapping {
475        let mut path = prefix.clone();
476        path.push(key.clone());
477
478        match node {
479            SchemaNode::Namespace(inner) => {
480                flatten_schema_paths_recursive(inner, path, paths);
481            }
482            SchemaNode::Table(_) => {
483                paths.push(path);
484            }
485        }
486    }
487}
488
489/// Set a value in a nested dictionary-like structure
490pub fn nested_set<V: Clone>(
491    map: &mut HashMap<String, HashMap<String, V>>,
492    keys: &[String],
493    value: V,
494) {
495    if keys.is_empty() {
496        return;
497    }
498
499    if keys.len() == 1 {
500        // Can't set at single level - need at least 2 keys
501        return;
502    }
503
504    let outer_key = &keys[0];
505    let inner_key = &keys[1];
506
507    map.entry(outer_key.clone())
508        .or_insert_with(HashMap::new)
509        .insert(inner_key.clone(), value);
510}
511
512/// Get a value from a nested dictionary-like structure
513pub fn nested_get<'a, V>(
514    map: &'a HashMap<String, HashMap<String, V>>,
515    keys: &[String],
516) -> Option<&'a V> {
517    if keys.len() != 2 {
518        return None;
519    }
520
521    map.get(&keys[0])?.get(&keys[1])
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_empty_schema() {
530        let schema = MappingSchema::new();
531        assert!(schema.is_empty());
532        assert_eq!(schema.depth(), 0);
533    }
534
535    #[test]
536    fn test_add_table() {
537        let mut schema = MappingSchema::new();
538        let columns = vec![
539            (
540                "id".to_string(),
541                DataType::Int {
542                    length: None,
543                    integer_spelling: false,
544                },
545            ),
546            (
547                "name".to_string(),
548                DataType::VarChar {
549                    length: Some(255),
550                    parenthesized_length: false,
551                },
552            ),
553        ];
554
555        schema.add_table("users", &columns, None).unwrap();
556
557        assert!(!schema.is_empty());
558        assert_eq!(schema.depth(), 1);
559        assert!(schema.has_column("users", "id"));
560        assert!(schema.has_column("users", "name"));
561        assert!(!schema.has_column("users", "email"));
562    }
563
564    #[test]
565    fn test_qualified_table_names() {
566        let mut schema = MappingSchema::new();
567        let columns = vec![(
568            "id".to_string(),
569            DataType::Int {
570                length: None,
571                integer_spelling: false,
572            },
573        )];
574
575        schema.add_table("mydb.users", &columns, None).unwrap();
576
577        assert!(schema.has_column("mydb.users", "id"));
578        assert_eq!(schema.depth(), 2);
579    }
580
581    #[test]
582    fn test_catalog_db_table() {
583        let mut schema = MappingSchema::new();
584        let columns = vec![(
585            "id".to_string(),
586            DataType::Int {
587                length: None,
588                integer_spelling: false,
589            },
590        )];
591
592        schema
593            .add_table("catalog.mydb.users", &columns, None)
594            .unwrap();
595
596        assert!(schema.has_column("catalog.mydb.users", "id"));
597        assert_eq!(schema.depth(), 3);
598    }
599
600    #[test]
601    fn test_get_column_type() {
602        let mut schema = MappingSchema::new();
603        let columns = vec![
604            (
605                "id".to_string(),
606                DataType::Int {
607                    length: None,
608                    integer_spelling: false,
609                },
610            ),
611            (
612                "name".to_string(),
613                DataType::VarChar {
614                    length: Some(255),
615                    parenthesized_length: false,
616                },
617            ),
618        ];
619
620        schema.add_table("users", &columns, None).unwrap();
621
622        let id_type = schema.get_column_type("users", "id").unwrap();
623        assert!(matches!(id_type, DataType::Int { .. }));
624
625        let name_type = schema.get_column_type("users", "name").unwrap();
626        assert!(matches!(
627            name_type,
628            DataType::VarChar {
629                length: Some(255),
630                parenthesized_length: false
631            }
632        ));
633    }
634
635    #[test]
636    fn test_column_names() {
637        let mut schema = MappingSchema::new();
638        let columns = vec![
639            (
640                "id".to_string(),
641                DataType::Int {
642                    length: None,
643                    integer_spelling: false,
644                },
645            ),
646            (
647                "name".to_string(),
648                DataType::VarChar {
649                    length: None,
650                    parenthesized_length: false,
651                },
652            ),
653        ];
654
655        schema.add_table("users", &columns, None).unwrap();
656
657        let names = schema.column_names("users").unwrap();
658        assert_eq!(names.len(), 2);
659        assert!(names.contains(&"id".to_string()));
660        assert!(names.contains(&"name".to_string()));
661    }
662
663    #[test]
664    fn test_table_not_found() {
665        let schema = MappingSchema::new();
666        let result = schema.column_names("nonexistent");
667        assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
668    }
669
670    #[test]
671    fn test_column_not_found() {
672        let mut schema = MappingSchema::new();
673        let columns = vec![(
674            "id".to_string(),
675            DataType::Int {
676                length: None,
677                integer_spelling: false,
678            },
679        )];
680        schema.add_table("users", &columns, None).unwrap();
681
682        let result = schema.get_column_type("users", "nonexistent");
683        assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
684    }
685
686    #[test]
687    fn test_normalize_name_default() {
688        let name = normalize_name("MyTable", None, true, true);
689        assert_eq!(name, "mytable");
690    }
691
692    #[test]
693    fn test_normalize_name_snowflake() {
694        let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
695        assert_eq!(name, "MYTABLE");
696    }
697
698    #[test]
699    fn test_normalize_disabled() {
700        let name = normalize_name("MyTable", None, true, false);
701        assert_eq!(name, "MyTable");
702    }
703
704    #[test]
705    fn test_from_simple_map() {
706        let schema = from_simple_map(&[
707            (
708                "users",
709                &[
710                    (
711                        "id",
712                        DataType::Int {
713                            length: None,
714                            integer_spelling: false,
715                        },
716                    ),
717                    (
718                        "name",
719                        DataType::VarChar {
720                            length: None,
721                            parenthesized_length: false,
722                        },
723                    ),
724                ],
725            ),
726            (
727                "orders",
728                &[
729                    (
730                        "id",
731                        DataType::Int {
732                            length: None,
733                            integer_spelling: false,
734                        },
735                    ),
736                    (
737                        "user_id",
738                        DataType::Int {
739                            length: None,
740                            integer_spelling: false,
741                        },
742                    ),
743                ],
744            ),
745        ]);
746
747        assert!(schema.has_column("users", "id"));
748        assert!(schema.has_column("users", "name"));
749        assert!(schema.has_column("orders", "id"));
750        assert!(schema.has_column("orders", "user_id"));
751    }
752
753    #[test]
754    fn test_flatten_schema_paths() {
755        let mut schema = MappingSchema::new();
756        schema
757            .add_table(
758                "db1.table1",
759                &[(
760                    "id".to_string(),
761                    DataType::Int {
762                        length: None,
763                        integer_spelling: false,
764                    },
765                )],
766                None,
767            )
768            .unwrap();
769        schema
770            .add_table(
771                "db1.table2",
772                &[(
773                    "id".to_string(),
774                    DataType::Int {
775                        length: None,
776                        integer_spelling: false,
777                    },
778                )],
779                None,
780            )
781            .unwrap();
782        schema
783            .add_table(
784                "db2.table1",
785                &[(
786                    "id".to_string(),
787                    DataType::Int {
788                        length: None,
789                        integer_spelling: false,
790                    },
791                )],
792                None,
793            )
794            .unwrap();
795
796        let paths = flatten_schema_paths(&schema);
797        assert_eq!(paths.len(), 3);
798    }
799
800    #[test]
801    fn test_visible_columns() {
802        let mut schema = MappingSchema::new();
803        let columns = vec![
804            (
805                "id".to_string(),
806                DataType::Int {
807                    length: None,
808                    integer_spelling: false,
809                },
810            ),
811            (
812                "name".to_string(),
813                DataType::VarChar {
814                    length: None,
815                    parenthesized_length: false,
816                },
817            ),
818            (
819                "password".to_string(),
820                DataType::VarChar {
821                    length: None,
822                    parenthesized_length: false,
823                },
824            ),
825        ];
826        schema.add_table("users", &columns, None).unwrap();
827        schema.set_visible_columns("users", &["id", "name"]);
828
829        let names = schema.column_names("users").unwrap();
830        assert_eq!(names.len(), 2);
831        assert!(names.contains(&"id".to_string()));
832        assert!(names.contains(&"name".to_string()));
833        assert!(!names.contains(&"password".to_string()));
834    }
835}