use crate::HttpRequest;
use crate::error::Error;
use std::fmt::{Debug, Formatter};
use std::sync::Arc;
pub trait RateLimitKey: Send + Sync {
fn extract(&self, req: &HttpRequest) -> Result<u64, Error>;
}
pub trait RateLimitKeyExt: Sized + RateLimitKey + 'static {
fn bind(self) -> RateLimitBinding;
}
impl RateLimitKey for RateLimitBinding {
#[inline]
fn extract(&self, req: &HttpRequest) -> Result<u64, Error> {
self.key.extract(req)
}
}
impl RateLimitKeyExt for RateLimitBinding {
#[inline]
fn bind(self) -> RateLimitBinding {
self
}
}
pub type PolicyName = Arc<str>;
#[derive(Clone)]
pub struct RateLimitBinding {
pub(super) key: Arc<dyn RateLimitKey>,
pub(super) policy: Option<PolicyName>,
}
impl Debug for RateLimitBinding {
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RateLimitBinding(...)").finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::HttpBody;
use hyper::Request;
use std::sync::Arc;
struct MockKey {
value: u64,
}
impl RateLimitKey for MockKey {
fn extract(&self, _req: &HttpRequest) -> Result<u64, Error> {
Ok(self.value)
}
}
struct ErrorKey;
impl RateLimitKey for ErrorKey {
fn extract(&self, _req: &HttpRequest) -> Result<u64, Error> {
Err(Error::server_error("Mock error"))
}
}
fn create_binding_with_policy(
key: Arc<dyn RateLimitKey>,
policy: Option<PolicyName>,
) -> RateLimitBinding {
RateLimitBinding { key, policy }
}
fn create_request() -> HttpRequest {
let (parts, body) = Request::get("/")
.body(HttpBody::empty())
.unwrap()
.into_parts();
HttpRequest::from_parts(parts, body)
}
#[test]
fn it_extracts_key_from_mock_implementation() {
let key = MockKey { value: 42 };
let req = create_request();
let result = key.extract(&req);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[test]
fn it_propagates_extraction_errors() {
let key = ErrorKey;
let req = create_request();
let result = key.extract(&req);
assert!(result.is_err());
}
#[test]
fn it_creates_binding_without_policy_name() {
let key = Arc::new(MockKey { value: 200 });
let binding = create_binding_with_policy(key, None);
assert!(binding.policy.is_none());
}
#[test]
fn it_creates_binding_with_policy_name() {
let key = Arc::new(MockKey { value: 100 });
let policy_name: Arc<str> = Arc::from("burst");
let binding = create_binding_with_policy(key, Some(policy_name.clone()));
assert!(binding.policy.is_some());
assert_eq!(binding.policy.as_ref().unwrap().as_ref(), "burst");
}
#[test]
fn it_extracts_key_through_binding() {
let key = Arc::new(MockKey { value: 123 });
let binding = create_binding_with_policy(key, Some(Arc::from("test_policy")));
let req = create_request();
let result = binding.extract(&req);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 123);
}
#[test]
fn it_propagates_errors_through_binding() {
let key = Arc::new(ErrorKey);
let binding = create_binding_with_policy(key, Some(Arc::from("test_policy")));
let req = create_request();
let result = binding.extract(&req);
assert!(result.is_err());
}
#[test]
fn it_clones_binding_correctly() {
let key = Arc::new(MockKey { value: 300 });
let binding = create_binding_with_policy(key, Some(Arc::from("original_policy")));
let cloned = binding.clone();
assert!(cloned.policy.is_some());
assert_eq!(cloned.policy.as_ref().unwrap().as_ref(), "original_policy");
}
#[test]
fn it_clones_binding_with_shared_policy_reference() {
let key = Arc::new(MockKey { value: 400 });
let policy: Arc<str> = Arc::from("shared_policy");
let binding = create_binding_with_policy(key, Some(policy));
let cloned = binding.clone();
assert!(Arc::ptr_eq(
binding.policy.as_ref().unwrap(),
cloned.policy.as_ref().unwrap()
));
}
#[test]
fn it_creates_multiple_independent_bindings() {
let key1 = Arc::new(MockKey { value: 100 });
let key2 = Arc::new(MockKey { value: 200 });
let binding1 = create_binding_with_policy(key1, Some(Arc::from("policy1")));
let binding2 = create_binding_with_policy(key2, Some(Arc::from("policy2")));
let req = create_request();
assert_eq!(binding1.extract(&req).unwrap(), 100);
assert_eq!(binding2.extract(&req).unwrap(), 200);
assert_eq!(binding1.policy.as_ref().unwrap().as_ref(), "policy1");
assert_eq!(binding2.policy.as_ref().unwrap().as_ref(), "policy2");
}
#[test]
fn it_binds_binding_to_itself() {
let key = Arc::new(MockKey { value: 500 });
let binding = create_binding_with_policy(key, Some(Arc::from("test_policy")));
let bound = binding.clone().bind();
assert!(bound.policy.is_some());
assert_eq!(bound.policy.as_ref().unwrap().as_ref(), "test_policy");
}
#[test]
fn it_extracts_same_value_multiple_times() {
let key = Arc::new(MockKey { value: 777 });
let binding = create_binding_with_policy(key, Some(Arc::from("consistent")));
let req = create_request();
let result1 = binding.extract(&req);
let result2 = binding.extract(&req);
assert_eq!(result1.unwrap(), 777);
assert_eq!(result2.unwrap(), 777);
}
#[test]
fn it_handles_empty_policy_name() {
let key = Arc::new(MockKey { value: 999 });
let binding = create_binding_with_policy(key, Some(Arc::from("")));
assert!(binding.policy.is_some());
assert_eq!(binding.policy.as_ref().unwrap().as_ref(), "");
}
#[test]
fn it_handles_long_policy_name() {
let key = Arc::new(MockKey { value: 111 });
let long_name = "very_long_policy_name_that_might_be_used_in_real_scenarios";
let binding = create_binding_with_policy(key, Some(Arc::from(long_name)));
assert!(binding.policy.is_some());
assert_eq!(binding.policy.as_ref().unwrap().as_ref(), long_name);
}
#[test]
fn it_formats_binding_debug_output() {
let key = Arc::new(MockKey { value: 222 });
let binding = create_binding_with_policy(key, Some(Arc::from("debug_test")));
let debug_str = format!("{:?}", binding);
assert!(debug_str.contains("RateLimitBinding"));
}
#[test]
fn it_shares_key_across_multiple_bindings() {
let key = Arc::new(MockKey { value: 333 });
let binding1 = create_binding_with_policy(key.clone(), Some(Arc::from("policy1")));
let binding2 = create_binding_with_policy(key.clone(), Some(Arc::from("policy2")));
let req = create_request();
assert_eq!(binding1.extract(&req).unwrap(), 333);
assert_eq!(binding2.extract(&req).unwrap(), 333);
}
#[test]
fn it_allows_none_policy_in_binding_construction() {
let key = Arc::new(MockKey { value: 444 });
let binding = create_binding_with_policy(key, None);
assert!(binding.policy.is_none());
}
#[test]
fn it_implements_send_and_sync_for_mock_key() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<MockKey>();
}
#[test]
fn it_implements_send_and_sync_for_binding() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<RateLimitBinding>();
}
#[test]
fn it_extracts_key_after_bind() {
let key = Arc::new(MockKey { value: 888 });
let binding = create_binding_with_policy(key, Some(Arc::from("policy")));
let req = create_request();
let bound = binding.bind();
let result = bound.extract(&req);
assert!(result.is_ok());
assert_eq!(result.unwrap(), 888);
}
#[test]
fn it_preserves_policy_after_bind() {
let key = Arc::new(MockKey { value: 555 });
let policy_name: Arc<str> = Arc::from("preserved");
let binding = create_binding_with_policy(key, Some(policy_name.clone()));
let bound = binding.bind();
assert!(bound.policy.is_some());
assert!(Arc::ptr_eq(bound.policy.as_ref().unwrap(), &policy_name));
}
#[test]
fn it_clones_binding_with_none_policy() {
let key = Arc::new(MockKey { value: 666 });
let binding = create_binding_with_policy(key, None);
let cloned = binding.clone();
assert!(cloned.policy.is_none());
}
#[test]
fn it_handles_multiple_binds_on_same_binding() {
let key = Arc::new(MockKey { value: 999 });
let binding = create_binding_with_policy(key, Some(Arc::from("multi_bind")));
let bound1 = binding.clone().bind();
let bound2 = binding.clone().bind();
assert_eq!(bound1.policy.as_ref().unwrap().as_ref(), "multi_bind");
assert_eq!(bound2.policy.as_ref().unwrap().as_ref(), "multi_bind");
}
}