1use crate::coordinated_store::{
2 CoordinatedClaim, CoordinatedLeaseConfig, CoordinatedPendingTrigger, CoordinatedRuntimeState,
3 CoordinatedStateStore,
4};
5use crate::error::{ExecutionGuardErrorKind, StoreErrorKind};
6use crate::execution_guard::{ExecutionGuardRenewal, ExecutionGuardScope, ExecutionLease};
7use crate::model::JobState;
8use crate::valkey_execution_support::{
9 next_token, now_millis, occurrence_index_key, occurrence_lease_key, resource_lock_key,
10};
11use crate::valkey_store::ValkeyStoreError;
12use chrono::SecondsFormat;
13use chrono::{DateTime, Utc};
14use redis::{AsyncCommands, Client, Script, aio::ConnectionManager, cmd};
15use std::collections::HashMap;
16use std::sync::atomic::AtomicU64;
17
18const DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:valkey:job-state:";
19const LEGACY_DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:job-state:";
20const DEFAULT_EXECUTION_KEY_PREFIX: &str = "scheduler:valkey:execution-lease:";
21
22const FIELD_VERSION: &str = "version";
23const FIELD_STATE: &str = "state";
24const FIELD_INFLIGHT_SCHEDULED_AT: &str = "inflight_scheduled_at";
25const FIELD_INFLIGHT_CATCH_UP: &str = "inflight_catch_up";
26const FIELD_INFLIGHT_TRIGGER_COUNT: &str = "inflight_trigger_count";
27const FIELD_INFLIGHT_RESOURCE_ID: &str = "inflight_resource_id";
28const FIELD_INFLIGHT_SCOPE: &str = "inflight_scope";
29const FIELD_INFLIGHT_TOKEN: &str = "inflight_token";
30const FIELD_INFLIGHT_LEASE_KEY: &str = "inflight_lease_key";
31const FIELD_INFLIGHT_LEASE_EXPIRES_AT: &str = "inflight_lease_expires_at";
32
33static COORDINATED_TOKEN_COUNTER: AtomicU64 = AtomicU64::new(1);
34
35#[derive(Debug, Clone)]
36pub struct ValkeyCoordinatedStateStore {
37 connection: ConnectionManager,
38 state_key_prefix: String,
39 execution_key_prefix: String,
40}
41
42impl ValkeyCoordinatedStateStore {
43 pub async fn new(url: impl AsRef<str>) -> Result<Self, redis::RedisError> {
44 Self::with_prefixes(url, DEFAULT_STATE_KEY_PREFIX, DEFAULT_EXECUTION_KEY_PREFIX).await
45 }
46
47 pub async fn with_prefixes(
48 url: impl AsRef<str>,
49 state_key_prefix: impl Into<String>,
50 execution_key_prefix: impl Into<String>,
51 ) -> Result<Self, redis::RedisError> {
52 let client = Client::open(url.as_ref())?;
53 let connection = client.get_connection_manager().await?;
54 Ok(Self {
55 connection,
56 state_key_prefix: state_key_prefix.into(),
57 execution_key_prefix: execution_key_prefix.into(),
58 })
59 }
60
61 fn state_key(&self, job_id: &str) -> String {
62 format!("{}{}", self.state_key_prefix, job_id)
63 }
64
65 fn legacy_state_key(&self, job_id: &str) -> Option<String> {
66 if self.state_key_prefix == DEFAULT_STATE_KEY_PREFIX {
67 Some(format!("{LEGACY_DEFAULT_STATE_KEY_PREFIX}{job_id}"))
68 } else {
69 None
70 }
71 }
72
73 fn resource_lock_key(&self, resource_id: &str) -> String {
74 resource_lock_key(&self.execution_key_prefix, resource_id)
75 }
76
77 fn occurrence_index_key(&self, resource_id: &str) -> String {
78 occurrence_index_key(&self.execution_key_prefix, resource_id)
79 }
80
81 fn occurrence_lease_key(&self, resource_id: &str, scheduled_at: DateTime<Utc>) -> String {
82 occurrence_lease_key(&self.execution_key_prefix, resource_id, scheduled_at)
83 }
84
85 async fn key_type(&self, key: &str) -> Result<String, ValkeyStoreError> {
86 let mut connection = self.connection.clone();
87 cmd("TYPE")
88 .arg(key)
89 .query_async(&mut connection)
90 .await
91 .map_err(ValkeyStoreError::from)
92 }
93
94 async fn load_hash(
95 &self,
96 key: &str,
97 ) -> Result<Option<CoordinatedRuntimeState>, ValkeyStoreError> {
98 let mut connection = self.connection.clone();
99 let fields: HashMap<String, String> = connection
100 .hgetall(key)
101 .await
102 .map_err(ValkeyStoreError::from)?;
103 if fields.is_empty() {
104 return Ok(None);
105 }
106
107 Ok(Some(parse_runtime_state(&fields)?))
108 }
109
110 async fn migrate_string_state(
111 &self,
112 key: &str,
113 payload: String,
114 ) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
115 let state: JobState = serde_json::from_str(&payload).map_err(ValkeyStoreError::from)?;
116 let runtime = CoordinatedRuntimeState { state, revision: 0 };
117 self.write_runtime(key, &runtime).await?;
118 Ok(runtime)
119 }
120
121 async fn write_runtime(
122 &self,
123 key: &str,
124 runtime: &CoordinatedRuntimeState,
125 ) -> Result<(), ValkeyStoreError> {
126 let mut connection = self.connection.clone();
127 let payload = serde_json::to_string(&runtime.state).map_err(ValkeyStoreError::from)?;
128 let _: () = cmd("DEL")
129 .arg(key)
130 .query_async(&mut connection)
131 .await
132 .map_err(ValkeyStoreError::from)?;
133 let _: () = cmd("HSET")
134 .arg(key)
135 .arg(FIELD_VERSION)
136 .arg(runtime.revision)
137 .arg(FIELD_STATE)
138 .arg(payload)
139 .query_async(&mut connection)
140 .await
141 .map_err(ValkeyStoreError::from)?;
142 Ok(())
143 }
144
145 async fn load_payload_state(&self, key: &str) -> Result<Option<String>, ValkeyStoreError> {
146 let mut connection = self.connection.clone();
147 connection.get(key).await.map_err(ValkeyStoreError::from)
148 }
149}
150
151impl CoordinatedStateStore for ValkeyCoordinatedStateStore {
152 type Error = ValkeyStoreError;
153
154 async fn load_or_initialize(
155 &self,
156 job_id: &str,
157 initial_state: JobState,
158 ) -> Result<CoordinatedRuntimeState, Self::Error> {
159 let key = self.state_key(job_id);
160 match self.key_type(&key).await?.as_str() {
161 "hash" => {
162 if let Some(runtime) = self.load_hash(&key).await? {
163 return Ok(runtime);
164 }
165 }
166 "string" => {
167 if let Some(payload) = self.load_payload_state(&key).await? {
168 return self.migrate_string_state(&key, payload).await;
169 }
170 }
171 "none" => {}
172 _ => {}
173 }
174
175 if let Some(legacy_key) = self.legacy_state_key(job_id) {
176 if self.key_type(&legacy_key).await?.as_str() == "string" {
177 if let Some(payload) = self.load_payload_state(&legacy_key).await? {
178 let runtime = self.migrate_string_state(&key, payload).await?;
179 let mut connection = self.connection.clone();
180 let _: () = cmd("DEL")
181 .arg(legacy_key)
182 .query_async(&mut connection)
183 .await
184 .map_err(ValkeyStoreError::from)?;
185 return Ok(runtime);
186 }
187 }
188 }
189
190 let runtime = CoordinatedRuntimeState {
191 state: initial_state,
192 revision: 0,
193 };
194 self.write_runtime(&key, &runtime).await?;
195 Ok(runtime)
196 }
197
198 async fn save_state(
199 &self,
200 job_id: &str,
201 revision: u64,
202 state: &JobState,
203 ) -> Result<bool, Self::Error> {
204 let key = self.state_key(job_id);
205 let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
206 let mut connection = self.connection.clone();
207 let updated: i32 = Script::new(
208 r"
209 local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
210 local inflight = redis.call('HGET', KEYS[1], ARGV[3])
211 if inflight then
212 return 0
213 end
214 if version ~= tonumber(ARGV[2]) then
215 return 0
216 end
217 redis.call('HSET', KEYS[1], ARGV[1], version + 1, ARGV[4], ARGV[5])
218 return 1
219 ",
220 )
221 .key(key)
222 .arg(FIELD_VERSION)
223 .arg(revision)
224 .arg(FIELD_INFLIGHT_TOKEN)
225 .arg(FIELD_STATE)
226 .arg(payload)
227 .invoke_async(&mut connection)
228 .await
229 .map_err(ValkeyStoreError::from)?;
230 Ok(updated == 1)
231 }
232
233 async fn reclaim_inflight(
234 &self,
235 job_id: &str,
236 resource_id: &str,
237 lease_config: CoordinatedLeaseConfig,
238 ) -> Result<Option<CoordinatedClaim>, Self::Error> {
239 let key = self.state_key(job_id);
240 let lease_key = self.occurrence_lease_key(resource_id, Utc::now());
241 let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
242 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
243 let now_millis = now_millis();
244 let expires_at_millis = now_millis.saturating_add(ttl_millis);
245 let mut connection = self.connection.clone();
246 let result: Option<Vec<String>> = Script::new(
247 r"
248 local scheduled_at = redis.call('HGET', KEYS[1], ARGV[1])
249 local catch_up = redis.call('HGET', KEYS[1], ARGV[2])
250 local trigger_count = redis.call('HGET', KEYS[1], ARGV[3])
251 local inflight_resource_id = redis.call('HGET', KEYS[1], ARGV[4])
252 local inflight_scope = redis.call('HGET', KEYS[1], ARGV[5])
253 local inflight_expires_at = tonumber(redis.call('HGET', KEYS[1], ARGV[6]) or '0')
254 local state_payload = redis.call('HGET', KEYS[1], ARGV[7])
255 local version = tonumber(redis.call('HGET', KEYS[1], ARGV[8]) or '0')
256
257 if not scheduled_at or not inflight_resource_id or not inflight_scope then
258 return nil
259 end
260 if inflight_expires_at > tonumber(ARGV[9]) then
261 return nil
262 end
263 redis.call('ZREMRANGEBYSCORE', KEYS[4], '-inf', ARGV[9])
264 if redis.call('EXISTS', KEYS[2]) == 1 then
265 return nil
266 end
267 local new_lease_key = ARGV[10] .. scheduled_at
268 local ok = redis.call('SET', new_lease_key, ARGV[11], 'NX', 'PX', ARGV[12])
269 if not ok then
270 return nil
271 end
272 redis.call('ZADD', KEYS[4], ARGV[13], new_lease_key)
273 redis.call('HSET', KEYS[1],
274 ARGV[6], ARGV[13],
275 ARGV[14], ARGV[11],
276 ARGV[15], new_lease_key,
277 ARGV[8], version + 1
278 )
279 return { tostring(version + 1), state_payload, scheduled_at, catch_up, trigger_count, inflight_scope, new_lease_key, ARGV[11] }
280 ",
281 )
282 .key(key)
283 .key(self.resource_lock_key(resource_id))
284 .key(lease_key.clone())
285 .key(self.occurrence_index_key(resource_id))
286 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
287 .arg(FIELD_INFLIGHT_CATCH_UP)
288 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
289 .arg(FIELD_INFLIGHT_RESOURCE_ID)
290 .arg(FIELD_INFLIGHT_SCOPE)
291 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
292 .arg(FIELD_STATE)
293 .arg(FIELD_VERSION)
294 .arg(now_millis)
295 .arg(format!("{}{}:occurrence:", self.execution_key_prefix, resource_id))
296 .arg(&token)
297 .arg(ttl_millis)
298 .arg(expires_at_millis)
299 .arg(FIELD_INFLIGHT_TOKEN)
300 .arg(FIELD_INFLIGHT_LEASE_KEY)
301 .invoke_async(&mut connection)
302 .await
303 .map_err(ValkeyStoreError::from)?;
304
305 let Some(values) = result else {
306 return Ok(None);
307 };
308 if values.len() != 8 {
309 return Ok(None);
310 }
311 let revision = values[0].parse::<u64>().unwrap_or(0);
312 let state: JobState = serde_json::from_str(&values[1]).map_err(ValkeyStoreError::from)?;
313 let scheduled_at = DateTime::parse_from_rfc3339(&values[2])
314 .map_err(|error| {
315 ValkeyStoreError::Codec(serde_json::Error::io(std::io::Error::other(
316 error.to_string(),
317 )))
318 })?
319 .with_timezone(&Utc);
320 let catch_up = values[3].parse::<bool>().unwrap_or(false);
321 let trigger_count = values[4].parse::<u32>().unwrap_or(0);
322 let scope = parse_scope(&values[5]);
323 Ok(Some(CoordinatedClaim {
324 state: CoordinatedRuntimeState { state, revision },
325 trigger: CoordinatedPendingTrigger {
326 scheduled_at,
327 catch_up,
328 trigger_count,
329 },
330 lease: ExecutionLease::new(
331 job_id.to_string(),
332 resource_id.to_string(),
333 scope,
334 Some(scheduled_at),
335 values[7].clone(),
336 values[6].clone(),
337 ),
338 replayed: true,
339 }))
340 }
341
342 async fn claim_trigger(
343 &self,
344 job_id: &str,
345 resource_id: &str,
346 revision: u64,
347 trigger: CoordinatedPendingTrigger,
348 next_state: &JobState,
349 lease_config: CoordinatedLeaseConfig,
350 ) -> Result<Option<CoordinatedClaim>, Self::Error> {
351 let key = self.state_key(job_id);
352 let lease_key = self.occurrence_lease_key(resource_id, trigger.scheduled_at);
353 let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
354 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
355 let now_millis = now_millis();
356 let expires_at_millis = now_millis.saturating_add(ttl_millis);
357 let next_state_payload =
358 serde_json::to_string(next_state).map_err(ValkeyStoreError::from)?;
359 let mut connection = self.connection.clone();
360 let new_revision: i64 = Script::new(
361 r"
362 local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
363 local inflight = redis.call('HGET', KEYS[1], ARGV[2])
364 if inflight then
365 local inflight_expires_at = tonumber(redis.call('HGET', KEYS[1], ARGV[3]) or '0')
366 if inflight_expires_at > tonumber(ARGV[4]) then
367 return 0
368 end
369 return 0
370 end
371 if version ~= tonumber(ARGV[5]) then
372 return 0
373 end
374 redis.call('ZREMRANGEBYSCORE', KEYS[4], '-inf', ARGV[4])
375 if redis.call('EXISTS', KEYS[2]) == 1 then
376 return 0
377 end
378 local ok = redis.call('SET', KEYS[3], ARGV[6], 'NX', 'PX', ARGV[7])
379 if not ok then
380 return 0
381 end
382 redis.call('ZADD', KEYS[4], ARGV[8], KEYS[3])
383 redis.call('HSET', KEYS[1],
384 ARGV[1], version + 1,
385 ARGV[9], ARGV[10],
386 ARGV[11], ARGV[12],
387 ARGV[13], ARGV[14],
388 ARGV[15], ARGV[16],
389 ARGV[17], ARGV[18],
390 ARGV[19], ARGV[20],
391 ARGV[21], ARGV[6],
392 ARGV[22], KEYS[3],
393 ARGV[3], ARGV[8]
394 )
395 return version + 1
396 ",
397 )
398 .key(key)
399 .key(self.resource_lock_key(resource_id))
400 .key(&lease_key)
401 .key(self.occurrence_index_key(resource_id))
402 .arg(FIELD_VERSION)
403 .arg(FIELD_INFLIGHT_TOKEN)
404 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
405 .arg(now_millis)
406 .arg(revision)
407 .arg(&token)
408 .arg(ttl_millis)
409 .arg(expires_at_millis)
410 .arg(FIELD_STATE)
411 .arg(next_state_payload)
412 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
413 .arg(
414 trigger
415 .scheduled_at
416 .to_rfc3339_opts(SecondsFormat::Nanos, true),
417 )
418 .arg(FIELD_INFLIGHT_CATCH_UP)
419 .arg(trigger.catch_up)
420 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
421 .arg(trigger.trigger_count)
422 .arg(FIELD_INFLIGHT_RESOURCE_ID)
423 .arg(resource_id)
424 .arg(FIELD_INFLIGHT_SCOPE)
425 .arg("occurrence")
426 .arg(FIELD_INFLIGHT_TOKEN)
427 .arg(FIELD_INFLIGHT_LEASE_KEY)
428 .invoke_async(&mut connection)
429 .await
430 .map_err(ValkeyStoreError::from)?;
431
432 if new_revision <= 0 {
433 return Ok(None);
434 }
435
436 Ok(Some(CoordinatedClaim {
437 state: CoordinatedRuntimeState {
438 state: next_state.clone(),
439 revision: new_revision as u64,
440 },
441 trigger: trigger.clone(),
442 lease: ExecutionLease::new(
443 job_id.to_string(),
444 resource_id.to_string(),
445 ExecutionGuardScope::Occurrence,
446 Some(trigger.scheduled_at),
447 token,
448 lease_key,
449 ),
450 replayed: false,
451 }))
452 }
453
454 async fn renew(
455 &self,
456 lease: &ExecutionLease,
457 lease_config: CoordinatedLeaseConfig,
458 ) -> Result<ExecutionGuardRenewal, Self::Error> {
459 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
460 let expires_at_millis = now_millis().saturating_add(ttl_millis);
461 let mut connection = self.connection.clone();
462 let renewed: i32 = Script::new(
463 r"
464 if redis.call('GET', KEYS[1]) == ARGV[1] then
465 redis.call('PEXPIRE', KEYS[1], ARGV[2])
466 redis.call('ZADD', KEYS[2], ARGV[3], KEYS[1])
467 redis.call('HSET', KEYS[3], ARGV[4], ARGV[3])
468 return 1
469 end
470 redis.call('ZREM', KEYS[2], KEYS[1])
471 return 0
472 ",
473 )
474 .key(&lease.lease_key)
475 .key(self.occurrence_index_key(&lease.resource_id))
476 .key(self.state_key(&lease.job_id))
477 .arg(&lease.token)
478 .arg(ttl_millis)
479 .arg(expires_at_millis)
480 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
481 .invoke_async(&mut connection)
482 .await
483 .map_err(ValkeyStoreError::from)?;
484 Ok(if renewed == 1 {
485 ExecutionGuardRenewal::Renewed
486 } else {
487 ExecutionGuardRenewal::Lost
488 })
489 }
490
491 async fn complete(
492 &self,
493 job_id: &str,
494 revision: u64,
495 lease: &ExecutionLease,
496 state: &JobState,
497 ) -> Result<bool, Self::Error> {
498 let key = self.state_key(job_id);
499 let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
500 let mut connection = self.connection.clone();
501 let completed: i32 = Script::new(
502 r"
503 local version = tonumber(redis.call('HGET', KEYS[1], ARGV[1]) or '-1')
504 local token = redis.call('HGET', KEYS[1], ARGV[2])
505 if version ~= tonumber(ARGV[3]) then
506 return 0
507 end
508 if token ~= ARGV[4] then
509 return 0
510 end
511 redis.call('DEL', KEYS[2])
512 redis.call('ZREM', KEYS[3], KEYS[2])
513 redis.call('HSET', KEYS[1], ARGV[1], version + 1, ARGV[5], ARGV[6])
514 redis.call('HDEL', KEYS[1], ARGV[2], ARGV[7], ARGV[8], ARGV[9], ARGV[10], ARGV[11], ARGV[12])
515 return 1
516 ",
517 )
518 .key(key)
519 .key(&lease.lease_key)
520 .key(self.occurrence_index_key(&lease.resource_id))
521 .arg(FIELD_VERSION)
522 .arg(FIELD_INFLIGHT_TOKEN)
523 .arg(revision)
524 .arg(&lease.token)
525 .arg(FIELD_STATE)
526 .arg(payload)
527 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
528 .arg(FIELD_INFLIGHT_CATCH_UP)
529 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
530 .arg(FIELD_INFLIGHT_RESOURCE_ID)
531 .arg(FIELD_INFLIGHT_SCOPE)
532 .arg(FIELD_INFLIGHT_LEASE_KEY)
533 .invoke_async(&mut connection)
534 .await
535 .map_err(ValkeyStoreError::from)?;
536 Ok(completed == 1)
537 }
538
539 async fn delete(&self, job_id: &str) -> Result<(), Self::Error> {
540 let key = self.state_key(job_id);
541 let mut connection = self.connection.clone();
542 let _: () = cmd("DEL")
543 .arg(key)
544 .query_async(&mut connection)
545 .await
546 .map_err(ValkeyStoreError::from)?;
547 Ok(())
548 }
549
550 fn classify_store_error(error: &Self::Error) -> StoreErrorKind
551 where
552 Self: Sized,
553 {
554 if matches!(error, ValkeyStoreError::Codec(_)) {
555 StoreErrorKind::Data
556 } else if error.is_connection_issue() {
557 StoreErrorKind::Connection
558 } else {
559 StoreErrorKind::Unknown
560 }
561 }
562
563 fn classify_guard_error(error: &Self::Error) -> ExecutionGuardErrorKind
564 where
565 Self: Sized,
566 {
567 if matches!(error, ValkeyStoreError::Codec(_)) {
568 ExecutionGuardErrorKind::Data
569 } else if error.is_connection_issue() {
570 ExecutionGuardErrorKind::Connection
571 } else {
572 ExecutionGuardErrorKind::Unknown
573 }
574 }
575}
576
577fn parse_runtime_state(
578 fields: &HashMap<String, String>,
579) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
580 let revision = fields
581 .get(FIELD_VERSION)
582 .and_then(|value| value.parse::<u64>().ok())
583 .unwrap_or(0);
584 let state = serde_json::from_str(fields.get(FIELD_STATE).map(String::as_str).unwrap_or("{}"))
585 .map_err(ValkeyStoreError::from)?;
586 Ok(CoordinatedRuntimeState { state, revision })
587}
588
589fn parse_scope(raw: &str) -> ExecutionGuardScope {
590 match raw {
591 "resource" => ExecutionGuardScope::Resource,
592 _ => ExecutionGuardScope::Occurrence,
593 }
594}