Skip to main content

polyglot_sql/
schema.rs

1//! Schema management for SQL queries
2//!
3//! This module provides functionality for:
4//! - Representing database schemas (tables, columns, types)
5//! - Looking up column types for query optimization
6//! - Normalizing identifiers per dialect
7//!
8//! Based on the Python implementation in `sqlglot/schema.py`.
9
10use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16/// Errors that can occur during schema operations
17#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19    #[error("Table not found: {0}")]
20    TableNotFound(String),
21
22    #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23    AmbiguousTable { table: String, matches: String },
24
25    #[error("Column not found: {column} in table {table}")]
26    ColumnNotFound { table: String, column: String },
27
28    #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29    DepthMismatch { expected: usize, actual: usize },
30
31    #[error("Invalid schema structure: {0}")]
32    InvalidStructure(String),
33}
34
35/// Result type for schema operations
36pub type SchemaResult<T> = Result<T, SchemaError>;
37
38/// Supported table argument names
39pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41/// Abstract trait for database schemas
42pub trait Schema {
43    /// Get the dialect associated with this schema (if any)
44    fn dialect(&self) -> Option<DialectType>;
45
46    /// Add or update a table in the schema
47    fn add_table(
48        &mut self,
49        table: &str,
50        columns: &[(String, DataType)],
51        dialect: Option<DialectType>,
52    ) -> SchemaResult<()>;
53
54    /// Get column names for a table
55    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57    /// Get the type of a column in a table
58    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60    /// Check if a column exists in a table
61    fn has_column(&self, table: &str, column: &str) -> bool;
62
63    /// Get supported table argument levels
64    fn supported_table_args(&self) -> &[&str];
65
66    /// Check if the schema is empty
67    fn is_empty(&self) -> bool;
68
69    /// Get the nesting depth of the schema
70    fn depth(&self) -> usize;
71
72    /// Find which table(s) contain the given column name.
73    /// Returns table names that have the column. Used for correlated subquery resolution.
74    fn find_tables_for_column(&self, column: &str) -> Vec<String>;
75}
76
77/// A column with its type and visibility
78#[derive(Debug, Clone)]
79pub struct ColumnInfo {
80    pub data_type: DataType,
81    pub visible: bool,
82}
83
84impl ColumnInfo {
85    pub fn new(data_type: DataType) -> Self {
86        Self {
87            data_type,
88            visible: true,
89        }
90    }
91
92    pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
93        Self { data_type, visible }
94    }
95}
96
97/// A mapping-based schema implementation
98///
99/// Supports nested schemas with different levels:
100/// - Level 1: `{table: {col: type}}`
101/// - Level 2: `{db: {table: {col: type}}}`
102/// - Level 3: `{catalog: {db: {table: {col: type}}}}`
103#[derive(Debug, Clone)]
104pub struct MappingSchema {
105    /// The actual schema data
106    mapping: HashMap<String, SchemaNode>,
107    /// Trie for efficient table lookup
108    mapping_trie: Trie<()>,
109    /// The dialect for this schema
110    dialect: Option<DialectType>,
111    /// Whether to normalize identifiers
112    normalize: bool,
113    /// Visible columns per table
114    visible: HashMap<String, HashSet<String>>,
115    /// Cached depth
116    cached_depth: usize,
117}
118
119/// A node in the schema tree
120#[derive(Debug, Clone)]
121pub enum SchemaNode {
122    /// Intermediate node (database or catalog)
123    Namespace(HashMap<String, SchemaNode>),
124    /// Leaf node (table with columns)
125    Table(HashMap<String, ColumnInfo>),
126}
127
128impl Default for MappingSchema {
129    fn default() -> Self {
130        Self::new()
131    }
132}
133
134impl MappingSchema {
135    /// Create a new empty schema
136    pub fn new() -> Self {
137        Self {
138            mapping: HashMap::new(),
139            mapping_trie: Trie::new(),
140            dialect: None,
141            normalize: true,
142            visible: HashMap::new(),
143            cached_depth: 0,
144        }
145    }
146
147    /// Create a schema with a specific dialect
148    pub fn with_dialect(dialect: DialectType) -> Self {
149        Self {
150            dialect: Some(dialect),
151            ..Self::new()
152        }
153    }
154
155    /// Create a schema with normalization disabled
156    pub fn without_normalization(mut self) -> Self {
157        self.normalize = false;
158        self
159    }
160
161    /// Set visibility for columns in a table
162    pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
163        let key = self.normalize_name(table, true);
164        let cols: HashSet<String> = columns
165            .iter()
166            .map(|c| self.normalize_name(c, false))
167            .collect();
168        self.visible.insert(key, cols);
169    }
170
171    /// Normalize an identifier name based on dialect
172    fn normalize_name(&self, name: &str, is_table: bool) -> String {
173        if !self.normalize {
174            return name.to_string();
175        }
176
177        // Default normalization: lowercase
178        // Different dialects may have different rules
179        match self.dialect {
180            Some(DialectType::BigQuery) if is_table => {
181                // BigQuery preserves case for tables
182                name.to_string()
183            }
184            Some(DialectType::Snowflake) => {
185                // Snowflake uppercases by default
186                name.to_uppercase()
187            }
188            _ => {
189                // Most dialects lowercase
190                name.to_lowercase()
191            }
192        }
193    }
194
195    /// Parse a qualified table name into parts
196    fn parse_table_parts(&self, table: &str) -> Vec<String> {
197        table
198            .split('.')
199            .map(|s| self.normalize_name(s.trim(), true))
200            .collect()
201    }
202
203    /// Get the column mapping for a table
204    fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
205        let parts = self.parse_table_parts(table);
206
207        // Use trie to find table
208        let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
209        let key: String = reversed_parts.join(".");
210
211        let (result, _) = self.mapping_trie.in_trie(&key);
212
213        match result {
214            TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
215            TrieResult::Prefix => {
216                // Ambiguous - multiple tables match
217                Err(SchemaError::AmbiguousTable {
218                    table: table.to_string(),
219                    matches: "multiple matches".to_string(),
220                })
221            }
222            TrieResult::Exists => {
223                // Navigate to the table
224                self.navigate_to_table(&parts)
225            }
226        }
227    }
228
229    /// Navigate the schema tree to find a table's columns
230    fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
231        let mut current = &self.mapping;
232
233        for (i, part) in parts.iter().enumerate() {
234            match current.get(part) {
235                Some(SchemaNode::Namespace(inner)) => {
236                    current = inner;
237                }
238                Some(SchemaNode::Table(cols)) => {
239                    if i == parts.len() - 1 {
240                        return Ok(cols);
241                    } else {
242                        return Err(SchemaError::InvalidStructure(format!(
243                            "Found table at {} but expected more levels",
244                            parts[..=i].join(".")
245                        )));
246                    }
247                }
248                None => {
249                    return Err(SchemaError::TableNotFound(parts.join(".")));
250                }
251            }
252        }
253
254        // We've exhausted parts but didn't find a table
255        Err(SchemaError::TableNotFound(parts.join(".")))
256    }
257
258    /// Add a table to the schema
259    fn add_table_internal(
260        &mut self,
261        parts: &[String],
262        columns: HashMap<String, ColumnInfo>,
263    ) -> SchemaResult<()> {
264        if parts.is_empty() {
265            return Err(SchemaError::InvalidStructure(
266                "Table name cannot be empty".to_string(),
267            ));
268        }
269
270        // Build trie key (reversed parts)
271        let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
272        self.mapping_trie.insert(&trie_key, ());
273
274        // Navigate/create path to table
275        let mut current = &mut self.mapping;
276
277        for (i, part) in parts.iter().enumerate() {
278            let is_last = i == parts.len() - 1;
279
280            if is_last {
281                // Insert table
282                current.insert(part.clone(), SchemaNode::Table(columns));
283                return Ok(());
284            } else {
285                // Navigate or create namespace
286                let entry = current
287                    .entry(part.clone())
288                    .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
289
290                match entry {
291                    SchemaNode::Namespace(inner) => {
292                        current = inner;
293                    }
294                    SchemaNode::Table(_) => {
295                        return Err(SchemaError::InvalidStructure(format!(
296                            "Expected namespace at {} but found table",
297                            parts[..=i].join(".")
298                        )));
299                    }
300                }
301            }
302        }
303
304        Ok(())
305    }
306
307    /// Update cached depth
308    fn update_depth(&mut self) {
309        self.cached_depth = self.calculate_depth(&self.mapping);
310    }
311
312    fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
313        if mapping.is_empty() {
314            return 0;
315        }
316
317        let mut max_depth = 1;
318        for node in mapping.values() {
319            match node {
320                SchemaNode::Namespace(inner) => {
321                    let d = 1 + self.calculate_depth(inner);
322                    if d > max_depth {
323                        max_depth = d;
324                    }
325                }
326                SchemaNode::Table(_) => {
327                    // Tables don't add to depth beyond their level
328                }
329            }
330        }
331        max_depth
332    }
333}
334
335impl Schema for MappingSchema {
336    fn dialect(&self) -> Option<DialectType> {
337        self.dialect
338    }
339
340    fn add_table(
341        &mut self,
342        table: &str,
343        columns: &[(String, DataType)],
344        _dialect: Option<DialectType>,
345    ) -> SchemaResult<()> {
346        let parts = self.parse_table_parts(table);
347
348        let cols: HashMap<String, ColumnInfo> = columns
349            .iter()
350            .map(|(name, dtype)| {
351                let normalized_name = self.normalize_name(name, false);
352                (normalized_name, ColumnInfo::new(dtype.clone()))
353            })
354            .collect();
355
356        self.add_table_internal(&parts, cols)?;
357        self.update_depth();
358        Ok(())
359    }
360
361    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
362        let cols = self.find_table(table)?;
363        let table_key = self.normalize_name(table, true);
364
365        // Check visibility
366        if let Some(visible_cols) = self.visible.get(&table_key) {
367            Ok(cols
368                .keys()
369                .filter(|k| visible_cols.contains(*k))
370                .cloned()
371                .collect())
372        } else {
373            Ok(cols.keys().cloned().collect())
374        }
375    }
376
377    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
378        let cols = self.find_table(table)?;
379        let normalized_col = self.normalize_name(column, false);
380
381        cols.get(&normalized_col)
382            .map(|info| info.data_type.clone())
383            .ok_or_else(|| SchemaError::ColumnNotFound {
384                table: table.to_string(),
385                column: column.to_string(),
386            })
387    }
388
389    fn has_column(&self, table: &str, column: &str) -> bool {
390        self.get_column_type(table, column).is_ok()
391    }
392
393    fn supported_table_args(&self) -> &[&str] {
394        let depth = self.depth();
395        if depth == 0 {
396            &[]
397        } else if depth <= 3 {
398            &TABLE_PARTS[..depth]
399        } else {
400            TABLE_PARTS
401        }
402    }
403
404    fn is_empty(&self) -> bool {
405        self.mapping.is_empty()
406    }
407
408    fn depth(&self) -> usize {
409        self.cached_depth
410    }
411
412    fn find_tables_for_column(&self, column: &str) -> Vec<String> {
413        let normalized = normalize_name(column, self.dialect, false, self.normalize);
414        let mut result = Vec::new();
415        for table_name in self.mapping.keys() {
416            if self.has_column(table_name, &normalized) {
417                result.push(table_name.clone());
418            }
419        }
420        result
421    }
422}
423
424/// Normalize a table or column name according to dialect rules
425pub fn normalize_name(
426    name: &str,
427    dialect: Option<DialectType>,
428    is_table: bool,
429    normalize: bool,
430) -> String {
431    if !normalize {
432        return name.to_string();
433    }
434
435    match dialect {
436        Some(DialectType::BigQuery) if is_table => name.to_string(),
437        Some(DialectType::Snowflake) => name.to_uppercase(),
438        _ => name.to_lowercase(),
439    }
440}
441
442/// Ensure we have a schema instance
443pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
444    schema.unwrap_or_default()
445}
446
447/// Helper to build a schema from a simple map
448///
449/// # Example
450///
451/// ```
452/// use polyglot_sql::schema::{MappingSchema, Schema, from_simple_map};
453/// use polyglot_sql::expressions::DataType;
454///
455/// let schema = from_simple_map(&[
456///     ("users", &[("id", DataType::Int { length: None, integer_spelling: false }), ("name", DataType::VarChar { length: Some(255), parenthesized_length: false })]),
457///     ("orders", &[("id", DataType::Int { length: None, integer_spelling: false }), ("user_id", DataType::Int { length: None, integer_spelling: false })]),
458/// ]);
459///
460/// assert_eq!(schema.column_names("users").unwrap().len(), 2);
461/// ```
462pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
463    let mut schema = MappingSchema::new();
464
465    for (table_name, columns) in tables {
466        let cols: Vec<(String, DataType)> = columns
467            .iter()
468            .map(|(name, dtype)| (name.to_string(), dtype.clone()))
469            .collect();
470
471        schema.add_table(table_name, &cols, None).ok();
472    }
473
474    schema
475}
476
477/// Flatten a nested schema to get all table paths
478pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
479    let mut paths = Vec::new();
480    flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
481    paths
482}
483
484fn flatten_schema_paths_recursive(
485    mapping: &HashMap<String, SchemaNode>,
486    prefix: Vec<String>,
487    paths: &mut Vec<Vec<String>>,
488) {
489    for (key, node) in mapping {
490        let mut path = prefix.clone();
491        path.push(key.clone());
492
493        match node {
494            SchemaNode::Namespace(inner) => {
495                flatten_schema_paths_recursive(inner, path, paths);
496            }
497            SchemaNode::Table(_) => {
498                paths.push(path);
499            }
500        }
501    }
502}
503
504/// Set a value in a nested dictionary-like structure
505pub fn nested_set<V: Clone>(
506    map: &mut HashMap<String, HashMap<String, V>>,
507    keys: &[String],
508    value: V,
509) {
510    if keys.is_empty() {
511        return;
512    }
513
514    if keys.len() == 1 {
515        // Can't set at single level - need at least 2 keys
516        return;
517    }
518
519    let outer_key = &keys[0];
520    let inner_key = &keys[1];
521
522    map.entry(outer_key.clone())
523        .or_insert_with(HashMap::new)
524        .insert(inner_key.clone(), value);
525}
526
527/// Get a value from a nested dictionary-like structure
528pub fn nested_get<'a, V>(
529    map: &'a HashMap<String, HashMap<String, V>>,
530    keys: &[String],
531) -> Option<&'a V> {
532    if keys.len() != 2 {
533        return None;
534    }
535
536    map.get(&keys[0])?.get(&keys[1])
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542
543    #[test]
544    fn test_empty_schema() {
545        let schema = MappingSchema::new();
546        assert!(schema.is_empty());
547        assert_eq!(schema.depth(), 0);
548    }
549
550    #[test]
551    fn test_add_table() {
552        let mut schema = MappingSchema::new();
553        let columns = vec![
554            (
555                "id".to_string(),
556                DataType::Int {
557                    length: None,
558                    integer_spelling: false,
559                },
560            ),
561            (
562                "name".to_string(),
563                DataType::VarChar {
564                    length: Some(255),
565                    parenthesized_length: false,
566                },
567            ),
568        ];
569
570        schema.add_table("users", &columns, None).unwrap();
571
572        assert!(!schema.is_empty());
573        assert_eq!(schema.depth(), 1);
574        assert!(schema.has_column("users", "id"));
575        assert!(schema.has_column("users", "name"));
576        assert!(!schema.has_column("users", "email"));
577    }
578
579    #[test]
580    fn test_qualified_table_names() {
581        let mut schema = MappingSchema::new();
582        let columns = vec![(
583            "id".to_string(),
584            DataType::Int {
585                length: None,
586                integer_spelling: false,
587            },
588        )];
589
590        schema.add_table("mydb.users", &columns, None).unwrap();
591
592        assert!(schema.has_column("mydb.users", "id"));
593        assert_eq!(schema.depth(), 2);
594    }
595
596    #[test]
597    fn test_catalog_db_table() {
598        let mut schema = MappingSchema::new();
599        let columns = vec![(
600            "id".to_string(),
601            DataType::Int {
602                length: None,
603                integer_spelling: false,
604            },
605        )];
606
607        schema
608            .add_table("catalog.mydb.users", &columns, None)
609            .unwrap();
610
611        assert!(schema.has_column("catalog.mydb.users", "id"));
612        assert_eq!(schema.depth(), 3);
613    }
614
615    #[test]
616    fn test_get_column_type() {
617        let mut schema = MappingSchema::new();
618        let columns = vec![
619            (
620                "id".to_string(),
621                DataType::Int {
622                    length: None,
623                    integer_spelling: false,
624                },
625            ),
626            (
627                "name".to_string(),
628                DataType::VarChar {
629                    length: Some(255),
630                    parenthesized_length: false,
631                },
632            ),
633        ];
634
635        schema.add_table("users", &columns, None).unwrap();
636
637        let id_type = schema.get_column_type("users", "id").unwrap();
638        assert!(matches!(id_type, DataType::Int { .. }));
639
640        let name_type = schema.get_column_type("users", "name").unwrap();
641        assert!(matches!(
642            name_type,
643            DataType::VarChar {
644                length: Some(255),
645                parenthesized_length: false
646            }
647        ));
648    }
649
650    #[test]
651    fn test_column_names() {
652        let mut schema = MappingSchema::new();
653        let columns = vec![
654            (
655                "id".to_string(),
656                DataType::Int {
657                    length: None,
658                    integer_spelling: false,
659                },
660            ),
661            (
662                "name".to_string(),
663                DataType::VarChar {
664                    length: None,
665                    parenthesized_length: false,
666                },
667            ),
668        ];
669
670        schema.add_table("users", &columns, None).unwrap();
671
672        let names = schema.column_names("users").unwrap();
673        assert_eq!(names.len(), 2);
674        assert!(names.contains(&"id".to_string()));
675        assert!(names.contains(&"name".to_string()));
676    }
677
678    #[test]
679    fn test_table_not_found() {
680        let schema = MappingSchema::new();
681        let result = schema.column_names("nonexistent");
682        assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
683    }
684
685    #[test]
686    fn test_column_not_found() {
687        let mut schema = MappingSchema::new();
688        let columns = vec![(
689            "id".to_string(),
690            DataType::Int {
691                length: None,
692                integer_spelling: false,
693            },
694        )];
695        schema.add_table("users", &columns, None).unwrap();
696
697        let result = schema.get_column_type("users", "nonexistent");
698        assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
699    }
700
701    #[test]
702    fn test_normalize_name_default() {
703        let name = normalize_name("MyTable", None, true, true);
704        assert_eq!(name, "mytable");
705    }
706
707    #[test]
708    fn test_normalize_name_snowflake() {
709        let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
710        assert_eq!(name, "MYTABLE");
711    }
712
713    #[test]
714    fn test_normalize_disabled() {
715        let name = normalize_name("MyTable", None, true, false);
716        assert_eq!(name, "MyTable");
717    }
718
719    #[test]
720    fn test_from_simple_map() {
721        let schema = from_simple_map(&[
722            (
723                "users",
724                &[
725                    (
726                        "id",
727                        DataType::Int {
728                            length: None,
729                            integer_spelling: false,
730                        },
731                    ),
732                    (
733                        "name",
734                        DataType::VarChar {
735                            length: None,
736                            parenthesized_length: false,
737                        },
738                    ),
739                ],
740            ),
741            (
742                "orders",
743                &[
744                    (
745                        "id",
746                        DataType::Int {
747                            length: None,
748                            integer_spelling: false,
749                        },
750                    ),
751                    (
752                        "user_id",
753                        DataType::Int {
754                            length: None,
755                            integer_spelling: false,
756                        },
757                    ),
758                ],
759            ),
760        ]);
761
762        assert!(schema.has_column("users", "id"));
763        assert!(schema.has_column("users", "name"));
764        assert!(schema.has_column("orders", "id"));
765        assert!(schema.has_column("orders", "user_id"));
766    }
767
768    #[test]
769    fn test_flatten_schema_paths() {
770        let mut schema = MappingSchema::new();
771        schema
772            .add_table(
773                "db1.table1",
774                &[(
775                    "id".to_string(),
776                    DataType::Int {
777                        length: None,
778                        integer_spelling: false,
779                    },
780                )],
781                None,
782            )
783            .unwrap();
784        schema
785            .add_table(
786                "db1.table2",
787                &[(
788                    "id".to_string(),
789                    DataType::Int {
790                        length: None,
791                        integer_spelling: false,
792                    },
793                )],
794                None,
795            )
796            .unwrap();
797        schema
798            .add_table(
799                "db2.table1",
800                &[(
801                    "id".to_string(),
802                    DataType::Int {
803                        length: None,
804                        integer_spelling: false,
805                    },
806                )],
807                None,
808            )
809            .unwrap();
810
811        let paths = flatten_schema_paths(&schema);
812        assert_eq!(paths.len(), 3);
813    }
814
815    #[test]
816    fn test_visible_columns() {
817        let mut schema = MappingSchema::new();
818        let columns = vec![
819            (
820                "id".to_string(),
821                DataType::Int {
822                    length: None,
823                    integer_spelling: false,
824                },
825            ),
826            (
827                "name".to_string(),
828                DataType::VarChar {
829                    length: None,
830                    parenthesized_length: false,
831                },
832            ),
833            (
834                "password".to_string(),
835                DataType::VarChar {
836                    length: None,
837                    parenthesized_length: false,
838                },
839            ),
840        ];
841        schema.add_table("users", &columns, None).unwrap();
842        schema.set_visible_columns("users", &["id", "name"]);
843
844        let names = schema.column_names("users").unwrap();
845        assert_eq!(names.len(), 2);
846        assert!(names.contains(&"id".to_string()));
847        assert!(names.contains(&"name".to_string()));
848        assert!(!names.contains(&"password".to_string()));
849    }
850}