ssh_agent_switcher/
lib.rs1use log::{debug, info, warn};
27use std::env;
28use std::fs;
29use std::io;
30use std::path::{Path, PathBuf};
31use std::time::Duration;
32use tokio::net::{UnixListener, UnixStream};
33use tokio::select;
34use tokio::signal::unix::{SignalKind, signal};
35
36mod find;
37
38type Result<T> = std::result::Result<T, String>;
40
41struct UmaskGuard {
43 old_umask: libc::mode_t,
44}
45
46impl Drop for UmaskGuard {
47 fn drop(&mut self) {
48 let _ = unsafe { libc::umask(self.old_umask) };
49 }
50}
51
52fn set_umask(umask: libc::mode_t) -> UmaskGuard {
54 UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
55}
56
57fn create_listener(socket_path: &Path) -> Result<UnixListener> {
61 let _guard = set_umask(0o177);
64
65 UnixListener::bind(socket_path)
66 .map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
67}
68
69async fn handle_connection(
71 mut client: UnixStream,
72 agents_dirs: &[PathBuf],
73 home: Option<&Path>,
74 uid: libc::uid_t,
75) -> Result<()> {
76 let mut agent = match find::find_socket(agents_dirs, home, uid).await {
77 Some(socket) => socket,
78 None => {
79 return Err("No agent found; cannot proxy request".to_owned());
80 }
81 };
82 let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
83 .await
84 .and_then(|_| Ok(()))
85 .map_err(|e| format!("{}", e));
86 debug!("Closing client connection");
87 result
88}
89
90pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
96 let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
97 let uid = unsafe { libc::getuid() };
98
99 let mut sighup = signal(SignalKind::hangup())
100 .map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
101 let mut sigint = signal(SignalKind::interrupt())
102 .map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
103 let mut sigquit = signal(SignalKind::quit())
104 .map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
105 let mut sigterm = signal(SignalKind::terminate())
106 .map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
107
108 let listener = create_listener(&socket_path)?;
109
110 debug!("Entering main loop");
111 let mut stop = None;
112 while stop.is_none() {
113 select! {
114 result = listener.accept() => match result {
115 Ok((socket, _addr)) => {
116 debug!("Connection accepted");
117 if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
119 warn!("Dropping connection due to error: {}", e);
120 }
121 }
122 Err(e) => warn!("Failed to accept connection: {}", e),
123 },
124
125 _ = sighup.recv() => (),
126 _ = sigint.recv() => stop = Some("SIGINT"),
127 _ = sigquit.recv() => stop = Some("SIGQUIT"),
128 _ = sigterm.recv() => stop = Some("SIGTERM"),
129 }
130 }
131 debug!("Main loop exited");
132
133 let stop = stop.expect("Loop can only exit by setting stop");
134 info!("Shutting down due to {} and removing {}", stop, socket_path.display());
135
136 let _ = fs::remove_file(&socket_path);
137 let _ = fs::remove_file(&pid_file);
140
141 Ok(())
142}
143
144pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
147 path: P,
148 mut pending_wait: Duration,
149 op: fn(P) -> io::Result<T>,
150) -> Result<T> {
151 while pending_wait > Duration::ZERO {
152 match op(path) {
153 Ok(result) => {
154 return Ok(result);
155 }
156 Err(e) if e.kind() == io::ErrorKind::NotFound => {
157 pending_wait -= Duration::from_millis(1);
158 }
159 Err(e) => {
160 return Err(e.to_string());
161 }
162 }
163 }
164 Err("File was not created on time".to_owned())
165}