use tokio::sync::mpsc;
use tokio::sync::oneshot;
use crate::error::private::connection_aborted;
use crate::peer::{Command, SendRawMessage, SendRequest};
use crate::{Error, Message, ReceivedMessage, SentRequestHandle};
pub struct PeerHandle<Body> {
read_handle: PeerReadHandle<Body>,
write_handle: PeerWriteHandle<Body>,
}
pub struct PeerReadHandle<Body> {
incoming_rx: mpsc::UnboundedReceiver<Result<ReceivedMessage<Body>, Error>>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
}
pub struct PeerWriteHandle<Body> {
command_tx: mpsc::UnboundedSender<Command<Body>>,
}
#[derive(Clone)]
pub struct PeerCloseHandle<Body> {
command_tx: mpsc::UnboundedSender<Command<Body>>,
}
impl<Body> PeerHandle<Body> {
pub(crate) fn new(
incoming_rx: mpsc::UnboundedReceiver<Result<ReceivedMessage<Body>, Error>>,
command_tx: mpsc::UnboundedSender<Command<Body>>,
) -> Self {
let read_handle = PeerReadHandle {
incoming_rx,
command_tx: command_tx.clone(),
};
let write_handle = PeerWriteHandle { command_tx };
Self { read_handle, write_handle }
}
pub fn split(self) -> (PeerReadHandle<Body>, PeerWriteHandle<Body>) {
(self.read_handle, self.write_handle)
}
pub async fn recv_message(&mut self) -> Result<ReceivedMessage<Body>, Error> {
self.read_handle.recv_message().await
}
pub async fn send_request(&self, service_id: i32, body: impl Into<Body>) -> Result<SentRequestHandle<Body>, Error> {
self.write_handle.send_request(service_id, body).await
}
pub async fn send_stream(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
self.write_handle.send_stream(service_id, body).await
}
pub fn close(self) {
self.read_handle.close()
}
pub fn close_handle(&self) -> PeerCloseHandle<Body> {
self.read_handle.close_handle()
}
}
impl<Body> PeerReadHandle<Body> {
pub async fn recv_message(&mut self) -> Result<ReceivedMessage<Body>, Error> {
self.incoming_rx.recv()
.await
.ok_or_else(connection_aborted)?
}
pub fn close(&self) {
let _: Result<_, _> = self.command_tx.send(Command::Stop);
}
pub fn close_handle(&self) -> PeerCloseHandle<Body> {
PeerCloseHandle {
command_tx: self.command_tx.clone(),
}
}
}
impl<Body> Drop for PeerReadHandle<Body> {
fn drop(&mut self) {
let _: Result<_, _> = self.command_tx.send(Command::UnregisterReadHandle);
}
}
impl<Body> PeerWriteHandle<Body> {
pub async fn send_request(&self, service_id: i32, body: impl Into<Body>) -> Result<SentRequestHandle<Body>, Error> {
let body = body.into();
let (result_tx, result_rx) = oneshot::channel();
self.command_tx
.send(SendRequest { service_id, body, result_tx }.into())
.map_err(|_| connection_aborted())?;
result_rx.await.map_err(|_| connection_aborted())?
}
pub async fn send_stream(&self, service_id: i32, body: impl Into<Body>) -> Result<(), Error> {
let body = body.into();
let (result_tx, result_rx) = oneshot::channel();
let message = Message::stream(0, service_id, body);
self.command_tx
.send(SendRawMessage { message, result_tx }.into())
.map_err(|_| connection_aborted())?;
result_rx.await.map_err(|_| connection_aborted())?
}
pub fn close(&self) {
let _: Result<_, _> = self.command_tx.send(Command::Stop);
}
pub fn close_handle(&self) -> PeerCloseHandle<Body> {
PeerCloseHandle {
command_tx: self.command_tx.clone(),
}
}
pub fn same_peer(&self, other: &Self) -> bool {
self.command_tx.same_channel(&other.command_tx)
}
}
impl<Body> Clone for PeerWriteHandle<Body> {
fn clone(&self) -> Self {
let command_tx = self.command_tx.clone();
let _: Result<_, _> = command_tx.send(Command::RegisterWriteHandle);
Self { command_tx }
}
}
impl<Body> Drop for PeerWriteHandle<Body> {
fn drop(&mut self) {
let _: Result<_, _> = self.command_tx.send(Command::UnregisterWriteHandle);
}
}
impl<Body> PeerCloseHandle<Body> {
pub fn close(&self) {
let _: Result<_, _> = self.command_tx.send(Command::Stop);
}
}
impl<Body> std::fmt::Debug for PeerHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for PeerReadHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.finish_non_exhaustive()
}
}
impl<Body> std::fmt::Debug for PeerWriteHandle<Body> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct(core::any::type_name::<Self>())
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod test {
use fizyr_rpc::UnixSeqpacketTransport;
use assert2::assert;
use assert2::let_assert;
use tokio_seqpacket::UnixSeqpacket;
#[tokio::test]
async fn test_same_peer() {
let_assert!(Ok((peer_a, peer_b)) = UnixSeqpacket::pair());
let transport_a = UnixSeqpacketTransport::new(peer_a, Default::default());
let peer_handle = fizyr_rpc::UnixSeqpacketPeer::spawn(transport_a);
let (_, write_handle_a) = peer_handle.split();
let duplicate = write_handle_a.clone();
assert!(write_handle_a.same_peer(&duplicate));
let transport_b = UnixSeqpacketTransport::new(peer_b, Default::default());
let peer_handle = fizyr_rpc::UnixSeqpacketPeer::spawn(transport_b);
let (_, write_handle_b) = peer_handle.split();
assert!(!write_handle_a.same_peer(&write_handle_b));
}
}