use std::sync::{atomic::AtomicBool, Arc};
use tokio::sync::oneshot;
use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
use super::reactor::completion_registry::{Completion, CompletionGuard};
use super::reactor::{
completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
};
use crate::client::reactor::completion_reactor::{CompletableRpc, RpcNotification};
use crate::Message;
#[derive(Debug)]
pub struct RpcClient<Request, Response>
where
Request: Message,
Response: Message,
{
submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
is_alive: Arc<AtomicBool>,
}
impl<Request, Response> Clone for RpcClient<Request, Response>
where
Request: Message,
Response: Message,
{
fn clone(&self) -> Self {
Self {
submission_queue: self.submission_queue.clone(),
is_alive: self.is_alive.clone(),
}
}
}
impl<Request, Response> RpcClient<Request, Response>
where
Request: Message,
Response: Message,
{
pub(crate) fn new(
submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
message_reactor: &RpcCompletionReactor<
Response,
Request,
DoNothingMessageHandler<Response>,
>,
) -> Self {
Self {
submission_queue,
is_alive: message_reactor.alive_handle(),
}
}
pub fn is_alive(&self) -> bool {
self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
}
#[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
pub fn send_streaming(
&self,
request: Request,
) -> crate::Result<StreamingCompletion<Response, Request>> {
let (sender, completion) = spillway::channel_with_concurrency(1);
let completion_guard = self.send_message(Completion::RemoteStreaming(sender), request)?;
let completion = StreamingCompletion::new(completion, completion_guard);
Ok(completion)
}
#[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
pub fn send_unary(
&self,
request: Request,
) -> crate::Result<UnaryCompletion<Response, Request>> {
let (completor, completion) = oneshot::channel();
let completion_guard = self.send_message(Completion::Unary(completor), request)?;
let completion = UnaryCompletion::new(completion, completion_guard);
Ok(completion)
}
fn send_message(
&self,
completion: Completion<Response>,
request: Request,
) -> crate::Result<CompletionGuard<Response, Request>> {
if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
return Err(crate::Error::ConnectionIsClosed);
}
let message_id = request.message_id();
let completion_guard = CompletionGuard::new(message_id, self.submission_queue.clone());
self.submission_queue
.send(RpcNotification::New(CompletableRpc {
message_id,
completion,
request,
}))
.map_err(|_e| crate::Error::ConnectionIsClosed)
.map(|_| completion_guard)
}
}
#[cfg(test)]
mod test {
use std::future::Future;
use std::pin::pin;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::sync::Mutex;
use std::task::Context;
use std::task::Poll;
use futures::task::noop_waker_ref;
use crate::client::connection_pool::ClientConnector;
use crate::client::connection_pool::ConnectionPool;
use crate::client::reactor::completion_reactor::CompletableRpc;
use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
use crate::client::reactor::completion_reactor::RpcCompletionReactor;
use crate::client::reactor::completion_reactor::RpcNotification;
use crate::Message;
use super::RpcClient;
impl Message for u64 {
fn message_id(&self) -> u64 {
*self & 0xffffffff
}
fn control_code(&self) -> crate::ProtosocketControlCode {
match *self >> 32 {
0 => crate::ProtosocketControlCode::Normal,
1 => crate::ProtosocketControlCode::Cancel,
2 => crate::ProtosocketControlCode::End,
_ => unreachable!("invalid control code"),
}
}
fn set_message_id(&mut self, message_id: u64) {
*self = (*self & 0xf00000000) | message_id;
}
fn cancelled(message_id: u64) -> Self {
(1_u64 << 32) | message_id
}
fn ended(message_id: u64) -> Self {
(2 << 32) | message_id
}
}
fn drive_future<F: Future>(f: F) -> F::Output {
let mut f = pin!(f);
loop {
let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
if let Poll::Ready(result) = next {
break result;
}
}
}
#[allow(clippy::type_complexity)]
fn get_client() -> (
spillway::Receiver<RpcNotification<u64, u64>>,
RpcClient<u64, u64>,
RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
) {
let (sender, remote_end) = spillway::channel();
let rpc_reactor =
RpcCompletionReactor::<u64, u64, _>::new(DoNothingMessageHandler::default());
let client = RpcClient::new(sender, &rpc_reactor);
(remote_end, client, rpc_reactor)
}
#[test]
fn unary_drop_cancel() {
let (mut remote_end, client, _reactor) = get_client();
let response = client.send_unary(4).expect("can send");
let notification = drive_future(remote_end.next()).expect("a request is sent");
assert!(matches!(
notification,
RpcNotification::New(CompletableRpc { message_id: 4, .. })
));
drop(response);
let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
assert!(matches!(cancellation, RpcNotification::Cancel(4)));
}
#[test]
fn streaming_drop_cancel() {
let (mut remote_end, client, _reactor) = get_client();
let response = client.send_streaming(4).expect("can send");
let notification = drive_future(remote_end.next()).expect("a request is sent");
assert!(matches!(
notification,
RpcNotification::New(CompletableRpc { message_id: 4, .. })
));
drop(response);
let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
assert!(matches!(cancellation, RpcNotification::Cancel(4)));
}
#[allow(clippy::type_complexity)]
#[derive(Default)]
struct TestConnector {
clients: Mutex<
Vec<(
spillway::Receiver<RpcNotification<u64, u64>>,
RpcClient<u64, u64>,
RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
)>,
>,
fail_connections: AtomicBool,
}
impl ClientConnector for Arc<TestConnector> {
type Request = u64;
type Response = u64;
async fn connect(self) -> crate::Result<RpcClient<Self::Request, Self::Response>> {
if self
.fail_connections
.load(std::sync::atomic::Ordering::Relaxed)
{
return Err(crate::Error::IoFailure(Arc::new(std::io::Error::other(
"simulated connection failure",
))));
}
let (remote_end, client, reactor) = get_client();
self.clients
.lock()
.expect("mutex works")
.push((remote_end, client.clone(), reactor));
Ok(client)
}
}
#[tokio::test]
async fn connection_pool() {
let connector = Arc::new(TestConnector::default());
let pool = ConnectionPool::new(connector.clone(), 1);
let rpc_client_a = pool
.get_connection()
.await
.expect("can get a connection from the pool");
assert_eq!(
1,
connector.clients.lock().expect("mutex works").len(),
"one connection created"
);
let rpc_client_b = pool
.get_connection()
.await
.expect("can get a connection from the pool");
assert_eq!(
1,
connector.clients.lock().expect("mutex works").len(),
"still one connection created"
);
assert!(
Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
"same connection shared"
);
let _reply_a = rpc_client_a.send_unary(42).expect("can send");
let _reply_b = rpc_client_b.send_unary(43).expect("can send");
let (mut remote_end, _client, _reactor) = {
let mut clients = connector.clients.lock().expect("mutex works");
clients.pop().expect("one client exists")
};
assert!(matches!(
remote_end.next().await.expect("request a is received"),
RpcNotification::New(CompletableRpc { message_id: 42, .. })
));
assert!(matches!(
remote_end.next().await.expect("request b is received"),
RpcNotification::New(CompletableRpc { message_id: 43, .. })
));
}
#[tokio::test]
async fn connection_pool_reconnect() {
let connector = Arc::new(TestConnector::default());
let pool = ConnectionPool::new(connector.clone(), 1);
let rpc_client_a = pool
.get_connection()
.await
.expect("can get a connection from the pool");
assert_eq!(
1,
connector.clients.lock().expect("mutex works").len(),
"one connection created"
);
rpc_client_a
.is_alive
.store(false, std::sync::atomic::Ordering::Relaxed);
let rpc_client_b = pool.get_connection().await.expect("can get a connection from the pool even when the previous connection is dead, as long as the connection attempt succeeds");
assert_eq!(
2,
connector.clients.lock().expect("mutex works").len(),
"a new connection was created, so the connector was asked to make a new connection"
);
assert!(
!Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
"new connection created"
);
}
#[tokio::test]
async fn connection_pool_failure() {
let connector = Arc::new(TestConnector::default());
let pool = ConnectionPool::new(connector.clone(), 1);
connector
.fail_connections
.store(true, std::sync::atomic::Ordering::Relaxed);
pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
}
#[tokio::test]
async fn connection_pool_reconnect_failure_recovery() {
let connector = Arc::new(TestConnector::default());
let pool = ConnectionPool::new(connector.clone(), 1);
let rpc_client_a = pool
.get_connection()
.await
.expect("can get a connection from the pool");
rpc_client_a
.is_alive
.store(false, std::sync::atomic::Ordering::Relaxed);
connector
.fail_connections
.store(true, std::sync::atomic::Ordering::Relaxed);
pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
connector
.fail_connections
.store(false, std::sync::atomic::Ordering::Relaxed);
let rpc_client_b = pool
.get_connection()
.await
.expect("can get a connection from the pool now");
assert!(
!Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
"new connection created"
);
}
}