rsincronlib/
state.rs

1use crate::{
2    config::Config,
3    watch::{ParseWatchError, WatchData, WatchDataAttributes},
4    SocketMessage, SOCKET,
5};
6use inotify::{Inotify, WatchDescriptor, WatchMask};
7use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
8use tracing::{event, span, Level};
9
10use std::{
11    collections::HashMap,
12    fs,
13    io::Read,
14    os::unix::net::UnixListener,
15    path::Path,
16    str::FromStr,
17    sync::{Arc, Mutex, MutexGuard},
18    thread,
19};
20
21#[tracing::instrument(skip_all)]
22fn setup_socket(tx: UnboundedSender<SocketMessage>) -> bool {
23    let Ok(ref socket) = *SOCKET else {
24        event!(Level::WARN, error = ?SOCKET.as_deref().unwrap_err(), "failed to get socket path");
25        return false;
26    };
27
28    if Path::new(&socket).exists() {
29        if let Err(error) = fs::remove_file(&socket) {
30            event!(Level::WARN, ?error, "failed to remove existing socket");
31            return false;
32        }
33    }
34
35    let listener = match UnixListener::bind(socket) {
36        Ok(l) => l,
37        Err(error) => {
38            event!(Level::WARN, ?error, "failed to bind to socket");
39            return false;
40        }
41    };
42
43    thread::spawn(move || {
44        for mut stream in listener.incoming().flatten() {
45            let mut buffer = [0; 100];
46            if stream.read(&mut buffer).is_err() {
47                return;
48            }
49
50            let Ok(SocketMessage::UpdateWatches) = bincode::deserialize(&buffer) else {
51                return;
52            };
53
54            if let Err(error) = tx.send(SocketMessage::UpdateWatches) {
55                event!(
56                    Level::WARN,
57                    ?error,
58                    "failed to send update message through channel"
59                );
60            }
61        }
62    });
63
64    true
65}
66
67type Watches = HashMap<WatchDescriptor, WatchData>;
68
69pub struct State {
70    pub failed_watches: Vec<WatchData>,
71    pub has_socket: bool,
72    pub rx: UnboundedReceiver<SocketMessage>,
73
74    config: Config,
75    inotify_watches: inotify::Watches,
76    watches: Watches,
77
78    span: tracing::Span,
79}
80
81impl State {
82    pub fn new(inotify: &mut Inotify, config: Config) -> Self {
83        let (tx, rx) = mpsc::unbounded_channel();
84
85        Self {
86            rx,
87            config,
88            watches: HashMap::new(),
89            failed_watches: Vec::new(),
90            has_socket: setup_socket(tx),
91            inotify_watches: inotify.watches(),
92            span: span!(Level::INFO, "state"),
93        }
94    }
95
96    #[tracing::instrument(skip_all, parent = &self.span)]
97    pub fn reload_watches(&mut self) {
98        self.watches.clear();
99        event!(Level::INFO, table = ?self.config.watch_table_file, "RELOAD");
100        event!(Level::DEBUG, ?self.watches);
101        let table_content = match fs::read_to_string(&self.config.watch_table_file) {
102            Ok(table_content) => table_content,
103            _ => {
104                event!(Level::ERROR, filename = ?self.config.watch_table_file, "failed to read file");
105                panic!("failed to read watch table file");
106            }
107        };
108
109        for line in table_content.lines() {
110            let watch = match WatchData::from_str(line) {
111                Ok(w) => w,
112                Err(error) => {
113                    if error != ParseWatchError::IsComment {
114                        event!(Level::WARN, ?error, line, "failed to parse line");
115                    }
116
117                    continue;
118                }
119            };
120
121            self.add_watch(watch);
122        }
123    }
124
125    #[tracing::instrument(skip_all, parent = &self.span)]
126    pub fn recover_watches(&mut self) {
127        self.failed_watches.retain(|watch| {
128            let Ok(descriptor) = self.inotify_watches.add(watch.path.clone(), watch.masks) else {
129                return true;
130            };
131
132            event!(
133                Level::INFO,
134                id = descriptor.get_watch_descriptor_id(),
135                ?watch.path,
136                ?watch.masks,
137                "ADD"
138            );
139            self.watches.insert(descriptor, watch.clone());
140            false
141        });
142    }
143
144    pub fn get_watch(&self, wd: &WatchDescriptor) -> Option<&WatchData> {
145        self.watches.get(wd)
146    }
147
148    pub fn remove_watch(&mut self, wd: &WatchDescriptor) -> Option<WatchData> {
149        self.watches.remove(wd)
150    }
151
152    #[tracing::instrument(skip_all, parent = &self.span)]
153    fn add_watch(&mut self, watch: WatchData) {
154        let Ok(descriptor) = self.inotify_watches.add(&watch.path, watch.masks) else {
155            event!(Level::WARN, "failed to add watch");
156            return;
157        };
158
159        event!(
160            Level::INFO,
161            id = descriptor.get_watch_descriptor_id(),
162            ?watch.path,
163            ?watch.masks,
164            "ADD"
165        );
166
167        if watch.attributes.recursive && watch.masks.contains(WatchMask::CREATE) {
168            for entry in fs::read_dir(&watch.path).unwrap() {
169                let Ok(entry) = entry else {
170                    continue;
171                };
172
173                let Ok(metadata) = entry.metadata() else {
174                    continue;
175                };
176
177                if !metadata.is_dir() {
178                    continue;
179                }
180
181                let watch = WatchData {
182                    path: watch.path.join(entry.file_name()),
183                    attributes: WatchDataAttributes {
184                        starting: false,
185                        recursive: true,
186                    },
187                    ..watch.clone()
188                };
189
190                self.add_watch(watch)
191            }
192        };
193
194        self.watches.insert(descriptor, watch);
195    }
196}
197
198pub struct Shared {
199    pub state: Mutex<State>,
200}
201
202impl Shared {
203    fn with_lock(&self) -> MutexGuard<'_, State> {
204        self.state.lock().unwrap()
205    }
206    pub fn reload_watches(&self) {
207        self.with_lock().reload_watches();
208    }
209
210    pub fn recover_watches(&self) {
211        self.with_lock().recover_watches();
212    }
213
214    pub fn get_watch(&self, wd: &WatchDescriptor) -> Option<WatchData> {
215        self.with_lock().get_watch(wd).cloned()
216    }
217
218    pub fn remove_watch(&self, wd: &WatchDescriptor) -> Option<WatchData> {
219        self.with_lock().remove_watch(wd)
220    }
221
222    pub fn has_socket(&self) -> bool {
223        self.with_lock().has_socket
224    }
225
226    pub fn rx_try_recv(&self) -> Result<SocketMessage, mpsc::error::TryRecvError> {
227        self.with_lock().rx.try_recv()
228    }
229
230    pub fn push_failed_watch(&self, watch: WatchData) {
231        self.with_lock().failed_watches.push(watch)
232    }
233
234    pub fn unset_socket(&self) {
235        self.with_lock().has_socket = false;
236    }
237}
238
239pub type ArcShared = Arc<Shared>;