use log::{debug, info, warn};
use std::env;
use std::fs;
use std::io;
use std::path::{Path, PathBuf};
use std::thread;
use std::time::Duration;
use tokio::net::{UnixListener, UnixStream};
use tokio::select;
use tokio::signal::unix::{SignalKind, signal};
mod find;
type Result<T> = std::result::Result<T, String>;
struct UmaskGuard {
old_umask: libc::mode_t,
}
impl Drop for UmaskGuard {
fn drop(&mut self) {
let _ = unsafe { libc::umask(self.old_umask) };
}
}
fn set_umask(umask: libc::mode_t) -> UmaskGuard {
UmaskGuard { old_umask: unsafe { libc::umask(umask) } }
}
fn create_listener(socket_path: &Path) -> Result<UnixListener> {
let _guard = set_umask(0o177);
UnixListener::bind(socket_path)
.map_err(|e| format!("Cannot listen on {}: {}", socket_path.display(), e))
}
async fn handle_connection(
mut client: UnixStream,
agents_dirs: &[PathBuf],
home: Option<&Path>,
uid: libc::uid_t,
) -> Result<()> {
let mut agent = match find::find_socket(agents_dirs, home, uid).await {
Some(socket) => socket,
None => {
return Err("No agent found; cannot proxy request".to_owned());
}
};
let result = tokio::io::copy_bidirectional(&mut client, &mut agent)
.await
.map(|_| ())
.map_err(|e| format!("{}", e));
debug!("Closing client connection");
result
}
pub async fn run(socket_path: PathBuf, agents_dirs: &[PathBuf], pid_file: PathBuf) -> Result<()> {
let home = env::var("HOME").map(|v| Some(PathBuf::from(v))).unwrap_or(None);
let uid = unsafe { libc::getuid() };
let mut sighup = signal(SignalKind::hangup())
.map_err(|e| format!("Failed to install SIGHUP handler: {}", e))?;
let mut sigint = signal(SignalKind::interrupt())
.map_err(|e| format!("Failed to install SIGINT handler: {}", e))?;
let mut sigquit = signal(SignalKind::quit())
.map_err(|e| format!("Failed to install SIGQUIT handler: {}", e))?;
let mut sigterm = signal(SignalKind::terminate())
.map_err(|e| format!("Failed to install SIGTERM handler: {}", e))?;
let listener = create_listener(&socket_path)?;
debug!("Entering main loop");
let mut stop = None;
while stop.is_none() {
select! {
result = listener.accept() => match result {
Ok((socket, _addr)) => {
debug!("Connection accepted");
if let Err(e) = handle_connection(socket, agents_dirs, home.as_deref(), uid).await {
warn!("Dropping connection due to error: {}", e);
}
}
Err(e) => warn!("Failed to accept connection: {}", e),
},
_ = sighup.recv() => (),
_ = sigint.recv() => stop = Some("SIGINT"),
_ = sigquit.recv() => stop = Some("SIGQUIT"),
_ = sigterm.recv() => stop = Some("SIGTERM"),
}
}
debug!("Main loop exited");
let stop = stop.expect("Loop can only exit by setting stop");
info!("Shutting down due to {} and removing {}", stop, socket_path.display());
let _ = fs::remove_file(&socket_path);
let _ = fs::remove_file(&pid_file);
Ok(())
}
pub fn wait_for_file<P: AsRef<Path> + Copy, T>(
path: P,
mut pending_wait: Duration,
op: fn(P) -> io::Result<T>,
) -> Result<T> {
while pending_wait > Duration::ZERO {
match op(path) {
Ok(result) => {
return Ok(result);
}
Err(e) if e.kind() == io::ErrorKind::NotFound => {
thread::sleep(Duration::from_millis(1));
pending_wait -= Duration::from_millis(1);
}
Err(e) => {
return Err(e.to_string());
}
}
}
Err("File was not created on time".to_owned())
}