1use std::collections::HashMap;
2use std::env;
3#[cfg(unix)]
4use std::fs;
5#[cfg(unix)]
6use std::os::unix::fs::FileTypeExt;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::sync::Mutex;
10
11#[cfg(unix)]
12use crate::Store;
13use crate::git;
14
15#[cfg(unix)]
16use ofiles::opath;
17
18const GIT_ENV_SSH: &str = "GIT_SSH_COMMAND";
20
21const SSH_PERSIST_CMD: &str = "ssh -o 'ControlMaster auto' -o 'ControlPath /tmp/.prs-session--%r@%h:%p' -o 'ControlPersist 1h' -o 'ConnectTimeout 10'";
30
31#[cfg(unix)]
33pub(crate) const SSH_PERSIST_SESSION_FILE_DIR: &str = "/tmp";
34
35#[cfg(unix)]
37pub(crate) const SSH_PERSIST_SESSION_FILE_PREFIX: &str = ".prs-session--";
38
39const SSH_PERSIST_HOST_WHITELIST: [&str; 2] = ["github.com", "gitlab.com"];
41
42lazy_static! {
43 static ref SSH_PERSIST_GUESS_CACHE: Mutex<HashMap<PathBuf, bool>> = Mutex::new(HashMap::new());
45}
46
47pub(crate) fn configure_ssh_persist(cmd: &mut Command) {
51 cmd.env(self::GIT_ENV_SSH, self::SSH_PERSIST_CMD);
52}
53
54pub(crate) fn guess_ssh_persist_support(repo: &Path) -> bool {
68 if !cfg!(unix) {
70 return false;
71 }
72
73 if env::var_os(GIT_ENV_SSH).is_some() {
75 return false;
76 }
77
78 if let Ok(guard) = (*SSH_PERSIST_GUESS_CACHE).lock()
80 && let Some(supported) = guard.get(repo)
81 {
82 return *supported;
83 }
84
85 let remotes = match git::git_remote(repo) {
87 Ok(remotes) if remotes.is_empty() => return false,
88 Ok(remotes) => remotes,
89 Err(_) => return false,
90 };
91
92 let ssh_uris: Vec<_> = remotes
94 .iter()
95 .filter_map(|remote| git::git_remote_get_url(repo, remote).ok())
96 .filter(|uri| !remote_is_http(uri))
97 .collect();
98
99 let supported = ssh_uris.iter().all(|uri| match ssh_uri_host(uri) {
101 Some(host) => SSH_PERSIST_HOST_WHITELIST.contains(&host.to_lowercase().as_str()),
102 None => false,
103 });
104
105 if let Ok(mut guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
107 guard.insert(repo.to_path_buf(), supported);
108 }
109
110 supported
111}
112
113fn remote_is_http(mut url: &str) -> bool {
115 url = url.trim();
116 url.starts_with("http://") || url.starts_with("https://")
117}
118
119#[allow(clippy::manual_split_once, clippy::needless_splitn)]
125fn ssh_uri_host(mut uri: &str) -> Option<&str> {
126 if remote_is_http(uri) {
128 return None;
129 }
130
131 if let Some(stripped) = uri.strip_prefix("ssh://") {
133 uri = stripped;
134 }
135
136 let before_slash = uri.splitn(2, '/').next().unwrap();
139 let after_at = before_slash.splitn(2, '@').last().unwrap();
140 let before_collon = after_at.splitn(2, ':').next().unwrap();
141 let uri = before_collon.trim();
142
143 if uri.len() >= 3 { Some(uri) } else { None }
145}
146
147#[cfg(unix)]
151pub fn kill_ssh_by_session(store: &Store) {
152 if !guess_ssh_persist_support(&store.root) {
154 return;
155 }
156
157 let dir = match fs::read_dir(SSH_PERSIST_SESSION_FILE_DIR) {
161 Ok(dir) => dir,
162 Err(_) => return,
163 };
164 let session_files = dir
165 .flatten()
166 .filter(|e| e.file_type().map(|t| t.is_socket()).unwrap_or(false))
167 .filter(|e| {
168 e.file_name()
169 .to_str()
170 .map(|n| n.starts_with(SSH_PERSIST_SESSION_FILE_PREFIX))
171 .unwrap_or(false)
172 })
173 .map(|e| e.path());
174
175 session_files.for_each(|p| {
177 let pids = match opath(p) {
179 Ok(pids) => pids,
180 Err(_) => return,
181 };
182
183 pids.into_iter()
184 .map(Into::into)
185 .filter(|pid: &u32| pid > &0 && pid < &(i32::MAX as u32))
186 .filter(|pid| {
187 fs::read_to_string(format!("/proc/{pid}/cmdline"))
189 .map(|cmdline| {
190 let cmd = cmdline.split([' ', ':']).next().unwrap();
191 cmd.starts_with("ssh")
192 })
193 .unwrap_or(true)
194 })
195 .for_each(|pid| {
196 if let Err(err) = nix::sys::signal::kill(
197 nix::unistd::Pid::from_raw(pid as i32),
198 Some(nix::sys::signal::Signal::SIGTERM),
199 ) {
200 eprintln!("Failed to kill persistent SSH client (pid: {pid}): {err}",);
201 }
202 });
203 });
204}