rain_engine_store_valkey/
lib.rs1use async_trait::async_trait;
7use rain_engine_core::{
8 CoordinationClaim, CoordinationError, CoordinationStore, InMemoryCoordinationStore,
9};
10use redis::{Commands, Connection, cmd};
11use serde_json::Value;
12use std::time::Duration;
13
14#[derive(Clone)]
15pub enum ValkeyBackend {
16 Redis(redis::Client),
17 InMemory(InMemoryCoordinationStore),
18}
19
20#[derive(Clone)]
21pub struct ValkeyCoordinationStore {
22 namespace: String,
23 backend: ValkeyBackend,
24}
25
26impl ValkeyCoordinationStore {
27 pub fn connect(url: &str) -> Result<Self, CoordinationError> {
28 Self::connect_with_namespace(url, "rain_engine")
29 }
30
31 pub fn connect_with_namespace(
32 url: &str,
33 namespace: impl Into<String>,
34 ) -> Result<Self, CoordinationError> {
35 let namespace = namespace.into();
36 let backend = if url == "memory://" {
37 ValkeyBackend::InMemory(InMemoryCoordinationStore::new())
38 } else {
39 ValkeyBackend::Redis(
40 redis::Client::open(url).map_err(|err| CoordinationError::new(err.to_string()))?,
41 )
42 };
43 Ok(Self { namespace, backend })
44 }
45
46 fn claim_key(&self, trigger_key: &str) -> String {
47 format!("{}:claim:{trigger_key}", self.namespace)
48 }
49
50 fn reverse_claim_key(&self, claim_id: &str) -> String {
51 format!("{}:claim_by_id:{claim_id}", self.namespace)
52 }
53
54 fn scratchpad_key(&self, namespace: &str, key: &str) -> String {
55 format!("{}:scratchpad:{namespace}:{key}", self.namespace)
56 }
57
58 async fn with_connection<T, F>(&self, operation: F) -> Result<T, CoordinationError>
59 where
60 T: Send + 'static,
61 F: FnOnce(Connection) -> Result<T, CoordinationError> + Send + 'static,
62 {
63 let ValkeyBackend::Redis(client) = &self.backend else {
64 return Err(CoordinationError::new("redis backend not configured"));
65 };
66 let client = client.clone();
67 tokio::task::spawn_blocking(move || {
68 let connection = client
69 .get_connection()
70 .map_err(|err| CoordinationError::new(err.to_string()))?;
71 operation(connection)
72 })
73 .await
74 .map_err(|err| CoordinationError::new(err.to_string()))?
75 }
76}
77
78#[async_trait]
79impl CoordinationStore for ValkeyCoordinationStore {
80 async fn claim_trigger(
81 &self,
82 trigger_key: &str,
83 ttl: Duration,
84 ) -> Result<Option<CoordinationClaim>, CoordinationError> {
85 if let ValkeyBackend::InMemory(store) = &self.backend {
86 return store.claim_trigger(trigger_key, ttl).await;
87 }
88
89 let claim_key = self.claim_key(trigger_key);
90 let claim_id = uuid::Uuid::new_v4().to_string();
91 let reverse_key = self.reverse_claim_key(&claim_id);
92 let ttl_ms = ttl.as_millis() as usize;
93 let trigger_key_owned = trigger_key.to_string();
94
95 self.with_connection(move |mut connection| {
96 let acquired: Option<String> = cmd("SET")
97 .arg(&claim_key)
98 .arg(&claim_id)
99 .arg("NX")
100 .arg("PX")
101 .arg(ttl_ms)
102 .query(&mut connection)
103 .map_err(|err| CoordinationError::new(err.to_string()))?;
104 if acquired.is_none() {
105 return Ok(None);
106 }
107 let _: () = cmd("SET")
108 .arg(&reverse_key)
109 .arg(&trigger_key_owned)
110 .arg("PX")
111 .arg(ttl_ms)
112 .query(&mut connection)
113 .map_err(|err| CoordinationError::new(err.to_string()))?;
114 Ok(Some(CoordinationClaim {
115 claim_id,
116 trigger_key: trigger_key_owned,
117 expires_at: std::time::SystemTime::now() + ttl,
118 }))
119 })
120 .await
121 }
122
123 async fn renew_claim(
124 &self,
125 claim_id: &str,
126 ttl: Duration,
127 ) -> Result<Option<CoordinationClaim>, CoordinationError> {
128 if let ValkeyBackend::InMemory(store) = &self.backend {
129 return store.renew_claim(claim_id, ttl).await;
130 }
131
132 let reverse_key = self.reverse_claim_key(claim_id);
133 let claim_id = claim_id.to_string();
134 let ttl_ms = ttl.as_millis() as usize;
135 let namespace = self.namespace.clone();
136
137 self.with_connection(move |mut connection| {
138 let trigger_key: Option<String> = connection
139 .get(&reverse_key)
140 .map_err(|err| CoordinationError::new(err.to_string()))?;
141 let Some(trigger_key) = trigger_key else {
142 return Ok(None);
143 };
144 let claim_key = format!("{namespace}:claim:{trigger_key}");
145 let current_claim: Option<String> = connection
146 .get(&claim_key)
147 .map_err(|err| CoordinationError::new(err.to_string()))?;
148 if current_claim.as_deref() != Some(claim_id.as_str()) {
149 return Ok(None);
150 }
151 let _: bool = cmd("PEXPIRE")
152 .arg(&claim_key)
153 .arg(ttl_ms)
154 .query(&mut connection)
155 .map_err(|err| CoordinationError::new(err.to_string()))?;
156 let _: bool = cmd("PEXPIRE")
157 .arg(&reverse_key)
158 .arg(ttl_ms)
159 .query(&mut connection)
160 .map_err(|err| CoordinationError::new(err.to_string()))?;
161 Ok(Some(CoordinationClaim {
162 claim_id,
163 trigger_key,
164 expires_at: std::time::SystemTime::now() + ttl,
165 }))
166 })
167 .await
168 }
169
170 async fn release_claim(&self, claim_id: &str) -> Result<(), CoordinationError> {
171 if let ValkeyBackend::InMemory(store) = &self.backend {
172 return store.release_claim(claim_id).await;
173 }
174
175 let reverse_key = self.reverse_claim_key(claim_id);
176 let claim_id = claim_id.to_string();
177 let namespace = self.namespace.clone();
178 self.with_connection(move |mut connection| {
179 let trigger_key: Option<String> = connection
180 .get(&reverse_key)
181 .map_err(|err| CoordinationError::new(err.to_string()))?;
182 if let Some(trigger_key) = trigger_key {
183 let claim_key = format!("{namespace}:claim:{trigger_key}");
184 let current_claim: Option<String> = connection
185 .get(&claim_key)
186 .map_err(|err| CoordinationError::new(err.to_string()))?;
187 if current_claim.as_deref() == Some(claim_id.as_str()) {
188 let _: usize = connection
189 .del(&claim_key)
190 .map_err(|err| CoordinationError::new(err.to_string()))?;
191 }
192 }
193 let _: usize = connection
194 .del(&reverse_key)
195 .map_err(|err| CoordinationError::new(err.to_string()))?;
196 Ok(())
197 })
198 .await
199 }
200
201 async fn scratchpad_get(
202 &self,
203 namespace: &str,
204 key: &str,
205 ) -> Result<Option<Value>, CoordinationError> {
206 if let ValkeyBackend::InMemory(store) = &self.backend {
207 return store.scratchpad_get(namespace, key).await;
208 }
209
210 let scratchpad_key = self.scratchpad_key(namespace, key);
211 self.with_connection(move |mut connection| {
212 let value: Option<String> = connection
213 .get(&scratchpad_key)
214 .map_err(|err| CoordinationError::new(err.to_string()))?;
215 value
216 .map(|value| {
217 serde_json::from_str(&value)
218 .map_err(|err| CoordinationError::new(err.to_string()))
219 })
220 .transpose()
221 })
222 .await
223 }
224
225 async fn scratchpad_set(
226 &self,
227 namespace: &str,
228 key: &str,
229 value: Value,
230 ttl: Duration,
231 ) -> Result<(), CoordinationError> {
232 if let ValkeyBackend::InMemory(store) = &self.backend {
233 return store.scratchpad_set(namespace, key, value, ttl).await;
234 }
235
236 let scratchpad_key = self.scratchpad_key(namespace, key);
237 let value =
238 serde_json::to_string(&value).map_err(|err| CoordinationError::new(err.to_string()))?;
239 let ttl_secs = ttl.as_secs().max(1);
240 self.with_connection(move |mut connection| {
241 let _: () = cmd("SETEX")
242 .arg(&scratchpad_key)
243 .arg(ttl_secs)
244 .arg(value)
245 .query(&mut connection)
246 .map_err(|err| CoordinationError::new(err.to_string()))?;
247 Ok(())
248 })
249 .await
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use rain_engine_core::CoordinationStore;
257
258 #[tokio::test]
259 async fn in_memory_backend_supports_claims_and_scratchpad() {
260 let store = ValkeyCoordinationStore::connect("memory://").expect("store");
261 let claim = store
262 .claim_trigger("trigger-1", Duration::from_secs(30))
263 .await
264 .expect("claim")
265 .expect("some claim");
266 assert!(
267 store
268 .claim_trigger("trigger-1", Duration::from_secs(30))
269 .await
270 .expect("second claim")
271 .is_none()
272 );
273 store
274 .scratchpad_set(
275 "ns",
276 "key",
277 serde_json::json!({"value": 1}),
278 Duration::from_secs(60),
279 )
280 .await
281 .expect("set");
282 assert_eq!(
283 store.scratchpad_get("ns", "key").await.expect("get"),
284 Some(serde_json::json!({"value": 1}))
285 );
286 store.release_claim(&claim.claim_id).await.expect("release");
287 }
288}