Skip to main content

dbkit_rs/
initialization.rs

1use crate::base_handler::{BaseHandler, FetchMode, WriteOp};
2use crate::DbkitError;
3use deadpool_postgres::Pool;
4use std::collections::hash_map::DefaultHasher;
5use std::hash::{Hash, Hasher};
6use std::sync::Arc;
7use tracing::{error, info};
8
9/// Batch DDL migration executor with tracking.
10///
11/// Maintains a `_dbkit_migrations` table so already-applied migrations
12/// are skipped on subsequent runs. Each migration is identified by a
13/// user-provided name and a content hash.
14pub struct InitializationHandler {
15    handler: BaseHandler,
16}
17
18impl InitializationHandler {
19    pub fn new(pool: Arc<Pool>) -> Self {
20        Self {
21            handler: BaseHandler::new(pool),
22        }
23    }
24
25    /// Ensure the migrations tracking table exists.
26    async fn ensure_tracking_table(&self) -> Result<(), DbkitError> {
27        self.handler
28            .execute_write(WriteOp::BatchDDL {
29                queries: &[
30                    "CREATE TABLE IF NOT EXISTS _dbkit_migrations (
31                        id SERIAL PRIMARY KEY,
32                        name TEXT NOT NULL UNIQUE,
33                        hash TEXT NOT NULL,
34                        applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
35                    )",
36                ],
37            })
38            .await?;
39        Ok(())
40    }
41
42    /// Compute a hash of the SQL content for change detection.
43    fn hash_sql(sql: &str) -> String {
44        let mut hasher = DefaultHasher::new();
45        sql.hash(&mut hasher);
46        format!("{:016x}", hasher.finish())
47    }
48
49    /// Run a named migration. Skips if already applied with the same content hash.
50    ///
51    /// If the migration name exists but the hash differs, it returns an error
52    /// (content changed after being applied).
53    pub async fn run_named_migration(&self, name: &str, sql: &str) -> Result<(), DbkitError> {
54        self.ensure_tracking_table().await?;
55
56        let hash = Self::hash_sql(sql);
57
58        // Check if already applied
59        let result = self
60            .handler
61            .execute_write(WriteOp::Single {
62                query: "SELECT hash FROM _dbkit_migrations WHERE name = $1",
63                params: &[&name],
64                mode: FetchMode::Optional,
65            })
66            .await?;
67
68        if let Some(row) = result.optional()? {
69            let existing_hash: String = row.get(0);
70            if existing_hash == hash {
71                info!("migration '{}' already applied, skipping", name);
72                return Ok(());
73            } else {
74                return Err(DbkitError::Migration(format!(
75                    "migration '{}' was already applied but content has changed (hash {} → {})",
76                    name, existing_hash, hash
77                )));
78            }
79        }
80
81        // Run the migration
82        info!("applying migration '{}'...", name);
83        let queries: Vec<String> = sql
84            .split(';')
85            .map(|s| s.trim().to_string())
86            .filter(|s| !s.is_empty())
87            .collect();
88
89        let query_refs: Vec<&str> = queries.iter().map(|s| s.as_str()).collect();
90
91        match self
92            .handler
93            .execute_write(WriteOp::BatchDDL {
94                queries: &query_refs,
95            })
96            .await
97        {
98            Ok(_) => {
99                info!(
100                    "migration '{}': {} DDL statements executed",
101                    name,
102                    query_refs.len()
103                );
104            }
105            Err(e) => {
106                error!("migration '{}' failed: {:?}", name, e);
107                return Err(DbkitError::Migration(e.to_string()));
108            }
109        }
110
111        // Record the migration
112        self.handler
113            .execute_write(WriteOp::Single {
114                query: "INSERT INTO _dbkit_migrations (name, hash) VALUES ($1, $2)",
115                params: &[&name, &hash],
116                mode: FetchMode::None,
117            })
118            .await?;
119
120        info!("migration '{}' recorded", name);
121        Ok(())
122    }
123
124    /// Run migrations from a SQL string (semicolon-separated DDL statements).
125    ///
126    /// This is the simple/legacy API — it runs all statements unconditionally
127    /// without tracking. Use [`run_named_migration`] for tracked migrations.
128    pub async fn run_migrations(&self, sql: &str) -> Result<(), DbkitError> {
129        info!("running database migrations...");
130
131        let queries: Vec<String> = sql
132            .split(';')
133            .map(|s| s.trim().to_string())
134            .filter(|s| !s.is_empty())
135            .collect();
136
137        let query_refs: Vec<&str> = queries.iter().map(|s| s.as_str()).collect();
138
139        match self
140            .handler
141            .execute_write(WriteOp::BatchDDL {
142                queries: &query_refs,
143            })
144            .await
145        {
146            Ok(_) => {
147                info!("{} DDL statements executed", query_refs.len());
148            }
149            Err(e) => {
150                error!("migration failed: {:?}", e);
151                return Err(DbkitError::Migration(e.to_string()));
152            }
153        }
154
155        Ok(())
156    }
157
158    /// List all applied migrations (name, hash, applied_at).
159    pub async fn applied_migrations(&self) -> Result<Vec<(String, String, String)>, DbkitError> {
160        self.ensure_tracking_table().await?;
161
162        let result = self
163            .handler
164            .execute_write(WriteOp::Single {
165                query: "SELECT name, hash, applied_at::TEXT FROM _dbkit_migrations ORDER BY id",
166                params: &[],
167                mode: FetchMode::All,
168            })
169            .await?;
170
171        let rows = result.all()?;
172        Ok(rows
173            .iter()
174            .map(|row| {
175                let name: String = row.get(0);
176                let hash: String = row.get(1);
177                let applied_at: String = row.get(2);
178                (name, hash, applied_at)
179            })
180            .collect())
181    }
182}