Skip to main content

d1_orm/
migration.rs

1use crate::types::DatabaseValue;
2use crate::{DatabaseExecutor, Error, MigrationInfo, MigrationMeta, Query};
3use std::borrow::Cow;
4
5const DEFAULT_MIGRATION_TABLE: &str = "_d1_migrations";
6
7/// Represents a single database migration version.
8#[derive(Clone)]
9pub struct Migration<Q> {
10    version: u32,
11    description: &'static str,
12    steps: Vec<Q>,
13}
14
15impl<Q> Migration<Q> {
16    /// Creates a new migration with the given version, description, and steps.
17    pub const fn new(version: u32, description: &'static str, steps: Vec<Q>) -> Self {
18        Self {
19            version,
20            description,
21            steps,
22        }
23    }
24}
25
26// --- Internal Queries ---
27
28struct CheckTableQuery<'a> {
29    table: &'a str,
30}
31
32impl<'a> Query for CheckTableQuery<'a> {
33    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
34        Ok((
35            Cow::Borrowed("SELECT 1 FROM sqlite_master WHERE type='table' AND name = ?"),
36            vec![self.table.into()],
37        ))
38    }
39}
40
41struct CheckIndexQuery<'a> {
42    index: &'a str,
43}
44
45impl<'a> Query for CheckIndexQuery<'a> {
46    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
47        Ok((
48            Cow::Borrowed("SELECT 1 FROM sqlite_master WHERE type='index' AND name = ?"),
49            vec![self.index.into()],
50        ))
51    }
52}
53
54struct CheckColumnQuery<'a> {
55    table: &'a str,
56    column: &'a str,
57}
58
59impl<'a> Query for CheckColumnQuery<'a> {
60    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
61        Ok((
62            Cow::Borrowed("SELECT 1 FROM pragma_table_info(?) WHERE name = ?"),
63            vec![self.table.into(), self.column.into()],
64        ))
65    }
66}
67
68struct CreateMigrationTableQuery {
69    table_name: String,
70}
71
72impl Query for CreateMigrationTableQuery {
73    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
74        let sql = format!(
75            "CREATE TABLE IF NOT EXISTS {} (version INTEGER PRIMARY KEY, applied_at INTEGER NOT NULL)",
76            self.table_name
77        );
78        Ok((Cow::Owned(sql), vec![]))
79    }
80}
81
82struct GetCurrentVersionQuery {
83    table_name: String,
84}
85
86impl Query for GetCurrentVersionQuery {
87    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
88        let sql = format!("SELECT MAX(version) as ver FROM {}", self.table_name);
89        Ok((Cow::Owned(sql), vec![]))
90    }
91}
92
93struct InsertVersionQuery {
94    table_name: String,
95    version: u32,
96}
97
98impl Query for InsertVersionQuery {
99    fn build(&self) -> Result<(Cow<'static, str>, Vec<DatabaseValue>), Error> {
100        let sql = format!(
101            "INSERT INTO {} (version, applied_at) VALUES (?, strftime('%s', 'now'))",
102            self.table_name
103        );
104        Ok((Cow::Owned(sql), vec![self.version.into()]))
105    }
106}
107
108#[derive(serde::Deserialize)]
109struct VersionResult {
110    ver: Option<u32>,
111}
112
113// --- Helpers ---
114
115async fn check_table_exists<D>(db: &D, table: &str) -> Result<bool, Error>
116where
117    D: DatabaseExecutor,
118{
119    let q = CheckTableQuery { table };
120    // We try to deserialize into a generic JSON Value because check queries return diverse shapes
121    // (scalar 1 or object { "1": 1 }) depending on the backend implementation detail.
122    // For D1/SqliteExecutor, query_first usually returns an Option<T>.
123    // If the query returns "SELECT 1 ...", D1 might return { "1": 1 } or just 1.
124    // Safest bet is to check if *any* row is returned.
125    let res: Option<serde_json::Value> = db.query_first(q).await?;
126    Ok(res.is_some())
127}
128
129async fn check_index_exists<D>(db: &D, index: &str) -> Result<bool, Error>
130where
131    D: DatabaseExecutor,
132{
133    let q = CheckIndexQuery { index };
134    let res: Option<serde_json::Value> = db.query_first(q).await?;
135    Ok(res.is_some())
136}
137
138async fn check_column_exists<D>(db: &D, table: &str, column: &str) -> Result<bool, Error>
139where
140    D: DatabaseExecutor,
141{
142    let q = CheckColumnQuery { table, column };
143    let res: Option<serde_json::Value> = db.query_first(q).await?;
144    Ok(res.is_some())
145}
146
147/// Executes database migrations.
148///
149/// # Arguments
150///
151/// * `db` - The database executor.
152/// * `migrations` - A list of migrations to apply.
153/// * `migration_table` - Optional custom name for the migration tracking table. Defaults to `_d1_migrations`.
154/// * `logger` - Optional callback for logging migration progress.
155pub async fn migrate<D, Q, I, F>(
156    db: &D,
157    migrations: I,
158    migration_table: Option<&str>,
159    logger: Option<F>,
160) -> Result<(), Error>
161where
162    D: DatabaseExecutor,
163    Q: Query + MigrationMeta + Clone, // Clone needed for iterating steps
164    I: IntoIterator<Item = Migration<Q>>,
165    F: Fn(&str),
166{
167    let table_name = migration_table
168        .unwrap_or(DEFAULT_MIGRATION_TABLE)
169        .to_string();
170
171    // 1. Ensure migration table exists
172    db.execute(CreateMigrationTableQuery {
173        table_name: table_name.clone(),
174    })
175    .await?;
176
177    // 2. Get current version
178    let version_result: Option<VersionResult> = db
179        .query_first(GetCurrentVersionQuery {
180            table_name: table_name.clone(),
181        })
182        .await?;
183    let current_ver = version_result.and_then(|r| r.ver).unwrap_or(0);
184
185    for migration in migrations {
186        if migration.version <= current_ver {
187            continue;
188        }
189
190        if let Some(log) = &logger {
191            log(&format!(
192                "Applying migration v{}: {}",
193                migration.version, migration.description
194            ));
195        }
196
197        for step in migration.steps {
198            let info = step.migration_info();
199            let should_execute = match info {
200                Some(MigrationInfo::Table(name)) => !check_table_exists(db, name).await?,
201                Some(MigrationInfo::Index(name)) => !check_index_exists(db, name).await?,
202                Some(MigrationInfo::Column { table, column }) => {
203                    !check_column_exists(db, table, column).await?
204                }
205                None => true, // Always execute if no metadata provided
206            };
207
208            if should_execute {
209                if let Some(log) = &logger {
210                    if let Some(info) = info {
211                        log(&format!("  -> Executing step for {:?}", info));
212                    } else {
213                        log("  -> Executing raw step");
214                    }
215                }
216                db.execute(step.clone()).await?;
217            } else if let Some(log) = &logger {
218                log(&format!("  -> Skipping step (already exists): {:?}", info));
219            }
220        }
221
222        // 3. Update version
223        db.execute(InsertVersionQuery {
224            table_name: table_name.clone(),
225            version: migration.version,
226        })
227        .await?;
228    }
229
230    Ok(())
231}