Skip to main content

slim_auth/
file_watcher.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use display_error_chain::ErrorChainExt;
5use notify::event::ModifyKind;
6use notify::{Event, RecommendedWatcher, RecursiveMode, Watcher};
7use std::path::Path;
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tracing::debug;
11
12use thiserror::Error;
13
14#[derive(Error, Debug)]
15pub enum FileWatcherError {
16    #[error("watch error")]
17    WatchError(#[from] notify::Error),
18}
19
20#[derive(Debug)]
21pub struct FileWatcher {
22    watcher: RecommendedWatcher,
23    cancellation_token: CancellationToken,
24}
25
26impl Drop for FileWatcher {
27    fn drop(&mut self) {
28        self.stop_watcher();
29    }
30}
31
32impl FileWatcher {
33    pub fn create_watcher<F>(callback: F) -> Self
34    where
35        F: Fn(&str) + Send + 'static,
36    {
37        let (tx, mut rx) = mpsc::channel::<notify::Result<Event>>(10);
38        let watcher = notify::recommended_watcher(move |res| {
39            // Send file system events to the event channel
40            let _ = tx.blocking_send(res);
41        })
42        .expect("error creating the watcher");
43        let fw = FileWatcher {
44            watcher,
45            cancellation_token: CancellationToken::new(),
46        };
47
48        let c_token = fw.cancellation_token.clone();
49        tokio::spawn(async move {
50            debug!("starting new watcher");
51            loop {
52                tokio::select! {
53                    next = rx.recv() => {
54                        match next {
55                            Some(res) => {
56                                match res {
57                                    Ok(event) => {
58                                        if let notify::EventKind::Modify(ModifyKind::Data(_)) = event.kind {
59                                            if event.paths.is_empty() {
60                                                // skip this event, we don't know the associated file
61                                                continue;
62                                            }
63                                            if let Some(p) = event.paths.first().and_then(|p| p.to_str()) {
64                                                debug!(event = ?event, "detected event");
65                                                callback(p);
66                                            }
67                                        }
68                                    }
69                                    Err(e) => tracing::error!(error = %e.chain(), "watch error"),
70                                }
71                            }
72                            None => {
73                                debug!("channel closed, stop watcher");
74                                break;
75                            }
76                        }
77                    }
78                    _ = c_token.cancelled() => {
79                        debug!("cancellation token signaled, stop watcher");
80                        break;
81                    }
82                }
83            }
84        });
85
86        // return the FileWatcher
87        fw
88    }
89
90    pub fn add_file(&mut self, file_name: &str) -> Result<(), FileWatcherError> {
91        self.watcher
92            .watch(Path::new(file_name), RecursiveMode::NonRecursive)?;
93
94        debug!(%file_name, "start watching file");
95
96        Ok(())
97    }
98
99    pub fn stop_watcher(&self) {
100        self.cancellation_token.cancel();
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use tokio::time;
108    use tracing::info;
109    use tracing_test::traced_test;
110
111    use parking_lot::RwLock;
112    use std::collections::HashMap;
113    use std::fs::{File, OpenOptions};
114    use std::io::{Seek, SeekFrom, Write};
115    use std::sync::Arc;
116    use std::time::Duration;
117    use std::{env, fs};
118
119    fn create_file(file_path: &str, content: &str) -> std::io::Result<()> {
120        let mut file = File::create(file_path)?;
121        file.write_all(content.as_bytes())?;
122        Ok(())
123    }
124
125    fn modify_file(file_path: &str, new_content: &str) -> std::io::Result<()> {
126        let mut file = OpenOptions::new().write(true).open(file_path)?;
127        file.seek(SeekFrom::Start(0))?;
128        file.write_all(new_content.as_bytes())?;
129        Ok(())
130    }
131
132    fn delete_file(file_path: &str) -> std::io::Result<()> {
133        fs::remove_file(file_path)?;
134        Ok(())
135    }
136
137    #[tokio::test]
138    #[traced_test]
139    async fn test_watcher() {
140        let counter_map = Arc::new(RwLock::new(HashMap::<String, u32>::new()));
141        let clone_map = Arc::clone(&counter_map);
142
143        // create the watcher
144        let mut w = FileWatcher::create_watcher(move |file: &str| {
145            info!(%file, "modification detected");
146            let mut map = clone_map.write();
147            match map.get_mut(file) {
148                Some(val) => {
149                    let x = *val + 1;
150                    map.insert(String::from(file), x);
151                }
152                None => {
153                    map.insert(String::from(file), 1);
154                }
155            }
156        });
157
158        // create a new file
159        let path = env::current_dir().expect("error reading local path");
160        let full_path = path.join("test_file_watcher.txt");
161        let full_test_file_name = full_path.to_str().unwrap();
162        create_file(full_test_file_name, "CONFIG 1").expect("Failed to create file");
163
164        // add file to watcher
165        let res = w.add_file(full_test_file_name);
166        assert!(res.is_ok());
167
168        // modify the file
169        modify_file(full_test_file_name, "CONFIG 2").expect("Failed to modify file");
170        time::sleep(Duration::from_millis(100)).await;
171        {
172            let map = counter_map.read();
173            let res = map.get(full_test_file_name).expect("file does not exists");
174            assert_eq!(*res, 1);
175        }
176
177        modify_file(full_test_file_name, "CONFIG 3").expect("Failed to modify file");
178        time::sleep(Duration::from_millis(100)).await;
179        {
180            let map = counter_map.read();
181            let res = map.get(full_test_file_name).expect("file does not exists");
182            assert_eq!(*res, 2);
183        }
184
185        // add other file to watch
186        let path = env::current_dir().expect("error reading local path");
187        let full_path = path.join("test_file_watcher_2.txt");
188        let full_test_file_name_2 = full_path.to_str().unwrap();
189        create_file(full_test_file_name_2, "CONFIG 1").expect("Failed to create file");
190
191        // add file to watcher
192        let res = w.add_file(full_test_file_name_2);
193        assert!(res.is_ok());
194
195        // modify the file
196        modify_file(full_test_file_name_2, "CONFIG 2").expect("Failed to modify file");
197        time::sleep(Duration::from_millis(100)).await;
198        {
199            let map = counter_map.read();
200            let res = map.get(full_test_file_name).expect("file does not exists");
201            assert_eq!(*res, 2);
202            let res = map
203                .get(full_test_file_name_2)
204                .expect("file does not exists");
205            assert_eq!(*res, 1);
206        }
207
208        modify_file(full_test_file_name_2, "CONFIG 3").expect("Failed to modify file");
209        time::sleep(Duration::from_millis(100)).await;
210        {
211            let map = counter_map.read();
212            let res = map.get(full_test_file_name).expect("file does not exists");
213            assert_eq!(*res, 2);
214            let res = map
215                .get(full_test_file_name_2)
216                .expect("file does not exists");
217            assert_eq!(*res, 2);
218        }
219
220        modify_file(full_test_file_name, "CONFIG 4").expect("Failed to modify file");
221        time::sleep(Duration::from_millis(100)).await;
222        {
223            let map = counter_map.read();
224            let res = map.get(full_test_file_name).expect("file does not exists");
225            assert_eq!(*res, 3);
226            let res = map
227                .get(full_test_file_name_2)
228                .expect("file does not exists");
229            assert_eq!(*res, 2);
230        }
231
232        w.stop_watcher();
233
234        delete_file(full_test_file_name).expect("error deleting file");
235        delete_file(full_test_file_name_2).expect("error deleting file");
236
237        time::sleep(Duration::from_millis(100)).await;
238    }
239}