Skip to main content

punch_memory/
migrations.rs

1//! Database migration engine for the Punch memory substrate.
2//!
3//! Tracks applied migrations in a `_punch_migrations` table and supports
4//! forward (up) and backward (down) migration with SHA-256 checksum
5//! verification for integrity.
6
7use std::sync::Arc;
8
9use rusqlite::Connection;
10use sha2::{Digest, Sha256};
11use tracing::info;
12
13use punch_types::{PunchError, PunchResult};
14
15// ---------------------------------------------------------------------------
16// Public types
17// ---------------------------------------------------------------------------
18
19/// A single database migration with up and down SQL.
20#[derive(Debug, Clone)]
21pub struct Migration {
22    pub version: u64,
23    pub name: String,
24    pub up_sql: String,
25    pub down_sql: String,
26}
27
28impl Migration {
29    /// Compute the SHA-256 hex digest of the `up_sql` content.
30    pub fn checksum(&self) -> String {
31        let mut hasher = Sha256::new();
32        hasher.update(self.up_sql.as_bytes());
33        format!("{:x}", hasher.finalize())
34    }
35}
36
37/// Status of a single migration (applied or pending).
38#[derive(Debug, Clone)]
39pub struct MigrationStatus {
40    pub version: u64,
41    pub name: String,
42    pub applied: bool,
43    pub applied_at: Option<String>,
44}
45
46/// The migration engine manages schema versioning for a SQLite database.
47pub struct MigrationEngine {
48    /// Database connection.
49    conn: Arc<std::sync::Mutex<Connection>>,
50}
51
52impl MigrationEngine {
53    /// Create a new engine and ensure the tracking table exists.
54    pub fn new(conn: Arc<std::sync::Mutex<Connection>>) -> PunchResult<Self> {
55        {
56            let c = conn
57                .lock()
58                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
59            c.execute_batch(
60                "CREATE TABLE IF NOT EXISTS _punch_migrations (
61                    id         INTEGER PRIMARY KEY,
62                    version    INTEGER NOT NULL UNIQUE,
63                    name       TEXT NOT NULL,
64                    applied_at TEXT NOT NULL,
65                    checksum   TEXT NOT NULL
66                );",
67            )
68            .map_err(|e| PunchError::Memory(format!("failed to create migrations table: {e}")))?;
69        }
70        Ok(Self { conn })
71    }
72
73    /// Return the highest applied migration version, or 0 if none.
74    pub fn current_version(&self) -> PunchResult<u64> {
75        let c = self
76            .conn
77            .lock()
78            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
79        let version: Option<u64> = c
80            .query_row("SELECT MAX(version) FROM _punch_migrations", [], |row| {
81                row.get(0)
82            })
83            .map_err(|e| PunchError::Memory(format!("failed to query current version: {e}")))?;
84        Ok(version.unwrap_or(0))
85    }
86
87    /// Return migrations from `all` that have not yet been applied, sorted by
88    /// version ascending.
89    pub fn pending_migrations<'a>(&self, all: &'a [Migration]) -> PunchResult<Vec<&'a Migration>> {
90        let applied = self.applied_versions()?;
91        let mut pending: Vec<&Migration> = all
92            .iter()
93            .filter(|m| !applied.contains(&m.version))
94            .collect();
95        pending.sort_by_key(|m| m.version);
96        Ok(pending)
97    }
98
99    /// Apply all pending migrations in order. Each migration runs inside its
100    /// own transaction. Returns the versions that were applied.
101    pub fn migrate_up(&self, migrations: &[Migration]) -> PunchResult<Vec<u64>> {
102        let pending = self.pending_migrations(migrations)?;
103        let mut applied = Vec::new();
104
105        for migration in pending {
106            let c = self
107                .conn
108                .lock()
109                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
110            let tx = c
111                .unchecked_transaction()
112                .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
113
114            tx.execute_batch(&migration.up_sql).map_err(|e| {
115                PunchError::Memory(format!(
116                    "migration v{} ({}) failed: {e}",
117                    migration.version, migration.name
118                ))
119            })?;
120
121            let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
122            tx.execute(
123                "INSERT INTO _punch_migrations (version, name, applied_at, checksum)
124                 VALUES (?1, ?2, ?3, ?4)",
125                rusqlite::params![migration.version, migration.name, now, migration.checksum(),],
126            )
127            .map_err(|e| {
128                PunchError::Memory(format!(
129                    "failed to record migration v{}: {e}",
130                    migration.version
131                ))
132            })?;
133
134            tx.commit().map_err(|e| {
135                PunchError::Memory(format!(
136                    "failed to commit migration v{}: {e}",
137                    migration.version
138                ))
139            })?;
140
141            info!(version = migration.version, name = %migration.name, "applied migration");
142            applied.push(migration.version);
143        }
144
145        Ok(applied)
146    }
147
148    /// Roll back applied migrations whose version is greater than
149    /// `target_version`, in reverse order. Returns the versions that were
150    /// rolled back.
151    pub fn migrate_down(
152        &self,
153        migrations: &[Migration],
154        target_version: u64,
155    ) -> PunchResult<Vec<u64>> {
156        let applied = self.applied_versions()?;
157
158        // Collect migrations that need to be rolled back, sorted descending.
159        let mut to_rollback: Vec<&Migration> = migrations
160            .iter()
161            .filter(|m| m.version > target_version && applied.contains(&m.version))
162            .collect();
163        to_rollback.sort_by(|a, b| b.version.cmp(&a.version));
164
165        let mut rolled_back = Vec::new();
166
167        for migration in to_rollback {
168            let c = self
169                .conn
170                .lock()
171                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
172            let tx = c
173                .unchecked_transaction()
174                .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
175
176            tx.execute_batch(&migration.down_sql).map_err(|e| {
177                PunchError::Memory(format!(
178                    "down migration v{} ({}) failed: {e}",
179                    migration.version, migration.name
180                ))
181            })?;
182
183            tx.execute(
184                "DELETE FROM _punch_migrations WHERE version = ?1",
185                [migration.version],
186            )
187            .map_err(|e| {
188                PunchError::Memory(format!(
189                    "failed to remove migration record v{}: {e}",
190                    migration.version
191                ))
192            })?;
193
194            tx.commit().map_err(|e| {
195                PunchError::Memory(format!(
196                    "failed to commit down migration v{}: {e}",
197                    migration.version
198                ))
199            })?;
200
201            info!(version = migration.version, name = %migration.name, "rolled back migration");
202            rolled_back.push(migration.version);
203        }
204
205        Ok(rolled_back)
206    }
207
208    /// Show the status (applied / pending) of every known migration.
209    pub fn migration_status(&self, migrations: &[Migration]) -> PunchResult<Vec<MigrationStatus>> {
210        let c = self
211            .conn
212            .lock()
213            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
214
215        let mut stmt = c
216            .prepare("SELECT version, applied_at FROM _punch_migrations")
217            .map_err(|e| PunchError::Memory(format!("failed to query migration status: {e}")))?;
218
219        let rows: Vec<(u64, String)> = stmt
220            .query_map([], |row| {
221                let version: u64 = row.get(0)?;
222                let applied_at: String = row.get(1)?;
223                Ok((version, applied_at))
224            })
225            .map_err(|e| PunchError::Memory(format!("failed to read migration rows: {e}")))?
226            .filter_map(|r| r.ok())
227            .collect();
228
229        let mut statuses: Vec<MigrationStatus> = migrations
230            .iter()
231            .map(|m| {
232                let applied_row = rows.iter().find(|(v, _)| *v == m.version);
233                MigrationStatus {
234                    version: m.version,
235                    name: m.name.clone(),
236                    applied: applied_row.is_some(),
237                    applied_at: applied_row.map(|(_, at)| at.clone()),
238                }
239            })
240            .collect();
241        statuses.sort_by_key(|s| s.version);
242        Ok(statuses)
243    }
244
245    /// Verify that every applied migration's stored checksum matches the
246    /// current `up_sql` content. Returns an error on the first mismatch.
247    pub fn verify_checksums(&self, migrations: &[Migration]) -> PunchResult<()> {
248        let c = self
249            .conn
250            .lock()
251            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
252
253        let mut stmt = c
254            .prepare("SELECT version, checksum FROM _punch_migrations")
255            .map_err(|e| PunchError::Memory(format!("failed to query checksums: {e}")))?;
256
257        let rows: Vec<(u64, String)> = stmt
258            .query_map([], |row| {
259                let version: u64 = row.get(0)?;
260                let checksum: String = row.get(1)?;
261                Ok((version, checksum))
262            })
263            .map_err(|e| PunchError::Memory(format!("failed to read checksum rows: {e}")))?
264            .filter_map(|r| r.ok())
265            .collect();
266
267        for (version, stored_checksum) in &rows {
268            if let Some(migration) = migrations.iter().find(|m| m.version == *version) {
269                let current_checksum = migration.checksum();
270                if *stored_checksum != current_checksum {
271                    return Err(PunchError::Memory(format!(
272                        "checksum mismatch for migration v{} ({}): stored={}, current={}",
273                        version, migration.name, stored_checksum, current_checksum
274                    )));
275                }
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Return the 6 built-in migrations that define the Punch schema.
283    pub fn builtin_migrations() -> Vec<Migration> {
284        vec![
285            Migration {
286                version: 1,
287                name: "create_memories_table".into(),
288                up_sql: "CREATE TABLE IF NOT EXISTS memories (
289                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
290                    fighter_id  TEXT NOT NULL,
291                    key         TEXT NOT NULL,
292                    value       TEXT NOT NULL,
293                    confidence  REAL NOT NULL DEFAULT 1.0,
294                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
295                    accessed_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
296                    UNIQUE(fighter_id, key)
297                );"
298                .into(),
299                down_sql: "DROP TABLE IF EXISTS memories;".into(),
300            },
301            Migration {
302                version: 2,
303                name: "create_entities_table".into(),
304                up_sql: "CREATE TABLE IF NOT EXISTS knowledge_entities (
305                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
306                    fighter_id  TEXT NOT NULL,
307                    name        TEXT NOT NULL,
308                    entity_type TEXT NOT NULL,
309                    properties  TEXT NOT NULL DEFAULT '{}',
310                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
311                    UNIQUE(fighter_id, name, entity_type)
312                );"
313                .into(),
314                down_sql: "DROP TABLE IF EXISTS knowledge_entities;".into(),
315            },
316            Migration {
317                version: 3,
318                name: "create_relations_table".into(),
319                up_sql: "CREATE TABLE IF NOT EXISTS knowledge_relations (
320                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
321                    fighter_id  TEXT NOT NULL,
322                    from_entity TEXT NOT NULL,
323                    relation    TEXT NOT NULL,
324                    to_entity   TEXT NOT NULL,
325                    properties  TEXT NOT NULL DEFAULT '{}',
326                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
327                    UNIQUE(fighter_id, from_entity, relation, to_entity)
328                );"
329                .into(),
330                down_sql: "DROP TABLE IF EXISTS knowledge_relations;".into(),
331            },
332            Migration {
333                version: 4,
334                name: "create_bouts_table".into(),
335                up_sql: "CREATE TABLE IF NOT EXISTS bouts (
336                    id          TEXT PRIMARY KEY,
337                    fighter_id  TEXT NOT NULL,
338                    title       TEXT,
339                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
340                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
341                );"
342                .into(),
343                down_sql: "DROP TABLE IF EXISTS bouts;".into(),
344            },
345            Migration {
346                version: 5,
347                name: "create_bout_messages_table".into(),
348                up_sql: "CREATE TABLE IF NOT EXISTS messages (
349                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
350                    bout_id     TEXT NOT NULL,
351                    role        TEXT NOT NULL,
352                    content     TEXT NOT NULL DEFAULT '',
353                    metadata    TEXT,
354                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
355                );"
356                .into(),
357                down_sql: "DROP TABLE IF EXISTS messages;".into(),
358            },
359            Migration {
360                version: 6,
361                name: "add_indexes".into(),
362                up_sql: "
363                    CREATE INDEX IF NOT EXISTS idx_memories_fighter ON memories(fighter_id);
364                    CREATE INDEX IF NOT EXISTS idx_ke_fighter ON knowledge_entities(fighter_id);
365                    CREATE INDEX IF NOT EXISTS idx_kr_fighter ON knowledge_relations(fighter_id);
366                    CREATE INDEX IF NOT EXISTS idx_bouts_fighter ON bouts(fighter_id);
367                    CREATE INDEX IF NOT EXISTS idx_messages_bout ON messages(bout_id);
368                "
369                .into(),
370                down_sql: "
371                    DROP INDEX IF EXISTS idx_memories_fighter;
372                    DROP INDEX IF EXISTS idx_ke_fighter;
373                    DROP INDEX IF EXISTS idx_kr_fighter;
374                    DROP INDEX IF EXISTS idx_bouts_fighter;
375                    DROP INDEX IF EXISTS idx_messages_bout;
376                "
377                .into(),
378            },
379            Migration {
380                version: 7,
381                name: "create_fighters_table".into(),
382                up_sql: "CREATE TABLE IF NOT EXISTS fighters (
383                    id          TEXT PRIMARY KEY,
384                    name        TEXT NOT NULL,
385                    manifest    TEXT NOT NULL,
386                    status      TEXT NOT NULL DEFAULT 'idle',
387                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
388                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
389                );"
390                .into(),
391                down_sql: "DROP TABLE IF EXISTS fighters;".into(),
392            },
393            Migration {
394                version: 8,
395                name: "create_usage_events_table".into(),
396                up_sql: "CREATE TABLE IF NOT EXISTS usage_events (
397                    id              INTEGER PRIMARY KEY AUTOINCREMENT,
398                    fighter_id      TEXT NOT NULL,
399                    model           TEXT NOT NULL,
400                    input_tokens    INTEGER NOT NULL,
401                    output_tokens   INTEGER NOT NULL,
402                    cost_usd        REAL NOT NULL DEFAULT 0.0,
403                    created_at      TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
404                );
405                CREATE INDEX IF NOT EXISTS idx_usage_fighter ON usage_events(fighter_id);"
406                    .into(),
407                down_sql: "DROP INDEX IF EXISTS idx_usage_fighter;
408                DROP TABLE IF EXISTS usage_events;"
409                    .into(),
410            },
411            Migration {
412                version: 9,
413                name: "create_gorilla_state_table".into(),
414                up_sql: "CREATE TABLE IF NOT EXISTS gorilla_state (
415                    gorilla_id  TEXT PRIMARY KEY,
416                    state       TEXT NOT NULL DEFAULT '{}',
417                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
418                );"
419                .into(),
420                down_sql: "DROP TABLE IF EXISTS gorilla_state;".into(),
421            },
422            Migration {
423                version: 10,
424                name: "create_embeddings_table".into(),
425                up_sql: "CREATE TABLE IF NOT EXISTS embeddings (
426                    id         TEXT PRIMARY KEY,
427                    text       TEXT NOT NULL,
428                    vector     BLOB NOT NULL,
429                    metadata   TEXT NOT NULL DEFAULT '{}',
430                    created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
431                );"
432                .into(),
433                down_sql: "DROP TABLE IF EXISTS embeddings;".into(),
434            },
435            Migration {
436                version: 11,
437                name: "create_creeds_table".into(),
438                up_sql: "CREATE TABLE IF NOT EXISTS creeds (
439                    id          TEXT PRIMARY KEY,
440                    fighter_name TEXT NOT NULL,
441                    fighter_id  TEXT,
442                    creed_data  TEXT NOT NULL,
443                    version     INTEGER NOT NULL DEFAULT 1,
444                    bout_count  INTEGER NOT NULL DEFAULT 0,
445                    message_count INTEGER NOT NULL DEFAULT 0,
446                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
447                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
448                );
449                CREATE UNIQUE INDEX IF NOT EXISTS idx_creeds_fighter_name ON creeds(fighter_name);
450                CREATE INDEX IF NOT EXISTS idx_creeds_fighter_id ON creeds(fighter_id);"
451                    .into(),
452                down_sql: "DROP INDEX IF EXISTS idx_creeds_fighter_id;
453                DROP INDEX IF EXISTS idx_creeds_fighter_name;
454                DROP TABLE IF EXISTS creeds;"
455                    .into(),
456            },
457        ]
458    }
459
460    // -----------------------------------------------------------------------
461    // Internal helpers
462    // -----------------------------------------------------------------------
463
464    fn applied_versions(&self) -> PunchResult<Vec<u64>> {
465        let c = self
466            .conn
467            .lock()
468            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
469        let mut stmt = c
470            .prepare("SELECT version FROM _punch_migrations ORDER BY version")
471            .map_err(|e| PunchError::Memory(format!("failed to query applied versions: {e}")))?;
472
473        let versions: Vec<u64> = stmt
474            .query_map([], |row| row.get(0))
475            .map_err(|e| PunchError::Memory(format!("failed to read version rows: {e}")))?
476            .filter_map(|r| r.ok())
477            .collect();
478
479        Ok(versions)
480    }
481}
482
483// ---------------------------------------------------------------------------
484// Legacy entry point — used by substrate.rs
485// ---------------------------------------------------------------------------
486
487/// Run all pending built-in migrations against `conn`.
488///
489/// This is the entry point called from [`crate::substrate::MemorySubstrate`]
490/// during initialisation. It also handles migration from the old `_punch_meta`
491/// version-tracking table if present.
492pub fn migrate(conn: &Connection) -> PunchResult<()> {
493    // If the old _punch_meta table exists, drop it — the new engine tracks
494    // state in _punch_migrations.
495    conn.execute_batch("DROP TABLE IF EXISTS _punch_meta;")
496        .map_err(|e| PunchError::Memory(format!("failed to drop legacy meta table: {e}")))?;
497
498    // Create the tracking table.
499    conn.execute_batch(
500        "CREATE TABLE IF NOT EXISTS _punch_migrations (
501            id         INTEGER PRIMARY KEY,
502            version    INTEGER NOT NULL UNIQUE,
503            name       TEXT NOT NULL,
504            applied_at TEXT NOT NULL,
505            checksum   TEXT NOT NULL
506        );",
507    )
508    .map_err(|e| PunchError::Memory(format!("failed to create migrations table: {e}")))?;
509
510    // Determine which versions have already been applied.
511    let applied_versions = {
512        let mut stmt = conn
513            .prepare("SELECT version FROM _punch_migrations ORDER BY version")
514            .map_err(|e| PunchError::Memory(format!("failed to query applied versions: {e}")))?;
515        let versions: Vec<u64> = stmt
516            .query_map([], |row| row.get(0))
517            .map_err(|e| PunchError::Memory(format!("failed to read version rows: {e}")))?
518            .filter_map(|r| r.ok())
519            .collect();
520        versions
521    };
522
523    let builtins = MigrationEngine::builtin_migrations();
524    let mut count = 0usize;
525
526    for migration in &builtins {
527        if applied_versions.contains(&migration.version) {
528            continue;
529        }
530
531        let tx = conn
532            .unchecked_transaction()
533            .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
534
535        tx.execute_batch(&migration.up_sql).map_err(|e| {
536            PunchError::Memory(format!(
537                "migration v{} ({}) failed: {e}",
538                migration.version, migration.name
539            ))
540        })?;
541
542        let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
543        tx.execute(
544            "INSERT INTO _punch_migrations (version, name, applied_at, checksum)
545             VALUES (?1, ?2, ?3, ?4)",
546            rusqlite::params![migration.version, migration.name, now, migration.checksum(),],
547        )
548        .map_err(|e| {
549            PunchError::Memory(format!(
550                "failed to record migration v{}: {e}",
551                migration.version
552            ))
553        })?;
554
555        tx.commit().map_err(|e| {
556            PunchError::Memory(format!(
557                "failed to commit migration v{}: {e}",
558                migration.version
559            ))
560        })?;
561
562        info!(version = migration.version, name = %migration.name, "applied migration");
563        count += 1;
564    }
565
566    if count > 0 {
567        info!(count, "migrations applied");
568    }
569
570    Ok(())
571}
572
573// ---------------------------------------------------------------------------
574// Tests
575// ---------------------------------------------------------------------------
576
577#[cfg(test)]
578mod tests {
579    use super::*;
580
581    fn test_engine() -> (MigrationEngine, Arc<std::sync::Mutex<Connection>>) {
582        let conn = Connection::open_in_memory().unwrap();
583        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
584        let arc = Arc::new(std::sync::Mutex::new(conn));
585        let engine = MigrationEngine::new(Arc::clone(&arc)).unwrap();
586        (engine, arc)
587    }
588
589    fn simple_migrations() -> Vec<Migration> {
590        vec![
591            Migration {
592                version: 1,
593                name: "create_alpha".into(),
594                up_sql: "CREATE TABLE alpha (id INTEGER PRIMARY KEY, name TEXT);".into(),
595                down_sql: "DROP TABLE IF EXISTS alpha;".into(),
596            },
597            Migration {
598                version: 2,
599                name: "create_beta".into(),
600                up_sql: "CREATE TABLE beta (id INTEGER PRIMARY KEY, value TEXT);".into(),
601                down_sql: "DROP TABLE IF EXISTS beta;".into(),
602            },
603            Migration {
604                version: 3,
605                name: "create_gamma".into(),
606                up_sql: "CREATE TABLE gamma (id INTEGER PRIMARY KEY, score REAL);".into(),
607                down_sql: "DROP TABLE IF EXISTS gamma;".into(),
608            },
609        ]
610    }
611
612    #[test]
613    fn test_migration_table_creation() {
614        let (engine, arc) = test_engine();
615        // The tracking table should exist after new().
616        {
617            let c = arc.lock().unwrap();
618            let count: i64 = c
619                .query_row(
620                    "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_punch_migrations'",
621                    [],
622                    |row| row.get(0),
623                )
624                .unwrap();
625            assert_eq!(count, 1);
626        }
627        // No migrations applied yet.
628        assert_eq!(engine.current_version().unwrap(), 0);
629    }
630
631    #[test]
632    fn test_apply_single_migration() {
633        let (engine, arc) = test_engine();
634        let migrations = vec![simple_migrations().remove(0)];
635        let applied = engine.migrate_up(&migrations).unwrap();
636        assert_eq!(applied, vec![1]);
637
638        let c = arc.lock().unwrap();
639        let count: i64 = c
640            .query_row(
641                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='alpha'",
642                [],
643                |row| row.get(0),
644            )
645            .unwrap();
646        assert_eq!(count, 1);
647    }
648
649    #[test]
650    fn test_apply_multiple_migrations_in_order() {
651        let (engine, _arc) = test_engine();
652        let migrations = simple_migrations();
653        let applied = engine.migrate_up(&migrations).unwrap();
654        assert_eq!(applied, vec![1, 2, 3]);
655        assert_eq!(engine.current_version().unwrap(), 3);
656    }
657
658    #[test]
659    fn test_skip_already_applied_migrations() {
660        let (engine, _arc) = test_engine();
661        let migrations = simple_migrations();
662
663        engine.migrate_up(&migrations).unwrap();
664        let applied_again = engine.migrate_up(&migrations).unwrap();
665        assert!(applied_again.is_empty());
666    }
667
668    #[test]
669    fn test_rollback_to_specific_version() {
670        let (engine, arc) = test_engine();
671        let migrations = simple_migrations();
672
673        engine.migrate_up(&migrations).unwrap();
674        assert_eq!(engine.current_version().unwrap(), 3);
675
676        let rolled_back = engine.migrate_down(&migrations, 1).unwrap();
677        assert_eq!(rolled_back, vec![3, 2]);
678        assert_eq!(engine.current_version().unwrap(), 1);
679
680        // gamma and beta tables should be gone.
681        let c = arc.lock().unwrap();
682        let tables: Vec<String> = c
683            .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name IN ('alpha','beta','gamma')")
684            .unwrap()
685            .query_map([], |row| row.get(0))
686            .unwrap()
687            .filter_map(|r| r.ok())
688            .collect();
689        assert_eq!(tables, vec!["alpha".to_string()]);
690    }
691
692    #[test]
693    fn test_current_version_tracking() {
694        let (engine, _arc) = test_engine();
695        assert_eq!(engine.current_version().unwrap(), 0);
696
697        let migrations = simple_migrations();
698        engine.migrate_up(&migrations[..1]).unwrap();
699        assert_eq!(engine.current_version().unwrap(), 1);
700
701        engine.migrate_up(&migrations).unwrap();
702        assert_eq!(engine.current_version().unwrap(), 3);
703    }
704
705    #[test]
706    fn test_pending_migration_detection() {
707        let (engine, _arc) = test_engine();
708        let migrations = simple_migrations();
709
710        let pending = engine.pending_migrations(&migrations).unwrap();
711        assert_eq!(pending.len(), 3);
712
713        engine.migrate_up(&migrations[..2]).unwrap();
714
715        let pending = engine.pending_migrations(&migrations).unwrap();
716        assert_eq!(pending.len(), 1);
717        assert_eq!(pending[0].version, 3);
718    }
719
720    #[test]
721    fn test_checksum_verification_passes() {
722        let (engine, _arc) = test_engine();
723        let migrations = simple_migrations();
724        engine.migrate_up(&migrations).unwrap();
725
726        // Verify with the same migrations — should succeed.
727        engine.verify_checksums(&migrations).unwrap();
728    }
729
730    #[test]
731    fn test_checksum_verification_fails_for_tampered() {
732        let (engine, _arc) = test_engine();
733        let migrations = simple_migrations();
734        engine.migrate_up(&migrations).unwrap();
735
736        // Tamper with a migration's up_sql.
737        let mut tampered = simple_migrations();
738        tampered[0].up_sql =
739            "CREATE TABLE alpha (id INTEGER PRIMARY KEY, name TEXT, extra TEXT);".into();
740
741        let result = engine.verify_checksums(&tampered);
742        assert!(result.is_err());
743        let err_msg = format!("{}", result.unwrap_err());
744        assert!(err_msg.contains("checksum mismatch"));
745    }
746
747    #[test]
748    fn test_migration_status_listing() {
749        let (engine, _arc) = test_engine();
750        let migrations = simple_migrations();
751        engine.migrate_up(&migrations[..2]).unwrap();
752
753        let statuses = engine.migration_status(&migrations).unwrap();
754        assert_eq!(statuses.len(), 3);
755
756        assert!(statuses[0].applied);
757        assert!(statuses[0].applied_at.is_some());
758        assert_eq!(statuses[0].version, 1);
759
760        assert!(statuses[1].applied);
761        assert_eq!(statuses[1].version, 2);
762
763        assert!(!statuses[2].applied);
764        assert!(statuses[2].applied_at.is_none());
765        assert_eq!(statuses[2].version, 3);
766    }
767
768    #[test]
769    fn test_builtin_migrations_are_valid_sql() {
770        let (engine, _arc) = test_engine();
771        let builtins = MigrationEngine::builtin_migrations();
772
773        // All built-in migrations should apply without error.
774        let applied = engine.migrate_up(&builtins).unwrap();
775        assert_eq!(applied.len(), 11);
776        assert_eq!(engine.current_version().unwrap(), 11);
777    }
778
779    #[test]
780    fn test_idempotent_migrate_up() {
781        let (engine, _arc) = test_engine();
782        let migrations = simple_migrations();
783
784        let first = engine.migrate_up(&migrations).unwrap();
785        assert_eq!(first.len(), 3);
786
787        let second = engine.migrate_up(&migrations).unwrap();
788        assert!(second.is_empty());
789
790        // State unchanged.
791        assert_eq!(engine.current_version().unwrap(), 3);
792    }
793
794    #[test]
795    fn test_transaction_rollback_on_sql_error() {
796        let (engine, _arc) = test_engine();
797
798        let bad_migrations = vec![
799            Migration {
800                version: 1,
801                name: "good".into(),
802                up_sql: "CREATE TABLE good (id INTEGER PRIMARY KEY);".into(),
803                down_sql: "DROP TABLE IF EXISTS good;".into(),
804            },
805            Migration {
806                version: 2,
807                name: "bad".into(),
808                up_sql: "THIS IS NOT VALID SQL;".into(),
809                down_sql: "SELECT 1;".into(),
810            },
811        ];
812
813        // First migration succeeds, second fails.
814        let result = engine.migrate_up(&bad_migrations);
815        assert!(result.is_err());
816
817        // Only version 1 should be applied.
818        assert_eq!(engine.current_version().unwrap(), 1);
819    }
820
821    #[test]
822    fn test_down_migration_ordering_reverse() {
823        let (engine, _arc) = test_engine();
824        let migrations = simple_migrations();
825
826        engine.migrate_up(&migrations).unwrap();
827
828        // Rolling back to 0 should go 3, 2, 1.
829        let rolled = engine.migrate_down(&migrations, 0).unwrap();
830        assert_eq!(rolled, vec![3, 2, 1]);
831        assert_eq!(engine.current_version().unwrap(), 0);
832    }
833
834    #[test]
835    fn test_empty_migration_list_handling() {
836        let (engine, _arc) = test_engine();
837        let empty: Vec<Migration> = vec![];
838
839        let applied = engine.migrate_up(&empty).unwrap();
840        assert!(applied.is_empty());
841
842        let pending = engine.pending_migrations(&empty).unwrap();
843        assert!(pending.is_empty());
844
845        let statuses = engine.migration_status(&empty).unwrap();
846        assert!(statuses.is_empty());
847
848        let rolled = engine.migrate_down(&empty, 0).unwrap();
849        assert!(rolled.is_empty());
850
851        engine.verify_checksums(&empty).unwrap();
852    }
853
854    #[test]
855    fn test_checksum_deterministic() {
856        let m = Migration {
857            version: 1,
858            name: "test".into(),
859            up_sql: "CREATE TABLE test (id INTEGER);".into(),
860            down_sql: "DROP TABLE test;".into(),
861        };
862        let c1 = m.checksum();
863        let c2 = m.checksum();
864        assert_eq!(c1, c2);
865        assert_eq!(c1.len(), 64); // SHA-256 hex is 64 chars
866    }
867
868    #[test]
869    fn test_legacy_migrate_function() {
870        let conn = Connection::open_in_memory().unwrap();
871        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
872
873        // The legacy migrate() function should work.
874        migrate(&conn).unwrap();
875
876        // Core tables from built-in migrations should exist.
877        let tables: Vec<String> = conn
878            .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
879            .unwrap()
880            .query_map([], |row| row.get(0))
881            .unwrap()
882            .filter_map(|r| r.ok())
883            .collect();
884
885        assert!(tables.contains(&"memories".to_string()));
886        assert!(tables.contains(&"knowledge_entities".to_string()));
887        assert!(tables.contains(&"knowledge_relations".to_string()));
888        assert!(tables.contains(&"bouts".to_string()));
889        assert!(tables.contains(&"messages".to_string()));
890        assert!(tables.contains(&"_punch_migrations".to_string()));
891    }
892
893    #[test]
894    fn test_legacy_migrate_idempotent() {
895        let conn = Connection::open_in_memory().unwrap();
896        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
897        migrate(&conn).unwrap();
898        migrate(&conn).unwrap();
899
900        // Should still be version 9.
901        let version: Option<u64> = conn
902            .query_row("SELECT MAX(version) FROM _punch_migrations", [], |row| {
903                row.get(0)
904            })
905            .unwrap();
906        assert_eq!(version.unwrap_or(0), 11);
907    }
908}