mysql_connector/migrator/
migrator_inner.rs

1use {
2    super::{MigrationList, Version},
3    crate::{
4        error::Error, migrator::model::MigrationModel, pool::AsyncPoolTrait, types::Value,
5        Connection,
6    },
7    std::collections::HashMap,
8};
9
10pub struct Migrator<'a> {
11    pool: &'a dyn AsyncPoolTrait<Connection>,
12    migrations: &'a [MigrationList],
13    applied: HashMap<Version, Vec<String>>,
14}
15
16impl<'a> Migrator<'a> {
17    pub async fn new(
18        pool: &'a dyn AsyncPoolTrait<Connection>,
19        migrations: &'a [MigrationList],
20    ) -> Result<Self, Error> {
21        debug_assert!(MigrationList::ordered(migrations));
22
23        let mut conn = pool.get().await?;
24        let mut migrations_table = conn.query::<Vec<Value>>("select 1 from `information_schema`.`PARTITIONS` where `TABLE_NAME` = \"migrations\" and `TABLE_SCHEMA` = DATABASE()").await?;
25        if migrations_table.collect().await?.is_empty() {
26            conn.execute_query(
27                "create table `migrations` (
28                    `version_0` smallint unsigned not null,
29                    `version_1` smallint unsigned not null,
30                    `version_2` smallint unsigned not null,
31                    `name` varchar(255) not null,
32                    `applied_at` datetime not null default current_timestamp,
33                    unique (`version_0`, `version_1`, `version_2`, `name`)
34                )",
35            )
36            .await?;
37        }
38
39        let mut query = conn
40            .query::<MigrationModel>(
41                "select `version_0`, `version_1`, `version_2`, `name` from `migrations`",
42            )
43            .await?;
44        let mut applied: HashMap<Version, Vec<String>> = HashMap::new();
45        while let Some(row) = query.next().await? {
46            let mut found = false;
47            'outer: for migration_list in migrations {
48                if migration_list.version == row.version {
49                    for migration in migration_list.migrations {
50                        if migration.name() == row.name {
51                            found = true;
52                            break 'outer;
53                        }
54                    }
55                    break 'outer;
56                }
57            }
58            if !found {
59                panic!("unknown migration: {}: \"{}\"", row.version, row.name)
60            }
61            Self::insert_applied(&mut applied, row.version, row.name);
62        }
63        Ok(Self {
64            pool,
65            migrations,
66            applied,
67        })
68    }
69
70    pub async fn up(&mut self) -> Result<(), Error> {
71        self.up_to_version(None).await
72    }
73
74    fn insert_applied(applied: &mut HashMap<Version, Vec<String>>, version: Version, name: String) {
75        match applied.get_mut(&version) {
76            Some(list) => list.push(name),
77            None => {
78                applied.insert(version, vec![name]);
79            }
80        };
81    }
82
83    pub fn get_applied<'b>(
84        applied: &'b mut HashMap<Version, Vec<String>>,
85        version: &Version,
86        name: &str,
87    ) -> Option<(&'b mut Vec<String>, usize)> {
88        if let Some(migrations) = applied.get_mut(version) {
89            return migrations
90                .iter()
91                .position(|x| x == name)
92                .map(|pos| (migrations, pos));
93        }
94        None
95    }
96
97    pub async fn up_to_version(&mut self, version: Option<Version>) -> Result<(), Error> {
98        for migration_list in self.migrations {
99            match &version {
100                Some(version) if migration_list.version > *version => (),
101                _ => {
102                    for migration in migration_list.migrations {
103                        if Self::get_applied(
104                            &mut self.applied,
105                            &migration_list.version,
106                            migration.name(),
107                        )
108                        .is_none()
109                        {
110                            migration.up(self.pool).await?;
111                            Self::insert_applied(
112                                &mut self.applied,
113                                migration_list.version,
114                                migration.name().to_owned(),
115                            );
116                            self.pool.get().await?.execute_query(&format!("insert into `migrations` (`version_0`, `version_1`, `version_2`, `name`) values ({}, {}, {}, \"{}\")", migration_list.version.0, migration_list.version.1, migration_list.version.2, migration.name())).await?;
117                        }
118                    }
119                }
120            }
121        }
122        Ok(())
123    }
124
125    pub async fn down_to_version(&mut self, version: Version) -> Result<(), Error> {
126        for migration_list in self.migrations.iter().rev() {
127            if migration_list.version > version {
128                for migration in migration_list.migrations.iter().rev() {
129                    if let Some((applied, index)) = Self::get_applied(
130                        &mut self.applied,
131                        &migration_list.version,
132                        migration.name(),
133                    ) {
134                        migration.down(self.pool).await?;
135                        applied.swap_remove(index);
136                        self.pool.get().await?.execute_query(&format!("delete from `migrations` where `version_0` = {} and `version_1` = {} and `version_2` = {} and `name` = \"{}\"", migration_list.version.0, migration_list.version.1, migration_list.version.2, migration.name())).await?;
137                    }
138                }
139            }
140        }
141        Ok(())
142    }
143
144    pub async fn to_version(&mut self, version: Version) -> Result<(), Error> {
145        self.up_to_version(Some(version)).await?;
146        self.down_to_version(version).await
147    }
148
149    #[cfg(debug_assertions)]
150    pub async fn one_down(&mut self) -> Result<bool, Error> {
151        for migration_list in self.migrations.iter().rev() {
152            for migration in migration_list.migrations.iter().rev() {
153                if let Some((applied, index)) =
154                    Self::get_applied(&mut self.applied, &migration_list.version, migration.name())
155                {
156                    migration.down(self.pool).await?;
157                    applied.swap_remove(index);
158                    self.pool.get().await?.execute_query(&format!("delete from `migrations` where `version_0` = {} and `version_1` = {} and `version_2` = {} and `name` = \"{}\"", migration_list.version.0, migration_list.version.1, migration_list.version.2, migration.name())).await?;
159                    return Ok(true);
160                }
161            }
162        }
163        Ok(false)
164    }
165}