Skip to main content

oxi_agent/tools/
file_mutation_queue.rs

1/// File mutation queue - serializes concurrent writes to the same file
2/// Prevents race conditions when multiple edit operations target the same file.
3/// Operations on *different* files run in parallel; operations on the *same* file are serialized.
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7use tokio::fs;
8use tokio::sync::Mutex;
9
10/// Global file mutation queue
11static QUEUE: std::sync::OnceLock<FileMutationQueue> = std::sync::OnceLock::new();
12
13/// Get the global file mutation queue
14pub fn global_mutation_queue() -> &'static FileMutationQueue {
15    QUEUE.get_or_init(FileMutationQueue::new)
16}
17
18/// Serializes file mutation operations per canonical path.
19#[derive(Debug)]
20pub struct FileMutationQueue {
21    /// Map from canonical path to a mutex that serializes operations
22    queues: Arc<Mutex<HashMap<PathBuf, Arc<Mutex<()>>>>>,
23}
24
25impl FileMutationQueue {
26    /// TODO.
27    pub fn new() -> Self {
28        Self {
29            queues: Arc::new(Mutex::new(HashMap::new())),
30        }
31    }
32
33    /// Execute a mutation operation on a file, serialized per canonical path.
34    /// If the file doesn't exist yet, uses the path as-is.
35    pub async fn with_queue<F, Fut, T>(&self, path: &Path, f: F) -> T
36    where
37        F: FnOnce() -> Fut,
38        Fut: std::future::Future<Output = T>,
39    {
40        let canonical = fs::canonicalize(path)
41            .await
42            .unwrap_or_else(|_| path.to_path_buf());
43
44        // Get or create a mutex for this file
45        let mutex = {
46            let mut queues = self.queues.lock().await;
47            queues
48                .entry(canonical)
49                .or_insert_with(|| Arc::new(Mutex::new(())))
50                .clone()
51        };
52
53        // Lock the per-file mutex
54        let _guard = mutex.lock().await;
55
56        // Execute the operation
57        f().await
58    }
59
60    /// Clean up entries for files that no longer need serialization.
61    pub async fn cleanup(&self, path: &Path) {
62        let canonical = fs::canonicalize(path)
63            .await
64            .unwrap_or_else(|_| path.to_path_buf());
65        let mut queues = self.queues.lock().await;
66        queues.remove(&canonical);
67    }
68}
69
70impl Default for FileMutationQueue {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79    use std::sync::atomic::{AtomicUsize, Ordering};
80
81    #[tokio::test]
82    async fn test_same_file_serialized() {
83        let queue = Arc::new(FileMutationQueue::new());
84        let counter = Arc::new(AtomicUsize::new(0));
85        let path = PathBuf::from("/tmp/test_mutation_queue_file");
86
87        let mut handles = Vec::new();
88
89        for _ in 0..10 {
90            let queue = queue.clone();
91            let counter = counter.clone();
92            let path = path.clone();
93
94            handles.push(tokio::spawn(async move {
95                queue
96                    .with_queue(&path, || async {
97                        let prev = counter.fetch_add(1, Ordering::SeqCst);
98                        // Simulate some work
99                        tokio::time::sleep(std::time::Duration::from_millis(1)).await;
100                        prev
101                    })
102                    .await
103            }));
104        }
105
106        // All operations should complete
107        for handle in handles {
108            let _ = handle.await.unwrap();
109        }
110
111        assert_eq!(counter.load(Ordering::SeqCst), 10);
112    }
113
114    #[tokio::test]
115    async fn test_different_files_parallel() {
116        let queue = Arc::new(FileMutationQueue::new());
117        let counter = Arc::new(AtomicUsize::new(0));
118
119        let path1 = PathBuf::from("/tmp/test_file_1");
120        let path2 = PathBuf::from("/tmp/test_file_2");
121
122        let q1 = queue.clone();
123        let q2 = queue.clone();
124        let counter1 = counter.clone();
125        let counter2 = counter.clone();
126
127        let h1 = tokio::spawn(async move {
128            q1.with_queue(&path1, || async {
129                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
130                counter1.fetch_add(1, Ordering::SeqCst)
131            })
132            .await
133        });
134
135        let h2 = tokio::spawn(async move {
136            q2.with_queue(&path2, || async { counter2.fetch_add(1, Ordering::SeqCst) })
137                .await
138        });
139
140        // Both should complete quickly (parallel)
141        let r1 = tokio::time::timeout(std::time::Duration::from_millis(100), h1).await;
142        let r2 = tokio::time::timeout(std::time::Duration::from_millis(100), h2).await;
143
144        assert!(r1.is_ok());
145        assert!(r2.is_ok());
146    }
147}