klauthed_data/locks/
redis.rs1use async_trait::async_trait;
21use klauthed_core::time::Duration;
22use redis::aio::ConnectionManager;
23use redis::{ExistenceCheck, SetExpiry, SetOptions};
24
25use crate::error::DataError;
26use crate::locks::{LockGuard, LockManager, LockToken};
27
28const RELEASE_SCRIPT: &str = "\
31if redis.call('GET', KEYS[1]) == ARGV[1] then
32 return redis.call('DEL', KEYS[1])
33else
34 return 0
35end";
36
37#[derive(Clone)]
41pub struct RedisLockManager {
42 conn: ConnectionManager,
43}
44
45impl RedisLockManager {
46 pub fn new(conn: ConnectionManager) -> Self {
48 Self { conn }
49 }
50
51 pub async fn release_token(&self, key: &str, token: LockToken) -> Result<bool, DataError> {
61 let mut conn = self.conn.clone();
62 let deleted: i64 = redis::Script::new(RELEASE_SCRIPT)
63 .key(key)
64 .arg(token.to_string())
65 .invoke_async(&mut conn)
66 .await?;
67 Ok(deleted == 1)
68 }
69}
70
71#[async_trait]
72impl LockManager for RedisLockManager {
73 async fn acquire(&self, key: &str, ttl: Duration) -> Result<Option<LockGuard>, DataError> {
74 let ttl_ms: u64 = ttl.whole_milliseconds().try_into().map_err(|_| {
75 DataError::LockHeld(format!("invalid (non-positive) TTL for lock '{key}'"))
76 })?;
77 if ttl_ms == 0 {
78 return Err(DataError::LockHeld(format!("invalid (zero) TTL for lock '{key}'")));
79 }
80
81 let token = LockToken::new();
82 let options = SetOptions::default()
83 .conditional_set(ExistenceCheck::NX)
84 .with_expiration(SetExpiry::PX(ttl_ms));
85
86 let mut conn = self.conn.clone();
87 let outcome: Option<String> = redis::cmd("SET")
90 .arg(key)
91 .arg(token.to_string())
92 .arg(&options)
93 .query_async(&mut conn)
94 .await?;
95
96 match outcome {
97 Some(_) => Ok(Some(LockGuard::redis(key.to_owned(), token, self.clone()))),
98 None => Ok(None),
99 }
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 async fn live_manager() -> RedisLockManager {
110 let url = std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1/".to_owned());
111 let client = redis::Client::open(url).expect("open redis client");
112 let conn = ConnectionManager::new(client).await.expect("connect redis");
113 RedisLockManager::new(conn)
114 }
115
116 #[tokio::test]
117 #[ignore = "requires a live Redis at REDIS_URL"]
118 async fn acquire_blocks_until_released() {
119 let locks = live_manager().await;
120 let key = format!("klauthed:test:lock:{}", LockToken::new());
121
122 let guard =
123 locks.acquire(&key, Duration::seconds(30)).await.unwrap().expect("first acquire wins");
124
125 assert!(locks.acquire(&key, Duration::seconds(30)).await.unwrap().is_none());
127
128 guard.release().await.unwrap();
129
130 assert!(locks.acquire(&key, Duration::seconds(30)).await.unwrap().is_some());
132 }
133
134 #[tokio::test]
135 #[ignore = "requires a live Redis at REDIS_URL"]
136 async fn stale_token_release_does_not_steal() {
137 let locks = live_manager().await;
138 let key = format!("klauthed:test:lock:{}", LockToken::new());
139
140 let stale = locks.acquire(&key, Duration::milliseconds(50)).await.unwrap().unwrap();
141 let stale_token = stale.token();
142 tokio::time::sleep(std::time::Duration::from_millis(120)).await;
144
145 let _fresh = locks
146 .acquire(&key, Duration::seconds(30))
147 .await
148 .unwrap()
149 .expect("fresh acquire after expiry");
150
151 let freed = locks.release_token(&key, stale_token).await.unwrap();
153 assert!(!freed);
154 std::mem::forget(stale);
155 }
156}