1use super::{ColumnData, ImportError, ImportResult, TableData};
15use crate::validation::input::{validate_column_name, validate_data_type, validate_table_name};
16use anyhow::Result;
17use sqlparser::ast::{ColumnDef, ColumnOption, ObjectName, Statement, TableConstraint};
18use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect, PostgreSqlDialect, SQLiteDialect};
19use sqlparser::parser::Parser;
20use std::collections::HashMap;
21
22#[derive(Debug)]
28struct DatabricksDialect;
29
30impl Dialect for DatabricksDialect {
31 fn is_identifier_start(&self, ch: char) -> bool {
32 ch.is_alphabetic() || ch == '_' || ch == ':'
34 }
35
36 fn is_identifier_part(&self, ch: char) -> bool {
37 ch.is_alphanumeric() || ch == '_' || ch == ':'
39 }
40
41 fn is_delimited_identifier_start(&self, ch: char) -> bool {
42 ch == '"' || ch == '`' || ch == '['
44 }
45}
46
47#[derive(Debug)]
49struct PreprocessingState {
50 identifier_replacements: HashMap<String, String>,
52 #[allow(dead_code)] variable_replacements: Vec<(String, String)>,
55}
56
57impl PreprocessingState {
58 fn new() -> Self {
59 Self {
60 identifier_replacements: HashMap::new(),
61 variable_replacements: Vec::new(),
62 }
63 }
64}
65
66pub struct SQLImporter {
68 pub dialect: String,
70}
71
72impl Default for SQLImporter {
73 fn default() -> Self {
74 Self {
75 dialect: "generic".to_string(),
76 }
77 }
78}
79
80impl SQLImporter {
81 pub fn new(dialect: &str) -> Self {
111 Self {
112 dialect: dialect.to_string(),
113 }
114 }
115
116 fn preprocess_identifier_expressions(sql: &str, state: &mut PreprocessingState) -> String {
121 use regex::Regex;
122
123 let re = Regex::new(r"(?i)IDENTIFIER\s*\(\s*([^)]+)\s*\)").unwrap();
125 let mut counter = 0;
126
127 re.replace_all(sql, |caps: ®ex::Captures| {
128 let expr = caps.get(1).map(|m| m.as_str()).unwrap_or("");
129 counter += 1;
130 let placeholder = format!("__databricks_table_{}__", counter);
131
132 state
134 .identifier_replacements
135 .insert(placeholder.clone(), expr.to_string());
136
137 placeholder
138 })
139 .to_string()
140 }
141
142 fn extract_identifier_table_name(expr: &str) -> Option<String> {
147 use regex::Regex;
148
149 let literal_re = Regex::new(r#"(?:'([^']*)'|"([^"]*)")"#).unwrap();
151 let mut parts = Vec::new();
152
153 for cap in literal_re.captures_iter(expr) {
155 if let Some(m) = cap.get(1) {
156 parts.push(m.as_str().to_string());
157 } else if let Some(m) = cap.get(2) {
158 parts.push(m.as_str().to_string());
159 }
160 }
161
162 if parts.is_empty() {
163 return None;
165 }
166
167 let result = parts.join("");
169 Some(result.trim_matches('.').to_string())
170 }
171
172 #[allow(dead_code)] fn handle_identifier_variables(placeholder: &str, _expr: &str) -> String {
177 placeholder.to_string()
179 }
180
181 fn preprocess_materialized_views(sql: &str) -> String {
186 use regex::Regex;
187
188 let re = Regex::new(r"(?i)CREATE\s+MATERIALIZED\s+VIEW").unwrap();
190 re.replace_all(sql, "CREATE VIEW").to_string()
191 }
192
193 fn replace_variables_in_struct_types(sql: &str) -> String {
197 use regex::Regex;
198
199 let re = Regex::new(r":\s*:([a-zA-Z_][a-zA-Z0-9_]*)").unwrap();
203
204 re.replace_all(sql, |_caps: ®ex::Captures| {
205 ": STRING".to_string()
207 })
208 .to_string()
209 }
210
211 fn replace_variables_in_array_types(sql: &str) -> String {
215 use regex::Regex;
216
217 let re = Regex::new(r"ARRAY\s*<\s*:([a-zA-Z_][a-zA-Z0-9_]*)\s*>").unwrap();
221
222 re.replace_all(sql, |_caps: ®ex::Captures| "ARRAY<STRING>".to_string())
223 .to_string()
224 }
225
226 fn replace_variables_in_column_definitions(sql: &str) -> String {
231 use regex::Regex;
232
233 let re = Regex::new(r"(\w+)\s+:\w+\s+([A-Z][A-Z0-9_]*(?:<[^>]*>)?)").unwrap();
237
238 re.replace_all(sql, |caps: ®ex::Captures| {
239 let col_name = caps.get(1).map(|m| m.as_str()).unwrap_or("");
240 let type_name = caps.get(2).map(|m| m.as_str()).unwrap_or("");
241 format!("{} {}", col_name, type_name)
242 })
243 .to_string()
244 }
245
246 fn replace_nested_variables(sql: &str) -> String {
251 let mut result = sql.to_string();
252 let mut changed = true;
253 let mut iterations = 0;
254 const MAX_ITERATIONS: usize = 10; while changed && iterations < MAX_ITERATIONS {
258 let before = result.clone();
259
260 result = Self::replace_variables_in_struct_types(&result);
262
263 result = Self::replace_variables_in_array_types(&result);
265
266 changed = before != result;
268 iterations += 1;
269 }
270
271 result
272 }
273
274 fn extract_complex_type_columns(sql: &str) -> (String, Vec<(String, String)>) {
282 use regex::Regex;
283
284 let mut column_types = Vec::new();
285 let mut result = sql.to_string();
286
287 let re = Regex::new(r"(\w+)\s+(STRUCT<|ARRAY<)").unwrap();
290
291 let mut matches_to_replace: Vec<(usize, usize, String, String)> = Vec::new();
293
294 for cap in re.captures_iter(sql) {
295 let col_name = cap.get(1).map(|m| m.as_str()).unwrap_or("");
296 let type_start = cap.get(0).map(|m| m.start()).unwrap_or(0);
297 let struct_or_array = cap.get(2).map(|m| m.as_str()).unwrap_or("");
298
299 let bracket_start = type_start + col_name.len() + 1 + struct_or_array.len() - 1; let mut bracket_count = 0;
303 let mut type_end = bracket_start;
304
305 for (idx, ch) in sql[bracket_start..].char_indices() {
306 let pos = bracket_start + idx;
307 if ch == '<' {
308 bracket_count += 1;
309 } else if ch == '>' {
310 bracket_count -= 1;
311 if bracket_count == 0 {
312 type_end = pos + 1;
313 break;
314 }
315 }
316 }
317
318 if bracket_count == 0 && type_end > type_start {
319 let type_start_pos = type_start + col_name.len() + 1;
322 let full_type = sql[type_start_pos..type_end].trim().to_string();
323 matches_to_replace.push((
324 type_start_pos,
325 type_end,
326 col_name.to_string(),
327 full_type,
328 ));
329 }
330 }
331
332 for (start, end, col_name, full_type) in matches_to_replace.iter().rev() {
334 column_types.push((col_name.clone(), full_type.clone()));
335 result.replace_range(*start..*end, "STRING");
336 }
337
338 (result, column_types)
339 }
340
341 pub fn parse(&self, sql: &str) -> Result<ImportResult> {
385 let (preprocessed_sql, preprocessing_state, complex_types) = if self.dialect.to_lowercase()
387 == "databricks"
388 {
389 let mut state = PreprocessingState::new();
390 let mut preprocessed = Self::preprocess_materialized_views(sql);
392 preprocessed = Self::preprocess_identifier_expressions(&preprocessed, &mut state);
394 preprocessed = Self::replace_variables_in_column_definitions(&preprocessed);
396 preprocessed = Self::replace_nested_variables(&preprocessed);
399 let normalized: String = preprocessed
401 .lines()
402 .map(|line| line.trim())
403 .filter(|line| !line.is_empty())
404 .collect::<Vec<_>>()
405 .join(" ");
406 let (simplified_sql, complex_cols) = Self::extract_complex_type_columns(&normalized);
408 (simplified_sql, state, complex_cols)
409 } else {
410 (sql.to_string(), PreprocessingState::new(), Vec::new())
411 };
412
413 let dialect = self.dialect_impl();
414 let statements = match Parser::parse_sql(dialect.as_ref(), &preprocessed_sql) {
415 Ok(stmts) => stmts,
416 Err(e) => {
417 return Ok(ImportResult {
418 tables: Vec::new(),
419 tables_requiring_name: Vec::new(),
420 errors: vec![ImportError::ParseError(e.to_string())],
421 ai_suggestions: None,
422 });
423 }
424 };
425
426 let mut tables = Vec::new();
427 let mut errors = Vec::new();
428 let mut tables_requiring_name = Vec::new();
429
430 for (idx, stmt) in statements.into_iter().enumerate() {
431 match stmt {
432 Statement::CreateTable(create) => {
433 match self.parse_create_table_with_preprocessing(
434 idx,
435 &create.name,
436 &create.columns,
437 &create.constraints,
438 &preprocessing_state,
439 &complex_types,
440 ) {
441 Ok((table, requires_name)) => {
442 if requires_name {
443 tables_requiring_name.push(super::TableRequiringName {
444 table_index: idx,
445 suggested_name: None,
446 });
447 }
448 tables.push(table);
449 }
450 Err(e) => errors.push(ImportError::ParseError(e)),
451 }
452 }
453 Statement::CreateView { name, .. } => {
454 match self.parse_create_view(idx, &name, &preprocessing_state) {
455 Ok((table, requires_name)) => {
456 if requires_name {
457 tables_requiring_name.push(super::TableRequiringName {
458 table_index: idx,
459 suggested_name: None,
460 });
461 }
462 tables.push(table);
463 }
464 Err(e) => errors.push(ImportError::ParseError(e)),
465 }
466 }
467 _ => {
468 }
470 }
471 }
472
473 Ok(ImportResult {
474 tables,
475 tables_requiring_name,
476 errors,
477 ai_suggestions: None,
478 })
479 }
480
481 pub fn parse_liquibase(&self, sql: &str) -> Result<ImportResult> {
508 let cleaned = sql
513 .lines()
514 .filter(|l| {
515 let t = l.trim_start();
516 if !t.starts_with("--") {
517 return true;
518 }
519 false
521 })
522 .collect::<Vec<_>>()
523 .join("\n");
524
525 self.parse(&cleaned)
526 }
527
528 fn dialect_impl(&self) -> Box<dyn Dialect + Send + Sync> {
529 match self.dialect.to_lowercase().as_str() {
530 "postgres" | "postgresql" => Box::new(PostgreSqlDialect {}),
531 "mysql" => Box::new(MySqlDialect {}),
532 "sqlite" => Box::new(SQLiteDialect {}),
533 "databricks" => Box::new(DatabricksDialect {}),
534 _ => Box::new(GenericDialect {}),
535 }
536 }
537
538 fn object_name_to_string(name: &ObjectName) -> String {
539 name.0
541 .last()
542 .map(|ident| ident.value.clone())
543 .unwrap_or_else(|| name.to_string())
544 }
545
546 fn parse_create_table_with_preprocessing(
547 &self,
548 table_index: usize,
549 name: &ObjectName,
550 columns: &[ColumnDef],
551 constraints: &[TableConstraint],
552 preprocessing_state: &PreprocessingState,
553 complex_types: &[(String, String)],
554 ) -> std::result::Result<(TableData, bool), String> {
555 let mut table_name = Self::object_name_to_string(name);
556 let mut requires_name = false;
557
558 if table_name.starts_with("__databricks_table_")
560 && let Some(original_expr) =
561 preprocessing_state.identifier_replacements.get(&table_name)
562 {
563 if let Some(extracted_name) = Self::extract_identifier_table_name(original_expr) {
565 table_name = extracted_name;
566 } else {
567 requires_name = true;
569 }
570 }
571
572 if let Err(e) = validate_table_name(&table_name) {
574 tracing::warn!("Table name validation warning: {}", e);
576 }
577
578 let mut pk_cols = std::collections::HashSet::<String>::new();
580 for c in constraints {
581 if let TableConstraint::PrimaryKey { columns, .. } = c {
582 for col in columns {
583 pk_cols.insert(col.value.clone());
584 }
585 }
586 }
587
588 let mut out_cols = Vec::new();
589 for col in columns {
590 let mut nullable = true;
591 let mut is_pk = false;
592
593 for opt_def in &col.options {
594 match &opt_def.option {
595 ColumnOption::NotNull => nullable = false,
596 ColumnOption::Null => nullable = true,
597 ColumnOption::Unique { is_primary, .. } => {
598 if *is_primary {
599 is_pk = true;
600 }
601 }
602 _ => {}
603 }
604 }
605
606 if pk_cols.contains(&col.name.value) {
607 is_pk = true;
608 }
609
610 let col_name = col.name.value.clone();
611 let mut data_type = col.data_type.to_string();
612
613 if let Some((_, original_type)) =
615 complex_types.iter().find(|(name, _)| name == &col_name)
616 {
617 data_type = original_type.clone();
618 }
619
620 if let Err(e) = validate_column_name(&col_name) {
622 tracing::warn!("Column name validation warning for '{}': {}", col_name, e);
623 }
624 if let Err(e) = validate_data_type(&data_type) {
625 tracing::warn!("Data type validation warning for '{}': {}", data_type, e);
626 }
627
628 out_cols.push(ColumnData {
629 name: col_name,
630 data_type,
631 nullable,
632 primary_key: is_pk,
633 description: None,
634 quality: None,
635 ref_path: None,
636 });
637 }
638
639 Ok((
640 TableData {
641 table_index,
642 name: Some(table_name),
643 columns: out_cols,
644 },
645 requires_name,
646 ))
647 }
648
649 #[allow(dead_code)] fn parse_create_table(
651 &self,
652 table_index: usize,
653 name: &ObjectName,
654 columns: &[ColumnDef],
655 constraints: &[TableConstraint],
656 ) -> std::result::Result<TableData, String> {
657 let empty_state = PreprocessingState::new();
659 self.parse_create_table_with_preprocessing(
660 table_index,
661 name,
662 columns,
663 constraints,
664 &empty_state,
665 &[],
666 )
667 .map(|(table, _)| table)
668 }
669
670 fn parse_create_view(
675 &self,
676 view_index: usize,
677 name: &ObjectName,
678 preprocessing_state: &PreprocessingState,
679 ) -> std::result::Result<(TableData, bool), String> {
680 let mut view_name = Self::object_name_to_string(name);
681 let mut requires_name = false;
682
683 if view_name.starts_with("__databricks_table_")
685 && let Some(original_expr) = preprocessing_state.identifier_replacements.get(&view_name)
686 {
687 if let Some(extracted_name) = Self::extract_identifier_table_name(original_expr) {
689 view_name = extracted_name;
690 } else {
691 requires_name = true;
693 }
694 }
695
696 if let Err(e) = validate_table_name(&view_name) {
698 tracing::warn!("View name validation warning: {}", e);
699 }
700
701 Ok((
706 TableData {
707 table_index: view_index,
708 name: Some(view_name),
709 columns: Vec::new(), },
711 requires_name,
712 ))
713 }
714}
715
716#[cfg(test)]
717mod tests {
718 use super::*;
719
720 #[test]
721 fn test_sql_importer_default() {
722 let importer = SQLImporter::default();
723 assert_eq!(importer.dialect, "generic");
724 }
725
726 #[test]
727 fn test_sql_importer_parse_basic() {
728 let importer = SQLImporter::new("postgres");
729 let result = importer
730 .parse("CREATE TABLE test (id INT PRIMARY KEY, name TEXT NOT NULL);")
731 .unwrap();
732 assert!(result.errors.is_empty());
733 assert_eq!(result.tables.len(), 1);
734 let t = &result.tables[0];
735 assert_eq!(t.name.as_deref(), Some("test"));
736 assert_eq!(t.columns.len(), 2);
737 assert!(t.columns.iter().any(|c| c.name == "id" && c.primary_key));
738 assert!(t.columns.iter().any(|c| c.name == "name" && !c.nullable));
739 }
740
741 #[test]
742 fn test_sql_importer_parse_table_pk_constraint() {
743 let importer = SQLImporter::new("postgres");
744 let result = importer
745 .parse("CREATE TABLE t (id INT, name TEXT, CONSTRAINT pk PRIMARY KEY (id));")
746 .unwrap();
747 assert!(result.errors.is_empty());
748 assert_eq!(result.tables.len(), 1);
749 let t = &result.tables[0];
750 assert!(t.columns.iter().any(|c| c.name == "id" && c.primary_key));
751 }
752
753 #[test]
754 fn test_sql_importer_parse_liquibase_formatted_sql() {
755 let importer = SQLImporter::new("postgres");
756 let result = importer
757 .parse_liquibase(
758 "--liquibase formatted sql\n--changeset user:1\nCREATE TABLE test (id INT);\n",
759 )
760 .unwrap();
761 assert!(result.errors.is_empty());
762 assert_eq!(result.tables.len(), 1);
763 }
764
765 #[test]
766 fn test_databricks_identifier_with_literal() {
767 let importer = SQLImporter::new("databricks");
768 let sql = "CREATE TABLE IDENTIFIER('test_table') (id STRING);";
769 let result = importer.parse(sql).unwrap();
770 assert!(result.errors.is_empty());
771 assert_eq!(result.tables.len(), 1);
772 assert_eq!(result.tables[0].name.as_deref(), Some("test_table"));
773 }
774
775 #[test]
776 fn test_databricks_identifier_with_variable() {
777 let importer = SQLImporter::new("databricks");
778 let sql = "CREATE TABLE IDENTIFIER(:table_name) (id STRING);";
779 let result = importer.parse(sql).unwrap();
780 assert_eq!(result.tables.len(), 1);
782 assert!(
783 result.tables[0]
784 .name
785 .as_deref()
786 .unwrap()
787 .starts_with("__databricks_table_")
788 );
789 assert_eq!(result.tables_requiring_name.len(), 1);
790 }
791
792 #[test]
793 fn test_databricks_identifier_with_concatenation() {
794 let importer = SQLImporter::new("databricks");
795 let sql = "CREATE TABLE IDENTIFIER(:catalog || '.schema.table') (id STRING);";
796 let result = importer.parse(sql).unwrap();
797 assert!(result.errors.is_empty());
798 assert_eq!(result.tables.len(), 1);
799 assert_eq!(result.tables[0].name.as_deref(), Some("schema.table"));
801 }
802
803 #[test]
804 fn test_databricks_variable_in_struct() {
805 let importer = SQLImporter::new("databricks");
806 let sql = "CREATE TABLE example (metadata STRUCT<key: STRING, value: :variable_type, timestamp: TIMESTAMP>);";
807 let result = importer.parse(sql).unwrap();
808 if !result.errors.is_empty() {
809 eprintln!("Parse errors: {:?}", result.errors);
810 }
811 assert!(result.errors.is_empty());
812 assert_eq!(result.tables.len(), 1);
813 assert!(
815 result.tables[0].columns[0]
816 .data_type
817 .contains("value: STRING")
818 );
819 }
820
821 #[test]
822 fn test_databricks_variable_in_array() {
823 let importer = SQLImporter::new("databricks");
824 let sql = "CREATE TABLE example (items ARRAY<:element_type>);";
825 let result = importer.parse(sql).unwrap();
826 assert!(result.errors.is_empty());
827 assert_eq!(result.tables.len(), 1);
828 assert_eq!(result.tables[0].columns[0].data_type, "ARRAY<STRING>");
830 }
831
832 #[test]
833 fn test_databricks_nested_variables() {
834 let importer = SQLImporter::new("databricks");
835 let sql = "CREATE TABLE example (rulesTriggered ARRAY<STRUCT<id: STRING, name: STRING, alertOperation: STRUCT<name: STRING, revert: :variable_type, timestamp: TIMESTAMP>>>);";
836 let result = importer.parse(sql).unwrap();
837 if !result.errors.is_empty() {
838 eprintln!("Parse errors: {:?}", result.errors);
839 }
840 assert!(result.errors.is_empty());
841 assert_eq!(result.tables.len(), 1);
842 assert!(
844 result.tables[0].columns[0]
845 .data_type
846 .contains("revert: STRING")
847 );
848 }
849
850 #[test]
851 fn test_databricks_comment_variable() {
852 let importer = SQLImporter::new("databricks");
853 let sql = "CREATE TABLE example (id STRING) COMMENT ':comment_variable';";
854 let result = importer.parse(sql).unwrap();
855 assert!(result.errors.is_empty());
856 assert_eq!(result.tables.len(), 1);
857 }
858
859 #[test]
860 fn test_databricks_tblproperties_variable() {
861 let importer = SQLImporter::new("databricks");
862 let sql = "CREATE TABLE example (id STRING) TBLPROPERTIES ('key1' = ':variable_value', 'key2' = 'static_value');";
863 let result = importer.parse(sql).unwrap();
864 assert!(result.errors.is_empty());
865 assert_eq!(result.tables.len(), 1);
866 }
867
868 #[test]
869 fn test_databricks_column_variable() {
870 let importer = SQLImporter::new("databricks");
871 let sql = "CREATE TABLE example (id :id_var STRING, name :name_var STRING);";
874 let result = importer.parse(sql).unwrap();
875 assert!(result.errors.is_empty());
876 assert_eq!(result.tables.len(), 1);
877 assert_eq!(result.tables[0].columns.len(), 2);
878 }
879
880 #[test]
881 fn test_databricks_create_view() {
882 let importer = SQLImporter::new("databricks");
883 let sql = "CREATE VIEW example_view AS SELECT id, name FROM source_table;";
884 let result = importer.parse(sql).unwrap();
885 assert!(result.errors.is_empty());
887 assert_eq!(result.tables.len(), 1);
888 assert_eq!(result.tables[0].name.as_deref(), Some("example_view"));
889 }
890
891 #[test]
892 fn test_databricks_view_with_identifier() {
893 let importer = SQLImporter::new("databricks");
894 let sql =
895 "CREATE VIEW IDENTIFIER(:catalog || '.schema.view_name') AS SELECT * FROM table1;";
896 let result = importer.parse(sql).unwrap();
897 assert!(result.errors.is_empty());
898 assert_eq!(result.tables.len(), 1);
899 assert_eq!(result.tables[0].name.as_deref(), Some("schema.view_name"));
901 }
902
903 #[test]
904 fn test_databricks_create_materialized_view() {
905 let importer = SQLImporter::new("databricks");
906 let sql = "CREATE MATERIALIZED VIEW mv_example AS SELECT id, name FROM source_table;";
908 let result = importer.parse(sql).unwrap();
909 assert!(result.errors.is_empty());
910 assert_eq!(result.tables.len(), 1);
911 assert_eq!(result.tables[0].name.as_deref(), Some("mv_example"));
912 }
913}