Skip to main content

fakecloud_persistence/
atomic.rs

1use std::fs::{File, OpenOptions};
2use std::io::{self, Write};
3use std::path::{Path, PathBuf};
4
5use serde::Serialize;
6
7fn tmp_path(path: &Path) -> PathBuf {
8    // Unique temp name per write. A fixed `<path>.tmp` let two concurrent
9    // writers to the same path (e.g. the KMS snapshot_lock-guarded save and the
10    // lock-free auto-provision snapshot hook firing from another worker)
11    // truncate+write the SAME temp file and interleave their bytes, producing a
12    // corrupt blob that fails to parse on restart -> KMS keys + all ciphertext
13    // permanently lost (bug-audit 2026-05-28, 4.1). A process id + monotonic
14    // counter make every in-flight temp distinct; the rename stays atomic.
15    use std::sync::atomic::{AtomicU64, Ordering};
16    static SEQ: AtomicU64 = AtomicU64::new(0);
17    let seq = SEQ.fetch_add(1, Ordering::Relaxed);
18    let mut os = path.as_os_str().to_owned();
19    os.push(format!(".{}.{}.tmp", std::process::id(), seq));
20    PathBuf::from(os)
21}
22
23fn fsync_parent(path: &Path) -> io::Result<()> {
24    if let Some(parent) = path.parent() {
25        if !parent.as_os_str().is_empty() {
26            let dir = File::open(parent)?;
27            dir.sync_all()?;
28        }
29    }
30    Ok(())
31}
32
33fn write_atomic_bytes_inner(tmp: &Path, path: &Path, bytes: &[u8]) -> io::Result<()> {
34    {
35        let mut f = OpenOptions::new()
36            .write(true)
37            .create(true)
38            .truncate(true)
39            .open(tmp)?;
40        f.write_all(bytes)?;
41        f.sync_all()?;
42    }
43    std::fs::rename(tmp, path)?;
44    fsync_parent(path)?;
45    Ok(())
46}
47
48pub fn write_atomic_bytes(path: &Path, bytes: &[u8]) -> io::Result<()> {
49    let tmp = tmp_path(path);
50    match write_atomic_bytes_inner(&tmp, path, bytes) {
51        Ok(()) => Ok(()),
52        Err(e) => {
53            let _ = std::fs::remove_file(&tmp);
54            Err(e)
55        }
56    }
57}
58
59pub fn write_atomic_toml<T: Serialize>(path: &Path, value: &T) -> io::Result<()> {
60    let text = toml::to_string_pretty(value)
61        .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))?;
62    write_atomic_bytes(path, text.as_bytes())
63}
64
65fn write_atomic_from_file_inner(src: &Path, dst: &Path) -> io::Result<()> {
66    {
67        let f = File::open(src)?;
68        f.sync_all()?;
69    }
70    std::fs::rename(src, dst)?;
71    fsync_parent(dst)?;
72    Ok(())
73}
74
75pub fn write_atomic_from_file(src: &Path, dst: &Path) -> io::Result<()> {
76    match write_atomic_from_file_inner(src, dst) {
77        Ok(()) => Ok(()),
78        Err(e) => {
79            // Best-effort cleanup: remove any stray tmp the caller might see.
80            let tmp = tmp_path(dst);
81            let _ = std::fs::remove_file(&tmp);
82            Err(e)
83        }
84    }
85}
86
87fn write_atomic_copy_from_file_inner(tmp: &Path, src: &Path, dst: &Path) -> io::Result<()> {
88    {
89        let mut input = File::open(src)?;
90        let mut out = OpenOptions::new()
91            .write(true)
92            .create(true)
93            .truncate(true)
94            .open(tmp)?;
95        io::copy(&mut input, &mut out)?;
96        out.sync_all()?;
97    }
98    std::fs::rename(tmp, dst)?;
99    fsync_parent(dst)?;
100    Ok(())
101}
102
103/// Copy `src` into `dst` atomically, leaving `src` untouched. Used by the
104/// S3 store to replicate disk-backed object bodies without round-tripping
105/// through RAM.
106pub fn write_atomic_copy_from_file(src: &Path, dst: &Path) -> io::Result<()> {
107    let tmp = tmp_path(dst);
108    match write_atomic_copy_from_file_inner(&tmp, src, dst) {
109        Ok(()) => Ok(()),
110        Err(e) => {
111            let _ = std::fs::remove_file(&tmp);
112            Err(e)
113        }
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn failed_write_leaves_no_tmp() {
123        // Writing into a non-existent parent directory should fail without
124        // leaving a lingering `.tmp` sibling. Use a tempdir so the test is
125        // hermetic.
126        let tmp = tempfile::tempdir().unwrap();
127        let bogus = tmp.path().join("does/not/exist/target.bin");
128        let err = write_atomic_bytes(&bogus, b"hello").unwrap_err();
129        let tmp_sibling = tmp_path(&bogus);
130        assert!(!tmp_sibling.exists(), "stray tmp: {:?}", tmp_sibling);
131        let _ = err;
132    }
133
134    #[test]
135    fn write_atomic_bytes_round_trip() {
136        let tmp = tempfile::tempdir().unwrap();
137        let path = tmp.path().join("out.bin");
138        write_atomic_bytes(&path, b"hello world").unwrap();
139        assert_eq!(std::fs::read(&path).unwrap(), b"hello world");
140    }
141
142    // bug-audit 2026-05-28, 4.1: many concurrent writers to one path must never
143    // produce a corrupt (interleaved) file. With a fixed `.tmp` suffix they
144    // raced on the same temp file; unique temp names + the atomic rename
145    // guarantee the final file equals exactly one writer's payload.
146    #[test]
147    fn concurrent_writes_never_corrupt() {
148        use std::sync::Arc;
149        let dir = tempfile::tempdir().unwrap();
150        let path = Arc::new(dir.path().join("snap.bin"));
151        let payloads: Vec<Vec<u8>> = (0..16).map(|i| vec![b'A' + i as u8; 8192]).collect();
152        let handles: Vec<_> = payloads
153            .iter()
154            .cloned()
155            .map(|p| {
156                let path = Arc::clone(&path);
157                std::thread::spawn(move || write_atomic_bytes(&path, &p).unwrap())
158            })
159            .collect();
160        for h in handles {
161            h.join().unwrap();
162        }
163        let got = std::fs::read(&*path).unwrap();
164        assert!(
165            payloads.contains(&got),
166            "persisted file is not any single writer's payload (corrupt interleave)"
167        );
168        let leftover: Vec<_> = std::fs::read_dir(dir.path())
169            .unwrap()
170            .filter_map(|e| e.ok())
171            .filter(|e| e.file_name().to_string_lossy().ends_with(".tmp"))
172            .collect();
173        assert!(leftover.is_empty(), "leftover temp files: {leftover:?}");
174    }
175
176    #[test]
177    fn write_atomic_bytes_overwrites() {
178        let tmp = tempfile::tempdir().unwrap();
179        let path = tmp.path().join("out.bin");
180        write_atomic_bytes(&path, b"v1").unwrap();
181        write_atomic_bytes(&path, b"v2").unwrap();
182        assert_eq!(std::fs::read(&path).unwrap(), b"v2");
183    }
184
185    #[test]
186    fn write_atomic_toml_round_trip() {
187        #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)]
188        struct Config {
189            name: String,
190            count: i64,
191        }
192        let tmp = tempfile::tempdir().unwrap();
193        let path = tmp.path().join("cfg.toml");
194        let cfg = Config {
195            name: "test".to_string(),
196            count: 42,
197        };
198        write_atomic_toml(&path, &cfg).unwrap();
199        let content = std::fs::read_to_string(&path).unwrap();
200        assert!(content.contains("name"));
201        assert!(content.contains("test"));
202    }
203}