Skip to main content

agent_first_pay/store/
db.rs

1use crate::provider::PayError;
2use crate::store::MigrationLog;
3use redb::{Database, ReadableDatabase, TableDefinition};
4use std::path::Path;
5use std::sync::Mutex;
6
7const SCHEMA_TABLE: TableDefinition<&str, u64> = TableDefinition::new("_schema");
8const VERSION_KEY: &str = "version";
9
10pub type Migration<'a> = &'a dyn Fn(&Database) -> Result<(), PayError>;
11
12static MIGRATION_LOG: Mutex<Vec<MigrationLog>> = Mutex::new(Vec::new());
13
14/// Drain all migration log entries accumulated since last drain.
15pub fn drain_migration_log() -> Vec<MigrationLog> {
16    match MIGRATION_LOG.lock() {
17        Ok(mut log) => std::mem::take(&mut *log),
18        Err(_) => Vec::new(),
19    }
20}
21
22fn push_migration_log(entry: MigrationLog) {
23    if let Ok(mut log) = MIGRATION_LOG.lock() {
24        log.push(entry);
25    }
26}
27
28pub fn open_database(path: &Path) -> Result<Database, PayError> {
29    if let Some(parent) = path.parent() {
30        std::fs::create_dir_all(parent)
31            .map_err(|e| PayError::InternalError(format!("mkdir {}: {e}", parent.display())))?;
32    }
33    let db = if path.exists() {
34        Database::open(path)
35    } else {
36        Database::create(path)
37    }
38    .map_err(|e| PayError::InternalError(format!("open {}: {e}", path.display())))?;
39    Ok(db)
40}
41
42pub fn open_and_migrate(
43    path: &Path,
44    target_version: u64,
45    migrations: &[Migration<'_>],
46) -> Result<Database, PayError> {
47    let db = open_database(path)?;
48    let current = read_schema_version(&db)?;
49
50    if current < target_version {
51        if migrations.len() < target_version as usize {
52            return Err(PayError::InternalError(format!(
53                "schema: need {} migrations but only {} provided for {}",
54                target_version,
55                migrations.len(),
56                path.display()
57            )));
58        }
59        for v in current..target_version {
60            migrations[v as usize](&db)?;
61        }
62        write_schema_version(&db, target_version)?;
63        push_migration_log(MigrationLog {
64            database: path
65                .file_name()
66                .and_then(|n| n.to_str())
67                .unwrap_or("unknown")
68                .to_string(),
69            from_version: current,
70            to_version: target_version,
71        });
72    }
73
74    Ok(db)
75}
76
77fn read_schema_version(db: &Database) -> Result<u64, PayError> {
78    let read_txn = db
79        .begin_read()
80        .map_err(|e| PayError::InternalError(format!("schema begin_read: {e}")))?;
81    let Ok(table) = read_txn.open_table(SCHEMA_TABLE) else {
82        return Ok(0);
83    };
84    match table
85        .get(VERSION_KEY)
86        .map_err(|e| PayError::InternalError(format!("schema read version: {e}")))?
87    {
88        Some(v) => Ok(v.value()),
89        None => Ok(0),
90    }
91}
92
93fn write_schema_version(db: &Database, version: u64) -> Result<(), PayError> {
94    let write_txn = db
95        .begin_write()
96        .map_err(|e| PayError::InternalError(format!("schema begin_write: {e}")))?;
97    {
98        let mut table = write_txn
99            .open_table(SCHEMA_TABLE)
100            .map_err(|e| PayError::InternalError(format!("schema open _schema: {e}")))?;
101        table
102            .insert(VERSION_KEY, version)
103            .map_err(|e| PayError::InternalError(format!("schema write version: {e}")))?;
104    }
105    write_txn
106        .commit()
107        .map_err(|e| PayError::InternalError(format!("schema commit: {e}")))?;
108    Ok(())
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114
115    #[test]
116    fn new_database_has_version_zero() {
117        let tmp = tempfile::tempdir().unwrap();
118        let path = tmp.path().join("test.redb");
119        let db = open_database(&path).unwrap();
120        assert_eq!(read_schema_version(&db).unwrap(), 0);
121    }
122
123    #[test]
124    fn open_and_migrate_stamps_version() {
125        let tmp = tempfile::tempdir().unwrap();
126        let path = tmp.path().join("stamps.redb");
127        let _ = drain_migration_log();
128        let db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
129        assert_eq!(read_schema_version(&db).unwrap(), 1);
130        let log = drain_migration_log();
131        let ours: Vec<_> = log.iter().filter(|e| e.database == "stamps.redb").collect();
132        assert_eq!(ours.len(), 1);
133        assert_eq!(ours[0].from_version, 0);
134        assert_eq!(ours[0].to_version, 1);
135    }
136
137    #[test]
138    fn open_and_migrate_skips_when_current() {
139        let tmp = tempfile::tempdir().unwrap();
140        let path = tmp.path().join("skip.redb");
141
142        let _db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
143        drop(_db);
144        let _ = drain_migration_log();
145
146        // Second open — no migration, no log
147        let db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
148        assert_eq!(read_schema_version(&db).unwrap(), 1);
149        let log = drain_migration_log();
150        let ours: Vec<_> = log.iter().filter(|e| e.database == "skip.redb").collect();
151        assert!(ours.is_empty());
152    }
153
154    #[test]
155    fn open_and_migrate_runs_sequential_migrations() {
156        let tmp = tempfile::tempdir().unwrap();
157        let path = tmp.path().join("test.redb");
158
159        let marker: TableDefinition<&str, u64> = TableDefinition::new("_test_marker");
160
161        let db = open_and_migrate(
162            &path,
163            2,
164            &[
165                &|db| {
166                    let w = db
167                        .begin_write()
168                        .map_err(|e| PayError::InternalError(e.to_string()))?;
169                    {
170                        let mut t = w
171                            .open_table(TableDefinition::<&str, u64>::new("_test_marker"))
172                            .map_err(|e| PayError::InternalError(e.to_string()))?;
173                        t.insert("v0_to_v1", 1u64)
174                            .map_err(|e| PayError::InternalError(e.to_string()))?;
175                    }
176                    w.commit()
177                        .map_err(|e| PayError::InternalError(e.to_string()))?;
178                    Ok(())
179                },
180                &|db| {
181                    let w = db
182                        .begin_write()
183                        .map_err(|e| PayError::InternalError(e.to_string()))?;
184                    {
185                        let mut t = w
186                            .open_table(TableDefinition::<&str, u64>::new("_test_marker"))
187                            .map_err(|e| PayError::InternalError(e.to_string()))?;
188                        t.insert("v1_to_v2", 2u64)
189                            .map_err(|e| PayError::InternalError(e.to_string()))?;
190                    }
191                    w.commit()
192                        .map_err(|e| PayError::InternalError(e.to_string()))?;
193                    Ok(())
194                },
195            ],
196        )
197        .unwrap();
198
199        assert_eq!(read_schema_version(&db).unwrap(), 2);
200
201        let r = db.begin_read().unwrap();
202        let t = r.open_table(marker).unwrap();
203        assert_eq!(t.get("v0_to_v1").unwrap().unwrap().value(), 1);
204        assert_eq!(t.get("v1_to_v2").unwrap().unwrap().value(), 2);
205    }
206}