Skip to main content

authx_plugins/
redis_token_store.rs

1//! Redis-backed one-time token store — use with the `redis-tokens` feature.
2//!
3//! Drop-in replacement for [`OneTimeTokenStore`] for multi-instance deployments.
4//! Tokens are stored with Redis `SET … EX` and consumed atomically with a Lua
5//! GET+DEL script.
6#[cfg(feature = "redis-tokens")]
7mod inner {
8    use redis::{AsyncCommands, Client, Script, aio::MultiplexedConnection};
9    use uuid::Uuid;
10
11    use crate::one_time_token::TokenKind;
12    use authx_core::crypto::sha256_hex;
13    use authx_core::error::{AuthError, Result};
14
15    /// Serialization envelope stored in Redis.
16    #[derive(serde::Serialize, serde::Deserialize)]
17    struct RedisRecord {
18        kind: u8,
19        user_id: Uuid,
20    }
21
22    fn kind_byte(k: &TokenKind) -> u8 {
23        match k {
24            TokenKind::PasswordReset => 0,
25            TokenKind::MagicLink => 1,
26            TokenKind::EmailVerification => 2,
27            TokenKind::EmailOtp => 3,
28        }
29    }
30
31    fn kind_from_byte(b: u8) -> Option<TokenKind> {
32        match b {
33            0 => Some(TokenKind::PasswordReset),
34            1 => Some(TokenKind::MagicLink),
35            2 => Some(TokenKind::EmailVerification),
36            3 => Some(TokenKind::EmailOtp),
37            _ => None,
38        }
39    }
40
41    /// Redis-backed single-use token store.
42    ///
43    /// # Usage
44    /// ```rust,ignore
45    /// let store = RedisTokenStore::connect("redis://127.0.0.1/").await?;
46    /// let token = store.issue(user_id, TokenKind::MagicLink, 900).await?;
47    /// let uid   = store.consume(&token, TokenKind::MagicLink).await?;
48    /// ```
49    #[derive(Clone)]
50    pub struct RedisTokenStore {
51        client: Client,
52    }
53
54    impl RedisTokenStore {
55        pub async fn connect(redis_url: &str) -> Result<Self> {
56            let client = Client::open(redis_url)
57                .map_err(|e| AuthError::Internal(format!("redis connect: {e}")))?;
58            tracing::info!("redis token store connected");
59            Ok(Self { client })
60        }
61
62        async fn conn(&self) -> Result<MultiplexedConnection> {
63            self.client
64                .get_multiplexed_async_connection()
65                .await
66                .map_err(|e| AuthError::Internal(format!("redis connection: {e}")))
67        }
68
69        /// Issue a token with `ttl_seconds` expiry. Returns the raw (un-hashed) token.
70        pub async fn issue(
71            &self,
72            user_id: Uuid,
73            kind: TokenKind,
74            ttl_seconds: u64,
75        ) -> Result<String> {
76            let raw: [u8; 32] = rand::Rng::r#gen(&mut rand::thread_rng());
77            let token = hex::encode(raw);
78            let hash = sha256_hex(token.as_bytes());
79
80            let record = RedisRecord {
81                kind: kind_byte(&kind),
82                user_id,
83            };
84            let json = serde_json::to_string(&record)
85                .map_err(|e| AuthError::Internal(format!("redis token serialize: {e}")))?;
86
87            let mut conn = self.conn().await?;
88            let _: () = conn
89                .set_ex(&hash, json, ttl_seconds)
90                .await
91                .map_err(|e| AuthError::Internal(format!("redis SET: {e}")))?;
92
93            tracing::debug!(user_id = %user_id, "redis: one-time token issued");
94            Ok(token)
95        }
96
97        /// Consume a token atomically (Lua GET+DEL). Returns `None` if the token
98        /// is expired, not found, or the wrong kind.
99        pub async fn consume(
100            &self,
101            raw_token: &str,
102            expected_kind: TokenKind,
103        ) -> Result<Option<Uuid>> {
104            let hash = sha256_hex(raw_token.as_bytes());
105
106            // Atomic GET+DEL via Lua so no other replica can consume the same token.
107            let lua = Script::new(
108                r#"
109                local val = redis.call('GET', KEYS[1])
110                if val == false then return nil end
111                redis.call('DEL', KEYS[1])
112                return val
113                "#,
114            );
115
116            let mut conn = self.conn().await?;
117            let raw_json: Option<String> = lua
118                .key(&hash)
119                .invoke_async(&mut conn)
120                .await
121                .map_err(|e| AuthError::Internal(format!("redis lua: {e}")))?;
122
123            let json = match raw_json {
124                Some(j) => j,
125                None => {
126                    tracing::debug!("redis: token not found or expired");
127                    return Ok(None);
128                }
129            };
130
131            let record: RedisRecord = serde_json::from_str(&json)
132                .map_err(|e| AuthError::Internal(format!("redis token deserialize: {e}")))?;
133
134            if kind_from_byte(record.kind).as_ref() != Some(&expected_kind) {
135                tracing::debug!("redis: token kind mismatch");
136                return Ok(None);
137            }
138
139            tracing::debug!(user_id = %record.user_id, "redis: one-time token consumed");
140            Ok(Some(record.user_id))
141        }
142    }
143}
144
145#[cfg(feature = "redis-tokens")]
146pub use inner::RedisTokenStore;