#![doc = concat!("actix-limitation = \"", env!("CARGO_PKG_VERSION_MAJOR"), ".", env!("CARGO_PKG_VERSION_MINOR"),"\"")]
#![forbid(unsafe_code)]
#![deny(rust_2018_idioms, nonstandard_style)]
#![warn(future_incompatible, missing_docs, missing_debug_implementations)]
#![doc(html_logo_url = "https://actix.rs/img/logo.png")]
#![doc(html_favicon_url = "https://actix.rs/favicon.ico")]
use std::{borrow::Cow, fmt, sync::Arc, time::Duration};
use actix_web::dev::ServiceRequest;
use redis::Client;
mod builder;
mod errors;
mod middleware;
mod status;
pub use self::{builder::Builder, errors::Error, middleware::RateLimiter, status::Status};
pub const DEFAULT_REQUEST_LIMIT: usize = 5000;
pub const DEFAULT_PERIOD_SECS: u64 = 3600;
pub const DEFAULT_COOKIE_NAME: &str = "sid";
#[cfg(feature = "session")]
pub const DEFAULT_SESSION_KEY: &str = "rate-api-id";
trait GetKeyFnT: Fn(&ServiceRequest) -> Option<String> {}
impl<T> GetKeyFnT for T where T: Fn(&ServiceRequest) -> Option<String> {}
type GetKeyFn = dyn GetKeyFnT + Send + Sync;
impl fmt::Debug for GetKeyFn {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "GetKeyFn")
}
}
type GetArcBoxKeyFn = Arc<GetKeyFn>;
#[derive(Debug, Clone)]
pub struct Limiter {
client: Client,
limit: usize,
period: Duration,
get_key_fn: GetArcBoxKeyFn,
}
impl Limiter {
#[must_use]
pub fn builder(redis_url: impl Into<String>) -> Builder {
Builder {
redis_url: redis_url.into(),
limit: DEFAULT_REQUEST_LIMIT,
period: Duration::from_secs(DEFAULT_PERIOD_SECS),
get_key_fn: None,
cookie_name: Cow::Borrowed(DEFAULT_COOKIE_NAME),
#[cfg(feature = "session")]
session_key: Cow::Borrowed(DEFAULT_SESSION_KEY),
}
}
pub async fn count(&self, key: impl Into<String>) -> Result<Status, Error> {
let (count, reset) = self.track(key).await?;
let status = Status::new(count, self.limit, reset);
if count > self.limit {
Err(Error::LimitExceeded(status))
} else {
Ok(status)
}
}
async fn track(&self, key: impl Into<String>) -> Result<(usize, usize), Error> {
let key = key.into();
let expires = self.period.as_secs();
let mut connection = self.client.get_tokio_connection().await?;
let mut pipe = redis::pipe();
pipe.atomic()
.cmd("SET") .arg(&key)
.arg(0)
.arg("EX") .arg(expires)
.arg("NX") .ignore() .cmd("INCR") .arg(&key)
.cmd("TTL") .arg(&key);
let (count, ttl) = pipe.query_async(&mut connection).await?;
let reset = Status::epoch_utc_plus(Duration::from_secs(ttl))?;
Ok((count, reset))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_limiter() {
let mut builder = Limiter::builder("redis://127.0.0.1:6379/1");
let limiter = builder.build();
assert!(limiter.is_ok());
let limiter = limiter.unwrap();
assert_eq!(limiter.limit, 5000);
assert_eq!(limiter.period, Duration::from_secs(3600));
}
}