use async_trait::async_trait;
use log::{debug, error, info};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use std::{
collections::HashMap,
sync::{Arc, RwLock},
};
use futures::Future;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use tokio;
use crate::{
errors::msrs::MsRsError,
transport::transport::{MsMessage, Transport},
};
pub struct Client {
transport: Arc<dyn Transport>,
listen_once: bool,
}
fn is_response(req_id: String, message: MsMessage) -> bool {
if let Some(headers) = message.0 {
if let Some(correlation_header) = headers.get("kafka_correlationId") {
return correlation_header.contains(&req_id);
}
}
false
}
impl Client {
pub fn new(transport: Arc<dyn Transport>, listen_once: bool) -> Self {
Self {
transport,
listen_once,
}
}
pub async fn send_sync<Request: Serialize, Response: DeserializeOwned>(
&self,
event: &str,
data: &mut Request,
) -> Result<Response, MsRsError> {
let request_id: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(30)
.map(char::from)
.collect();
let predicate_req_id_copy = request_id.clone();
info!("RequestId: {:?}", request_id);
let predicate =
Box::new(move |msg: MsMessage| is_response(predicate_req_id_copy.clone(), msg));
let result_future = self.transport.listen_response(&event, predicate);
let mut headers: HashMap<String, String> = HashMap::new();
headers.insert("kafka_correlationId".to_string(), request_id);
let _ = self
.transport
.send(
event,
MsMessage(
Some(headers),
serde_json::to_string::<Request>(&data).unwrap(),
),
)
.await;
let result = result_future.await?;
let result = result.1.clone();
serde_json::from_str(&result).map_err(|err| {
error!("Error while parsing result {}: {}", result, err);
MsRsError::SendError("Response parse error".to_string())
})
}
pub async fn listen<Request, Response, Fut>(
&self,
event: &str,
f: Box<dyn Fn(Request) -> Fut + Sync + Send + 'static>,
) -> Result<(), MsRsError>
where
Fut: Future<Output = Result<Option<Response>, MsRsError>> + Send + 'static,
Request: DeserializeOwned,
Response: Serialize,
{
loop {
let data = self.transport.consume(event).await?;
let resp_topic = &format!("{}.reply", event);
debug!("Received message {:?}", data.1);
let req = serde_json::from_str::<Request>(&data.1);
if req.is_err() {
error!("Unable to parse request {}, skipping...", event);
continue;
}
let req = req.unwrap();
let result = f(req).await;
match result {
Ok(result) => {
if let Some(result) = result {
self.send(resp_topic, data.0, &result).await?;
}
}
Err(error) => {
info!("Get an error: {}", error.to_string());
let mut headers: HashMap<String, String> = HashMap::new();
headers.insert(
"kafka_nest-err".to_string(),
format!("{{\"type\":\"rpc\",\"message\":\"{}\"}}", error.to_string()),
);
headers.insert("msrs_err".to_string(), error.to_string());
let correlation_id_name = "kafka_correlationId".to_string();
let in_headers = data.0.unwrap();
let correlation_id = in_headers.get(&correlation_id_name);
if let Some(correlation_id) = correlation_id {
headers.insert(correlation_id_name, correlation_id.to_owned());
} else {
println!(
"No correlation_id header found, sending general message to {} topic",
resp_topic
);
}
self.send(resp_topic, Some(headers), {}).await?;
}
}
if self.listen_once {
break Ok(());
}
}
}
pub async fn send<Request: Serialize>(
&self,
event: &str,
headers: Option<HashMap<String, String>>,
data: Request,
) -> Result<(), MsRsError> {
self.transport
.send(
event,
MsMessage(headers, serde_json::to_string(&data).unwrap()),
)
.await
}
}
pub struct MockTransport {
pub sent_data: RwLock<Vec<String>>,
pub reply_results: RwLock<Vec<String>>,
}
impl MockTransport {
fn _new() -> Self {
MockTransport {
sent_data: RwLock::new(Vec::new()),
reply_results: RwLock::new(Vec::new()),
}
}
}
#[async_trait]
impl Transport for MockTransport {
async fn send(&self, event: &str, data: MsMessage) -> Result<(), MsRsError> {
let mut data_write = self.sent_data.write().unwrap();
data_write.push(data.1.to_string());
if event.contains(".reply") {
let mut reply_results_write = self.reply_results.write().unwrap();
reply_results_write.push(data.1.to_string());
}
Ok(())
}
async fn listen_response(
&self,
_event: &str,
_predicate: Box<dyn Fn(MsMessage) -> bool + Send + Sync>,
) -> Result<MsMessage, MsRsError> {
let data_read = self.sent_data.read().unwrap();
Ok(MsMessage(None, data_read[0].clone()))
}
async fn consume(&self, _event: &str) -> Result<MsMessage, MsRsError> {
let data_read = self.sent_data.read().unwrap();
Ok(MsMessage(None, data_read[0].clone()))
}
}
#[derive(Debug, Deserialize, Serialize)]
struct TestData {
id: String,
}
#[tokio::test]
async fn test_event_send() {
let transport = Arc::new(MockTransport::_new());
let event = Client::new(transport.clone(), true);
event.send("test_event", None, "test_data").await.unwrap();
let sent_data = transport.sent_data.read().unwrap();
assert_eq!(sent_data[0], "\"test_data\"");
}
#[tokio::test]
async fn test_event_listen() {
let transport = Arc::new(MockTransport::_new());
let event = Client::new(transport.clone(), true);
transport
.send(
"test_event",
MsMessage(None, "{\"id\": \"123\"}".to_string()),
)
.await
.unwrap();
event
.listen(
"test_event",
Box::new(|x: TestData| async move { Ok(Some(x)) }),
)
.await
.unwrap();
assert_eq!(transport.clone().reply_results.read().unwrap().len(), 1)
}