use std::{
sync::{
Arc,
RwLock,
atomic::{AtomicUsize, Ordering},
},
task::{Context, Poll},
};
use bytes::Bytes;
use futures::future;
use tower::{Service, util::BoxService};
use crate::{
message::MessageExt,
peer_manager::{PeerFeatures, create_test_peer},
protocol::{
ProtocolId,
rpc::{
Request,
Response,
RpcError,
RpcStatus,
body::{Body, ClientStreaming},
client::RpcClient,
context::RpcCommsBackend,
message::RpcMethod,
server::{NamedProtocolService, RpcServerError},
},
},
test_utils::{
build_peer_manager,
mocks::{ConnectivityManagerMockState, create_connectivity_mock},
},
};
#[derive(Clone, Default)]
pub struct MockRpcService {
state: MockRpcServiceState,
}
impl NamedProtocolService for MockRpcService {
const PROTOCOL_NAME: &'static [u8] = b"rpc-mock";
}
impl MockRpcService {
pub fn new() -> Self {
Default::default()
}
pub fn shared_state(&self) -> MockRpcServiceState {
self.state.clone()
}
}
impl Service<ProtocolId> for MockRpcService {
type Error = RpcServerError;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
type Response = BoxService<Request<Bytes>, Response<Body>, RpcStatus>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ProtocolId) -> Self::Future {
let state = self.state.clone();
let my_service = tower::service_fn(move |_: Request<Bytes>| {
state.inc_call_count();
future::ready(state.get_response())
});
future::ready(Ok(BoxService::new(my_service)))
}
}
#[derive(Debug, Clone)]
pub struct MockRpcServiceState {
call_count: Arc<AtomicUsize>,
response: Arc<RwLock<Result<Response<Bytes>, RpcStatus>>>,
}
impl Default for MockRpcServiceState {
fn default() -> Self {
Self {
call_count: Arc::new(AtomicUsize::new(0)),
response: Arc::new(RwLock::new(Err(RpcStatus::not_implemented(
"Mock service not implemented",
)))),
}
}
}
impl MockRpcServiceState {
fn inc_call_count(&self) -> usize {
self.call_count.fetch_add(1, Ordering::SeqCst)
}
pub fn call_count(&self) -> usize {
self.call_count.load(Ordering::SeqCst)
}
fn get_response(&self) -> Result<Response<Body>, RpcStatus> {
let lock = &*self.response.read().unwrap();
lock.as_ref()
.map(|r| r.clone().map(Body::single))
.map_err(|err| err.clone())
}
pub fn set_response(&self, response: Result<Response<Bytes>, RpcStatus>) {
*self.response.write().unwrap() = response;
}
pub fn set_response_ok<T: prost::Message>(&self, response: &T) {
self.set_response(Ok(Response::new(response.to_encoded_bytes().into())));
}
pub fn set_response_err(&self, err: RpcStatus) {
self.set_response(Err(err));
}
}
pub struct MockRpcClient {
inner: RpcClient,
}
impl NamedProtocolService for MockRpcClient {
const PROTOCOL_NAME: &'static [u8] = b"rpc-mock";
}
impl MockRpcClient {
pub async fn request_response<T: prost::Message, R: prost::Message + Default>(
&mut self,
request: T,
method: RpcMethod,
) -> Result<R, RpcError> {
self.inner.request_response(request, method).await
}
#[allow(dead_code)]
pub async fn server_streaming<T: prost::Message, R: prost::Message + Default>(
&mut self,
request: T,
method: RpcMethod,
) -> Result<ClientStreaming<R>, RpcError> {
self.inner.server_streaming(request, method).await
}
}
impl From<RpcClient> for MockRpcClient {
fn from(inner: RpcClient) -> Self {
Self { inner }
}
}
pub(crate) fn create_mocked_rpc_context() -> (RpcCommsBackend, ConnectivityManagerMockState) {
let (connectivity, mock) = create_connectivity_mock();
let mock_state = mock.get_shared_state();
mock.spawn();
let peer_manager = build_peer_manager(&create_test_peer(false, PeerFeatures::COMMUNICATION_NODE)).unwrap();
(RpcCommsBackend::new(peer_manager, connectivity), mock_state)
}