use tokio::sync::mpsc;
use tokio::sync::oneshot;
use std::sync::{Arc, atomic::{AtomicBool, Ordering}};
use crate::error::private::{
connection_aborted,
InnerError,
UnexpectedMessageType,
};
use crate::peer::Command;
use crate::{Error, Message};
pub(crate) enum RequestHandleCommand<Body> {
Close,
Message(Message<Body>),
}
pub struct SentRequestHandle<Body> {
write_handle: SentRequestWriteHandle<Body>,
incoming_rx: mpsc::UnboundedReceiver<RequestHandleCommand<Body>>,
peek_buffer: Option<Message<Body>>,
}
pub struct SentRequestWriteHandle<Body> {
request_id: u32,
service_id: i32,
closed: Arc<AtomicBool>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
}
pub struct ReceivedRequestHandle<Body> {
write_handle: ReceivedRequestWriteHandle<Body>,
incoming_rx: mpsc::UnboundedReceiver<RequestHandleCommand<Body>>,
}
pub struct ReceivedRequestWriteHandle<Body> {
request_id: u32,
service_id: i32,
closed: Arc<AtomicBool>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
}
pub enum ReceivedMessage<Body> {
Request(ReceivedRequestHandle<Body>, Body),
Stream(Message<Body>),
}
impl<Body> SentRequestHandle<Body> {
pub(crate) fn new(
request_id: u32,
service_id: i32,
closed: Arc<AtomicBool>,
incoming_rx: mpsc::UnboundedReceiver<RequestHandleCommand<Body>>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
) -> Self {
let write_handle = SentRequestWriteHandle {
request_id,
service_id,
closed,
command_tx,
};
Self {
write_handle,
incoming_rx,
peek_buffer: None,
}
}
pub fn request_id(&self) -> u32 {
self.write_handle.request_id()
}
pub fn service_id(&self) -> i32 {
self.write_handle.service_id()
}
pub fn write_handle(&self) -> SentRequestWriteHandle<Body> {
self.write_handle.clone()
}
pub async fn recv_update(&mut self) -> Option<Message<Body>> {
let message = self.recv_message().await?;
if message.header.message_type.is_responder_update() {
Some(message)
} else {
self.peek_buffer = Some(message);
None
}
}
pub async fn recv_response(&mut self) -> Result<Message<Body>, Error> {
let message = self.recv_message()
.await
.ok_or_else(connection_aborted)?;
let kind = message.header.message_type;
if kind.is_response() {
Ok(message)
} else {
self.peek_buffer = Some(message);
Err(
InnerError::from(
UnexpectedMessageType {
value: kind,
expected: crate::MessageType::Response,
}
).into()
)
}
}
async fn recv_message(&mut self) -> Option<Message<Body>> {
if let Some(message) = self.peek_buffer.take() {
Some(message)
} else {
match self.incoming_rx.recv().await? {
RequestHandleCommand::Message(message) => {
if message.header.message_type.is_response() {
self.incoming_rx.close();
}
Some(message)
},
RequestHandleCommand::Close => {
self.incoming_rx.close();
None
},
}
}
}
pub async fn send_update(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
self.write_handle.send_update(service_id, body).await
}
}
impl<Body> SentRequestWriteHandle<Body> {
pub fn request_id(&self) -> u32 {
self.request_id
}
pub fn service_id(&self) -> i32 {
self.service_id
}
pub async fn send_update(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
use crate::peer::SendRawMessage;
if self.closed.load(Ordering::Acquire) {
return Err(InnerError::RequestClosed.into())
}
let body = body.into();
let (result_tx, result_rx) = oneshot::channel();
let message = Message::requester_update(self.request_id, service_id, body);
self.command_tx
.send(SendRawMessage { message, result_tx }.into())
.map_err(|_| connection_aborted())?;
result_rx.await.map_err(|_| connection_aborted())??;
Ok(())
}
}
impl<Body> ReceivedRequestHandle<Body> {
pub(crate) fn new(
request_id: u32,
service_id: i32,
closed: Arc<AtomicBool>,
incoming_rx: mpsc::UnboundedReceiver<RequestHandleCommand<Body>>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
) -> Self {
let write_handle = ReceivedRequestWriteHandle {
request_id,
service_id,
closed,
command_tx,
};
Self {
write_handle,
incoming_rx,
}
}
pub fn request_id(&self) -> u32 {
self.write_handle.request_id()
}
pub fn service_id(&self) -> i32 {
self.write_handle.service_id()
}
pub fn write_handle(&self) -> ReceivedRequestWriteHandle<Body> {
self.write_handle.clone()
}
pub async fn recv_update(&mut self) -> Option<Message<Body>> {
match self.incoming_rx.recv().await? {
RequestHandleCommand::Message(x) => Some(x),
RequestHandleCommand::Close => {
self.incoming_rx.close();
None
},
}
}
pub async fn send_update(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
self.write_handle.send_update(service_id, body).await
}
pub async fn send_response(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
self.write_handle.send_response(service_id, body).await
}
pub async fn send_error_response(&self, message: &str) -> Result<(), Error>
where
Body: crate::Body,
{
self.write_handle.send_error_response(message).await
}
}
impl<Body> ReceivedRequestWriteHandle<Body> {
pub fn request_id(&self) -> u32 {
self.request_id
}
pub fn service_id(&self) -> i32 {
self.service_id
}
pub async fn send_update(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
let body = body.into();
self.send_raw_message(Message::responder_update(self.request_id, service_id, body)).await
}
pub async fn send_response(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
let body = body.into();
self.send_raw_message(Message::response(self.request_id, service_id, body)).await
}
pub async fn send_error_response(&self, message: &str) -> Result<(), Error>
where
Body: crate::Body,
{
self.send_raw_message(Message::error_response(self.request_id, message)).await
}
async fn send_raw_message(&self, message: Message<Body>) -> Result<(), Error> {
use crate::peer::SendRawMessage;
if self.closed.load(Ordering::Acquire) {
return Err(InnerError::RequestClosed.into())
}
let (result_tx, result_rx) = oneshot::channel();
self.command_tx
.send(SendRawMessage { message, result_tx }.into())
.map_err(|_| connection_aborted())?;
result_rx.await.map_err(|_| connection_aborted())??;
Ok(())
}
}
impl<Body> std::fmt::Debug for SentRequestHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SentRequestHandle")
.field("request_id", &self.request_id())
.field("service_id", &self.service_id())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for SentRequestWriteHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("SentRequestWriteHandle")
.field("request_id", &self.request_id())
.field("service_id", &self.service_id())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for ReceivedRequestHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ReceivedRequestHandle")
.field("request_id", &self.request_id())
.field("service_id", &self.service_id())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for ReceivedRequestWriteHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("ReceivedRequestWriteHandle")
.field("request_id", &self.request_id())
.field("service_id", &self.service_id())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for ReceivedMessage<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Request(x, _body) => write!(f, "Request({:?})", x),
Self::Stream(x) => write!(f, "Stream({:?})", x),
}
}
}
impl<Body> Clone for SentRequestWriteHandle<Body> {
fn clone(&self) -> Self {
Self {
request_id: self.request_id,
service_id: self.service_id,
closed: self.closed.clone(),
command_tx: self.command_tx.clone(),
}
}
}
impl<Body> Clone for ReceivedRequestWriteHandle<Body> {
fn clone(&self) -> Self {
Self {
request_id: self.request_id,
service_id: self.service_id,
closed: self.closed.clone(),
command_tx: self.command_tx.clone(),
}
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::{Peer, UnixStreamTransport};
use tokio::net::UnixStream;
use assert2::{assert, let_assert};
#[tokio::test]
async fn test_closed_after_respone() {
let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
let (peer_a, handle_a) = Peer::new(UnixStreamTransport::new(peer_a, Default::default()));
let (peer_b, mut handle_b) = Peer::new(UnixStreamTransport::new(peer_b, Default::default()));
let task_a = tokio::spawn(peer_a.run());
let task_b = tokio::spawn(peer_b.run());
let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
let_assert!(Ok(ReceivedMessage::Request(mut received_request, _body)) = handle_b.recv_message().await);
assert!(let Ok(()) = sent_request.send_update(1, vec![]).await);
assert!(let Some(_) = received_request.recv_update().await);
assert!(let Ok(()) = received_request.send_update(1, vec![]).await);
assert!(let Some(_) = sent_request.recv_update().await);
assert!(let Ok(()) = received_request.send_response(1, vec![]).await);
assert!(let Err(_) = received_request.send_update(1, vec![]).await);
assert!(let Err(_) = received_request.send_response(1, vec![]).await);
assert!(let Ok(_) = sent_request.recv_response().await);
assert!(let Err(_) = sent_request.send_update(1, vec![]).await);
drop(handle_a);
drop(handle_b);
drop(sent_request);
drop(received_request);
assert!(let Ok(()) = task_a.await);
assert!(let Ok(()) = task_b.await);
}
}