use std::collections::HashMap;
#[cfg(unix)]
use std::os::unix::fs::FileTypeExt;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::sync::Mutex;
use std::{env, fs};
use crate::{git, Store};
#[cfg(unix)]
use ofiles::opath;
const GIT_ENV_SSH: &str = "GIT_SSH_COMMAND";
const SSH_PERSIST_CMD: &str = "ssh -o 'ControlMaster auto' -o 'ControlPath /tmp/.prs-session--%r@%h:%p' -o 'ControlPersist 1h' -o 'ConnectTimeout 10'";
#[cfg(unix)]
pub(crate) const SSH_PERSIST_SESSION_FILE_DIR: &str = "/tmp";
#[cfg(unix)]
pub(crate) const SSH_PERSIST_SESSION_FILE_PREFIX: &str = ".prs-session--";
const SSH_PERSIST_HOST_WHITELIST: [&str; 2] = ["github.com", "gitlab.com"];
lazy_static! {
static ref SSH_PERSIST_GUESS_CACHE: Mutex<HashMap<PathBuf, bool>> = Mutex::new(HashMap::new());
}
pub(crate) fn configure_ssh_persist(cmd: &mut Command) {
cmd.env(self::GIT_ENV_SSH, self::SSH_PERSIST_CMD);
}
pub(crate) fn guess_ssh_persist_support(repo: &Path) -> bool {
if !cfg!(unix) {
return false;
}
if env::var_os(GIT_ENV_SSH).is_some() {
return false;
}
if let Ok(guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
if let Some(supported) = guard.get(repo) {
return *supported;
}
}
let remotes = match git::git_remote(repo) {
Ok(remotes) if remotes.is_empty() => return false,
Ok(remotes) => remotes,
Err(_) => return false,
};
let ssh_uris: Vec<_> = remotes
.iter()
.filter_map(|remote| git::git_remote_get_url(repo, remote).ok())
.filter(|uri| !remote_is_http(uri))
.collect();
let supported = ssh_uris.iter().all(|uri| match ssh_uri_host(uri) {
Some(host) => SSH_PERSIST_HOST_WHITELIST.contains(&host.to_lowercase().as_str()),
None => false,
});
if let Ok(mut guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
guard.insert(repo.to_path_buf(), supported);
}
supported
}
fn remote_is_http(mut url: &str) -> bool {
url = url.trim();
url.starts_with("http://") || url.starts_with("https://")
}
#[allow(clippy::manual_split_once, clippy::needless_splitn)]
fn ssh_uri_host(mut uri: &str) -> Option<&str> {
if remote_is_http(uri) {
return None;
}
if let Some(stripped) = uri.strip_prefix("ssh://") {
uri = stripped;
}
let before_slash = uri.splitn(2, '/').next().unwrap();
let after_at = before_slash.splitn(2, '@').last().unwrap();
let before_collon = after_at.splitn(2, ':').next().unwrap();
let uri = before_collon.trim();
if uri.len() >= 3 {
Some(uri)
} else {
None
}
}
#[cfg(unix)]
pub fn kill_ssh_by_session(store: &Store) {
if !guess_ssh_persist_support(&store.root) {
return;
}
let dir = match fs::read_dir(SSH_PERSIST_SESSION_FILE_DIR) {
Ok(dir) => dir,
Err(_) => return,
};
let session_files = dir
.flatten()
.filter(|e| e.file_type().map(|t| t.is_socket()).unwrap_or(false))
.filter(|e| {
e.file_name()
.to_str()
.map(|n| n.starts_with(SSH_PERSIST_SESSION_FILE_PREFIX))
.unwrap_or(false)
})
.map(|e| e.path());
session_files.for_each(|p| {
let pids = match opath(p) {
Ok(pids) => pids,
Err(_) => return,
};
pids.into_iter()
.map(Into::into)
.filter(|pid: &u32| pid > &0 && pid < &(i32::MAX as u32))
.filter(|pid| {
fs::read_to_string(format!("/proc/{}/cmdline", pid))
.map(|cmdline| {
let cmd = cmdline.split(|b| b == ' ' || b == ':').next().unwrap();
cmd.starts_with("ssh")
})
.unwrap_or(true)
})
.for_each(|pid| {
if let Err(err) = nix::sys::signal::kill(
nix::unistd::Pid::from_raw(pid as i32),
Some(nix::sys::signal::Signal::SIGTERM),
) {
eprintln!(
"Failed to kill persistent SSH client (pid: {}): {}",
pid, err
);
}
});
});
}