use tokio::sync::{mpsc, oneshot};
use crate::{
util,
Error,
Message,
PeerHandle,
ReceivedMessage,
SentRequestHandle,
};
use crate::request_tracker::RequestTracker;
use crate::util::{select, Either};
pub enum Command<Body> {
SendRequest(SendRequest<Body>),
SendRawMessage(SendRawMessage<Body>),
ProcessReceivedMessage(ProcessReceivedMessage<Body>),
Stop,
UnregisterReadHandle,
RegisterWriteHandle,
UnregisterWriteHandle,
}
pub struct Peer<Transport: crate::transport::Transport> {
transport: Transport,
request_tracker: RequestTracker<Transport::Body>,
command_tx: mpsc::UnboundedSender<Command<Transport::Body>>,
command_rx: mpsc::UnboundedReceiver<Command<Transport::Body>>,
incoming_tx: mpsc::UnboundedSender<Result<ReceivedMessage<Transport::Body>, Error>>,
write_handles: usize,
}
impl<Transport: crate::transport::Transport> Peer<Transport> {
pub fn new(transport: Transport) -> (Self, PeerHandle<Transport::Body>) {
let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
let (command_tx, command_rx) = mpsc::unbounded_channel();
let request_tracker = RequestTracker::new(command_tx.clone());
let peer = Self {
transport,
request_tracker,
command_tx: command_tx.clone(),
command_rx,
incoming_tx,
write_handles: 1,
};
let handle = PeerHandle::new(incoming_rx, command_tx);
(peer, handle)
}
pub fn spawn(transport: Transport) -> PeerHandle<Transport::Body> {
let (peer, handle) = Self::new(transport);
tokio::spawn(peer.run());
handle
}
pub async fn connect<'a, Address>(address: Address, config: Transport::Config) -> std::io::Result<(PeerHandle<Transport::Body>, Transport::Info)>
where
Address: 'a,
Transport: util::Connect<'a, Address>,
{
let transport = Transport::connect(address, config).await?;
let info = transport.info()?;
Ok((Self::spawn(transport), info))
}
pub async fn run(mut self) {
let Self {
transport,
request_tracker,
command_tx,
command_rx,
incoming_tx,
write_handles,
} = &mut self;
let (read_half, write_half) = transport.split();
let mut read_loop = ReadLoop {
read_half,
command_tx: command_tx.clone(),
};
let mut command_loop = CommandLoop {
write_half,
request_tracker,
command_rx,
incoming_tx,
read_handle_dropped: &mut false,
write_handles,
};
let read_loop = read_loop.run();
let command_loop = command_loop.run();
tokio::pin!(read_loop);
tokio::pin!(command_loop);
match select(read_loop, command_loop).await {
Either::Left(((), command_loop)) => {
command_tx
.send(Command::Stop)
.map_err(drop)
.expect("command loop did not stop yet but command channel is closed");
command_loop.await;
},
Either::Right((_read_loop, ())) => {
},
}
}
pub fn transport(&self) -> &Transport {
&self.transport
}
pub fn transport_mut(&mut self) -> &mut Transport {
&mut self.transport
}
}
struct ReadLoop<R>
where
R: crate::transport::TransportReadHalf,
{
read_half: R,
command_tx: mpsc::UnboundedSender<Command<R::Body>>,
}
impl<R> ReadLoop<R>
where
R: crate::transport::TransportReadHalf,
{
async fn run(&mut self) {
loop {
let message = self.read_half.read_msg().await;
let stop = matches!(&message, Err(e) if e.is_fatal());
let message = message.map_err(|e| e.into_inner());
if self.command_tx.send(crate::peer::ProcessReceivedMessage { message }.into()).is_err() {
break;
}
if stop {
break;
}
}
}
}
struct CommandLoop<'a, W>
where
W: crate::transport::TransportWriteHalf,
{
write_half: W,
request_tracker: &'a mut RequestTracker<W::Body>,
command_rx: &'a mut mpsc::UnboundedReceiver<Command<W::Body>>,
incoming_tx: &'a mut mpsc::UnboundedSender<Result<ReceivedMessage<W::Body>, Error>>,
read_handle_dropped: &'a mut bool,
write_handles: &'a mut usize,
}
impl<W> CommandLoop<'_, W>
where
W: crate::transport::TransportWriteHalf,
{
async fn run(&mut self) {
loop {
if *self.read_handle_dropped && *self.write_handles == 0 {
break;
}
let command = self
.command_rx
.recv()
.await
.expect("all command channels closed, but we keep one open ourselves");
let flow = match command {
Command::SendRequest(command) => self.send_request(command).await,
Command::SendRawMessage(command) => self.send_raw_message(command).await,
Command::ProcessReceivedMessage(command) => self.process_incoming_message(command).await,
Command::Stop => LoopFlow::Stop,
Command::UnregisterReadHandle => {
*self.read_handle_dropped = true;
LoopFlow::Continue
},
Command::RegisterWriteHandle => {
*self.write_handles += 1;
LoopFlow::Continue
},
Command::UnregisterWriteHandle => {
*self.write_handles -= 1;
LoopFlow::Continue
},
};
match flow {
LoopFlow::Stop => break,
LoopFlow::Continue => continue,
}
}
}
async fn send_request(&mut self, command: crate::peer::SendRequest<W::Body>) -> LoopFlow {
let request = match self.request_tracker.allocate_sent_request(command.service_id) {
Ok(x) => x,
Err(e) => {
let _: Result<_, _> = command.result_tx.send(Err(e));
return LoopFlow::Continue;
},
};
let request_id = request.request_id();
let message = Message::request(request.request_id(), request.service_id(), command.body);
if let Err((e, flow)) = self.write_message(&message).await {
let _: Result<_, _> = command.result_tx.send(Err(e));
let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
return flow;
}
if command.result_tx.send(Ok(request)).is_err() {
let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
}
LoopFlow::Continue
}
async fn send_raw_message(&mut self, command: crate::peer::SendRawMessage<W::Body>) -> LoopFlow {
if command.message.header.message_type.is_response() {
let _: Result<_, _> = self.request_tracker.remove_received_request(command.message.header.request_id);
}
if let Err((e, flow)) = self.write_message(&command.message).await {
let _: Result<_, _> = command.result_tx.send(Err(e));
return flow;
}
let _: Result<_, _> = command.result_tx.send(Ok(()));
LoopFlow::Continue
}
async fn process_incoming_message(&mut self, command: crate::peer::ProcessReceivedMessage<W::Body>) -> LoopFlow {
let message = match command.message {
Ok(x) => x,
Err(e) => {
let _: Result<_, _> = self.send_incoming(Err(e)).await;
return LoopFlow::Continue;
},
};
let incoming = match self.request_tracker.process_incoming_message(message).await {
Ok(None) => return LoopFlow::Continue,
Ok(Some(x)) => x,
Err(e) => {
let _: Result<_, _> = self.send_incoming(Err(e)).await;
return LoopFlow::Continue;
},
};
match self.incoming_tx.send(Ok(incoming)) {
Ok(()) => LoopFlow::Continue,
Err(mpsc::error::SendError(msg)) => match msg.unwrap() {
ReceivedMessage::Request(request, _body) => {
let error_msg = format!("unexpected request for service {}", request.service_id());
let response = Message::error_response(request.request_id(), &error_msg);
if self.write_message(&response).await.is_err() {
LoopFlow::Stop
} else {
LoopFlow::Continue
}
},
ReceivedMessage::Stream(_) => LoopFlow::Continue,
},
}
}
async fn send_incoming(&mut self, incoming: Result<ReceivedMessage<W::Body>, Error>) -> Result<(), ()> {
if self.incoming_tx.send(incoming).is_err() {
*self.read_handle_dropped = true;
Err(())
} else {
Ok(())
}
}
async fn write_message(&mut self, message: &Message<W::Body>) -> Result<(), (Error, LoopFlow)> {
match self.write_half.write_msg(&message.header, &message.body).await {
Ok(()) => Ok(()),
Err(e) => {
let flow = if e.is_fatal() {
LoopFlow::Stop
} else {
LoopFlow::Continue
};
Err((e.into_inner(), flow))
},
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum LoopFlow {
Continue,
Stop,
}
pub struct SendRequest<Body> {
pub service_id: i32,
pub body: Body,
pub result_tx: oneshot::Sender<Result<SentRequestHandle<Body>, Error>>,
}
pub struct SendRawMessage<Body> {
pub message: Message<Body>,
pub result_tx: oneshot::Sender<Result<(), Error>>,
}
pub struct ProcessReceivedMessage<Body> {
pub message: Result<Message<Body>, Error>,
}
impl<Body> std::fmt::Debug for Command<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let mut debug = f.debug_struct("Command");
match self {
Self::SendRequest(x) => debug.field("SendRequest", x),
Self::SendRawMessage(x) => debug.field("SendRawMessage", x),
Self::ProcessReceivedMessage(x) => debug.field("ProcessReceivedMessage", x),
Self::Stop => debug.field("Stop", &()),
Self::UnregisterReadHandle => debug.field("UnregisterReadHandle", &()),
Self::RegisterWriteHandle => debug.field("RegisterWriteHandle", &()),
Self::UnregisterWriteHandle => debug.field("UnregisterWriteHandle", &()),
}.finish()
}
}
impl<Body> std::fmt::Debug for SendRequest<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SendRequest").field("service_id", &self.service_id).finish()
}
}
impl<Body> std::fmt::Debug for SendRawMessage<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SendRawMessage").field("message", &self.message).finish()
}
}
impl<Body> std::fmt::Debug for ProcessReceivedMessage<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ProcessReceivedMessage").field("message", &self.message).finish()
}
}
impl<Body> From<SendRequest<Body>> for Command<Body> {
fn from(other: SendRequest<Body>) -> Self {
Self::SendRequest(other)
}
}
impl<Body> From<SendRawMessage<Body>> for Command<Body> {
fn from(other: SendRawMessage<Body>) -> Self {
Self::SendRawMessage(other)
}
}
impl<Body> From<ProcessReceivedMessage<Body>> for Command<Body> {
fn from(other: ProcessReceivedMessage<Body>) -> Self {
Self::ProcessReceivedMessage(other)
}
}
#[cfg(test)]
mod test {
use super::*;
use assert2::assert;
use assert2::let_assert;
use crate::MessageHeader;
use crate::transport::StreamTransport;
use tokio::net::UnixStream;
#[tokio::test]
async fn test_peer() {
let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
let (peer_a, handle_a) = Peer::new(StreamTransport::new(peer_a, Default::default()));
let (peer_b, mut handle_b) = Peer::new(StreamTransport::new(peer_b, Default::default()));
let task_a = tokio::spawn(peer_a.run());
let task_b = tokio::spawn(peer_b.run());
let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
let request_id = sent_request.request_id();
let_assert!(Ok(ReceivedMessage::Request(mut received_request, _body)) = handle_b.recv_message().await);
let_assert!(Ok(()) = sent_request.send_update(3, &[4][..]).await);
let_assert!(Some(update) = received_request.recv_update().await);
assert!(update.header == MessageHeader::requester_update(request_id, 3));
assert!(update.body.as_ref() == &[4]);
let_assert!(Ok(()) = received_request.send_update(5, &[6][..]).await);
let_assert!(Some(update) = sent_request.recv_update().await);
assert!(update.header == MessageHeader::responder_update(request_id, 5));
assert!(update.body.as_ref() == &[6]);
let_assert!(Ok(()) = received_request.send_response(7, &[8][..]).await);
let_assert!(Ok(response) = sent_request.recv_response().await);
assert!(response.header == MessageHeader::response(request_id, 7));
assert!(response.body.as_ref() == &[8]);
drop(handle_a);
drop(handle_b);
drop(sent_request);
assert!(let Ok(()) = task_a.await);
assert!(let Ok(()) = task_b.await);
}
#[tokio::test]
async fn peeked_response_is_not_gone() {
let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
let request_id = sent_request.request_id();
let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
let_assert!(Ok(()) = received_request.send_update(6, &b"Hello world!"[..]).await);
let_assert!(Ok(()) = received_request.send_response(7, &b"Goodbye!"[..]).await);
assert!(let Some(_) = sent_request.recv_update().await);
assert!(let Some(_) = sent_request.recv_update().await);
assert!(let None = sent_request.recv_update().await);
let_assert!(Ok(response) = sent_request.recv_response().await);
assert!(let Err(_) = sent_request.recv_response().await);
assert!(response.header == MessageHeader::response(request_id, 7));
assert!(response.body.as_ref() == b"Goodbye!");
}
#[tokio::test]
async fn peeked_update_is_not_gone() {
let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
let request_id = sent_request.request_id();
let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
let_assert!(Ok(()) = received_request.send_response(6, &b"Goodbye!"[..]).await);
assert!(let Err(_) = sent_request.recv_response().await);
let_assert!(Some(update) = sent_request.recv_update().await);
assert!(update.header == MessageHeader::responder_update(request_id, 5));
assert!(update.body.as_ref() == b"Hello world!");
assert!(let None = sent_request.recv_update().await);
let_assert!(Ok(response) = sent_request.recv_response().await);
assert!(response.header == MessageHeader::response(request_id, 6));
assert!(response.body.as_ref() == b"Goodbye!");
}
}