Skip to main content

atd_runtime/
tracker.rs

1//! Per-connection record of which files have been Read + their observed
2//! mtime/size at Read time. Edit uses this to enforce "you must Read before
3//! Edit, and the file mustn't have changed since then."
4
5use std::collections::HashMap;
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8use std::time::SystemTime;
9
10use thiserror::Error;
11
12#[derive(Debug, Clone, Copy)]
13struct ReadRecord {
14    mtime: SystemTime,
15    size: u64,
16}
17
18#[derive(Debug, Error)]
19pub enum ReadTrackerError {
20    #[error("file has not been read in this session: {path}")]
21    NotRead { path: PathBuf },
22
23    #[error("file modified since it was read: {path}")]
24    Modified { path: PathBuf },
25}
26
27pub struct ReadTracker {
28    entries: Mutex<HashMap<PathBuf, ReadRecord>>,
29}
30
31impl ReadTracker {
32    pub fn new() -> Self {
33        Self {
34            entries: Mutex::new(HashMap::new()),
35        }
36    }
37
38    /// Record a successful read. `path` should already be canonicalized by
39    /// the caller.
40    pub fn record(&self, path: PathBuf, mtime: SystemTime, size: u64) {
41        let mut g = self.entries.lock().expect("tracker mutex poisoned");
42        g.insert(path, ReadRecord { mtime, size });
43    }
44
45    /// Verify that `path` has been read in this session AND its current
46    /// mtime + size match what was recorded.
47    ///
48    /// Caller passes the current stat to avoid racing a syscall inside the
49    /// lock. `path` must already be canonicalized.
50    pub fn check(
51        &self,
52        path: &Path,
53        current_mtime: SystemTime,
54        current_size: u64,
55    ) -> Result<(), ReadTrackerError> {
56        let g = self.entries.lock().expect("tracker mutex poisoned");
57        match g.get(path) {
58            None => Err(ReadTrackerError::NotRead {
59                path: path.to_path_buf(),
60            }),
61            Some(rec) => {
62                if rec.mtime != current_mtime || rec.size != current_size {
63                    Err(ReadTrackerError::Modified {
64                        path: path.to_path_buf(),
65                    })
66                } else {
67                    Ok(())
68                }
69            }
70        }
71    }
72}
73
74impl Default for ReadTracker {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use std::time::Duration;
84
85    fn t(secs: u64) -> SystemTime {
86        SystemTime::UNIX_EPOCH + Duration::from_secs(secs)
87    }
88
89    #[test]
90    fn check_unrecorded_path_returns_not_read() {
91        let tr = ReadTracker::new();
92        let err = tr.check(Path::new("/tmp/nope"), t(1), 10).unwrap_err();
93        assert!(matches!(err, ReadTrackerError::NotRead { .. }));
94    }
95
96    #[test]
97    fn record_then_check_same_stat_is_ok() {
98        let tr = ReadTracker::new();
99        let p = PathBuf::from("/tmp/f");
100        tr.record(p.clone(), t(100), 42);
101        tr.check(&p, t(100), 42).unwrap();
102    }
103
104    #[test]
105    fn check_returns_modified_when_mtime_changed() {
106        let tr = ReadTracker::new();
107        let p = PathBuf::from("/tmp/f");
108        tr.record(p.clone(), t(100), 42);
109        let err = tr.check(&p, t(200), 42).unwrap_err();
110        assert!(matches!(err, ReadTrackerError::Modified { .. }));
111    }
112
113    #[test]
114    fn check_returns_modified_when_size_changed() {
115        let tr = ReadTracker::new();
116        let p = PathBuf::from("/tmp/f");
117        tr.record(p.clone(), t(100), 42);
118        let err = tr.check(&p, t(100), 100).unwrap_err();
119        assert!(matches!(err, ReadTrackerError::Modified { .. }));
120    }
121
122    #[test]
123    fn record_overwrites_prior_entry() {
124        let tr = ReadTracker::new();
125        let p = PathBuf::from("/tmp/f");
126        tr.record(p.clone(), t(100), 42);
127        tr.record(p.clone(), t(200), 84);
128        // The new record is the one that matches.
129        tr.check(&p, t(200), 84).unwrap();
130        // The old one doesn't.
131        let err = tr.check(&p, t(100), 42).unwrap_err();
132        assert!(matches!(err, ReadTrackerError::Modified { .. }));
133    }
134}