oxi_agent/tools/
file_mutation_queue.rs1use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use std::sync::Arc;
7use tokio::fs;
8use tokio::sync::Mutex;
9
10static QUEUE: std::sync::OnceLock<FileMutationQueue> = std::sync::OnceLock::new();
12
13pub fn global_mutation_queue() -> &'static FileMutationQueue {
15 QUEUE.get_or_init(FileMutationQueue::new)
16}
17
18#[derive(Debug)]
20pub struct FileMutationQueue {
21 queues: Arc<Mutex<HashMap<PathBuf, Arc<Mutex<()>>>>>,
23}
24
25impl FileMutationQueue {
26 pub fn new() -> Self {
28 Self {
29 queues: Arc::new(Mutex::new(HashMap::new())),
30 }
31 }
32
33 pub async fn with_queue<F, Fut, T>(&self, path: &Path, f: F) -> T
36 where
37 F: FnOnce() -> Fut,
38 Fut: std::future::Future<Output = T>,
39 {
40 let canonical = fs::canonicalize(path)
41 .await
42 .unwrap_or_else(|_| path.to_path_buf());
43
44 let mutex = {
46 let mut queues = self.queues.lock().await;
47 queues
48 .entry(canonical)
49 .or_insert_with(|| Arc::new(Mutex::new(())))
50 .clone()
51 };
52
53 let _guard = mutex.lock().await;
55
56 f().await
58 }
59
60 pub async fn cleanup(&self, path: &Path) {
62 let canonical = fs::canonicalize(path)
63 .await
64 .unwrap_or_else(|_| path.to_path_buf());
65 let mut queues = self.queues.lock().await;
66 queues.remove(&canonical);
67 }
68}
69
70impl Default for FileMutationQueue {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use std::sync::atomic::{AtomicUsize, Ordering};
80
81 #[tokio::test]
82 async fn test_same_file_serialized() {
83 let queue = Arc::new(FileMutationQueue::new());
84 let counter = Arc::new(AtomicUsize::new(0));
85 let path = PathBuf::from("/tmp/test_mutation_queue_file");
86
87 let mut handles = Vec::new();
88
89 for _ in 0..10 {
90 let queue = queue.clone();
91 let counter = counter.clone();
92 let path = path.clone();
93
94 handles.push(tokio::spawn(async move {
95 queue
96 .with_queue(&path, || async {
97 let prev = counter.fetch_add(1, Ordering::SeqCst);
98 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
100 prev
101 })
102 .await
103 }));
104 }
105
106 for handle in handles {
108 let _ = handle.await.unwrap();
109 }
110
111 assert_eq!(counter.load(Ordering::SeqCst), 10);
112 }
113
114 #[tokio::test]
115 async fn test_different_files_parallel() {
116 let queue = Arc::new(FileMutationQueue::new());
117 let counter = Arc::new(AtomicUsize::new(0));
118
119 let path1 = PathBuf::from("/tmp/test_file_1");
120 let path2 = PathBuf::from("/tmp/test_file_2");
121
122 let q1 = queue.clone();
123 let q2 = queue.clone();
124 let counter1 = counter.clone();
125 let counter2 = counter.clone();
126
127 let h1 = tokio::spawn(async move {
128 q1.with_queue(&path1, || async {
129 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
130 counter1.fetch_add(1, Ordering::SeqCst)
131 })
132 .await
133 });
134
135 let h2 = tokio::spawn(async move {
136 q2.with_queue(&path2, || async { counter2.fetch_add(1, Ordering::SeqCst) })
137 .await
138 });
139
140 let r1 = tokio::time::timeout(std::time::Duration::from_millis(100), h1).await;
142 let r2 = tokio::time::timeout(std::time::Duration::from_millis(100), h2).await;
143
144 assert!(r1.is_ok());
145 assert!(r2.is_ok());
146 }
147}