1use std::{
2 ffi::OsString,
3 io,
4 path::{Path, PathBuf},
5 process::ExitStatus,
6 str::FromStr,
7};
8
9use crate::{
10 events::MaskWrapper,
11 parser::WatchOption,
12 parser::{parse_command, parse_masks, parse_path},
13};
14use inotify::{Event, WatchMask};
15use tracing::{event, Level};
16use winnow::{combinator::cut_err, Parser};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct Command {
20 pub program: String,
21 pub argv: Vec<String>,
22}
23
24impl Command {
25 pub async fn execute(
26 &self,
27 path: &Path,
28 event: &Event<OsString>,
29 ) -> Result<ExitStatus, io::Error> {
30 tokio::process::Command::new(&self.program)
31 .args(self.argv.iter().map(|arg| {
32 let mut formatted = String::new();
33 let mut parsing_dollar = false;
34
35 for c in arg.chars() {
36 if c == '$' {
37 if parsing_dollar {
38 formatted.push(c);
39 }
40 parsing_dollar = !parsing_dollar;
41 } else if parsing_dollar {
42 match c {
43 '#' => formatted.push_str(
44 event
45 .name
46 .as_deref()
47 .map(|s| s.to_str().unwrap_or_default())
48 .unwrap_or_default(),
49 ),
50 '@' => formatted.push_str(path.to_str().unwrap_or_default()),
51 '%' => formatted.push_str(&format!("\"{:?}\"", event.mask)),
52 '&' => formatted.push_str(&event.mask.bits().to_string()),
53 _ => formatted.push(c),
54 }
55 parsing_dollar = false;
56 } else {
57 formatted.push(c);
58 }
59 }
60 formatted
61 }))
62 .status()
63 .await
64 }
65}
66
67#[derive(Debug, PartialEq, Eq)]
68pub enum ParseWatchError {
69 InvalidMask,
70 IsComment,
71 CorruptInput,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct WatchDataAttributes {
76 pub starting: bool,
77 pub recursive: bool,
78}
79
80impl Default for WatchDataAttributes {
81 fn default() -> Self {
82 Self {
83 starting: true,
84 recursive: false,
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct WatchData {
91 pub path: PathBuf,
92 pub masks: WatchMask,
93 pub command: Command,
94 pub attributes: WatchDataAttributes,
95}
96
97impl FromStr for WatchData {
98 type Err = ParseWatchError;
99
100 #[tracing::instrument]
101 fn from_str(s: &str) -> Result<Self, Self::Err> {
102 let s = s.trim();
103 if s.starts_with('#') {
104 return Err(ParseWatchError::IsComment);
105 };
106
107 (
108 cut_err(parse_path),
109 cut_err(parse_masks),
110 cut_err(parse_command),
111 )
112 .map(|(path, watch_options, command)| {
113 let mut masks = WatchMask::empty();
114 let mut attributes = WatchDataAttributes::default();
115
116 for option in watch_options {
117 match option {
118 WatchOption::Mask(mask) => {
119 let mask = match mask.parse::<MaskWrapper>() {
120 Ok(m) => m,
121 Err(_) => {
122 event!(Level::ERROR, mask, "invalid mask");
123 return Err(ParseWatchError::InvalidMask);
124 }
125 };
126
127 masks = masks.union(mask.0);
128 }
129 WatchOption::Attribute(flag, value) => match flag.as_str() {
130 "recursive" => attributes.recursive = value,
131 _ => continue,
132 },
133 }
134 }
135
136 Ok(WatchData {
137 path,
138 command,
139 masks,
140 attributes,
141 })
142 })
143 .parse(s)
144 .map_err(|error| {
145 event!(Level::ERROR, ?error);
146 ParseWatchError::CorruptInput
147 })?
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use std::path::PathBuf;
154
155 use inotify::WatchMask;
156
157 use crate::watch::{Command, ParseWatchError, WatchData, WatchDataAttributes};
158
159 const LINE_DATA: &str = include_str!("../assets/test/test-line");
160 const DATA: &str = include_str!("../assets/test/test-table");
161
162 fn get_test_watch() -> WatchData {
163 WatchData {
164 path: PathBuf::from("/var/tmp"),
165 masks: WatchMask::CREATE | WatchMask::DELETE,
166 attributes: WatchDataAttributes {
167 starting: true,
168 recursive: true,
169 },
170 command: Command {
171 program: String::from("echo"),
172 argv: ["$@", "$#", "&>", "/dev/null"].map(String::from).to_vec(),
173 },
174 }
175 }
176
177 #[test]
178 fn test_parse_line() {
179 assert_eq!(LINE_DATA.parse::<WatchData>().unwrap(), get_test_watch());
180 }
181
182 #[test]
183 fn test_parse_table() {
184 assert_eq!(
185 DATA.lines()
186 .map(|l| l.parse::<WatchData>())
187 .collect::<Vec<Result<WatchData, ParseWatchError>>>(),
188 vec![
189 Ok(get_test_watch()),
190 Ok(get_test_watch()),
191 Err(ParseWatchError::InvalidMask),
192 Err(ParseWatchError::IsComment),
193 Err(ParseWatchError::CorruptInput),
194 ]
195 )
196 }
197}