rsincronlib/
watch.rs

1use std::{
2    ffi::OsString,
3    io,
4    path::{Path, PathBuf},
5    process::ExitStatus,
6    str::FromStr,
7};
8
9use crate::{
10    events::MaskWrapper,
11    parser::WatchOption,
12    parser::{parse_command, parse_masks, parse_path},
13};
14use inotify::{Event, WatchMask};
15use tracing::{event, Level};
16use winnow::{combinator::cut_err, Parser};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19pub struct Command {
20    pub program: String,
21    pub argv: Vec<String>,
22}
23
24impl Command {
25    pub async fn execute(
26        &self,
27        path: &Path,
28        event: &Event<OsString>,
29    ) -> Result<ExitStatus, io::Error> {
30        tokio::process::Command::new(&self.program)
31            .args(self.argv.iter().map(|arg| {
32                let mut formatted = String::new();
33                let mut parsing_dollar = false;
34
35                for c in arg.chars() {
36                    if c == '$' {
37                        if parsing_dollar {
38                            formatted.push(c);
39                        }
40                        parsing_dollar = !parsing_dollar;
41                    } else if parsing_dollar {
42                        match c {
43                            '#' => formatted.push_str(
44                                event
45                                    .name
46                                    .as_deref()
47                                    .map(|s| s.to_str().unwrap_or_default())
48                                    .unwrap_or_default(),
49                            ),
50                            '@' => formatted.push_str(path.to_str().unwrap_or_default()),
51                            '%' => formatted.push_str(&format!("\"{:?}\"", event.mask)),
52                            '&' => formatted.push_str(&event.mask.bits().to_string()),
53                            _ => formatted.push(c),
54                        }
55                        parsing_dollar = false;
56                    } else {
57                        formatted.push(c);
58                    }
59                }
60                formatted
61            }))
62            .status()
63            .await
64    }
65}
66
67#[derive(Debug, PartialEq, Eq)]
68pub enum ParseWatchError {
69    InvalidMask,
70    IsComment,
71    CorruptInput,
72}
73
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub struct WatchDataAttributes {
76    pub starting: bool,
77    pub recursive: bool,
78}
79
80impl Default for WatchDataAttributes {
81    fn default() -> Self {
82        Self {
83            starting: true,
84            recursive: false,
85        }
86    }
87}
88
89#[derive(Debug, Clone, PartialEq, Eq)]
90pub struct WatchData {
91    pub path: PathBuf,
92    pub masks: WatchMask,
93    pub command: Command,
94    pub attributes: WatchDataAttributes,
95}
96
97impl FromStr for WatchData {
98    type Err = ParseWatchError;
99
100    #[tracing::instrument]
101    fn from_str(s: &str) -> Result<Self, Self::Err> {
102        let s = s.trim();
103        if s.starts_with('#') {
104            return Err(ParseWatchError::IsComment);
105        };
106
107        (
108            cut_err(parse_path),
109            cut_err(parse_masks),
110            cut_err(parse_command),
111        )
112            .map(|(path, watch_options, command)| {
113                let mut masks = WatchMask::empty();
114                let mut attributes = WatchDataAttributes::default();
115
116                for option in watch_options {
117                    match option {
118                        WatchOption::Mask(mask) => {
119                            let mask = match mask.parse::<MaskWrapper>() {
120                                Ok(m) => m,
121                                Err(_) => {
122                                    event!(Level::ERROR, mask, "invalid mask");
123                                    return Err(ParseWatchError::InvalidMask);
124                                }
125                            };
126
127                            masks = masks.union(mask.0);
128                        }
129                        WatchOption::Attribute(flag, value) => match flag.as_str() {
130                            "recursive" => attributes.recursive = value,
131                            _ => continue,
132                        },
133                    }
134                }
135
136                Ok(WatchData {
137                    path,
138                    command,
139                    masks,
140                    attributes,
141                })
142            })
143            .parse(s)
144            .map_err(|error| {
145                event!(Level::ERROR, ?error);
146                ParseWatchError::CorruptInput
147            })?
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use std::path::PathBuf;
154
155    use inotify::WatchMask;
156
157    use crate::watch::{Command, ParseWatchError, WatchData, WatchDataAttributes};
158
159    const LINE_DATA: &str = include_str!("../assets/test/test-line");
160    const DATA: &str = include_str!("../assets/test/test-table");
161
162    fn get_test_watch() -> WatchData {
163        WatchData {
164            path: PathBuf::from("/var/tmp"),
165            masks: WatchMask::CREATE | WatchMask::DELETE,
166            attributes: WatchDataAttributes {
167                starting: true,
168                recursive: true,
169            },
170            command: Command {
171                program: String::from("echo"),
172                argv: ["$@", "$#", "&>", "/dev/null"].map(String::from).to_vec(),
173            },
174        }
175    }
176
177    #[test]
178    fn test_parse_line() {
179        assert_eq!(LINE_DATA.parse::<WatchData>().unwrap(), get_test_watch());
180    }
181
182    #[test]
183    fn test_parse_table() {
184        assert_eq!(
185            DATA.lines()
186                .map(|l| l.parse::<WatchData>())
187                .collect::<Vec<Result<WatchData, ParseWatchError>>>(),
188            vec![
189                Ok(get_test_watch()),
190                Ok(get_test_watch()),
191                Err(ParseWatchError::InvalidMask),
192                Err(ParseWatchError::IsComment),
193                Err(ParseWatchError::CorruptInput),
194            ]
195        )
196    }
197}