systemprompt_database/lifecycle/
migrations.rs1use crate::services::{DatabaseProvider, SqlExecutor};
5use std::collections::HashSet;
6use systemprompt_extension::{Extension, LoaderError, Migration};
7use tracing::{debug, info, warn};
8
9#[derive(Debug, Clone)]
10pub struct AppliedMigration {
11 pub extension_id: String,
12 pub version: u32,
13 pub name: String,
14 pub checksum: String,
15}
16
17#[derive(Debug, Default, Clone, Copy)]
18pub struct MigrationResult {
19 pub migrations_run: usize,
20 pub migrations_skipped: usize,
21}
22
23pub struct MigrationService<'a> {
24 db: &'a dyn DatabaseProvider,
25}
26
27impl std::fmt::Debug for MigrationService<'_> {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 f.debug_struct("MigrationService").finish_non_exhaustive()
30 }
31}
32
33impl<'a> MigrationService<'a> {
34 pub fn new(db: &'a dyn DatabaseProvider) -> Self {
35 Self { db }
36 }
37
38 async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
39 let sql = include_str!("../../schema/extension_migrations.sql");
40 SqlExecutor::execute_statements_parsed(self.db, sql)
41 .await
42 .map_err(|e| LoaderError::MigrationFailed {
43 extension: "database".to_string(),
44 message: format!("Failed to ensure migrations table exists: {e}"),
45 })
46 }
47
48 pub async fn get_applied_migrations(
49 &self,
50 extension_id: &str,
51 ) -> Result<Vec<AppliedMigration>, LoaderError> {
52 let result = self
53 .db
54 .query_raw_with(
55 &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
56 extension_id = $1 ORDER BY version",
57 vec![serde_json::Value::String(extension_id.to_string())],
58 )
59 .await
60 .map_err(|e| LoaderError::MigrationFailed {
61 extension: extension_id.to_string(),
62 message: format!("Failed to query applied migrations: {e}"),
63 })?;
64
65 let migrations = result
66 .rows
67 .iter()
68 .filter_map(|row| {
69 Some(AppliedMigration {
70 extension_id: row.get("extension_id")?.as_str()?.to_string(),
71 version: row.get("version")?.as_i64()? as u32,
72 name: row.get("name")?.as_str()?.to_string(),
73 checksum: row.get("checksum")?.as_str()?.to_string(),
74 })
75 })
76 .collect();
77
78 Ok(migrations)
79 }
80
81 pub async fn run_pending_migrations(
82 &self,
83 extension: &dyn Extension,
84 ) -> Result<MigrationResult, LoaderError> {
85 let ext_id = extension.metadata().id;
86 let migrations = extension.migrations();
87
88 if migrations.is_empty() {
89 return Ok(MigrationResult::default());
90 }
91
92 self.ensure_migrations_table_exists().await?;
93
94 let applied = self.get_applied_migrations(ext_id).await?;
95 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
96 let applied_checksums: std::collections::HashMap<u32, &str> = applied
97 .iter()
98 .map(|m| (m.version, m.checksum.as_str()))
99 .collect();
100
101 let mut migrations_run = 0;
102 let mut migrations_skipped = 0;
103
104 for migration in &migrations {
105 if applied_versions.contains(&migration.version) {
106 let current_checksum = migration.checksum();
107 if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
108 if stored_checksum != current_checksum {
109 warn!(
110 extension = %ext_id,
111 version = migration.version,
112 name = %migration.name,
113 stored_checksum = %stored_checksum,
114 current_checksum = %current_checksum,
115 "Migration checksum mismatch - SQL has changed since it was applied"
116 );
117 }
118 }
119 migrations_skipped += 1;
120 debug!(
121 extension = %ext_id,
122 version = migration.version,
123 "Migration already applied, skipping"
124 );
125 continue;
126 }
127
128 self.execute_migration(ext_id, migration).await?;
129 migrations_run += 1;
130 }
131
132 if migrations_run > 0 {
133 info!(
134 extension = %ext_id,
135 migrations_run,
136 migrations_skipped,
137 "Migrations completed"
138 );
139 }
140
141 Ok(MigrationResult {
142 migrations_run,
143 migrations_skipped,
144 })
145 }
146
147 async fn execute_migration(
148 &self,
149 ext_id: &str,
150 migration: &Migration,
151 ) -> Result<(), LoaderError> {
152 info!(
153 extension = %ext_id,
154 version = migration.version,
155 name = %migration.name,
156 "Running migration"
157 );
158
159 SqlExecutor::execute_statements_parsed(self.db, migration.sql)
160 .await
161 .map_err(|e| LoaderError::MigrationFailed {
162 extension: ext_id.to_string(),
163 message: format!(
164 "Failed to execute migration {} ({}): {e}",
165 migration.version, migration.name
166 ),
167 })?;
168
169 self.record_migration(ext_id, migration).await?;
170
171 Ok(())
172 }
173
174 async fn record_migration(
175 &self,
176 ext_id: &str,
177 migration: &Migration,
178 ) -> Result<(), LoaderError> {
179 let id = format!("{}_{:03}", ext_id, migration.version);
180 let checksum = migration.checksum();
181 let name = migration.name.replace('\'', "''");
182
183 let sql = format!(
184 "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
185 ('{}', '{}', {}, '{}', '{}')",
186 id, ext_id, migration.version, name, checksum
187 );
188
189 self.db
190 .execute_raw(&sql)
191 .await
192 .map_err(|e| LoaderError::MigrationFailed {
193 extension: ext_id.to_string(),
194 message: format!("Failed to record migration: {e}"),
195 })?;
196
197 Ok(())
198 }
199
200 pub async fn get_migration_status(
201 &self,
202 extension: &dyn Extension,
203 ) -> Result<MigrationStatus, LoaderError> {
204 self.ensure_migrations_table_exists().await?;
205
206 let ext_id = extension.metadata().id;
207 let defined_migrations = extension.migrations();
208 let applied = self.get_applied_migrations(ext_id).await?;
209
210 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
211
212 let pending: Vec<_> = defined_migrations
213 .iter()
214 .filter(|m| !applied_versions.contains(&m.version))
215 .cloned()
216 .collect();
217
218 Ok(MigrationStatus {
219 extension_id: ext_id.to_string(),
220 total_defined: defined_migrations.len(),
221 total_applied: applied.len(),
222 pending_count: pending.len(),
223 pending,
224 applied,
225 })
226 }
227}
228
229#[derive(Debug)]
230pub struct MigrationStatus {
231 pub extension_id: String,
232 pub total_defined: usize,
233 pub total_applied: usize,
234 pub pending_count: usize,
235 pub pending: Vec<Migration>,
236 pub applied: Vec<AppliedMigration>,
237}