use alloc::collections::BTreeMap as HashMap;
use async_trait::async_trait;
use crate::dialect::{v1, Dialect};
use crate::{
client::{
subscription::SubscriptionTx,
sync::{unbounded, ChannelRx, ChannelTx},
transport::router::SubscriptionRouter,
Client,
},
event::Event,
prelude::*,
query::Query,
request::SimpleRequest,
utils::uuid_str,
Error, Method, Request, Response, Subscription, SubscriptionClient,
};
#[derive(Debug)]
pub struct MockClient<M: MockRequestMatcher> {
matcher: M,
driver_tx: ChannelTx<DriverCommand>,
}
#[async_trait]
impl<M: MockRequestMatcher> Client for MockClient<M> {
async fn perform<R>(&self, request: R) -> Result<R::Output, Error>
where
R: SimpleRequest<v1::Dialect>,
{
self.matcher
.response_for(request)
.ok_or_else(Error::mismatch_response)?
.map(Into::into)
}
}
impl<M: MockRequestMatcher> MockClient<M> {
pub fn new(matcher: M) -> (Self, MockClientDriver) {
let (driver_tx, driver_rx) = unbounded();
(
Self { matcher, driver_tx },
MockClientDriver::new(driver_rx),
)
}
pub fn publish(&self, ev: &Event) {
self.driver_tx
.send(DriverCommand::Publish(Box::new(ev.clone())))
.unwrap();
}
pub fn close(self) {
self.driver_tx.send(DriverCommand::Terminate).unwrap();
}
}
#[async_trait]
impl<M: MockRequestMatcher> SubscriptionClient for MockClient<M> {
async fn subscribe(&self, query: Query) -> Result<Subscription, Error> {
let id = uuid_str();
let (subs_tx, subs_rx) = unbounded();
let (result_tx, mut result_rx) = unbounded();
self.driver_tx.send(DriverCommand::Subscribe {
id: id.clone(),
query: query.clone(),
subscription_tx: subs_tx,
result_tx,
})?;
result_rx.recv().await.unwrap()?;
Ok(Subscription::new(id, query, subs_rx))
}
async fn unsubscribe(&self, query: Query) -> Result<(), Error> {
let (result_tx, mut result_rx) = unbounded();
self.driver_tx
.send(DriverCommand::Unsubscribe { query, result_tx })?;
result_rx.recv().await.unwrap()
}
fn close(self) -> Result<(), Error> {
Ok(())
}
}
#[derive(Debug)]
pub enum DriverCommand {
Subscribe {
id: String,
query: Query,
subscription_tx: SubscriptionTx,
result_tx: ChannelTx<Result<(), Error>>,
},
Unsubscribe {
query: Query,
result_tx: ChannelTx<Result<(), Error>>,
},
Publish(Box<Event>),
Terminate,
}
#[derive(Debug)]
pub struct MockClientDriver {
router: SubscriptionRouter,
rx: ChannelRx<DriverCommand>,
}
impl MockClientDriver {
pub fn new(rx: ChannelRx<DriverCommand>) -> Self {
Self {
router: SubscriptionRouter::default(),
rx,
}
}
pub async fn run(mut self) -> Result<(), Error> {
loop {
tokio::select! {
Some(cmd) = self.rx.recv() => match cmd {
DriverCommand::Subscribe { id, query, subscription_tx, result_tx } => {
self.subscribe(id, query, subscription_tx, result_tx);
}
DriverCommand::Unsubscribe { query, result_tx } => {
self.unsubscribe(query, result_tx);
}
DriverCommand::Publish(event) => self.publish(*event),
DriverCommand::Terminate => return Ok(()),
}
}
}
}
fn subscribe(
&mut self,
id: String,
query: Query,
subscription_tx: SubscriptionTx,
result_tx: ChannelTx<Result<(), Error>>,
) {
self.router.add(id, query, subscription_tx);
result_tx.send(Ok(())).unwrap();
}
fn unsubscribe(&mut self, query: Query, result_tx: ChannelTx<Result<(), Error>>) {
self.router.remove_by_query(query);
result_tx.send(Ok(())).unwrap();
}
fn publish(&mut self, event: Event) {
self.router.publish_event(event);
}
}
pub trait MockRequestMatcher: Send + Sync {
fn response_for<R, S>(&self, request: R) -> Option<Result<R::Response, Error>>
where
R: Request<S>,
S: Dialect;
}
#[derive(Debug, Default)]
pub struct MockRequestMethodMatcher {
mappings: HashMap<Method, Result<String, Error>>,
}
impl MockRequestMatcher for MockRequestMethodMatcher {
fn response_for<R, S>(&self, request: R) -> Option<Result<R::Response, Error>>
where
R: Request<S>,
S: Dialect,
{
self.mappings.get(&request.method()).map(|res| match res {
Ok(json) => R::Response::from_string(json),
Err(e) => Err(e.clone()),
})
}
}
impl MockRequestMethodMatcher {
#[allow(dead_code)]
pub fn map(mut self, method: Method, response: Result<String, Error>) -> Self {
self.mappings.insert(method, response);
self
}
}
#[cfg(test)]
mod test {
use std::path::PathBuf;
use cometbft::{block::Height, chain::Id};
use futures::StreamExt;
use tokio::fs;
use super::*;
use crate::query::EventType;
async fn read_json_fixture(version: &str, name: &str) -> String {
fs::read_to_string(
PathBuf::from("./tests/kvstore_fixtures")
.join(version)
.join("incoming")
.join(name.to_owned() + ".json"),
)
.await
.unwrap()
}
mod v0_34 {
use super::*;
use crate::event::v0_34::DeEvent;
async fn read_event(name: &str) -> Event {
let msg = DeEvent::from_string(read_json_fixture("v0_34", name).await).unwrap();
msg.into()
}
#[tokio::test]
async fn mock_client() {
let abci_info_fixture = read_json_fixture("v0_34", "abci_info").await;
let block_fixture = read_json_fixture("v0_34", "block_at_height_10").await;
let matcher = MockRequestMethodMatcher::default()
.map(Method::AbciInfo, Ok(abci_info_fixture))
.map(Method::Block, Ok(block_fixture));
let (client, driver) = MockClient::new(matcher);
let driver_hdl = tokio::spawn(async move { driver.run().await });
let abci_info = client.abci_info().await.unwrap();
assert_eq!("{\"size\":0}".to_string(), abci_info.data);
let block = client.block(Height::from(10_u32)).await.unwrap().block;
assert_eq!(Height::from(10_u32), block.header.height);
assert_eq!("dockerchain".parse::<Id>().unwrap(), block.header.chain_id);
client.close();
driver_hdl.await.unwrap().unwrap();
}
#[tokio::test]
async fn mock_subscription_client() {
let (client, driver) = MockClient::new(MockRequestMethodMatcher::default());
let driver_hdl = tokio::spawn(async move { driver.run().await });
let event1 = read_event("subscribe_newblock_0").await;
let event2 = read_event("subscribe_newblock_1").await;
let event3 = read_event("subscribe_newblock_2").await;
let events = vec![event1, event2, event3];
let subs1 = client.subscribe(EventType::NewBlock.into()).await.unwrap();
let subs2 = client.subscribe(EventType::NewBlock.into()).await.unwrap();
assert_ne!(subs1.id().to_string(), subs2.id().to_string());
let subs1_events = subs1.take(3);
let subs2_events = subs2.take(3);
for ev in &events {
client.publish(ev);
}
let subs1_events = subs1_events.collect::<Vec<Result<Event, Error>>>().await;
let subs2_events = subs2_events.collect::<Vec<Result<Event, Error>>>().await;
assert_eq!(3, subs1_events.len());
assert_eq!(3, subs2_events.len());
for i in 0..3 {
assert!(events[i].eq(subs1_events[i].as_ref().unwrap()));
}
client.close();
driver_hdl.await.unwrap().unwrap();
}
}
mod v0_37 {
use super::*;
use crate::event::v0_37::DeEvent;
async fn read_event(name: &str) -> Event {
let msg = DeEvent::from_string(read_json_fixture("v0_37", name).await).unwrap();
msg.into()
}
#[tokio::test]
async fn mock_client() {
let abci_info_fixture = read_json_fixture("v0_37", "abci_info").await;
let block_fixture = read_json_fixture("v0_37", "block_at_height_10").await;
let matcher = MockRequestMethodMatcher::default()
.map(Method::AbciInfo, Ok(abci_info_fixture))
.map(Method::Block, Ok(block_fixture));
let (client, driver) = MockClient::new(matcher);
let driver_hdl = tokio::spawn(async move { driver.run().await });
let abci_info = client.abci_info().await.unwrap();
assert_eq!("{\"size\":9}".to_string(), abci_info.data);
let block = client.block(Height::from(10_u32)).await.unwrap().block;
assert_eq!(Height::from(10_u32), block.header.height);
assert_eq!("dockerchain".parse::<Id>().unwrap(), block.header.chain_id);
client.close();
driver_hdl.await.unwrap().unwrap();
}
#[tokio::test]
async fn mock_subscription_client() {
let (client, driver) = MockClient::new(MockRequestMethodMatcher::default());
let driver_hdl = tokio::spawn(async move { driver.run().await });
let event1 = read_event("subscribe_newblock_0").await;
let event2 = read_event("subscribe_newblock_1").await;
let event3 = read_event("subscribe_newblock_2").await;
let events = vec![event1, event2, event3];
let subs1 = client.subscribe(EventType::NewBlock.into()).await.unwrap();
let subs2 = client.subscribe(EventType::NewBlock.into()).await.unwrap();
assert_ne!(subs1.id().to_string(), subs2.id().to_string());
let subs1_events = subs1.take(3);
let subs2_events = subs2.take(3);
for ev in &events {
client.publish(ev);
}
let subs1_events = subs1_events.collect::<Vec<Result<Event, Error>>>().await;
let subs2_events = subs2_events.collect::<Vec<Result<Event, Error>>>().await;
assert_eq!(3, subs1_events.len());
assert_eq!(3, subs2_events.len());
for i in 0..3 {
assert!(events[i].eq(subs1_events[i].as_ref().unwrap()));
}
client.close();
driver_hdl.await.unwrap().unwrap();
}
}
}