use crate::bin_error::{self, ContextExt as _};
use tokio::io::{AsyncReadExt as _, AsyncWriteExt as _};
const MAX_MESSAGE: u32 = 16 * 1024 * 1024;
pub struct Sock(tokio::net::UnixStream);
impl Sock {
pub fn new(s: tokio::net::UnixStream) -> Self {
Self(s)
}
pub async fn send(
&mut self,
res: &bwx::protocol::Response,
) -> bin_error::Result<()> {
if let bwx::protocol::Response::Error { error } = res {
log::warn!("{error}");
}
let Self(sock) = self;
let payload =
rmp_serde::to_vec(res).context("failed to serialize message")?;
let len = u32::try_from(payload.len()).map_err(|_| {
bin_error::Error::msg(format!(
"outgoing message exceeds {MAX_MESSAGE}-byte cap"
))
})?;
if len > MAX_MESSAGE {
return Err(bin_error::Error::msg(format!(
"outgoing message exceeds {MAX_MESSAGE}-byte cap"
)));
}
sock.write_all(&len.to_be_bytes())
.await
.context("failed to write message to socket")?;
sock.write_all(&payload)
.await
.context("failed to write message to socket")?;
Ok(())
}
pub async fn recv(
&mut self,
) -> bin_error::Result<std::result::Result<bwx::protocol::Request, String>>
{
let Self(sock) = self;
let mut len_buf = [0u8; 4];
match sock.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Ok(Err("connection closed".to_string()));
}
Err(e) => {
return Err(bin_error::Error::with_context(
e,
"failed to read message from socket",
));
}
}
let len = u32::from_be_bytes(len_buf);
if len > MAX_MESSAGE {
return Ok(Err(format!(
"message exceeds {MAX_MESSAGE}-byte cap"
)));
}
let mut payload = vec![
0u8;
usize::try_from(len)
.expect("16 MiB-capped u32 fits in usize")
];
sock.read_exact(&mut payload)
.await
.context("failed to read message from socket")?;
Ok(rmp_serde::from_slice(&payload)
.map_err(|e| format!("failed to parse message: {e}")))
}
}
pub fn peer_pid_of(stream: &tokio::net::UnixStream) -> Option<i32> {
use std::os::unix::io::AsRawFd as _;
peer_pid(stream.as_raw_fd())
}
pub fn check_peer_uid(
stream: &tokio::net::UnixStream,
) -> bin_error::Result<()> {
use std::os::unix::io::AsRawFd as _;
let fd = stream.as_raw_fd();
let peer_uid = peer_uid(fd).context("failed to read peer uid")?;
let self_uid = unsafe { libc::getuid() };
if peer_uid != self_uid {
return Err(bin_error::Error::msg(format!(
"peer uid {peer_uid} does not match agent uid {self_uid}; \
refusing connection"
)));
}
Ok(())
}
#[cfg(any(target_os = "linux", target_os = "android"))]
fn peer_ucred(fd: std::os::unix::io::RawFd) -> std::io::Result<libc::ucred> {
let mut cred: libc::ucred = unsafe { std::mem::zeroed() };
let mut len = u32::try_from(std::mem::size_of::<libc::ucred>())
.expect("ucred size fits in socklen_t");
let rc = unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_PEERCRED,
std::ptr::from_mut::<libc::ucred>(&mut cred).cast(),
&raw mut len,
)
};
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(cred)
}
#[cfg(any(target_os = "linux", target_os = "android"))]
fn peer_uid(fd: std::os::unix::io::RawFd) -> std::io::Result<u32> {
peer_ucred(fd).map(|c| c.uid)
}
#[cfg(any(target_os = "linux", target_os = "android"))]
pub fn peer_pid(fd: std::os::unix::io::RawFd) -> Option<i32> {
peer_ucred(fd).ok().map(|c| c.pid)
}
#[cfg(any(
target_os = "macos",
target_os = "ios",
target_os = "freebsd",
target_os = "openbsd",
target_os = "netbsd",
target_os = "dragonfly"
))]
fn peer_uid(fd: std::os::unix::io::RawFd) -> std::io::Result<u32> {
let mut uid: libc::uid_t = u32::MAX;
let mut gid: libc::gid_t = u32::MAX;
let rc = unsafe { libc::getpeereid(fd, &raw mut uid, &raw mut gid) };
if rc != 0 {
return Err(std::io::Error::last_os_error());
}
Ok(uid)
}
#[cfg(target_os = "macos")]
pub fn peer_pid(fd: std::os::unix::io::RawFd) -> Option<i32> {
const SOL_LOCAL: libc::c_int = 0;
const LOCAL_PEERPID: libc::c_int = 2;
let mut pid: libc::pid_t = 0;
let mut len = u32::try_from(std::mem::size_of::<libc::pid_t>())
.expect("pid_t fits in socklen_t");
let rc = unsafe {
libc::getsockopt(
fd,
SOL_LOCAL,
LOCAL_PEERPID,
std::ptr::from_mut::<libc::pid_t>(&mut pid).cast(),
&raw mut len,
)
};
if rc != 0 {
return None;
}
Some(pid)
}
#[cfg(not(any(
target_os = "linux",
target_os = "android",
target_os = "macos"
)))]
pub fn peer_pid(_fd: std::os::unix::io::RawFd) -> Option<i32> {
None
}
pub fn listen() -> bin_error::Result<tokio::net::UnixListener> {
let path = bwx::dirs::socket_file();
let sock = bind_atomic(&path).context("failed to listen on socket")?;
log::debug!("listening on socket {}", path.to_string_lossy());
Ok(sock)
}
pub fn bind_atomic(
path: &std::path::Path,
) -> std::io::Result<tokio::net::UnixListener> {
match bind_atomic_inner(path) {
Ok(l) => Ok(l),
Err(e) => {
log::warn!(
"bind_atomic failed ({e}); falling back to unlink-then-bind \
on {}. TOCTOU mitigation partially degraded; socket is \
still protected by its 0o700 parent dir.",
path.display()
);
let _ = std::fs::remove_file(path);
tokio::net::UnixListener::bind(path)
}
}
}
fn bind_atomic_inner(
path: &std::path::Path,
) -> std::io::Result<tokio::net::UnixListener> {
use rand::RngCore as _;
use std::fmt::Write as _;
let parent = path.parent().ok_or_else(|| {
std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"socket path has no parent directory",
)
})?;
let mut nonce = [0u8; 4];
rand::rng().fill_bytes(&mut nonce);
let mut nonce_hex = String::with_capacity(nonce.len() * 2 + 2);
nonce_hex.push_str(".t");
for b in &nonce {
write!(&mut nonce_hex, "{b:02x}").unwrap();
}
let tmp = parent.join(nonce_hex);
let _ = std::fs::remove_file(&tmp);
let listener = tokio::net::UnixListener::bind(&tmp)?;
if let Err(e) = std::fs::rename(&tmp, path) {
let _ = std::fs::remove_file(&tmp);
return Err(e);
}
Ok(listener)
}