Skip to main content

neleus_db/
refs.rs

1use std::fs;
2use std::path::PathBuf;
3use std::time::Duration;
4
5use anyhow::{Context, Result, anyhow};
6
7use crate::atomic::write_atomic;
8use crate::hash::Hash;
9use crate::lock::acquire_lock;
10use crate::wal::{Wal, WalOp};
11
12#[derive(Clone, Debug)]
13pub struct RefsStore {
14    root: PathBuf,
15    wal: Wal,
16}
17
18impl RefsStore {
19    pub fn new(root: impl Into<PathBuf>, wal: Wal) -> Self {
20        Self {
21            root: root.into(),
22            wal,
23        }
24    }
25
26    pub fn ensure_dirs(&self) -> Result<()> {
27        fs::create_dir_all(self.heads_dir())?;
28        fs::create_dir_all(self.states_dir())?;
29        Ok(())
30    }
31
32    pub fn head_get(&self, name: &str) -> Result<Option<Hash>> {
33        validate_ref_name(name)?;
34        read_hash(self.head_path(name))
35    }
36
37    pub fn head_set(&self, name: &str, hash: Hash) -> Result<()> {
38        validate_ref_name(name)?;
39        self.ensure_dirs()?;
40
41        let _lock = acquire_lock(self.root.join(".refs.lock"), Duration::from_secs(10))?;
42        let entry = Wal::make_ref_entry(WalOp::RefHeadSet, name, hash);
43        let wal_path = self.wal.begin_entry(&entry)?;
44
45        write_atomic(&self.head_path(name), format!("{hash}\n").as_bytes())?;
46        self.wal.end(&wal_path)?;
47        Ok(())
48    }
49
50    pub fn state_get(&self, name: &str) -> Result<Option<Hash>> {
51        validate_ref_name(name)?;
52        read_hash(self.state_path(name))
53    }
54
55    pub fn state_set(&self, name: &str, hash: Hash) -> Result<()> {
56        validate_ref_name(name)?;
57        self.ensure_dirs()?;
58
59        let _lock = acquire_lock(self.root.join(".refs.lock"), Duration::from_secs(10))?;
60        let entry = Wal::make_ref_entry(WalOp::RefStateSet, name, hash);
61        let wal_path = self.wal.begin_entry(&entry)?;
62
63        write_atomic(&self.state_path(name), format!("{hash}\n").as_bytes())?;
64        self.wal.end(&wal_path)?;
65        Ok(())
66    }
67
68    pub fn root(&self) -> &PathBuf {
69        &self.root
70    }
71
72    fn heads_dir(&self) -> PathBuf {
73        self.root.join("heads")
74    }
75
76    fn states_dir(&self) -> PathBuf {
77        self.root.join("states")
78    }
79
80    fn head_path(&self, name: &str) -> PathBuf {
81        self.heads_dir().join(name)
82    }
83
84    fn state_path(&self, name: &str) -> PathBuf {
85        self.states_dir().join(name)
86    }
87}
88
89pub fn head_get(store: &RefsStore, name: &str) -> Result<Option<Hash>> {
90    store.head_get(name)
91}
92
93pub fn head_set(store: &RefsStore, name: &str, hash: Hash) -> Result<()> {
94    store.head_set(name, hash)
95}
96
97fn read_hash(path: PathBuf) -> Result<Option<Hash>> {
98    if !path.exists() {
99        return Ok(None);
100    }
101    let raw = fs::read_to_string(&path)
102        .with_context(|| format!("failed reading ref file {}", path.display()))?;
103    let h = raw.trim().parse::<Hash>()?;
104    Ok(Some(h))
105}
106
107fn validate_ref_name(name: &str) -> Result<()> {
108    if name.is_empty() {
109        return Err(anyhow!("reference name cannot be empty"));
110    }
111    if name.starts_with('/') || name.contains("..") || name.contains('\0') {
112        return Err(anyhow!("unsafe reference name: {name}"));
113    }
114    Ok(())
115}
116
117#[cfg(test)]
118mod tests {
119    use tempfile::TempDir;
120
121    use super::*;
122    use crate::hash::hash_blob;
123
124    fn store(tmp: &TempDir) -> RefsStore {
125        let wal = Wal::new(tmp.path().join("wal"));
126        RefsStore::new(tmp.path().join("refs"), wal)
127    }
128
129    #[test]
130    fn head_set_get_roundtrip() {
131        let tmp = TempDir::new().unwrap();
132        let s = store(&tmp);
133        let h = hash_blob(b"x");
134        s.head_set("main", h).unwrap();
135        assert_eq!(s.head_get("main").unwrap(), Some(h));
136    }
137
138    #[test]
139    fn head_get_missing_returns_none() {
140        let tmp = TempDir::new().unwrap();
141        let s = store(&tmp);
142        assert_eq!(s.head_get("main").unwrap(), None);
143    }
144
145    #[test]
146    fn state_set_get_roundtrip() {
147        let tmp = TempDir::new().unwrap();
148        let s = store(&tmp);
149        let h = hash_blob(b"state");
150        s.state_set("main", h).unwrap();
151        assert_eq!(s.state_get("main").unwrap(), Some(h));
152    }
153
154    #[test]
155    fn invalid_ref_name_rejected() {
156        let tmp = TempDir::new().unwrap();
157        let s = store(&tmp);
158        let h = hash_blob(b"x");
159        assert!(s.head_set("../bad", h).is_err());
160    }
161
162    #[test]
163    fn free_functions_delegate() {
164        let tmp = TempDir::new().unwrap();
165        let s = store(&tmp);
166        let h = hash_blob(b"x");
167        super::head_set(&s, "main", h).unwrap();
168        assert_eq!(super::head_get(&s, "main").unwrap(), Some(h));
169    }
170
171    #[test]
172    fn wal_cleanup_after_ref_set() {
173        let tmp = TempDir::new().unwrap();
174        let s = store(&tmp);
175        s.head_set("main", hash_blob(b"x")).unwrap();
176        let wal = Wal::new(tmp.path().join("wal"));
177        assert!(wal.pending().unwrap().is_empty());
178    }
179
180    #[test]
181    fn head_overwrite_works() {
182        let tmp = TempDir::new().unwrap();
183        let s = store(&tmp);
184        let a = hash_blob(b"a");
185        let b = hash_blob(b"b");
186        s.head_set("main", a).unwrap();
187        s.head_set("main", b).unwrap();
188        assert_eq!(s.head_get("main").unwrap(), Some(b));
189    }
190
191    #[test]
192    fn state_ref_missing_returns_none() {
193        let tmp = TempDir::new().unwrap();
194        let s = store(&tmp);
195        assert_eq!(s.state_get("dev").unwrap(), None);
196    }
197}