1use std::collections::HashMap;
2use std::env;
3#[cfg(unix)]
4use std::fs;
5#[cfg(unix)]
6use std::os::unix::fs::FileTypeExt;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::sync::Mutex;
10
11#[cfg(unix)]
12use crate::Store;
13use crate::git;
14
15const GIT_ENV_SSH: &str = "GIT_SSH_COMMAND";
17
18const SSH_PERSIST_CMD: &str = "ssh -o 'ControlMaster auto' -o 'ControlPath /tmp/.prs-session--%r@%h:%p' -o 'ControlPersist 1h' -o 'ConnectTimeout 10'";
27
28#[cfg(unix)]
30pub(crate) const SSH_PERSIST_SESSION_FILE_DIR: &str = "/tmp";
31
32#[cfg(unix)]
34pub(crate) const SSH_PERSIST_SESSION_FILE_PREFIX: &str = ".prs-session--";
35
36const SSH_PERSIST_HOST_WHITELIST: [&str; 2] = ["github.com", "gitlab.com"];
38
39lazy_static! {
40 static ref SSH_PERSIST_GUESS_CACHE: Mutex<HashMap<PathBuf, bool>> = Mutex::new(HashMap::new());
42}
43
44pub(crate) fn configure_ssh_persist(cmd: &mut Command) {
48 cmd.env(self::GIT_ENV_SSH, self::SSH_PERSIST_CMD);
49}
50
51pub(crate) fn guess_ssh_persist_support(repo: &Path) -> bool {
65 if !cfg!(unix) {
67 return false;
68 }
69
70 if env::var_os(GIT_ENV_SSH).is_some() {
72 return false;
73 }
74
75 if let Ok(guard) = (*SSH_PERSIST_GUESS_CACHE).lock()
77 && let Some(supported) = guard.get(repo)
78 {
79 return *supported;
80 }
81
82 let remotes = match git::git_remote(repo) {
84 Ok(remotes) if remotes.is_empty() => return false,
85 Ok(remotes) => remotes,
86 Err(_) => return false,
87 };
88
89 let ssh_uris: Vec<_> = remotes
91 .iter()
92 .filter_map(|remote| git::git_remote_get_url(repo, remote).ok())
93 .filter(|uri| !remote_is_http(uri))
94 .collect();
95
96 let supported = ssh_uris.iter().all(|uri| match ssh_uri_host(uri) {
98 Some(host) => SSH_PERSIST_HOST_WHITELIST.contains(&host.to_lowercase().as_str()),
99 None => false,
100 });
101
102 if let Ok(mut guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
104 guard.insert(repo.to_path_buf(), supported);
105 }
106
107 supported
108}
109
110fn remote_is_http(mut url: &str) -> bool {
112 url = url.trim();
113 url.starts_with("http://") || url.starts_with("https://")
114}
115
116#[allow(clippy::manual_split_once, clippy::needless_splitn)]
122fn ssh_uri_host(mut uri: &str) -> Option<&str> {
123 if remote_is_http(uri) {
125 return None;
126 }
127
128 if let Some(stripped) = uri.strip_prefix("ssh://") {
130 uri = stripped;
131 }
132
133 let before_slash = uri.splitn(2, '/').next().unwrap();
136 let after_at = before_slash.splitn(2, '@').last().unwrap();
137 let before_collon = after_at.splitn(2, ':').next().unwrap();
138 let uri = before_collon.trim();
139
140 if uri.len() >= 3 { Some(uri) } else { None }
142}
143
144#[cfg(unix)]
148pub fn kill_ssh_by_session(store: &Store) {
149 if !guess_ssh_persist_support(&store.root) {
151 return;
152 }
153
154 let dir = match fs::read_dir(SSH_PERSIST_SESSION_FILE_DIR) {
158 Ok(dir) => dir,
159 Err(_) => return,
160 };
161 let session_files = dir
162 .flatten()
163 .filter(|e| e.file_type().map(|t| t.is_socket()).unwrap_or(false))
164 .filter(|e| {
165 e.file_name()
166 .to_str()
167 .map(|n| n.starts_with(SSH_PERSIST_SESSION_FILE_PREFIX))
168 .unwrap_or(false)
169 })
170 .map(|e| e.path());
171
172 #[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
174 session_files.for_each(|path| {
175 use super::proc::{pids_with_file_open, cmdline};
176
177 let pids = match pids_with_file_open(&path) {
179 Ok(pids) => pids,
180 Err(_) => return,
181 };
182
183 pids.into_iter()
184 .filter(|pid| pid.as_raw() > 0 && pid.as_raw() < nix::libc::pid_t::MAX)
186 .filter(|pid| {
188 cmdline(*pid)
189 .map(|cmdline| {
190 let cmd = cmdline.split([' ', ':']).next().unwrap();
191 cmd.starts_with("ssh")
192 })
193 .unwrap_or(true)
194 })
195 .for_each(|pid| {
196 if let Err(err) = nix::sys::signal::kill(
197 pid,
198 Some(nix::sys::signal::Signal::SIGTERM),
199 ) {
200 eprintln!("Failed to kill persistent SSH client (pid: {pid}): {err}",);
201 }
202 });
203 });
204}