1use cp_core::{CPError, Result};
7use rusqlite::Connection;
8use tracing::info;
9
10struct Migration {
12 version: u32,
13 name: &'static str,
14 sql: &'static str,
15}
16
17const MIGRATIONS: &[Migration] = &[
19 Migration {
20 version: 1,
21 name: "initial_schema",
22 sql: include_str!("migrations/001_initial.sql"),
23 },
24 Migration {
25 version: 2,
26 name: "add_timestamps",
27 sql: include_str!("migrations/002_add_timestamps.sql"),
28 },
29 Migration {
30 version: 3,
31 name: "add_l2_norm",
32 sql: include_str!("migrations/003_add_l2_norm.sql"),
33 },
34 Migration {
35 version: 4,
36 name: "add_path_id_and_embedding_version",
37 sql: include_str!("migrations/004_add_path_id_and_embedding_version.sql"),
38 },
39 Migration {
40 version: 5,
41 name: "add_arweave_tx",
42 sql: include_str!("migrations/005_add_arweave_tx.sql"),
43 },
44];
45
46pub fn run_migrations(conn: &Connection) -> Result<()> {
48 conn.execute_batch(
50 r#"
51 CREATE TABLE IF NOT EXISTS schema_version (
52 version INTEGER PRIMARY KEY,
53 name TEXT NOT NULL,
54 applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
55 );
56 "#,
57 )
58 .map_err(|e| CPError::Database(format!("Failed to create schema_version table: {}", e)))?;
59
60 let current_version: u32 = conn
62 .query_row(
63 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
64 [],
65 |row| row.get(0),
66 )
67 .map_err(|e| CPError::Database(format!("Failed to get schema version: {}", e)))?;
68
69 info!("Current schema version: {}", current_version);
70
71 for migration in MIGRATIONS {
73 if migration.version > current_version {
74 info!(
75 "Running migration {}: {}",
76 migration.version, migration.name
77 );
78
79 let tx = conn
81 .unchecked_transaction()
82 .map_err(|e| CPError::Database(format!("Failed to start transaction: {}", e)))?;
83
84 tx.execute_batch(migration.sql)
85 .map_err(|e| {
86 CPError::Database(format!(
87 "Migration {} ({}) failed: {}",
88 migration.version, migration.name, e
89 ))
90 })?;
91
92 tx.execute(
94 "INSERT INTO schema_version (version, name) VALUES (?1, ?2)",
95 rusqlite::params![migration.version, migration.name],
96 )
97 .map_err(|e| {
98 CPError::Database(format!("Failed to record migration: {}", e))
99 })?;
100
101 tx.commit()
102 .map_err(|e| CPError::Database(format!("Failed to commit migration: {}", e)))?;
103
104 info!("Migration {} complete", migration.version);
105 }
106 }
107
108 info!("All migrations complete");
109 Ok(())
110}
111
112pub fn get_schema_version(conn: &Connection) -> Result<u32> {
114 let table_exists: bool = conn
116 .query_row(
117 "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='schema_version')",
118 [],
119 |row| row.get(0),
120 )
121 .map_err(|e| CPError::Database(e.to_string()))?;
122
123 if !table_exists {
124 return Ok(0);
125 }
126
127 conn.query_row(
128 "SELECT COALESCE(MAX(version), 0) FROM schema_version",
129 [],
130 |row| row.get(0),
131 )
132 .map_err(|e| CPError::Database(e.to_string()))
133}
134
135pub fn needs_migration(conn: &Connection) -> Result<bool> {
137 let current = get_schema_version(conn)?;
138 let latest = MIGRATIONS.last().map(|m| m.version).unwrap_or(0);
139 Ok(current < latest)
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_migrations_run_idempotent() {
148 let conn = Connection::open_in_memory().unwrap();
149
150 run_migrations(&conn).unwrap();
152 run_migrations(&conn).unwrap();
153
154 let version = get_schema_version(&conn).unwrap();
155 assert_eq!(version, 5);
156 }
157
158 #[test]
159 fn test_schema_version_tracking() {
160 let conn = Connection::open_in_memory().unwrap();
161
162 assert_eq!(get_schema_version(&conn).unwrap(), 0);
163 assert!(needs_migration(&conn).unwrap());
164
165 run_migrations(&conn).unwrap();
166
167 assert_eq!(get_schema_version(&conn).unwrap(), 5);
168 assert!(!needs_migration(&conn).unwrap());
169 }
170
171 #[test]
172 fn test_timestamps_exist() {
173 let conn = Connection::open_in_memory().unwrap();
174 run_migrations(&conn).unwrap();
175
176 let has_created_at: bool = conn
178 .query_row(
179 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
180 [],
181 |row| row.get(0),
182 )
183 .unwrap();
184 assert!(has_created_at);
185
186 let has_updated_at: bool = conn
187 .query_row(
188 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
189 [],
190 |row| row.get(0),
191 )
192 .unwrap();
193 assert!(has_updated_at);
194 }
195
196 #[test]
197 fn test_l2_norm_column_exists() {
198 let conn = Connection::open_in_memory().unwrap();
199 run_migrations(&conn).unwrap();
200
201 let has_l2_norm: bool = conn
202 .query_row(
203 "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
204 [],
205 |row| row.get(0),
206 )
207 .unwrap();
208 assert!(has_l2_norm);
209 }
210
211 #[test]
214 fn test_migration_runner_initial_schema() {
215 let conn = Connection::open_in_memory().unwrap();
216
217 run_migrations(&conn).unwrap();
219
220 let doc_count: i64 = conn
222 .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='documents'", [], |row| row.get(0))
223 .unwrap();
224 assert_eq!(doc_count, 1);
225
226 let chunk_count: i64 = conn
227 .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='chunks'", [], |row| row.get(0))
228 .unwrap();
229 assert_eq!(chunk_count, 1);
230
231 let emb_count: i64 = conn
232 .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embeddings'", [], |row| row.get(0))
233 .unwrap();
234 assert_eq!(emb_count, 1);
235
236 let edge_count: i64 = conn
237 .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='edges'", [], |row| row.get(0))
238 .unwrap();
239 assert_eq!(edge_count, 1);
240
241 let state_root_count: i64 = conn
242 .query_row("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='state_roots'", [], |row| row.get(0))
243 .unwrap();
244 assert_eq!(state_root_count, 1);
245 }
246
247 #[test]
248 fn test_migration_already_applied() {
249 let conn = Connection::open_in_memory().unwrap();
250
251 run_migrations(&conn).unwrap();
253 let version1 = get_schema_version(&conn).unwrap();
254 assert_eq!(version1, 5);
255
256 run_migrations(&conn).unwrap();
258 let version2 = get_schema_version(&conn).unwrap();
259 assert_eq!(version2, 5);
260 }
261
262 #[test]
263 fn test_migration_001_documents_table() {
264 let conn = Connection::open_in_memory().unwrap();
265 run_migrations(&conn).unwrap();
266
267 let has_id: bool = conn
269 .query_row(
270 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'id'",
271 [],
272 |row| row.get(0),
273 )
274 .unwrap();
275 assert!(has_id);
276
277 let has_path: bool = conn
278 .query_row(
279 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path'",
280 [],
281 |row| row.get(0),
282 )
283 .unwrap();
284 assert!(has_path);
285
286 let has_hash: bool = conn
287 .query_row(
288 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'hash'",
289 [],
290 |row| row.get(0),
291 )
292 .unwrap();
293 assert!(has_hash);
294
295 let has_mtime: bool = conn
296 .query_row(
297 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mtime'",
298 [],
299 |row| row.get(0),
300 )
301 .unwrap();
302 assert!(has_mtime);
303
304 let has_size: bool = conn
305 .query_row(
306 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'size'",
307 [],
308 |row| row.get(0),
309 )
310 .unwrap();
311 assert!(has_size);
312
313 let has_mime_type: bool = conn
314 .query_row(
315 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'mime_type'",
316 [],
317 |row| row.get(0),
318 )
319 .unwrap();
320 assert!(has_mime_type);
321 }
322
323 #[test]
324 fn test_migration_002_timestamps() {
325 let conn = Connection::open_in_memory().unwrap();
326 run_migrations(&conn).unwrap();
327
328 let created_at_exists: bool = conn
330 .query_row(
331 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'created_at'",
332 [],
333 |row| row.get(0),
334 )
335 .unwrap();
336 assert!(created_at_exists);
337
338 let updated_at_exists: bool = conn
340 .query_row(
341 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'updated_at'",
342 [],
343 |row| row.get(0),
344 )
345 .unwrap();
346 assert!(updated_at_exists);
347
348 let chunk_created_at: bool = conn
350 .query_row(
351 "SELECT COUNT(*) > 0 FROM pragma_table_info('chunks') WHERE name = 'created_at'",
352 [],
353 |row| row.get(0),
354 )
355 .unwrap();
356 assert!(chunk_created_at);
357
358 let emb_created_at: bool = conn
360 .query_row(
361 "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'created_at'",
362 [],
363 |row| row.get(0),
364 )
365 .unwrap();
366 assert!(emb_created_at);
367 }
368
369 #[test]
370 fn test_migration_003_l2_norm() {
371 let conn = Connection::open_in_memory().unwrap();
372 run_migrations(&conn).unwrap();
373
374 let l2_norm_exists: bool = conn
376 .query_row(
377 "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'l2_norm'",
378 [],
379 |row| row.get(0),
380 )
381 .unwrap();
382 assert!(l2_norm_exists);
383 }
384
385 #[test]
386 fn test_migration_004_path_id_embedding_version() {
387 let conn = Connection::open_in_memory().unwrap();
388 run_migrations(&conn).unwrap();
389
390 let path_id_exists: bool = conn
392 .query_row(
393 "SELECT COUNT(*) > 0 FROM pragma_table_info('documents') WHERE name = 'path_id'",
394 [],
395 |row| row.get(0),
396 )
397 .unwrap();
398 assert!(path_id_exists);
399
400 let emb_version_exists: bool = conn
402 .query_row(
403 "SELECT COUNT(*) > 0 FROM pragma_table_info('embeddings') WHERE name = 'embedding_version'",
404 [],
405 |row| row.get(0),
406 )
407 .unwrap();
408 assert!(emb_version_exists);
409 }
410
411 #[test]
412 fn test_migration_foreign_keys() {
413 let conn = Connection::open_in_memory().unwrap();
414 run_migrations(&conn).unwrap();
415
416 conn.execute("PRAGMA foreign_keys = ON", []).unwrap();
418
419 conn.execute(
421 "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
422 rusqlite::params![
423 uuid::Uuid::new_v4().as_bytes(),
424 "test.md",
425 [0u8; 32].as_slice(),
426 [0u8; 32].as_slice(),
427 0i64,
428 0i64,
429 "text/markdown"
430 ],
431 ).unwrap();
432
433 let fk_exists: bool = conn
435 .query_row(
436 "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('chunks')",
437 [],
438 |row| row.get(0),
439 )
440 .unwrap();
441 assert!(fk_exists);
442
443 let emb_fk_exists: bool = conn
445 .query_row(
446 "SELECT COUNT(*) > 0 FROM pragma_foreign_key_list('embeddings')",
447 [],
448 |row| row.get(0),
449 )
450 .unwrap();
451 assert!(emb_fk_exists);
452 }
453
454 #[test]
455 fn test_migration_fts_triggers() {
456 let conn = Connection::open_in_memory().unwrap();
457 run_migrations(&conn).unwrap();
458
459 let fts_exists: bool = conn
461 .query_row(
462 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='table' AND name='fts_chunks'",
463 [],
464 |row| row.get(0),
465 )
466 .unwrap();
467 assert!(fts_exists);
468
469 let ai_trigger: bool = conn
471 .query_row(
472 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ai'",
473 [],
474 |row| row.get(0),
475 )
476 .unwrap();
477 assert!(ai_trigger);
478
479 let ad_trigger: bool = conn
481 .query_row(
482 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_ad'",
483 [],
484 |row| row.get(0),
485 )
486 .unwrap();
487 assert!(ad_trigger);
488
489 let au_trigger: bool = conn
491 .query_row(
492 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='trigger' AND name='chunks_au'",
493 [],
494 |row| row.get(0),
495 )
496 .unwrap();
497 assert!(au_trigger);
498 }
499
500 #[test]
501 fn test_migration_fts_content_sync() {
502 let conn = Connection::open_in_memory().unwrap();
503 run_migrations(&conn).unwrap();
504
505 let doc_id = uuid::Uuid::new_v4();
507 conn.execute(
508 "INSERT INTO documents (id, path, hash, hierarchical_hash, mtime, size, mime_type) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
509 rusqlite::params![
510 doc_id.as_bytes(),
511 "test.md",
512 [0u8; 32].as_slice(),
513 [0u8; 32].as_slice(),
514 0i64,
515 0i64,
516 "text/markdown"
517 ],
518 ).unwrap();
519
520 let chunk_id = uuid::Uuid::new_v4();
522 conn.execute(
523 "INSERT INTO chunks (id, doc_id, text, byte_offset, byte_length, sequence, text_hash) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)",
524 rusqlite::params![
525 chunk_id.as_bytes(),
526 doc_id.as_bytes(),
527 "test content for search",
528 0i64,
529 0i64,
530 0u32,
531 [0u8; 32].as_slice()
532 ],
533 ).unwrap();
534
535 let fts_count: i64 = conn
537 .query_row(
538 "SELECT COUNT(*) FROM fts_chunks WHERE fts_chunks MATCH 'test'",
539 [],
540 |row| row.get(0),
541 )
542 .unwrap();
543 assert!(fts_count > 0);
544 }
545
546 #[test]
547 fn test_migration_indexes() {
548 let conn = Connection::open_in_memory().unwrap();
549 run_migrations(&conn).unwrap();
550
551 let idx1: bool = conn
553 .query_row(
554 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_chunks_doc_id'",
555 [],
556 |row| row.get(0),
557 )
558 .unwrap();
559 assert!(idx1);
560
561 let idx2: bool = conn
563 .query_row(
564 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_embeddings_chunk_id'",
565 [],
566 |row| row.get(0),
567 )
568 .unwrap();
569 assert!(idx2);
570
571 let idx3: bool = conn
573 .query_row(
574 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_source'",
575 [],
576 |row| row.get(0),
577 )
578 .unwrap();
579 assert!(idx3);
580
581 let idx4: bool = conn
583 .query_row(
584 "SELECT COUNT(*) > 0 FROM sqlite_master WHERE type='index' AND name='idx_edges_target'",
585 [],
586 |row| row.get(0),
587 )
588 .unwrap();
589 assert!(idx4);
590 }
591
592 #[test]
593 fn test_schema_version_table_structure() {
594 let conn = Connection::open_in_memory().unwrap();
595
596 conn.execute_batch(
598 r#"
599 CREATE TABLE IF NOT EXISTS schema_version (
600 version INTEGER PRIMARY KEY,
601 name TEXT NOT NULL,
602 applied_at INTEGER NOT NULL DEFAULT (strftime('%s', 'now'))
603 );
604 "#,
605 ).unwrap();
606
607 conn.execute(
609 "INSERT INTO schema_version (version, name) VALUES (1, 'test_migration')",
610 [],
611 ).unwrap();
612
613 let version: i64 = conn
615 .query_row("SELECT version FROM schema_version WHERE name = 'test_migration'", [], |row| row.get(0))
616 .unwrap();
617 assert_eq!(version, 1);
618
619 let name: String = conn
620 .query_row("SELECT name FROM schema_version WHERE version = 1", [], |row| row.get(0))
621 .unwrap();
622 assert_eq!(name, "test_migration");
623 }
624}