Skip to main content

cp_graph/
migrations.rs

1//! Database migration system for CP
2//!
3//! Per CP-002: Provides versioned schema migrations with automatic upgrade.
4//! Migrations are embedded in the binary and run in order.
5
6use cp_core::{CPError, Result};
7use rusqlite::Connection;
8use tracing::info;
9
10/// Migration definition
11struct Migration {
12    version: u32,
13    name: &'static str,
14    sql: &'static str,
15}
16
17/// All migrations in order
18const MIGRATIONS: &[Migration] = &[
19    Migration {
20        version: 1,
21        name: "initial_schema",
22        sql: include_str!("migrations/001_initial.sql"),
23    },
24    Migration {
25        version: 2,
26        name: "add_timestamps",
27        sql: include_str!("migrations/002_add_timestamps.sql"),
28    },
29    Migration {
30        version: 3,
31        name: "add_l2_norm",
32        sql: include_str!("migrations/003_add_l2_norm.sql"),
33    },
34    Migration {
35        version: 4,
36        name: "add_path_id_and_embedding_version",
37        sql: include_str!("migrations/004_add_path_id_and_embedding_version.sql"),
38    },
39    Migration {
40        version: 5,
41        name: "add_arweave_tx",
42        sql: include_str!("migrations/005_add_arweave_tx.sql"),
43    },
44];
45
46/// Run all pending migrations on the database
47pub fn run_migrations(conn: &Connection) -> Result<()> {
48    // Create schema_version table if it doesn't exist
49    conn.execute_batch(
50        r#"
51        CREATE TABLE IF NOT EXISTS schema_version (
52            version INTEGER PRIMARY KEY,
53            name TEXT NOT NULL,
54            applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
55        );
56        "#,
57    )
58    .map_err(|e| CPError::Database(format!("Failed to create schema_version table: {}", e)))?;
59
60    // Get current version
61    let current_version: u32 = conn
62        .query_row(
63            "SELECT COALESCE(MAX(version), 0) FROM schema_version",
64            [],
65            |row| row.get(0),
66        )
67        .map_err(|e| CPError::Database(format!("Failed to get schema version: {}", e)))?;
68
69    info!("Current schema version: {}", current_version);
70
71    // Run pending migrations
72    for migration in MIGRATIONS {
73        if migration.version > current_version {
74            info!(
75                "Running migration {}: {}",
76                migration.version, migration.name
77            );
78
79            // Run migration in a transaction
80            let tx = conn
81                .unchecked_transaction()
82                .map_err(|e| CPError::Database(format!("Failed to start transaction: {}", e)))?;
83
84            tx.execute_batch(migration.sql)
85                .map_err(|e| {
86                    CPError::Database(format!(
87                        "Migration {} ({}) failed: {}",
88                        migration.version, migration.name, e
89                    ))
90                })?;
91
92            // Record migration
93            tx.execute(
94                "INSERT INTO schema_version (version, name) VALUES (?1, ?2)",
95                rusqlite::params![migration.version, migration.name],
96            )
97            .map_err(|e| {
98                CPError::Database(format!("Failed to record migration: {}", e))
99            })?;
100
101            tx.commit()
102                .map_err(|e| CPError::Database(format!("Failed to commit migration: {}", e)))?;
103
104            info!("Migration {} complete", migration.version);
105        }
106    }
107
108    info!("All migrations complete");
109    Ok(())
110}
111
112/// Get the current schema version
113pub fn get_schema_version(conn: &Connection) -> Result<u32> {
114    // Check if schema_version table exists
115    let table_exists: bool = conn
116        .query_row(
117            "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='schema_version')",
118            [],
119            |row| row.get(0),
120        )
121        .map_err(|e| CPError::Database(e.to_string()))?;
122
123    if !table_exists {
124        return Ok(0);
125    }
126
127    conn.query_row(
128        "SELECT COALESCE(MAX(version), 0) FROM schema_version",
129        [],
130        |row| row.get(0),
131    )
132    .map_err(|e| CPError::Database(e.to_string()))
133}
134
135/// Check if the database needs migration
136pub fn needs_migration(conn: &Connection) -> Result<bool> {
137    let current = get_schema_version(conn)?;
138    let latest = MIGRATIONS.last().map(|m| m.version).unwrap_or(0);
139    Ok(current < latest)
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145
146    #[test]
147    fn test_migrations_run_idempotent() {
148        let conn = Connection::open_in_memory().unwrap();
149
150        // Run migrations twice - should not fail
151        run_migrations(&conn).unwrap();
152        run_migrations(&conn).unwrap();
153
154        let version = get_schema_version(&conn).unwrap();
155        assert_eq!(version, 5);
156    }
157
158    #[test]
159    fn test_schema_version_tracking() {
160        let conn = Connection::open_in_memory().unwrap();
161
162        assert_eq!(get_schema_version(&conn).unwrap(), 0);
163        assert!(needs_migration(&conn).unwrap());
164
165        run_migrations(&conn).unwrap();
166
167        assert_eq!(get_schema_version(&conn).unwrap(), 5);
168        assert!(!needs_migration(&conn).unwrap());
169    }
170
171    #[test]
172    fn test_timestamps_exist() {
173        let conn = Connection::open_in_memory().unwrap();
174        run_migrations(&conn).unwrap();
175
176        // Check that timestamp columns exist
177        let has_created_at: bool = conn
178            .query_row(
179                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
180                [],
181                |row| row.get(0),
182            )
183            .unwrap();
184        assert!(has_created_at);
185
186        let has_updated_at: bool = conn
187            .query_row(
188                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
189                [],
190                |row| row.get(0),
191            )
192            .unwrap();
193        assert!(has_updated_at);
194    }
195
196    #[test]
197    fn test_l2_norm_column_exists() {
198        let conn = Connection::open_in_memory().unwrap();
199        run_migrations(&conn).unwrap();
200
201        let has_l2_norm: bool = conn
202            .query_row(
203                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
204                [],
205                |row| row.get(0),
206            )
207            .unwrap();
208        assert!(has_l2_norm);
209    }
210
211    // ========== Additional Migration Tests ==========
212
213    #[test]
214    fn test_migration_runner_initial_schema() {
215        let conn = Connection::open_in_memory().unwrap();
216
217        // Run migrations on fresh database
218        run_migrations(&conn).unwrap();
219
220        // Verify all tables exist
221        let doc_count: i64 = conn
222            .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='documents'", [], |row| row.get(0))
223            .unwrap();
224        assert_eq!(doc_count, 1);
225
226        let chunk_count: i64 = conn
227            .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='chunks'", [], |row| row.get(0))
228            .unwrap();
229        assert_eq!(chunk_count, 1);
230
231        let emb_count: i64 = conn
232            .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings'", [], |row| row.get(0))
233            .unwrap();
234        assert_eq!(emb_count, 1);
235
236        let edge_count: i64 = conn
237            .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='edges'", [], |row| row.get(0))
238            .unwrap();
239        assert_eq!(edge_count, 1);
240
241        let state_root_count: i64 = conn
242            .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='state_roots'", [], |row| row.get(0))
243            .unwrap();
244        assert_eq!(state_root_count, 1);
245    }
246
247    #[test]
248    fn test_migration_already_applied() {
249        let conn = Connection::open_in_memory().unwrap();
250
251        // Run migrations first time
252        run_migrations(&conn).unwrap();
253        let version1 = get_schema_version(&conn).unwrap();
254        assert_eq!(version1, 5);
255
256        // Run again - should skip already applied migrations
257        run_migrations(&conn).unwrap();
258        let version2 = get_schema_version(&conn).unwrap();
259        assert_eq!(version2, 5);
260    }
261
262    #[test]
263    fn test_migration_001_documents_table() {
264        let conn = Connection::open_in_memory().unwrap();
265        run_migrations(&conn).unwrap();
266
267        // Check documents table schema
268        let has_id: bool = conn
269            .query_row(
270                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'id'",
271                [],
272                |row| row.get(0),
273            )
274            .unwrap();
275        assert!(has_id);
276
277        let has_path: bool = conn
278            .query_row(
279                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path'",
280                [],
281                |row| row.get(0),
282            )
283            .unwrap();
284        assert!(has_path);
285
286        let has_hash: bool = conn
287            .query_row(
288                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'hash'",
289                [],
290                |row| row.get(0),
291            )
292            .unwrap();
293        assert!(has_hash);
294
295        let has_mtime: bool = conn
296            .query_row(
297                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mtime'",
298                [],
299                |row| row.get(0),
300            )
301            .unwrap();
302        assert!(has_mtime);
303
304        let has_size: bool = conn
305            .query_row(
306                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'size'",
307                [],
308                |row| row.get(0),
309            )
310            .unwrap();
311        assert!(has_size);
312
313        let has_mime_type: bool = conn
314            .query_row(
315                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mime_type'",
316                [],
317                |row| row.get(0),
318            )
319            .unwrap();
320        assert!(has_mime_type);
321    }
322
323    #[test]
324    fn test_migration_002_timestamps() {
325        let conn = Connection::open_in_memory().unwrap();
326        run_migrations(&conn).unwrap();
327
328        // Verify created_at exists on documents
329        let created_at_exists: bool = conn
330            .query_row(
331                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
332                [],
333                |row| row.get(0),
334            )
335            .unwrap();
336        assert!(created_at_exists);
337
338        // Verify updated_at exists on documents
339        let updated_at_exists: bool = conn
340            .query_row(
341                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
342                [],
343                |row| row.get(0),
344            )
345            .unwrap();
346        assert!(updated_at_exists);
347
348        // Verify created_at exists on chunks
349        let chunk_created_at: bool = conn
350            .query_row(
351                "SELECT COUNT(*) > 0 FROM pragma_table_info('chunks') WHERE name = 'created_at'",
352                [],
353                |row| row.get(0),
354            )
355            .unwrap();
356        assert!(chunk_created_at);
357
358        // Verify created_at exists on embeddings
359        let emb_created_at: bool = conn
360            .query_row(
361                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'created_at'",
362                [],
363                |row| row.get(0),
364            )
365            .unwrap();
366        assert!(emb_created_at);
367    }
368
369    #[test]
370    fn test_migration_003_l2_norm() {
371        let conn = Connection::open_in_memory().unwrap();
372        run_migrations(&conn).unwrap();
373
374        // Verify l2_norm column exists on embeddings
375        let l2_norm_exists: bool = conn
376            .query_row(
377                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
378                [],
379                |row| row.get(0),
380            )
381            .unwrap();
382        assert!(l2_norm_exists);
383    }
384
385    #[test]
386    fn test_migration_004_path_id_embedding_version() {
387        let conn = Connection::open_in_memory().unwrap();
388        run_migrations(&conn).unwrap();
389
390        // Verify path_id column exists on documents
391        let path_id_exists: bool = conn
392            .query_row(
393                "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path_id'",
394                [],
395                |row| row.get(0),
396            )
397            .unwrap();
398        assert!(path_id_exists);
399
400        // Verify embedding_version column exists on embeddings
401        let emb_version_exists: bool = conn
402            .query_row(
403                "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'embedding_version'",
404                [],
405                |row| row.get(0),
406            )
407            .unwrap();
408        assert!(emb_version_exists);
409    }
410
411    #[test]
412    fn test_migration_foreign_keys() {
413        let conn = Connection::open_in_memory().unwrap();
414        run_migrations(&conn).unwrap();
415
416        // Enable foreign keys
417        conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
418
419        // Insert a document
420        conn.execute(
421            "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
422            rusqlite::params![
423                uuid::Uuid::new_v4().as_bytes(),
424                "test.md",
425                [0u8; 32].as_slice(),
426                [0u8; 32].as_slice(),
427                0i64,
428                0i64,
429                "text/markdown"
430            ],
431        ).unwrap();
432
433        // Verify chunks table has foreign key to documents
434        let fk_exists: bool = conn
435            .query_row(
436                "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('chunks')",
437                [],
438                |row| row.get(0),
439            )
440            .unwrap();
441        assert!(fk_exists);
442
443        // Verify embeddings table has foreign key to chunks
444        let emb_fk_exists: bool = conn
445            .query_row(
446                "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('embeddings')",
447                [],
448                |row| row.get(0),
449            )
450            .unwrap();
451        assert!(emb_fk_exists);
452    }
453
454    #[test]
455    fn test_migration_fts_triggers() {
456        let conn = Connection::open_in_memory().unwrap();
457        run_migrations(&conn).unwrap();
458
459        // Verify FTS virtual table exists
460        let fts_exists: bool = conn
461            .query_row(
462                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='fts_chunks'",
463                [],
464                |row| row.get(0),
465            )
466            .unwrap();
467        assert!(fts_exists);
468
469        // Verify insert trigger exists
470        let ai_trigger: bool = conn
471            .query_row(
472                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ai'",
473                [],
474                |row| row.get(0),
475            )
476            .unwrap();
477        assert!(ai_trigger);
478
479        // Verify delete trigger exists
480        let ad_trigger: bool = conn
481            .query_row(
482                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ad'",
483                [],
484                |row| row.get(0),
485            )
486            .unwrap();
487        assert!(ad_trigger);
488
489        // Verify update trigger exists
490        let au_trigger: bool = conn
491            .query_row(
492                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_au'",
493                [],
494                |row| row.get(0),
495            )
496            .unwrap();
497        assert!(au_trigger);
498    }
499
500    #[test]
501    fn test_migration_fts_content_sync() {
502        let conn = Connection::open_in_memory().unwrap();
503        run_migrations(&conn).unwrap();
504
505        // Insert a document
506        let doc_id = uuid::Uuid::new_v4();
507        conn.execute(
508            "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
509            rusqlite::params![
510                doc_id.as_bytes(),
511                "test.md",
512                [0u8; 32].as_slice(),
513                [0u8; 32].as_slice(),
514                0i64,
515                0i64,
516                "text/markdown"
517            ],
518        ).unwrap();
519
520        // Insert a chunk
521        let chunk_id = uuid::Uuid::new_v4();
522        conn.execute(
523            "INSERT INTO chunks (id, doc_id, text, byte_offset, byte_length, sequence, text_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
524            rusqlite::params![
525                chunk_id.as_bytes(),
526                doc_id.as_bytes(),
527                "test content for search",
528                0i64,
529                0i64,
530                0u32,
531                [0u8; 32].as_slice()
532            ],
533        ).unwrap();
534
535        // Verify FTS table has the content
536        let fts_count: i64 = conn
537            .query_row(
538                "SELECT COUNT(*) FROM fts_chunks WHERE fts_chunks MATCH 'test'",
539                [],
540                |row| row.get(0),
541            )
542            .unwrap();
543        assert!(fts_count > 0);
544    }
545
546    #[test]
547    fn test_migration_indexes() {
548        let conn = Connection::open_in_memory().unwrap();
549        run_migrations(&conn).unwrap();
550
551        // Check for idx_chunks_doc_id
552        let idx1: bool = conn
553            .query_row(
554                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_chunks_doc_id'",
555                [],
556                |row| row.get(0),
557            )
558            .unwrap();
559        assert!(idx1);
560
561        // Check for idx_embeddings_chunk_id
562        let idx2: bool = conn
563            .query_row(
564                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_embeddings_chunk_id'",
565                [],
566                |row| row.get(0),
567            )
568            .unwrap();
569        assert!(idx2);
570
571        // Check for idx_edges_source
572        let idx3: bool = conn
573            .query_row(
574                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_source'",
575                [],
576                |row| row.get(0),
577            )
578            .unwrap();
579        assert!(idx3);
580
581        // Check for idx_edges_target
582        let idx4: bool = conn
583            .query_row(
584                "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_target'",
585                [],
586                |row| row.get(0),
587            )
588            .unwrap();
589        assert!(idx4);
590    }
591
592    #[test]
593    fn test_schema_version_table_structure() {
594        let conn = Connection::open_in_memory().unwrap();
595
596        // Create the schema_version table manually to test structure
597        conn.execute_batch(
598            r#"
599            CREATE TABLE IF NOT EXISTS schema_version (
600                version INTEGER PRIMARY KEY,
601                name TEXT NOT NULL,
602                applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
603            );
604            "#,
605        ).unwrap();
606
607        // Insert a test migration record
608        conn.execute(
609            "INSERT INTO schema_version (version, name) VALUES (1, 'test_migration')",
610            [],
611        ).unwrap();
612
613        // Verify it was recorded
614        let version: i64 = conn
615            .query_row("SELECT version FROM schema_version WHERE name = 'test_migration'", [], |row| row.get(0))
616            .unwrap();
617        assert_eq!(version, 1);
618
619        let name: String = conn
620            .query_row("SELECT name FROM schema_version WHERE version = 1", [], |row| row.get(0))
621            .unwrap();
622        assert_eq!(name, "test_migration");
623    }
624}