Skip to main content

openauth_redis/
lib.rs

1//! Redis-backed integrations for OpenAuth.
2//!
3//! The rate limit store uses `redis-rs` with the async
4//! `redis::aio::ConnectionManager`, RESP-compatible Redis or Valkey servers,
5//! Lua scripting for atomic consume decisions, and core commands that are
6//! shared by Redis and Valkey.
7
8use std::borrow::Cow;
9
10use openauth_core::error::OpenAuthError;
11use openauth_core::options::{
12    RateLimitConsumeInput, RateLimitDecision, RateLimitFuture, RateLimitStore,
13};
14use redis::aio::ConnectionManager;
15use redis::Script;
16
17const RATE_LIMIT_SCRIPT: &str = r#"
18local key = KEYS[1]
19local now = tonumber(ARGV[1])
20local window = tonumber(ARGV[2])
21local max = tonumber(ARGV[3])
22
23local data = redis.call("HMGET", key, "count", "last_request")
24local count = tonumber(data[1])
25local last_request = tonumber(data[2])
26
27if count == nil or last_request == nil or (now - last_request) > window then
28  redis.call("HSET", key, "count", 1, "last_request", now)
29  redis.call("PEXPIRE", key, window)
30  return {1, 1, now}
31end
32
33if count >= max then
34  redis.call("PEXPIRE", key, window)
35  return {0, count, last_request}
36end
37
38count = count + 1
39redis.call("HSET", key, "count", count, "last_request", now)
40redis.call("PEXPIRE", key, window)
41return {1, count, now}
42"#;
43
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub struct RedisRateLimitOptions {
46    pub key_prefix: String,
47}
48
49impl Default for RedisRateLimitOptions {
50    fn default() -> Self {
51        Self {
52            key_prefix: "openauth:".to_owned(),
53        }
54    }
55}
56
57#[derive(Clone)]
58pub struct RedisRateLimitStore {
59    manager: ConnectionManager,
60    options: RedisRateLimitOptions,
61}
62
63impl RedisRateLimitStore {
64    pub async fn connect(redis_url: &str) -> Result<Self, OpenAuthError> {
65        let redis_url = normalize_redis_url(redis_url);
66        let client = redis::Client::open(redis_url.as_ref())
67            .map_err(|error| OpenAuthError::Adapter(error.to_string()))?;
68        let manager = ConnectionManager::new(client)
69            .await
70            .map_err(|error| OpenAuthError::Adapter(error.to_string()))?;
71        Ok(Self::new(manager, RedisRateLimitOptions::default()))
72    }
73
74    pub fn new(manager: ConnectionManager, options: RedisRateLimitOptions) -> Self {
75        Self { manager, options }
76    }
77
78    fn key(&self, key: &str) -> String {
79        format!("{}rate-limit:{key}", self.options.key_prefix)
80    }
81}
82
83fn normalize_redis_url(redis_url: &str) -> Cow<'_, str> {
84    if let Some(rest) = redis_url.strip_prefix("valkey://") {
85        return Cow::Owned(format!("redis://{rest}"));
86    }
87    if let Some(rest) = redis_url.strip_prefix("valkeys://") {
88        return Cow::Owned(format!("rediss://{rest}"));
89    }
90    Cow::Borrowed(redis_url)
91}
92
93impl RateLimitStore for RedisRateLimitStore {
94    fn consume<'a>(&'a self, input: RateLimitConsumeInput) -> RateLimitFuture<'a> {
95        Box::pin(async move {
96            let redis_key = self.key(&input.key);
97            let window_ms = input.rule.window.saturating_mul(1000);
98            let mut manager = self.manager.clone();
99            let result: (i64, i64, i64) = Script::new(RATE_LIMIT_SCRIPT)
100                .key(redis_key)
101                .arg(input.now_ms)
102                .arg(window_ms as i64)
103                .arg(input.rule.max as i64)
104                .invoke_async(&mut manager)
105                .await
106                .map_err(|error| OpenAuthError::Adapter(error.to_string()))?;
107            let permitted = result.0 == 1;
108            let count = result.1.max(0) as u64;
109            let last_request = result.2;
110            let retry_ms = last_request
111                .saturating_add(window_ms as i64)
112                .saturating_sub(input.now_ms)
113                .max(0);
114            Ok(RateLimitDecision {
115                permitted,
116                retry_after: if permitted {
117                    0
118                } else {
119                    ceil_millis_to_seconds(retry_ms)
120                },
121                limit: input.rule.max,
122                remaining: input.rule.max.saturating_sub(count),
123                reset_after: ceil_millis_to_seconds(retry_ms),
124            })
125        })
126    }
127}
128
129fn ceil_millis_to_seconds(milliseconds: i64) -> u64 {
130    if milliseconds <= 0 {
131        return 0;
132    }
133    ((milliseconds as u64).saturating_add(999)) / 1000
134}
135
136/// Current crate version.
137pub const VERSION: &str = env!("CARGO_PKG_VERSION");
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn normalizes_valkey_urls_to_redis_urls() {
145        assert_eq!(
146            normalize_redis_url("valkey://localhost:6379").as_ref(),
147            "redis://localhost:6379"
148        );
149        assert_eq!(
150            normalize_redis_url("valkeys://localhost:6380").as_ref(),
151            "rediss://localhost:6380"
152        );
153    }
154
155    #[test]
156    fn leaves_non_valkey_urls_unchanged() {
157        assert_eq!(
158            normalize_redis_url("redis://localhost:6379").as_ref(),
159            "redis://localhost:6379"
160        );
161        assert_eq!(
162            normalize_redis_url("rediss://localhost:6380").as_ref(),
163            "rediss://localhost:6380"
164        );
165        assert_eq!(
166            normalize_redis_url("unix:///tmp/redis.sock").as_ref(),
167            "unix:///tmp/redis.sock"
168        );
169    }
170
171    #[test]
172    fn rate_limit_script_uses_current_hash_set_command() {
173        assert!(RATE_LIMIT_SCRIPT.contains("HSET"));
174        assert!(!RATE_LIMIT_SCRIPT.contains("HMSET"));
175    }
176}