prs_lib/util/
git.rs

1use std::collections::HashMap;
2use std::env;
3#[cfg(unix)]
4use std::fs;
5#[cfg(unix)]
6use std::os::unix::fs::FileTypeExt;
7use std::path::{Path, PathBuf};
8use std::process::Command;
9use std::sync::Mutex;
10
11#[cfg(unix)]
12use crate::Store;
13use crate::git;
14
15/// Environment variable git uses to modify the ssh command.
16const GIT_ENV_SSH: &str = "GIT_SSH_COMMAND";
17
18/// Custom ssh command for git.
19///
20/// With this custom SSH command we enable SSH connection persistence for session reuse to make
21/// remote git operations much quicker for repositories using an SSH URL. This greatly improves prs
22/// sync speeds.
23///
24/// This sets up a session file in the users `/tmp` directory. A timeout of 10 seconds is set to
25/// quickly abort a connection attempt if the persistent connection fails.
26const SSH_PERSIST_CMD: &str = "ssh -o 'ControlMaster auto' -o 'ControlPath /tmp/.prs-session--%r@%h:%p' -o 'ControlPersist 1h' -o 'ConnectTimeout 10'";
27
28/// Directory for SSH persistent session files.
29#[cfg(unix)]
30pub(crate) const SSH_PERSIST_SESSION_FILE_DIR: &str = "/tmp";
31
32/// Prefix for SSH persistent session files.
33#[cfg(unix)]
34pub(crate) const SSH_PERSIST_SESSION_FILE_PREFIX: &str = ".prs-session--";
35
36/// A whitelist of SSH hosts that support connection persisting.
37const SSH_PERSIST_HOST_WHITELIST: [&str; 2] = ["github.com", "gitlab.com"];
38
39lazy_static! {
40    /// Cache for SSH connection persistence support guess.
41    static ref SSH_PERSIST_GUESS_CACHE: Mutex<HashMap<PathBuf, bool>> = Mutex::new(HashMap::new());
42}
43
44/// Configure given git command to use SSH connection persisting.
45///
46/// `guess_ssh_connection_persist_support` should be used to guess whether this is supported.
47pub(crate) fn configure_ssh_persist(cmd: &mut Command) {
48    cmd.env(self::GIT_ENV_SSH, self::SSH_PERSIST_CMD);
49}
50
51/// Guess whether SSH connection persistence is supported.
52///
53/// This does a best effort to determine whether SSH connection persistence is supported. This is
54/// used to enable connection reuse. This internally caches the guess in the current process by
55/// repository path.
56///
57/// - Disabled on non-Unix
58/// - Disabled if user set `GIT_SSH_COMMAND`
59/// - Requires all repository SSH remote hosts to be whitelisted
60///
61/// Related: https://gitlab.com/timvisee/prs/-/issues/31
62/// Related: https://github.com/timvisee/prs/issues/5#issuecomment-803940880
63// TODO: make configurable, add current user ID to path
64pub(crate) fn guess_ssh_persist_support(repo: &Path) -> bool {
65    // We must be using Unix, unreliable on Windows (and others?)
66    if !cfg!(unix) {
67        return false;
68    }
69
70    // User must not have set GIT_SSH_COMMAND variable
71    if env::var_os(GIT_ENV_SSH).is_some() {
72        return false;
73    }
74
75    // Get cached result
76    if let Ok(guard) = (*SSH_PERSIST_GUESS_CACHE).lock()
77        && let Some(supported) = guard.get(repo)
78    {
79        return *supported;
80    }
81
82    // Gather git remotes, assume not supported if no remote or error
83    let remotes = match git::git_remote(repo) {
84        Ok(remotes) if remotes.is_empty() => return false,
85        Ok(remotes) => remotes,
86        Err(_) => return false,
87    };
88
89    // Get remote host bits, ensure we have all
90    let ssh_uris: Vec<_> = remotes
91        .iter()
92        .filter_map(|remote| git::git_remote_get_url(repo, remote).ok())
93        .filter(|uri| !remote_is_http(uri))
94        .collect();
95
96    // Ensure all SSH URI hosts are part of whitelist, assume incompatible on error
97    let supported = ssh_uris.iter().all(|uri| match ssh_uri_host(uri) {
98        Some(host) => SSH_PERSIST_HOST_WHITELIST.contains(&host.to_lowercase().as_str()),
99        None => false,
100    });
101
102    // Cache result
103    if let Ok(mut guard) = (*SSH_PERSIST_GUESS_CACHE).lock() {
104        guard.insert(repo.to_path_buf(), supported);
105    }
106
107    supported
108}
109
110/// Check if given git remote URI is using HTTP(S) rather than SSH.
111fn remote_is_http(mut url: &str) -> bool {
112    url = url.trim();
113    url.starts_with("http://") || url.starts_with("https://")
114}
115
116/// Grab the host bit of an SSH URI.
117///
118/// This will do a best effort to grap the host bit of an SSH URI. If an HTTP(S) URL is given, or
119/// if the host bit could not be determined, `None` is returned. Note that this may not be very
120/// reliable.
121#[allow(clippy::manual_split_once, clippy::needless_splitn)]
122fn ssh_uri_host(mut uri: &str) -> Option<&str> {
123    // Must not be a HTTP(S) URL
124    if remote_is_http(uri) {
125        return None;
126    }
127
128    // Strip any ssh prefix
129    if let Some(stripped) = uri.strip_prefix("ssh://") {
130        uri = stripped;
131    }
132
133    // Strip the URI until we're left with the host
134    // TODO: this is potentially unreliable, improve this logic
135    let before_slash = uri.splitn(2, '/').next().unwrap();
136    let after_at = before_slash.splitn(2, '@').last().unwrap();
137    let before_collon = after_at.splitn(2, ':').next().unwrap();
138    let uri = before_collon.trim();
139
140    // Ensure the host is at least 3 characters long
141    if uri.len() >= 3 { Some(uri) } else { None }
142}
143
144/// Kill SSH clients that have an opened persistent session on a password store.
145///
146/// Closing these is required to close any open Tomb mount.
147#[cfg(unix)]
148pub fn kill_ssh_by_session(store: &Store) {
149    // If persistent SSH isn't used, we don't have to close sessions
150    if !guess_ssh_persist_support(&store.root) {
151        return;
152    }
153
154    // TODO: guess SSH session directory and file details from environment variable
155
156    // Find prs persistent SSH session files
157    let dir = match fs::read_dir(SSH_PERSIST_SESSION_FILE_DIR) {
158        Ok(dir) => dir,
159        Err(_) => return,
160    };
161    let session_files = dir
162        .flatten()
163        .filter(|e| e.file_type().map(|t| t.is_socket()).unwrap_or(false))
164        .filter(|e| {
165            e.file_name()
166                .to_str()
167                .map(|n| n.starts_with(SSH_PERSIST_SESSION_FILE_PREFIX))
168                .unwrap_or(false)
169        })
170        .map(|e| e.path());
171
172    // For each session file, kill attached SSH clients
173    #[cfg(any(target_os = "linux", target_os = "macos", target_os = "freebsd"))]
174    session_files.for_each(|path| {
175        use super::proc::{pids_with_file_open, cmdline};
176
177        // List PIDs having this session file open
178        let pids = match pids_with_file_open(&path) {
179            Ok(pids) => pids,
180            Err(_) => return,
181        };
182
183        pids.into_iter()
184            // PID must be in valid range
185            .filter(|pid| pid.as_raw() > 0 && pid.as_raw() < nix::libc::pid_t::MAX)
186            // Only kill commands starting with "ssh"
187            .filter(|pid| {
188                cmdline(*pid)
189                        .map(|cmdline| {
190                            let cmd = cmdline.split([' ', ':']).next().unwrap();
191                            cmd.starts_with("ssh")
192                        })
193                        .unwrap_or(true)
194            })
195            .for_each(|pid| {
196                if let Err(err) = nix::sys::signal::kill(
197                    pid,
198                    Some(nix::sys::signal::Signal::SIGTERM),
199                ) {
200                    eprintln!("Failed to kill persistent SSH client (pid: {pid}): {err}",);
201                }
202            });
203    });
204}