tower-redis-cell 0.3.0

Rust bindings for the Redis Cell module
Documentation
use crate::config;
use crate::error::Error;
use crate::rule;
use redis::{FromRedisValue, aio::ConnectionLike};
pub use redis_cell_rs as redis_cell;
use std::{pin::Pin, sync::Arc};

pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
    inner: S,
    config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
    connection: C,
}

impl<S, PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
where
    S: Clone,
    C: Clone,
{
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
            config: Arc::clone(&self.config),
            connection: self.connection.clone(),
        }
    }
}

impl<S, PR, ReqTy, RespTy, IntoRespTy, C> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C> {
    pub fn new<RLC>(inner: S, config: RLC, connection: C) -> Self
    where
        RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
    {
        RateLimit {
            inner,
            config: config.into(),
            connection,
        }
    }
}

impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Service<ReqTy>
    for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>
where
    S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
    S::Future: Send + 'static,
    S::Error: Send,
    S::Response: Send,
    PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
    ReqTy: Send + 'static,
    IntoRespTy: Into<RespTy> + 'static,
    RespTy: 'static,
    C: ConnectionLike + Clone + Send + 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;

    fn poll_ready(
        &mut self,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), Self::Error>> {
        self.inner.poll_ready(cx)
    }

    fn call(&mut self, req: ReqTy) -> Self::Future {
        let mut connection = self.connection.clone();
        let mut inner = self.inner.clone();
        let config = self.config.clone();

        Box::pin(async move {
            let maybe_rule = match config.rule_provider.provide(&req) {
                Ok(rule) => rule,
                Err(e) => {
                    let config::OnError::Sync(ref h) = config.on_error;
                    let resp = h(Error::ProvideRule(e), &req);
                    return Ok(resp.into());
                }
            };
            let rule = match maybe_rule {
                Some(rule) => rule,
                None => {
                    return inner
                        .call(req)
                        .await
                        .map(|mut resp| match &config.on_unruled {
                            config::OnUnruled::Noop => resp,
                            config::OnUnruled::Sync(h) => {
                                h(&mut resp);
                                resp
                            }
                        });
                }
            };
            let policy = rule.policy;
            let cmd = redis_cell::Cmd::new(&rule.key, &policy);

            let redis_response = match connection.req_packed_command(&cmd.into()).await {
                Ok(res) => res,
                Err(redis_err) => {
                    let config::OnError::Sync(ref h) = config.on_error;
                    let handled = h(redis_err.into(), &req);
                    return Ok(handled.into());
                }
            };
            let redis_cell_verdict = match redis_cell::Verdict::from_redis_value(redis_response) {
                Ok(verdict) => verdict,
                Err(parsing_error) => {
                    let config::OnError::Sync(ref h) = config.on_error;
                    let handled = h(Error::Redis(parsing_error.into()), &req);
                    return Ok(handled.into());
                }
            };
            match redis_cell_verdict {
                redis_cell::Verdict::Blocked(details) => {
                    let config::OnError::Sync(ref h) = config.on_error;
                    let handled = h(
                        Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
                        &req,
                    );
                    Ok(handled.into())
                }
                redis_cell::Verdict::Allowed(details) => {
                    let policy = rule.policy;
                    let resource = rule.resource;
                    inner
                        .call(req)
                        .await
                        .map(|mut resp| match &config.on_success {
                            config::OnSuccess::Noop => resp,
                            config::OnSuccess::Sync(h) => {
                                let details = rule::RequestAllowedDetails {
                                    details,
                                    policy,
                                    resource,
                                };
                                h(details, &mut resp);
                                resp
                            }
                        })
                }
            }
        })
    }
}

pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
    config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
    connection: C,
}

impl<PR, ReqTy, RespTy, IntoRespTy, C> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
where
    C: Clone,
{
    fn clone(&self) -> Self {
        Self {
            config: Arc::clone(&self.config),
            connection: self.connection.clone(),
        }
    }
}

impl<S, PR, ReqTy, RespTy, IntoRespTy, C> tower::Layer<S>
    for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C>
where
    C: Clone,
{
    type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy, C>;
    fn layer(&self, inner: S) -> Self::Service {
        RateLimit::new(inner, Arc::clone(&self.config), self.connection.clone())
    }
}

impl<PR, ReqTy, RespTy, IntoRespTy, C> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy, C> {
    pub fn new<RLC>(config: RLC, connection: C) -> Self
    where
        RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
    {
        RateLimitLayer {
            config: config.into(),
            connection,
        }
    }
}

#[cfg(feature = "deadpool")]
#[cfg_attr(docsrs, doc(cfg(feature = "deadpool")))]
pub mod deadpool {
    use crate::config;
    use crate::error::Error;
    use crate::rule;
    use redis::{FromRedisValue, aio::ConnectionLike};
    pub use redis_cell_rs as redis_cell;
    use redis_cell_rs::Verdict;
    use std::{pin::Pin, sync::Arc};

    pub struct RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
        inner: S,
        config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
        pool: deadpool_redis::Pool,
    }

    impl<S, PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
    where
        S: Clone,
    {
        fn clone(&self) -> Self {
            Self {
                inner: self.inner.clone(),
                config: Arc::clone(&self.config),
                pool: self.pool.clone(),
            }
        }
    }

    impl<S, PR, ReqTy, RespTy, IntoRespTy> RateLimit<S, PR, ReqTy, RespTy, IntoRespTy> {
        pub fn new<RLC>(inner: S, config: RLC, pool: deadpool_redis::Pool) -> Self
        where
            RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
        {
            RateLimit {
                inner,
                config: config.into(),
                pool,
            }
        }
    }

    impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Service<ReqTy>
        for RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>
    where
        S: tower::Service<ReqTy, Response = RespTy> + Clone + Send + 'static,
        S::Future: Send + 'static,
        S::Error: Send,
        S::Response: Send,
        PR: rule::ProvideRule<ReqTy> + Clone + Send + Sync + 'static,
        ReqTy: Send + 'static,
        IntoRespTy: Into<RespTy> + 'static,
        RespTy: 'static,
    {
        type Response = S::Response;
        type Error = S::Error;
        type Future = Pin<Box<dyn Future<Output = Result<S::Response, S::Error>> + Send>>;

        fn poll_ready(
            &mut self,
            cx: &mut std::task::Context<'_>,
        ) -> std::task::Poll<Result<(), Self::Error>> {
            self.inner.poll_ready(cx)
        }

        fn call(&mut self, req: ReqTy) -> Self::Future {
            let pool = self.pool.clone();
            let mut inner = self.inner.clone();
            let config = self.config.clone();

            Box::pin(async move {
                let maybe_rule = match config.rule_provider.provide(&req) {
                    Ok(rule) => rule,
                    Err(e) => {
                        let config::OnError::Sync(ref h) = config.on_error;
                        let resp = h(Error::ProvideRule(e), &req);
                        return Ok(resp.into());
                    }
                };
                let rule = match maybe_rule {
                    Some(rule) => rule,
                    None => {
                        return inner
                            .call(req)
                            .await
                            .map(|mut resp| match &config.on_unruled {
                                config::OnUnruled::Noop => resp,
                                config::OnUnruled::Sync(h) => {
                                    h(&mut resp);
                                    resp
                                }
                            });
                    }
                };
                let policy = rule.policy;
                let cmd = redis_cell::Cmd::new(&rule.key, &policy);

                let mut connection = match pool.get().await {
                    Ok(conn) => conn,
                    Err(deadpool_err) => {
                        let config::OnError::Sync(ref h) = config.on_error;
                        let handled = h(deadpool_err.into(), &req);
                        return Ok(handled.into());
                    }
                };
                let redis_response = match connection.req_packed_command(&cmd.into()).await {
                    Ok(res) => res,
                    Err(redis_err) => {
                        let config::OnError::Sync(ref h) = config.on_error;
                        let handled = h(redis_err.into(), &req);
                        return Ok(handled.into());
                    }
                };
                let redis_cell_verdict = match Verdict::from_redis_value(redis_response) {
                    Ok(verdict) => verdict,
                    Err(parsing_error) => {
                        let config::OnError::Sync(ref h) = config.on_error;
                        let handled = h(Error::Redis(parsing_error.into()), &req);
                        return Ok(handled.into());
                    }
                };
                match redis_cell_verdict {
                    redis_cell::Verdict::Blocked(details) => {
                        let config::OnError::Sync(ref h) = config.on_error;
                        let handled = h(
                            Error::RateLimit(rule::RequestBlockedDetails { rule, details }),
                            &req,
                        );
                        Ok(handled.into())
                    }
                    redis_cell::Verdict::Allowed(details) => {
                        let policy = rule.policy;
                        let resource = rule.resource;
                        inner
                            .call(req)
                            .await
                            .map(|mut resp| match &config.on_success {
                                config::OnSuccess::Noop => resp,
                                config::OnSuccess::Sync(h) => {
                                    let details = rule::RequestAllowedDetails {
                                        details,
                                        policy,
                                        resource,
                                    };
                                    h(details, &mut resp);
                                    resp
                                }
                            })
                    }
                }
            })
        }
    }

    pub struct RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
        config: Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>,
        pool: deadpool_redis::Pool,
    }

    impl<PR, ReqTy, RespTy, IntoRespTy> Clone for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
        fn clone(&self) -> Self {
            Self {
                config: Arc::clone(&self.config),
                pool: self.pool.clone(),
            }
        }
    }

    impl<S, PR, ReqTy, RespTy, IntoRespTy> tower::Layer<S>
        for RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy>
    {
        type Service = RateLimit<S, PR, ReqTy, RespTy, IntoRespTy>;
        fn layer(&self, inner: S) -> Self::Service {
            RateLimit::new(inner, Arc::clone(&self.config), self.pool.clone())
        }
    }

    impl<PR, ReqTy, RespTy, IntoRespTy> RateLimitLayer<PR, ReqTy, RespTy, IntoRespTy> {
        pub fn new<RLC>(config: RLC, pool: deadpool_redis::Pool) -> Self
        where
            RLC: Into<Arc<config::RateLimitConfig<PR, ReqTy, RespTy, IntoRespTy>>>,
        {
            RateLimitLayer {
                config: config.into(),
                pool,
            }
        }
    }
}