1pub use std::collections::HashMap;
2
3use include_dir::Dir;
4use sha2::{Digest, Sha256};
5use ydb::Client;
6use ydb::TableClient;
7
8use crate::{query, select, update};
9
10#[derive(Clone, Default)]
15#[allow(unused)]
16pub struct Migrator {
17 migrations: Vec<Migration>,
18}
19
20#[derive(Debug)]
21pub enum MigrationError {
22 FailCreateMigrationTable(String),
23 FailGetAppliedMigrations(String),
24 FailApplyMigration(String),
25 ChangeAppliedMigration(String),
26 FailMarkMigrationAsApplied(String),
27}
28
29impl Migrator {
30 #[allow(unused)]
31 pub fn new_from_dir(dir: &Dir) -> Self {
32 let mut migration = Self::default();
33 dir.files().for_each(|f| {
34 let file_name = f.path().file_name().unwrap().to_str().unwrap().to_string();
35 if file_name.ends_with(".sql") {
36 let content = f.contents_utf8().unwrap().to_owned();
37 migration.add_migration(file_name, content);
38 }
39 });
40 migration
41 }
42
43 #[allow(unused)]
44 pub fn add_migration(&mut self, name: String, sql: String) {
45 let migration = Migration::new(name, sql);
46 self.migrations.push(migration);
47 }
48
49 #[allow(unused)]
50 pub async fn migrate(&mut self, client: &Client) -> Result<(), MigrationError> {
51 let table_client = client.table_client();
52 Self::create_migrated_table(&table_client).await;
53 let applied_migrations = Self::get_applied_migrations(&table_client).await?;
54 self.migrations.sort_by_key(|m| m.name.clone());
55 self.verify_migrations(&applied_migrations)?;
56 self.apply_migrate(table_client, applied_migrations).await
57 }
58
59 async fn get_applied_migrations(
60 table_client: &TableClient,
61 ) -> Result<HashMap<String, AppliedMigration>, MigrationError> {
62 let result: Vec<(String, Vec<u8>)> =
63 select!(table_client, query!("select * from migrations"),
64 name=> String,
65 checksum=> ydb::Bytes)
66 .await
67 .map_err(|e| MigrationError::FailGetAppliedMigrations(format!("{}", e)))?;
68
69 Ok(result
70 .into_iter()
71 .map(|(name, checksum)| (name.clone(), AppliedMigration { name, checksum }))
72 .collect())
73 }
74 fn verify_migrations(
75 &self,
76 applied_migrations: &HashMap<String, AppliedMigration>,
77 ) -> Result<(), MigrationError> {
78 for migration in self.migrations.iter() {
79 if let Some(applied_migration) = applied_migrations.get(&migration.name) {
80 if applied_migration.checksum != migration.checksum {
81 return Err(MigrationError::ChangeAppliedMigration(format!(
82 "name {}",
83 migration.name
84 )));
85 }
86 }
87 }
88 Ok(())
89 }
90
91 async fn apply_migrate(
92 &mut self,
93 table_client: TableClient,
94 applied_migrations: HashMap<String, AppliedMigration>,
95 ) -> Result<(), MigrationError> {
96 for migration in &self.migrations {
97 if applied_migrations.contains_key(&migration.name) {
98 continue;
99 }
100 Self::apply_migration(&table_client, migration).await?;
101 Self::mark_migration_as_applied(&table_client, migration).await?;
102 }
103 Ok(())
104 }
105
106 async fn mark_migration_as_applied(
107 table_client: &TableClient,
108 migration: &Migration,
109 ) -> Result<(), MigrationError> {
110 update!(
111 table_client,
112 query!(
113 "insert into migrations (name, checksum) values($name, $checksum)",
114 name => migration.name,
115 checksum => migration.checksum
116 )
117 )
118 .await
119 .map_err(|e| MigrationError::FailMarkMigrationAsApplied(format!("{}", e)))
120 }
121
122 async fn apply_migration(
123 table_client: &TableClient,
124 migration: &Migration,
125 ) -> Result<(), MigrationError> {
126 table_client
127 .retry_execute_scheme_query(migration.sql.clone())
128 .await
129 .map_err(|e| MigrationError::FailApplyMigration(format!("{}", e)))
130 }
131
132 async fn create_migrated_table(table_client: &TableClient) -> Result<(), MigrationError> {
133 table_client
134 .retry_execute_scheme_query(
135 "create table migrations(name Utf8, checksum string,PRIMARY KEY(name));",
136 )
137 .await
138 .map_err(|e| MigrationError::FailCreateMigrationTable(format!("{}", e)))
139 }
140}
141
142#[allow(unused)]
143struct AppliedMigration {
144 name: String,
145 checksum: Vec<u8>,
146}
147
148#[derive(Clone)]
149struct Migration {
150 name: String,
151 sql: String,
152 checksum: Vec<u8>,
153}
154
155impl Migration {
156 fn new(name: String, sql: String) -> Self {
157 let mut hasher = Sha256::default();
158 hasher.update(sql.clone());
159 let checksum = hasher.finalize().to_vec();
160 Self {
161 name,
162 sql,
163 checksum,
164 }
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use include_dir::include_dir;
171 use ydb::Query;
172
173 use crate::migration::Migrator;
174 use crate::test_container::get_or_create_ydb_instance;
175
176 #[tokio::test]
177 async fn should_migration() {
178 let (_node, mut client) = get_or_create_ydb_instance("should_migration").await;
179 let mut migrator = Migrator::default();
180
181 migrator.add_migration(
182 "002.sql".to_owned(),
183 "alter table a ADD COLUMN some_flag Bool;".to_owned(),
184 );
185 migrator.add_migration(
186 "001.sql".to_owned(),
187 "create table a(id int, PRIMARY KEY(id));".to_owned(),
188 );
189 migrator.migrate(&mut client).await.unwrap();
190 client
191 .table_client()
192 .retry_transaction(|mut t| async move {
193 t.query(Query::new(
194 "insert into a (id, some_flag) values (1, false)",
195 ))
196 .await?;
197 Ok(())
198 })
199 .await
200 .unwrap();
201 }
202
203 #[tokio::test]
204 async fn should_not_migration_if_migrated() {
205 let (_node, mut client) =
206 get_or_create_ydb_instance("should_not_migration_if_migrated").await;
207 let mut migrator = Migrator::default();
208 migrator.add_migration(
209 "001.sql".to_owned(),
210 "create table a(id int, PRIMARY KEY(id));".to_owned(),
211 );
212 migrator.migrate(&mut client).await.unwrap();
213 migrator.migrate(&mut client).await.unwrap();
214 }
215
216 #[tokio::test]
217 async fn should_not_migration_when_already_migrated_script_changed() {
218 let (_node, mut client) =
219 get_or_create_ydb_instance("should_not_migration_when_already_migrated_script_changed")
220 .await;
221 {
222 let mut migrator = Migrator::default();
223 migrator.add_migration(
224 "001.sql".to_owned(),
225 "create table a(id int, PRIMARY KEY(id));".to_owned(),
226 );
227 migrator.migrate(&mut client).await.unwrap();
228 }
229
230 let mut migrator = Migrator::default();
231 migrator.add_migration(
232 "001.sql".to_owned(),
233 "create table b(id int, PRIMARY KEY(id));".to_owned(),
234 );
235 assert!(migrator.migrate(&mut client).await.is_err());
236 }
237
238 #[tokio::test]
239 async fn should_migration_from_directory() {
240 let (_node, mut client) =
241 get_or_create_ydb_instance("should_migration_from_directory").await;
242 let mut migrator =
243 Migrator::new_from_dir(&include_dir!("$CARGO_MANIFEST_DIR/test-migration"));
244 migrator.migrate(&mut client).await.unwrap();
245 client
246 .table_client()
247 .retry_transaction(|mut t| async move {
248 t.query(Query::new("insert into a (id) values (1)")).await?;
249 t.query(Query::new("insert into b (id) values (1)")).await?;
250 Ok(())
251 })
252 .await
253 .unwrap();
254 }
255}