mq-bridge 0.2.15

An asynchronous message bridging library connecting Kafka, MQTT, AMQP, NATS, MongoDB, HTTP, and more.
Documentation
use crate::models::LimiterMiddleware;
use crate::traits::{
    BoxFuture, ConsumerError, MessageConsumer, MessagePublisher, PublisherError, Received,
    ReceivedBatch, Sent, SentBatch,
};
use crate::CanonicalMessage;
use async_trait::async_trait;
use std::any::Any;
use std::sync::Mutex;
use tokio::time::{Duration, Instant};

#[derive(Debug)]
struct RateState {
    next_allowed_at: Instant,
}

impl RateState {
    fn new() -> Self {
        Self {
            next_allowed_at: Instant::now(),
        }
    }

    fn reserve(&mut self, count: usize, per_message: Duration) -> Duration {
        const MAX_DELAY: Duration = Duration::from_secs(3600);

        if count == 0 {
            return Duration::ZERO;
        }

        let now = Instant::now();
        let start_at = self.next_allowed_at.max(now);
        let additional = per_message.mul_f64(count as f64).min(MAX_DELAY);
        self.next_allowed_at = start_at
            .checked_add(additional)
            .unwrap_or_else(|| start_at + MAX_DELAY);
        start_at.saturating_duration_since(now)
    }
}

pub struct LimiterConsumer {
    inner: Box<dyn MessageConsumer>,
    per_message: Duration,
    state: RateState,
}

impl LimiterConsumer {
    pub fn new(
        inner: Box<dyn MessageConsumer>,
        config: &LimiterMiddleware,
    ) -> anyhow::Result<Self> {
        if !(config.messages_per_second.is_finite() && config.messages_per_second > 0.0) {
            return Err(anyhow::anyhow!(
                "Limiter messages_per_second must be a finite value greater than zero"
            ));
        }
        Ok(Self {
            inner,
            per_message: Duration::from_secs_f64(1.0 / config.messages_per_second),
            state: RateState::new(),
        })
    }

    async fn wait_for(&mut self, count: usize) {
        let delay = self.state.reserve(count, self.per_message);
        if !delay.is_zero() {
            tokio::time::sleep(delay).await;
        }
    }
}

#[async_trait]
impl MessageConsumer for LimiterConsumer {
    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
        self.inner.on_connect_hook()
    }

    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
        self.inner.on_disconnect_hook()
    }

    async fn receive(&mut self) -> Result<Received, ConsumerError> {
        let received = self.inner.receive().await?;
        self.wait_for(1).await;
        Ok(received)
    }

    async fn receive_batch(&mut self, max_messages: usize) -> Result<ReceivedBatch, ConsumerError> {
        let batch = self.inner.receive_batch(max_messages).await?;
        self.wait_for(batch.messages.len()).await;
        Ok(batch)
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

pub struct LimiterPublisher {
    inner: Box<dyn MessagePublisher>,
    per_message: Duration,
    state: Mutex<RateState>,
}

impl LimiterPublisher {
    pub fn new(
        inner: Box<dyn MessagePublisher>,
        config: &LimiterMiddleware,
    ) -> anyhow::Result<Self> {
        if !(config.messages_per_second.is_finite() && config.messages_per_second > 0.0) {
            return Err(anyhow::anyhow!(
                "Limiter messages_per_second must be a finite value greater than zero"
            ));
        }
        Ok(Self {
            inner,
            per_message: Duration::from_secs_f64(1.0 / config.messages_per_second),
            state: Mutex::new(RateState::new()),
        })
    }

    async fn wait_for(&self, count: usize) {
        let delay = self
            .state
            .lock()
            .expect("Limiter mutex poisoned")
            .reserve(count, self.per_message);
        if !delay.is_zero() {
            tokio::time::sleep(delay).await;
        }
    }
}

#[async_trait]
impl MessagePublisher for LimiterPublisher {
    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
        self.inner.on_connect_hook()
    }

    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
        self.inner.on_disconnect_hook()
    }

    async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
        self.wait_for(1).await;
        self.inner.send(message).await
    }

    async fn send_batch(
        &self,
        messages: Vec<CanonicalMessage>,
    ) -> Result<SentBatch, PublisherError> {
        self.wait_for(messages.len()).await;
        self.inner.send_batch(messages).await
    }

    fn as_any(&self) -> &dyn Any {
        self
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::traits::MessagePublisher;
    use crate::CanonicalMessage;
    use async_trait::async_trait;
    use std::collections::VecDeque;
    use std::sync::{Arc, Mutex as StdMutex};

    struct MockConsumer {
        batches: VecDeque<Vec<CanonicalMessage>>,
    }

    #[async_trait]
    impl MessageConsumer for MockConsumer {
        async fn receive_batch(
            &mut self,
            _max_messages: usize,
        ) -> Result<ReceivedBatch, ConsumerError> {
            Ok(ReceivedBatch {
                messages: self.batches.pop_front().expect("batch already consumed"),
                commit: ack_commit(),
            })
        }

        fn as_any(&self) -> &dyn Any {
            self
        }
    }

    #[derive(Clone)]
    struct MockPublisher {
        sent: Arc<StdMutex<Vec<CanonicalMessage>>>,
    }

    #[async_trait]
    impl MessagePublisher for MockPublisher {
        async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
            self.sent.lock().unwrap().push(message);
            Ok(Sent::Ack)
        }

        async fn send_batch(
            &self,
            messages: Vec<CanonicalMessage>,
        ) -> Result<SentBatch, PublisherError> {
            self.sent.lock().unwrap().extend(messages);
            Ok(SentBatch::Ack)
        }

        fn as_any(&self) -> &dyn Any {
            self
        }
    }

    fn ack_commit() -> crate::traits::BatchCommitFunc {
        Box::new(|_| Box::pin(async { Ok(()) }))
    }

    #[tokio::test]
    async fn test_limiter_consumer_delays_batch_by_message_count() {
        let config = LimiterMiddleware {
            messages_per_second: 20.0,
        };
        let mut consumer = LimiterConsumer::new(
            Box::new(MockConsumer {
                batches: VecDeque::from([
                    vec![CanonicalMessage::from("one"), CanonicalMessage::from("two")],
                    vec![
                        CanonicalMessage::from("three"),
                        CanonicalMessage::from("four"),
                    ],
                ]),
            }),
            &config,
        )
        .unwrap();

        let first = consumer.receive_batch(10).await.unwrap();
        let start = Instant::now();
        let second = consumer.receive_batch(10).await.unwrap();
        let elapsed = start.elapsed();

        assert_eq!(first.messages.len(), 2);
        assert_eq!(second.messages.len(), 2);
        assert!(elapsed >= Duration::from_millis(90));
    }

    #[tokio::test]
    async fn test_limiter_publisher_delays_consecutive_sends() {
        let config = LimiterMiddleware {
            messages_per_second: 20.0,
        };
        let sent = Arc::new(StdMutex::new(Vec::new()));
        let publisher =
            LimiterPublisher::new(Box::new(MockPublisher { sent: sent.clone() }), &config).unwrap();

        let start = Instant::now();
        publisher.send(CanonicalMessage::from("one")).await.unwrap();
        publisher.send(CanonicalMessage::from("two")).await.unwrap();
        let elapsed = start.elapsed();

        assert_eq!(sent.lock().unwrap().len(), 2);
        assert!(elapsed >= Duration::from_millis(45));
    }
}