1use super::*;
2
3pub async fn execute_schema_migration_plan<E>(
8 executor: &mut E,
9 plan: &SchemaMigrationPlan,
10) -> Result<(), OpenAuthError>
11where
12 E: SqlExecutor,
13{
14 for statement in &plan.statements {
15 executor
16 .execute(SqlStatement::new(statement.sql.clone()))
17 .await?;
18 }
19 Ok(())
20}
21
22#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
24pub struct SchemaMigrationPlan {
25 pub to_be_created: Vec<TableToCreate>,
26 pub to_be_added: Vec<ColumnToAdd>,
27 pub indexes_to_be_created: Vec<IndexToCreate>,
28 pub warnings: Vec<SchemaMigrationWarning>,
29 pub statements: Vec<MigrationStatement>,
30}
31
32impl SchemaMigrationPlan {
33 pub fn is_empty(&self) -> bool {
34 self.statements.is_empty()
35 }
36
37 pub fn has_warnings(&self) -> bool {
38 !self.warnings.is_empty()
39 }
40
41 pub fn compile(&self) -> String {
42 if self.statements.is_empty() {
43 return ";".to_owned();
44 }
45
46 format!(
47 "{};",
48 self.statements
49 .iter()
50 .map(|statement| statement.sql.as_str())
51 .collect::<Vec<_>>()
52 .join(";\n\n")
53 )
54 }
55}
56
57#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
59pub struct TableToCreate {
60 pub logical_name: String,
61 pub table_name: String,
62}
63
64#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
66pub struct ColumnToAdd {
67 pub table_logical_name: String,
68 pub table_name: String,
69 pub field_logical_name: String,
70 pub column_name: String,
71}
72
73#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
75pub struct IndexToCreate {
76 pub table_logical_name: String,
77 pub table_name: String,
78 pub field_logical_name: String,
79 pub column_name: String,
80 pub index_name: String,
81 pub unique: bool,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
86#[allow(clippy::enum_variant_names)]
87pub enum SchemaMigrationWarning {
88 ColumnTypeMismatch {
89 table_name: String,
90 column_name: String,
91 expected: String,
92 actual: String,
93 },
94 ColumnNullabilityMismatch {
95 table_name: String,
96 column_name: String,
97 expected_nullable: bool,
98 actual_nullable: bool,
99 },
100 PrimaryKeyMismatch {
101 table_name: String,
102 column_name: String,
103 },
104 GeneratedIdMismatch {
105 table_name: String,
106 column_name: String,
107 expected: IdGeneration,
108 actual: Option<IdGeneration>,
109 },
110 ForeignKeyMismatch {
111 table_name: String,
112 column_name: String,
113 expected: ForeignKey,
114 actual: Option<ForeignKey>,
115 },
116}
117
118#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
120pub struct MigrationStatement {
121 pub kind: MigrationStatementKind,
122 pub sql: String,
123}
124
125#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
127pub enum MigrationStatementKind {
128 CreateTable,
129 AddColumn,
130 CreateIndex,
131}
132
133#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
135pub struct SqlSchemaSnapshot {
136 tables: IndexMap<String, SqlTableSnapshot>,
137}
138
139impl SqlSchemaSnapshot {
140 pub fn with_table(mut self, table: impl Into<String>) -> Self {
141 self.tables.entry(table.into()).or_default();
142 self
143 }
144
145 pub fn with_column(mut self, table: impl Into<String>, column: SqlColumnSnapshot) -> Self {
146 self.tables
147 .entry(table.into())
148 .or_default()
149 .columns
150 .insert(column.name.clone(), column);
151 self
152 }
153
154 pub fn with_index(mut self, table: impl Into<String>, index: impl Into<String>) -> Self {
155 self.tables
156 .entry(table.into())
157 .or_default()
158 .indexes
159 .insert(index.into());
160 self
161 }
162
163 pub fn with_unique_column(
164 mut self,
165 table: impl Into<String>,
166 column: impl Into<String>,
167 ) -> Self {
168 self.tables
169 .entry(table.into())
170 .or_default()
171 .unique_columns
172 .insert(column.into());
173 self
174 }
175
176 pub fn table_exists(&self, table: &str) -> bool {
177 self.tables.contains_key(table)
178 }
179
180 pub fn column_type(&self, table: &str, column: &str) -> Option<&str> {
181 self.column(table, column)
182 .map(|column| column.data_type.as_str())
183 }
184
185 pub fn column(&self, table: &str, column: &str) -> Option<&SqlColumnSnapshot> {
186 self.tables
187 .get(table)
188 .and_then(|table| table.columns.get(column))
189 }
190
191 pub fn index_exists(&self, table: &str, index: &str) -> bool {
192 self.tables
193 .get(table)
194 .is_some_and(|table| table.indexes.contains(index))
195 || self
196 .tables
197 .values()
198 .any(|table| table.indexes.contains(index))
199 }
200
201 pub fn unique_column_exists(&self, table: &str, column: &str) -> bool {
202 self.tables
203 .get(table)
204 .is_some_and(|table| table.unique_columns.contains(column))
205 }
206}
207
208#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
210pub struct SqlTableSnapshot {
211 columns: IndexMap<String, SqlColumnSnapshot>,
212 indexes: IndexSet<String>,
213 unique_columns: IndexSet<String>,
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218pub struct SqlColumnSnapshot {
219 pub name: String,
220 pub data_type: String,
221 pub nullable: Option<bool>,
222 pub primary_key: Option<bool>,
223 pub generated_id: Option<IdGeneration>,
224 pub foreign_key: Option<ForeignKey>,
225}
226
227impl SqlColumnSnapshot {
228 pub fn new(name: impl Into<String>, data_type: impl Into<String>) -> Self {
229 Self {
230 name: name.into(),
231 data_type: data_type.into(),
232 nullable: None,
233 primary_key: None,
234 generated_id: None,
235 foreign_key: None,
236 }
237 }
238
239 pub fn nullable(mut self, nullable: bool) -> Self {
240 self.nullable = Some(nullable);
241 self
242 }
243
244 pub fn primary_key(mut self, primary_key: bool) -> Self {
245 self.primary_key = Some(primary_key);
246 self
247 }
248
249 pub fn generated_id(mut self, generated_id: Option<IdGeneration>) -> Self {
250 self.generated_id = generated_id;
251 self
252 }
253
254 pub fn references(mut self, foreign_key: ForeignKey) -> Self {
255 self.foreign_key = Some(foreign_key);
256 self
257 }
258}
259
260pub fn plan_schema_migration(
262 dialect: SqlDialect,
263 schema: &DbSchema,
264 snapshot: &SqlSchemaSnapshot,
265) -> Result<SchemaMigrationPlan, OpenAuthError> {
266 let mut plan = SchemaMigrationPlan::default();
267 let mut tables = schema.tables().collect::<Vec<_>>();
268 tables.sort_by_key(|(_, table)| table.order.unwrap_or(u16::MAX));
269
270 for (table_logical_name, table) in &tables {
271 if snapshot.table_exists(&table.name) {
272 for (logical_name, field) in &table.fields {
273 if let Some(column) = snapshot.column(&table.name, &field.name) {
274 if !dialect.type_matches(&column.data_type, field) {
275 plan.warnings
276 .push(SchemaMigrationWarning::ColumnTypeMismatch {
277 table_name: table.name.clone(),
278 column_name: field.name.clone(),
279 expected: dialect.sql_type(logical_name, field),
280 actual: column.data_type.clone(),
281 });
282 }
283 push_constraint_warnings(&mut plan, table, logical_name, field, column);
284 } else {
285 plan.to_be_added.push(ColumnToAdd {
286 table_logical_name: (*table_logical_name).to_owned(),
287 table_name: table.name.clone(),
288 field_logical_name: logical_name.clone(),
289 column_name: field.name.clone(),
290 });
291 plan.statements.push(MigrationStatement {
292 kind: MigrationStatementKind::AddColumn,
293 sql: dialect.add_column_statement(&table.name, logical_name, field)?,
294 });
295 }
296 }
297 } else {
298 plan.to_be_created.push(TableToCreate {
299 logical_name: (*table_logical_name).to_owned(),
300 table_name: table.name.clone(),
301 });
302 plan.statements.push(MigrationStatement {
303 kind: MigrationStatementKind::CreateTable,
304 sql: dialect.create_table_statement(table)?,
305 });
306 }
307 }
308
309 for (table_logical_name, table) in tables {
310 let table_exists = snapshot.table_exists(&table.name);
311 for (logical_name, field) in &table.fields {
312 if field.index || field.unique {
313 if field.unique
314 && (!table_exists || snapshot.unique_column_exists(&table.name, &field.name))
315 {
316 continue;
317 }
318 let prefix = if field.unique { "uidx" } else { "idx" };
319 let index_name = dialect
320 .sanitize_identifier(&format!("{prefix}_{}_{}", table.name, logical_name))?;
321 if !snapshot.index_exists(&table.name, &index_name) {
322 plan.indexes_to_be_created.push(IndexToCreate {
323 table_logical_name: table_logical_name.to_owned(),
324 table_name: table.name.clone(),
325 field_logical_name: logical_name.clone(),
326 column_name: field.name.clone(),
327 index_name: index_name.clone(),
328 unique: field.unique,
329 });
330 plan.statements.push(MigrationStatement {
331 kind: MigrationStatementKind::CreateIndex,
332 sql: dialect.create_index_statement(
333 &table.name,
334 &field.name,
335 &index_name,
336 field.unique,
337 )?,
338 });
339 }
340 }
341 }
342 }
343
344 Ok(plan)
345}
346
347fn push_constraint_warnings(
348 plan: &mut SchemaMigrationPlan,
349 table: &DbTable,
350 logical_name: &str,
351 field: &DbField,
352 column: &SqlColumnSnapshot,
353) {
354 if logical_name == "id" || field.name == "id" {
355 if column.primary_key == Some(false) {
356 plan.warnings
357 .push(SchemaMigrationWarning::PrimaryKeyMismatch {
358 table_name: table.name.clone(),
359 column_name: field.name.clone(),
360 });
361 }
362 } else if let Some(actual_nullable) = column.nullable {
363 let expected_nullable = !field.required;
364 if expected_nullable != actual_nullable {
365 plan.warnings
366 .push(SchemaMigrationWarning::ColumnNullabilityMismatch {
367 table_name: table.name.clone(),
368 column_name: field.name.clone(),
369 expected_nullable,
370 actual_nullable,
371 });
372 }
373 }
374
375 if logical_name == "id" || field.name == "id" {
376 if let Some(expected) = field.generated_id {
377 if column.generated_id != Some(expected) {
378 plan.warnings
379 .push(SchemaMigrationWarning::GeneratedIdMismatch {
380 table_name: table.name.clone(),
381 column_name: field.name.clone(),
382 expected,
383 actual: column.generated_id,
384 });
385 }
386 }
387 }
388
389 if let Some(expected) = &field.foreign_key {
390 if column.foreign_key.as_ref() != Some(expected) {
391 plan.warnings
392 .push(SchemaMigrationWarning::ForeignKeyMismatch {
393 table_name: table.name.clone(),
394 column_name: field.name.clone(),
395 expected: expected.clone(),
396 actual: column.foreign_key.clone(),
397 });
398 }
399 }
400}