use std::{path::PathBuf, time::Duration};
use tokio::{
io::AsyncWriteExt, net::unix::OwnedWriteHalf, net::UnixDatagram, net::UnixStream, sync::Mutex,
};
use logfence_proto::syslog::SyslogMessage;
use crate::error::ClientError;
fn is_buffer_full(e: &std::io::Error) -> bool {
matches!(e.kind(), std::io::ErrorKind::WouldBlock) || matches!(e.raw_os_error(), Some(105 | 55))
}
fn dgram_attempt_delay(attempt: u32) -> Duration {
let max = Duration::from_secs(1);
let shift = attempt.saturating_sub(2);
let micros = 1u64
.checked_shl(shift)
.map_or(u64::MAX, |v| 100u64.saturating_mul(v));
let delay = Duration::from_micros(micros);
if delay > max {
max
} else {
delay
}
}
#[allow(
async_fn_in_trait,
reason = "Transport is only implemented within this crate; \
the implementation produces Send futures due to its Send-safe state"
)]
pub trait Transport: Send + Sync {
async fn send(&self, msg: &SyslogMessage) -> Result<(), ClientError>;
}
pub struct UnixTransport {
path: PathBuf,
max_size: usize,
stream: Mutex<Option<OwnedWriteHalf>>,
}
impl UnixTransport {
#[must_use]
pub fn new(path: impl Into<PathBuf>, max_size: usize) -> Self {
Self {
path: path.into(),
max_size,
stream: Mutex::new(None),
}
}
}
impl Transport for UnixTransport {
async fn send(&self, msg: &SyslogMessage) -> Result<(), ClientError> {
let wire = msg.to_string();
if wire.len() > self.max_size {
return Err(ClientError::MessageTooLarge {
max: self.max_size,
got: wire.len(),
});
}
let frame = format!("{} {wire}", wire.len());
let frame_bytes = frame.as_bytes();
let mut guard = self.stream.lock().await;
if guard.is_none() {
let conn = UnixStream::connect(&self.path).await?;
let std_conn = conn.into_std()?;
std_conn.shutdown(std::net::Shutdown::Read)?;
let conn = UnixStream::from_std(std_conn)?;
let (_, write_half) = conn.into_split();
*guard = Some(write_half);
}
let Some(stream) = guard.as_mut() else {
return Err(ClientError::Io(std::io::Error::other(
"internal: Unix stream not initialised",
)));
};
if let Err(e) = stream.write_all(frame_bytes).await {
*guard = None;
return Err(ClientError::Io(e));
}
Ok(())
}
}
pub struct UnixDatagramTransport {
path: PathBuf,
max_size: usize,
max_attempts: u32,
socket: Mutex<Option<UnixDatagram>>,
}
impl UnixDatagramTransport {
#[must_use]
pub fn new(path: impl Into<PathBuf>, max_size: usize) -> Self {
Self {
path: path.into(),
max_size,
max_attempts: 4,
socket: Mutex::new(None),
}
}
#[must_use]
pub fn max_attempts(mut self, n: u32) -> Self {
self.max_attempts = n;
self
}
}
impl Transport for UnixDatagramTransport {
async fn send(&self, msg: &SyslogMessage) -> Result<(), ClientError> {
let wire = msg.to_string();
if wire.len() > self.max_size {
return Err(ClientError::MessageTooLarge {
max: self.max_size,
got: wire.len(),
});
}
let mut guard = self.socket.lock().await;
if guard.is_none() {
let sock = UnixDatagram::unbound()?;
if let Err(e) = sock.shutdown(std::net::Shutdown::Read) {
if e.kind() != std::io::ErrorKind::NotConnected {
return Err(ClientError::Io(e));
}
}
*guard = Some(sock);
}
let Some(sock) = guard.as_ref() else {
return Err(ClientError::Io(std::io::Error::other(
"internal: Unix datagram socket not initialised",
)));
};
let mut last_err = match sock.try_send_to(wire.as_bytes(), &self.path) {
Ok(_) => return Ok(()),
Err(e) if !is_buffer_full(&e) => {
*guard = None;
return Err(ClientError::Io(e));
}
Err(e) => e,
};
let mut attempt = 2u32;
loop {
if self.max_attempts != 0 && attempt > self.max_attempts {
break;
}
tokio::time::sleep(dgram_attempt_delay(attempt)).await;
match sock.try_send_to(wire.as_bytes(), &self.path) {
Ok(_) => return Ok(()),
Err(e) if !is_buffer_full(&e) => {
*guard = None;
return Err(ClientError::Io(e));
}
Err(e) => last_err = e,
}
attempt = attempt.saturating_add(1);
}
*guard = None;
Err(ClientError::Io(last_err))
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
reason = "unwrap is appropriate in test assertions"
)]
mod tests {
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tokio::net::UnixListener;
use logfence_proto::syslog::{Facility, Priority, Severity};
use super::*;
fn sample_msg() -> SyslogMessage {
SyslogMessage {
priority: Priority {
facility: Facility::Local0,
severity: Severity::Info,
},
timestamp: None,
hostname: None,
app_name: Some("test".into()),
proc_id: None,
msg_id: None,
structured_data: "-".into(),
msg: r#"{"k":"v"}"#.into(),
}
}
#[tokio::test]
async fn unix_transport_sends_octet_count_frame() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("test.sock");
let listener = UnixListener::bind(&sock_path).unwrap();
let transport = UnixTransport::new(&sock_path, 65536);
let msg = sample_msg();
let expected_wire = msg.to_string();
let send_task = tokio::spawn(async move { transport.send(&msg).await.unwrap() });
let (mut conn, _) = tokio::time::timeout(Duration::from_secs(1), listener.accept())
.await
.unwrap()
.unwrap();
let mut buf = vec![0u8; 4096];
let n = tokio::time::timeout(Duration::from_secs(1), conn.read(&mut buf))
.await
.unwrap()
.unwrap();
let received = std::str::from_utf8(&buf[..n]).unwrap();
let (count_str, body) = received.split_once(' ').unwrap();
assert_eq!(count_str.parse::<usize>().unwrap(), expected_wire.len());
assert_eq!(body, expected_wire);
send_task.await.unwrap();
}
#[tokio::test]
async fn unix_transport_reconnects_after_error() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("reconnect.sock");
let transport = UnixTransport::new(&sock_path, 65536);
let msg = sample_msg();
assert!(transport.send(&msg).await.is_err());
let listener = UnixListener::bind(&sock_path).unwrap();
let send_task = tokio::spawn({
let msg = msg.clone();
async move { transport.send(&msg).await }
});
let accept = tokio::time::timeout(Duration::from_secs(1), listener.accept()).await;
assert!(accept.is_ok());
assert!(send_task.await.unwrap().is_ok());
}
#[tokio::test]
async fn unix_transport_rejects_oversized_message() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("oversize.sock");
let transport = UnixTransport::new(&sock_path, 10);
let err = transport.send(&sample_msg()).await.unwrap_err();
assert!(matches!(err, ClientError::MessageTooLarge { .. }));
}
#[tokio::test]
async fn unix_datagram_transport_sends_raw_wire() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("dgram.sock");
let receiver = UnixDatagram::bind(&sock_path).unwrap();
let transport = UnixDatagramTransport::new(&sock_path, 65536);
let msg = sample_msg();
let expected_wire = msg.to_string();
transport.send(&msg).await.unwrap();
let mut buf = vec![0u8; 4096];
let n = tokio::time::timeout(Duration::from_secs(1), receiver.recv(&mut buf))
.await
.unwrap()
.unwrap();
let received = std::str::from_utf8(&buf[..n]).unwrap();
assert_eq!(received, expected_wire);
}
#[tokio::test]
async fn unix_datagram_transport_rejects_oversized_message() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("oversize_dgram.sock");
let transport = UnixDatagramTransport::new(&sock_path, 10);
let err = transport.send(&sample_msg()).await.unwrap_err();
assert!(matches!(err, ClientError::MessageTooLarge { .. }));
}
#[tokio::test]
async fn unix_datagram_retries_on_buffer_full() {
let dir = tempfile::tempdir().unwrap();
let sock_path = dir.path().join("dgram_retry.sock");
let receiver = std::os::unix::net::UnixDatagram::bind(&sock_path).unwrap();
socket2::SockRef::from(&receiver)
.set_recv_buffer_size(4096)
.unwrap();
let filler = std::os::unix::net::UnixDatagram::unbound().unwrap();
filler.set_nonblocking(true).unwrap();
let mut fill_count = 0usize;
loop {
match filler.send_to(&[0u8], &sock_path) {
Ok(_) => {
fill_count += 1;
assert!(fill_count < 100_000, "socket buffer never filled");
}
Err(ref e) if super::is_buffer_full(e) => break,
Err(ref e) => {
assert!(super::is_buffer_full(e), "unexpected fill error: {e}");
break;
}
}
}
assert!(fill_count > 0);
let drainer = receiver.try_clone().unwrap();
std::thread::spawn(move || {
std::thread::sleep(Duration::from_micros(200));
let mut buf = vec![0u8; 65_536];
drainer.set_nonblocking(true).unwrap();
while drainer.recv(&mut buf).is_ok() {}
});
let transport = UnixDatagramTransport::new(&sock_path, 65_536);
tokio::time::timeout(Duration::from_millis(200), transport.send(&sample_msg()))
.await
.unwrap()
.unwrap();
}
}