midas_core/
commander.rs

1use std::iter::Iterator;
2use std::thread;
3use std::time::Duration;
4
5use anyhow::{
6  Context,
7  Result as AnyhowResult,
8};
9use console::style;
10use indicatif::ProgressBar;
11use prettytable::format::consts;
12use prettytable::{
13  color,
14  row,
15  Attr,
16  Cell,
17  Row,
18  Table,
19};
20use rand::Rng;
21use url::Url;
22
23use crate::lookup::{
24  MigrationFiles,
25  VecStr,
26};
27use crate::sequel::{
28  Driver as SequelDriver,
29  VecSerial,
30};
31use crate::{
32  ensure_migration_state_dir_exists,
33  progress_style,
34};
35
36/// Get the content string
37macro_rules! get_content_string {
38  ($content: ident) => {
39    $content
40      .iter()
41      .filter(|&l| l != "")
42      .map(|s| s.to_owned())
43      .collect::<VecStr>()
44      .join("\n")
45  };
46}
47
48/// The migrator struct
49pub struct Migrator<T: ?Sized> {
50  /// The executor instance
51  executor: Box<T>,
52
53  /// The migration files
54  migrations: MigrationFiles,
55}
56
57impl<T: SequelDriver + 'static + ?Sized> Migrator<T> {
58  /// Create a new migrator instance
59  pub fn new(executor: Box<T>, migrations: MigrationFiles) -> Self {
60    Self { executor, migrations }
61  }
62
63  /// Run the status command to show the current status of migrations
64  pub fn status(&mut self) -> AnyhowResult<()> {
65    // Get the completed migrations
66    let completed_migrations = self.executor.get_completed_migrations()?;
67    let available_migrations = self.migrations.keys().copied().collect::<VecSerial>();
68
69    // If there are no available migrations, print a message and return
70    if available_migrations.is_empty() {
71      println!("There are no available migration files.");
72      return Ok(());
73    }
74
75    // Create a new table instance
76    let mut table = Table::new();
77    table.set_titles(row![Fbb->"Migration No.", Fbb->"Status", Fbb->"Filename"]);
78    table.set_format(*consts::FORMAT_CLEAN);
79
80    // Iterate over the available migrations
81    available_migrations.iter().for_each(|it| {
82      // Set the color based on whether the migration is completed
83      let temp_color = if completed_migrations.contains(it) {
84        color::GREEN
85      } else {
86        color::RED
87      };
88
89      // Get the migration number and the migration file
90      let migration_no = format!("{it:013}");
91      if let Some(migration) = self.migrations.get(it) {
92        let filename = &migration.filename;
93
94        table.add_row(Row::new(vec![
95          Cell::new(&migration_no).with_style(Attr::Bold),
96          Cell::new(if temp_color == color::GREEN {
97            "Active"
98          } else {
99            "Inactive"
100          })
101          .with_style(Attr::ForegroundColor(temp_color)),
102          Cell::new(filename).with_style(Attr::ForegroundColor(temp_color)),
103        ]));
104      }
105    });
106
107    // Print the table
108    let msg = style("Available migrations:").bold().cyan();
109    println!();
110    println!("{msg}");
111    println!();
112    table.printstd();
113    println!();
114
115    // Print the completed migrations count and the available migrations count
116    let available_migrations_count = available_migrations.len();
117    let completed_migrations_count = completed_migrations.len();
118    let completed_migrations = style("Completed migrations:").bold().cyan();
119    let total_migrations = style("Total migrations:").bold().cyan();
120    println!("{completed_migrations}: {completed_migrations_count}");
121    println!("{total_migrations}: {available_migrations_count}");
122
123    Ok(())
124  }
125
126  /// Run up migrations
127  pub fn up(&mut self) -> AnyhowResult<()> {
128    // Ensure the migration state directory exists
129    ensure_migration_state_dir_exists()?;
130
131    // Get the completed migrations
132    let completed_migrations = self.executor.get_completed_migrations()?;
133    let available_migrations = self.migrations.keys().copied().collect::<VecSerial>();
134
135    // If there are no available migrations, print a message and return
136    if available_migrations.is_empty() {
137      println!("There are no available migration files.");
138      return Ok(());
139    }
140
141    // Filter the available migrations
142    let filtered: Vec<_> = available_migrations
143      .iter()
144      .filter(|s| !completed_migrations.contains(s))
145      .copied()
146      .collect();
147
148    // If there are no filtered migrations, print a message and return
149    if filtered.is_empty() {
150      println!("Migrations are all up-to-date.");
151      return Ok(());
152    }
153
154    // Create a new progress bar instance
155    let pb = ProgressBar::new(filtered.len() as u64);
156    let tick_interval = Duration::from_millis(80);
157    pb.set_style(progress_style()?);
158    pb.enable_steady_tick(tick_interval);
159    let mut rng = rand::thread_rng();
160
161    // Iterate over the filtered migrations
162    for it in &filtered {
163      // Sleep for a random duration between 40 and 300 milliseconds
164      // to simulate a delay and make the progress bar more interesting
165      thread::sleep(Duration::from_millis(rng.gen_range(40..300)));
166
167      // Set the progress bar prefix
168      pb.set_prefix(format!("{it:013}"));
169
170      // Get the migration file
171      let migration = self.migrations.get(it).context("Migration file not found")?;
172      let filename_parts: Vec<&str> = migration.filename.splitn(2, '_').collect();
173      let migration_name = filename_parts
174        .get(1)
175        .and_then(|s| s.strip_suffix(".sql"))
176        .context("Migration name not found")?;
177
178      // Set the progress bar message
179      pb.set_message(format!("Applying migration: {migration_name}"));
180
181      // Get the migration up content and convert it to a string
182      let content_up = migration
183        .content_up
184        .as_ref()
185        .context("Migration content not found")?;
186      let content_up = get_content_string!(content_up);
187
188      // Run the migration content
189      self.executor.migrate(&content_up, *it)?;
190
191      // Add the completed migration
192      self.executor.add_completed_migration(*it)?;
193      pb.inc(1);
194    }
195    pb.finish();
196
197    Ok(())
198  }
199
200  /// Run up migrations up to a specific migration number
201  pub fn upto(&mut self, migration_number: i64) -> AnyhowResult<()> {
202    // Ensure the migration state directory exists
203    ensure_migration_state_dir_exists()?;
204
205    // Get the completed migrations
206    let completed_migrations = self.executor.get_completed_migrations()?;
207    let available_migrations = self.migrations.keys().copied().collect::<VecSerial>();
208
209    // If there are no available migrations, print a message and return
210    if available_migrations.is_empty() {
211      println!("There are no available migration files.");
212      return Ok(());
213    }
214
215    // Filter the available migrations
216    let filtered: Vec<_> = available_migrations
217      .iter()
218      .filter(|s| !completed_migrations.contains(s))
219      .filter(|s| **s <= migration_number)
220      .copied()
221      .collect();
222
223    // If there are no filtered migrations, print a message and return
224    if filtered.is_empty() {
225      println!("Migrations are all up-to-date.");
226      return Ok(());
227    }
228
229    // Create a new progress bar instance
230    let pb = ProgressBar::new(filtered.len() as u64);
231    let tick_interval = Duration::from_millis(80);
232    pb.set_style(progress_style()?);
233    pb.enable_steady_tick(tick_interval);
234    let mut rng = rand::thread_rng();
235
236    // Iterate over the filtered migrations
237    for it in &filtered {
238      // Sleep for a random duration between 40 and 300 milliseconds
239      // to simulate a delay and make the progress bar more interesting
240      thread::sleep(Duration::from_millis(rng.gen_range(40..300)));
241      pb.set_prefix(format!("{it:013}"));
242
243      // Get the migration file
244      let migration = self.migrations.get(it).context("Migration file not found")?;
245      let filename_parts: Vec<&str> = migration.filename.splitn(2, '_').collect();
246      let migration_name = filename_parts
247        .get(1)
248        .and_then(|s| s.strip_suffix(".sql"))
249        .context("Migration name not found")?;
250
251      // Set the progress bar message
252      pb.set_message(format!("Applying migration: {migration_name}"));
253
254      // Get the migration up content and convert it to a string
255      let content_up = migration
256        .content_up
257        .as_ref()
258        .context("Migration content not found")?;
259      let content_up = get_content_string!(content_up);
260
261      // Run the migration content
262      self.executor.migrate(&content_up, *it)?;
263      self.executor.add_completed_migration(*it)?;
264      pb.inc(1);
265    }
266    pb.finish();
267
268    Ok(())
269  }
270
271  /// Run down migrations
272  pub fn down(&mut self) -> AnyhowResult<()> {
273    // Ensure the migration state directory exists
274    ensure_migration_state_dir_exists()?;
275
276    // Get the completed migrations
277    let completed_migrations = self.executor.get_completed_migrations()?;
278    if completed_migrations.is_empty() {
279      println!("Migrations table is empty. No need to run down migrations.");
280      return Ok(());
281    }
282
283    // Create a new progress bar instance
284    let pb = ProgressBar::new(completed_migrations.len() as u64);
285    let tick_interval = Duration::from_millis(80);
286    pb.set_style(progress_style()?);
287    pb.enable_steady_tick(tick_interval);
288    let mut rng = rand::thread_rng();
289
290    // Iterate over the completed migrations
291    for it in completed_migrations.iter().rev() {
292      // Sleep for a random duration between 40 and 300 milliseconds
293      // to simulate a delay and make the progress bar more interesting
294      thread::sleep(Duration::from_millis(rng.gen_range(40..300)));
295      pb.set_prefix(format!("{it:013}"));
296
297      // Get the migration file
298      let migration = self.migrations.get(it).context("Migration file not found")?;
299      let filename_parts: Vec<&str> = migration.filename.splitn(2, '_').collect();
300      let migration_name = filename_parts
301        .get(1)
302        .and_then(|s| s.strip_suffix(".sql"))
303        .context("Migration name not found")?;
304
305      // Set the progress bar message
306      pb.set_message(format!("Undoing migration: {migration_name}"));
307
308      // Get the migration down content and convert it to a string
309      let content_down = migration
310        .content_down
311        .as_ref()
312        .context("Migration content not found")?;
313      let content_down = get_content_string!(content_down);
314
315      // Run the migration content down
316      self.executor.migrate(&content_down, *it)?;
317      if std::env::var("MIGRATIONS_SKIP_LAST").is_err() || !completed_migrations.first().eq(&Some(it)) {
318        self.executor.delete_completed_migration(it.to_owned())?;
319      }
320      pb.inc(1);
321    }
322    pb.finish();
323
324    Ok(())
325  }
326
327  /// Redo the last migration
328  /// This is equivalent to running down and then up
329  /// on the last completed migration
330  /// If there are no completed migrations, this will run the first migration
331  pub fn redo(&mut self) -> AnyhowResult<()> {
332    // Get the last completed migration
333    let current = self.executor.get_last_completed_migration()?;
334    let current = if current == -1 { 0 } else { current };
335
336    // Get the migration file
337    let migration = self
338      .migrations
339      .get(&current)
340      .context("Migration file not found")?;
341
342    // Get the migration name
343    let filename_parts: Vec<&str> = migration.filename.splitn(2, '_').collect();
344    let migration_name = filename_parts
345      .get(1)
346      .and_then(|s| s.strip_suffix(".sql"))
347      .context("Migration name not found")?;
348
349    // Create a new progress bar instance
350    let pb = ProgressBar::new(1u64);
351    let tick_interval = Duration::from_millis(80);
352    pb.set_style(progress_style()?);
353    pb.enable_steady_tick(tick_interval);
354    pb.set_prefix(format!("{current:013}"));
355    pb.tick();
356
357    // If the current migration is not 0, run down
358    if current != 0 {
359      pb.set_message(format!("Undoing migration: {migration_name}"));
360
361      // Get the migration down content and convert it to a string
362      let content_down = migration
363        .content_down
364        .as_ref()
365        .context("Migration content not found")?;
366      let content_down = get_content_string!(content_down);
367
368      // Run the migration down
369      self.executor.migrate(&content_down, current)?;
370      self.executor.delete_completed_migration(current)?;
371    }
372
373    log::trace!("Running the method `redo` {:?}", migration);
374
375    // Set the progress bar message
376    pb.set_message(format!("Applying migration: {migration_name}"));
377
378    // Get the migration up content and convert it to a string
379    let content_up = migration
380      .content_up
381      .as_ref()
382      .context("Migration content not found")?;
383    let content_up = get_content_string!(content_up);
384
385    // Run the migration up
386    self.executor.migrate(&content_up, current)?;
387    self.executor.add_completed_migration(current)?;
388
389    pb.inc(1);
390    pb.finish();
391    Ok(())
392  }
393
394  /// Revert the last migration
395  /// This is equivalent to running down on the last completed migration
396  /// If there are no completed migrations, this will do nothing
397  pub fn revert(&mut self) -> AnyhowResult<()> {
398    // Get the migrations count
399    let migrations_count = self.executor.count_migrations()?;
400
401    // Get the last completed migration
402    let current = self.executor.get_last_completed_migration()?;
403
404    // If there are no completed migrations, do nothing
405    if current == -1 {
406      println!("Migrations table is empty. No need to run revert migrations.");
407      return Ok(());
408    }
409
410    // Get the migration file
411    let migration = self
412      .migrations
413      .get(&current)
414      .context("Migration file not found")?;
415
416    // Get the migration name
417    let filename_parts: Vec<&str> = migration.filename.splitn(2, '_').collect();
418    let migration_name = filename_parts
419      .get(1)
420      .and_then(|s| s.strip_suffix(".sql"))
421      .context("Migration name not found")?;
422
423    // Create a new progress bar instance
424    let pb = ProgressBar::new(1u64);
425    let tick_interval = Duration::from_millis(80);
426    pb.set_style(progress_style()?);
427    pb.enable_steady_tick(tick_interval);
428    pb.set_prefix(format!("{current:013}"));
429    pb.tick();
430    pb.set_message(format!("Reverting migration: {migration_name}"));
431
432    // Get the migration down content and convert it to a string
433    let content_down = migration
434      .content_down
435      .as_ref()
436      .context("Migration content not found")?;
437    let content_down = get_content_string!(content_down);
438
439    // Run the migration down
440    self.executor.migrate(&content_down, current)?;
441
442    // Delete the last completed migration
443    if migrations_count > 1 || std::env::var("MIGRATIONS_SKIP_LAST").is_err() {
444      self.executor.delete_last_completed_migration()?;
445    }
446
447    pb.inc(1);
448    pb.finish();
449    Ok(())
450  }
451
452  /// Drop the database
453  /// This will drop the database specified in the database URL
454  /// The database URL should be in the format `dialect://user:password@host:port/database`
455  /// For example, `postgres://user:password@localhost:5432/database`
456  pub fn drop(&mut self, db_url: &str) -> AnyhowResult<()> {
457    let db_url = Url::parse(db_url).ok();
458
459    // If the database URL is not found, return an error
460    if let Some(db_url) = db_url {
461      let db_name = db_url.path().trim_start_matches('/');
462      self.executor.drop_database(db_name)?;
463    }
464    Ok(())
465  }
466}
467
468#[cfg(test)]
469mod tests {
470  #[test]
471  fn test_create() {}
472}