Skip to main content

punch_memory/
migrations.rs

1//! Database migration engine for the Punch memory substrate.
2//!
3//! Tracks applied migrations in a `_punch_migrations` table and supports
4//! forward (up) and backward (down) migration with SHA-256 checksum
5//! verification for integrity.
6
7use std::sync::Arc;
8
9use rusqlite::Connection;
10use sha2::{Digest, Sha256};
11use tracing::info;
12
13use punch_types::{PunchError, PunchResult};
14
15// ---------------------------------------------------------------------------
16// Public types
17// ---------------------------------------------------------------------------
18
19/// A single database migration with up and down SQL.
20#[derive(Debug, Clone)]
21pub struct Migration {
22    pub version: u64,
23    pub name: String,
24    pub up_sql: String,
25    pub down_sql: String,
26}
27
28impl Migration {
29    /// Compute the SHA-256 hex digest of the `up_sql` content.
30    pub fn checksum(&self) -> String {
31        let mut hasher = Sha256::new();
32        hasher.update(self.up_sql.as_bytes());
33        format!("{:x}", hasher.finalize())
34    }
35}
36
37/// Status of a single migration (applied or pending).
38#[derive(Debug, Clone)]
39pub struct MigrationStatus {
40    pub version: u64,
41    pub name: String,
42    pub applied: bool,
43    pub applied_at: Option<String>,
44}
45
46/// The migration engine manages schema versioning for a SQLite database.
47pub struct MigrationEngine {
48    /// Database connection.
49    conn: Arc<std::sync::Mutex<Connection>>,
50}
51
52impl MigrationEngine {
53    /// Create a new engine and ensure the tracking table exists.
54    pub fn new(conn: Arc<std::sync::Mutex<Connection>>) -> PunchResult<Self> {
55        {
56            let c = conn
57                .lock()
58                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
59            c.execute_batch(
60                "CREATE TABLE IF NOT EXISTS _punch_migrations (
61                    id         INTEGER PRIMARY KEY,
62                    version    INTEGER NOT NULL UNIQUE,
63                    name       TEXT NOT NULL,
64                    applied_at TEXT NOT NULL,
65                    checksum   TEXT NOT NULL
66                );",
67            )
68            .map_err(|e| PunchError::Memory(format!("failed to create migrations table: {e}")))?;
69        }
70        Ok(Self { conn })
71    }
72
73    /// Return the highest applied migration version, or 0 if none.
74    pub fn current_version(&self) -> PunchResult<u64> {
75        let c = self
76            .conn
77            .lock()
78            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
79        let version: Option<u64> = c
80            .query_row("SELECT MAX(version) FROM _punch_migrations", [], |row| {
81                row.get(0)
82            })
83            .map_err(|e| PunchError::Memory(format!("failed to query current version: {e}")))?;
84        Ok(version.unwrap_or(0))
85    }
86
87    /// Return migrations from `all` that have not yet been applied, sorted by
88    /// version ascending.
89    pub fn pending_migrations<'a>(&self, all: &'a [Migration]) -> PunchResult<Vec<&'a Migration>> {
90        let applied = self.applied_versions()?;
91        let mut pending: Vec<&Migration> = all
92            .iter()
93            .filter(|m| !applied.contains(&m.version))
94            .collect();
95        pending.sort_by_key(|m| m.version);
96        Ok(pending)
97    }
98
99    /// Apply all pending migrations in order. Each migration runs inside its
100    /// own transaction. Returns the versions that were applied.
101    pub fn migrate_up(&self, migrations: &[Migration]) -> PunchResult<Vec<u64>> {
102        let pending = self.pending_migrations(migrations)?;
103        let mut applied = Vec::new();
104
105        for migration in pending {
106            let c = self
107                .conn
108                .lock()
109                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
110            let tx = c
111                .unchecked_transaction()
112                .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
113
114            tx.execute_batch(&migration.up_sql).map_err(|e| {
115                PunchError::Memory(format!(
116                    "migration v{} ({}) failed: {e}",
117                    migration.version, migration.name
118                ))
119            })?;
120
121            let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
122            tx.execute(
123                "INSERT INTO _punch_migrations (version, name, applied_at, checksum)
124                 VALUES (?1, ?2, ?3, ?4)",
125                rusqlite::params![migration.version, migration.name, now, migration.checksum(),],
126            )
127            .map_err(|e| {
128                PunchError::Memory(format!(
129                    "failed to record migration v{}: {e}",
130                    migration.version
131                ))
132            })?;
133
134            tx.commit().map_err(|e| {
135                PunchError::Memory(format!(
136                    "failed to commit migration v{}: {e}",
137                    migration.version
138                ))
139            })?;
140
141            info!(version = migration.version, name = %migration.name, "applied migration");
142            applied.push(migration.version);
143        }
144
145        Ok(applied)
146    }
147
148    /// Roll back applied migrations whose version is greater than
149    /// `target_version`, in reverse order. Returns the versions that were
150    /// rolled back.
151    pub fn migrate_down(
152        &self,
153        migrations: &[Migration],
154        target_version: u64,
155    ) -> PunchResult<Vec<u64>> {
156        let applied = self.applied_versions()?;
157
158        // Collect migrations that need to be rolled back, sorted descending.
159        let mut to_rollback: Vec<&Migration> = migrations
160            .iter()
161            .filter(|m| m.version > target_version && applied.contains(&m.version))
162            .collect();
163        to_rollback.sort_by(|a, b| b.version.cmp(&a.version));
164
165        let mut rolled_back = Vec::new();
166
167        for migration in to_rollback {
168            let c = self
169                .conn
170                .lock()
171                .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
172            let tx = c
173                .unchecked_transaction()
174                .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
175
176            tx.execute_batch(&migration.down_sql).map_err(|e| {
177                PunchError::Memory(format!(
178                    "down migration v{} ({}) failed: {e}",
179                    migration.version, migration.name
180                ))
181            })?;
182
183            tx.execute(
184                "DELETE FROM _punch_migrations WHERE version = ?1",
185                [migration.version],
186            )
187            .map_err(|e| {
188                PunchError::Memory(format!(
189                    "failed to remove migration record v{}: {e}",
190                    migration.version
191                ))
192            })?;
193
194            tx.commit().map_err(|e| {
195                PunchError::Memory(format!(
196                    "failed to commit down migration v{}: {e}",
197                    migration.version
198                ))
199            })?;
200
201            info!(version = migration.version, name = %migration.name, "rolled back migration");
202            rolled_back.push(migration.version);
203        }
204
205        Ok(rolled_back)
206    }
207
208    /// Show the status (applied / pending) of every known migration.
209    pub fn migration_status(&self, migrations: &[Migration]) -> PunchResult<Vec<MigrationStatus>> {
210        let c = self
211            .conn
212            .lock()
213            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
214
215        let mut stmt = c
216            .prepare("SELECT version, applied_at FROM _punch_migrations")
217            .map_err(|e| PunchError::Memory(format!("failed to query migration status: {e}")))?;
218
219        let rows: Vec<(u64, String)> = stmt
220            .query_map([], |row| {
221                let version: u64 = row.get(0)?;
222                let applied_at: String = row.get(1)?;
223                Ok((version, applied_at))
224            })
225            .map_err(|e| PunchError::Memory(format!("failed to read migration rows: {e}")))?
226            .filter_map(|r| r.ok())
227            .collect();
228
229        let mut statuses: Vec<MigrationStatus> = migrations
230            .iter()
231            .map(|m| {
232                let applied_row = rows.iter().find(|(v, _)| *v == m.version);
233                MigrationStatus {
234                    version: m.version,
235                    name: m.name.clone(),
236                    applied: applied_row.is_some(),
237                    applied_at: applied_row.map(|(_, at)| at.clone()),
238                }
239            })
240            .collect();
241        statuses.sort_by_key(|s| s.version);
242        Ok(statuses)
243    }
244
245    /// Verify that every applied migration's stored checksum matches the
246    /// current `up_sql` content. Returns an error on the first mismatch.
247    pub fn verify_checksums(&self, migrations: &[Migration]) -> PunchResult<()> {
248        let c = self
249            .conn
250            .lock()
251            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
252
253        let mut stmt = c
254            .prepare("SELECT version, checksum FROM _punch_migrations")
255            .map_err(|e| PunchError::Memory(format!("failed to query checksums: {e}")))?;
256
257        let rows: Vec<(u64, String)> = stmt
258            .query_map([], |row| {
259                let version: u64 = row.get(0)?;
260                let checksum: String = row.get(1)?;
261                Ok((version, checksum))
262            })
263            .map_err(|e| PunchError::Memory(format!("failed to read checksum rows: {e}")))?
264            .filter_map(|r| r.ok())
265            .collect();
266
267        for (version, stored_checksum) in &rows {
268            if let Some(migration) = migrations.iter().find(|m| m.version == *version) {
269                let current_checksum = migration.checksum();
270                if *stored_checksum != current_checksum {
271                    return Err(PunchError::Memory(format!(
272                        "checksum mismatch for migration v{} ({}): stored={}, current={}",
273                        version, migration.name, stored_checksum, current_checksum
274                    )));
275                }
276            }
277        }
278
279        Ok(())
280    }
281
282    /// Return the 6 built-in migrations that define the Punch schema.
283    pub fn builtin_migrations() -> Vec<Migration> {
284        vec![
285            Migration {
286                version: 1,
287                name: "create_memories_table".into(),
288                up_sql: "CREATE TABLE IF NOT EXISTS memories (
289                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
290                    fighter_id  TEXT NOT NULL,
291                    key         TEXT NOT NULL,
292                    value       TEXT NOT NULL,
293                    confidence  REAL NOT NULL DEFAULT 1.0,
294                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
295                    accessed_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
296                    UNIQUE(fighter_id, key)
297                );"
298                .into(),
299                down_sql: "DROP TABLE IF EXISTS memories;".into(),
300            },
301            Migration {
302                version: 2,
303                name: "create_entities_table".into(),
304                up_sql: "CREATE TABLE IF NOT EXISTS knowledge_entities (
305                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
306                    fighter_id  TEXT NOT NULL,
307                    name        TEXT NOT NULL,
308                    entity_type TEXT NOT NULL,
309                    properties  TEXT NOT NULL DEFAULT '{}',
310                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
311                    UNIQUE(fighter_id, name, entity_type)
312                );"
313                .into(),
314                down_sql: "DROP TABLE IF EXISTS knowledge_entities;".into(),
315            },
316            Migration {
317                version: 3,
318                name: "create_relations_table".into(),
319                up_sql: "CREATE TABLE IF NOT EXISTS knowledge_relations (
320                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
321                    fighter_id  TEXT NOT NULL,
322                    from_entity TEXT NOT NULL,
323                    relation    TEXT NOT NULL,
324                    to_entity   TEXT NOT NULL,
325                    properties  TEXT NOT NULL DEFAULT '{}',
326                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
327                    UNIQUE(fighter_id, from_entity, relation, to_entity)
328                );"
329                .into(),
330                down_sql: "DROP TABLE IF EXISTS knowledge_relations;".into(),
331            },
332            Migration {
333                version: 4,
334                name: "create_bouts_table".into(),
335                up_sql: "CREATE TABLE IF NOT EXISTS bouts (
336                    id          TEXT PRIMARY KEY,
337                    fighter_id  TEXT NOT NULL,
338                    title       TEXT,
339                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
340                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
341                );"
342                .into(),
343                down_sql: "DROP TABLE IF EXISTS bouts;".into(),
344            },
345            Migration {
346                version: 5,
347                name: "create_bout_messages_table".into(),
348                up_sql: "CREATE TABLE IF NOT EXISTS messages (
349                    id          INTEGER PRIMARY KEY AUTOINCREMENT,
350                    bout_id     TEXT NOT NULL,
351                    role        TEXT NOT NULL,
352                    content     TEXT NOT NULL DEFAULT '',
353                    metadata    TEXT,
354                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
355                );"
356                .into(),
357                down_sql: "DROP TABLE IF EXISTS messages;".into(),
358            },
359            Migration {
360                version: 6,
361                name: "add_indexes".into(),
362                up_sql: "
363                    CREATE INDEX IF NOT EXISTS idx_memories_fighter ON memories(fighter_id);
364                    CREATE INDEX IF NOT EXISTS idx_ke_fighter ON knowledge_entities(fighter_id);
365                    CREATE INDEX IF NOT EXISTS idx_kr_fighter ON knowledge_relations(fighter_id);
366                    CREATE INDEX IF NOT EXISTS idx_bouts_fighter ON bouts(fighter_id);
367                    CREATE INDEX IF NOT EXISTS idx_messages_bout ON messages(bout_id);
368                "
369                .into(),
370                down_sql: "
371                    DROP INDEX IF EXISTS idx_memories_fighter;
372                    DROP INDEX IF EXISTS idx_ke_fighter;
373                    DROP INDEX IF EXISTS idx_kr_fighter;
374                    DROP INDEX IF EXISTS idx_bouts_fighter;
375                    DROP INDEX IF EXISTS idx_messages_bout;
376                "
377                .into(),
378            },
379            Migration {
380                version: 7,
381                name: "create_fighters_table".into(),
382                up_sql: "CREATE TABLE IF NOT EXISTS fighters (
383                    id          TEXT PRIMARY KEY,
384                    name        TEXT NOT NULL,
385                    manifest    TEXT NOT NULL,
386                    status      TEXT NOT NULL DEFAULT 'idle',
387                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
388                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
389                );"
390                .into(),
391                down_sql: "DROP TABLE IF EXISTS fighters;".into(),
392            },
393            Migration {
394                version: 8,
395                name: "create_usage_events_table".into(),
396                up_sql: "CREATE TABLE IF NOT EXISTS usage_events (
397                    id              INTEGER PRIMARY KEY AUTOINCREMENT,
398                    fighter_id      TEXT NOT NULL,
399                    model           TEXT NOT NULL,
400                    input_tokens    INTEGER NOT NULL,
401                    output_tokens   INTEGER NOT NULL,
402                    cost_usd        REAL NOT NULL DEFAULT 0.0,
403                    created_at      TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
404                );
405                CREATE INDEX IF NOT EXISTS idx_usage_fighter ON usage_events(fighter_id);"
406                    .into(),
407                down_sql: "DROP INDEX IF EXISTS idx_usage_fighter;
408                DROP TABLE IF EXISTS usage_events;"
409                    .into(),
410            },
411            Migration {
412                version: 9,
413                name: "create_gorilla_state_table".into(),
414                up_sql: "CREATE TABLE IF NOT EXISTS gorilla_state (
415                    gorilla_id  TEXT PRIMARY KEY,
416                    state       TEXT NOT NULL DEFAULT '{}',
417                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
418                );"
419                .into(),
420                down_sql: "DROP TABLE IF EXISTS gorilla_state;".into(),
421            },
422            Migration {
423                version: 10,
424                name: "create_embeddings_table".into(),
425                up_sql: "CREATE TABLE IF NOT EXISTS embeddings (
426                    id         TEXT PRIMARY KEY,
427                    text       TEXT NOT NULL,
428                    vector     BLOB NOT NULL,
429                    metadata   TEXT NOT NULL DEFAULT '{}',
430                    created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
431                );"
432                .into(),
433                down_sql: "DROP TABLE IF EXISTS embeddings;".into(),
434            },
435            Migration {
436                version: 11,
437                name: "create_creeds_table".into(),
438                up_sql: "CREATE TABLE IF NOT EXISTS creeds (
439                    id          TEXT PRIMARY KEY,
440                    fighter_name TEXT NOT NULL,
441                    fighter_id  TEXT,
442                    creed_data  TEXT NOT NULL,
443                    version     INTEGER NOT NULL DEFAULT 1,
444                    bout_count  INTEGER NOT NULL DEFAULT 0,
445                    message_count INTEGER NOT NULL DEFAULT 0,
446                    created_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
447                    updated_at  TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
448                );
449                CREATE UNIQUE INDEX IF NOT EXISTS idx_creeds_fighter_name ON creeds(fighter_name);
450                CREATE INDEX IF NOT EXISTS idx_creeds_fighter_id ON creeds(fighter_id);"
451                    .into(),
452                down_sql: "DROP INDEX IF EXISTS idx_creeds_fighter_id;
453                DROP INDEX IF EXISTS idx_creeds_fighter_name;
454                DROP TABLE IF EXISTS creeds;"
455                    .into(),
456            },
457            Migration {
458                version: 12,
459                name: "create_channels_table".into(),
460                up_sql: "CREATE TABLE IF NOT EXISTS channels (
461                    id              TEXT PRIMARY KEY,
462                    name            TEXT NOT NULL UNIQUE,
463                    platform        TEXT NOT NULL,
464                    credentials     TEXT NOT NULL DEFAULT '{}',
465                    settings        TEXT NOT NULL DEFAULT '{}',
466                    status          TEXT NOT NULL DEFAULT 'disconnected',
467                    validated_at    TEXT,
468                    created_at      TEXT NOT NULL,
469                    updated_at      TEXT NOT NULL
470                );
471                CREATE INDEX IF NOT EXISTS idx_channels_platform ON channels(platform);"
472                    .into(),
473                down_sql: "DROP INDEX IF EXISTS idx_channels_platform;
474                DROP TABLE IF EXISTS channels;"
475                    .into(),
476            },
477        ]
478    }
479
480    // -----------------------------------------------------------------------
481    // Internal helpers
482    // -----------------------------------------------------------------------
483
484    fn applied_versions(&self) -> PunchResult<Vec<u64>> {
485        let c = self
486            .conn
487            .lock()
488            .map_err(|e| PunchError::Memory(format!("failed to lock connection: {e}")))?;
489        let mut stmt = c
490            .prepare("SELECT version FROM _punch_migrations ORDER BY version")
491            .map_err(|e| PunchError::Memory(format!("failed to query applied versions: {e}")))?;
492
493        let versions: Vec<u64> = stmt
494            .query_map([], |row| row.get(0))
495            .map_err(|e| PunchError::Memory(format!("failed to read version rows: {e}")))?
496            .filter_map(|r| r.ok())
497            .collect();
498
499        Ok(versions)
500    }
501}
502
503// ---------------------------------------------------------------------------
504// Legacy entry point — used by substrate.rs
505// ---------------------------------------------------------------------------
506
507/// Run all pending built-in migrations against `conn`.
508///
509/// This is the entry point called from [`crate::substrate::MemorySubstrate`]
510/// during initialisation. It also handles migration from the old `_punch_meta`
511/// version-tracking table if present.
512pub fn migrate(conn: &Connection) -> PunchResult<()> {
513    // If the old _punch_meta table exists, drop it — the new engine tracks
514    // state in _punch_migrations.
515    conn.execute_batch("DROP TABLE IF EXISTS _punch_meta;")
516        .map_err(|e| PunchError::Memory(format!("failed to drop legacy meta table: {e}")))?;
517
518    // Create the tracking table.
519    conn.execute_batch(
520        "CREATE TABLE IF NOT EXISTS _punch_migrations (
521            id         INTEGER PRIMARY KEY,
522            version    INTEGER NOT NULL UNIQUE,
523            name       TEXT NOT NULL,
524            applied_at TEXT NOT NULL,
525            checksum   TEXT NOT NULL
526        );",
527    )
528    .map_err(|e| PunchError::Memory(format!("failed to create migrations table: {e}")))?;
529
530    // Determine which versions have already been applied.
531    let applied_versions = {
532        let mut stmt = conn
533            .prepare("SELECT version FROM _punch_migrations ORDER BY version")
534            .map_err(|e| PunchError::Memory(format!("failed to query applied versions: {e}")))?;
535        let versions: Vec<u64> = stmt
536            .query_map([], |row| row.get(0))
537            .map_err(|e| PunchError::Memory(format!("failed to read version rows: {e}")))?
538            .filter_map(|r| r.ok())
539            .collect();
540        versions
541    };
542
543    let builtins = MigrationEngine::builtin_migrations();
544    let mut count = 0usize;
545
546    for migration in &builtins {
547        if applied_versions.contains(&migration.version) {
548            continue;
549        }
550
551        let tx = conn
552            .unchecked_transaction()
553            .map_err(|e| PunchError::Memory(format!("failed to begin transaction: {e}")))?;
554
555        tx.execute_batch(&migration.up_sql).map_err(|e| {
556            PunchError::Memory(format!(
557                "migration v{} ({}) failed: {e}",
558                migration.version, migration.name
559            ))
560        })?;
561
562        let now = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string();
563        tx.execute(
564            "INSERT INTO _punch_migrations (version, name, applied_at, checksum)
565             VALUES (?1, ?2, ?3, ?4)",
566            rusqlite::params![migration.version, migration.name, now, migration.checksum(),],
567        )
568        .map_err(|e| {
569            PunchError::Memory(format!(
570                "failed to record migration v{}: {e}",
571                migration.version
572            ))
573        })?;
574
575        tx.commit().map_err(|e| {
576            PunchError::Memory(format!(
577                "failed to commit migration v{}: {e}",
578                migration.version
579            ))
580        })?;
581
582        info!(version = migration.version, name = %migration.name, "applied migration");
583        count += 1;
584    }
585
586    if count > 0 {
587        info!(count, "migrations applied");
588    }
589
590    Ok(())
591}
592
593// ---------------------------------------------------------------------------
594// Tests
595// ---------------------------------------------------------------------------
596
597#[cfg(test)]
598mod tests {
599    use super::*;
600
601    fn test_engine() -> (MigrationEngine, Arc<std::sync::Mutex<Connection>>) {
602        let conn = Connection::open_in_memory().unwrap();
603        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
604        let arc = Arc::new(std::sync::Mutex::new(conn));
605        let engine = MigrationEngine::new(Arc::clone(&arc)).unwrap();
606        (engine, arc)
607    }
608
609    fn simple_migrations() -> Vec<Migration> {
610        vec![
611            Migration {
612                version: 1,
613                name: "create_alpha".into(),
614                up_sql: "CREATE TABLE alpha (id INTEGER PRIMARY KEY, name TEXT);".into(),
615                down_sql: "DROP TABLE IF EXISTS alpha;".into(),
616            },
617            Migration {
618                version: 2,
619                name: "create_beta".into(),
620                up_sql: "CREATE TABLE beta (id INTEGER PRIMARY KEY, value TEXT);".into(),
621                down_sql: "DROP TABLE IF EXISTS beta;".into(),
622            },
623            Migration {
624                version: 3,
625                name: "create_gamma".into(),
626                up_sql: "CREATE TABLE gamma (id INTEGER PRIMARY KEY, score REAL);".into(),
627                down_sql: "DROP TABLE IF EXISTS gamma;".into(),
628            },
629        ]
630    }
631
632    #[test]
633    fn test_migration_table_creation() {
634        let (engine, arc) = test_engine();
635        // The tracking table should exist after new().
636        {
637            let c = arc.lock().unwrap();
638            let count: i64 = c
639                .query_row(
640                    "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='_punch_migrations'",
641                    [],
642                    |row| row.get(0),
643                )
644                .unwrap();
645            assert_eq!(count, 1);
646        }
647        // No migrations applied yet.
648        assert_eq!(engine.current_version().unwrap(), 0);
649    }
650
651    #[test]
652    fn test_apply_single_migration() {
653        let (engine, arc) = test_engine();
654        let migrations = vec![simple_migrations().remove(0)];
655        let applied = engine.migrate_up(&migrations).unwrap();
656        assert_eq!(applied, vec![1]);
657
658        let c = arc.lock().unwrap();
659        let count: i64 = c
660            .query_row(
661                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='alpha'",
662                [],
663                |row| row.get(0),
664            )
665            .unwrap();
666        assert_eq!(count, 1);
667    }
668
669    #[test]
670    fn test_apply_multiple_migrations_in_order() {
671        let (engine, _arc) = test_engine();
672        let migrations = simple_migrations();
673        let applied = engine.migrate_up(&migrations).unwrap();
674        assert_eq!(applied, vec![1, 2, 3]);
675        assert_eq!(engine.current_version().unwrap(), 3);
676    }
677
678    #[test]
679    fn test_skip_already_applied_migrations() {
680        let (engine, _arc) = test_engine();
681        let migrations = simple_migrations();
682
683        engine.migrate_up(&migrations).unwrap();
684        let applied_again = engine.migrate_up(&migrations).unwrap();
685        assert!(applied_again.is_empty());
686    }
687
688    #[test]
689    fn test_rollback_to_specific_version() {
690        let (engine, arc) = test_engine();
691        let migrations = simple_migrations();
692
693        engine.migrate_up(&migrations).unwrap();
694        assert_eq!(engine.current_version().unwrap(), 3);
695
696        let rolled_back = engine.migrate_down(&migrations, 1).unwrap();
697        assert_eq!(rolled_back, vec![3, 2]);
698        assert_eq!(engine.current_version().unwrap(), 1);
699
700        // gamma and beta tables should be gone.
701        let c = arc.lock().unwrap();
702        let tables: Vec<String> = c
703            .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name IN ('alpha','beta','gamma')")
704            .unwrap()
705            .query_map([], |row| row.get(0))
706            .unwrap()
707            .filter_map(|r| r.ok())
708            .collect();
709        assert_eq!(tables, vec!["alpha".to_string()]);
710    }
711
712    #[test]
713    fn test_current_version_tracking() {
714        let (engine, _arc) = test_engine();
715        assert_eq!(engine.current_version().unwrap(), 0);
716
717        let migrations = simple_migrations();
718        engine.migrate_up(&migrations[..1]).unwrap();
719        assert_eq!(engine.current_version().unwrap(), 1);
720
721        engine.migrate_up(&migrations).unwrap();
722        assert_eq!(engine.current_version().unwrap(), 3);
723    }
724
725    #[test]
726    fn test_pending_migration_detection() {
727        let (engine, _arc) = test_engine();
728        let migrations = simple_migrations();
729
730        let pending = engine.pending_migrations(&migrations).unwrap();
731        assert_eq!(pending.len(), 3);
732
733        engine.migrate_up(&migrations[..2]).unwrap();
734
735        let pending = engine.pending_migrations(&migrations).unwrap();
736        assert_eq!(pending.len(), 1);
737        assert_eq!(pending[0].version, 3);
738    }
739
740    #[test]
741    fn test_checksum_verification_passes() {
742        let (engine, _arc) = test_engine();
743        let migrations = simple_migrations();
744        engine.migrate_up(&migrations).unwrap();
745
746        // Verify with the same migrations — should succeed.
747        engine.verify_checksums(&migrations).unwrap();
748    }
749
750    #[test]
751    fn test_checksum_verification_fails_for_tampered() {
752        let (engine, _arc) = test_engine();
753        let migrations = simple_migrations();
754        engine.migrate_up(&migrations).unwrap();
755
756        // Tamper with a migration's up_sql.
757        let mut tampered = simple_migrations();
758        tampered[0].up_sql =
759            "CREATE TABLE alpha (id INTEGER PRIMARY KEY, name TEXT, extra TEXT);".into();
760
761        let result = engine.verify_checksums(&tampered);
762        assert!(result.is_err());
763        let err_msg = format!("{}", result.unwrap_err());
764        assert!(err_msg.contains("checksum mismatch"));
765    }
766
767    #[test]
768    fn test_migration_status_listing() {
769        let (engine, _arc) = test_engine();
770        let migrations = simple_migrations();
771        engine.migrate_up(&migrations[..2]).unwrap();
772
773        let statuses = engine.migration_status(&migrations).unwrap();
774        assert_eq!(statuses.len(), 3);
775
776        assert!(statuses[0].applied);
777        assert!(statuses[0].applied_at.is_some());
778        assert_eq!(statuses[0].version, 1);
779
780        assert!(statuses[1].applied);
781        assert_eq!(statuses[1].version, 2);
782
783        assert!(!statuses[2].applied);
784        assert!(statuses[2].applied_at.is_none());
785        assert_eq!(statuses[2].version, 3);
786    }
787
788    #[test]
789    fn test_builtin_migrations_are_valid_sql() {
790        let (engine, _arc) = test_engine();
791        let builtins = MigrationEngine::builtin_migrations();
792
793        // All built-in migrations should apply without error.
794        let applied = engine.migrate_up(&builtins).unwrap();
795        assert_eq!(applied.len(), 12);
796        assert_eq!(engine.current_version().unwrap(), 12);
797    }
798
799    #[test]
800    fn test_idempotent_migrate_up() {
801        let (engine, _arc) = test_engine();
802        let migrations = simple_migrations();
803
804        let first = engine.migrate_up(&migrations).unwrap();
805        assert_eq!(first.len(), 3);
806
807        let second = engine.migrate_up(&migrations).unwrap();
808        assert!(second.is_empty());
809
810        // State unchanged.
811        assert_eq!(engine.current_version().unwrap(), 3);
812    }
813
814    #[test]
815    fn test_transaction_rollback_on_sql_error() {
816        let (engine, _arc) = test_engine();
817
818        let bad_migrations = vec![
819            Migration {
820                version: 1,
821                name: "good".into(),
822                up_sql: "CREATE TABLE good (id INTEGER PRIMARY KEY);".into(),
823                down_sql: "DROP TABLE IF EXISTS good;".into(),
824            },
825            Migration {
826                version: 2,
827                name: "bad".into(),
828                up_sql: "THIS IS NOT VALID SQL;".into(),
829                down_sql: "SELECT 1;".into(),
830            },
831        ];
832
833        // First migration succeeds, second fails.
834        let result = engine.migrate_up(&bad_migrations);
835        assert!(result.is_err());
836
837        // Only version 1 should be applied.
838        assert_eq!(engine.current_version().unwrap(), 1);
839    }
840
841    #[test]
842    fn test_down_migration_ordering_reverse() {
843        let (engine, _arc) = test_engine();
844        let migrations = simple_migrations();
845
846        engine.migrate_up(&migrations).unwrap();
847
848        // Rolling back to 0 should go 3, 2, 1.
849        let rolled = engine.migrate_down(&migrations, 0).unwrap();
850        assert_eq!(rolled, vec![3, 2, 1]);
851        assert_eq!(engine.current_version().unwrap(), 0);
852    }
853
854    #[test]
855    fn test_empty_migration_list_handling() {
856        let (engine, _arc) = test_engine();
857        let empty: Vec<Migration> = vec![];
858
859        let applied = engine.migrate_up(&empty).unwrap();
860        assert!(applied.is_empty());
861
862        let pending = engine.pending_migrations(&empty).unwrap();
863        assert!(pending.is_empty());
864
865        let statuses = engine.migration_status(&empty).unwrap();
866        assert!(statuses.is_empty());
867
868        let rolled = engine.migrate_down(&empty, 0).unwrap();
869        assert!(rolled.is_empty());
870
871        engine.verify_checksums(&empty).unwrap();
872    }
873
874    #[test]
875    fn test_checksum_deterministic() {
876        let m = Migration {
877            version: 1,
878            name: "test".into(),
879            up_sql: "CREATE TABLE test (id INTEGER);".into(),
880            down_sql: "DROP TABLE test;".into(),
881        };
882        let c1 = m.checksum();
883        let c2 = m.checksum();
884        assert_eq!(c1, c2);
885        assert_eq!(c1.len(), 64); // SHA-256 hex is 64 chars
886    }
887
888    #[test]
889    fn test_legacy_migrate_function() {
890        let conn = Connection::open_in_memory().unwrap();
891        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
892
893        // The legacy migrate() function should work.
894        migrate(&conn).unwrap();
895
896        // Core tables from built-in migrations should exist.
897        let tables: Vec<String> = conn
898            .prepare("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
899            .unwrap()
900            .query_map([], |row| row.get(0))
901            .unwrap()
902            .filter_map(|r| r.ok())
903            .collect();
904
905        assert!(tables.contains(&"memories".to_string()));
906        assert!(tables.contains(&"knowledge_entities".to_string()));
907        assert!(tables.contains(&"knowledge_relations".to_string()));
908        assert!(tables.contains(&"bouts".to_string()));
909        assert!(tables.contains(&"messages".to_string()));
910        assert!(tables.contains(&"_punch_migrations".to_string()));
911    }
912
913    #[test]
914    fn test_legacy_migrate_idempotent() {
915        let conn = Connection::open_in_memory().unwrap();
916        conn.execute_batch("PRAGMA foreign_keys = ON;").unwrap();
917        migrate(&conn).unwrap();
918        migrate(&conn).unwrap();
919
920        // Should still be version 9.
921        let version: Option<u64> = conn
922            .query_row("SELECT MAX(version) FROM _punch_migrations", [], |row| {
923                row.get(0)
924            })
925            .unwrap();
926        assert_eq!(version.unwrap_or(0), 12);
927    }
928}