use std::{
collections::VecDeque,
io,
sync::Arc,
task::{Context, Poll},
};
use event_listener::{Event, EventListener};
#[cfg(unix)]
use crate::OwnedFd;
use crate::{
message_header::{MAX_MESSAGE_SIZE, MIN_MESSAGE_SIZE},
raw::Socket,
utils::padding_for_8_bytes,
Message, MessagePrimaryHeader,
};
use futures_core::ready;
#[derive(derivative::Derivative)]
#[derivative(Debug)]
pub struct Connection<S> {
#[derivative(Debug = "ignore")]
socket: S,
event: Event,
raw_in_buffer: Vec<u8>,
#[cfg(unix)]
raw_in_fds: Vec<OwnedFd>,
raw_in_pos: usize,
out_pos: usize,
out_msgs: VecDeque<Arc<Message>>,
prev_seq: u64,
}
impl<S: Socket> Connection<S> {
pub(crate) fn new(socket: S, raw_in_buffer: Vec<u8>) -> Connection<S> {
Connection {
socket,
event: Event::new(),
raw_in_pos: raw_in_buffer.len(),
raw_in_buffer,
#[cfg(unix)]
raw_in_fds: vec![],
out_pos: 0,
out_msgs: VecDeque::new(),
prev_seq: 0,
}
}
pub fn try_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.event.notify(usize::MAX);
while let Some(msg) = self.out_msgs.front() {
loop {
let data = &msg.as_bytes()[self.out_pos..];
if data.is_empty() {
self.out_pos = 0;
self.out_msgs.pop_front();
break;
}
#[cfg(unix)]
let fds = if self.out_pos == 0 { msg.fds() } else { vec![] };
self.out_pos += ready!(self.socket.poll_sendmsg(
cx,
data,
#[cfg(unix)]
&fds,
))?;
}
}
Poll::Ready(Ok(()))
}
pub fn enqueue_message(&mut self, msg: Arc<Message>) {
self.out_msgs.push_back(msg);
}
pub fn try_receive_message(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<Message>> {
self.event.notify(usize::MAX);
if self.raw_in_pos < MIN_MESSAGE_SIZE {
self.raw_in_buffer.resize(MIN_MESSAGE_SIZE, 0);
while self.raw_in_pos < MIN_MESSAGE_SIZE {
let res = ready!(self
.socket
.poll_recvmsg(cx, &mut self.raw_in_buffer[self.raw_in_pos..]))?;
let len = {
#[cfg(unix)]
{
let (len, fds) = res;
self.raw_in_fds.extend(fds);
len
}
#[cfg(not(unix))]
{
res
}
};
self.raw_in_pos += len;
if len == 0 {
return Poll::Ready(Err(crate::Error::InputOutput(
std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"failed to receive message",
)
.into(),
)));
}
}
}
let (primary_header, fields_len) = MessagePrimaryHeader::read(&self.raw_in_buffer)?;
let header_len = MIN_MESSAGE_SIZE + fields_len as usize;
let body_padding = padding_for_8_bytes(header_len);
let body_len = primary_header.body_len() as usize;
let total_len = header_len + body_padding + body_len;
if total_len > MAX_MESSAGE_SIZE {
return Poll::Ready(Err(crate::Error::ExcessData));
}
self.raw_in_buffer.resize(total_len, 0);
while self.raw_in_buffer.len() > self.raw_in_pos {
let res = ready!(self
.socket
.poll_recvmsg(cx, &mut self.raw_in_buffer[self.raw_in_pos..]))?;
let read = {
#[cfg(unix)]
{
let (read, fds) = res;
self.raw_in_fds.extend(fds);
read
}
#[cfg(not(unix))]
{
res
}
};
self.raw_in_pos += read;
}
self.raw_in_pos = 0;
let bytes = std::mem::take(&mut self.raw_in_buffer);
#[cfg(unix)]
let fds = std::mem::take(&mut self.raw_in_fds);
let seq = self.prev_seq + 1;
self.prev_seq = seq;
Poll::Ready(Message::from_raw_parts(
bytes,
#[cfg(unix)]
fds,
seq,
))
}
pub fn close(&self) -> crate::Result<()> {
self.event.notify(usize::MAX);
self.socket().close().map_err(|e| e.into())
}
pub fn socket(&self) -> &S {
&self.socket
}
pub(crate) fn monitor_activity(&self) -> EventListener {
self.event.listen()
}
}
impl Connection<Box<dyn Socket>> {
pub(crate) fn flush(&mut self, cx: &mut Context<'_>) -> Poll<crate::Result<()>> {
self.try_flush(cx).map_err(Into::into)
}
}
#[cfg(unix)]
#[cfg(test)]
mod tests {
use super::{Arc, Connection};
use crate::message::Message;
use futures_util::future::poll_fn;
use test_log::test;
#[test]
fn raw_send_receive() {
crate::block_on(raw_send_receive_async());
}
async fn raw_send_receive_async() {
#[cfg(not(feature = "tokio"))]
let (p0, p1) = std::os::unix::net::UnixStream::pair()
.map(|(p0, p1)| {
(
async_io::Async::new(p0).unwrap(),
async_io::Async::new(p1).unwrap(),
)
})
.unwrap();
#[cfg(feature = "tokio")]
let (p0, p1) = tokio::net::UnixStream::pair().unwrap();
let mut conn0 = Connection::new(p0, vec![]);
let mut conn1 = Connection::new(p1, vec![]);
let msg = Message::method(
None::<()>,
None::<()>,
"/",
Some("org.zbus.p2p"),
"Test",
&(),
)
.unwrap();
conn0.enqueue_message(Arc::new(msg));
poll_fn(|cx| conn0.try_flush(cx)).await.unwrap();
let ret = poll_fn(|cx| conn1.try_receive_message(cx)).await.unwrap();
assert_eq!(ret.to_string(), "Method call Test");
}
}