use std::io::{Read as _, Write as _};
use std::sync::Mutex;
use crate::bin_error::{self, ContextExt as _};
const MAX_MESSAGE: u32 = 16 * 1024 * 1024;
static CACHED: Mutex<Option<Sock>> = Mutex::new(None);
pub struct Sock(std::os::unix::net::UnixStream);
impl Sock {
pub fn connect() -> std::io::Result<Self> {
Ok(Self(std::os::unix::net::UnixStream::connect(
bwx::dirs::socket_file(),
)?))
}
pub fn invalidate_cached() {
let mut guard = CACHED
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = None;
}
pub fn send(
&mut self,
msg: &bwx::protocol::Request,
) -> bin_error::Result<()> {
let Self(sock) = self;
let payload = rmp_serde::to_vec(msg)
.context("failed to serialize message to agent")?;
let len = u32::try_from(payload.len()).map_err(|_| {
bin_error::Error::msg(format!(
"outgoing message exceeds {MAX_MESSAGE}-byte cap"
))
})?;
if len > MAX_MESSAGE {
return Err(bin_error::Error::msg(format!(
"outgoing message exceeds {MAX_MESSAGE}-byte cap"
)));
}
sock.write_all(&len.to_be_bytes())
.context("failed to send message to agent")?;
sock.write_all(&payload)
.context("failed to send message to agent")?;
Ok(())
}
pub fn recv(&mut self) -> bin_error::Result<bwx::protocol::Response> {
let Self(sock) = self;
let mut len_buf = [0u8; 4];
sock.read_exact(&mut len_buf)
.context("failed to read message from agent")?;
let len = u32::from_be_bytes(len_buf);
if len > MAX_MESSAGE {
return Err(bin_error::Error::msg(format!(
"agent response exceeds {MAX_MESSAGE}-byte cap"
)));
}
let mut payload = vec![
0u8;
usize::try_from(len)
.expect("16 MiB-capped u32 fits in usize")
];
sock.read_exact(&mut payload)
.context("failed to read message from agent")?;
rmp_serde::from_slice(&payload)
.context("failed to parse message from agent")
}
}
#[cfg(test)]
pub fn cached_is_some() -> bool {
CACHED
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner)
.is_some()
}
pub fn request(
msg: &bwx::protocol::Request,
) -> bin_error::Result<bwx::protocol::Response> {
let mut guard = CACHED.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
if let Some(sock) = guard.as_mut() {
match sock.send(msg).and_then(|()| sock.recv()) {
Ok(res) => return Ok(res),
Err(_) => {
*guard = None;
}
}
}
let mut sock = Sock::connect().with_context(|| {
let log = bwx::dirs::agent_stderr_file();
format!(
"failed to connect to bwx-agent \
(this often means that the agent failed to start; \
check {} for agent logs)",
log.display()
)
})?;
sock.send(msg)?;
let res = sock.recv()?;
*guard = Some(sock);
Ok(res)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn framed_send_writes_length_prefix_then_msgpack() {
let (a, mut b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut sock = Sock(a);
let req = bwx::protocol::Request::new(
bwx::protocol::Environment::default(),
bwx::protocol::Action::Version,
);
sock.send(&req).unwrap();
let mut len_buf = [0u8; 4];
std::io::Read::read_exact(&mut b, &mut len_buf).unwrap();
let len = u32::from_be_bytes(len_buf);
assert!(len > 0 && len <= MAX_MESSAGE);
let mut payload =
vec![0u8; usize::try_from(len).unwrap()];
std::io::Read::read_exact(&mut b, &mut payload).unwrap();
let decoded: bwx::protocol::Request =
rmp_serde::from_slice(&payload).unwrap();
let (action, _, _, _) = decoded.into_parts();
assert!(matches!(action, bwx::protocol::Action::Version));
}
#[test]
fn framed_recv_rejects_oversized_length() {
let (a, mut b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut sock = Sock(a);
let bogus_len: u32 = MAX_MESSAGE + 1;
std::io::Write::write_all(&mut b, &bogus_len.to_be_bytes()).unwrap();
let res = sock.recv();
let err = res.unwrap_err();
assert!(format!("{err}").contains("cap"), "got: {err}");
}
#[test]
fn framed_recv_rejects_truncated_payload() {
let (a, mut b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut sock = Sock(a);
std::io::Write::write_all(&mut b, &64u32.to_be_bytes()).unwrap();
std::io::Write::write_all(&mut b, &[0xc0, 0xc1, 0xc2, 0xc3]).unwrap();
drop(b);
let err = sock.recv().unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("read message"),
"expected read error, got: {msg}"
);
}
#[test]
fn framed_recv_rejects_malformed_msgpack() {
let (a, mut b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut sock = Sock(a);
let payload = b"\xc1\xc1\xc1\xc1";
let len = u32::try_from(payload.len()).unwrap();
std::io::Write::write_all(&mut b, &len.to_be_bytes()).unwrap();
std::io::Write::write_all(&mut b, payload).unwrap();
let err = sock.recv().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("parse"), "expected parse error, got: {msg}");
}
#[test]
fn framed_recv_rejects_zero_length_frame() {
let (a, mut b) = std::os::unix::net::UnixStream::pair().unwrap();
let mut sock = Sock(a);
std::io::Write::write_all(&mut b, &0u32.to_be_bytes()).unwrap();
let err = sock.recv().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("parse"), "expected parse error, got: {msg}");
}
#[test]
fn framed_send_recv_roundtrip_via_sock_pair() {
let (a, b) = std::os::unix::net::UnixStream::pair().unwrap();
let resp = bwx::protocol::Response::Version { version: 42 };
let payload = rmp_serde::to_vec(&resp).unwrap();
let len = u32::try_from(payload.len()).unwrap();
let mut b = b;
std::io::Write::write_all(&mut b, &len.to_be_bytes()).unwrap();
std::io::Write::write_all(&mut b, &payload).unwrap();
let mut sock = Sock(a);
match sock.recv().unwrap() {
bwx::protocol::Response::Version { version } => {
assert_eq!(version, 42);
}
other => panic!("unexpected variant: {other:?}"),
}
}
#[test]
fn invalidate_cached_clears_slot() {
let (a, _b) = std::os::unix::net::UnixStream::pair().unwrap();
{
let mut guard = CACHED
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
*guard = Some(Sock(a));
}
assert!(cached_is_some());
Sock::invalidate_cached();
assert!(!cached_is_some());
}
}