1use crate::types::DatabaseValue;
2use crate::{DatabaseExecutor, Error, MigrationInfo, MigrationMeta, Query};
3use std::borrow::Cow;
4
5const DEFAULT_MIGRATION_TABLE: &str = "_d1_migrations";
6
7#[derive(Clone)]
9pub struct Migration<Q> {
10 version: u32,
11 description: &'static str,
12 steps: Vec<Q>,
13}
14
15impl<Q> Migration<Q> {
16 pub const fn new(version: u32, description: &'static str, steps: Vec<Q>) -> Self {
18 Self {
19 version,
20 description,
21 steps,
22 }
23 }
24}
25
26struct CheckTableQuery<'a> {
29 table: &'a str,
30}
31
32impl<'a> Query for CheckTableQuery<'a> {
33 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
34 Ok((
35 Cow::Borrowed("SELECT 1 FROM sqlite_master WHERE type='table' AND name = ?"),
36 vec![self.table.into()],
37 ))
38 }
39}
40
41struct CheckIndexQuery<'a> {
42 index: &'a str,
43}
44
45impl<'a> Query for CheckIndexQuery<'a> {
46 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
47 Ok((
48 Cow::Borrowed("SELECT 1 FROM sqlite_master WHERE type='index' AND name = ?"),
49 vec![self.index.into()],
50 ))
51 }
52}
53
54struct CheckColumnQuery<'a> {
55 table: &'a str,
56 column: &'a str,
57}
58
59impl<'a> Query for CheckColumnQuery<'a> {
60 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
61 Ok((
62 Cow::Borrowed("SELECT 1 FROM pragma_table_info(?) WHERE name = ?"),
63 vec![self.table.into(), self.column.into()],
64 ))
65 }
66}
67
68struct CreateMigrationTableQuery {
69 table_name: String,
70}
71
72impl Query for CreateMigrationTableQuery {
73 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
74 let sql = format!(
75 "CREATE TABLE IF NOT EXISTS {} (version INTEGER PRIMARY KEY, applied_at INTEGER NOT NULL)",
76 self.table_name
77 );
78 Ok((Cow::Owned(sql), vec![]))
79 }
80}
81
82struct GetCurrentVersionQuery {
83 table_name: String,
84}
85
86impl Query for GetCurrentVersionQuery {
87 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
88 let sql = format!("SELECT MAX(version) as ver FROM {}", self.table_name);
89 Ok((Cow::Owned(sql), vec![]))
90 }
91}
92
93struct InsertVersionQuery {
94 table_name: String,
95 version: u32,
96}
97
98impl Query for InsertVersionQuery {
99 fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
100 let sql = format!(
101 "INSERT INTO {} (version, applied_at) VALUES (?, strftime('%s', 'now'))",
102 self.table_name
103 );
104 Ok((Cow::Owned(sql), vec![self.version.into()]))
105 }
106}
107
108#[derive(serde::Deserialize)]
109struct VersionResult {
110 ver: Option<u32>,
111}
112
113async fn check_table_exists<D>(db: &D, table: &str) -> Result<bool, Error>
116where
117 D: DatabaseExecutor,
118{
119 let q = CheckTableQuery { table };
120 let res: Option<serde_json::Value> = db.query_first(q).await?;
126 Ok(res.is_some())
127}
128
129async fn check_index_exists<D>(db: &D, index: &str) -> Result<bool, Error>
130where
131 D: DatabaseExecutor,
132{
133 let q = CheckIndexQuery { index };
134 let res: Option<serde_json::Value> = db.query_first(q).await?;
135 Ok(res.is_some())
136}
137
138async fn check_column_exists<D>(db: &D, table: &str, column: &str) -> Result<bool, Error>
139where
140 D: DatabaseExecutor,
141{
142 let q = CheckColumnQuery { table, column };
143 let res: Option<serde_json::Value> = db.query_first(q).await?;
144 Ok(res.is_some())
145}
146
147pub async fn migrate<D, Q, I, F>(
156 db: &D,
157 migrations: I,
158 migration_table: Option<&str>,
159 logger: Option<F>,
160) -> Result<(), Error>
161where
162 D: DatabaseExecutor,
163 Q: Query + MigrationMeta + Clone, I: IntoIterator<Item = Migration<Q>>,
165 F: Fn(&str),
166{
167 let table_name = migration_table
168 .unwrap_or(DEFAULT_MIGRATION_TABLE)
169 .to_string();
170
171 db.execute(CreateMigrationTableQuery {
173 table_name: table_name.clone(),
174 })
175 .await?;
176
177 let version_result: Option<VersionResult> = db
179 .query_first(GetCurrentVersionQuery {
180 table_name: table_name.clone(),
181 })
182 .await?;
183 let current_ver = version_result.and_then(|r| r.ver).unwrap_or(0);
184
185 for migration in migrations {
186 if migration.version <= current_ver {
187 continue;
188 }
189
190 if let Some(log) = &logger {
191 log(&format!(
192 "Applying migration v{}: {}",
193 migration.version, migration.description
194 ));
195 }
196
197 for step in migration.steps {
198 let info = step.migration_info();
199 let should_execute = match info {
200 Some(MigrationInfo::Table(name)) => !check_table_exists(db, name).await?,
201 Some(MigrationInfo::Index(name)) => !check_index_exists(db, name).await?,
202 Some(MigrationInfo::Column { table, column }) => {
203 !check_column_exists(db, table, column).await?
204 }
205 None => true, };
207
208 if should_execute {
209 if let Some(log) = &logger {
210 if let Some(info) = info {
211 log(&format!(" -> Executing step for {:?}", info));
212 } else {
213 log(" -> Executing raw step");
214 }
215 }
216 db.execute(step.clone()).await?;
217 } else if let Some(log) = &logger {
218 log(&format!(" -> Skipping step (already exists): {:?}", info));
219 }
220 }
221
222 db.execute(InsertVersionQuery {
224 table_name: table_name.clone(),
225 version: migration.version,
226 })
227 .await?;
228 }
229
230 Ok(())
231}