prs_lib/util/
git.rs

1use std::collections::HashMap;
2use std::env;
3#[cfg(unix)]
4use std::fs;
5#[cfg(unix)]
6use std::os::unix::fs::FileTypeExt;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::sync::Mutex;
10
11#[cfg(unix)]
12use crate::Store;
13use crate::git;
14
15#[cfg(unix)]
16use ofiles::opath;
17
18/// Environment variable git uses to modify the ssh command.
19const GIT_ENV_SSH: &str = "GIT_SSH_COMMAND";
20
21/// Custom ssh command for git.
22///
23/// With this custom SSH command we enable SSH connection persistence for session reuse to make
24/// remote git operations much quicker for repositories using an SSH URL. This greatly improves prs
25/// sync speeds.
26///
27/// This sets up a session file in the users `/tmp` directory. A timeout of 10 seconds is set to
28/// quickly abort a connection attempt if the persistent connection fails.
29const SSH_PERSIST_CMD: &str = "ssh -o 'ControlMaster auto' -o 'ControlPath /tmp/.prs-session--%r@%h:%p' -o 'ControlPersist 1h' -o 'ConnectTimeout 10'";
30
31/// Directory for SSH persistent session files.
32#[cfg(unix)]
33pub(crate) const SSH_PERSIST_SESSION_FILE_DIR: &str = "/tmp";
34
35/// Prefix for SSH persistent session files.
36#[cfg(unix)]
37pub(crate) const SSH_PERSIST_SESSION_FILE_PREFIX: &str = ".prs-session--";
38
39/// A whitelist of SSH hosts that support connection persisting.
40const SSH_PERSIST_HOST_WHITELIST: [&str; 2] = ["github.com", "gitlab.com"];
41
42lazy_static! {
43    /// Cache for SSH connection persistence support guess.
44    static ref SSH_PERSIST_GUESS_CACHE: Mutex<HashMap<PathBuf, bool>> = Mutex::new(HashMap::new());
45}
46
47/// Configure given git command to use SSH connection persisting.
48///
49/// `guess_ssh_connection_persist_support` should be used to guess whether this is supported.
50pub(crate) fn configure_ssh_persist(cmd: &mut Command) {
51    cmd.env(self::GIT_ENV_SSH, self::SSH_PERSIST_CMD);
52}
53
54/// Guess whether SSH connection persistence is supported.
55///
56/// This does a best effort to determine whether SSH connection persistence is supported. This is
57/// used to enable connection reuse. This internally caches the guess in the current process by
58/// repository path.
59///
60/// - Disabled on non-Unix
61/// - Disabled if user set `GIT_SSH_COMMAND`
62/// - Requires all repository SSH remote hosts to be whitelisted
63///
64/// Related: https://gitlab.com/timvisee/prs/-/issues/31
65/// Related: https://github.com/timvisee/prs/issues/5#issuecomment-803940880
66// TODO: make configurable, add current user ID to path
67pub(crate) fn guess_ssh_persist_support(repo: &Path) -> bool {
68    // We must be using Unix, unreliable on Windows (and others?)
69    if !cfg!(unix) {
70        return false;
71    }
72
73    // User must not have set GIT_SSH_COMMAND variable
74    if env::var_os(GIT_ENV_SSH).is_some() {
75        return false;
76    }
77
78    // Get cached result
79    if let Ok(guard) = (*SSH_PERSIST_GUESS_CACHE).lock()
80        && let Some(supported) = guard.get(repo)
81    {
82        return *supported;
83    }
84
85    // Gather git remotes, assume not supported if no remote or error
86    let remotes = match git::git_remote(repo) {
87        Ok(remotes) if remotes.is_empty() => return false,
88        Ok(remotes) => remotes,
89        Err(_) => return false,
90    };
91
92    // Get remote host bits, ensure we have all
93    let ssh_uris: Vec<_> = remotes
94        .iter()
95        .filter_map(|remote| git::git_remote_get_url(repo, remote).ok())
96        .filter(|uri| !remote_is_http(uri))
97        .collect();
98
99    // Ensure all SSH URI hosts are part of whitelist, assume incompatible on error
100    let supported = ssh_uris.iter().all(|uri| match ssh_uri_host(uri) {
101        Some(host) => SSH_PERSIST_HOST_WHITELIST.contains(&host.to_lowercase().as_str()),
102        None => false,
103    });
104
105    // Cache result
106    if let Ok(mut guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
107        guard.insert(repo.to_path_buf(), supported);
108    }
109
110    supported
111}
112
113/// Check if given git remote URI is using HTTP(S) rather than SSH.
114fn remote_is_http(mut url: &str) -> bool {
115    url = url.trim();
116    url.starts_with("http://") || url.starts_with("https://")
117}
118
119/// Grab the host bit of an SSH URI.
120///
121/// This will do a best effort to grap the host bit of an SSH URI. If an HTTP(S) URL is given, or
122/// if the host bit could not be determined, `None` is returned. Note that this may not be very
123/// reliable.
124#[allow(clippy::manual_split_once, clippy::needless_splitn)]
125fn ssh_uri_host(mut uri: &str) -> Option<&str> {
126    // Must not be a HTTP(S) URL
127    if remote_is_http(uri) {
128        return None;
129    }
130
131    // Strip any ssh prefix
132    if let Some(stripped) = uri.strip_prefix("ssh://") {
133        uri = stripped;
134    }
135
136    // Strip the URI until we're left with the host
137    // TODO: this is potentially unreliable, improve this logic
138    let before_slash = uri.splitn(2, '/').next().unwrap();
139    let after_at = before_slash.splitn(2, '@').last().unwrap();
140    let before_collon = after_at.splitn(2, ':').next().unwrap();
141    let uri = before_collon.trim();
142
143    // Ensure the host is at least 3 characters long
144    if uri.len() >= 3 { Some(uri) } else { None }
145}
146
147/// Kill SSH clients that have an opened persistent session on a password store.
148///
149/// Closing these is required to close any open Tomb mount.
150#[cfg(unix)]
151pub fn kill_ssh_by_session(store: &Store) {
152    // If persistent SSH isn't used, we don't have to close sessions
153    if !guess_ssh_persist_support(&store.root) {
154        return;
155    }
156
157    // TODO: guess SSH session directory and file details from environment variable
158
159    // Find prs persistent SSH session files
160    let dir = match fs::read_dir(SSH_PERSIST_SESSION_FILE_DIR) {
161        Ok(dir) => dir,
162        Err(_) => return,
163    };
164    let session_files = dir
165        .flatten()
166        .filter(|e| e.file_type().map(|t| t.is_socket()).unwrap_or(false))
167        .filter(|e| {
168            e.file_name()
169                .to_str()
170                .map(|n| n.starts_with(SSH_PERSIST_SESSION_FILE_PREFIX))
171                .unwrap_or(false)
172        })
173        .map(|e| e.path());
174
175    // For each session file, kill attached SSH clients
176    session_files.for_each(|p| {
177        // List PIDs having this session file open
178        let pids = match opath(p) {
179            Ok(pids) => pids,
180            Err(_) => return,
181        };
182
183        pids.into_iter()
184            .map(Into::into)
185            .filter(|pid: &u32| pid > &0 && pid < &(i32::MAX as u32))
186            .filter(|pid| {
187                // Only handle ssh clients
188                fs::read_to_string(format!("/proc/{pid}/cmdline"))
189                    .map(|cmdline| {
190                        let cmd = cmdline.split([' ', ':']).next().unwrap();
191                        cmd.starts_with("ssh")
192                    })
193                    .unwrap_or(true)
194            })
195            .for_each(|pid| {
196                if let Err(err) = nix::sys::signal::kill(
197                    nix::unistd::Pid::from_raw(pid as i32),
198                    Some(nix::sys::signal::Signal::SIGTERM),
199                ) {
200                    eprintln!("Failed to kill persistent SSH client (pid: {pid}): {err}",);
201                }
202            });
203    });
204}