Skip to main content

rain_engine_store_valkey/
lib.rs

1//! Valkey-backed coordination and scratchpad storage for RainEngine workers.
2//!
3//! Valkey is used for distributed claims, leases, and short-lived key/value
4//! state. Durable agent state remains in the ledger store.
5
6use async_trait::async_trait;
7use rain_engine_core::{
8    CoordinationClaim, CoordinationError, CoordinationStore, InMemoryCoordinationStore,
9};
10use redis::{Commands, Connection, cmd};
11use serde_json::Value;
12use std::time::Duration;
13
14#[derive(Clone)]
15pub enum ValkeyBackend {
16    Redis(redis::Client),
17    InMemory(InMemoryCoordinationStore),
18}
19
20#[derive(Clone)]
21pub struct ValkeyCoordinationStore {
22    namespace: String,
23    backend: ValkeyBackend,
24}
25
26impl ValkeyCoordinationStore {
27    pub fn connect(url: &str) -> Result<Self, CoordinationError> {
28        Self::connect_with_namespace(url, "rain_engine")
29    }
30
31    pub fn connect_with_namespace(
32        url: &str,
33        namespace: impl Into<String>,
34    ) -> Result<Self, CoordinationError> {
35        let namespace = namespace.into();
36        let backend = if url == "memory://" {
37            ValkeyBackend::InMemory(InMemoryCoordinationStore::new())
38        } else {
39            ValkeyBackend::Redis(
40                redis::Client::open(url).map_err(|err| CoordinationError::new(err.to_string()))?,
41            )
42        };
43        Ok(Self { namespace, backend })
44    }
45
46    fn claim_key(&self, trigger_key: &str) -> String {
47        format!("{}:claim:{trigger_key}", self.namespace)
48    }
49
50    fn reverse_claim_key(&self, claim_id: &str) -> String {
51        format!("{}:claim_by_id:{claim_id}", self.namespace)
52    }
53
54    fn scratchpad_key(&self, namespace: &str, key: &str) -> String {
55        format!("{}:scratchpad:{namespace}:{key}", self.namespace)
56    }
57
58    async fn with_connection<T, F>(&self, operation: F) -> Result<T, CoordinationError>
59    where
60        T: Send + 'static,
61        F: FnOnce(Connection) -> Result<T, CoordinationError> + Send + 'static,
62    {
63        let ValkeyBackend::Redis(client) = &self.backend else {
64            return Err(CoordinationError::new("redis backend not configured"));
65        };
66        let client = client.clone();
67        tokio::task::spawn_blocking(move || {
68            let connection = client
69                .get_connection()
70                .map_err(|err| CoordinationError::new(err.to_string()))?;
71            operation(connection)
72        })
73        .await
74        .map_err(|err| CoordinationError::new(err.to_string()))?
75    }
76}
77
78#[async_trait]
79impl CoordinationStore for ValkeyCoordinationStore {
80    async fn claim_trigger(
81        &self,
82        trigger_key: &str,
83        ttl: Duration,
84    ) -> Result<Option<CoordinationClaim>, CoordinationError> {
85        if let ValkeyBackend::InMemory(store) = &self.backend {
86            return store.claim_trigger(trigger_key, ttl).await;
87        }
88
89        let claim_key = self.claim_key(trigger_key);
90        let claim_id = uuid::Uuid::new_v4().to_string();
91        let reverse_key = self.reverse_claim_key(&claim_id);
92        let ttl_ms = ttl.as_millis() as usize;
93        let trigger_key_owned = trigger_key.to_string();
94
95        self.with_connection(move |mut connection| {
96            let acquired: Option<String> = cmd("SET")
97                .arg(&claim_key)
98                .arg(&claim_id)
99                .arg("NX")
100                .arg("PX")
101                .arg(ttl_ms)
102                .query(&mut connection)
103                .map_err(|err| CoordinationError::new(err.to_string()))?;
104            if acquired.is_none() {
105                return Ok(None);
106            }
107            let _: () = cmd("SET")
108                .arg(&reverse_key)
109                .arg(&trigger_key_owned)
110                .arg("PX")
111                .arg(ttl_ms)
112                .query(&mut connection)
113                .map_err(|err| CoordinationError::new(err.to_string()))?;
114            Ok(Some(CoordinationClaim {
115                claim_id,
116                trigger_key: trigger_key_owned,
117                expires_at: std::time::SystemTime::now() + ttl,
118            }))
119        })
120        .await
121    }
122
123    async fn renew_claim(
124        &self,
125        claim_id: &str,
126        ttl: Duration,
127    ) -> Result<Option<CoordinationClaim>, CoordinationError> {
128        if let ValkeyBackend::InMemory(store) = &self.backend {
129            return store.renew_claim(claim_id, ttl).await;
130        }
131
132        let reverse_key = self.reverse_claim_key(claim_id);
133        let claim_id = claim_id.to_string();
134        let ttl_ms = ttl.as_millis() as usize;
135        let namespace = self.namespace.clone();
136
137        self.with_connection(move |mut connection| {
138            let trigger_key: Option<String> = connection
139                .get(&reverse_key)
140                .map_err(|err| CoordinationError::new(err.to_string()))?;
141            let Some(trigger_key) = trigger_key else {
142                return Ok(None);
143            };
144            let claim_key = format!("{namespace}:claim:{trigger_key}");
145            let current_claim: Option<String> = connection
146                .get(&claim_key)
147                .map_err(|err| CoordinationError::new(err.to_string()))?;
148            if current_claim.as_deref() != Some(claim_id.as_str()) {
149                return Ok(None);
150            }
151            let _: bool = cmd("PEXPIRE")
152                .arg(&claim_key)
153                .arg(ttl_ms)
154                .query(&mut connection)
155                .map_err(|err| CoordinationError::new(err.to_string()))?;
156            let _: bool = cmd("PEXPIRE")
157                .arg(&reverse_key)
158                .arg(ttl_ms)
159                .query(&mut connection)
160                .map_err(|err| CoordinationError::new(err.to_string()))?;
161            Ok(Some(CoordinationClaim {
162                claim_id,
163                trigger_key,
164                expires_at: std::time::SystemTime::now() + ttl,
165            }))
166        })
167        .await
168    }
169
170    async fn release_claim(&self, claim_id: &str) -> Result<(), CoordinationError> {
171        if let ValkeyBackend::InMemory(store) = &self.backend {
172            return store.release_claim(claim_id).await;
173        }
174
175        let reverse_key = self.reverse_claim_key(claim_id);
176        let claim_id = claim_id.to_string();
177        let namespace = self.namespace.clone();
178        self.with_connection(move |mut connection| {
179            let trigger_key: Option<String> = connection
180                .get(&reverse_key)
181                .map_err(|err| CoordinationError::new(err.to_string()))?;
182            if let Some(trigger_key) = trigger_key {
183                let claim_key = format!("{namespace}:claim:{trigger_key}");
184                let current_claim: Option<String> = connection
185                    .get(&claim_key)
186                    .map_err(|err| CoordinationError::new(err.to_string()))?;
187                if current_claim.as_deref() == Some(claim_id.as_str()) {
188                    let _: usize = connection
189                        .del(&claim_key)
190                        .map_err(|err| CoordinationError::new(err.to_string()))?;
191                }
192            }
193            let _: usize = connection
194                .del(&reverse_key)
195                .map_err(|err| CoordinationError::new(err.to_string()))?;
196            Ok(())
197        })
198        .await
199    }
200
201    async fn scratchpad_get(
202        &self,
203        namespace: &str,
204        key: &str,
205    ) -> Result<Option<Value>, CoordinationError> {
206        if let ValkeyBackend::InMemory(store) = &self.backend {
207            return store.scratchpad_get(namespace, key).await;
208        }
209
210        let scratchpad_key = self.scratchpad_key(namespace, key);
211        self.with_connection(move |mut connection| {
212            let value: Option<String> = connection
213                .get(&scratchpad_key)
214                .map_err(|err| CoordinationError::new(err.to_string()))?;
215            value
216                .map(|value| {
217                    serde_json::from_str(&value)
218                        .map_err(|err| CoordinationError::new(err.to_string()))
219                })
220                .transpose()
221        })
222        .await
223    }
224
225    async fn scratchpad_set(
226        &self,
227        namespace: &str,
228        key: &str,
229        value: Value,
230        ttl: Duration,
231    ) -> Result<(), CoordinationError> {
232        if let ValkeyBackend::InMemory(store) = &self.backend {
233            return store.scratchpad_set(namespace, key, value, ttl).await;
234        }
235
236        let scratchpad_key = self.scratchpad_key(namespace, key);
237        let value =
238            serde_json::to_string(&value).map_err(|err| CoordinationError::new(err.to_string()))?;
239        let ttl_secs = ttl.as_secs().max(1);
240        self.with_connection(move |mut connection| {
241            let _: () = cmd("SETEX")
242                .arg(&scratchpad_key)
243                .arg(ttl_secs)
244                .arg(value)
245                .query(&mut connection)
246                .map_err(|err| CoordinationError::new(err.to_string()))?;
247            Ok(())
248        })
249        .await
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use rain_engine_core::CoordinationStore;
257
258    #[tokio::test]
259    async fn in_memory_backend_supports_claims_and_scratchpad() {
260        let store = ValkeyCoordinationStore::connect("memory://").expect("store");
261        let claim = store
262            .claim_trigger("trigger-1", Duration::from_secs(30))
263            .await
264            .expect("claim")
265            .expect("some claim");
266        assert!(
267            store
268                .claim_trigger("trigger-1", Duration::from_secs(30))
269                .await
270                .expect("second claim")
271                .is_none()
272        );
273        store
274            .scratchpad_set(
275                "ns",
276                "key",
277                serde_json::json!({"value": 1}),
278                Duration::from_secs(60),
279            )
280            .await
281            .expect("set");
282        assert_eq!(
283            store.scratchpad_get("ns", "key").await.expect("get"),
284            Some(serde_json::json!({"value": 1}))
285        );
286        store.release_claim(&claim.claim_id).await.expect("release");
287    }
288}