Skip to main content

schema_installer/
installer.rs

1use crate::config::SchemaInstallerConfig;
2use crate::connection::AnyPool;
3use crate::error::SchemaInstallerError;
4use schema_parser::parse_database_xml;
5use schema_sql_generator::common::generate_options::GenerateOptions;
6use schema_sql_generator::common::generator_type::GeneratorType;
7use schema_sql_generator::common::output_mode::OutputMode;
8use schema_sql_generator::common::print_writer::PrintWriter;
9use std::cell::RefCell;
10use std::fs;
11use std::rc::Rc;
12
13pub struct SchemaInstaller;
14
15impl SchemaInstaller {
16    pub async fn install(config: &SchemaInstallerConfig) -> Result<(), SchemaInstallerError> {
17        // Connect to database
18        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
19
20        // Create tracking tables if they don't exist
21        Self::ensure_tracking_tables(&pool, &config.database_type).await?;
22
23        // Check if already installed
24        if Self::check_if_installed(&pool).await? {
25            println!("Schema is already installed. Skipping installation.");
26            return Ok(());
27        }
28
29        // Parse schema
30        let schema_file = config.schema_file.as_ref()
31            .ok_or_else(|| SchemaInstallerError::InvalidConfiguration("schema_file required for install command".to_string()))?;
32        let schema_file_str = schema_file.to_str()
33            .ok_or_else(|| SchemaInstallerError::SchemaFileNotFound("Invalid path".to_string()))?;
34        let schema_content = fs::read_to_string(schema_file_str)
35            .map_err(|e| SchemaInstallerError::Io(e))?;
36        let database_model = parse_database_xml(&schema_content)
37            .map_err(|e| SchemaInstallerError::Parse(e))?;
38
39        // Get schema version
40        let version = database_model.version()
41            .map(|v| format!("{}.{}.{}", v.major_version(), v.minor_version(), v.patch_version()))
42            .unwrap_or_else(|| "1.0.0".to_string());
43
44        // Generate SQL by writing to temp file
45        // (PrintWriter's BufWriter makes it difficult to extract bytes in memory)
46        let nanos = std::time::SystemTime::now()
47            .duration_since(std::time::UNIX_EPOCH)
48            .map(|d| d.subsec_nanos())
49            .unwrap_or(0);
50        let temp_file = std::env::temp_dir().join(format!("schema_install_temp_{}.sql", nanos));
51        let file = std::fs::File::create(&temp_file)
52            .map_err(|e| SchemaInstallerError::Io(e))?;
53
54        let writer_temp = PrintWriter::new(Box::new(file));
55        let generate_options = GenerateOptions {
56            database_model: Rc::new(database_model),
57            writer: Rc::new(RefCell::new(writer_temp)),
58            boolean_mode: config.boolean_mode.clone(),
59            foreign_key_mode: config.foreign_key_mode.clone(),
60            output_mode: OutputMode::All,
61            target_postgres_version: 17,
62        };
63
64        (&config.database_type).generate(generate_options);
65
66        let sql = std::fs::read_to_string(&temp_file)
67            .map_err(|e| SchemaInstallerError::Io(e))?;
68
69        let _ = std::fs::remove_file(&temp_file);
70
71        // Record migration
72        let script_name = format!("V{}__install_schema.sql", version);
73        let checksum = crate::migration::compute_checksum(&sql);
74        let tool_version = env!("CARGO_PKG_VERSION");
75
76        let migration_id = pool
77            .insert_migration(&version, &script_name, &checksum, 0, "pending", tool_version)
78            .await?;
79
80        // Execute SQL statements
81        let start = std::time::Instant::now();
82        match Self::execute_sql_script(&pool, &config.database_type, &sql).await {
83            Ok(_) => {
84                let elapsed_ms = start.elapsed().as_millis() as i64;
85                pool.update_migration_status(migration_id, "success", elapsed_ms)
86                    .await?;
87                println!("Schema installed successfully. Version: {}", version);
88                Ok(())
89            }
90            Err(e) => {
91                let elapsed_ms = start.elapsed().as_millis() as i64;
92                pool.update_migration_status(migration_id, "failed", elapsed_ms)
93                    .await?;
94                Err(e)
95            }
96        }
97    }
98
99    pub async fn is_installed(config: &SchemaInstallerConfig) -> Result<bool, SchemaInstallerError> {
100        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
101        Self::check_if_installed(&pool).await
102    }
103
104    pub async fn get_installed_version(config: &SchemaInstallerConfig) -> Result<Option<String>, SchemaInstallerError> {
105        let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
106        match pool.get_applied_migrations().await {
107            Ok(migrations) => {
108                let latest = migrations
109                    .iter()
110                    .filter(|m| m.status == "success")
111                    .max_by(|a, b| {
112                        crate::migration::compare_versions(&a.version, &b.version)
113                    });
114                Ok(latest.map(|m| m.version.clone()))
115            }
116            Err(e) => {
117                // Table might not exist yet, which is fine
118                if e.to_string().contains("does not exist") || e.to_string().contains("no such table") {
119                    Ok(None)
120                } else {
121                    Err(e)
122                }
123            }
124        }
125    }
126
127    async fn check_if_installed(pool: &AnyPool) -> Result<bool, SchemaInstallerError> {
128        match pool.get_applied_migrations().await {
129            Ok(migrations) => Ok(migrations.iter().any(|m| m.status == "success")),
130            Err(e) => {
131                // Table might not exist yet, which is fine
132                if e.to_string().contains("does not exist") || e.to_string().contains("no such table") {
133                    Ok(false)
134                } else {
135                    Err(e)
136                }
137            }
138        }
139    }
140
141    async fn ensure_tracking_tables(pool: &AnyPool, database_type: &GeneratorType) -> Result<(), SchemaInstallerError> {
142        pool.ensure_migration_table(database_type).await?;
143        Ok(())
144    }
145
146    async fn execute_sql_script(
147        pool: &AnyPool,
148        database_type: &GeneratorType,
149        sql: &str,
150    ) -> Result<(), SchemaInstallerError> {
151        // Split SQL statements based on database type
152        let delimiter = match database_type {
153            GeneratorType::SqlServer => "GO",
154            _ => ";",
155        };
156
157        let statements: Vec<&str> = sql.split(delimiter)
158            .map(|s| s.trim())
159            .filter(|s| !s.is_empty())
160            .collect();
161
162        for statement in statements {
163            pool.execute_sql(statement).await?;
164        }
165
166        Ok(())
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    #[tokio::test]
173    async fn test_sql_script_splitting() {
174        // Test PostgreSQL delimiter
175        let sql_pg = "CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);";
176        let statements: Vec<&str> = sql_pg.split(";")
177            .map(|s| s.trim())
178            .filter(|s| !s.is_empty())
179            .collect();
180        assert_eq!(statements.len(), 2);
181
182        // Test SQL Server delimiter
183        let sql_mssql = "CREATE TABLE t1 (id INT)\nGO\nCREATE TABLE t2 (id INT)\nGO";
184        let statements: Vec<&str> = sql_mssql.split("GO")
185            .map(|s| s.trim())
186            .filter(|s| !s.is_empty())
187            .collect();
188        assert_eq!(statements.len(), 2);
189    }
190}