1use super::*;
2
3pub async fn execute_schema_migration_plan<E>(
8 executor: &mut E,
9 plan: &SchemaMigrationPlan,
10) -> Result<(), OpenAuthError>
11where
12 E: SqlExecutor,
13{
14 for statement in &plan.statements {
15 executor
16 .execute(SqlStatement::new(statement.sql.clone()))
17 .await?;
18 }
19 Ok(())
20}
21
22#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
24pub struct SchemaMigrationPlan {
25 pub to_be_created: Vec<TableToCreate>,
26 pub to_be_added: Vec<ColumnToAdd>,
27 pub indexes_to_be_created: Vec<IndexToCreate>,
28 pub warnings: Vec<SchemaMigrationWarning>,
29 pub statements: Vec<MigrationStatement>,
30}
31
32impl SchemaMigrationPlan {
33 pub fn is_empty(&self) -> bool {
34 self.statements.is_empty()
35 }
36
37 pub fn compile(&self) -> String {
38 if self.statements.is_empty() {
39 return ";".to_owned();
40 }
41
42 format!(
43 "{};",
44 self.statements
45 .iter()
46 .map(|statement| statement.sql.as_str())
47 .collect::<Vec<_>>()
48 .join(";\n\n")
49 )
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub struct TableToCreate {
56 pub logical_name: String,
57 pub table_name: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
62pub struct ColumnToAdd {
63 pub table_logical_name: String,
64 pub table_name: String,
65 pub field_logical_name: String,
66 pub column_name: String,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
71pub struct IndexToCreate {
72 pub table_logical_name: String,
73 pub table_name: String,
74 pub field_logical_name: String,
75 pub column_name: String,
76 pub index_name: String,
77}
78
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
81pub enum SchemaMigrationWarning {
82 ColumnTypeMismatch {
83 table_name: String,
84 column_name: String,
85 expected: String,
86 actual: String,
87 },
88}
89
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
92pub struct MigrationStatement {
93 pub kind: MigrationStatementKind,
94 pub sql: String,
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
99pub enum MigrationStatementKind {
100 CreateTable,
101 AddColumn,
102 CreateIndex,
103}
104
105#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
107pub struct SqlSchemaSnapshot {
108 tables: IndexMap<String, SqlTableSnapshot>,
109}
110
111impl SqlSchemaSnapshot {
112 pub fn with_table(mut self, table: impl Into<String>) -> Self {
113 self.tables.entry(table.into()).or_default();
114 self
115 }
116
117 pub fn with_column(mut self, table: impl Into<String>, column: SqlColumnSnapshot) -> Self {
118 self.tables
119 .entry(table.into())
120 .or_default()
121 .columns
122 .insert(column.name.clone(), column);
123 self
124 }
125
126 pub fn with_index(mut self, table: impl Into<String>, index: impl Into<String>) -> Self {
127 self.tables
128 .entry(table.into())
129 .or_default()
130 .indexes
131 .insert(index.into());
132 self
133 }
134
135 pub fn table_exists(&self, table: &str) -> bool {
136 self.tables.contains_key(table)
137 }
138
139 pub fn column_type(&self, table: &str, column: &str) -> Option<&str> {
140 self.tables
141 .get(table)
142 .and_then(|table| table.columns.get(column))
143 .map(|column| column.data_type.as_str())
144 }
145
146 pub fn index_exists(&self, table: &str, index: &str) -> bool {
147 self.tables
148 .get(table)
149 .is_some_and(|table| table.indexes.contains(index))
150 || self
151 .tables
152 .values()
153 .any(|table| table.indexes.contains(index))
154 }
155}
156
157#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
159pub struct SqlTableSnapshot {
160 columns: IndexMap<String, SqlColumnSnapshot>,
161 indexes: IndexSet<String>,
162}
163
164#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
166pub struct SqlColumnSnapshot {
167 pub name: String,
168 pub data_type: String,
169}
170
171impl SqlColumnSnapshot {
172 pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
173 Self {
174 name: name.into(),
175 data_type: data_type.into(),
176 }
177 }
178}
179
180pub fn plan_schema_migration(
182 dialect: SqlDialect,
183 schema: &DbSchema,
184 snapshot: &SqlSchemaSnapshot,
185) -> Result<SchemaMigrationPlan, OpenAuthError> {
186 let mut plan = SchemaMigrationPlan::default();
187 let mut tables = schema.tables().collect::<Vec<_>>();
188 tables.sort_by_key(|(_, table)| table.order.unwrap_or(u16::MAX));
189
190 for (table_logical_name, table) in &tables {
191 if snapshot.table_exists(&table.name) {
192 for (logical_name, field) in &table.fields {
193 if let Some(actual_type) = snapshot.column_type(&table.name, &field.name) {
194 if !dialect.type_matches(actual_type, field) {
195 plan.warnings
196 .push(SchemaMigrationWarning::ColumnTypeMismatch {
197 table_name: table.name.clone(),
198 column_name: field.name.clone(),
199 expected: dialect.sql_type(logical_name, field),
200 actual: actual_type.to_owned(),
201 });
202 }
203 } else {
204 plan.to_be_added.push(ColumnToAdd {
205 table_logical_name: (*table_logical_name).to_owned(),
206 table_name: table.name.clone(),
207 field_logical_name: logical_name.clone(),
208 column_name: field.name.clone(),
209 });
210 plan.statements.push(MigrationStatement {
211 kind: MigrationStatementKind::AddColumn,
212 sql: dialect.add_column_statement(&table.name, logical_name, field)?,
213 });
214 }
215 }
216 } else {
217 plan.to_be_created.push(TableToCreate {
218 logical_name: (*table_logical_name).to_owned(),
219 table_name: table.name.clone(),
220 });
221 plan.statements.push(MigrationStatement {
222 kind: MigrationStatementKind::CreateTable,
223 sql: dialect.create_table_statement(table)?,
224 });
225 }
226 }
227
228 for (table_logical_name, table) in tables {
229 for (logical_name, field) in &table.fields {
230 if field.index && !field.unique {
231 let index_name =
232 dialect.sanitize_identifier(&format!("idx_{}_{}", table.name, logical_name))?;
233 if !snapshot.index_exists(&table.name, &index_name) {
234 plan.indexes_to_be_created.push(IndexToCreate {
235 table_logical_name: table_logical_name.to_owned(),
236 table_name: table.name.clone(),
237 field_logical_name: logical_name.clone(),
238 column_name: field.name.clone(),
239 index_name: index_name.clone(),
240 });
241 plan.statements.push(MigrationStatement {
242 kind: MigrationStatementKind::CreateIndex,
243 sql: dialect.create_index_statement(
244 &table.name,
245 &field.name,
246 &index_name,
247 )?,
248 });
249 }
250 }
251 }
252 }
253
254 Ok(plan)
255}