use async_trait::async_trait;
use std::sync::Arc;
use tokio::io::{duplex, DuplexStream};
use tokio::sync::Mutex;
use crate::client::{OpenStream, PacketReceiver};
use crate::error::{Error, Result};
use crate::rpc::PacketWriter;
use crate::transport::create_packet_channel;
pub fn create_pipe(buffer_size: usize) -> (DuplexStream, DuplexStream) {
duplex(buffer_size)
}
pub fn create_pipe_default() -> (DuplexStream, DuplexStream) {
create_pipe(64 * 1024)
}
pub struct InMemoryOpener {
streams: Arc<Mutex<Vec<DuplexStream>>>,
}
impl InMemoryOpener {
pub fn new(streams: Vec<DuplexStream>) -> Self {
Self {
streams: Arc::new(Mutex::new(streams)),
}
}
pub fn single(stream: DuplexStream) -> Self {
Self::new(vec![stream])
}
pub async fn add_stream(&self, stream: DuplexStream) {
self.streams.lock().await.push(stream);
}
}
#[async_trait]
impl OpenStream for InMemoryOpener {
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
let mut streams = self.streams.lock().await;
let stream = streams.pop().ok_or(Error::StreamClosed)?;
let (read_half, write_half) = tokio::io::split(stream);
Ok(create_packet_channel(read_half, write_half))
}
}
pub struct SingleInMemoryOpener {
stream: Mutex<Option<DuplexStream>>,
}
impl SingleInMemoryOpener {
pub fn new(stream: DuplexStream) -> Self {
Self {
stream: Mutex::new(Some(stream)),
}
}
}
#[async_trait]
impl OpenStream for SingleInMemoryOpener {
async fn open_stream(&self) -> Result<(Arc<dyn PacketWriter>, PacketReceiver)> {
let stream = self
.stream
.lock()
.await
.take()
.ok_or(Error::StreamClosed)?;
let (read_half, write_half) = tokio::io::split(stream);
Ok(create_packet_channel(read_half, write_half))
}
}
pub fn create_test_pair() -> (SingleInMemoryOpener, DuplexStream) {
let (client_stream, server_stream) = create_pipe_default();
(SingleInMemoryOpener::new(client_stream), server_stream)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handler::Handler;
use crate::invoker::Invoker;
use crate::mux::Mux;
use crate::server::Server;
use crate::stream::Stream;
use crate::{Client, SrpcClient};
struct EchoHandler;
#[async_trait]
impl Invoker for EchoHandler {
async fn invoke_method(
&self,
_service_id: &str,
method_id: &str,
stream: Box<dyn Stream>,
) -> (bool, Result<()>) {
if method_id != "Echo" {
return (false, Err(Error::Unimplemented));
}
let request = match stream.recv_bytes().await {
Ok(b) => b,
Err(e) => return (true, Err(e)),
};
if let Err(e) = stream.send_bytes(request).await {
return (true, Err(e));
}
(true, Ok(()))
}
}
impl Handler for EchoHandler {
fn service_id(&self) -> &'static str {
"test.Echo"
}
fn method_ids(&self) -> &'static [&'static str] {
&["Echo"]
}
}
#[tokio::test]
async fn test_in_memory_echo() {
let (client_stream, server_stream) = create_pipe_default();
let mux = Arc::new(Mux::new());
mux.register(Arc::new(EchoHandler)).unwrap();
let server = Server::with_arc(mux);
let server_handle = tokio::spawn(async move {
let _ = server.handle_stream(server_stream).await;
});
let opener = SingleInMemoryOpener::new(client_stream);
let client = SrpcClient::new(opener);
let stream = client
.new_stream("test.Echo", "Echo", Some(b"hello"))
.await
.unwrap();
stream.close_send().await.unwrap();
let response = tokio::time::timeout(
std::time::Duration::from_secs(5),
stream.recv_bytes(),
)
.await
.expect("timeout")
.expect("recv_bytes failed");
assert_eq!(&response[..], b"hello");
let _ = tokio::time::timeout(std::time::Duration::from_secs(1), server_handle).await;
}
#[tokio::test]
async fn test_create_test_pair() {
let (opener, server_stream) = create_test_pair();
let mux = Arc::new(Mux::new());
mux.register(Arc::new(EchoHandler)).unwrap();
let server = Server::with_arc(mux);
let server_handle = tokio::spawn(async move {
let _ = server.handle_stream(server_stream).await;
});
let client = SrpcClient::new(opener);
let stream = client
.new_stream("test.Echo", "Echo", Some(b"test data"))
.await
.unwrap();
stream.close_send().await.unwrap();
let response = stream.recv_bytes().await.unwrap();
assert_eq!(&response[..], b"test data");
server_handle.abort();
}
#[tokio::test]
async fn test_multi_stream_opener() {
let (stream1, _) = create_pipe_default();
let (stream2, _) = create_pipe_default();
let opener = InMemoryOpener::new(vec![stream1, stream2]);
let result1 = opener.open_stream().await;
assert!(result1.is_ok());
let result2 = opener.open_stream().await;
assert!(result2.is_ok());
let result3 = opener.open_stream().await;
assert!(matches!(result3, Err(Error::StreamClosed)));
}
}