Skip to main content

agent_first_pay/store/
db.rs

1use crate::provider::PayError;
2use crate::store::MigrationLog;
3use redb::{Database, ReadableDatabase, TableDefinition};
4use std::path::Path;
5use std::sync::Mutex;
6
7const SCHEMA_TABLE: TableDefinition<&str, u64> = TableDefinition::new("_schema");
8const VERSION_KEY: &str = "version";
9
10pub type Migration<'a> = &'a dyn Fn(&Database) -> Result<(), PayError>;
11
12static MIGRATION_LOG: Mutex<Vec<MigrationLog>> = Mutex::new(Vec::new());
13
14/// Drain all migration log entries accumulated since last drain.
15pub fn drain_migration_log() -> Vec<MigrationLog> {
16    match MIGRATION_LOG.lock() {
17        Ok(mut log) => std::mem::take(&mut *log),
18        Err(_) => Vec::new(),
19    }
20}
21
22fn push_migration_log(entry: MigrationLog) {
23    if let Ok(mut log) = MIGRATION_LOG.lock() {
24        log.push(entry);
25    }
26}
27
28pub fn open_database(path: &Path) -> Result<Database, PayError> {
29    if let Some(parent) = path.parent() {
30        std::fs::create_dir_all(parent)
31            .map_err(|e| PayError::InternalError(format!("mkdir {}: {e}", parent.display())))?;
32        set_private_dir_permissions(parent)?;
33    }
34    let db = if path.exists() {
35        Database::open(path)
36    } else {
37        Database::create(path)
38    }
39    .map_err(|e| PayError::InternalError(format!("open {}: {e}", path.display())))?;
40    set_private_file_permissions(path)?;
41    Ok(db)
42}
43
44#[cfg(unix)]
45fn set_private_dir_permissions(path: &Path) -> Result<(), PayError> {
46    use std::os::unix::fs::PermissionsExt;
47
48    std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o700))
49        .map_err(|e| PayError::InternalError(format!("chmod 700 {}: {e}", path.display())))
50}
51
52#[cfg(not(unix))]
53fn set_private_dir_permissions(_path: &Path) -> Result<(), PayError> {
54    Ok(())
55}
56
57#[cfg(unix)]
58fn set_private_file_permissions(path: &Path) -> Result<(), PayError> {
59    use std::os::unix::fs::PermissionsExt;
60
61    std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600))
62        .map_err(|e| PayError::InternalError(format!("chmod 600 {}: {e}", path.display())))
63}
64
65#[cfg(not(unix))]
66fn set_private_file_permissions(_path: &Path) -> Result<(), PayError> {
67    Ok(())
68}
69
70pub fn open_and_migrate(
71    path: &Path,
72    target_version: u64,
73    migrations: &[Migration<'_>],
74) -> Result<Database, PayError> {
75    let db = open_database(path)?;
76    let current = read_schema_version(&db)?;
77
78    if current < target_version {
79        if migrations.len() < target_version as usize {
80            return Err(PayError::InternalError(format!(
81                "schema: need {} migrations but only {} provided for {}",
82                target_version,
83                migrations.len(),
84                path.display()
85            )));
86        }
87        for v in current..target_version {
88            migrations[v as usize](&db)?;
89        }
90        write_schema_version(&db, target_version)?;
91        push_migration_log(MigrationLog {
92            database: path
93                .file_name()
94                .and_then(|n| n.to_str())
95                .unwrap_or("unknown")
96                .to_string(),
97            from_version: current,
98            to_version: target_version,
99        });
100    }
101
102    Ok(db)
103}
104
105fn read_schema_version(db: &Database) -> Result<u64, PayError> {
106    let read_txn = db
107        .begin_read()
108        .map_err(|e| PayError::InternalError(format!("schema begin_read: {e}")))?;
109    let Ok(table) = read_txn.open_table(SCHEMA_TABLE) else {
110        return Ok(0);
111    };
112    match table
113        .get(VERSION_KEY)
114        .map_err(|e| PayError::InternalError(format!("schema read version: {e}")))?
115    {
116        Some(v) => Ok(v.value()),
117        None => Ok(0),
118    }
119}
120
121fn write_schema_version(db: &Database, version: u64) -> Result<(), PayError> {
122    let write_txn = db
123        .begin_write()
124        .map_err(|e| PayError::InternalError(format!("schema begin_write: {e}")))?;
125    {
126        let mut table = write_txn
127            .open_table(SCHEMA_TABLE)
128            .map_err(|e| PayError::InternalError(format!("schema open _schema: {e}")))?;
129        table
130            .insert(VERSION_KEY, version)
131            .map_err(|e| PayError::InternalError(format!("schema write version: {e}")))?;
132    }
133    write_txn
134        .commit()
135        .map_err(|e| PayError::InternalError(format!("schema commit: {e}")))?;
136    Ok(())
137}
138
139#[cfg(test)]
140#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn new_database_has_version_zero() {
146        let tmp = tempfile::tempdir().unwrap();
147        let path = tmp.path().join("test.redb");
148        let db = open_database(&path).unwrap();
149        assert_eq!(read_schema_version(&db).unwrap(), 0);
150    }
151
152    #[test]
153    fn open_and_migrate_stamps_version() {
154        let tmp = tempfile::tempdir().unwrap();
155        let path = tmp.path().join("stamps.redb");
156        let _ = drain_migration_log();
157        let db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
158        assert_eq!(read_schema_version(&db).unwrap(), 1);
159        let log = drain_migration_log();
160        let ours: Vec<_> = log.iter().filter(|e| e.database == "stamps.redb").collect();
161        assert_eq!(ours.len(), 1);
162        assert_eq!(ours[0].from_version, 0);
163        assert_eq!(ours[0].to_version, 1);
164    }
165
166    #[test]
167    fn open_and_migrate_skips_when_current() {
168        let tmp = tempfile::tempdir().unwrap();
169        let path = tmp.path().join("skip.redb");
170
171        let _db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
172        drop(_db);
173        let _ = drain_migration_log();
174
175        // Second open — no migration, no log
176        let db = open_and_migrate(&path, 1, &[&|_db| Ok(())]).unwrap();
177        assert_eq!(read_schema_version(&db).unwrap(), 1);
178        let log = drain_migration_log();
179        let ours: Vec<_> = log.iter().filter(|e| e.database == "skip.redb").collect();
180        assert!(ours.is_empty());
181    }
182
183    #[test]
184    fn open_and_migrate_runs_sequential_migrations() {
185        let tmp = tempfile::tempdir().unwrap();
186        let path = tmp.path().join("test.redb");
187
188        let marker: TableDefinition<&str, u64> = TableDefinition::new("_test_marker");
189
190        let db = open_and_migrate(
191            &path,
192            2,
193            &[
194                &|db| {
195                    let w = db
196                        .begin_write()
197                        .map_err(|e| PayError::InternalError(e.to_string()))?;
198                    {
199                        let mut t = w
200                            .open_table(TableDefinition::<&str, u64>::new("_test_marker"))
201                            .map_err(|e| PayError::InternalError(e.to_string()))?;
202                        t.insert("v0_to_v1", 1u64)
203                            .map_err(|e| PayError::InternalError(e.to_string()))?;
204                    }
205                    w.commit()
206                        .map_err(|e| PayError::InternalError(e.to_string()))?;
207                    Ok(())
208                },
209                &|db| {
210                    let w = db
211                        .begin_write()
212                        .map_err(|e| PayError::InternalError(e.to_string()))?;
213                    {
214                        let mut t = w
215                            .open_table(TableDefinition::<&str, u64>::new("_test_marker"))
216                            .map_err(|e| PayError::InternalError(e.to_string()))?;
217                        t.insert("v1_to_v2", 2u64)
218                            .map_err(|e| PayError::InternalError(e.to_string()))?;
219                    }
220                    w.commit()
221                        .map_err(|e| PayError::InternalError(e.to_string()))?;
222                    Ok(())
223                },
224            ],
225        )
226        .unwrap();
227
228        assert_eq!(read_schema_version(&db).unwrap(), 2);
229
230        let r = db.begin_read().unwrap();
231        let t = r.open_table(marker).unwrap();
232        assert_eq!(t.get("v0_to_v1").unwrap().unwrap().value(), 1);
233        assert_eq!(t.get("v1_to_v2").unwrap().unwrap().value(), 2);
234    }
235}