1use std::fmt;
2
3use bon::bon;
4use sqlparser::ast::{
5 AlterColumnOperation, AlterTable, AlterTableOperation, AlterType, AlterTypeAddValuePosition,
6 AlterTypeOperation, ColumnOption, ColumnOptionDef, CreateExtension, CreateTable, DropExtension,
7 GeneratedAs, ObjectName, ObjectNamePart, ObjectType, Statement, UserDefinedTypeRepresentation,
8};
9use thiserror::Error;
10
11#[derive(Error, Debug)]
12pub struct MigrateError {
13 kind: MigrateErrorKind,
14 statement_a: Option<Box<Statement>>,
15 statement_b: Option<Box<Statement>>,
16}
17
18impl fmt::Display for MigrateError {
19 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
20 write!(
21 f,
22 "Oops, we couldn't migrate that: {reason}",
23 reason = self.kind
24 )?;
25 if let Some(statement_a) = &self.statement_a {
26 write!(f, "\n\nSubject:\n{statement_a}")?;
27 }
28 if let Some(statement_b) = &self.statement_b {
29 write!(f, "\n\nMigration:\n{statement_b}")?;
30 }
31 Ok(())
32 }
33}
34
35#[bon]
36impl MigrateError {
37 #[builder]
38 fn new(
39 kind: MigrateErrorKind,
40 statement_a: Option<Statement>,
41 statement_b: Option<Statement>,
42 ) -> Self {
43 Self {
44 kind,
45 statement_a: statement_a.map(Box::new),
46 statement_b: statement_b.map(Box::new),
47 }
48 }
49}
50
51#[derive(Error, Debug)]
52#[non_exhaustive]
53enum MigrateErrorKind {
54 #[error("can't migrate unnamed index")]
55 UnnamedIndex,
56 #[error("ALTER TABLE operation \"{0}\" not yet supported")]
57 AlterTableOpNotImplemented(Box<AlterTableOperation>),
58 #[error("invalid ALTER TYPE operation \"{0}\"")]
59 AlterTypeInvalidOp(Box<AlterTypeOperation>),
60 #[error("not yet supported")]
61 NotImplemented,
62}
63
64pub(crate) trait Migrate: Sized {
65 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError>;
66}
67
68impl Migrate for Vec<Statement> {
69 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
70 let next: Self = self
71 .into_iter()
72 .filter_map(|sa| {
74 let orig = sa.clone();
75 match &sa {
76 Statement::CreateTable(ca) => other
77 .iter()
78 .find(|sb| match sb {
79 Statement::AlterTable(AlterTable { name, .. }) => *name == ca.name,
80 Statement::Drop {
81 object_type, names, ..
82 } => {
83 *object_type == ObjectType::Table
84 && names.len() == 1
85 && names[0] == ca.name
86 }
87 _ => false,
88 })
89 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
90 Statement::CreateIndex(a) => other
91 .iter()
92 .find(|sb| match sb {
93 Statement::Drop {
94 object_type, names, ..
95 } => {
96 *object_type == ObjectType::Index
97 && names.len() == 1
98 && Some(&names[0]) == a.name.as_ref()
99 }
100 _ => false,
101 })
102 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
103 Statement::CreateType { name, .. } => other
104 .iter()
105 .find(|sb| match sb {
106 Statement::AlterType(b) => *name == b.name,
107 Statement::Drop {
108 object_type, names, ..
109 } => {
110 *object_type == ObjectType::Type
111 && names.len() == 1
112 && names[0] == *name
113 }
114 _ => false,
115 })
116 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
117 Statement::CreateExtension(CreateExtension { name, .. }) => other
118 .iter()
119 .find(|sb| match sb {
120 Statement::DropExtension(DropExtension { names, .. }) => {
121 names.contains(name)
122 }
123 _ => false,
124 })
125 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
126 Statement::CreateDomain(a) => other
127 .iter()
128 .find(|sb| match sb {
129 Statement::DropDomain(b) => a.name == b.name,
130 _ => false,
131 })
132 .map_or(Some(Ok(orig)), |sb| sa.migrate(sb).transpose()),
133 _ => Some(Err(MigrateError::builder()
134 .kind(MigrateErrorKind::NotImplemented)
135 .statement_a(sa.clone())
136 .build())),
137 }
138 })
139 .chain(other.iter().filter_map(|sb| match sb {
141 Statement::CreateTable(_)
142 | Statement::CreateIndex { .. }
143 | Statement::CreateType { .. }
144 | Statement::CreateExtension { .. }
145 | Statement::CreateDomain(..) => Some(Ok(sb.clone())),
146 _ => None,
147 }))
148 .collect::<Result<_, _>>()?;
149 Ok(Some(next))
150 }
151}
152
153impl Migrate for Statement {
154 fn migrate(self, other: &Self) -> Result<Option<Self>, MigrateError> {
155 match self {
156 Self::CreateTable(ca) => match other {
157 Self::AlterTable(AlterTable {
158 name, operations, ..
159 }) => {
160 if *name == ca.name {
161 Ok(Some(Self::CreateTable(migrate_alter_table(
162 ca, operations,
163 )?)))
164 } else {
165 Ok(Some(Self::CreateTable(ca)))
167 }
168 }
169 Self::Drop {
170 object_type, names, ..
171 } => {
172 if *object_type == ObjectType::Table && names.contains(&ca.name) {
173 Ok(None)
174 } else {
175 Ok(Some(Self::CreateTable(ca)))
177 }
178 }
179 _ => Err(MigrateError::builder()
180 .kind(MigrateErrorKind::NotImplemented)
181 .statement_a(Self::CreateTable(ca))
182 .statement_b(other.clone())
183 .build()),
184 },
185 Self::CreateIndex(a) => match other {
186 Self::Drop {
187 object_type, names, ..
188 } => {
189 let name = a.name.clone().ok_or_else(|| {
190 MigrateError::builder()
191 .kind(MigrateErrorKind::UnnamedIndex)
192 .statement_a(Self::CreateIndex(a.clone()))
193 .statement_b(other.clone())
194 .build()
195 })?;
196 if *object_type == ObjectType::Index && names.contains(&name) {
197 Ok(None)
198 } else {
199 Ok(Some(Self::CreateIndex(a)))
201 }
202 }
203 _ => Err(MigrateError::builder()
204 .kind(MigrateErrorKind::NotImplemented)
205 .statement_a(Self::CreateIndex(a))
206 .statement_b(other.clone())
207 .build()),
208 },
209 Self::CreateType {
210 name,
211 representation,
212 } => match other {
213 Self::AlterType(ba) => {
214 if name == ba.name {
215 let (name, representation) =
216 migrate_alter_type(name.clone(), representation.clone(), ba)?;
217 Ok(Some(Self::CreateType {
218 name,
219 representation,
220 }))
221 } else {
222 Ok(Some(Self::CreateType {
224 name,
225 representation,
226 }))
227 }
228 }
229 Self::Drop {
230 object_type, names, ..
231 } => {
232 if *object_type == ObjectType::Type && names.contains(&name) {
233 Ok(None)
234 } else {
235 Ok(Some(Self::CreateType {
237 name,
238 representation,
239 }))
240 }
241 }
242 _ => Err(MigrateError::builder()
243 .kind(MigrateErrorKind::NotImplemented)
244 .statement_a(Self::CreateType {
245 name,
246 representation,
247 })
248 .statement_b(other.clone())
249 .build()),
250 },
251 _ => Err(MigrateError::builder()
252 .kind(MigrateErrorKind::NotImplemented)
253 .statement_a(self)
254 .statement_b(other.clone())
255 .build()),
256 }
257 }
258}
259
260fn migrate_alter_table(
261 mut t: CreateTable,
262 ops: &[AlterTableOperation],
263) -> Result<CreateTable, MigrateError> {
264 for op in ops.iter() {
265 match op {
266 AlterTableOperation::AddColumn { column_def, .. } => {
267 t.columns.push(column_def.clone());
268 }
269 AlterTableOperation::DropColumn { column_names, .. } => {
270 t.columns.retain(|c| {
271 !column_names
272 .iter().any(|name| c.name.value == name.value)
273 });
274 }
275 AlterTableOperation::AlterColumn { column_name, op } => {
276 t.columns.iter_mut().for_each(|c| {
277 if c.name != *column_name {
278 return;
279 }
280 match op {
281 AlterColumnOperation::SetNotNull => {
282 c.options.push(ColumnOptionDef {
283 name: None,
284 option: ColumnOption::NotNull,
285 });
286 }
287 AlterColumnOperation::DropNotNull => {
288 c.options
289 .retain(|o| !matches!(o.option, ColumnOption::NotNull));
290 }
291 AlterColumnOperation::SetDefault { value } => {
292 c.options
293 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
294 c.options.push(ColumnOptionDef {
295 name: None,
296 option: ColumnOption::Default(value.clone()),
297 });
298 }
299 AlterColumnOperation::DropDefault => {
300 c.options
301 .retain(|o| !matches!(o.option, ColumnOption::Default(_)));
302 }
303 AlterColumnOperation::SetDataType {
304 data_type,
305 using: _, had_set: _, } => {
308 c.data_type = data_type.clone();
309 }
310 AlterColumnOperation::AddGenerated {
311 generated_as,
312 sequence_options,
313 } => {
314 c.options
315 .retain(|o| !matches!(o.option, ColumnOption::Generated { .. }));
316 c.options.push(ColumnOptionDef {
317 name: None,
318 option: ColumnOption::Generated {
319 generated_as: (*generated_as)
320 .unwrap_or(GeneratedAs::Always),
321 sequence_options: sequence_options.clone(),
322 generation_expr: None,
323 generation_expr_mode: None,
324 generated_keyword: true,
325 },
326 });
327 }
328 }
329 });
330 }
331 op => {
332 return Err(MigrateError::builder()
333 .kind(MigrateErrorKind::AlterTableOpNotImplemented(Box::new(
334 op.clone(),
335 )))
336 .statement_a(Statement::CreateTable(t.clone()))
337 .build())
338 }
339 }
340 }
341
342 Ok(t)
343}
344
345fn migrate_alter_type(
346 name: ObjectName,
347 representation: Option<UserDefinedTypeRepresentation>,
348 other: &AlterType,
349) -> Result<(ObjectName, Option<UserDefinedTypeRepresentation>), MigrateError> {
350 match &other.operation {
351 AlterTypeOperation::Rename(r) => {
352 let mut parts = name.0;
353 parts.pop();
354 parts.push(ObjectNamePart::Identifier(r.new_name.clone()));
355 let name = ObjectName(parts);
356
357 Ok((name, representation))
358 }
359 AlterTypeOperation::AddValue(a) => match representation {
360 Some(UserDefinedTypeRepresentation::Enum { mut labels }) => {
361 match &a.position {
362 Some(AlterTypeAddValuePosition::Before(before_name)) => {
363 let index = labels
364 .iter()
365 .enumerate()
366 .find(|(_, l)| *l == before_name)
367 .map(|(i, _)| i)
368 .unwrap_or(0);
370 labels.insert(index, a.value.clone());
371 }
372 Some(AlterTypeAddValuePosition::After(after_name)) => {
373 let index = labels
374 .iter()
375 .enumerate()
376 .find(|(_, l)| *l == after_name)
377 .map(|(i, _)| i + 1);
378 match index {
379 Some(index) => labels.insert(index, a.value.clone()),
380 None => labels.push(a.value.clone()),
382 }
383 }
384 None => labels.push(a.value.clone()),
385 }
386
387 Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
388 }
389 Some(_) | None => Err(MigrateError::builder()
390 .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
391 other.operation.clone(),
392 )))
393 .statement_a(Statement::CreateType {
394 name,
395 representation,
396 })
397 .statement_b(Statement::AlterType(other.clone()))
398 .build()),
399 },
400 AlterTypeOperation::RenameValue(rv) => match representation {
401 Some(UserDefinedTypeRepresentation::Enum { labels }) => {
402 let labels = labels
403 .into_iter()
404 .map(|l| if l == rv.from { rv.to.clone() } else { l })
405 .collect::<Vec<_>>();
406
407 Ok((name, Some(UserDefinedTypeRepresentation::Enum { labels })))
408 }
409 Some(_) | None => Err(MigrateError::builder()
410 .kind(MigrateErrorKind::AlterTypeInvalidOp(Box::new(
411 other.operation.clone(),
412 )))
413 .statement_a(Statement::CreateType {
414 name,
415 representation,
416 })
417 .statement_b(Statement::AlterType(other.clone()))
418 .build()),
419 },
420 }
421}