schema_installer/
installer.rs1use crate::config::SchemaInstallerConfig;
2use crate::connection::AnyPool;
3use crate::error::SchemaInstallerError;
4use schema_parser::parse_database_xml;
5use schema_sql_generator::common::generate_options::GenerateOptions;
6use schema_sql_generator::common::generator_type::GeneratorType;
7use schema_sql_generator::common::output_mode::OutputMode;
8use schema_sql_generator::common::print_writer::PrintWriter;
9use std::cell::RefCell;
10use std::fs;
11use std::rc::Rc;
12
13pub struct SchemaInstaller;
14
15impl SchemaInstaller {
16 pub async fn install(config: &SchemaInstallerConfig) -> Result<(), SchemaInstallerError> {
17 let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
19
20 Self::ensure_tracking_tables(&pool, &config.database_type).await?;
22
23 if Self::check_if_installed(&pool).await? {
25 println!("Schema is already installed. Skipping installation.");
26 return Ok(());
27 }
28
29 let schema_file = config.schema_file.as_ref()
31 .ok_or_else(|| SchemaInstallerError::InvalidConfiguration("schema_file required for install command".to_string()))?;
32 let schema_file_str = schema_file.to_str()
33 .ok_or_else(|| SchemaInstallerError::SchemaFileNotFound("Invalid path".to_string()))?;
34 let schema_content = fs::read_to_string(schema_file_str)
35 .map_err(|e| SchemaInstallerError::Io(e))?;
36 let database_model = parse_database_xml(&schema_content)
37 .map_err(|e| SchemaInstallerError::Parse(e))?;
38
39 let version = database_model.version()
41 .map(|v| format!("{}.{}.{}", v.major_version(), v.minor_version(), v.patch_version()))
42 .unwrap_or_else(|| "1.0.0".to_string());
43
44 let nanos = std::time::SystemTime::now()
47 .duration_since(std::time::UNIX_EPOCH)
48 .map(|d| d.subsec_nanos())
49 .unwrap_or(0);
50 let temp_file = std::env::temp_dir().join(format!("schema_install_temp_{}.sql", nanos));
51 let file = std::fs::File::create(&temp_file)
52 .map_err(|e| SchemaInstallerError::Io(e))?;
53
54 let writer_temp = PrintWriter::new(Box::new(file));
55 let generate_options = GenerateOptions {
56 database_model: Rc::new(database_model),
57 writer: Rc::new(RefCell::new(writer_temp)),
58 boolean_mode: config.boolean_mode.clone(),
59 foreign_key_mode: config.foreign_key_mode.clone(),
60 output_mode: OutputMode::All,
61 target_postgres_version: 17,
62 };
63
64 (&config.database_type).generate(generate_options);
65
66 let sql = std::fs::read_to_string(&temp_file)
67 .map_err(|e| SchemaInstallerError::Io(e))?;
68
69 let _ = std::fs::remove_file(&temp_file);
70
71 let script_name = format!("V{}__install_schema.sql", version);
73 let checksum = crate::migration::compute_checksum(&sql);
74 let tool_version = env!("CARGO_PKG_VERSION");
75
76 let migration_id = pool
77 .insert_migration(&version, &script_name, &checksum, 0, "pending", tool_version)
78 .await?;
79
80 let start = std::time::Instant::now();
82 match Self::execute_sql_script(&pool, &config.database_type, &sql).await {
83 Ok(_) => {
84 let elapsed_ms = start.elapsed().as_millis() as i64;
85 pool.update_migration_status(migration_id, "success", elapsed_ms)
86 .await?;
87 println!("Schema installed successfully. Version: {}", version);
88 Ok(())
89 }
90 Err(e) => {
91 let elapsed_ms = start.elapsed().as_millis() as i64;
92 pool.update_migration_status(migration_id, "failed", elapsed_ms)
93 .await?;
94 Err(e)
95 }
96 }
97 }
98
99 pub async fn is_installed(config: &SchemaInstallerConfig) -> Result<bool, SchemaInstallerError> {
100 let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
101 Self::check_if_installed(&pool).await
102 }
103
104 pub async fn get_installed_version(config: &SchemaInstallerConfig) -> Result<Option<String>, SchemaInstallerError> {
105 let pool = AnyPool::connect(&config.database_type, &config.connection_string).await?;
106 match pool.get_applied_migrations().await {
107 Ok(migrations) => {
108 let latest = migrations
109 .iter()
110 .filter(|m| m.status == "success")
111 .max_by(|a, b| {
112 crate::migration::compare_versions(&a.version, &b.version)
113 });
114 Ok(latest.map(|m| m.version.clone()))
115 }
116 Err(e) => {
117 if e.to_string().contains("does not exist") || e.to_string().contains("no such table") {
119 Ok(None)
120 } else {
121 Err(e)
122 }
123 }
124 }
125 }
126
127 async fn check_if_installed(pool: &AnyPool) -> Result<bool, SchemaInstallerError> {
128 match pool.get_applied_migrations().await {
129 Ok(migrations) => Ok(migrations.iter().any(|m| m.status == "success")),
130 Err(e) => {
131 if e.to_string().contains("does not exist") || e.to_string().contains("no such table") {
133 Ok(false)
134 } else {
135 Err(e)
136 }
137 }
138 }
139 }
140
141 async fn ensure_tracking_tables(pool: &AnyPool, database_type: &GeneratorType) -> Result<(), SchemaInstallerError> {
142 pool.ensure_migration_table(database_type).await?;
143 Ok(())
144 }
145
146 async fn execute_sql_script(
147 pool: &AnyPool,
148 database_type: &GeneratorType,
149 sql: &str,
150 ) -> Result<(), SchemaInstallerError> {
151 let delimiter = match database_type {
153 GeneratorType::SqlServer => "GO",
154 _ => ";",
155 };
156
157 let statements: Vec<&str> = sql.split(delimiter)
158 .map(|s| s.trim())
159 .filter(|s| !s.is_empty())
160 .collect();
161
162 for statement in statements {
163 pool.execute_sql(statement).await?;
164 }
165
166 Ok(())
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 #[tokio::test]
173 async fn test_sql_script_splitting() {
174 let sql_pg = "CREATE TABLE t1 (id INT); CREATE TABLE t2 (id INT);";
176 let statements: Vec<&str> = sql_pg.split(";")
177 .map(|s| s.trim())
178 .filter(|s| !s.is_empty())
179 .collect();
180 assert_eq!(statements.len(), 2);
181
182 let sql_mssql = "CREATE TABLE t1 (id INT)\nGO\nCREATE TABLE t2 (id INT)\nGO";
184 let statements: Vec<&str> = sql_mssql.split("GO")
185 .map(|s| s.trim())
186 .filter(|s| !s.is_empty())
187 .collect();
188 assert_eq!(statements.len(), 2);
189 }
190}