#[cfg(feature = "std")]
use crate::connection;
use crate::connection::socket::{ReadHalf, Socket, WriteHalf};
use alloc::vec::Vec;
#[cfg(feature = "std")]
use core::cell::RefCell;
#[cfg(feature = "std")]
use rustix::fd::{BorrowedFd, OwnedFd};
#[derive(Debug)]
#[doc(hidden)]
pub struct MockSocket {
messages: Vec<Vec<u8>>,
#[cfg(feature = "std")]
fds: Vec<Vec<OwnedFd>>,
}
impl MockSocket {
pub fn new(responses: &[&str], #[cfg(feature = "std")] fds: Vec<Vec<OwnedFd>>) -> Self {
let mut messages: Vec<Vec<u8>> = responses
.iter()
.map(|r| {
let mut msg = r.as_bytes().to_vec();
msg.push(b'\0');
msg
})
.collect();
if let Some(last) = messages.last_mut() {
last.push(b'\0');
}
Self {
messages,
#[cfg(feature = "std")]
fds,
}
}
pub fn with_responses(responses: &[&str]) -> Self {
Self::new(
responses,
#[cfg(feature = "std")]
Vec::new(),
)
}
}
impl Socket for MockSocket {
type ReadHalf = MockReadHalf;
type WriteHalf = MockWriteHalf;
fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
(
MockReadHalf {
messages: self.messages,
msg_index: 0,
pos_in_msg: 0,
#[cfg(feature = "std")]
fds: self.fds,
},
MockWriteHalf {
written: Vec::new(),
#[cfg(feature = "std")]
fds_written: RefCell::new(Vec::new()),
},
)
}
}
#[derive(Debug)]
#[doc(hidden)]
pub struct MockReadHalf {
messages: Vec<Vec<u8>>,
msg_index: usize,
pos_in_msg: usize,
#[cfg(feature = "std")]
fds: Vec<Vec<OwnedFd>>,
}
impl MockReadHalf {
pub fn messages_remaining(&self) -> usize {
self.messages.len().saturating_sub(self.msg_index)
}
#[cfg(feature = "std")]
pub fn fds_consumed(&self) -> usize {
self.msg_index
}
}
impl ReadHalf for MockReadHalf {
#[cfg(feature = "std")]
async fn read(
&mut self,
buf: &mut [u8],
) -> crate::Result<(usize, alloc::vec::Vec<std::os::fd::OwnedFd>)> {
if self.msg_index >= self.messages.len() {
return Ok((0, Vec::new()));
}
let msg = &self.messages[self.msg_index];
let remaining = msg.len() - self.pos_in_msg;
let to_read = remaining.min(buf.len());
buf[..to_read].copy_from_slice(&msg[self.pos_in_msg..self.pos_in_msg + to_read]);
self.pos_in_msg += to_read;
let fds = if self.pos_in_msg >= msg.len() {
let fds = if self.msg_index < self.fds.len() {
core::mem::take(&mut self.fds[self.msg_index])
} else {
Vec::new()
};
self.msg_index += 1;
self.pos_in_msg = 0;
fds
} else {
Vec::new()
};
Ok((to_read, fds))
}
#[cfg(not(feature = "std"))]
async fn read(&mut self, buf: &mut [u8]) -> crate::Result<usize> {
if self.msg_index >= self.messages.len() {
return Ok(0);
}
let msg = &self.messages[self.msg_index];
let remaining = msg.len() - self.pos_in_msg;
let to_read = remaining.min(buf.len());
buf[..to_read].copy_from_slice(&msg[self.pos_in_msg..self.pos_in_msg + to_read]);
self.pos_in_msg += to_read;
if self.pos_in_msg >= msg.len() {
self.msg_index += 1;
self.pos_in_msg = 0;
}
Ok(to_read)
}
}
#[cfg(feature = "std")]
impl connection::socket::FetchPeerCredentials for MockReadHalf {
async fn fetch_peer_credentials(&self) -> std::io::Result<connection::Credentials> {
let uid = rustix::process::getuid();
let pid = rustix::process::getpid();
let gid = rustix::process::getgid();
#[cfg(target_os = "linux")]
{
use rustix::process::PidfdFlags;
let process_fd = rustix::process::pidfd_open(pid, PidfdFlags::empty())?;
Ok(connection::Credentials::new(
uid,
gid,
vec![],
pid,
process_fd,
))
}
#[cfg(not(target_os = "linux"))]
{
Ok(connection::Credentials::new(uid, gid, pid))
}
}
}
#[derive(Debug)]
#[doc(hidden)]
pub struct MockWriteHalf {
written: Vec<u8>,
#[cfg(feature = "std")]
fds_written: RefCell<Vec<Vec<OwnedFd>>>,
}
impl MockWriteHalf {
#[cfg(feature = "std")]
pub fn new() -> Self {
Self {
written: Vec::new(),
fds_written: RefCell::new(Vec::new()),
}
}
pub fn written_data(&self) -> &[u8] {
&self.written
}
#[cfg(feature = "std")]
pub fn fds_written(&self) -> core::cell::Ref<'_, Vec<Vec<OwnedFd>>> {
self.fds_written.borrow()
}
#[cfg(feature = "std")]
pub fn fd_write_count(&self) -> usize {
self.fds_written.borrow().len()
}
}
#[cfg(feature = "std")]
impl Default for MockWriteHalf {
fn default() -> Self {
Self::new()
}
}
impl WriteHalf for MockWriteHalf {
async fn write(
&mut self,
buf: &[u8],
#[cfg(feature = "std")] fds: &[impl std::os::fd::AsFd],
) -> crate::Result<()> {
self.written.extend_from_slice(buf);
#[cfg(feature = "std")]
{
let borrowed_fds: Vec<BorrowedFd<'_>> = fds.iter().map(|f| f.as_fd()).collect();
if !borrowed_fds.is_empty() {
let owned_fds: Vec<OwnedFd> = borrowed_fds
.iter()
.map(|fd| {
rustix::io::fcntl_dupfd_cloexec(fd, 0)
.map_err(|e| crate::Error::Io(e.into()))
})
.collect::<crate::Result<Vec<_>>>()?;
self.fds_written.borrow_mut().push(owned_fds);
}
}
Ok(())
}
}
#[derive(Debug)]
#[doc(hidden)]
pub struct TestWriteHalf {
expected_len: usize,
#[cfg(feature = "std")]
expected_fd_count: Option<usize>,
#[cfg(feature = "std")]
write_count: usize,
}
impl TestWriteHalf {
#[cfg(not(feature = "std"))]
pub fn new(expected_len: usize) -> Self {
Self { expected_len }
}
#[cfg(feature = "std")]
pub fn new(expected_len: usize) -> Self {
Self {
expected_len,
expected_fd_count: None,
write_count: 0,
}
}
#[cfg(feature = "std")]
pub fn new_with_fds(expected_len: usize, expected_fd_count: usize) -> Self {
Self {
expected_len,
expected_fd_count: Some(expected_fd_count),
write_count: 0,
}
}
#[cfg(feature = "std")]
pub fn write_count(&self) -> usize {
self.write_count
}
}
impl WriteHalf for TestWriteHalf {
async fn write(
&mut self,
buf: &[u8],
#[cfg(feature = "std")] fds: &[impl std::os::fd::AsFd],
) -> crate::Result<()> {
assert_eq!(buf.len(), self.expected_len);
#[cfg(feature = "std")]
{
let fd_count = fds.len();
if let Some(expected_count) = self.expected_fd_count {
assert_eq!(fd_count, expected_count);
} else {
assert_eq!(fd_count, 0, "Expected no FDs to be passed");
}
self.write_count += 1;
}
Ok(())
}
}
#[derive(Debug)]
#[doc(hidden)]
pub struct CountingWriteHalf {
count: usize,
}
impl Default for CountingWriteHalf {
fn default() -> Self {
Self::new()
}
}
impl CountingWriteHalf {
pub fn new() -> Self {
Self { count: 0 }
}
pub fn count(&self) -> usize {
self.count
}
}
impl WriteHalf for CountingWriteHalf {
async fn write(
&mut self,
_buf: &[u8],
#[cfg(feature = "std")] _fds: &[impl std::os::fd::AsFd],
) -> crate::Result<()> {
self.count += 1;
Ok(())
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
use std::os::fd::AsFd;
#[tokio::test]
async fn mock_socket_with_fds_basic() {
use std::os::unix::net::UnixStream;
let (r1, _w1) = UnixStream::pair().unwrap();
let (r2, _w2) = UnixStream::pair().unwrap();
let fds = vec![vec![r1.into()], vec![r2.into()]];
let socket = MockSocket::new(&["test1", "test2"], fds);
let (mut read, _write) = socket.split();
let mut buf = [0u8; 10]; let (bytes, fds1) = read.read(&mut buf).await.unwrap();
assert!(bytes > 0);
assert_eq!(fds1.len(), 1);
let (bytes, fds2) = read.read(&mut buf).await.unwrap();
assert!(bytes > 0);
assert_eq!(fds2.len(), 1);
}
#[tokio::test]
async fn mock_write_half_captures_fds() {
use std::os::unix::net::UnixStream;
let mut write = MockWriteHalf::new();
let (r1, _w1) = UnixStream::pair().unwrap();
let borrowed = r1.as_fd();
write.write(b"test", &[borrowed]).await.unwrap();
assert_eq!(write.written_data(), b"test");
assert_eq!(write.fd_write_count(), 1);
assert_eq!(write.fds_written().len(), 1);
assert_eq!(write.fds_written()[0].len(), 1);
}
#[tokio::test]
async fn test_write_half_validates_fd_count() {
use std::os::unix::net::UnixStream;
let mut write = TestWriteHalf::new_with_fds(4, 2);
let (r1, _w1) = UnixStream::pair().unwrap();
let (r2, _w2) = UnixStream::pair().unwrap();
let borrowed = [r1.as_fd(), r2.as_fd()];
write.write(b"test", &borrowed).await.unwrap();
assert_eq!(write.write_count(), 1);
}
#[tokio::test]
#[should_panic(expected = "assertion `left == right` failed")]
async fn test_write_half_panics_on_wrong_fd_count() {
use std::os::unix::net::UnixStream;
let mut write = TestWriteHalf::new_with_fds(4, 2);
let (r1, _w1) = UnixStream::pair().unwrap();
let borrowed = [r1.as_fd()];
write.write(b"test", &borrowed).await.unwrap();
}
#[tokio::test]
async fn mock_read_half_multiple_fds_per_read() {
use std::os::unix::net::UnixStream;
let (r1, _w1) = UnixStream::pair().unwrap();
let (r2, _w2) = UnixStream::pair().unwrap();
let (r3, _w3) = UnixStream::pair().unwrap();
let fds = vec![vec![r1.into(), r2.into(), r3.into()]];
let socket = MockSocket::new(&["test"], fds);
let (mut read, _write) = socket.split();
let mut buf = [0u8; 1024];
let (bytes, fds) = read.read(&mut buf).await.unwrap();
assert!(bytes > 0);
assert_eq!(fds.len(), 3);
}
#[tokio::test]
async fn mock_read_half_mixed_fd_and_no_fd_reads() {
use std::os::unix::net::UnixStream;
let (r1, _w1) = UnixStream::pair().unwrap();
let fds = vec![vec![r1.into()], vec![]];
let socket = MockSocket::new(&["test1", "test2"], fds);
let (mut read, _write) = socket.split();
let mut buf = [0u8; 10];
let (bytes, fds1) = read.read(&mut buf).await.unwrap();
assert!(bytes > 0);
assert!(!fds1.is_empty());
let (bytes, fds2) = read.read(&mut buf).await.unwrap();
assert!(bytes > 0);
assert!(fds2.is_empty());
}
}