Skip to main content

scheduler/
valkey_coordinated_store.rs

1use crate::coordinated_store::{
2    CoordinatedClaim, CoordinatedLeaseConfig, CoordinatedPendingTrigger, CoordinatedRuntimeState,
3    CoordinatedStateStore,
4};
5use crate::error::{ExecutionGuardErrorKind, StoreErrorKind};
6use crate::execution_guard::{ExecutionGuardRenewal, ExecutionGuardScope, ExecutionLease};
7use crate::model::JobState;
8use crate::valkey_execution_support::{
9    next_token, now_millis, occurrence_index_key, occurrence_lease_key, resource_lock_key,
10};
11use crate::valkey_store::ValkeyStoreError;
12use chrono::SecondsFormat;
13use chrono::{DateTime, Utc};
14use redis::{AsyncCommands, Client, Script, aio::ConnectionManager, cmd};
15use std::collections::HashMap;
16use std::sync::atomic::AtomicU64;
17
18const DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:valkey:job-state:";
19const LEGACY_DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:job-state:";
20const DEFAULT_EXECUTION_KEY_PREFIX: &str = "scheduler:valkey:execution-lease:";
21const SAVE_STATE_LUA: &str = include_str!("lua/valkey_coordinated_store/save_state.lua");
22const RECLAIM_INFLIGHT_LUA: &str =
23    include_str!("lua/valkey_coordinated_store/reclaim_inflight.lua");
24const CLAIM_TRIGGER_LUA: &str = include_str!("lua/valkey_coordinated_store/claim_trigger.lua");
25const RENEW_LEASE_LUA: &str = include_str!("lua/valkey_coordinated_store/renew_lease.lua");
26const COMPLETE_LUA: &str = include_str!("lua/valkey_coordinated_store/complete.lua");
27const PAUSE_LUA: &str = include_str!("lua/valkey_coordinated_store/pause.lua");
28const RESUME_LUA: &str = include_str!("lua/valkey_coordinated_store/resume.lua");
29
30const FIELD_VERSION: &str = "version";
31const FIELD_STATE: &str = "state";
32const FIELD_PAUSED: &str = "paused";
33const FIELD_INFLIGHT_SCHEDULED_AT: &str = "inflight_scheduled_at";
34const FIELD_INFLIGHT_CATCH_UP: &str = "inflight_catch_up";
35const FIELD_INFLIGHT_TRIGGER_COUNT: &str = "inflight_trigger_count";
36const FIELD_INFLIGHT_RESOURCE_ID: &str = "inflight_resource_id";
37const FIELD_INFLIGHT_SCOPE: &str = "inflight_scope";
38const FIELD_INFLIGHT_TOKEN: &str = "inflight_token";
39const FIELD_INFLIGHT_LEASE_KEY: &str = "inflight_lease_key";
40const FIELD_INFLIGHT_LEASE_EXPIRES_AT: &str = "inflight_lease_expires_at";
41
42static COORDINATED_TOKEN_COUNTER: AtomicU64 = AtomicU64::new(1);
43
44#[derive(Debug, Clone)]
45pub struct ValkeyCoordinatedStateStore {
46    connection: ConnectionManager,
47    state_key_prefix: String,
48    execution_key_prefix: String,
49}
50
51impl ValkeyCoordinatedStateStore {
52    pub async fn new(url: impl AsRef<str>) -> Result<Self, redis::RedisError> {
53        Self::with_prefixes(url, DEFAULT_STATE_KEY_PREFIX, DEFAULT_EXECUTION_KEY_PREFIX).await
54    }
55
56    pub async fn with_prefixes(
57        url: impl AsRef<str>,
58        state_key_prefix: impl Into<String>,
59        execution_key_prefix: impl Into<String>,
60    ) -> Result<Self, redis::RedisError> {
61        let client = Client::open(url.as_ref())?;
62        let connection = client.get_connection_manager().await?;
63        Ok(Self {
64            connection,
65            state_key_prefix: state_key_prefix.into(),
66            execution_key_prefix: execution_key_prefix.into(),
67        })
68    }
69
70    fn state_key(&self, job_id: &str) -> String {
71        format!("{}{}", self.state_key_prefix, job_id)
72    }
73
74    fn legacy_state_key(&self, job_id: &str) -> Option<String> {
75        if self.state_key_prefix == DEFAULT_STATE_KEY_PREFIX {
76            Some(format!("{LEGACY_DEFAULT_STATE_KEY_PREFIX}{job_id}"))
77        } else {
78            None
79        }
80    }
81
82    fn resource_lock_key(&self, resource_id: &str) -> String {
83        resource_lock_key(&self.execution_key_prefix, resource_id)
84    }
85
86    fn occurrence_index_key(&self, resource_id: &str) -> String {
87        occurrence_index_key(&self.execution_key_prefix, resource_id)
88    }
89
90    fn occurrence_lease_key(&self, resource_id: &str, scheduled_at: DateTime<Utc>) -> String {
91        occurrence_lease_key(&self.execution_key_prefix, resource_id, scheduled_at)
92    }
93
94    async fn key_type(&self, key: &str) -> Result<String, ValkeyStoreError> {
95        let mut connection = self.connection.clone();
96        cmd("TYPE")
97            .arg(key)
98            .query_async(&mut connection)
99            .await
100            .map_err(ValkeyStoreError::from)
101    }
102
103    async fn load_hash(
104        &self,
105        key: &str,
106    ) -> Result<Option<CoordinatedRuntimeState>, ValkeyStoreError> {
107        let mut connection = self.connection.clone();
108        let fields: HashMap<String, String> = connection
109            .hgetall(key)
110            .await
111            .map_err(ValkeyStoreError::from)?;
112        if fields.is_empty() {
113            return Ok(None);
114        }
115
116        Ok(Some(parse_runtime_state(&fields)?))
117    }
118
119    async fn migrate_string_state(
120        &self,
121        key: &str,
122        payload: String,
123    ) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
124        let state: JobState = serde_json::from_str(&payload).map_err(ValkeyStoreError::from)?;
125        let runtime = CoordinatedRuntimeState {
126            state,
127            revision: 0,
128            paused: false,
129        };
130        self.write_runtime(key, &runtime).await?;
131        Ok(runtime)
132    }
133
134    async fn write_runtime(
135        &self,
136        key: &str,
137        runtime: &CoordinatedRuntimeState,
138    ) -> Result<(), ValkeyStoreError> {
139        let mut connection = self.connection.clone();
140        let payload = serde_json::to_string(&runtime.state).map_err(ValkeyStoreError::from)?;
141        let _: () = cmd("DEL")
142            .arg(key)
143            .query_async(&mut connection)
144            .await
145            .map_err(ValkeyStoreError::from)?;
146        let _: () = cmd("HSET")
147            .arg(key)
148            .arg(FIELD_VERSION)
149            .arg(runtime.revision)
150            .arg(FIELD_STATE)
151            .arg(payload)
152            .arg(FIELD_PAUSED)
153            .arg(if runtime.paused { "1" } else { "0" })
154            .query_async(&mut connection)
155            .await
156            .map_err(ValkeyStoreError::from)?;
157        Ok(())
158    }
159
160    async fn load_payload_state(&self, key: &str) -> Result<Option<String>, ValkeyStoreError> {
161        let mut connection = self.connection.clone();
162        connection.get(key).await.map_err(ValkeyStoreError::from)
163    }
164}
165
166impl CoordinatedStateStore for ValkeyCoordinatedStateStore {
167    type Error = ValkeyStoreError;
168
169    async fn load_or_initialize(
170        &self,
171        job_id: &str,
172        initial_state: JobState,
173    ) -> Result<CoordinatedRuntimeState, Self::Error> {
174        let key = self.state_key(job_id);
175        match self.key_type(&key).await?.as_str() {
176            "hash" => {
177                if let Some(runtime) = self.load_hash(&key).await? {
178                    return Ok(runtime);
179                }
180            }
181            "string" => {
182                if let Some(payload) = self.load_payload_state(&key).await? {
183                    return self.migrate_string_state(&key, payload).await;
184                }
185            }
186            "none" => {}
187            _ => {}
188        }
189
190        if let Some(legacy_key) = self.legacy_state_key(job_id) {
191            if self.key_type(&legacy_key).await?.as_str() == "string" {
192                if let Some(payload) = self.load_payload_state(&legacy_key).await? {
193                    let runtime = self.migrate_string_state(&key, payload).await?;
194                    let mut connection = self.connection.clone();
195                    let _: () = cmd("DEL")
196                        .arg(legacy_key)
197                        .query_async(&mut connection)
198                        .await
199                        .map_err(ValkeyStoreError::from)?;
200                    return Ok(runtime);
201                }
202            }
203        }
204
205        let runtime = CoordinatedRuntimeState {
206            state: initial_state,
207            revision: 0,
208            paused: false,
209        };
210        self.write_runtime(&key, &runtime).await?;
211        Ok(runtime)
212    }
213
214    async fn save_state(
215        &self,
216        job_id: &str,
217        revision: u64,
218        state: &JobState,
219    ) -> Result<bool, Self::Error> {
220        let key = self.state_key(job_id);
221        let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
222        let mut connection = self.connection.clone();
223        let updated: i32 = Script::new(SAVE_STATE_LUA)
224            .key(key)
225            .arg(FIELD_VERSION)
226            .arg(revision)
227            .arg(FIELD_INFLIGHT_TOKEN)
228            .arg(FIELD_STATE)
229            .arg(payload)
230            .invoke_async(&mut connection)
231            .await
232            .map_err(ValkeyStoreError::from)?;
233        Ok(updated == 1)
234    }
235
236    async fn reclaim_inflight(
237        &self,
238        job_id: &str,
239        resource_id: &str,
240        lease_config: CoordinatedLeaseConfig,
241    ) -> Result<Option<CoordinatedClaim>, Self::Error> {
242        let key = self.state_key(job_id);
243        let lease_key = self.occurrence_lease_key(resource_id, Utc::now());
244        let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
245        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
246        let now_millis = now_millis();
247        let expires_at_millis = now_millis.saturating_add(ttl_millis);
248        let mut connection = self.connection.clone();
249        let result: Option<Vec<String>> = Script::new(RECLAIM_INFLIGHT_LUA)
250            .key(key)
251            .key(self.resource_lock_key(resource_id))
252            .key(lease_key.clone())
253            .key(self.occurrence_index_key(resource_id))
254            // ARGV layout:
255            //  1..9  = hash field names read from KEYS[1]
256            // 10..14 = reclaim timing + new lease payload
257            // 15..16 = hash field names updated on successful reclaim
258            .arg(FIELD_INFLIGHT_SCHEDULED_AT)
259            .arg(FIELD_INFLIGHT_CATCH_UP)
260            .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
261            .arg(FIELD_INFLIGHT_RESOURCE_ID)
262            .arg(FIELD_INFLIGHT_SCOPE)
263            .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
264            .arg(FIELD_STATE)
265            .arg(FIELD_VERSION)
266            .arg(FIELD_PAUSED)
267            .arg(now_millis)
268            .arg(format!(
269                "{}{}:occurrence:",
270                self.execution_key_prefix, resource_id
271            ))
272            .arg(&token)
273            .arg(ttl_millis)
274            .arg(expires_at_millis)
275            .arg(FIELD_INFLIGHT_TOKEN)
276            .arg(FIELD_INFLIGHT_LEASE_KEY)
277            .invoke_async(&mut connection)
278            .await
279            .map_err(ValkeyStoreError::from)?;
280
281        let Some(values) = result else {
282            return Ok(None);
283        };
284        if values.len() != 8 {
285            return Ok(None);
286        }
287        let revision = values[0].parse::<u64>().unwrap_or(0);
288        let state: JobState = serde_json::from_str(&values[1]).map_err(ValkeyStoreError::from)?;
289        let scheduled_at = DateTime::parse_from_rfc3339(&values[2])
290            .map_err(|error| {
291                ValkeyStoreError::Codec(serde_json::Error::io(std::io::Error::other(
292                    error.to_string(),
293                )))
294            })?
295            .with_timezone(&Utc);
296        let catch_up = values[3].parse::<bool>().unwrap_or(false);
297        let trigger_count = values[4].parse::<u32>().unwrap_or(0);
298        let scope = parse_scope(&values[5]);
299        Ok(Some(CoordinatedClaim {
300            state: CoordinatedRuntimeState {
301                state,
302                revision,
303                paused: false,
304            },
305            trigger: CoordinatedPendingTrigger {
306                scheduled_at,
307                catch_up,
308                trigger_count,
309            },
310            lease: ExecutionLease::new(
311                job_id.to_string(),
312                resource_id.to_string(),
313                scope,
314                Some(scheduled_at),
315                values[7].clone(),
316                values[6].clone(),
317            ),
318            replayed: true,
319        }))
320    }
321
322    async fn claim_trigger(
323        &self,
324        job_id: &str,
325        resource_id: &str,
326        revision: u64,
327        trigger: CoordinatedPendingTrigger,
328        next_state: &JobState,
329        lease_config: CoordinatedLeaseConfig,
330    ) -> Result<Option<CoordinatedClaim>, Self::Error> {
331        let key = self.state_key(job_id);
332        let lease_key = self.occurrence_lease_key(resource_id, trigger.scheduled_at);
333        let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
334        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
335        let now_millis = now_millis();
336        let expires_at_millis = now_millis.saturating_add(ttl_millis);
337        let next_state_payload =
338            serde_json::to_string(next_state).map_err(ValkeyStoreError::from)?;
339        let mut connection = self.connection.clone();
340        let new_revision: i64 = Script::new(CLAIM_TRIGGER_LUA)
341            .key(key)
342            .key(self.resource_lock_key(resource_id))
343            .key(&lease_key)
344            .key(self.occurrence_index_key(resource_id))
345            .arg(FIELD_VERSION)
346            .arg(FIELD_INFLIGHT_TOKEN)
347            .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
348            .arg(FIELD_PAUSED)
349            .arg(now_millis)
350            .arg(revision)
351            .arg(&token)
352            .arg(ttl_millis)
353            .arg(expires_at_millis)
354            .arg(FIELD_STATE)
355            .arg(next_state_payload)
356            .arg(FIELD_INFLIGHT_SCHEDULED_AT)
357            .arg(
358                trigger
359                    .scheduled_at
360                    .to_rfc3339_opts(SecondsFormat::Nanos, true),
361            )
362            .arg(FIELD_INFLIGHT_CATCH_UP)
363            .arg(trigger.catch_up)
364            .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
365            .arg(trigger.trigger_count)
366            .arg(FIELD_INFLIGHT_RESOURCE_ID)
367            .arg(resource_id)
368            .arg(FIELD_INFLIGHT_SCOPE)
369            .arg("occurrence")
370            .arg(FIELD_INFLIGHT_TOKEN)
371            .arg(FIELD_INFLIGHT_LEASE_KEY)
372            .invoke_async(&mut connection)
373            .await
374            .map_err(ValkeyStoreError::from)?;
375
376        if new_revision <= 0 {
377            return Ok(None);
378        }
379
380        Ok(Some(CoordinatedClaim {
381            state: CoordinatedRuntimeState {
382                state: next_state.clone(),
383                revision: new_revision as u64,
384                paused: false,
385            },
386            trigger: trigger.clone(),
387            lease: ExecutionLease::new(
388                job_id.to_string(),
389                resource_id.to_string(),
390                ExecutionGuardScope::Occurrence,
391                Some(trigger.scheduled_at),
392                token,
393                lease_key,
394            ),
395            replayed: false,
396        }))
397    }
398
399    async fn renew(
400        &self,
401        lease: &ExecutionLease,
402        lease_config: CoordinatedLeaseConfig,
403    ) -> Result<ExecutionGuardRenewal, Self::Error> {
404        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
405        let expires_at_millis = now_millis().saturating_add(ttl_millis);
406        let mut connection = self.connection.clone();
407        let renewed: i32 = Script::new(RENEW_LEASE_LUA)
408            .key(&lease.lease_key)
409            .key(self.occurrence_index_key(&lease.resource_id))
410            .key(self.state_key(&lease.job_id))
411            .arg(&lease.token)
412            .arg(ttl_millis)
413            .arg(expires_at_millis)
414            .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
415            .invoke_async(&mut connection)
416            .await
417            .map_err(ValkeyStoreError::from)?;
418        Ok(if renewed == 1 {
419            ExecutionGuardRenewal::Renewed
420        } else {
421            ExecutionGuardRenewal::Lost
422        })
423    }
424
425    async fn complete(
426        &self,
427        job_id: &str,
428        revision: u64,
429        lease: &ExecutionLease,
430        state: &JobState,
431    ) -> Result<bool, Self::Error> {
432        let key = self.state_key(job_id);
433        let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
434        let mut connection = self.connection.clone();
435        let completed: i32 = Script::new(COMPLETE_LUA)
436            .key(key)
437            .key(&lease.lease_key)
438            .key(self.occurrence_index_key(&lease.resource_id))
439            .arg(FIELD_VERSION)
440            .arg(FIELD_INFLIGHT_TOKEN)
441            .arg(revision)
442            .arg(&lease.token)
443            .arg(FIELD_STATE)
444            .arg(payload)
445            .arg(FIELD_INFLIGHT_SCHEDULED_AT)
446            .arg(FIELD_INFLIGHT_CATCH_UP)
447            .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
448            .arg(FIELD_INFLIGHT_RESOURCE_ID)
449            .arg(FIELD_INFLIGHT_SCOPE)
450            .arg(FIELD_INFLIGHT_LEASE_KEY)
451            .invoke_async(&mut connection)
452            .await
453            .map_err(ValkeyStoreError::from)?;
454        Ok(completed == 1)
455    }
456
457    async fn delete(&self, job_id: &str) -> Result<(), Self::Error> {
458        let key = self.state_key(job_id);
459        let mut connection = self.connection.clone();
460        let _: () = cmd("DEL")
461            .arg(key)
462            .query_async(&mut connection)
463            .await
464            .map_err(ValkeyStoreError::from)?;
465        Ok(())
466    }
467
468    async fn pause(&self, job_id: &str) -> Result<bool, Self::Error> {
469        let key = self.state_key(job_id);
470        let mut connection = self.connection.clone();
471        let changed: i32 = Script::new(PAUSE_LUA)
472            .key(key)
473            .arg(FIELD_PAUSED)
474            .invoke_async(&mut connection)
475            .await
476            .map_err(ValkeyStoreError::from)?;
477        Ok(changed == 1)
478    }
479
480    async fn resume(&self, job_id: &str) -> Result<bool, Self::Error> {
481        let key = self.state_key(job_id);
482        let mut connection = self.connection.clone();
483        let changed: i32 = Script::new(RESUME_LUA)
484            .key(key)
485            .arg(FIELD_PAUSED)
486            .invoke_async(&mut connection)
487            .await
488            .map_err(ValkeyStoreError::from)?;
489        Ok(changed == 1)
490    }
491
492    fn classify_store_error(error: &Self::Error) -> StoreErrorKind
493    where
494        Self: Sized,
495    {
496        if matches!(error, ValkeyStoreError::Codec(_)) {
497            StoreErrorKind::Data
498        } else if error.is_connection_issue() {
499            StoreErrorKind::Connection
500        } else {
501            StoreErrorKind::Unknown
502        }
503    }
504
505    fn classify_guard_error(error: &Self::Error) -> ExecutionGuardErrorKind
506    where
507        Self: Sized,
508    {
509        if matches!(error, ValkeyStoreError::Codec(_)) {
510            ExecutionGuardErrorKind::Data
511        } else if error.is_connection_issue() {
512            ExecutionGuardErrorKind::Connection
513        } else {
514            ExecutionGuardErrorKind::Unknown
515        }
516    }
517}
518
519fn parse_runtime_state(
520    fields: &HashMap<String, String>,
521) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
522    let revision = fields
523        .get(FIELD_VERSION)
524        .and_then(|value| value.parse::<u64>().ok())
525        .unwrap_or(0);
526    let paused = fields
527        .get(FIELD_PAUSED)
528        .map(|value| value == "1" || value.eq_ignore_ascii_case("true"))
529        .unwrap_or(false);
530    let state = serde_json::from_str(fields.get(FIELD_STATE).map(String::as_str).unwrap_or("{}"))
531        .map_err(ValkeyStoreError::from)?;
532    Ok(CoordinatedRuntimeState {
533        state,
534        revision,
535        paused,
536    })
537}
538
539fn parse_scope(raw: &str) -> ExecutionGuardScope {
540    match raw {
541        "resource" => ExecutionGuardScope::Resource,
542        _ => ExecutionGuardScope::Occurrence,
543    }
544}