Skip to main content

systemprompt_database/lifecycle/
migrations.rs

1use crate::services::{DatabaseProvider, SqlExecutor};
2use std::collections::HashSet;
3use systemprompt_extension::{Extension, LoaderError, Migration};
4use tracing::{debug, info, warn};
5
6#[derive(Debug, Clone)]
7pub struct AppliedMigration {
8    pub extension_id: String,
9    pub version: u32,
10    pub name: String,
11    pub checksum: String,
12}
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MigrationResult {
16    pub migrations_run: usize,
17    pub migrations_skipped: usize,
18}
19
20pub struct MigrationService<'a> {
21    db: &'a dyn DatabaseProvider,
22}
23
24impl std::fmt::Debug for MigrationService<'_> {
25    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26        f.debug_struct("MigrationService").finish_non_exhaustive()
27    }
28}
29
30impl<'a> MigrationService<'a> {
31    pub fn new(db: &'a dyn DatabaseProvider) -> Self {
32        Self { db }
33    }
34
35    async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
36        let sql = include_str!("../../schema/extension_migrations.sql");
37        SqlExecutor::execute_statements_parsed(self.db, sql)
38            .await
39            .map_err(|e| LoaderError::MigrationFailed {
40                extension: "database".to_string(),
41                message: format!("Failed to ensure migrations table exists: {e}"),
42            })
43    }
44
45    pub async fn get_applied_migrations(
46        &self,
47        extension_id: &str,
48    ) -> Result<Vec<AppliedMigration>, LoaderError> {
49        let result = self
50            .db
51            .query_raw_with(
52                &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
53                  extension_id = $1 ORDER BY version",
54                vec![serde_json::Value::String(extension_id.to_string())],
55            )
56            .await
57            .map_err(|e| LoaderError::MigrationFailed {
58                extension: extension_id.to_string(),
59                message: format!("Failed to query applied migrations: {e}"),
60            })?;
61
62        let migrations = result
63            .rows
64            .iter()
65            .filter_map(|row| {
66                Some(AppliedMigration {
67                    extension_id: row.get("extension_id")?.as_str()?.to_string(),
68                    version: row.get("version")?.as_i64()? as u32,
69                    name: row.get("name")?.as_str()?.to_string(),
70                    checksum: row.get("checksum")?.as_str()?.to_string(),
71                })
72            })
73            .collect();
74
75        Ok(migrations)
76    }
77
78    pub async fn run_pending_migrations(
79        &self,
80        extension: &dyn Extension,
81    ) -> Result<MigrationResult, LoaderError> {
82        let ext_id = extension.metadata().id;
83        let migrations = extension.migrations();
84
85        if migrations.is_empty() {
86            return Ok(MigrationResult::default());
87        }
88
89        self.ensure_migrations_table_exists().await?;
90
91        let applied = self.get_applied_migrations(ext_id).await?;
92        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
93        let applied_checksums: std::collections::HashMap<u32, &str> = applied
94            .iter()
95            .map(|m| (m.version, m.checksum.as_str()))
96            .collect();
97
98        let mut migrations_run = 0;
99        let mut migrations_skipped = 0;
100
101        for migration in &migrations {
102            if applied_versions.contains(&migration.version) {
103                let current_checksum = migration.checksum();
104                if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
105                    if stored_checksum != current_checksum {
106                        warn!(
107                            extension = %ext_id,
108                            version = migration.version,
109                            name = %migration.name,
110                            stored_checksum = %stored_checksum,
111                            current_checksum = %current_checksum,
112                            "Migration checksum mismatch - SQL has changed since it was applied"
113                        );
114                    }
115                }
116                migrations_skipped += 1;
117                debug!(
118                    extension = %ext_id,
119                    version = migration.version,
120                    "Migration already applied, skipping"
121                );
122                continue;
123            }
124
125            self.execute_migration(ext_id, migration).await?;
126            migrations_run += 1;
127        }
128
129        if migrations_run > 0 {
130            info!(
131                extension = %ext_id,
132                migrations_run,
133                migrations_skipped,
134                "Migrations completed"
135            );
136        }
137
138        Ok(MigrationResult {
139            migrations_run,
140            migrations_skipped,
141        })
142    }
143
144    async fn execute_migration(
145        &self,
146        ext_id: &str,
147        migration: &Migration,
148    ) -> Result<(), LoaderError> {
149        info!(
150            extension = %ext_id,
151            version = migration.version,
152            name = %migration.name,
153            "Running migration"
154        );
155
156        SqlExecutor::execute_statements_parsed(self.db, migration.sql)
157            .await
158            .map_err(|e| LoaderError::MigrationFailed {
159                extension: ext_id.to_string(),
160                message: format!(
161                    "Failed to execute migration {} ({}): {e}",
162                    migration.version, migration.name
163                ),
164            })?;
165
166        self.record_migration(ext_id, migration).await?;
167
168        Ok(())
169    }
170
171    async fn record_migration(
172        &self,
173        ext_id: &str,
174        migration: &Migration,
175    ) -> Result<(), LoaderError> {
176        let id = format!("{}_{:03}", ext_id, migration.version);
177        let checksum = migration.checksum();
178        let name = migration.name.replace('\'', "''");
179
180        let sql = format!(
181            "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
182             ('{}', '{}', {}, '{}', '{}')",
183            id, ext_id, migration.version, name, checksum
184        );
185
186        self.db
187            .execute_raw(&sql)
188            .await
189            .map_err(|e| LoaderError::MigrationFailed {
190                extension: ext_id.to_string(),
191                message: format!("Failed to record migration: {e}"),
192            })?;
193
194        Ok(())
195    }
196
197    pub async fn get_migration_status(
198        &self,
199        extension: &dyn Extension,
200    ) -> Result<MigrationStatus, LoaderError> {
201        self.ensure_migrations_table_exists().await?;
202
203        let ext_id = extension.metadata().id;
204        let defined_migrations = extension.migrations();
205        let applied = self.get_applied_migrations(ext_id).await?;
206
207        let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
208
209        let pending: Vec<_> = defined_migrations
210            .iter()
211            .filter(|m| !applied_versions.contains(&m.version))
212            .cloned()
213            .collect();
214
215        Ok(MigrationStatus {
216            extension_id: ext_id.to_string(),
217            total_defined: defined_migrations.len(),
218            total_applied: applied.len(),
219            pending_count: pending.len(),
220            pending,
221            applied,
222        })
223    }
224}
225
226#[derive(Debug)]
227pub struct MigrationStatus {
228    pub extension_id: String,
229    pub total_defined: usize,
230    pub total_applied: usize,
231    pub pending_count: usize,
232    pub pending: Vec<Migration>,
233    pub applied: Vec<AppliedMigration>,
234}