Skip to main content

watch_path/
ssh.rs

1use std::collections::HashMap;
2use std::io::Read;
3use std::net::TcpStream;
4use std::time::{Duration, Instant};
5
6use ssh2::Session;
7
8use crate::url::WatchTarget;
9use crate::watcher::{
10    ConnectionState, PathWatcher, WatchError, WatchEvent, WatchEventKind, WatchOptions,
11};
12
13enum WatchMode {
14    InotifyPush {
15        channel: ssh2::Channel,
16        buf: Vec<u8>,
17    },
18    StatPoll {
19        known_mtimes: HashMap<String, i64>,
20        last_poll: Instant,
21    },
22}
23
24pub struct SshWatcher {
25    session: Session,
26    target: WatchTarget,
27    mode: WatchMode,
28    pending: Vec<WatchEvent>,
29    poll_interval: Duration,
30    loss_timeout: Duration,
31    last_success: Instant,
32}
33
34impl SshWatcher {
35    pub fn connect(target: WatchTarget, options: &WatchOptions) -> Result<Self, WatchError> {
36        let host = target
37            .host
38            .as_deref()
39            .ok_or_else(|| WatchError::InvalidUrl("SSH requires a host".to_string()))?;
40        let port = target.port.unwrap_or(22);
41
42        let tcp = TcpStream::connect(format!("{host}:{port}"))
43            .map_err(|e| WatchError::Connection(e.to_string()))?;
44
45        let mut session = Session::new().map_err(|e| WatchError::Ssh(e.to_string()))?;
46        session.set_tcp_stream(tcp);
47        session
48            .handshake()
49            .map_err(|e| WatchError::Ssh(e.to_string()))?;
50
51        let user = target.user.as_deref().unwrap_or("root");
52        authenticate(&session, user, options)?;
53
54        let mode = try_inotifywait(&session, &target.path).unwrap_or_else(|| WatchMode::StatPoll {
55            known_mtimes: HashMap::new(),
56            last_poll: Instant::now() - options.poll_interval,
57        });
58
59        Ok(Self {
60            session,
61            target,
62            mode,
63            pending: Vec::new(),
64            poll_interval: options.poll_interval,
65            loss_timeout: options.loss_timeout,
66            last_success: Instant::now(),
67        })
68    }
69}
70
71fn authenticate(session: &Session, user: &str, options: &WatchOptions) -> Result<(), WatchError> {
72    if let Some(key_path) = &options.key_path {
73        session
74            .userauth_pubkey_file(user, None, key_path, options.password.as_deref())
75            .map_err(|e| WatchError::Ssh(format!("key auth failed: {e}")))?;
76    } else if let Some(password) = &options.password {
77        session
78            .userauth_password(user, password)
79            .map_err(|e| WatchError::Ssh(format!("password auth failed: {e}")))?;
80    } else {
81        session
82            .userauth_agent(user)
83            .map_err(|e| WatchError::Ssh(format!("agent auth failed: {e}")))?;
84    }
85    Ok(())
86}
87
88fn try_inotifywait(session: &Session, path: &str) -> Option<WatchMode> {
89    let mut check = session.channel_session().ok()?;
90    check.exec("which inotifywait").ok()?;
91    let mut output = String::new();
92    check.read_to_string(&mut output).ok()?;
93    check.wait_close().ok()?;
94    if check.exit_status().ok()? != 0 {
95        return None;
96    }
97
98    let mut channel = session.channel_session().ok()?;
99    let quoted_path = shlex::try_quote(path).ok()?;
100    let cmd = format!("inotifywait -m -r --format '%w%f %e' {quoted_path}");
101    channel.exec(&cmd).ok()?;
102
103    Some(WatchMode::InotifyPush {
104        channel,
105        buf: Vec::new(),
106    })
107}
108
109impl PathWatcher for SshWatcher {
110    fn poll(&mut self) -> Result<Vec<WatchEvent>, WatchError> {
111        match &mut self.mode {
112            WatchMode::InotifyPush { channel, buf } => {
113                self.session.set_blocking(false);
114                let mut tmp = [0u8; 4096];
115                loop {
116                    match channel.read(&mut tmp) {
117                        Ok(0) => break,
118                        Ok(n) => buf.extend_from_slice(&tmp[..n]),
119                        Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
120                        Err(e) => {
121                            self.session.set_blocking(true);
122                            return Err(WatchError::Ssh(e.to_string()));
123                        }
124                    }
125                }
126                self.session.set_blocking(true);
127
128                while let Some(pos) = buf.iter().position(|&b| b == b'\n') {
129                    let line = String::from_utf8_lossy(&buf[..pos]).to_string();
130                    buf.drain(..=pos);
131                    if let Some(event) = parse_inotify_line(&line) {
132                        self.pending.push(event);
133                    }
134                }
135
136                if !self.pending.is_empty() {
137                    self.last_success = Instant::now();
138                }
139            }
140            WatchMode::StatPoll {
141                known_mtimes,
142                last_poll,
143            } => {
144                if last_poll.elapsed() < self.poll_interval {
145                    return Ok(Vec::new());
146                }
147                *last_poll = Instant::now();
148
149                let path = self.target.path.clone();
150                let mut channel = self
151                    .session
152                    .channel_session()
153                    .map_err(|e| WatchError::Ssh(e.to_string()))?;
154
155                let quoted_path = shlex::try_quote(&path).map_err(|_| {
156                    WatchError::InvalidUrl(format!("path contains invalid characters: {path}"))
157                })?;
158                let cmd = format!("find {quoted_path} -type f -printf '%p %T@\\n'");
159                channel
160                    .exec(&cmd)
161                    .map_err(|e| WatchError::Ssh(e.to_string()))?;
162
163                let mut output = String::new();
164                channel
165                    .read_to_string(&mut output)
166                    .map_err(|e| WatchError::Ssh(e.to_string()))?;
167                let _ = channel.wait_close();
168
169                self.last_success = Instant::now();
170
171                let mut current_mtimes: HashMap<String, i64> = HashMap::new();
172                for line in output.lines() {
173                    let parts: Vec<&str> = line.rsplitn(2, ' ').collect();
174                    if parts.len() != 2 {
175                        continue;
176                    }
177                    let mtime_str = parts[0];
178                    let file_path = parts[1];
179                    if let Ok(mtime) = mtime_str.parse::<f64>() {
180                        current_mtimes.insert(file_path.to_string(), mtime as i64);
181                    }
182                }
183
184                for (file_path, mtime) in &current_mtimes {
185                    let changed = match known_mtimes.get(file_path) {
186                        Some(old_mtime) => *mtime != *old_mtime,
187                        None => true,
188                    };
189                    if changed {
190                        let kind = if known_mtimes.contains_key(file_path) {
191                            WatchEventKind::Modified
192                        } else {
193                            WatchEventKind::Created
194                        };
195                        self.pending.push(WatchEvent {
196                            path: file_path.clone(),
197                            kind,
198                        });
199                    }
200                }
201
202                *known_mtimes = current_mtimes;
203            }
204        }
205
206        Ok(std::mem::take(&mut self.pending))
207    }
208
209    fn read(&mut self, path: &str) -> Result<Vec<u8>, WatchError> {
210        let sftp = self
211            .session
212            .sftp()
213            .map_err(|e| WatchError::Ssh(e.to_string()))?;
214
215        let mut file = sftp
216            .open(std::path::Path::new(path))
217            .map_err(|e| WatchError::Ssh(e.to_string()))?;
218
219        let mut buf = Vec::new();
220        file.read_to_end(&mut buf)
221            .map_err(|e| WatchError::Ssh(e.to_string()))?;
222
223        self.last_success = Instant::now();
224        Ok(buf)
225    }
226
227    fn has_pending(&self) -> bool {
228        !self.pending.is_empty()
229    }
230
231    fn connection_state(&self) -> ConnectionState {
232        let elapsed = self.last_success.elapsed();
233        if elapsed < self.poll_interval * 2 {
234            ConnectionState::Connected
235        } else if elapsed < self.loss_timeout {
236            ConnectionState::Degraded
237        } else {
238            ConnectionState::Lost
239        }
240    }
241}
242
243fn parse_inotify_line(line: &str) -> Option<WatchEvent> {
244    let parts: Vec<&str> = line.splitn(2, ' ').collect();
245    if parts.len() != 2 {
246        return None;
247    }
248
249    let path = parts[0].to_string();
250    let events_str = parts[1];
251
252    let kind = if events_str.contains("CREATE") {
253        WatchEventKind::Created
254    } else if events_str.contains("MODIFY") || events_str.contains("CLOSE_WRITE") {
255        WatchEventKind::Modified
256    } else {
257        return None;
258    };
259
260    Some(WatchEvent { path, kind })
261}