use std::{
collections::HashMap,
future,
sync::Arc,
task::{Context, Poll},
};
use async_trait::async_trait;
use bytes::Bytes;
use futures::future::BoxFuture;
use tokio::{
sync::{Mutex, RwLock, mpsc},
task,
};
use tower::{Service, make::MakeService};
use crate::{
NodeIdentity,
PeerConnection,
PeerManager,
Substream,
connectivity::ConnectivitySelection,
peer_manager::{NodeId, Peer},
protocol::{
ProtocolEvent,
ProtocolId,
ProtocolNotification,
ProtocolNotificationTx,
rpc::{
Body,
NamedProtocolService,
Request,
Response,
RpcError,
RpcServer,
RpcStatus,
Streaming,
context::{RequestContext, RpcCommsBackend, RpcCommsProvider},
server::{PeerRpcServer, RpcServerError, handle::RpcServerRequest},
},
},
test_utils::mocks::{ConnectivityManagerMockState, create_connectivity_mock, create_peer_connection_mock_pair},
utils,
};
pub struct RpcRequestMock {
comms_provider: RpcCommsBackend,
#[allow(dead_code)]
connectivity_mock_state: ConnectivityManagerMockState,
}
impl RpcRequestMock {
pub fn new(peer_manager: Arc<PeerManager>) -> Self {
let (connectivity, connectivity_mock) = create_connectivity_mock();
let connectivity_mock_state = connectivity_mock.get_shared_state();
connectivity_mock.spawn();
Self {
comms_provider: RpcCommsBackend::new(peer_manager, connectivity),
connectivity_mock_state,
}
}
pub fn peer_manager(&self) -> &PeerManager {
self.comms_provider.peer_manager()
}
pub fn request_with_context<T>(&self, node_id: NodeId, msg: T) -> Request<T> {
let context = RequestContext::new(0, node_id, Box::new(self.comms_provider.clone()));
Request::with_context(context, 0.into(), msg)
}
pub fn request_no_context<T>(&self, msg: T) -> Request<T> {
Request::new(0.into(), msg)
}
}
#[async_trait]
pub trait RpcMock {
async fn request_response<TReq, TResp>(
&self,
request: Request<TReq>,
method_state: &RpcMockMethodState<TReq, TResp>,
) -> Result<Response<TResp>, RpcStatus>
where
TReq: Send + Sync,
TResp: Send + Sync + Clone,
{
method_state.requests.write().await.push(request.into_message());
let resp = method_state.response.read().await.clone()?;
Ok(Response::new(resp))
}
async fn server_streaming<TReq, TResp>(
&self,
request: Request<TReq>,
method_state: &RpcMockMethodState<TReq, Vec<TResp>>,
) -> Result<Streaming<TResp>, RpcStatus>
where
TReq: Send + Sync,
TResp: Send + Sync + Clone,
{
method_state.requests.write().await.push(request.into_message());
let resp = method_state.response.read().await.clone()?;
let (tx, rx) = mpsc::channel(resp.len());
#[allow(clippy::match_wild_err_arm)]
match utils::mpsc::send_all(&tx, resp.into_iter().map(Ok)).await {
Ok(_) => {},
Err(_) => panic!("send error"),
}
Ok(Streaming::new(rx))
}
}
#[derive(Debug, Clone)]
pub struct RpcMockMethodState<TReq, TResp> {
requests: Arc<RwLock<Vec<TReq>>>,
response: Arc<RwLock<Result<TResp, RpcStatus>>>,
}
impl<TReq, TResp> RpcMockMethodState<TReq, TResp> {
pub async fn request_count(&self) -> usize {
self.requests.read().await.len()
}
pub async fn set_response(&self, response: Result<TResp, RpcStatus>) {
*self.response.write().await = response;
}
}
impl<TReq, TResp: Default> Default for RpcMockMethodState<TReq, TResp> {
fn default() -> Self {
Self {
requests: Default::default(),
response: Arc::new(RwLock::new(Ok(Default::default()))),
}
}
}
#[derive(Debug, Clone)]
pub struct MockCommsProvider;
#[async_trait]
impl RpcCommsProvider for MockCommsProvider {
async fn fetch_peer(&self, _: &NodeId) -> Result<Peer, RpcError> {
unimplemented!()
}
async fn dial_peer(&mut self, _: &NodeId) -> Result<PeerConnection, RpcError> {
unimplemented!()
}
async fn select_connections(&mut self, _: ConnectivitySelection) -> Result<Vec<PeerConnection>, RpcError> {
unimplemented!()
}
}
pub struct MockRpcServer<TSvc> {
inner: Option<PeerRpcServer<TSvc, MockCommsProvider>>,
protocol_tx: ProtocolNotificationTx<Substream>,
our_node: Arc<NodeIdentity>,
#[allow(dead_code)]
request_tx: mpsc::Sender<RpcServerRequest>,
}
impl<TSvc> MockRpcServer<TSvc>
where
TSvc: MakeService<
ProtocolId,
Request<Bytes>,
MakeError = RpcServerError,
Response = Response<Body>,
Error = RpcStatus,
> + Send
+ Sync
+ 'static,
TSvc::Service: Send + 'static,
<TSvc::Service as Service<Request<Bytes>>>::Future: Send + 'static,
TSvc::Future: Send + 'static,
{
pub fn new(service: TSvc, our_node: Arc<NodeIdentity>) -> Self {
let (protocol_tx, protocol_rx) = mpsc::channel(1);
let (request_tx, request_rx) = mpsc::channel(1);
Self {
inner: Some(PeerRpcServer::new(
RpcServer::builder(),
service,
protocol_rx,
MockCommsProvider,
request_rx,
)),
our_node,
protocol_tx,
request_tx,
}
}
pub async fn create_connection(&self, peer: Peer, protocol_id: ProtocolId) -> PeerConnection {
let peer_node_id = peer.node_id.clone();
let (_, our_conn_mock, peer_conn, _) = create_peer_connection_mock_pair(peer, self.our_node.to_peer()).await;
let protocol_tx = self.protocol_tx.clone();
task::spawn(async move {
while let Some(substream) = our_conn_mock.next_incoming_substream().await {
let proto_notif = ProtocolNotification::new(
protocol_id.clone(),
ProtocolEvent::NewInboundSubstream(peer_node_id.clone(), substream),
);
protocol_tx.send(proto_notif).await.unwrap();
}
});
peer_conn
}
pub fn serve(&mut self) -> task::JoinHandle<Result<(), RpcServerError>> {
let inner = self.inner.take().expect("can only call `serve` once");
task::spawn(inner.serve())
}
}
impl MockRpcServer<MockRpcImpl> {
pub async fn create_mockimpl_connection(&self, peer: Peer) -> PeerConnection {
self.create_connection(peer, ProtocolId::new()).await
}
}
#[derive(Clone, Default)]
pub struct MockRpcImpl {
state: Arc<Mutex<State>>,
}
#[derive(Default)]
struct State {
accepted_calls: HashMap<u32, Response<Bytes>>,
}
impl MockRpcImpl {
pub fn new() -> Self {
Default::default()
}
}
impl Service<Request<Bytes>> for MockRpcImpl {
type Error = RpcStatus;
type Future = BoxFuture<'static, Result<Response<Body>, RpcStatus>>;
type Response = Response<Body>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Bytes>) -> Self::Future {
let state = self.state.clone();
Box::pin(async move {
let method_id = req.method().id();
match state.lock().await.accepted_calls.get(&method_id) {
Some(resp) => Ok(resp.clone().map(Body::single)),
None => Err(RpcStatus::unsupported_method(&format!(
"Method identifier `{method_id}` is not recognised or supported"
))),
}
})
}
}
impl NamedProtocolService for MockRpcImpl {
const PROTOCOL_NAME: &'static [u8] = b"mock-service";
}
impl Service<ProtocolId> for MockRpcImpl {
type Error = RpcServerError;
type Future = future::Ready<Result<Self::Response, Self::Error>>;
type Response = Self;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: ProtocolId) -> Self::Future {
future::ready(Ok(self.clone()))
}
}