use std::io::{ErrorKind, Read, Write};
use std::os::unix::net::UnixStream;
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
#[cfg(feature = "server")]
use std::os::unix::fs::PermissionsExt;
#[cfg(feature = "server")]
use std::os::unix::net::UnixListener;
#[cfg(feature = "server")]
use std::path::Path;
#[cfg(feature = "server")]
use std::sync::atomic::{AtomicBool, Ordering};
#[cfg(feature = "server")]
use std::thread::JoinHandle;
#[cfg(feature = "server")]
use std::time::Duration;
#[cfg(feature = "server")]
use purecrypto::rng::{OsRng, RngCore};
#[cfg(feature = "server")]
use crate::error::Result;
#[cfg(feature = "server")]
use crate::server::{AgentForwardContext, AgentForwardHandle, AgentForwardHandler};
use crate::stream::{ChannelEgress, ChannelStream};
#[cfg(feature = "server")]
const ACCEPT_POLL_INTERVAL: Duration = Duration::from_millis(100);
#[cfg(feature = "server")]
struct AgentBinding {
stop: Arc<AtomicBool>,
handle: Option<JoinHandle<()>>,
socket_path: PathBuf,
}
#[cfg(feature = "server")]
impl Drop for AgentBinding {
fn drop(&mut self) {
self.stop.store(true, Ordering::SeqCst);
if let Some(h) = self.handle.take() {
let _ = h.join();
}
let _ = std::fs::remove_file(&self.socket_path);
}
}
#[cfg(feature = "server")]
pub struct DefaultAgentForwardHandler {
parent_dir: Option<PathBuf>,
}
#[cfg(feature = "server")]
impl Default for DefaultAgentForwardHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "server")]
impl DefaultAgentForwardHandler {
pub fn new() -> Self {
Self { parent_dir: None }
}
pub fn with_parent_dir(dir: PathBuf) -> Self {
Self {
parent_dir: Some(dir),
}
}
fn resolve_parent(&self) -> PathBuf {
if let Some(d) = &self.parent_dir {
return d.clone();
}
if let Ok(d) = std::env::var("XDG_RUNTIME_DIR") {
if !d.is_empty() {
return PathBuf::from(d);
}
}
PathBuf::from("/tmp")
}
}
#[cfg(feature = "server")]
impl AgentForwardHandler for DefaultAgentForwardHandler {
fn setup(&self, _user: &str, ctx: AgentForwardContext) -> Result<AgentForwardHandle> {
let parent = self.resolve_parent();
let socket_path = mint_socket_path(&parent)?;
let _ = std::fs::remove_file(&socket_path);
let listener = UnixListener::bind(&socket_path)?;
listener.set_nonblocking(true)?;
let _ = std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o600));
let stop = Arc::new(AtomicBool::new(false));
let stop_thread = Arc::clone(&stop);
let handle = thread::spawn(move || {
while !stop_thread.load(Ordering::SeqCst) {
match listener.accept() {
Ok((conn, _peer)) => match ctx.open_auth_agent() {
Ok(channel_stream) => {
spawn_unix_splice(conn, channel_stream);
}
Err(_) => {
let _ = conn.shutdown(std::net::Shutdown::Both);
}
},
Err(e) if e.kind() == ErrorKind::WouldBlock => {
thread::sleep(ACCEPT_POLL_INTERVAL);
}
Err(_) => break,
}
}
});
let binding = AgentBinding {
stop,
handle: Some(handle),
socket_path: socket_path.clone(),
};
Ok(AgentForwardHandle {
auth_sock_path: socket_path,
stopper: Box::new(binding),
})
}
}
fn spawn_unix_splice(uds: UnixStream, stream: ChannelStream) {
let (chan_rx, chan_tx) = stream.into_raw();
let Ok(uds_in) = uds.try_clone() else {
let _ = chan_tx.send(ChannelEgress::Eof);
let _ = chan_tx.send(ChannelEgress::Close);
return;
};
let uds_out = uds;
let chan_tx_a = chan_tx.clone();
let mut uds_in_a = uds_in;
let a = thread::spawn(move || {
let mut buf = [0u8; 32 * 1024];
loop {
match uds_in_a.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if chan_tx_a
.send(ChannelEgress::Data(buf[..n].to_vec()))
.is_err()
{
break;
}
}
Err(e) if e.kind() == ErrorKind::Interrupted => continue,
Err(_) => break,
}
}
let _ = chan_tx_a.send(ChannelEgress::Eof);
});
let mut uds_out_b = uds_out;
let b = thread::spawn(move || {
while let Ok(Some(chunk)) = chan_rx.recv() {
if uds_out_b.write_all(&chunk).is_err() {
break;
}
}
let _ = uds_out_b.shutdown(std::net::Shutdown::Read);
});
thread::spawn(move || {
let _ = a.join();
let _ = b.join();
let _ = chan_tx.send(ChannelEgress::Close);
});
}
#[cfg(feature = "server")]
fn mint_socket_path(parent: &Path) -> Result<PathBuf> {
let mut entropy = [0u8; 8];
OsRng.fill_bytes(&mut entropy);
let suffix = hex_encode(&entropy);
let pid = std::process::id();
Ok(parent.join(format!("puressh-agent.{pid}.{suffix}.sock")))
}
pub fn splice_to_unix_socket_callback(
path: PathBuf,
) -> Option<Arc<dyn Fn(ChannelStream) + Send + Sync + 'static>> {
if !path.exists() {
return None;
}
Some(Arc::new(move |stream: ChannelStream| {
match UnixStream::connect(&path) {
Ok(uds) => spawn_unix_splice(uds, stream),
Err(_) => {
let (_rx, tx) = stream.into_raw();
let _ = tx.send(ChannelEgress::Eof);
let _ = tx.send(ChannelEgress::Close);
}
}
}))
}
pub fn splice_to_local_agent_callback() -> Option<Arc<dyn Fn(ChannelStream) + Send + Sync + 'static>>
{
let raw = std::env::var("SSH_AUTH_SOCK").ok()?;
if raw.is_empty() {
return None;
}
splice_to_unix_socket_callback(PathBuf::from(raw))
}
#[cfg(feature = "server")]
fn hex_encode(bytes: &[u8]) -> String {
const HEX: &[u8; 16] = b"0123456789abcdef";
let mut out = String::with_capacity(bytes.len() * 2);
for &b in bytes {
out.push(HEX[(b >> 4) as usize] as char);
out.push(HEX[(b & 0x0f) as usize] as char);
}
out
}
#[cfg(all(test, feature = "server"))]
mod tests {
use super::*;
struct TestTempDir {
path: PathBuf,
}
impl TestTempDir {
fn new(prefix: &str) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0);
let pid = std::process::id();
let path =
std::env::temp_dir().join(format!("puressh-agentfwd-{prefix}-{pid}-{nanos}"));
std::fs::create_dir_all(&path).expect("create tempdir");
Self { path }
}
fn path(&self) -> &Path {
&self.path
}
}
impl Drop for TestTempDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.path);
}
}
#[test]
fn setup_binds_and_drop_unlinks() {
let dir = TestTempDir::new("setup");
let h = DefaultAgentForwardHandler::with_parent_dir(dir.path().to_path_buf());
let ctx = AgentForwardContext::for_test_no_opens();
let handle = h.setup("u", ctx).expect("setup");
let path = handle.auth_sock_path.clone();
assert!(path.exists(), "socket should exist on disk after setup");
assert!(path
.to_string_lossy()
.starts_with(&format!("{}/puressh-agent.", dir.path().display())));
drop(handle);
for _ in 0..50 {
if !path.exists() {
break;
}
thread::sleep(Duration::from_millis(50));
}
assert!(
!path.exists(),
"socket should be unlinked when the handle is dropped (path={path:?})",
);
}
#[test]
fn accepted_connection_is_closed_when_open_fails() {
let dir = TestTempDir::new("closeonfail");
let h = DefaultAgentForwardHandler::with_parent_dir(dir.path().to_path_buf());
let ctx = AgentForwardContext::for_test_no_opens();
let handle = h.setup("u", ctx).expect("setup");
let mut peer = UnixStream::connect(&handle.auth_sock_path).expect("connect");
peer.set_read_timeout(Some(Duration::from_secs(2)))
.expect("read timeout");
let mut buf = [0u8; 1];
let _ = peer.read(&mut buf);
}
}