use crate::error::{Error, Result};
use crate::message::{JsonRpcMessage, JsonRpcNotification, MessageWithFds, get_fd_count};
use rustix::fd::AsFd;
use rustix::net::{
RecvAncillaryBuffer, RecvAncillaryMessage, RecvFlags, SendAncillaryBuffer,
SendAncillaryMessage, SendFlags,
};
use serde::Serialize;
use std::collections::VecDeque;
use std::io::{self, IoSlice, IoSliceMut};
use std::mem::MaybeUninit;
use std::num::NonZeroUsize;
use std::os::unix::io::OwnedFd;
use std::sync::Arc;
use tokio::io::Interest;
use tokio::net::UnixStream as TokioUnixStream;
use tracing::{debug, trace};
pub const DEFAULT_MAX_FDS_PER_SENDMSG: NonZeroUsize = NonZeroUsize::new(500).unwrap();
const MAX_FDS_PER_RECVMSG: usize = 512;
const READ_BUFFER_SIZE: usize = 4096;
pub struct UnixSocketTransport {
stream: TokioUnixStream,
}
impl UnixSocketTransport {
pub fn new(stream: TokioUnixStream) -> Self {
Self { stream }
}
pub fn split(self) -> (Sender, Receiver) {
let stream = Arc::new(self.stream);
(
Sender {
stream: Arc::clone(&stream),
pretty: false,
max_fds_per_sendmsg: DEFAULT_MAX_FDS_PER_SENDMSG,
},
Receiver {
stream,
buffer: Vec::new(),
fd_queue: VecDeque::new(),
pending_message: None,
},
)
}
}
pub struct Sender {
stream: Arc<TokioUnixStream>,
pretty: bool,
max_fds_per_sendmsg: NonZeroUsize,
}
impl Sender {
pub fn set_pretty(&mut self, pretty: bool) {
self.pretty = pretty;
}
pub fn set_max_fds_per_sendmsg(&mut self, max_fds: NonZeroUsize) {
self.max_fds_per_sendmsg = max_fds;
}
pub async fn notify<P: Serialize>(&mut self, method: &str, params: P) -> Result<()> {
self.notify_with_fds(method, params, Vec::new()).await
}
pub async fn notify_with_fds<P: Serialize>(
&mut self,
method: &str,
params: P,
fds: Vec<OwnedFd>,
) -> Result<()> {
let params_value = serde_json::to_value(params)?;
let params_opt = if params_value.is_null() {
None
} else {
Some(params_value)
};
let notification = JsonRpcNotification::new(method.to_string(), params_opt);
let message = JsonRpcMessage::Notification(notification);
let message_with_fds = MessageWithFds::new(message, fds);
self.send(message_with_fds).await
}
pub async fn send(&mut self, message_with_fds: MessageWithFds) -> Result<()> {
let serialized = if self.pretty {
message_with_fds.serialize_pretty()?
} else {
message_with_fds.serialize()?
};
let data = serialized.into_bytes();
trace!(
"Sending message: {} with {} FDs",
String::from_utf8_lossy(&data).trim(),
message_with_fds.file_descriptors.len()
);
let fds = message_with_fds.file_descriptors;
let mut bytes_sent = 0usize;
let mut fds_sent = 0usize;
let mut current_max_fds = self.max_fds_per_sendmsg.get();
while bytes_sent < data.len() || fds_sent < fds.len() {
let remaining_data = &data[bytes_sent..];
let remaining_fds = &fds[fds_sent..];
let fds_batch = remaining_fds
.get(..current_max_fds)
.unwrap_or(remaining_fds);
let result = self
.stream
.async_io(Interest::WRITABLE, || {
let sockfd = self.stream.as_fd();
if !fds_batch.is_empty() {
let borrowed_fds: Vec<_> = fds_batch.iter().map(|fd| fd.as_fd()).collect();
let mut buffer: [MaybeUninit<u8>;
rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
[MaybeUninit::uninit();
rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
let mut control = SendAncillaryBuffer::new(&mut buffer);
if !control.push(SendAncillaryMessage::ScmRights(&borrowed_fds)) {
return Err(io::Error::other(
"Failed to add file descriptors to control message",
));
}
let iov = if !remaining_data.is_empty() {
[IoSlice::new(remaining_data)]
} else {
[IoSlice::new(b" ")]
};
rustix::net::sendmsg(sockfd, &iov, &mut control, SendFlags::empty())
.map_err(|e| to_io_error(e, "sendmsg"))
} else if !remaining_data.is_empty() {
rustix::net::send(sockfd, remaining_data, SendFlags::empty())
.map_err(|e| to_io_error(e, "send"))
} else {
Ok(0)
}
})
.await;
match result {
Ok(sent) => {
if !remaining_data.is_empty() {
bytes_sent += sent;
}
if !fds_batch.is_empty() {
fds_sent += fds_batch.len();
trace!(
"Sent {} FDs (total: {}/{}) with {} bytes",
fds_batch.len(),
fds_sent,
fds.len(),
sent
);
}
trace!(
"Progress: {}/{} bytes, {}/{} FDs",
bytes_sent,
data.len(),
fds_sent,
fds.len()
);
}
Err(e) if e.kind() == io::ErrorKind::InvalidInput && fds_batch.len() > 1 => {
let new_max = fds_batch.len() / 2;
debug!(
"sendmsg returned EINVAL with {} FDs, reducing batch size to {}",
fds_batch.len(),
new_max
);
current_max_fds = new_max;
continue;
}
Err(e) => return Err(Error::Io(e)),
}
}
if current_max_fds < self.max_fds_per_sendmsg.get() {
debug!(
"Learned kernel FD limit: reducing max_fds_per_sendmsg from {} to {}",
self.max_fds_per_sendmsg, current_max_fds
);
self.max_fds_per_sendmsg =
NonZeroUsize::new(current_max_fds).expect("current_max_fds should be >= 1");
}
Ok(())
}
}
pub struct Receiver {
stream: Arc<TokioUnixStream>,
buffer: Vec<u8>,
fd_queue: VecDeque<OwnedFd>,
pending_message: Option<(serde_json::Value, usize)>,
}
impl Receiver {
pub async fn receive(&mut self) -> Result<MessageWithFds> {
loop {
if let Some(message) = self.try_parse_message()? {
return Ok(message);
}
if let Err(e) = self.read_more_data().await {
if matches!(e, Error::ConnectionClosed)
&& let Some((_, fd_count)) = self.pending_message.take()
{
return Err(Error::MismatchedCount {
expected: fd_count,
found: self.fd_queue.len(),
});
}
return Err(e);
}
}
}
pub async fn receive_opt(&mut self) -> Result<Option<MessageWithFds>> {
match self.receive().await {
Ok(msg) => Ok(Some(msg)),
Err(Error::ConnectionClosed) => Ok(None),
Err(e) => Err(e),
}
}
fn build_message(
fd_queue: &mut VecDeque<OwnedFd>,
value: serde_json::Value,
fd_count: usize,
) -> Result<MessageWithFds> {
let fds: Vec<OwnedFd> = fd_queue.drain(..fd_count).collect();
let message = JsonRpcMessage::from_json_value(value)?;
Ok(MessageWithFds::new(message, fds))
}
fn try_parse_message(&mut self) -> Result<Option<MessageWithFds>> {
if let Some((value, fd_count)) = self
.pending_message
.take_if(|(_, c)| self.fd_queue.len() >= *c)
{
return Ok(Some(Self::build_message(
&mut self.fd_queue,
value,
fd_count,
)?));
} else if let Some((_, fd_count)) = &self.pending_message {
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
return Err(Error::MismatchedCount {
expected: *fd_count,
found: self.fd_queue.len(),
});
}
return Ok(None);
}
if self.buffer.is_empty() {
return Ok(None);
}
let mut stream =
serde_json::Deserializer::from_slice(&self.buffer).into_iter::<serde_json::Value>();
match stream.next() {
Some(Ok(value)) => {
let bytes_consumed = stream.byte_offset();
trace!("Parsed message ({} bytes): {:?}", bytes_consumed, value);
self.buffer.drain(..bytes_consumed);
let fd_count = get_fd_count(&value);
if fd_count > self.fd_queue.len() {
if self.buffer.iter().any(|&b| !b.is_ascii_whitespace()) {
return Err(Error::MismatchedCount {
expected: fd_count,
found: self.fd_queue.len(),
});
}
trace!(
"Message expects {} FDs but only {} available, waiting for more",
fd_count,
self.fd_queue.len()
);
self.pending_message = Some((value, fd_count));
return Ok(None);
}
Ok(Some(Self::build_message(
&mut self.fd_queue,
value,
fd_count,
)?))
}
Some(Err(e)) if e.is_eof() => {
Ok(None)
}
Some(Err(e)) => {
Err(Error::Json(e))
}
None => {
Ok(None)
}
}
}
async fn read_more_data(&mut self) -> Result<()> {
let mut data_buffer = [0u8; READ_BUFFER_SIZE];
let mut received_fds: Vec<OwnedFd> = Vec::new();
let bytes_read = self
.stream
.async_io(Interest::READABLE, || {
let sockfd = self.stream.as_fd();
let mut iov = [IoSliceMut::new(&mut data_buffer)];
let mut cmsg_space: [MaybeUninit<u8>;
rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))] =
[MaybeUninit::uninit(); rustix::cmsg_space!(ScmRights(MAX_FDS_PER_RECVMSG))];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg_space);
let result = rustix::net::recvmsg(
sockfd,
&mut iov,
&mut cmsg_buffer,
RecvFlags::CMSG_CLOEXEC,
)
.map_err(|e| to_io_error(e, "recvmsg"))?;
for msg in cmsg_buffer.drain() {
if let RecvAncillaryMessage::ScmRights(fds) = msg {
received_fds.extend(fds);
}
}
Ok(result.bytes)
})
.await
.map_err(Error::Io)?;
if bytes_read == 0 {
return Err(Error::ConnectionClosed);
}
self.buffer.extend_from_slice(&data_buffer[..bytes_read]);
self.fd_queue.extend(received_fds);
debug!(
"Read {} bytes, {} FDs in queue",
bytes_read,
self.fd_queue.len()
);
Ok(())
}
}
fn to_io_error(e: rustix::io::Errno, operation: &str) -> io::Error {
let io_err: io::Error = e.into();
if io_err.kind() == io::ErrorKind::WouldBlock {
io_err
} else {
io::Error::new(io_err.kind(), format!("{} failed: {}", operation, io_err))
}
}