a2a_protocol_server/store/
migration.rs1use sqlx::sqlite::SqlitePool;
46use sqlx::Row;
47
48#[derive(Debug, Clone)]
53pub struct Migration {
54 pub version: u32,
57 pub description: &'static str,
59 pub sql: &'static str,
62}
63
64pub static BUILTIN_MIGRATIONS: &[Migration] = &[
68 Migration {
69 version: 1,
70 description: "Initial schema: tasks table with context_id and state indexes",
71 sql: "\
72CREATE TABLE IF NOT EXISTS tasks (
73 id TEXT PRIMARY KEY,
74 context_id TEXT NOT NULL,
75 state TEXT NOT NULL,
76 data TEXT NOT NULL,
77 updated_at TEXT NOT NULL DEFAULT (datetime('now'))
78);
79CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id);
80CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state);",
81 },
82 Migration {
83 version: 2,
84 description: "Add created_at column to tasks table",
85 sql: "ALTER TABLE tasks ADD COLUMN created_at TEXT NOT NULL DEFAULT (datetime('now'));",
86 },
87 Migration {
88 version: 3,
89 description: "Add composite index on (context_id, state) for combined filter queries",
90 sql: "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state);",
91 },
92];
93
94#[derive(Debug, Clone)]
107pub struct MigrationRunner {
108 pool: SqlitePool,
109 migrations: &'static [Migration],
110}
111
112impl MigrationRunner {
113 #[must_use]
115 pub fn new(pool: SqlitePool) -> Self {
116 Self {
117 pool,
118 migrations: BUILTIN_MIGRATIONS,
119 }
120 }
121
122 #[must_use]
126 pub const fn with_migrations(pool: SqlitePool, migrations: &'static [Migration]) -> Self {
127 Self { pool, migrations }
128 }
129
130 async fn ensure_version_table(&self) -> Result<(), sqlx::Error> {
132 sqlx::query(
133 "CREATE TABLE IF NOT EXISTS schema_versions (
134 version INTEGER PRIMARY KEY,
135 description TEXT NOT NULL,
136 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
137 )",
138 )
139 .execute(&self.pool)
140 .await?;
141 Ok(())
142 }
143
144 pub async fn current_version(&self) -> Result<u32, sqlx::Error> {
151 self.ensure_version_table().await?;
152 let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
153 .fetch_one(&self.pool)
154 .await?;
155 let version: i32 = row.get("v");
156 #[allow(clippy::cast_sign_loss)]
157 Ok(version as u32)
158 }
159
160 pub async fn pending_migrations(&self) -> Result<Vec<&Migration>, sqlx::Error> {
166 let current = self.current_version().await?;
167 Ok(self
168 .migrations
169 .iter()
170 .filter(|m| m.version > current)
171 .collect())
172 }
173
174 pub async fn run_pending(&self) -> Result<Vec<u32>, sqlx::Error> {
186 self.ensure_version_table().await?;
187
188 let mut applied = Vec::new();
189
190 for migration in self.migrations {
191 let mut conn = self.pool.acquire().await?;
196 sqlx::query("BEGIN EXCLUSIVE").execute(&mut *conn).await?;
197
198 let row = sqlx::query("SELECT COALESCE(MAX(version), 0) AS v FROM schema_versions")
201 .fetch_one(&mut *conn)
202 .await?;
203 let current: i32 = row.get("v");
204 #[allow(clippy::cast_sign_loss)]
205 let current = current as u32;
206
207 if migration.version <= current {
208 sqlx::query("ROLLBACK").execute(&mut *conn).await?;
210 continue;
211 }
212
213 for statement in migration.sql.split(';') {
217 let trimmed = statement.trim();
218 if trimmed.is_empty() {
219 continue;
220 }
221 sqlx::query(trimmed).execute(&mut *conn).await?;
222 }
223
224 sqlx::query("INSERT INTO schema_versions (version, description) VALUES (?1, ?2)")
226 .bind(migration.version)
227 .bind(migration.description)
228 .execute(&mut *conn)
229 .await?;
230
231 sqlx::query("COMMIT").execute(&mut *conn).await?;
232 applied.push(migration.version);
233 }
234
235 Ok(applied)
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use sqlx::sqlite::SqlitePoolOptions;
243
244 async fn memory_pool() -> SqlitePool {
246 SqlitePoolOptions::new()
247 .max_connections(1)
248 .connect("sqlite::memory:")
249 .await
250 .expect("failed to open in-memory sqlite")
251 }
252
253 #[tokio::test]
254 async fn current_version_starts_at_zero() {
255 let pool = memory_pool().await;
256 let runner = MigrationRunner::new(pool);
257 assert_eq!(runner.current_version().await.unwrap(), 0);
258 }
259
260 #[tokio::test]
261 async fn run_pending_applies_all_builtin_migrations() {
262 let pool = memory_pool().await;
263 let runner = MigrationRunner::new(pool.clone());
264
265 let applied = runner.run_pending().await.unwrap();
266 assert_eq!(applied, vec![1, 2, 3]);
267 assert_eq!(runner.current_version().await.unwrap(), 3);
268
269 let row = sqlx::query("PRAGMA table_info(tasks)")
271 .fetch_all(&pool)
272 .await
273 .unwrap();
274 let columns: Vec<String> = row.iter().map(|r| r.get::<String, _>("name")).collect();
275 assert!(columns.contains(&"id".to_string()));
276 assert!(columns.contains(&"context_id".to_string()));
277 assert!(columns.contains(&"state".to_string()));
278 assert!(columns.contains(&"data".to_string()));
279 assert!(columns.contains(&"updated_at".to_string()));
280 assert!(columns.contains(&"created_at".to_string()));
281 }
282
283 #[tokio::test]
284 async fn run_pending_is_idempotent() {
285 let pool = memory_pool().await;
286 let runner = MigrationRunner::new(pool);
287
288 let first = runner.run_pending().await.unwrap();
289 assert_eq!(first, vec![1, 2, 3]);
290
291 let second = runner.run_pending().await.unwrap();
292 assert!(second.is_empty());
293
294 assert_eq!(runner.current_version().await.unwrap(), 3);
295 }
296
297 #[tokio::test]
298 async fn pending_migrations_returns_unapplied() {
299 let pool = memory_pool().await;
300 let runner = MigrationRunner::new(pool);
301
302 let pending = runner.pending_migrations().await.unwrap();
303 assert_eq!(pending.len(), 3);
304 assert_eq!(pending[0].version, 1);
305 assert_eq!(pending[1].version, 2);
306 assert_eq!(pending[2].version, 3);
307
308 runner.run_pending().await.unwrap();
309
310 let pending = runner.pending_migrations().await.unwrap();
311 assert!(pending.is_empty());
312 }
313
314 #[tokio::test]
315 async fn partial_application_tracks_correctly() {
316 let pool = memory_pool().await;
318
319 let v1_only: &[Migration] = &BUILTIN_MIGRATIONS[..1];
320 let runner = MigrationRunner::with_migrations(pool.clone(), v1_only);
323 let applied = runner.run_pending().await.unwrap();
324 assert_eq!(applied, vec![1]);
325 assert_eq!(runner.current_version().await.unwrap(), 1);
326
327 let full_runner = MigrationRunner::new(pool);
329 let pending = full_runner.pending_migrations().await.unwrap();
330 assert_eq!(pending.len(), 2);
331 assert_eq!(pending[0].version, 2);
332 assert_eq!(pending[1].version, 3);
333
334 let applied = full_runner.run_pending().await.unwrap();
335 assert_eq!(applied, vec![2, 3]);
336 assert_eq!(full_runner.current_version().await.unwrap(), 3);
337 }
338
339 #[tokio::test]
340 async fn schema_versions_table_records_metadata() {
341 let pool = memory_pool().await;
342 let runner = MigrationRunner::new(pool.clone());
343 runner.run_pending().await.unwrap();
344
345 let rows = sqlx::query(
346 "SELECT version, description, applied_at FROM schema_versions ORDER BY version",
347 )
348 .fetch_all(&pool)
349 .await
350 .unwrap();
351
352 assert_eq!(rows.len(), 3);
353 assert_eq!(rows[0].get::<i32, _>("version"), 1);
354 assert!(!rows[0].get::<String, _>("description").is_empty());
355 assert!(!rows[0].get::<String, _>("applied_at").is_empty());
356 }
357
358 #[tokio::test]
359 async fn composite_index_exists_after_v3() {
360 let pool = memory_pool().await;
361 let runner = MigrationRunner::new(pool.clone());
362 runner.run_pending().await.unwrap();
363
364 let rows = sqlx::query("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_tasks_context_id_state'")
365 .fetch_all(&pool)
366 .await
367 .unwrap();
368
369 assert_eq!(rows.len(), 1);
370 }
371}