oxana 2.0.0-rc.3

A simple & fast job queue system.
Documentation
use crate::OxanaError;
use deadpool_redis::redis;
use uuid::Uuid;

pub struct Throttler {
    redis_pool: deadpool_redis::Pool,
    key: String,
    limit: u64,
    window_ms: i64,
}

#[derive(Debug)]
pub struct ThrottlerState {
    pub requests: u64,
    pub is_allowed: bool,
    pub throttled_for: Option<i64>,
}

impl Throttler {
    pub fn new(redis_pool: deadpool_redis::Pool, key: &str, limit: u64, window_ms: i64) -> Self {
        Throttler {
            redis_pool,
            key: Self::build_key(key),
            limit,
            window_ms,
        }
    }

    pub async fn consume(&self, cost: Option<u64>) -> Result<ThrottlerState, OxanaError> {
        let mut redis = self.redis_pool.get().await?;
        let current_time = u64::try_from(chrono::Utc::now().timestamp_micros())?;
        let state = self.state_w_conn(&mut redis).await?;

        if state.is_allowed {
            let effective_cost = cost.unwrap_or(1);

            if effective_cost == 0 {
                return Ok(state);
            }

            let members: Vec<(u64, String)> = (0..effective_cost)
                .map(|_| (current_time, Uuid::new_v4().to_string()))
                .collect();

            let mut pipe = redis::pipe();
            pipe.zadd_multiple(&self.key, &members)
                .expire(&self.key, self.window_s());
            let (updated, _): (u64, ()) = pipe.query_async(&mut redis).await?;

            Ok(ThrottlerState {
                requests: state.requests + updated,
                ..state
            })
        } else {
            Ok(state)
        }
    }

    pub async fn state(&self) -> Result<ThrottlerState, OxanaError> {
        let mut redis = self.redis_pool.get().await?;
        self.state_w_conn(&mut redis).await
    }

    async fn state_w_conn(
        &self,
        redis: &mut deadpool_redis::Connection,
    ) -> Result<ThrottlerState, OxanaError> {
        let now = chrono::Utc::now().timestamp_micros();
        let window_start = now - self.window_micros();

        let (_, first, request_count): ((), Vec<(String, f64)>, u64) = redis::pipe()
            .zrembyscore(&self.key, 0, window_start)
            .zrange_withscores(&self.key, 0, 0)
            .zcard(&self.key)
            .query_async(&mut *redis)
            .await?;

        let accurate_window_start = if let Some((_, score)) = first.first() {
            Some(*score as i64)
        } else {
            None
        };

        let is_allowed = request_count < self.limit;

        let throttled_for_micros = if is_allowed {
            None
        } else {
            accurate_window_start.map(|start| now - start + self.window_micros())
        };

        let throttled_for = throttled_for_micros.map(|micros| micros / 1000 + 1);

        Ok(ThrottlerState {
            requests: request_count,
            is_allowed,
            throttled_for,
        })
    }

    fn build_key(key: &str) -> String {
        format!("oxana:throttler:{key}")
    }

    fn window_s(&self) -> i64 {
        self.window_ms / 1000
    }

    fn window_micros(&self) -> i64 {
        self.window_ms * 1000
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::test_helper::*;
    use testresult::TestResult;

    #[tokio::test]
    async fn test_consume() -> TestResult {
        let pool = redis_pool().await?;
        let key = random_string();
        let rate_limiter = Throttler::new(pool, &key, 2, 60000);
        assert!(rate_limiter.consume(None).await?.is_allowed);
        assert!(rate_limiter.consume(None).await?.is_allowed);
        let state = rate_limiter.consume(None).await?;
        assert!(!state.is_allowed);
        assert!(state.throttled_for.is_some());
        assert!(state.throttled_for.unwrap_or(0) >= 60000);
        let state = rate_limiter.consume(None).await?;
        assert!(!state.is_allowed);
        assert!(state.throttled_for.is_some());
        assert!(state.throttled_for.unwrap_or(0) >= 60000);

        Ok(())
    }

    #[tokio::test]
    async fn test_consume_with_cost() -> TestResult {
        let pool = redis_pool().await?;
        let key = random_string();
        let rate_limiter = Throttler::new(pool, &key, 4, 60000);
        assert!(rate_limiter.consume(Some(2)).await?.is_allowed);
        assert!(rate_limiter.consume(Some(2)).await?.is_allowed);
        let state = rate_limiter.consume(Some(1)).await?;
        assert!(!state.is_allowed);

        Ok(())
    }
}