Skip to main content

scheduler/
valkey_coordinated_store.rs

1use crate::coordinated_store::{
2    CoordinatedClaim, CoordinatedLeaseConfig, CoordinatedPendingTrigger, CoordinatedRuntimeState,
3    CoordinatedStateStore,
4};
5use crate::error::{ExecutionGuardErrorKind, StoreErrorKind};
6use crate::execution_guard::{ExecutionGuardRenewal, ExecutionGuardScope, ExecutionLease};
7use crate::model::JobState;
8use crate::valkey_execution_support::{
9    next_token, now_millis, occurrence_index_key, occurrence_lease_key, resource_lock_key,
10};
11use crate::valkey_store::ValkeyStoreError;
12use chrono::SecondsFormat;
13use chrono::{DateTime, Utc};
14use redis::{AsyncCommands, Client, Script, aio::ConnectionManager, cmd};
15use std::collections::HashMap;
16use std::sync::atomic::AtomicU64;
17
18const DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:valkey:job-state:";
19const LEGACY_DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:job-state:";
20const DEFAULT_EXECUTION_KEY_PREFIX: &str = "scheduler:valkey:execution-lease:";
21
22const FIELD_VERSION: &str = "version";
23const FIELD_STATE: &str = "state";
24const FIELD_INFLIGHT_SCHEDULED_AT: &str = "inflight_scheduled_at";
25const FIELD_INFLIGHT_CATCH_UP: &str = "inflight_catch_up";
26const FIELD_INFLIGHT_TRIGGER_COUNT: &str = "inflight_trigger_count";
27const FIELD_INFLIGHT_RESOURCE_ID: &str = "inflight_resource_id";
28const FIELD_INFLIGHT_SCOPE: &str = "inflight_scope";
29const FIELD_INFLIGHT_TOKEN: &str = "inflight_token";
30const FIELD_INFLIGHT_LEASE_KEY: &str = "inflight_lease_key";
31const FIELD_INFLIGHT_LEASE_EXPIRES_AT: &str = "inflight_lease_expires_at";
32
33static COORDINATED_TOKEN_COUNTER: AtomicU64 = AtomicU64::new(1);
34
35#[derive(Debug, Clone)]
36pub struct ValkeyCoordinatedStateStore {
37    connection: ConnectionManager,
38    state_key_prefix: String,
39    execution_key_prefix: String,
40}
41
42impl ValkeyCoordinatedStateStore {
43    pub async fn new(url: impl AsRef<str>) -> Result<Self, redis::RedisError> {
44        Self::with_prefixes(url, DEFAULT_STATE_KEY_PREFIX, DEFAULT_EXECUTION_KEY_PREFIX).await
45    }
46
47    pub async fn with_prefixes(
48        url: impl AsRef<str>,
49        state_key_prefix: impl Into<String>,
50        execution_key_prefix: impl Into<String>,
51    ) -> Result<Self, redis::RedisError> {
52        let client = Client::open(url.as_ref())?;
53        let connection = client.get_connection_manager().await?;
54        Ok(Self {
55            connection,
56            state_key_prefix: state_key_prefix.into(),
57            execution_key_prefix: execution_key_prefix.into(),
58        })
59    }
60
61    fn state_key(&self, job_id: &str) -> String {
62        format!("{}{}", self.state_key_prefix, job_id)
63    }
64
65    fn legacy_state_key(&self, job_id: &str) -> Option<String> {
66        if self.state_key_prefix == DEFAULT_STATE_KEY_PREFIX {
67            Some(format!("{LEGACY_DEFAULT_STATE_KEY_PREFIX}{job_id}"))
68        } else {
69            None
70        }
71    }
72
73    fn resource_lock_key(&self, resource_id: &str) -> String {
74        resource_lock_key(&self.execution_key_prefix, resource_id)
75    }
76
77    fn occurrence_index_key(&self, resource_id: &str) -> String {
78        occurrence_index_key(&self.execution_key_prefix, resource_id)
79    }
80
81    fn occurrence_lease_key(&self, resource_id: &str, scheduled_at: DateTime<Utc>) -> String {
82        occurrence_lease_key(&self.execution_key_prefix, resource_id, scheduled_at)
83    }
84
85    async fn key_type(&self, key: &str) -> Result<String, ValkeyStoreError> {
86        let mut connection = self.connection.clone();
87        cmd("TYPE")
88            .arg(key)
89            .query_async(&mut connection)
90            .await
91            .map_err(ValkeyStoreError::from)
92    }
93
94    async fn load_hash(
95        &self,
96        key: &str,
97    ) -> Result<Option<CoordinatedRuntimeState>, ValkeyStoreError> {
98        let mut connection = self.connection.clone();
99        let fields: HashMap<String, String> = connection
100            .hgetall(key)
101            .await
102            .map_err(ValkeyStoreError::from)?;
103        if fields.is_empty() {
104            return Ok(None);
105        }
106
107        Ok(Some(parse_runtime_state(&fields)?))
108    }
109
110    async fn migrate_string_state(
111        &self,
112        key: &str,
113        payload: String,
114    ) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
115        let state: JobState = serde_json::from_str(&payload).map_err(ValkeyStoreError::from)?;
116        let runtime = CoordinatedRuntimeState { state, revision: 0 };
117        self.write_runtime(key, &runtime).await?;
118        Ok(runtime)
119    }
120
121    async fn write_runtime(
122        &self,
123        key: &str,
124        runtime: &CoordinatedRuntimeState,
125    ) -> Result<(), ValkeyStoreError> {
126        let mut connection = self.connection.clone();
127        let payload = serde_json::to_string(&runtime.state).map_err(ValkeyStoreError::from)?;
128        let _: () = cmd("DEL")
129            .arg(key)
130            .query_async(&mut connection)
131            .await
132            .map_err(ValkeyStoreError::from)?;
133        let _: () = cmd("HSET")
134            .arg(key)
135            .arg(FIELD_VERSION)
136            .arg(runtime.revision)
137            .arg(FIELD_STATE)
138            .arg(payload)
139            .query_async(&mut connection)
140            .await
141            .map_err(ValkeyStoreError::from)?;
142        Ok(())
143    }
144
145    async fn load_payload_state(&self, key: &str) -> Result<Option<String>, ValkeyStoreError> {
146        let mut connection = self.connection.clone();
147        connection.get(key).await.map_err(ValkeyStoreError::from)
148    }
149}
150
151impl CoordinatedStateStore for ValkeyCoordinatedStateStore {
152    type Error = ValkeyStoreError;
153
154    async fn load_or_initialize(
155        &self,
156        job_id: &str,
157        initial_state: JobState,
158    ) -> Result<CoordinatedRuntimeState, Self::Error> {
159        let key = self.state_key(job_id);
160        match self.key_type(&key).await?.as_str() {
161            "hash" => {
162                if let Some(runtime) = self.load_hash(&key).await? {
163                    return Ok(runtime);
164                }
165            }
166            "string" => {
167                if let Some(payload) = self.load_payload_state(&key).await? {
168                    return self.migrate_string_state(&key, payload).await;
169                }
170            }
171            "none" => {}
172            _ => {}
173        }
174
175        if let Some(legacy_key) = self.legacy_state_key(job_id) {
176            if self.key_type(&legacy_key).await?.as_str() == "string" {
177                if let Some(payload) = self.load_payload_state(&legacy_key).await? {
178                    let runtime = self.migrate_string_state(&key, payload).await?;
179                    let mut connection = self.connection.clone();
180                    let _: () = cmd("DEL")
181                        .arg(legacy_key)
182                        .query_async(&mut connection)
183                        .await
184                        .map_err(ValkeyStoreError::from)?;
185                    return Ok(runtime);
186                }
187            }
188        }
189
190        let runtime = CoordinatedRuntimeState {
191            state: initial_state,
192            revision: 0,
193        };
194        self.write_runtime(&key, &runtime).await?;
195        Ok(runtime)
196    }
197
198    async fn save_state(
199        &self,
200        job_id: &str,
201        revision: u64,
202        state: &JobState,
203    ) -> Result<bool, Self::Error> {
204        let key = self.state_key(job_id);
205        let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
206        let mut connection = self.connection.clone();
207        let updated: i32 = Script::new(
208            r"
209            local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
210            local inflight = redis.call('HGET', KEYS[1], ARGV[3])
211            if inflight then
212                return 0
213            end
214            if version ~= tonumber(ARGV[2]) then
215                return 0
216            end
217            redis.call('HSET', KEYS[1], ARGV[1], version + 1, ARGV[4], ARGV[5])
218            return 1
219            ",
220        )
221        .key(key)
222        .arg(FIELD_VERSION)
223        .arg(revision)
224        .arg(FIELD_INFLIGHT_TOKEN)
225        .arg(FIELD_STATE)
226        .arg(payload)
227        .invoke_async(&mut connection)
228        .await
229        .map_err(ValkeyStoreError::from)?;
230        Ok(updated == 1)
231    }
232
233    async fn reclaim_inflight(
234        &self,
235        job_id: &str,
236        resource_id: &str,
237        lease_config: CoordinatedLeaseConfig,
238    ) -> Result<Option<CoordinatedClaim>, Self::Error> {
239        let key = self.state_key(job_id);
240        let lease_key = self.occurrence_lease_key(resource_id, Utc::now());
241        let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
242        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
243        let now_millis = now_millis();
244        let expires_at_millis = now_millis.saturating_add(ttl_millis);
245        let mut connection = self.connection.clone();
246        let result: Option<Vec<String>> = Script::new(
247            r"
248            local scheduled_at = redis.call('HGET', KEYS[1], ARGV[1])
249            local catch_up = redis.call('HGET', KEYS[1], ARGV[2])
250            local trigger_count = redis.call('HGET', KEYS[1], ARGV[3])
251            local inflight_resource_id = redis.call('HGET', KEYS[1], ARGV[4])
252            local inflight_scope = redis.call('HGET', KEYS[1], ARGV[5])
253            local inflight_expires_at = tonumber(redis.call('HGET', KEYS[1], ARGV[6]) or '0')
254            local state_payload = redis.call('HGET', KEYS[1], ARGV[7])
255            local version = tonumber(redis.call('HGET', KEYS[1], ARGV[8]) or '0')
256
257            if not scheduled_at or not inflight_resource_id or not inflight_scope then
258                return nil
259            end
260            if inflight_expires_at > tonumber(ARGV[9]) then
261                return nil
262            end
263            redis.call('ZREMRANGEBYSCORE', KEYS[4], '-inf', ARGV[9])
264            if redis.call('EXISTS', KEYS[2]) == 1 then
265                return nil
266            end
267            local new_lease_key = ARGV[10] .. scheduled_at
268            local ok = redis.call('SET', new_lease_key, ARGV[11], 'NX', 'PX', ARGV[12])
269            if not ok then
270                return nil
271            end
272            redis.call('ZADD', KEYS[4], ARGV[13], new_lease_key)
273            redis.call('HSET', KEYS[1],
274                ARGV[6], ARGV[13],
275                ARGV[14], ARGV[11],
276                ARGV[15], new_lease_key,
277                ARGV[8], version + 1
278            )
279            return { tostring(version + 1), state_payload, scheduled_at, catch_up, trigger_count, inflight_scope, new_lease_key, ARGV[11] }
280            ",
281        )
282        .key(key)
283        .key(self.resource_lock_key(resource_id))
284        .key(lease_key.clone())
285        .key(self.occurrence_index_key(resource_id))
286        .arg(FIELD_INFLIGHT_SCHEDULED_AT)
287        .arg(FIELD_INFLIGHT_CATCH_UP)
288        .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
289        .arg(FIELD_INFLIGHT_RESOURCE_ID)
290        .arg(FIELD_INFLIGHT_SCOPE)
291        .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
292        .arg(FIELD_STATE)
293        .arg(FIELD_VERSION)
294        .arg(now_millis)
295        .arg(format!("{}{}:occurrence:", self.execution_key_prefix, resource_id))
296        .arg(&token)
297        .arg(ttl_millis)
298        .arg(expires_at_millis)
299        .arg(FIELD_INFLIGHT_TOKEN)
300        .arg(FIELD_INFLIGHT_LEASE_KEY)
301        .invoke_async(&mut connection)
302        .await
303        .map_err(ValkeyStoreError::from)?;
304
305        let Some(values) = result else {
306            return Ok(None);
307        };
308        if values.len() != 8 {
309            return Ok(None);
310        }
311        let revision = values[0].parse::<u64>().unwrap_or(0);
312        let state: JobState = serde_json::from_str(&values[1]).map_err(ValkeyStoreError::from)?;
313        let scheduled_at = DateTime::parse_from_rfc3339(&values[2])
314            .map_err(|error| {
315                ValkeyStoreError::Codec(serde_json::Error::io(std::io::Error::other(
316                    error.to_string(),
317                )))
318            })?
319            .with_timezone(&Utc);
320        let catch_up = values[3].parse::<bool>().unwrap_or(false);
321        let trigger_count = values[4].parse::<u32>().unwrap_or(0);
322        let scope = parse_scope(&values[5]);
323        Ok(Some(CoordinatedClaim {
324            state: CoordinatedRuntimeState { state, revision },
325            trigger: CoordinatedPendingTrigger {
326                scheduled_at,
327                catch_up,
328                trigger_count,
329            },
330            lease: ExecutionLease::new(
331                job_id.to_string(),
332                resource_id.to_string(),
333                scope,
334                Some(scheduled_at),
335                values[7].clone(),
336                values[6].clone(),
337            ),
338            replayed: true,
339        }))
340    }
341
342    async fn claim_trigger(
343        &self,
344        job_id: &str,
345        resource_id: &str,
346        revision: u64,
347        trigger: CoordinatedPendingTrigger,
348        next_state: &JobState,
349        lease_config: CoordinatedLeaseConfig,
350    ) -> Result<Option<CoordinatedClaim>, Self::Error> {
351        let key = self.state_key(job_id);
352        let lease_key = self.occurrence_lease_key(resource_id, trigger.scheduled_at);
353        let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
354        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
355        let now_millis = now_millis();
356        let expires_at_millis = now_millis.saturating_add(ttl_millis);
357        let next_state_payload =
358            serde_json::to_string(next_state).map_err(ValkeyStoreError::from)?;
359        let mut connection = self.connection.clone();
360        let new_revision: i64 = Script::new(
361            r"
362            local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
363            local inflight = redis.call('HGET', KEYS[1], ARGV[2])
364            if inflight then
365                local inflight_expires_at = tonumber(redis.call('HGET', KEYS[1], ARGV[3]) or '0')
366                if inflight_expires_at > tonumber(ARGV[4]) then
367                    return 0
368                end
369                return 0
370            end
371            if version ~= tonumber(ARGV[5]) then
372                return 0
373            end
374            redis.call('ZREMRANGEBYSCORE', KEYS[4], '-inf', ARGV[4])
375            if redis.call('EXISTS', KEYS[2]) == 1 then
376                return 0
377            end
378            local ok = redis.call('SET', KEYS[3], ARGV[6], 'NX', 'PX', ARGV[7])
379            if not ok then
380                return 0
381            end
382            redis.call('ZADD', KEYS[4], ARGV[8], KEYS[3])
383            redis.call('HSET', KEYS[1],
384                ARGV[1], version + 1,
385                ARGV[9], ARGV[10],
386                ARGV[11], ARGV[12],
387                ARGV[13], ARGV[14],
388                ARGV[15], ARGV[16],
389                ARGV[17], ARGV[18],
390                ARGV[19], ARGV[20],
391                ARGV[21], ARGV[6],
392                ARGV[22], KEYS[3],
393                ARGV[3], ARGV[8]
394            )
395            return version + 1
396            ",
397        )
398        .key(key)
399        .key(self.resource_lock_key(resource_id))
400        .key(&lease_key)
401        .key(self.occurrence_index_key(resource_id))
402        .arg(FIELD_VERSION)
403        .arg(FIELD_INFLIGHT_TOKEN)
404        .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
405        .arg(now_millis)
406        .arg(revision)
407        .arg(&token)
408        .arg(ttl_millis)
409        .arg(expires_at_millis)
410        .arg(FIELD_STATE)
411        .arg(next_state_payload)
412        .arg(FIELD_INFLIGHT_SCHEDULED_AT)
413        .arg(
414            trigger
415                .scheduled_at
416                .to_rfc3339_opts(SecondsFormat::Nanos, true),
417        )
418        .arg(FIELD_INFLIGHT_CATCH_UP)
419        .arg(trigger.catch_up)
420        .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
421        .arg(trigger.trigger_count)
422        .arg(FIELD_INFLIGHT_RESOURCE_ID)
423        .arg(resource_id)
424        .arg(FIELD_INFLIGHT_SCOPE)
425        .arg("occurrence")
426        .arg(FIELD_INFLIGHT_TOKEN)
427        .arg(FIELD_INFLIGHT_LEASE_KEY)
428        .invoke_async(&mut connection)
429        .await
430        .map_err(ValkeyStoreError::from)?;
431
432        if new_revision <= 0 {
433            return Ok(None);
434        }
435
436        Ok(Some(CoordinatedClaim {
437            state: CoordinatedRuntimeState {
438                state: next_state.clone(),
439                revision: new_revision as u64,
440            },
441            trigger: trigger.clone(),
442            lease: ExecutionLease::new(
443                job_id.to_string(),
444                resource_id.to_string(),
445                ExecutionGuardScope::Occurrence,
446                Some(trigger.scheduled_at),
447                token,
448                lease_key,
449            ),
450            replayed: false,
451        }))
452    }
453
454    async fn renew(
455        &self,
456        lease: &ExecutionLease,
457        lease_config: CoordinatedLeaseConfig,
458    ) -> Result<ExecutionGuardRenewal, Self::Error> {
459        let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
460        let expires_at_millis = now_millis().saturating_add(ttl_millis);
461        let mut connection = self.connection.clone();
462        let renewed: i32 = Script::new(
463            r"
464            if redis.call('GET', KEYS[1]) == ARGV[1] then
465                redis.call('PEXPIRE', KEYS[1], ARGV[2])
466                redis.call('ZADD', KEYS[2], ARGV[3], KEYS[1])
467                redis.call('HSET', KEYS[3], ARGV[4], ARGV[3])
468                return 1
469            end
470            redis.call('ZREM', KEYS[2], KEYS[1])
471            return 0
472            ",
473        )
474        .key(&lease.lease_key)
475        .key(self.occurrence_index_key(&lease.resource_id))
476        .key(self.state_key(&lease.job_id))
477        .arg(&lease.token)
478        .arg(ttl_millis)
479        .arg(expires_at_millis)
480        .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
481        .invoke_async(&mut connection)
482        .await
483        .map_err(ValkeyStoreError::from)?;
484        Ok(if renewed == 1 {
485            ExecutionGuardRenewal::Renewed
486        } else {
487            ExecutionGuardRenewal::Lost
488        })
489    }
490
491    async fn complete(
492        &self,
493        job_id: &str,
494        revision: u64,
495        lease: &ExecutionLease,
496        state: &JobState,
497    ) -> Result<bool, Self::Error> {
498        let key = self.state_key(job_id);
499        let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
500        let mut connection = self.connection.clone();
501        let completed: i32 = Script::new(
502            r"
503            local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
504            local token = redis.call('HGET', KEYS[1], ARGV[2])
505            if version ~= tonumber(ARGV[3]) then
506                return 0
507            end
508            if token ~= ARGV[4] then
509                return 0
510            end
511            redis.call('DEL', KEYS[2])
512            redis.call('ZREM', KEYS[3], KEYS[2])
513            redis.call('HSET', KEYS[1], ARGV[1], version + 1, ARGV[5], ARGV[6])
514            redis.call('HDEL', KEYS[1], ARGV[2], ARGV[7], ARGV[8], ARGV[9], ARGV[10], ARGV[11], ARGV[12])
515            return 1
516            ",
517        )
518        .key(key)
519        .key(&lease.lease_key)
520        .key(self.occurrence_index_key(&lease.resource_id))
521        .arg(FIELD_VERSION)
522        .arg(FIELD_INFLIGHT_TOKEN)
523        .arg(revision)
524        .arg(&lease.token)
525        .arg(FIELD_STATE)
526        .arg(payload)
527        .arg(FIELD_INFLIGHT_SCHEDULED_AT)
528        .arg(FIELD_INFLIGHT_CATCH_UP)
529        .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
530        .arg(FIELD_INFLIGHT_RESOURCE_ID)
531        .arg(FIELD_INFLIGHT_SCOPE)
532        .arg(FIELD_INFLIGHT_LEASE_KEY)
533        .invoke_async(&mut connection)
534        .await
535        .map_err(ValkeyStoreError::from)?;
536        Ok(completed == 1)
537    }
538
539    async fn delete(&self, job_id: &str) -> Result<(), Self::Error> {
540        let key = self.state_key(job_id);
541        let mut connection = self.connection.clone();
542        let _: () = cmd("DEL")
543            .arg(key)
544            .query_async(&mut connection)
545            .await
546            .map_err(ValkeyStoreError::from)?;
547        Ok(())
548    }
549
550    fn classify_store_error(error: &Self::Error) -> StoreErrorKind
551    where
552        Self: Sized,
553    {
554        if matches!(error, ValkeyStoreError::Codec(_)) {
555            StoreErrorKind::Data
556        } else if error.is_connection_issue() {
557            StoreErrorKind::Connection
558        } else {
559            StoreErrorKind::Unknown
560        }
561    }
562
563    fn classify_guard_error(error: &Self::Error) -> ExecutionGuardErrorKind
564    where
565        Self: Sized,
566    {
567        if matches!(error, ValkeyStoreError::Codec(_)) {
568            ExecutionGuardErrorKind::Data
569        } else if error.is_connection_issue() {
570            ExecutionGuardErrorKind::Connection
571        } else {
572            ExecutionGuardErrorKind::Unknown
573        }
574    }
575}
576
577fn parse_runtime_state(
578    fields: &HashMap<String, String>,
579) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
580    let revision = fields
581        .get(FIELD_VERSION)
582        .and_then(|value| value.parse::<u64>().ok())
583        .unwrap_or(0);
584    let state = serde_json::from_str(fields.get(FIELD_STATE).map(String::as_str).unwrap_or("{}"))
585        .map_err(ValkeyStoreError::from)?;
586    Ok(CoordinatedRuntimeState { state, revision })
587}
588
589fn parse_scope(raw: &str) -> ExecutionGuardScope {
590    match raw {
591        "resource" => ExecutionGuardScope::Resource,
592        _ => ExecutionGuardScope::Occurrence,
593    }
594}