Skip to main content

rab/builtin/
file_mutation_queue.rs

1use std::collections::HashMap;
2use std::path::Path;
3use std::sync::{Arc, LazyLock, Mutex};
4use tokio::sync::Notify;
5
6use crate::builtin;
7
8/// Per-file queue entries. Each entry is a `Notify` that the NEXT operation
9/// will wait on. Operations chain through these to serialize access.
10static FILE_QUEUES: LazyLock<Mutex<HashMap<String, Arc<Notify>>>> =
11    LazyLock::new(|| Mutex::new(HashMap::new()));
12
13/// Normalize a path for use as a queue key.
14fn normalize_path_key(path: &str, cwd: &Path) -> String {
15    builtin::resolve_path(path, cwd)
16        .to_string_lossy()
17        .replace('\\', "/")
18}
19
20/// Serialize file mutation operations targeting the same file.
21///
22/// Operations for different files still run in parallel. This mirrors pi's
23/// `withFileMutationQueue` in file-mutation-queue.ts.
24///
25/// The implementation:
26/// - Each file has a `Notify` stored in a global map, representing the
27///   "next operation" signal.
28/// - An operation registers by replacing the entry with its own `Notify`
29///   (for the operation after it), and picking up the previous `Notify`
30///   to wait on.
31/// - When the operation finishes, it signals its own `Notify` (which the
32///   next operation is waiting on) and, if it is still the latest entry,
33///   cleans up.
34pub async fn with_file_mutation_queue<T, E, F, Fut>(
35    file_path: &str,
36    cwd: &Path,
37    f: F,
38) -> Result<T, E>
39where
40    F: FnOnce() -> Fut,
41    Fut: std::future::Future<Output = Result<T, E>>,
42{
43    let key = normalize_path_key(file_path, cwd);
44
45    // ── Registration phase ─────────────────────────────────────
46    // Atomically: pick up the previous Notify (if any) and store ours.
47    let our_notify = Arc::new(Notify::new());
48    let prev_notify = {
49        let mut queues = FILE_QUEUES.lock().unwrap();
50        queues.insert(key.clone(), our_notify.clone())
51    };
52
53    // ── Wait for the previous operation to finish ──────────────
54    if let Some(prev) = &prev_notify {
55        prev.notified().await;
56    }
57
58    // ── Run the operation ──────────────────────────────────────
59    let result = f().await;
60
61    // ── Signal the next operation ──────────────────────────────
62    // Our Notify may have been picked up by the next operation as
63    // its prev_notify. Signal it so the next operation can proceed.
64    our_notify.notify_one();
65
66    // ── Clean up if we're still the latest entry ───────────────
67    let mut queues = FILE_QUEUES.lock().unwrap();
68    if let Some(current) = queues.get(&key)
69        && Arc::ptr_eq(current, &our_notify)
70    {
71        // No new operation registered after us — clean up.
72        queues.remove(&key);
73    }
74    // If a new operation registered, its own Notify is now in the
75    // map; we leave it there for the next cleanup cycle.
76
77    result
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use std::sync::atomic::{AtomicUsize, Ordering};
84    use std::time::Duration;
85
86    #[tokio::test]
87    async fn runs_without_previous() {
88        let cwd = Path::new("/tmp");
89        let mut ran = false;
90        with_file_mutation_queue("/tmp/test_file_1.txt", cwd, || async {
91            ran = true;
92            Ok::<_, String>(42)
93        })
94        .await
95        .unwrap();
96        assert!(ran);
97    }
98
99    #[tokio::test]
100    async fn serializes_concurrent_access() {
101        let cwd = Path::new("/tmp");
102        let counter = Arc::new(AtomicUsize::new(0));
103        let max = Arc::new(AtomicUsize::new(0));
104
105        let mut handles = Vec::new();
106        for _ in 0..10 {
107            let c = counter.clone();
108            let m = max.clone();
109            handles.push(tokio::spawn(async move {
110                with_file_mutation_queue("/tmp/test_serial.txt", cwd, || async {
111                    let v = c.fetch_add(1, Ordering::SeqCst) + 1;
112                    // Track the maximum concurrent count
113                    let prev_max = m.fetch_max(v, Ordering::SeqCst);
114                    // Simulate work
115                    tokio::time::sleep(Duration::from_millis(5)).await;
116                    c.fetch_sub(1, Ordering::SeqCst);
117                    // If max concurrent > 1, the queue didn't work
118                    if prev_max >= 1 && v > 1 {
119                        panic!("concurrent access detected: v={}", v);
120                    }
121                    Ok::<_, String>(())
122                })
123                .await
124                .unwrap();
125            }));
126        }
127
128        for handle in handles {
129            handle.await.unwrap();
130        }
131
132        // Max concurrent should be 1 (serialized)
133        assert_eq!(max.load(Ordering::SeqCst), 1);
134    }
135
136    #[tokio::test]
137    async fn different_files_run_in_parallel() {
138        let cwd = Path::new("/tmp");
139        let start = std::time::Instant::now();
140
141        let mut handles = Vec::new();
142        for i in 0..5 {
143            handles.push(tokio::spawn(async move {
144                with_file_mutation_queue(&format!("/tmp/parallel_{}.txt", i), cwd, || async {
145                    tokio::time::sleep(Duration::from_millis(50)).await;
146                    Ok::<_, String>(i)
147                })
148                .await
149                .unwrap()
150            }));
151        }
152
153        for handle in handles {
154            handle.await.unwrap();
155        }
156
157        // All 5 ran in parallel, so total time should be ~50ms not ~250ms
158        let elapsed = start.elapsed();
159        assert!(
160            elapsed < Duration::from_millis(150),
161            "took too long: {:?} — files ran sequentially instead of in parallel",
162            elapsed
163        );
164    }
165
166    #[tokio::test]
167    async fn returns_value() {
168        let cwd = Path::new("/tmp");
169        let result: Result<i32, String> =
170            with_file_mutation_queue("/tmp/retval.txt", cwd, || async { Ok(99) }).await;
171        assert_eq!(result.unwrap(), 99);
172    }
173
174    #[tokio::test]
175    async fn propagates_error() {
176        let cwd = Path::new("/tmp");
177        let result: Result<i32, String> =
178            with_file_mutation_queue("/tmp/error.txt", cwd, || async { Err("oops".to_string()) })
179                .await;
180        assert!(result.is_err());
181        assert_eq!(result.unwrap_err(), "oops");
182    }
183
184    #[tokio::test]
185    async fn chains_correctly() {
186        // Test that three operations on the same file run in order
187        let cwd = Path::new("/tmp");
188        let order = Arc::new(std::sync::Mutex::new(Vec::new()));
189
190        let mut handles = Vec::new();
191        for i in 0..3 {
192            let o = order.clone();
193            handles.push(tokio::spawn(async move {
194                with_file_mutation_queue("/tmp/chaining.txt", cwd, || async {
195                    // Simulate variable work time
196                    tokio::time::sleep(Duration::from_millis(10 * (3 - i))).await;
197                    o.lock().unwrap().push(i);
198                    Ok::<_, String>(())
199                })
200                .await
201                .unwrap()
202            }));
203        }
204
205        for handle in handles {
206            handle.await.unwrap();
207        }
208
209        // Despite task 0 having the longest sleep (30ms),
210        // task 1 (20ms) and 2 (10ms) should execute AFTER task 0
211        // because they're serialized
212        let order = order.lock().unwrap();
213        assert_eq!(*order, vec![0, 1, 2], "operations executed out of order");
214    }
215}