ssh_agent_switcher/
lib.rs1use log::{debug, info, warn};
27use std::env;
28use std::fs;
29use std::io;
30use std::path::{Path, PathBuf};
31use std::time::Duration;
32use tokio::net::{UnixListener, UnixStream};
33use tokio::select;
34use tokio::signal::unix::{SignalKind, signal};
35
36mod find;
37
38type Result<T> = std::result::Result<T, String>;
40
41struct UmaskGuard {
43 old_umask: libc::mode_t,
44}
45
46impl Drop for UmaskGuard {
47 fn drop(&mut self) {
48 let _ = unsafe { libc::umask(self.old_umask) };
49 }
50}
51
52fn set_umask(umask: libc::mode_t) -> UmaskGuard {
54 UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
55}
56
57fn create_listener(socket_path: &Path) -> Result<UnixListener> {
61 let _guard = set_umask(0o177);
64
65 UnixListener::bind(socket_path)
66 .map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
67}
68
69async fn handle_connection(
71 mut client: UnixStream,
72 agents_dirs: &[PathBuf],
73 home: Option<&Path>,
74 uid: libc::uid_t,
75) -> Result<()> {
76 let mut agent = match find::find_socket(agents_dirs, home, uid).await {
77 Some(socket) => socket,
78 None => {
79 return Err("No agent found; cannot proxy request".to_owned());
80 }
81 };
82 let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
83 .await.map(|_| ())
84 .map_err(|e| format!("{}", e));
85 debug!("Closing client connection");
86 result
87}
88
89pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
95 let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
96 let uid = unsafe { libc::getuid() };
97
98 let mut sighup = signal(SignalKind::hangup())
99 .map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
100 let mut sigint = signal(SignalKind::interrupt())
101 .map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
102 let mut sigquit = signal(SignalKind::quit())
103 .map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
104 let mut sigterm = signal(SignalKind::terminate())
105 .map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
106
107 let listener = create_listener(&socket_path)?;
108
109 debug!("Entering main loop");
110 let mut stop = None;
111 while stop.is_none() {
112 select! {
113 result = listener.accept() => match result {
114 Ok((socket, _addr)) => {
115 debug!("Connection accepted");
116 if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
118 warn!("Dropping connection due to error: {}", e);
119 }
120 }
121 Err(e) => warn!("Failed to accept connection: {}", e),
122 },
123
124 _ = sighup.recv() => (),
125 _ = sigint.recv() => stop = Some("SIGINT"),
126 _ = sigquit.recv() => stop = Some("SIGQUIT"),
127 _ = sigterm.recv() => stop = Some("SIGTERM"),
128 }
129 }
130 debug!("Main loop exited");
131
132 let stop = stop.expect("Loop can only exit by setting stop");
133 info!("Shutting down due to {} and removing {}", stop, socket_path.display());
134
135 let _ = fs::remove_file(&socket_path);
136 let _ = fs::remove_file(&pid_file);
139
140 Ok(())
141}
142
143pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
146 path: P,
147 mut pending_wait: Duration,
148 op: fn(P) -> io::Result<T>,
149) -> Result<T> {
150 while pending_wait > Duration::ZERO {
151 match op(path) {
152 Ok(result) => {
153 return Ok(result);
154 }
155 Err(e) if e.kind() == io::ErrorKind::NotFound => {
156 pending_wait -= Duration::from_millis(1);
157 }
158 Err(e) => {
159 return Err(e.to_string());
160 }
161 }
162 }
163 Err("File was not created on time".to_owned())
164}