use log::{debug, info, trace};
use std::io::{ErrorKind, Result};
use std::os::unix::fs::MetadataExt;
use std::path::Path;
use std::{fs, path::PathBuf};
use tokio::net::UnixStream;
#[macro_export]
macro_rules! error {
( $kind:expr, $text:expr ) => {
std::io::Error::new($kind, $text)
};
( $kind:expr, $fmt:literal $(, $args:expr)+ ) => {
std::io::Error::new($kind, format!($fmt $(, $args)+))
};
}
async fn try_open(path: &Path) -> Result<UnixStream> {
let name = path.file_name().expect(
"The path comes from joining a directory to one of its entries, so it must have a name",
);
let name = match name.to_str() {
Some(name) => name,
None => return Err(error!(ErrorKind::InvalidInput, "Invalid socket path")),
};
let is_pre_openssh_10_1 = name.starts_with("agent.");
let is_openssh_10_1 = name.contains(".sshd.");
if !is_pre_openssh_10_1 && !is_openssh_10_1 {
return Err(error!(
ErrorKind::InvalidInput,
"Socket name in does not start with 'agent.' or does not contain '.sshd.'"
));
}
let metadata =
fs::metadata(path).map_err(|e| error!(e.kind(), "Failed to get metadata: {}", e))?;
if (metadata.mode() & libc::S_IFSOCK as u32) == 0 {
return Err(error!(ErrorKind::InvalidInput, "Path is not a socket"));
}
let socket = UnixStream::connect(&path)
.await
.map_err(|e| error!(e.kind(), "Cannot connect to socket: {}", e))?;
Ok(socket)
}
async fn find_in_subdir(dir: &Path) -> Option<UnixStream> {
let entries = match fs::read_dir(dir) {
Ok(entries) => entries,
Err(e) => {
debug!("Failed to read directory entries in {}: {}", dir.display(), e);
return None;
}
};
let mut candidates = vec![];
for entry in entries {
let entry = match entry {
Ok(entry) => entry,
Err(e) => {
debug!("Failed to read directory entry in {}: {}", dir.display(), e);
continue;
}
};
let candidate = entry.path();
candidates.push(candidate);
}
candidates.sort();
for candidate in candidates {
let socket = match try_open(&candidate).await {
Ok(socket) => socket,
Err(e) => {
trace!("Ignoring candidate socket {}: {}", candidate.display(), e);
continue;
}
};
info!("Successfully opened socket at {}", candidate.display());
return Some(socket);
}
debug!("No socket in directory {}", dir.display());
None
}
async fn try_shared_subdir(dir: &Path, uid: libc::uid_t) -> Result<UnixStream> {
let name = dir.file_name().expect(
"The candidate path comes from joining a directory to one of its entries, so it must have a name");
let name = match name.to_str() {
Some(name) => name,
None => return Err(error!(ErrorKind::InvalidInput, "Invalid file name")),
};
if !name.starts_with("ssh-") {
return Err(error!(ErrorKind::InvalidInput, "Basename does not start with 'ssh-'"));
}
let metadata = fs::metadata(dir).map_err(|e| error!(e.kind(), "Stat failed: {}", e))?;
if metadata.uid() != uid {
return Err(error!(
ErrorKind::InvalidInput,
"{} is owned by {}, not the current user {}",
dir.display(),
metadata.uid(),
uid
));
}
match find_in_subdir(dir).await {
Some(socket) => Ok(socket),
None => Err(error!(ErrorKind::NotFound, "No socket in subdirectory")),
}
}
async fn find_in_shared_dir(dir: &Path, our_uid: libc::uid_t) -> Option<UnixStream> {
let entries = match fs::read_dir(dir) {
Ok(entries) => entries,
Err(e) => {
debug!("Failed to read directory entries in {}: {}", dir.display(), e);
return None;
}
};
let mut subdirs = vec![];
for entry in entries {
let entry = match entry {
Ok(entry) => entry,
Err(e) => {
debug!("Failed to read directory entry in {}: {}", dir.display(), e);
continue;
}
};
let path = entry.path();
match entry.file_type() {
Ok(file_type) if file_type.is_dir() => (),
Ok(_file_type) => {
trace!("Ignoring {}: not a directory", path.display());
continue;
}
Err(e) => {
trace!("Ignoring {}: {}", path.display(), e);
continue;
}
};
subdirs.push(path);
}
subdirs.sort();
for subdir in subdirs {
let socket = match try_shared_subdir(&subdir, our_uid).await {
Ok(socket) => socket,
Err(e) => {
trace!("Ignoring {}: {}", subdir.display(), e);
continue;
}
};
return Some(socket);
}
debug!("No socket in directory: {}", dir.display());
None
}
pub(super) async fn find_socket(
dirs: &[PathBuf],
home: Option<&Path>,
uid: libc::uid_t,
) -> Option<UnixStream> {
for dir in dirs {
if let Some(home) = home
&& dir.starts_with(home)
{
debug!("Looking for an agent socket in {} with HOME naming scheme", dir.display());
if let Some(socket) = find_in_subdir(dir).await {
return Some(socket);
}
}
debug!("Looking for an agent socket in {} subdirs", dir.display());
if let Some(socket) = find_in_shared_dir(dir, uid).await {
return Some(socket);
}
}
None
}