msrs 0.1.31

Micro Microservices framework for rust. Supports different transports
Documentation
use async_trait::async_trait;
use log::{debug, error, info};
use rand::{distributions::Alphanumeric, thread_rng, Rng};

use std::{
    collections::HashMap,
    sync::{Arc, RwLock},
};

use futures::Future;
use serde::{de::DeserializeOwned, Deserialize, Serialize};

use tokio;

use crate::{
    errors::msrs::MsRsError,
    transport::transport::{MsMessage, Transport},
};

pub struct Client {
    transport: Arc<dyn Transport>,
    listen_once: bool,
}

fn is_response(req_id: String, message: MsMessage) -> bool {
    if let Some(headers) = message.0 {
        if let Some(correlation_header) = headers.get("kafka_correlationId") {
            return correlation_header.contains(&req_id);
        }
    }

    false
}

impl Client {
    pub fn new(transport: Arc<dyn Transport>, listen_once: bool) -> Self {
        Self {
            transport,
            listen_once,
        }
    }

    pub async fn send_sync<Request: Serialize, Response: DeserializeOwned>(
        &self,
        event: &str,
        data: &mut Request,
    ) -> Result<Response, MsRsError> {
        let request_id: String = thread_rng()
            .sample_iter(&Alphanumeric)
            .take(30)
            .map(char::from)
            .collect();

        let predicate_req_id_copy = request_id.clone();
        info!("RequestId: {:?}", request_id);
        let predicate =
            Box::new(move |msg: MsMessage| is_response(predicate_req_id_copy.clone(), msg));
        let result_future = self.transport.listen_response(&event, predicate);

        let mut headers: HashMap<String, String> = HashMap::new();
        headers.insert("kafka_correlationId".to_string(), request_id);
        let _ = self
            .transport
            .send(
                event,
                MsMessage(
                    Some(headers),
                    serde_json::to_string::<Request>(&data).unwrap(),
                ),
            )
            .await;
        let result = result_future.await?;

        let result = result.1.clone();

        serde_json::from_str(&result).map_err(|err| {
            error!("Error while parsing result {}: {}", result, err);
            MsRsError::SendError("Response parse error".to_string())
        })
    }

    pub async fn listen<Request, Response, Fut>(
        &self,
        event: &str,
        f: Box<dyn Fn(Request) -> Fut + Sync + Send + 'static>,
    ) -> Result<(), MsRsError>
    where
        Fut: Future<Output = Result<Option<Response>, MsRsError>> + Send + 'static,
        Request: DeserializeOwned,
        Response: Serialize,
    {
        loop {
            let data = self.transport.consume(event).await?;
            let resp_topic = &format!("{}.reply", event);

            debug!("Received message {:?}", data.1);

            let req = serde_json::from_str::<Request>(&data.1);
            if req.is_err() {
                error!("Unable to parse request {}, skipping...", event);

                continue;
            }
            let req = req.unwrap();

            let result = f(req).await;
            match result {
                Ok(result) => {
                    if let Some(result) = result {
                        self.send(resp_topic, data.0, &result).await?;
                    }
                }
                Err(error) => {
                    info!("Get an error: {}", error.to_string());

                    let mut headers: HashMap<String, String> = HashMap::new();
                    headers.insert(
                        "kafka_nest-err".to_string(),
                        format!("{{\"type\":\"rpc\",\"message\":\"{}\"}}", error.to_string()),
                    );
                    headers.insert("msrs_err".to_string(), error.to_string());

                    let correlation_id_name = "kafka_correlationId".to_string();
                    let in_headers = data.0.unwrap();
                    let correlation_id = in_headers.get(&correlation_id_name);
                    if let Some(correlation_id) = correlation_id {
                        headers.insert(correlation_id_name, correlation_id.to_owned());
                    } else {
                        println!(
                            "No correlation_id header found, sending general message to {} topic",
                            resp_topic
                        );
                    }

                    self.send(resp_topic, Some(headers), {}).await?;
                }
            }

            if self.listen_once {
                break Ok(());
            }
        }
    }

    pub async fn send<Request: Serialize>(
        &self,
        event: &str,
        headers: Option<HashMap<String, String>>,
        data: Request,
    ) -> Result<(), MsRsError> {
        self.transport
            .send(
                event,
                MsMessage(headers, serde_json::to_string(&data).unwrap()),
            )
            .await
    }
}

pub struct MockTransport {
    pub sent_data: RwLock<Vec<String>>,
    pub reply_results: RwLock<Vec<String>>,
}

impl MockTransport {
    fn _new() -> Self {
        MockTransport {
            sent_data: RwLock::new(Vec::new()),
            reply_results: RwLock::new(Vec::new()),
        }
    }
}

#[async_trait]
impl Transport for MockTransport {
    async fn send(&self, event: &str, data: MsMessage) -> Result<(), MsRsError> {
        let mut data_write = self.sent_data.write().unwrap();
        data_write.push(data.1.to_string());

        if event.contains(".reply") {
            let mut reply_results_write = self.reply_results.write().unwrap();

            reply_results_write.push(data.1.to_string());
        }

        Ok(())
    }

    async fn listen_response(
        &self,
        _event: &str,
        _predicate: Box<dyn Fn(MsMessage) -> bool + Send + Sync>,
    ) -> Result<MsMessage, MsRsError> {
        let data_read = self.sent_data.read().unwrap();
        Ok(MsMessage(None, data_read[0].clone()))
    }

    async fn consume(&self, _event: &str) -> Result<MsMessage, MsRsError> {
        let data_read = self.sent_data.read().unwrap();
        Ok(MsMessage(None, data_read[0].clone()))
    }
}

#[derive(Debug, Deserialize, Serialize)]
struct TestData {
    id: String,
}

#[tokio::test]
async fn test_event_send() {
    let transport = Arc::new(MockTransport::_new());
    let event = Client::new(transport.clone(), true);

    event.send("test_event", None, "test_data").await.unwrap();

    let sent_data = transport.sent_data.read().unwrap();
    assert_eq!(sent_data[0], "\"test_data\"");
}

#[tokio::test]
async fn test_event_listen() {
    let transport = Arc::new(MockTransport::_new());
    let event = Client::new(transport.clone(), true);
    transport
        .send(
            "test_event",
            MsMessage(None, "{\"id\": \"123\"}".to_string()),
        )
        .await
        .unwrap();

    event
        .listen(
            "test_event",
            Box::new(|x: TestData| async move { Ok(Some(x)) }),
        )
        .await
        .unwrap();

    assert_eq!(transport.clone().reply_results.read().unwrap().len(), 1)
}