Skip to main content

polyglot_sql/
schema.rs

1//! Schema management for SQL queries
2//!
3//! This module provides functionality for:
4//! - Representing database schemas (tables, columns, types)
5//! - Looking up column types for query optimization
6//! - Normalizing identifiers per dialect
7//!
8//! Based on the Python implementation in `sqlglot/schema.py`.
9
10use crate::dialects::DialectType;
11use crate::expressions::DataType;
12use crate::trie::{Trie, TrieResult};
13use std::collections::{HashMap, HashSet};
14use thiserror::Error;
15
16/// Errors that can occur during schema operations
17#[derive(Debug, Error, Clone)]
18pub enum SchemaError {
19    #[error("Table not found: {0}")]
20    TableNotFound(String),
21
22    #[error("Ambiguous table: {table} matches multiple tables: {matches}")]
23    AmbiguousTable { table: String, matches: String },
24
25    #[error("Column not found: {column} in table {table}")]
26    ColumnNotFound { table: String, column: String },
27
28    #[error("Schema nesting depth mismatch: expected {expected}, got {actual}")]
29    DepthMismatch { expected: usize, actual: usize },
30
31    #[error("Invalid schema structure: {0}")]
32    InvalidStructure(String),
33}
34
35/// Result type for schema operations
36pub type SchemaResult<T> = Result<T, SchemaError>;
37
38/// Supported table argument names
39pub const TABLE_PARTS: &[&str] = &["this", "db", "catalog"];
40
41/// Abstract trait for database schemas
42pub trait Schema {
43    /// Get the dialect associated with this schema (if any)
44    fn dialect(&self) -> Option<DialectType>;
45
46    /// Add or update a table in the schema
47    fn add_table(
48        &mut self,
49        table: &str,
50        columns: &[(String, DataType)],
51        dialect: Option<DialectType>,
52    ) -> SchemaResult<()>;
53
54    /// Get column names for a table
55    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>>;
56
57    /// Get the type of a column in a table
58    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType>;
59
60    /// Check if a column exists in a table
61    fn has_column(&self, table: &str, column: &str) -> bool;
62
63    /// Get supported table argument levels
64    fn supported_table_args(&self) -> &[&str];
65
66    /// Check if the schema is empty
67    fn is_empty(&self) -> bool;
68
69    /// Get the nesting depth of the schema
70    fn depth(&self) -> usize;
71}
72
73/// A column with its type and visibility
74#[derive(Debug, Clone)]
75pub struct ColumnInfo {
76    pub data_type: DataType,
77    pub visible: bool,
78}
79
80impl ColumnInfo {
81    pub fn new(data_type: DataType) -> Self {
82        Self {
83            data_type,
84            visible: true,
85        }
86    }
87
88    pub fn with_visibility(data_type: DataType, visible: bool) -> Self {
89        Self { data_type, visible }
90    }
91}
92
93/// A mapping-based schema implementation
94///
95/// Supports nested schemas with different levels:
96/// - Level 1: `{table: {col: type}}`
97/// - Level 2: `{db: {table: {col: type}}}`
98/// - Level 3: `{catalog: {db: {table: {col: type}}}}`
99#[derive(Debug, Clone)]
100pub struct MappingSchema {
101    /// The actual schema data
102    mapping: HashMap<String, SchemaNode>,
103    /// Trie for efficient table lookup
104    mapping_trie: Trie<()>,
105    /// The dialect for this schema
106    dialect: Option<DialectType>,
107    /// Whether to normalize identifiers
108    normalize: bool,
109    /// Visible columns per table
110    visible: HashMap<String, HashSet<String>>,
111    /// Cached depth
112    cached_depth: usize,
113}
114
115/// A node in the schema tree
116#[derive(Debug, Clone)]
117pub enum SchemaNode {
118    /// Intermediate node (database or catalog)
119    Namespace(HashMap<String, SchemaNode>),
120    /// Leaf node (table with columns)
121    Table(HashMap<String, ColumnInfo>),
122}
123
124impl Default for MappingSchema {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl MappingSchema {
131    /// Create a new empty schema
132    pub fn new() -> Self {
133        Self {
134            mapping: HashMap::new(),
135            mapping_trie: Trie::new(),
136            dialect: None,
137            normalize: true,
138            visible: HashMap::new(),
139            cached_depth: 0,
140        }
141    }
142
143    /// Create a schema with a specific dialect
144    pub fn with_dialect(dialect: DialectType) -> Self {
145        Self {
146            dialect: Some(dialect),
147            ..Self::new()
148        }
149    }
150
151    /// Create a schema with normalization disabled
152    pub fn without_normalization(mut self) -> Self {
153        self.normalize = false;
154        self
155    }
156
157    /// Set visibility for columns in a table
158    pub fn set_visible_columns(&mut self, table: &str, columns: &[&str]) {
159        let key = self.normalize_name(table, true);
160        let cols: HashSet<String> = columns
161            .iter()
162            .map(|c| self.normalize_name(c, false))
163            .collect();
164        self.visible.insert(key, cols);
165    }
166
167    /// Normalize an identifier name based on dialect
168    fn normalize_name(&self, name: &str, is_table: bool) -> String {
169        if !self.normalize {
170            return name.to_string();
171        }
172
173        // Default normalization: lowercase
174        // Different dialects may have different rules
175        match self.dialect {
176            Some(DialectType::BigQuery) if is_table => {
177                // BigQuery preserves case for tables
178                name.to_string()
179            }
180            Some(DialectType::Snowflake) => {
181                // Snowflake uppercases by default
182                name.to_uppercase()
183            }
184            _ => {
185                // Most dialects lowercase
186                name.to_lowercase()
187            }
188        }
189    }
190
191    /// Parse a qualified table name into parts
192    fn parse_table_parts(&self, table: &str) -> Vec<String> {
193        table
194            .split('.')
195            .map(|s| self.normalize_name(s.trim(), true))
196            .collect()
197    }
198
199    /// Get the column mapping for a table
200    fn find_table(&self, table: &str) -> SchemaResult<&HashMap<String, ColumnInfo>> {
201        let parts = self.parse_table_parts(table);
202
203        // Use trie to find table
204        let reversed_parts: Vec<_> = parts.iter().rev().map(|s| s.as_str()).collect();
205        let key: String = reversed_parts.join(".");
206
207        let (result, _) = self.mapping_trie.in_trie(&key);
208
209        match result {
210            TrieResult::Failed => Err(SchemaError::TableNotFound(table.to_string())),
211            TrieResult::Prefix => {
212                // Ambiguous - multiple tables match
213                Err(SchemaError::AmbiguousTable {
214                    table: table.to_string(),
215                    matches: "multiple matches".to_string(),
216                })
217            }
218            TrieResult::Exists => {
219                // Navigate to the table
220                self.navigate_to_table(&parts)
221            }
222        }
223    }
224
225    /// Navigate the schema tree to find a table's columns
226    fn navigate_to_table(&self, parts: &[String]) -> SchemaResult<&HashMap<String, ColumnInfo>> {
227        let mut current = &self.mapping;
228
229        for (i, part) in parts.iter().enumerate() {
230            match current.get(part) {
231                Some(SchemaNode::Namespace(inner)) => {
232                    current = inner;
233                }
234                Some(SchemaNode::Table(cols)) => {
235                    if i == parts.len() - 1 {
236                        return Ok(cols);
237                    } else {
238                        return Err(SchemaError::InvalidStructure(format!(
239                            "Found table at {} but expected more levels",
240                            parts[..=i].join(".")
241                        )));
242                    }
243                }
244                None => {
245                    return Err(SchemaError::TableNotFound(parts.join(".")));
246                }
247            }
248        }
249
250        // We've exhausted parts but didn't find a table
251        Err(SchemaError::TableNotFound(parts.join(".")))
252    }
253
254    /// Add a table to the schema
255    fn add_table_internal(
256        &mut self,
257        parts: &[String],
258        columns: HashMap<String, ColumnInfo>,
259    ) -> SchemaResult<()> {
260        if parts.is_empty() {
261            return Err(SchemaError::InvalidStructure(
262                "Table name cannot be empty".to_string(),
263            ));
264        }
265
266        // Build trie key (reversed parts)
267        let trie_key: String = parts.iter().rev().cloned().collect::<Vec<_>>().join(".");
268        self.mapping_trie.insert(&trie_key, ());
269
270        // Navigate/create path to table
271        let mut current = &mut self.mapping;
272
273        for (i, part) in parts.iter().enumerate() {
274            let is_last = i == parts.len() - 1;
275
276            if is_last {
277                // Insert table
278                current.insert(part.clone(), SchemaNode::Table(columns));
279                return Ok(());
280            } else {
281                // Navigate or create namespace
282                let entry = current
283                    .entry(part.clone())
284                    .or_insert_with(|| SchemaNode::Namespace(HashMap::new()));
285
286                match entry {
287                    SchemaNode::Namespace(inner) => {
288                        current = inner;
289                    }
290                    SchemaNode::Table(_) => {
291                        return Err(SchemaError::InvalidStructure(format!(
292                            "Expected namespace at {} but found table",
293                            parts[..=i].join(".")
294                        )));
295                    }
296                }
297            }
298        }
299
300        Ok(())
301    }
302
303    /// Update cached depth
304    fn update_depth(&mut self) {
305        self.cached_depth = self.calculate_depth(&self.mapping);
306    }
307
308    fn calculate_depth(&self, mapping: &HashMap<String, SchemaNode>) -> usize {
309        if mapping.is_empty() {
310            return 0;
311        }
312
313        let mut max_depth = 1;
314        for node in mapping.values() {
315            match node {
316                SchemaNode::Namespace(inner) => {
317                    let d = 1 + self.calculate_depth(inner);
318                    if d > max_depth {
319                        max_depth = d;
320                    }
321                }
322                SchemaNode::Table(_) => {
323                    // Tables don't add to depth beyond their level
324                }
325            }
326        }
327        max_depth
328    }
329}
330
331impl Schema for MappingSchema {
332    fn dialect(&self) -> Option<DialectType> {
333        self.dialect
334    }
335
336    fn add_table(
337        &mut self,
338        table: &str,
339        columns: &[(String, DataType)],
340        _dialect: Option<DialectType>,
341    ) -> SchemaResult<()> {
342        let parts = self.parse_table_parts(table);
343
344        let cols: HashMap<String, ColumnInfo> = columns
345            .iter()
346            .map(|(name, dtype)| {
347                let normalized_name = self.normalize_name(name, false);
348                (normalized_name, ColumnInfo::new(dtype.clone()))
349            })
350            .collect();
351
352        self.add_table_internal(&parts, cols)?;
353        self.update_depth();
354        Ok(())
355    }
356
357    fn column_names(&self, table: &str) -> SchemaResult<Vec<String>> {
358        let cols = self.find_table(table)?;
359        let table_key = self.normalize_name(table, true);
360
361        // Check visibility
362        if let Some(visible_cols) = self.visible.get(&table_key) {
363            Ok(cols
364                .keys()
365                .filter(|k| visible_cols.contains(*k))
366                .cloned()
367                .collect())
368        } else {
369            Ok(cols.keys().cloned().collect())
370        }
371    }
372
373    fn get_column_type(&self, table: &str, column: &str) -> SchemaResult<DataType> {
374        let cols = self.find_table(table)?;
375        let normalized_col = self.normalize_name(column, false);
376
377        cols.get(&normalized_col)
378            .map(|info| info.data_type.clone())
379            .ok_or_else(|| SchemaError::ColumnNotFound {
380                table: table.to_string(),
381                column: column.to_string(),
382            })
383    }
384
385    fn has_column(&self, table: &str, column: &str) -> bool {
386        self.get_column_type(table, column).is_ok()
387    }
388
389    fn supported_table_args(&self) -> &[&str] {
390        let depth = self.depth();
391        if depth == 0 {
392            &[]
393        } else if depth <= 3 {
394            &TABLE_PARTS[..depth]
395        } else {
396            TABLE_PARTS
397        }
398    }
399
400    fn is_empty(&self) -> bool {
401        self.mapping.is_empty()
402    }
403
404    fn depth(&self) -> usize {
405        self.cached_depth
406    }
407}
408
409/// Normalize a table or column name according to dialect rules
410pub fn normalize_name(
411    name: &str,
412    dialect: Option<DialectType>,
413    is_table: bool,
414    normalize: bool,
415) -> String {
416    if !normalize {
417        return name.to_string();
418    }
419
420    match dialect {
421        Some(DialectType::BigQuery) if is_table => name.to_string(),
422        Some(DialectType::Snowflake) => name.to_uppercase(),
423        _ => name.to_lowercase(),
424    }
425}
426
427/// Ensure we have a schema instance
428pub fn ensure_schema(schema: Option<MappingSchema>) -> MappingSchema {
429    schema.unwrap_or_default()
430}
431
432/// Helper to build a schema from a simple map
433///
434/// # Example
435///
436/// ```
437/// use polyglot_sql::schema::{MappingSchema, Schema, from_simple_map};
438/// use polyglot_sql::expressions::DataType;
439///
440/// let schema = from_simple_map(&[
441///     ("users", &[("id", DataType::Int { length: None, integer_spelling: false }), ("name", DataType::VarChar { length: Some(255), parenthesized_length: false })]),
442///     ("orders", &[("id", DataType::Int { length: None, integer_spelling: false }), ("user_id", DataType::Int { length: None, integer_spelling: false })]),
443/// ]);
444///
445/// assert_eq!(schema.column_names("users").unwrap().len(), 2);
446/// ```
447pub fn from_simple_map(tables: &[(&str, &[(&str, DataType)])]) -> MappingSchema {
448    let mut schema = MappingSchema::new();
449
450    for (table_name, columns) in tables {
451        let cols: Vec<(String, DataType)> = columns
452            .iter()
453            .map(|(name, dtype)| (name.to_string(), dtype.clone()))
454            .collect();
455
456        schema.add_table(table_name, &cols, None).ok();
457    }
458
459    schema
460}
461
462/// Flatten a nested schema to get all table paths
463pub fn flatten_schema_paths(schema: &MappingSchema) -> Vec<Vec<String>> {
464    let mut paths = Vec::new();
465    flatten_schema_paths_recursive(&schema.mapping, Vec::new(), &mut paths);
466    paths
467}
468
469fn flatten_schema_paths_recursive(
470    mapping: &HashMap<String, SchemaNode>,
471    prefix: Vec<String>,
472    paths: &mut Vec<Vec<String>>,
473) {
474    for (key, node) in mapping {
475        let mut path = prefix.clone();
476        path.push(key.clone());
477
478        match node {
479            SchemaNode::Namespace(inner) => {
480                flatten_schema_paths_recursive(inner, path, paths);
481            }
482            SchemaNode::Table(_) => {
483                paths.push(path);
484            }
485        }
486    }
487}
488
489/// Set a value in a nested dictionary-like structure
490pub fn nested_set<V: Clone>(
491    map: &mut HashMap<String, HashMap<String, V>>,
492    keys: &[String],
493    value: V,
494) {
495    if keys.is_empty() {
496        return;
497    }
498
499    if keys.len() == 1 {
500        // Can't set at single level - need at least 2 keys
501        return;
502    }
503
504    let outer_key = &keys[0];
505    let inner_key = &keys[1];
506
507    map.entry(outer_key.clone())
508        .or_insert_with(HashMap::new)
509        .insert(inner_key.clone(), value);
510}
511
512/// Get a value from a nested dictionary-like structure
513pub fn nested_get<'a, V>(
514    map: &'a HashMap<String, HashMap<String, V>>,
515    keys: &[String],
516) -> Option<&'a V> {
517    if keys.len() != 2 {
518        return None;
519    }
520
521    map.get(&keys[0])?.get(&keys[1])
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_empty_schema() {
530        let schema = MappingSchema::new();
531        assert!(schema.is_empty());
532        assert_eq!(schema.depth(), 0);
533    }
534
535    #[test]
536    fn test_add_table() {
537        let mut schema = MappingSchema::new();
538        let columns = vec![
539            ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
540            (
541                "name".to_string(),
542                DataType::VarChar { length: Some(255), parenthesized_length: false },
543            ),
544        ];
545
546        schema.add_table("users", &columns, None).unwrap();
547
548        assert!(!schema.is_empty());
549        assert_eq!(schema.depth(), 1);
550        assert!(schema.has_column("users", "id"));
551        assert!(schema.has_column("users", "name"));
552        assert!(!schema.has_column("users", "email"));
553    }
554
555    #[test]
556    fn test_qualified_table_names() {
557        let mut schema = MappingSchema::new();
558        let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
559
560        schema.add_table("mydb.users", &columns, None).unwrap();
561
562        assert!(schema.has_column("mydb.users", "id"));
563        assert_eq!(schema.depth(), 2);
564    }
565
566    #[test]
567    fn test_catalog_db_table() {
568        let mut schema = MappingSchema::new();
569        let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
570
571        schema
572            .add_table("catalog.mydb.users", &columns, None)
573            .unwrap();
574
575        assert!(schema.has_column("catalog.mydb.users", "id"));
576        assert_eq!(schema.depth(), 3);
577    }
578
579    #[test]
580    fn test_get_column_type() {
581        let mut schema = MappingSchema::new();
582        let columns = vec![
583            ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
584            (
585                "name".to_string(),
586                DataType::VarChar { length: Some(255), parenthesized_length: false },
587            ),
588        ];
589
590        schema.add_table("users", &columns, None).unwrap();
591
592        let id_type = schema.get_column_type("users", "id").unwrap();
593        assert!(matches!(id_type, DataType::Int { .. }));
594
595        let name_type = schema.get_column_type("users", "name").unwrap();
596        assert!(matches!(name_type, DataType::VarChar { length: Some(255), parenthesized_length: false }));
597    }
598
599    #[test]
600    fn test_column_names() {
601        let mut schema = MappingSchema::new();
602        let columns = vec![
603            ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
604            ("name".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
605        ];
606
607        schema.add_table("users", &columns, None).unwrap();
608
609        let names = schema.column_names("users").unwrap();
610        assert_eq!(names.len(), 2);
611        assert!(names.contains(&"id".to_string()));
612        assert!(names.contains(&"name".to_string()));
613    }
614
615    #[test]
616    fn test_table_not_found() {
617        let schema = MappingSchema::new();
618        let result = schema.column_names("nonexistent");
619        assert!(matches!(result, Err(SchemaError::TableNotFound(_))));
620    }
621
622    #[test]
623    fn test_column_not_found() {
624        let mut schema = MappingSchema::new();
625        let columns = vec![("id".to_string(), DataType::Int { length: None, integer_spelling: false })];
626        schema.add_table("users", &columns, None).unwrap();
627
628        let result = schema.get_column_type("users", "nonexistent");
629        assert!(matches!(result, Err(SchemaError::ColumnNotFound { .. })));
630    }
631
632    #[test]
633    fn test_normalize_name_default() {
634        let name = normalize_name("MyTable", None, true, true);
635        assert_eq!(name, "mytable");
636    }
637
638    #[test]
639    fn test_normalize_name_snowflake() {
640        let name = normalize_name("MyTable", Some(DialectType::Snowflake), true, true);
641        assert_eq!(name, "MYTABLE");
642    }
643
644    #[test]
645    fn test_normalize_disabled() {
646        let name = normalize_name("MyTable", None, true, false);
647        assert_eq!(name, "MyTable");
648    }
649
650    #[test]
651    fn test_from_simple_map() {
652        let schema = from_simple_map(&[
653            (
654                "users",
655                &[
656                    ("id", DataType::Int { length: None, integer_spelling: false }),
657                    ("name", DataType::VarChar { length: None, parenthesized_length: false }),
658                ],
659            ),
660            (
661                "orders",
662                &[
663                    ("id", DataType::Int { length: None, integer_spelling: false }),
664                    ("user_id", DataType::Int { length: None, integer_spelling: false }),
665                ],
666            ),
667        ]);
668
669        assert!(schema.has_column("users", "id"));
670        assert!(schema.has_column("users", "name"));
671        assert!(schema.has_column("orders", "id"));
672        assert!(schema.has_column("orders", "user_id"));
673    }
674
675    #[test]
676    fn test_flatten_schema_paths() {
677        let mut schema = MappingSchema::new();
678        schema
679            .add_table("db1.table1", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
680            .unwrap();
681        schema
682            .add_table("db1.table2", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
683            .unwrap();
684        schema
685            .add_table("db2.table1", &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })], None)
686            .unwrap();
687
688        let paths = flatten_schema_paths(&schema);
689        assert_eq!(paths.len(), 3);
690    }
691
692    #[test]
693    fn test_visible_columns() {
694        let mut schema = MappingSchema::new();
695        let columns = vec![
696            ("id".to_string(), DataType::Int { length: None, integer_spelling: false }),
697            ("name".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
698            ("password".to_string(), DataType::VarChar { length: None, parenthesized_length: false }),
699        ];
700        schema.add_table("users", &columns, None).unwrap();
701        schema.set_visible_columns("users", &["id", "name"]);
702
703        let names = schema.column_names("users").unwrap();
704        assert_eq!(names.len(), 2);
705        assert!(names.contains(&"id".to_string()));
706        assert!(names.contains(&"name".to_string()));
707        assert!(!names.contains(&"password".to_string()));
708    }
709}