systemprompt_database/lifecycle/
migrations.rs1use crate::services::{DatabaseProvider, SqlExecutor};
2use std::collections::HashSet;
3use systemprompt_extension::{Extension, LoaderError, Migration};
4use tracing::{debug, info, warn};
5
6#[derive(Debug, Clone)]
7pub struct AppliedMigration {
8 pub extension_id: String,
9 pub version: u32,
10 pub name: String,
11 pub checksum: String,
12}
13
14#[derive(Debug, Default, Clone, Copy)]
15pub struct MigrationResult {
16 pub migrations_run: usize,
17 pub migrations_skipped: usize,
18}
19
20pub struct MigrationService<'a> {
21 db: &'a dyn DatabaseProvider,
22}
23
24impl std::fmt::Debug for MigrationService<'_> {
25 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
26 f.debug_struct("MigrationService").finish_non_exhaustive()
27 }
28}
29
30impl<'a> MigrationService<'a> {
31 pub fn new(db: &'a dyn DatabaseProvider) -> Self {
32 Self { db }
33 }
34
35 async fn ensure_migrations_table_exists(&self) -> Result<(), LoaderError> {
36 let sql = include_str!("../../schema/extension_migrations.sql");
37 SqlExecutor::execute_statements_parsed(self.db, sql)
38 .await
39 .map_err(|e| LoaderError::MigrationFailed {
40 extension: "database".to_string(),
41 message: format!("Failed to ensure migrations table exists: {e}"),
42 })
43 }
44
45 pub async fn get_applied_migrations(
46 &self,
47 extension_id: &str,
48 ) -> Result<Vec<AppliedMigration>, LoaderError> {
49 let result = self
50 .db
51 .query_raw_with(
52 &"SELECT extension_id, version, name, checksum FROM extension_migrations WHERE \
53 extension_id = $1 ORDER BY version",
54 vec![serde_json::Value::String(extension_id.to_string())],
55 )
56 .await
57 .map_err(|e| LoaderError::MigrationFailed {
58 extension: extension_id.to_string(),
59 message: format!("Failed to query applied migrations: {e}"),
60 })?;
61
62 let migrations = result
63 .rows
64 .iter()
65 .filter_map(|row| {
66 Some(AppliedMigration {
67 extension_id: row.get("extension_id")?.as_str()?.to_string(),
68 version: row.get("version")?.as_i64()? as u32,
69 name: row.get("name")?.as_str()?.to_string(),
70 checksum: row.get("checksum")?.as_str()?.to_string(),
71 })
72 })
73 .collect();
74
75 Ok(migrations)
76 }
77
78 pub async fn run_pending_migrations(
79 &self,
80 extension: &dyn Extension,
81 ) -> Result<MigrationResult, LoaderError> {
82 let ext_id = extension.metadata().id;
83 let migrations = extension.migrations();
84
85 if migrations.is_empty() {
86 return Ok(MigrationResult::default());
87 }
88
89 self.ensure_migrations_table_exists().await?;
90
91 let applied = self.get_applied_migrations(ext_id).await?;
92 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
93 let applied_checksums: std::collections::HashMap<u32, &str> = applied
94 .iter()
95 .map(|m| (m.version, m.checksum.as_str()))
96 .collect();
97
98 let mut migrations_run = 0;
99 let mut migrations_skipped = 0;
100
101 for migration in &migrations {
102 if applied_versions.contains(&migration.version) {
103 let current_checksum = migration.checksum();
104 if let Some(&stored_checksum) = applied_checksums.get(&migration.version) {
105 if stored_checksum != current_checksum {
106 warn!(
107 extension = %ext_id,
108 version = migration.version,
109 name = %migration.name,
110 stored_checksum = %stored_checksum,
111 current_checksum = %current_checksum,
112 "Migration checksum mismatch - SQL has changed since it was applied"
113 );
114 }
115 }
116 migrations_skipped += 1;
117 debug!(
118 extension = %ext_id,
119 version = migration.version,
120 "Migration already applied, skipping"
121 );
122 continue;
123 }
124
125 self.execute_migration(ext_id, migration).await?;
126 migrations_run += 1;
127 }
128
129 if migrations_run > 0 {
130 info!(
131 extension = %ext_id,
132 migrations_run,
133 migrations_skipped,
134 "Migrations completed"
135 );
136 }
137
138 Ok(MigrationResult {
139 migrations_run,
140 migrations_skipped,
141 })
142 }
143
144 async fn execute_migration(
145 &self,
146 ext_id: &str,
147 migration: &Migration,
148 ) -> Result<(), LoaderError> {
149 info!(
150 extension = %ext_id,
151 version = migration.version,
152 name = %migration.name,
153 "Running migration"
154 );
155
156 SqlExecutor::execute_statements_parsed(self.db, migration.sql)
157 .await
158 .map_err(|e| LoaderError::MigrationFailed {
159 extension: ext_id.to_string(),
160 message: format!(
161 "Failed to execute migration {} ({}): {e}",
162 migration.version, migration.name
163 ),
164 })?;
165
166 self.record_migration(ext_id, migration).await?;
167
168 Ok(())
169 }
170
171 async fn record_migration(
172 &self,
173 ext_id: &str,
174 migration: &Migration,
175 ) -> Result<(), LoaderError> {
176 let id = format!("{}_{:03}", ext_id, migration.version);
177 let checksum = migration.checksum();
178 let name = migration.name.replace('\'', "''");
179
180 let sql = format!(
181 "INSERT INTO extension_migrations (id, extension_id, version, name, checksum) VALUES \
182 ('{}', '{}', {}, '{}', '{}')",
183 id, ext_id, migration.version, name, checksum
184 );
185
186 self.db
187 .execute_raw(&sql)
188 .await
189 .map_err(|e| LoaderError::MigrationFailed {
190 extension: ext_id.to_string(),
191 message: format!("Failed to record migration: {e}"),
192 })?;
193
194 Ok(())
195 }
196
197 pub async fn get_migration_status(
198 &self,
199 extension: &dyn Extension,
200 ) -> Result<MigrationStatus, LoaderError> {
201 self.ensure_migrations_table_exists().await?;
202
203 let ext_id = extension.metadata().id;
204 let defined_migrations = extension.migrations();
205 let applied = self.get_applied_migrations(ext_id).await?;
206
207 let applied_versions: HashSet<u32> = applied.iter().map(|m| m.version).collect();
208
209 let pending: Vec<_> = defined_migrations
210 .iter()
211 .filter(|m| !applied_versions.contains(&m.version))
212 .cloned()
213 .collect();
214
215 Ok(MigrationStatus {
216 extension_id: ext_id.to_string(),
217 total_defined: defined_migrations.len(),
218 total_applied: applied.len(),
219 pending_count: pending.len(),
220 pending,
221 applied,
222 })
223 }
224}
225
226#[derive(Debug)]
227pub struct MigrationStatus {
228 pub extension_id: String,
229 pub total_defined: usize,
230 pub total_applied: usize,
231 pub pending_count: usize,
232 pub pending: Vec<Migration>,
233 pub applied: Vec<AppliedMigration>,
234}