1use crate::{
2 config::Config,
3 watch::{ParseWatchError, WatchData, WatchDataAttributes},
4 SocketMessage, SOCKET,
5};
6use inotify::{Inotify, WatchDescriptor, WatchMask};
7use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
8use tracing::{event, span, Level};
9
10use std::{
11 collections::HashMap,
12 fs,
13 io::Read,
14 os::unix::net::UnixListener,
15 path::Path,
16 str::FromStr,
17 sync::{Arc, Mutex, MutexGuard},
18 thread,
19};
20
21#[tracing::instrument(skip_all)]
22fn setup_socket(tx: UnboundedSender<SocketMessage>) -> bool {
23 let Ok(ref socket) = *SOCKET else {
24 event!(Level::WARN, error = ?SOCKET.as_deref().unwrap_err(), "failed to get socket path");
25 return false;
26 };
27
28 if Path::new(&socket).exists() {
29 if let Err(error) = fs::remove_file(&socket) {
30 event!(Level::WARN, ?error, "failed to remove existing socket");
31 return false;
32 }
33 }
34
35 let listener = match UnixListener::bind(socket) {
36 Ok(l) => l,
37 Err(error) => {
38 event!(Level::WARN, ?error, "failed to bind to socket");
39 return false;
40 }
41 };
42
43 thread::spawn(move || {
44 for mut stream in listener.incoming().flatten() {
45 let mut buffer = [0; 100];
46 if stream.read(&mut buffer).is_err() {
47 return;
48 }
49
50 let Ok(SocketMessage::UpdateWatches) = bincode::deserialize(&buffer) else {
51 return;
52 };
53
54 if let Err(error) = tx.send(SocketMessage::UpdateWatches) {
55 event!(
56 Level::WARN,
57 ?error,
58 "failed to send update message through channel"
59 );
60 }
61 }
62 });
63
64 true
65}
66
67type Watches = HashMap<WatchDescriptor, WatchData>;
68
69pub struct State {
70 pub failed_watches: Vec<WatchData>,
71 pub has_socket: bool,
72 pub rx: UnboundedReceiver<SocketMessage>,
73
74 config: Config,
75 inotify_watches: inotify::Watches,
76 watches: Watches,
77
78 span: tracing::Span,
79}
80
81impl State {
82 pub fn new(inotify: &mut Inotify, config: Config) -> Self {
83 let (tx, rx) = mpsc::unbounded_channel();
84
85 Self {
86 rx,
87 config,
88 watches: HashMap::new(),
89 failed_watches: Vec::new(),
90 has_socket: setup_socket(tx),
91 inotify_watches: inotify.watches(),
92 span: span!(Level::INFO, "state"),
93 }
94 }
95
96 #[tracing::instrument(skip_all, parent = &self.span)]
97 pub fn reload_watches(&mut self) {
98 self.watches.clear();
99 event!(Level::INFO, table = ?self.config.watch_table_file, "RELOAD");
100 event!(Level::DEBUG, ?self.watches);
101 let table_content = match fs::read_to_string(&self.config.watch_table_file) {
102 Ok(table_content) => table_content,
103 _ => {
104 event!(Level::ERROR, filename = ?self.config.watch_table_file, "failed to read file");
105 panic!("failed to read watch table file");
106 }
107 };
108
109 for line in table_content.lines() {
110 let watch = match WatchData::from_str(line) {
111 Ok(w) => w,
112 Err(error) => {
113 if error != ParseWatchError::IsComment {
114 event!(Level::WARN, ?error, line, "failed to parse line");
115 }
116
117 continue;
118 }
119 };
120
121 self.add_watch(watch);
122 }
123 }
124
125 #[tracing::instrument(skip_all, parent = &self.span)]
126 pub fn recover_watches(&mut self) {
127 self.failed_watches.retain(|watch| {
128 let Ok(descriptor) = self.inotify_watches.add(watch.path.clone(), watch.masks) else {
129 return true;
130 };
131
132 event!(
133 Level::INFO,
134 id = descriptor.get_watch_descriptor_id(),
135 ?watch.path,
136 ?watch.masks,
137 "ADD"
138 );
139 self.watches.insert(descriptor, watch.clone());
140 false
141 });
142 }
143
144 pub fn get_watch(&self, wd: &WatchDescriptor) -> Option<&WatchData> {
145 self.watches.get(wd)
146 }
147
148 pub fn remove_watch(&mut self, wd: &WatchDescriptor) -> Option<WatchData> {
149 self.watches.remove(wd)
150 }
151
152 #[tracing::instrument(skip_all, parent = &self.span)]
153 fn add_watch(&mut self, watch: WatchData) {
154 let Ok(descriptor) = self.inotify_watches.add(&watch.path, watch.masks) else {
155 event!(Level::WARN, "failed to add watch");
156 return;
157 };
158
159 event!(
160 Level::INFO,
161 id = descriptor.get_watch_descriptor_id(),
162 ?watch.path,
163 ?watch.masks,
164 "ADD"
165 );
166
167 if watch.attributes.recursive && watch.masks.contains(WatchMask::CREATE) {
168 for entry in fs::read_dir(&watch.path).unwrap() {
169 let Ok(entry) = entry else {
170 continue;
171 };
172
173 let Ok(metadata) = entry.metadata() else {
174 continue;
175 };
176
177 if !metadata.is_dir() {
178 continue;
179 }
180
181 let watch = WatchData {
182 path: watch.path.join(entry.file_name()),
183 attributes: WatchDataAttributes {
184 starting: false,
185 recursive: true,
186 },
187 ..watch.clone()
188 };
189
190 self.add_watch(watch)
191 }
192 };
193
194 self.watches.insert(descriptor, watch);
195 }
196}
197
198pub struct Shared {
199 pub state: Mutex<State>,
200}
201
202impl Shared {
203 fn with_lock(&self) -> MutexGuard<'_, State> {
204 self.state.lock().unwrap()
205 }
206 pub fn reload_watches(&self) {
207 self.with_lock().reload_watches();
208 }
209
210 pub fn recover_watches(&self) {
211 self.with_lock().recover_watches();
212 }
213
214 pub fn get_watch(&self, wd: &WatchDescriptor) -> Option<WatchData> {
215 self.with_lock().get_watch(wd).cloned()
216 }
217
218 pub fn remove_watch(&self, wd: &WatchDescriptor) -> Option<WatchData> {
219 self.with_lock().remove_watch(wd)
220 }
221
222 pub fn has_socket(&self) -> bool {
223 self.with_lock().has_socket
224 }
225
226 pub fn rx_try_recv(&self) -> Result<SocketMessage, mpsc::error::TryRecvError> {
227 self.with_lock().rx.try_recv()
228 }
229
230 pub fn push_failed_watch(&self, watch: WatchData) {
231 self.with_lock().failed_watches.push(watch)
232 }
233
234 pub fn unset_socket(&self) {
235 self.with_lock().has_socket = false;
236 }
237}
238
239pub type ArcShared = Arc<Shared>;