Skip to main content

d1_orm_migration/
runner.rs

1use std::collections::HashMap;
2use worker::D1Database;
3use crate::{error::MigrationError, migration::Migration, tracker};
4
5pub struct MigrationStatus {
6    pub id: String,
7    pub name: String,
8    pub applied: bool,
9    pub applied_at: Option<String>,
10}
11
12pub struct MigrationRunner<'db> {
13    db: &'db D1Database,
14    migrations: Vec<Migration>,
15}
16
17impl<'db> MigrationRunner<'db> {
18    pub fn new(db: &'db D1Database) -> Self {
19        Self { db, migrations: vec![] }
20    }
21
22    pub fn register(mut self, m: Migration) -> Self {
23        self.migrations.push(m);
24        self
25    }
26
27    pub async fn run(&self) -> Result<usize, MigrationError> {
28        tracker::ensure_table(self.db).await?;
29        let mut applied = 0;
30        for m in &self.migrations {
31            if !tracker::is_applied(self.db, m.id).await? {
32                self.db.prepare(m.up)
33                    .run().await
34                    .map_err(|e| MigrationError::Sql(format!("{}: {}", m.id, e)))?;
35                tracker::record(self.db, m.id, m.name).await?;
36                applied += 1;
37            }
38        }
39        Ok(applied)
40    }
41
42    pub async fn rollback(&self, steps: usize) -> Result<usize, MigrationError> {
43        tracker::ensure_table(self.db).await?;
44        let entries = tracker::applied_entries(self.db).await?;
45        let to_roll: Vec<String> = entries.into_iter().rev().take(steps).map(|(id, _)| id).collect();
46        let mut count = 0;
47        for id in &to_roll {
48            let m = self.migrations.iter().find(|m| m.id == id.as_str())
49                .ok_or_else(|| MigrationError::NotFound(id.clone()))?;
50            let down = m.down.ok_or_else(|| MigrationError::MissingDown(id.clone()))?;
51            self.db.prepare(down)
52                .run().await
53                .map_err(|e| MigrationError::Sql(format!("{}: {}", id, e)))?;
54            tracker::remove(self.db, id).await?;
55            count += 1;
56        }
57        Ok(count)
58    }
59
60    pub async fn status(&self) -> Result<Vec<MigrationStatus>, MigrationError> {
61        tracker::ensure_table(self.db).await?;
62        let entries = tracker::applied_entries(self.db).await?;
63        let map: HashMap<String, String> = entries.into_iter().collect();
64        Ok(self.migrations.iter().map(|m| {
65            let at = map.get(m.id).cloned();
66            MigrationStatus {
67                id: m.id.to_string(),
68                name: m.name.to_string(),
69                applied: at.is_some(),
70                applied_at: at,
71            }
72        }).collect())
73    }
74}