use async_trait::async_trait;
use bytes::Bytes;
use prost::Message;
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::rpc::{ClientRpc, PacketWriter};
use crate::stream::{Context, Stream, StreamExt};
use crate::transport::create_packet_channel;
pub type PacketReceiver = tokio::sync::mpsc::Receiver<crate::proto::Packet>;
#[async_trait]
pub trait OpenStream: Send + Sync {
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)>;
}
#[async_trait]
pub trait Client: Send + Sync {
async fn exec_call<I, O>(&self, service: &str, method: &str, input: &I) -> Result<O>
where
I: Message + Send + Sync,
O: Message + Default;
async fn new_stream(
&self,
service: &str,
method: &str,
first_msg: Option<&[u8]>,
) -> Result<Box<dyn Stream>>;
}
pub type BoxClient = Box<dyn Client>;
pub struct SrpcClient<T: OpenStream> {
opener: T,
}
impl<T: OpenStream> SrpcClient<T> {
pub fn new(opener: T) -> Self {
Self { opener }
}
}
#[async_trait]
impl<T: OpenStream + 'static> Client for SrpcClient<T> {
async fn exec_call<I, O>(&self, service: &str, method: &str, input: &I) -> Result<O>
where
I: Message + Send + Sync,
O: Message + Default,
{
let input_data = input.encode_to_vec();
let (writer, mut receiver) = self.opener.open_stream().await?;
let ctx = Context::new();
let rpc = Arc::new(ClientRpc::new(
ctx.clone(),
service.to_string(),
method.to_string(),
writer,
));
rpc.start(Some(Bytes::from(input_data))).await?;
let rpc_clone = rpc.clone();
let packet_handler = tokio::spawn(async move {
while let Some(packet) = receiver.recv().await {
if rpc_clone.handle_packet(packet).await.is_err() {
break;
}
}
let _ = rpc_clone.handle_stream_close(None).await;
});
rpc.close_send().await?;
let output: O = rpc.msg_recv().await?;
let _ = rpc.wait().await;
let _ = rpc.close().await;
packet_handler.abort();
Ok(output)
}
async fn new_stream(
&self,
service: &str,
method: &str,
first_msg: Option<&[u8]>,
) -> Result<Box<dyn Stream>> {
let (writer, mut receiver) = self.opener.open_stream().await?;
let ctx = Context::new();
let rpc = Arc::new(ClientRpc::new(
ctx.clone(),
service.to_string(),
method.to_string(),
writer,
));
let first_data = first_msg.map(|d| Bytes::from(d.to_vec()));
rpc.start(first_data).await?;
let rpc_clone = rpc.clone();
let packet_handler = tokio::spawn(async move {
while let Some(packet) = receiver.recv().await {
if rpc_clone.handle_packet(packet).await.is_err() {
break;
}
}
let _ = rpc_clone.handle_stream_close(None).await;
});
Ok(Box::new(ClientStream {
rpc,
packet_handler: tokio::sync::Mutex::new(Some(packet_handler)),
}))
}
}
struct ClientStream {
rpc: Arc<ClientRpc>,
packet_handler: tokio::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
#[async_trait]
impl Stream for ClientStream {
fn context(&self) -> &Context {
self.rpc.context()
}
async fn send_bytes(&self, data: Bytes) -> Result<()> {
self.rpc.send_bytes(data).await
}
async fn recv_bytes(&self) -> Result<Bytes> {
self.rpc.recv_bytes().await
}
async fn close_send(&self) -> Result<()> {
self.rpc.close_send().await
}
async fn close(&self) -> Result<()> {
let _ = self.rpc.close().await;
if let Some(handle) = self.packet_handler.lock().await.take() {
handle.abort();
}
Ok(())
}
}
pub mod transport {
use super::*;
use std::sync::Mutex;
use tokio::io::{AsyncRead, AsyncWrite};
pub struct SingleStreamOpener<T> {
inner: Mutex<Option<T>>,
}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> SingleStreamOpener<T> {
pub fn new(transport: T) -> Self {
Self {
inner: Mutex::new(Some(transport)),
}
}
}
#[async_trait]
impl<T: AsyncRead + AsyncWrite + Send + Unpin + 'static> OpenStream for SingleStreamOpener<T> {
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
let transport = self
.inner
.lock()
.unwrap()
.take()
.ok_or(Error::StreamClosed)?;
let (read_half, write_half) = tokio::io::split(transport);
Ok(create_packet_channel(read_half, write_half))
}
}
pub use crate::transport::TransportPacketWriter;
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
struct MockWriter {
closed: AtomicBool,
}
impl MockWriter {
fn new() -> Self {
Self {
closed: AtomicBool::new(false),
}
}
}
#[async_trait]
impl PacketWriter for MockWriter {
async fn write_packet(&self, _packet: crate::proto::Packet) -> Result<()> {
Ok(())
}
async fn close(&self) -> Result<()> {
self.closed.store(true, Ordering::SeqCst);
Ok(())
}
}
struct MockOpener {
writer: Arc<MockWriter>,
receiver: Mutex<Option<PacketReceiver>>,
}
impl MockOpener {
fn new() -> (Self, tokio::sync::mpsc::Sender<crate::proto::Packet>) {
let (tx, rx) = tokio::sync::mpsc::channel(32);
(
Self {
writer: Arc::new(MockWriter::new()),
receiver: Mutex::new(Some(rx)),
},
tx,
)
}
}
#[async_trait]
impl OpenStream for MockOpener {
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
let rx = self
.receiver
.lock()
.unwrap()
.take()
.ok_or(Error::StreamClosed)?;
Ok((self.writer.clone(), rx))
}
}
#[tokio::test]
async fn test_client_new_stream() {
let (opener, _tx) = MockOpener::new();
let client = SrpcClient::new(opener);
let stream = client
.new_stream("test.Service", "TestMethod", Some(b"hello"))
.await
.unwrap();
assert!(!stream.context().is_cancelled());
}
#[tokio::test]
async fn test_single_stream_opener_only_once() {
use tokio::io::duplex;
let (client_stream, _server_stream) = duplex(1024);
let opener = transport::SingleStreamOpener::new(client_stream);
let result1 = opener.open_stream().await;
assert!(result1.is_ok());
let result2 = opener.open_stream().await;
assert!(matches!(result2, Err(Error::StreamClosed)));
}
}