authx_plugins/
redis_token_store.rs1#[cfg(feature = "redis-tokens")]
7mod inner {
8 use redis::{AsyncCommands, Client, Script, aio::MultiplexedConnection};
9 use uuid::Uuid;
10
11 use crate::one_time_token::TokenKind;
12 use authx_core::crypto::sha256_hex;
13 use authx_core::error::{AuthError, Result};
14
15 #[derive(serde::Serialize, serde::Deserialize)]
17 struct RedisRecord {
18 kind: u8,
19 user_id: Uuid,
20 }
21
22 fn kind_byte(k: &TokenKind) -> u8 {
23 match k {
24 TokenKind::PasswordReset => 0,
25 TokenKind::MagicLink => 1,
26 TokenKind::EmailVerification => 2,
27 TokenKind::EmailOtp => 3,
28 }
29 }
30
31 fn kind_from_byte(b: u8) -> Option<TokenKind> {
32 match b {
33 0 => Some(TokenKind::PasswordReset),
34 1 => Some(TokenKind::MagicLink),
35 2 => Some(TokenKind::EmailVerification),
36 3 => Some(TokenKind::EmailOtp),
37 _ => None,
38 }
39 }
40
41 #[derive(Clone)]
50 pub struct RedisTokenStore {
51 client: Client,
52 }
53
54 impl RedisTokenStore {
55 pub async fn connect(redis_url: &str) -> Result<Self> {
56 let client = Client::open(redis_url)
57 .map_err(|e| AuthError::Internal(format!("redis connect: {e}")))?;
58 tracing::info!("redis token store connected");
59 Ok(Self { client })
60 }
61
62 async fn conn(&self) -> Result<MultiplexedConnection> {
63 self.client
64 .get_multiplexed_async_connection()
65 .await
66 .map_err(|e| AuthError::Internal(format!("redis connection: {e}")))
67 }
68
69 pub async fn issue(
71 &self,
72 user_id: Uuid,
73 kind: TokenKind,
74 ttl_seconds: u64,
75 ) -> Result<String> {
76 let raw: [u8; 32] = rand::Rng::r#gen(&mut rand::thread_rng());
77 let token = hex::encode(raw);
78 let hash = sha256_hex(token.as_bytes());
79
80 let record = RedisRecord {
81 kind: kind_byte(&kind),
82 user_id,
83 };
84 let json = serde_json::to_string(&record)
85 .map_err(|e| AuthError::Internal(format!("redis token serialize: {e}")))?;
86
87 let mut conn = self.conn().await?;
88 let _: () = conn
89 .set_ex(&hash, json, ttl_seconds)
90 .await
91 .map_err(|e| AuthError::Internal(format!("redis SET: {e}")))?;
92
93 tracing::debug!(user_id = %user_id, "redis: one-time token issued");
94 Ok(token)
95 }
96
97 pub async fn consume(
100 &self,
101 raw_token: &str,
102 expected_kind: TokenKind,
103 ) -> Result<Option<Uuid>> {
104 let hash = sha256_hex(raw_token.as_bytes());
105
106 let lua = Script::new(
108 r#"
109 local val = redis.call('GET', KEYS[1])
110 if val == false then return nil end
111 redis.call('DEL', KEYS[1])
112 return val
113 "#,
114 );
115
116 let mut conn = self.conn().await?;
117 let raw_json: Option<String> = lua
118 .key(&hash)
119 .invoke_async(&mut conn)
120 .await
121 .map_err(|e| AuthError::Internal(format!("redis lua: {e}")))?;
122
123 let json = match raw_json {
124 Some(j) => j,
125 None => {
126 tracing::debug!("redis: token not found or expired");
127 return Ok(None);
128 }
129 };
130
131 let record: RedisRecord = serde_json::from_str(&json)
132 .map_err(|e| AuthError::Internal(format!("redis token deserialize: {e}")))?;
133
134 if kind_from_byte(record.kind).as_ref() != Some(&expected_kind) {
135 tracing::debug!("redis: token kind mismatch");
136 return Ok(None);
137 }
138
139 tracing::debug!(user_id = %record.user_id, "redis: one-time token consumed");
140 Ok(Some(record.user_id))
141 }
142 }
143}
144
145#[cfg(feature = "redis-tokens")]
146pub use inner::RedisTokenStore;