Skip to main content

cp_graph/
migrations.rs

1//! Database migration system for CP
2//!
3//! Per CP-002: Provides versioned schema migrations with automatic upgrade.
4//! Migrations are embedded in the binary and run in order.
5
6use cp_core::{CPError, Result};
7use rusqlite::Connection;
8use tracing::info;
9
10/// Migration definition
11struct Migration {
12    version: u32,
13    name: &'static str,
14    sql: &'static str,
15}
16
17/// All migrations in order
18const MIGRATIONS: &[Migration] = &[
19    Migration {
20        version: 1,
21        name: "initial_schema",
22        sql: include_str!("migrations/001_initial.sql"),
23    },
24    Migration {
25        version: 2,
26        name: "add_timestamps",
27        sql: include_str!("migrations/002_add_timestamps.sql"),
28    },
29    Migration {
30        version: 3,
31        name: "add_l2_norm",
32        sql: include_str!("migrations/003_add_l2_norm.sql"),
33    },
34    Migration {
35        version: 4,
36        name: "add_path_id_and_embedding_version",
37        sql: include_str!("migrations/004_add_path_id_and_embedding_version.sql"),
38    },
39    Migration {
40        version: 5,
41        name: "add_arweave_tx",
42        sql: include_str!("migrations/005_add_arweave_tx.sql"),
43    },
44    Migration {
45        version: 6,
46        name: "spec_compliance",
47        sql: include_str!("migrations/006_spec_compliance.sql"),
48    },
49];
50
51/// Run all pending migrations on the database
52pub fn run_migrations(conn: &Connection) -> Result<()> {
53    // Create schema_version table if it doesn't exist
54    conn.execute_batch(
55        r"
56        CREATE TABLE IF NOT EXISTS schema_version (
57            version INTEGER PRIMARY KEY,
58            name TEXT NOT NULL,
59            applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
60        );
61        ",
62    )
63    .map_err(|e| CPError::Database(format!("Failed to create schema_version table: {e}")))?;
64
65    // Get current version
66    let current_version: u32 = conn
67        .query_row(
68            "SELECT COALESCE(MAX(version), 0) FROM schema_version",
69            [],
70            |row| row.get(0),
71        )
72        .map_err(|e| CPError::Database(format!("Failed to get schema version: {e}")))?;
73
74    info!("Current schema version: {}", current_version);
75
76    // Run pending migrations
77    for migration in MIGRATIONS {
78        if migration.version > current_version {
79            info!(
80                "Running migration {}: {}",
81                migration.version, migration.name
82            );
83
84            // Run migration in a transaction
85            let tx = conn
86                .unchecked_transaction()
87                .map_err(|e| CPError::Database(format!("Failed to start transaction: {e}")))?;
88
89            tx.execute_batch(migration.sql).map_err(|e| {
90                CPError::Database(format!(
91                    "Migration {} ({}) failed: {}",
92                    migration.version, migration.name, e
93                ))
94            })?;
95
96            // Record migration
97            tx.execute(
98                "INSERT INTO schema_version (version, name) VALUES (?1, ?2)",
99                rusqlite::params![migration.version, migration.name],
100            )
101            .map_err(|e| CPError::Database(format!("Failed to record migration: {e}")))?;
102
103            tx.commit()
104                .map_err(|e| CPError::Database(format!("Failed to commit migration: {e}")))?;
105
106            info!("Migration {} complete", migration.version);
107        }
108    }
109
110    info!("All migrations complete");
111    Ok(())
112}
113
114/// Get the current schema version
115pub fn get_schema_version(conn: &Connection) -> Result<u32> {
116    // Check if schema_version table exists
117    let table_exists: bool = conn
118        .query_row(
119            "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='schema_version')",
120            [],
121            |row| row.get(0),
122        )
123        .map_err(|e| CPError::Database(e.to_string()))?;
124
125    if !table_exists {
126        return Ok(0);
127    }
128
129    conn.query_row(
130        "SELECT COALESCE(MAX(version), 0) FROM schema_version",
131        [],
132        |row| row.get(0),
133    )
134    .map_err(|e| CPError::Database(e.to_string()))
135}
136
137/// Check if the database needs migration
138pub fn needs_migration(conn: &Connection) -> Result<bool> {
139    let current = get_schema_version(conn)?;
140    let latest = MIGRATIONS.last().map_or(0, |m| m.version);
141    Ok(current < latest)
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_migrations_run_idempotent() {
150        let conn = Connection::open_in_memory().unwrap();
151
152        // Run migrations twice - should not fail
153        run_migrations(&conn).unwrap();
154        run_migrations(&conn).unwrap();
155
156        let version = get_schema_version(&conn).unwrap();
157        assert_eq!(version, 6);
158    }
159
160    #[test]
161    fn test_schema_version_tracking() {
162        let conn = Connection::open_in_memory().unwrap();
163
164        assert_eq!(get_schema_version(&conn).unwrap(), 0);
165        assert!(needs_migration(&conn).unwrap());
166
167        run_migrations(&conn).unwrap();
168
169        assert_eq!(get_schema_version(&conn).unwrap(), 6);
170        assert!(!needs_migration(&conn).unwrap());
171    }
172
173    #[test]
174    fn test_timestamps_exist() {
175        let conn = Connection::open_in_memory().unwrap();
176        run_migrations(&conn).unwrap();
177
178        // Check that timestamp columns exist
179        let has_created_at: bool = conn
180            .query_row(
181                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
182                [],
183                |row| row.get(0),
184            )
185            .unwrap();
186        assert!(has_created_at);
187
188        let has_updated_at: bool = conn
189            .query_row(
190                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
191                [],
192                |row| row.get(0),
193            )
194            .unwrap();
195        assert!(has_updated_at);
196    }
197
198    #[test]
199    fn test_l2_norm_column_exists() {
200        let conn = Connection::open_in_memory().unwrap();
201        run_migrations(&conn).unwrap();
202
203        let has_l2_norm: bool = conn
204            .query_row(
205                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
206                [],
207                |row| row.get(0),
208            )
209            .unwrap();
210        assert!(has_l2_norm);
211    }
212
213    // ========== Additional Migration Tests ==========
214
215    #[test]
216    fn test_migration_runner_initial_schema() {
217        let conn = Connection::open_in_memory().unwrap();
218
219        // Run migrations on fresh database
220        run_migrations(&conn).unwrap();
221
222        // Verify all tables exist
223        let doc_count: i64 = conn
224            .query_row(
225                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='documents'",
226                [],
227                |row| row.get(0),
228            )
229            .unwrap();
230        assert_eq!(doc_count, 1);
231
232        let chunk_count: i64 = conn
233            .query_row(
234                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='chunks'",
235                [],
236                |row| row.get(0),
237            )
238            .unwrap();
239        assert_eq!(chunk_count, 1);
240
241        let emb_count: i64 = conn
242            .query_row(
243                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings'",
244                [],
245                |row| row.get(0),
246            )
247            .unwrap();
248        assert_eq!(emb_count, 1);
249
250        let edge_count: i64 = conn
251            .query_row(
252                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='edges'",
253                [],
254                |row| row.get(0),
255            )
256            .unwrap();
257        assert_eq!(edge_count, 1);
258
259        let state_root_count: i64 = conn
260            .query_row(
261                "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='state_roots'",
262                [],
263                |row| row.get(0),
264            )
265            .unwrap();
266        assert_eq!(state_root_count, 1);
267    }
268
269    #[test]
270    fn test_migration_already_applied() {
271        let conn = Connection::open_in_memory().unwrap();
272
273        // Run migrations first time
274        run_migrations(&conn).unwrap();
275        let version1 = get_schema_version(&conn).unwrap();
276        assert_eq!(version1, 6);
277
278        // Run again - should skip already applied migrations
279        run_migrations(&conn).unwrap();
280        let version2 = get_schema_version(&conn).unwrap();
281        assert_eq!(version2, 6);
282    }
283
284    #[test]
285    fn test_migration_001_documents_table() {
286        let conn = Connection::open_in_memory().unwrap();
287        run_migrations(&conn).unwrap();
288
289        // Check documents table schema
290        let has_id: bool = conn
291            .query_row(
292                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'id'",
293                [],
294                |row| row.get(0),
295            )
296            .unwrap();
297        assert!(has_id);
298
299        let has_path: bool = conn
300            .query_row(
301                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path'",
302                [],
303                |row| row.get(0),
304            )
305            .unwrap();
306        assert!(has_path);
307
308        let has_hash: bool = conn
309            .query_row(
310                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'hash'",
311                [],
312                |row| row.get(0),
313            )
314            .unwrap();
315        assert!(has_hash);
316
317        let has_mtime: bool = conn
318            .query_row(
319                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mtime'",
320                [],
321                |row| row.get(0),
322            )
323            .unwrap();
324        assert!(has_mtime);
325
326        let has_size: bool = conn
327            .query_row(
328                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'size'",
329                [],
330                |row| row.get(0),
331            )
332            .unwrap();
333        assert!(has_size);
334
335        let has_mime_type: bool = conn
336            .query_row(
337                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mime_type'",
338                [],
339                |row| row.get(0),
340            )
341            .unwrap();
342        assert!(has_mime_type);
343    }
344
345    #[test]
346    fn test_migration_002_timestamps() {
347        let conn = Connection::open_in_memory().unwrap();
348        run_migrations(&conn).unwrap();
349
350        // Verify created_at exists on documents
351        let created_at_exists: bool = conn
352            .query_row(
353                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
354                [],
355                |row| row.get(0),
356            )
357            .unwrap();
358        assert!(created_at_exists);
359
360        // Verify updated_at exists on documents
361        let updated_at_exists: bool = conn
362            .query_row(
363                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
364                [],
365                |row| row.get(0),
366            )
367            .unwrap();
368        assert!(updated_at_exists);
369
370        // Verify created_at exists on chunks
371        let chunk_created_at: bool = conn
372            .query_row(
373                "SELECT COUNT(*) > 0 FROM pragma_table_info('chunks') WHERE name = 'created_at'",
374                [],
375                |row| row.get(0),
376            )
377            .unwrap();
378        assert!(chunk_created_at);
379
380        // Verify created_at exists on embeddings
381        let emb_created_at: bool = conn
382            .query_row(
383                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'created_at'",
384                [],
385                |row| row.get(0),
386            )
387            .unwrap();
388        assert!(emb_created_at);
389    }
390
391    #[test]
392    fn test_migration_003_l2_norm() {
393        let conn = Connection::open_in_memory().unwrap();
394        run_migrations(&conn).unwrap();
395
396        // Verify l2_norm column exists on embeddings
397        let l2_norm_exists: bool = conn
398            .query_row(
399                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
400                [],
401                |row| row.get(0),
402            )
403            .unwrap();
404        assert!(l2_norm_exists);
405    }
406
407    #[test]
408    fn test_migration_004_path_id_embedding_version() {
409        let conn = Connection::open_in_memory().unwrap();
410        run_migrations(&conn).unwrap();
411
412        // Verify path_id column exists on documents
413        let path_id_exists: bool = conn
414            .query_row(
415                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path_id'",
416                [],
417                |row| row.get(0),
418            )
419            .unwrap();
420        assert!(path_id_exists);
421
422        // Verify embedding_version column exists on embeddings
423        let emb_version_exists: bool = conn
424            .query_row(
425                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'embedding_version'",
426                [],
427                |row| row.get(0),
428            )
429            .unwrap();
430        assert!(emb_version_exists);
431    }
432
433    #[test]
434    fn test_migration_foreign_keys() {
435        let conn = Connection::open_in_memory().unwrap();
436        run_migrations(&conn).unwrap();
437
438        // Enable foreign keys
439        conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
440
441        // Insert a document
442        conn.execute(
443            "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
444            rusqlite::params![
445                uuid::Uuid::new_v4().as_bytes(),
446                "test.md",
447                [0u8; 32].as_slice(),
448                [0u8; 32].as_slice(),
449                0i64,
450                0i64,
451                "text/markdown"
452            ],
453        ).unwrap();
454
455        // Verify chunks table has foreign key to documents
456        let fk_exists: bool = conn
457            .query_row(
458                "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('chunks')",
459                [],
460                |row| row.get(0),
461            )
462            .unwrap();
463        assert!(fk_exists);
464
465        // Verify embeddings table has foreign key to chunks
466        let emb_fk_exists: bool = conn
467            .query_row(
468                "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('embeddings')",
469                [],
470                |row| row.get(0),
471            )
472            .unwrap();
473        assert!(emb_fk_exists);
474    }
475
476    #[test]
477    fn test_migration_fts_triggers() {
478        let conn = Connection::open_in_memory().unwrap();
479        run_migrations(&conn).unwrap();
480
481        // Verify FTS virtual table exists
482        let fts_exists: bool = conn
483            .query_row(
484                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='fts_chunks'",
485                [],
486                |row| row.get(0),
487            )
488            .unwrap();
489        assert!(fts_exists);
490
491        // Verify insert trigger exists
492        let ai_trigger: bool = conn
493            .query_row(
494                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ai'",
495                [],
496                |row| row.get(0),
497            )
498            .unwrap();
499        assert!(ai_trigger);
500
501        // Verify delete trigger exists
502        let ad_trigger: bool = conn
503            .query_row(
504                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ad'",
505                [],
506                |row| row.get(0),
507            )
508            .unwrap();
509        assert!(ad_trigger);
510
511        // Verify update trigger exists
512        let au_trigger: bool = conn
513            .query_row(
514                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_au'",
515                [],
516                |row| row.get(0),
517            )
518            .unwrap();
519        assert!(au_trigger);
520    }
521
522    #[test]
523    fn test_migration_fts_content_sync() {
524        let conn = Connection::open_in_memory().unwrap();
525        run_migrations(&conn).unwrap();
526
527        // Insert a document
528        let doc_id = uuid::Uuid::new_v4();
529        conn.execute(
530            "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
531            rusqlite::params![
532                doc_id.as_bytes(),
533                "test.md",
534                [0u8; 32].as_slice(),
535                [0u8; 32].as_slice(),
536                0i64,
537                0i64,
538                "text/markdown"
539            ],
540        ).unwrap();
541
542        // Insert a chunk
543        let chunk_id = uuid::Uuid::new_v4();
544        conn.execute(
545            "INSERT INTO chunks (id, doc_id, text, byte_offset, byte_length, sequence, text_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
546            rusqlite::params![
547                chunk_id.as_bytes(),
548                doc_id.as_bytes(),
549                "test content for search",
550                0i64,
551                0i64,
552                0u32,
553                [0u8; 32].as_slice()
554            ],
555        ).unwrap();
556
557        // Verify FTS table has the content
558        let fts_count: i64 = conn
559            .query_row(
560                "SELECT COUNT(*) FROM fts_chunks WHERE fts_chunks MATCH 'test'",
561                [],
562                |row| row.get(0),
563            )
564            .unwrap();
565        assert!(fts_count > 0);
566    }
567
568    #[test]
569    fn test_migration_indexes() {
570        let conn = Connection::open_in_memory().unwrap();
571        run_migrations(&conn).unwrap();
572
573        // Check for idx_chunks_doc_id
574        let idx1: bool = conn
575            .query_row(
576                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_chunks_doc_id'",
577                [],
578                |row| row.get(0),
579            )
580            .unwrap();
581        assert!(idx1);
582
583        // Check for idx_embeddings_chunk_id
584        let idx2: bool = conn
585            .query_row(
586                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_embeddings_chunk_id'",
587                [],
588                |row| row.get(0),
589            )
590            .unwrap();
591        assert!(idx2);
592
593        // Check for idx_edges_source
594        let idx3: bool = conn
595            .query_row(
596                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_source'",
597                [],
598                |row| row.get(0),
599            )
600            .unwrap();
601        assert!(idx3);
602
603        // Check for idx_edges_target
604        let idx4: bool = conn
605            .query_row(
606                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_target'",
607                [],
608                |row| row.get(0),
609            )
610            .unwrap();
611        assert!(idx4);
612    }
613
614    #[test]
615    fn test_schema_version_table_structure() {
616        let conn = Connection::open_in_memory().unwrap();
617
618        // Create the schema_version table manually to test structure
619        conn.execute_batch(
620            r"
621            CREATE TABLE IF NOT EXISTS schema_version (
622                version INTEGER PRIMARY KEY,
623                name TEXT NOT NULL,
624                applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
625            );
626            ",
627        )
628        .unwrap();
629
630        // Insert a test migration record
631        conn.execute(
632            "INSERT INTO schema_version (version, name) VALUES (1, 'test_migration')",
633            [],
634        )
635        .unwrap();
636
637        // Verify it was recorded
638        let version: i64 = conn
639            .query_row(
640                "SELECT version FROM schema_version WHERE name = 'test_migration'",
641                [],
642                |row| row.get(0),
643            )
644            .unwrap();
645        assert_eq!(version, 1);
646
647        let name: String = conn
648            .query_row(
649                "SELECT name FROM schema_version WHERE version = 1",
650                [],
651                |row| row.get(0),
652            )
653            .unwrap();
654        assert_eq!(name, "test_migration");
655    }
656}