ssh_agent_switcher/
lib.rs

1// Copyright 2025 Julio Merino.
2// All rights reserved.
3//
4// Redistribution and use in source and binary forms, with or without modification, are permitted
5// provided that the following conditions are met:
6//
7// * Redistributions of source code must retain the above copyright notice, this list of conditions
8//   and the following disclaimer.
9// * Redistributions in binary form must reproduce the above copyright notice, this list of
10//   conditions and the following disclaimer in the documentation and/or other materials provided with
11//   the distribution.
12// * Neither the name of ssh-agent-switcher nor the names of its contributors may be used to endorse
13//   or promote products derived from this software without specific prior written permission.
14//
15// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17// FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
18// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
19// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
20// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
21// WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY
22// WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23
24//! Serves a Unix domain socket that proxies connections to any valid SSH agent provided by sshd.
25
26use log::{debug, info, warn};
27use std::env;
28use std::fs;
29use std::io;
30use std::path::{Path, PathBuf};
31use std::time::Duration;
32use tokio::net::{UnixListener, UnixStream};
33use tokio::select;
34use tokio::signal::unix::{SignalKind, signal};
35
36mod find;
37
38/// Result type for this crate.
39type Result<T> = std::result::Result<T, String>;
40
41/// A scope guard to restore the previous umask.
42struct UmaskGuard {
43    old_umask: libc::mode_t,
44}
45
46impl Drop for UmaskGuard {
47    fn drop(&mut self) {
48        let _ = unsafe { libc::umask(self.old_umask) };
49    }
50}
51
52/// Sets the umask and returns a guard to restore it on drop.
53fn set_umask(umask: libc::mode_t) -> UmaskGuard {
54    UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
55}
56
57/// Creates the agent socket to listen on.
58///
59/// This makes sure that the socket is only accessible by the current user.
60fn create_listener(socket_path: &Path) -> Result<UnixListener> {
61    // Ensure the socket is not group nor world readable so that we don't expose the real socket
62    // indirectly to other users.
63    let _guard = set_umask(0o177);
64
65    UnixListener::bind(socket_path)
66        .map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
67}
68
69/// Handles one incoming connection on `client`.
70async fn handle_connection(
71    mut client: UnixStream,
72    agents_dirs: &[PathBuf],
73    home: Option<&Path>,
74    uid: libc::uid_t,
75) -> Result<()> {
76    let mut agent = match find::find_socket(agents_dirs, home, uid).await {
77        Some(socket) => socket,
78        None => {
79            return Err("No agent found; cannot proxy request".to_owned());
80        }
81    };
82    let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
83        .await
84        .and_then(|_| Ok(()))
85        .map_err(|e| format!("{}", e));
86    debug!("Closing client connection");
87    result
88}
89
90/// Runs the core logic of the app.
91///
92/// This serves the SSH agent socket on `socket_path` and looks for sshd sockets in `agents_dirs`.
93///
94/// The `pid_file` needs to be passed in for cleanup purposes.
95pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
96    let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
97    let uid = unsafe { libc::getuid() };
98
99    let mut sighup = signal(SignalKind::hangup())
100        .map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
101    let mut sigint = signal(SignalKind::interrupt())
102        .map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
103    let mut sigquit = signal(SignalKind::quit())
104        .map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
105    let mut sigterm = signal(SignalKind::terminate())
106        .map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
107
108    let listener = create_listener(&socket_path)?;
109
110    debug!("Entering main loop");
111    let mut stop = None;
112    while stop.is_none() {
113        select! {
114            result = listener.accept() => match result {
115                Ok((socket, _addr)) => {
116                    debug!("Connection accepted");
117                    // TODO(jmmv): Connections are handled sequentially.  This is... fine.
118                    if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
119                        warn!("Dropping connection due to error: {}", e);
120                    }
121                }
122                Err(e) => warn!("Failed to accept connection: {}", e),
123            },
124
125            _ = sighup.recv() => (),
126            _ = sigint.recv() => stop = Some("SIGINT"),
127            _ = sigquit.recv() => stop = Some("SIGQUIT"),
128            _ = sigterm.recv() => stop = Some("SIGTERM"),
129        }
130    }
131    debug!("Main loop exited");
132
133    let stop = stop.expect("Loop can only exit by setting stop");
134    info!("Shutting down due to {} and removing {}", stop, socket_path.display());
135
136    let _ = fs::remove_file(&socket_path);
137    // Because we catch signals, daemonize doesn't properly clean up the PID file so we have
138    // to do it ourselves.
139    let _ = fs::remove_file(&pid_file);
140
141    Ok(())
142}
143
144/// Waits for `path` to exist for a maximum period of time using operation `op`.
145/// Returns the result of `op` on success.
146pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
147    path: P,
148    mut pending_wait: Duration,
149    op: fn(P) -> io::Result<T>,
150) -> Result<T> {
151    while pending_wait > Duration::ZERO {
152        match op(path) {
153            Ok(result) => {
154                return Ok(result);
155            }
156            Err(e) if e.kind() == io::ErrorKind::NotFound => {
157                pending_wait -= Duration::from_millis(1);
158            }
159            Err(e) => {
160                return Err(e.to_string());
161            }
162        }
163    }
164    Err("File was not created on time".to_owned())
165}