Skip to main content

adk_session/
migration.rs

1//! Lightweight, embedded migration runner for SQL-backed session services.
2//!
3//! This module provides shared types and free functions that track applied
4//! schema versions in a per-backend registry table and execute only unapplied
5//! forward-only migration steps.
6//!
7//! The types ([`MigrationStep`], [`AppliedMigration`], [`MigrationError`]) are
8//! always compiled. The SQL runner functions (`run_sql_migrations`,
9//! `sql_schema_version`) require the `sqlite` or `postgres` feature.
10
11use chrono::{DateTime, Utc};
12
13/// A single forward-only migration step.
14///
15/// The struct intentionally does not contain the SQL itself — each backend
16/// defines its own step list as `&[(i64, &str, &str)]` tuples of
17/// `(version, description, sql)`.
18#[derive(Debug, Clone, Copy)]
19pub struct MigrationStep {
20    /// Monotonically increasing version number, starting at 1.
21    pub version: i64,
22    /// Human-readable description of what this step does.
23    pub description: &'static str,
24}
25
26/// Record of an applied migration stored in the registry table.
27#[derive(Debug, Clone)]
28pub struct AppliedMigration {
29    /// The applied version number.
30    pub version: i64,
31    /// Description recorded at apply time.
32    pub description: String,
33    /// UTC timestamp of application.
34    pub applied_at: DateTime<Utc>,
35}
36
37/// Error context for a failed migration step.
38#[derive(Debug)]
39pub struct MigrationError {
40    /// The version that failed.
41    pub version: i64,
42    /// Description of the failed step.
43    pub description: String,
44    /// Underlying cause (database error message, etc.).
45    pub cause: String,
46}
47
48impl std::fmt::Display for MigrationError {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "migration v{} ({}) failed: {}", self.version, self.description, self.cause)
51    }
52}
53
54impl std::error::Error for MigrationError {}
55
56// ---------------------------------------------------------------------------
57// SQL runner — macro generates concrete implementations per database backend
58// ---------------------------------------------------------------------------
59
60/// Generates `run_sql_migrations` and `sql_schema_version` for a concrete
61/// sqlx pool type. Each SQL backend (`sqlite`, `postgres`) gets its own
62/// monomorphised copy, avoiding complex generic trait bounds.
63#[cfg(any(feature = "sqlite", feature = "postgres"))]
64macro_rules! impl_sql_migration_runner {
65    ($mod_name:ident, $pool_ty:ty, $int_type:expr) => {
66        pub mod $mod_name {
67            use super::MigrationError;
68            use chrono::Utc;
69            use sqlx::Row;
70            use std::future::Future;
71
72            /// Run all pending migrations for a SQL backend.
73            ///
74            /// 1. Creates the registry table if it does not exist.
75            /// 2. Calls `detect_existing` to check for pre-existing schema
76            ///    tables. If tables exist but the registry is empty, records
77            ///    version 1 as already applied (baseline).
78            /// 3. Reads the maximum applied version from the registry.
79            /// 4. If the database version exceeds the compiled-in maximum,
80            ///    returns a version-mismatch error.
81            /// 5. Executes each unapplied step inside a transaction and
82            ///    records it in the registry.
83            pub async fn run_sql_migrations<F, Fut>(
84                pool: &$pool_ty,
85                registry_table: &str,
86                steps: &[(i64, &str, &str)],
87                detect_existing: F,
88            ) -> Result<(), adk_core::AdkError>
89            where
90                F: FnOnce() -> Fut,
91                Fut: Future<Output = Result<bool, adk_core::AdkError>>,
92            {
93                // Step 1: Create registry table if missing
94                let create_sql = format!(
95                    "CREATE TABLE IF NOT EXISTS {registry_table} (\
96                        version {} PRIMARY KEY, \
97                        description TEXT NOT NULL, \
98                        applied_at TEXT NOT NULL\
99                    )",
100                    $int_type
101                );
102                sqlx::query(&create_sql).execute(pool).await.map_err(|e| {
103                    adk_core::AdkError::session(format!("migration registry creation failed: {e}"))
104                })?;
105
106                // Step 2: Read current max applied version
107                let max_sql =
108                    format!("SELECT COALESCE(MAX(version), 0) AS max_v FROM {registry_table}");
109                let row = sqlx::query(&max_sql).fetch_one(pool).await.map_err(|e| {
110                    adk_core::AdkError::session(format!("migration registry read failed: {e}"))
111                })?;
112                let mut max_applied: i64 = row.try_get("max_v").map_err(|e| {
113                    adk_core::AdkError::session(format!("migration registry read failed: {e}"))
114                })?;
115
116                // Step 3: Baseline detection — if registry is empty but
117                // tables already exist, record v1 as applied.
118                if max_applied == 0 {
119                    let existing = detect_existing().await?;
120                    if existing {
121                        if let Some(&(v, desc, _)) = steps.first() {
122                            let now = Utc::now().to_rfc3339();
123                            let ins = format!(
124                                "INSERT INTO {registry_table} \
125                                 (version, description, applied_at) \
126                                 VALUES ({v}, '{desc}', '{now}')"
127                            );
128                            sqlx::query(&ins).execute(pool).await.map_err(|e| {
129                                adk_core::AdkError::session(format!(
130                                    "{}",
131                                    MigrationError {
132                                        version: v,
133                                        description: desc.to_string(),
134                                        cause: e.to_string(),
135                                    }
136                                ))
137                            })?;
138                            max_applied = v;
139                        }
140                    }
141                }
142
143                // Step 4: Compiled-in max version
144                let max_compiled = steps.last().map(|s| s.0).unwrap_or(0);
145
146                // Step 5: Version mismatch check
147                if max_applied > max_compiled {
148                    return Err(adk_core::AdkError::session(format!(
149                        "schema version mismatch: database is at v{max_applied} \
150                         but code only knows up to v{max_compiled}. \
151                         Upgrade your ADK version."
152                    )));
153                }
154
155                // Step 6: Execute unapplied steps in transactions
156                for &(version, description, sql) in steps {
157                    if version <= max_applied {
158                        continue;
159                    }
160
161                    let mut tx = pool.begin().await.map_err(|e| {
162                        adk_core::AdkError::session(format!(
163                            "{}",
164                            MigrationError {
165                                version,
166                                description: description.to_string(),
167                                cause: format!("transaction begin failed: {e}"),
168                            }
169                        ))
170                    })?;
171
172                    // Execute the migration SQL (raw_sql supports multiple
173                    // semicolon-separated statements in a single call).
174                    sqlx::raw_sql(sql).execute(&mut *tx).await.map_err(|e| {
175                        adk_core::AdkError::session(format!(
176                            "{}",
177                            MigrationError {
178                                version,
179                                description: description.to_string(),
180                                cause: e.to_string(),
181                            }
182                        ))
183                    })?;
184
185                    // Record the step in the registry
186                    let now = Utc::now().to_rfc3339();
187                    let rec = format!(
188                        "INSERT INTO {registry_table} \
189                         (version, description, applied_at) \
190                         VALUES ({version}, '{description}', '{now}')"
191                    );
192                    sqlx::query(&rec).execute(&mut *tx).await.map_err(|e| {
193                        adk_core::AdkError::session(format!(
194                            "{}",
195                            MigrationError {
196                                version,
197                                description: description.to_string(),
198                                cause: format!("registry record failed: {e}"),
199                            }
200                        ))
201                    })?;
202
203                    tx.commit().await.map_err(|e| {
204                        adk_core::AdkError::session(format!(
205                            "{}",
206                            MigrationError {
207                                version,
208                                description: description.to_string(),
209                                cause: format!("transaction commit failed: {e}"),
210                            }
211                        ))
212                    })?;
213                }
214
215                Ok(())
216            }
217
218            /// Returns the highest applied migration version, or 0 if no
219            /// registry table exists or the registry is empty.
220            pub async fn sql_schema_version(
221                pool: &$pool_ty,
222                registry_table: &str,
223            ) -> Result<i64, adk_core::AdkError> {
224                let sql =
225                    format!("SELECT COALESCE(MAX(version), 0) AS max_v FROM {registry_table}");
226                match sqlx::query(&sql).fetch_one(pool).await {
227                    Ok(row) => {
228                        let version: i64 = row.try_get("max_v").unwrap_or(0);
229                        Ok(version)
230                    }
231                    Err(_) => Ok(0),
232                }
233            }
234        }
235    };
236}
237
238#[cfg(feature = "sqlite")]
239impl_sql_migration_runner!(sqlite_runner, sqlx::SqlitePool, "INTEGER");
240
241#[cfg(feature = "postgres")]
242impl_sql_migration_runner!(pg_runner, sqlx::PgPool, "BIGINT");