1use crate::coordinated_store::{
2 CoordinatedClaim, CoordinatedLeaseConfig, CoordinatedPendingTrigger, CoordinatedRuntimeState,
3 CoordinatedStateStore,
4};
5use crate::error::{ExecutionGuardErrorKind, StoreErrorKind};
6use crate::execution_guard::{ExecutionGuardRenewal, ExecutionGuardScope, ExecutionLease};
7use crate::model::JobState;
8use crate::valkey_execution_support::{
9 next_token, now_millis, occurrence_index_key, occurrence_lease_key, resource_lock_key,
10};
11use crate::valkey_store::ValkeyStoreError;
12use chrono::SecondsFormat;
13use chrono::{DateTime, Utc};
14use redis::{AsyncCommands, Client, Script, aio::ConnectionManager, cmd};
15use std::collections::HashMap;
16use std::sync::atomic::AtomicU64;
17
18const DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:valkey:job-state:";
19const LEGACY_DEFAULT_STATE_KEY_PREFIX: &str = "scheduler:job-state:";
20const DEFAULT_EXECUTION_KEY_PREFIX: &str = "scheduler:valkey:execution-lease:";
21const SAVE_STATE_LUA: &str = include_str!("lua/valkey_coordinated_store/save_state.lua");
22const RECLAIM_INFLIGHT_LUA: &str =
23 include_str!("lua/valkey_coordinated_store/reclaim_inflight.lua");
24const CLAIM_TRIGGER_LUA: &str = include_str!("lua/valkey_coordinated_store/claim_trigger.lua");
25const RENEW_LEASE_LUA: &str = include_str!("lua/valkey_coordinated_store/renew_lease.lua");
26const COMPLETE_LUA: &str = include_str!("lua/valkey_coordinated_store/complete.lua");
27const PAUSE_LUA: &str = include_str!("lua/valkey_coordinated_store/pause.lua");
28const RESUME_LUA: &str = include_str!("lua/valkey_coordinated_store/resume.lua");
29
30const FIELD_VERSION: &str = "version";
31const FIELD_STATE: &str = "state";
32const FIELD_PAUSED: &str = "paused";
33const FIELD_INFLIGHT_SCHEDULED_AT: &str = "inflight_scheduled_at";
34const FIELD_INFLIGHT_CATCH_UP: &str = "inflight_catch_up";
35const FIELD_INFLIGHT_TRIGGER_COUNT: &str = "inflight_trigger_count";
36const FIELD_INFLIGHT_RESOURCE_ID: &str = "inflight_resource_id";
37const FIELD_INFLIGHT_SCOPE: &str = "inflight_scope";
38const FIELD_INFLIGHT_TOKEN: &str = "inflight_token";
39const FIELD_INFLIGHT_LEASE_KEY: &str = "inflight_lease_key";
40const FIELD_INFLIGHT_LEASE_EXPIRES_AT: &str = "inflight_lease_expires_at";
41
42static COORDINATED_TOKEN_COUNTER: AtomicU64 = AtomicU64::new(1);
43
44#[derive(Debug, Clone)]
45pub struct ValkeyCoordinatedStateStore {
46 connection: ConnectionManager,
47 state_key_prefix: String,
48 execution_key_prefix: String,
49}
50
51impl ValkeyCoordinatedStateStore {
52 pub async fn new(url: impl AsRef<str>) -> Result<Self, redis::RedisError> {
53 Self::with_prefixes(url, DEFAULT_STATE_KEY_PREFIX, DEFAULT_EXECUTION_KEY_PREFIX).await
54 }
55
56 pub async fn with_prefixes(
57 url: impl AsRef<str>,
58 state_key_prefix: impl Into<String>,
59 execution_key_prefix: impl Into<String>,
60 ) -> Result<Self, redis::RedisError> {
61 let client = Client::open(url.as_ref())?;
62 let connection = client.get_connection_manager().await?;
63 Ok(Self {
64 connection,
65 state_key_prefix: state_key_prefix.into(),
66 execution_key_prefix: execution_key_prefix.into(),
67 })
68 }
69
70 fn state_key(&self, job_id: &str) -> String {
71 format!("{}{}", self.state_key_prefix, job_id)
72 }
73
74 fn legacy_state_key(&self, job_id: &str) -> Option<String> {
75 if self.state_key_prefix == DEFAULT_STATE_KEY_PREFIX {
76 Some(format!("{LEGACY_DEFAULT_STATE_KEY_PREFIX}{job_id}"))
77 } else {
78 None
79 }
80 }
81
82 fn resource_lock_key(&self, resource_id: &str) -> String {
83 resource_lock_key(&self.execution_key_prefix, resource_id)
84 }
85
86 fn occurrence_index_key(&self, resource_id: &str) -> String {
87 occurrence_index_key(&self.execution_key_prefix, resource_id)
88 }
89
90 fn occurrence_lease_key(&self, resource_id: &str, scheduled_at: DateTime<Utc>) -> String {
91 occurrence_lease_key(&self.execution_key_prefix, resource_id, scheduled_at)
92 }
93
94 async fn key_type(&self, key: &str) -> Result<String, ValkeyStoreError> {
95 let mut connection = self.connection.clone();
96 cmd("TYPE")
97 .arg(key)
98 .query_async(&mut connection)
99 .await
100 .map_err(ValkeyStoreError::from)
101 }
102
103 async fn load_hash(
104 &self,
105 key: &str,
106 ) -> Result<Option<CoordinatedRuntimeState>, ValkeyStoreError> {
107 let mut connection = self.connection.clone();
108 let fields: HashMap<String, String> = connection
109 .hgetall(key)
110 .await
111 .map_err(ValkeyStoreError::from)?;
112 if fields.is_empty() {
113 return Ok(None);
114 }
115
116 Ok(Some(parse_runtime_state(&fields)?))
117 }
118
119 async fn migrate_string_state(
120 &self,
121 key: &str,
122 payload: String,
123 ) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
124 let state: JobState = serde_json::from_str(&payload).map_err(ValkeyStoreError::from)?;
125 let runtime = CoordinatedRuntimeState {
126 state,
127 revision: 0,
128 paused: false,
129 };
130 self.write_runtime(key, &runtime).await?;
131 Ok(runtime)
132 }
133
134 async fn write_runtime(
135 &self,
136 key: &str,
137 runtime: &CoordinatedRuntimeState,
138 ) -> Result<(), ValkeyStoreError> {
139 let mut connection = self.connection.clone();
140 let payload = serde_json::to_string(&runtime.state).map_err(ValkeyStoreError::from)?;
141 let _: () = cmd("DEL")
142 .arg(key)
143 .query_async(&mut connection)
144 .await
145 .map_err(ValkeyStoreError::from)?;
146 let _: () = cmd("HSET")
147 .arg(key)
148 .arg(FIELD_VERSION)
149 .arg(runtime.revision)
150 .arg(FIELD_STATE)
151 .arg(payload)
152 .arg(FIELD_PAUSED)
153 .arg(if runtime.paused { "1" } else { "0" })
154 .query_async(&mut connection)
155 .await
156 .map_err(ValkeyStoreError::from)?;
157 Ok(())
158 }
159
160 async fn load_payload_state(&self, key: &str) -> Result<Option<String>, ValkeyStoreError> {
161 let mut connection = self.connection.clone();
162 connection.get(key).await.map_err(ValkeyStoreError::from)
163 }
164}
165
166impl CoordinatedStateStore for ValkeyCoordinatedStateStore {
167 type Error = ValkeyStoreError;
168
169 async fn load_or_initialize(
170 &self,
171 job_id: &str,
172 initial_state: JobState,
173 ) -> Result<CoordinatedRuntimeState, Self::Error> {
174 let key = self.state_key(job_id);
175 match self.key_type(&key).await?.as_str() {
176 "hash" => {
177 if let Some(runtime) = self.load_hash(&key).await? {
178 return Ok(runtime);
179 }
180 }
181 "string" => {
182 if let Some(payload) = self.load_payload_state(&key).await? {
183 return self.migrate_string_state(&key, payload).await;
184 }
185 }
186 "none" => {}
187 _ => {}
188 }
189
190 if let Some(legacy_key) = self.legacy_state_key(job_id) {
191 if self.key_type(&legacy_key).await?.as_str() == "string" {
192 if let Some(payload) = self.load_payload_state(&legacy_key).await? {
193 let runtime = self.migrate_string_state(&key, payload).await?;
194 let mut connection = self.connection.clone();
195 let _: () = cmd("DEL")
196 .arg(legacy_key)
197 .query_async(&mut connection)
198 .await
199 .map_err(ValkeyStoreError::from)?;
200 return Ok(runtime);
201 }
202 }
203 }
204
205 let runtime = CoordinatedRuntimeState {
206 state: initial_state,
207 revision: 0,
208 paused: false,
209 };
210 self.write_runtime(&key, &runtime).await?;
211 Ok(runtime)
212 }
213
214 async fn save_state(
215 &self,
216 job_id: &str,
217 revision: u64,
218 state: &JobState,
219 ) -> Result<bool, Self::Error> {
220 let key = self.state_key(job_id);
221 let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
222 let mut connection = self.connection.clone();
223 let updated: i32 = Script::new(SAVE_STATE_LUA)
224 .key(key)
225 .arg(FIELD_VERSION)
226 .arg(revision)
227 .arg(FIELD_INFLIGHT_TOKEN)
228 .arg(FIELD_STATE)
229 .arg(payload)
230 .invoke_async(&mut connection)
231 .await
232 .map_err(ValkeyStoreError::from)?;
233 Ok(updated == 1)
234 }
235
236 async fn reclaim_inflight(
237 &self,
238 job_id: &str,
239 resource_id: &str,
240 lease_config: CoordinatedLeaseConfig,
241 ) -> Result<Option<CoordinatedClaim>, Self::Error> {
242 let key = self.state_key(job_id);
243 let lease_key = self.occurrence_lease_key(resource_id, Utc::now());
244 let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
245 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
246 let now_millis = now_millis();
247 let expires_at_millis = now_millis.saturating_add(ttl_millis);
248 let mut connection = self.connection.clone();
249 let result: Option<Vec<String>> = Script::new(RECLAIM_INFLIGHT_LUA)
250 .key(key)
251 .key(self.resource_lock_key(resource_id))
252 .key(lease_key.clone())
253 .key(self.occurrence_index_key(resource_id))
254 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
259 .arg(FIELD_INFLIGHT_CATCH_UP)
260 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
261 .arg(FIELD_INFLIGHT_RESOURCE_ID)
262 .arg(FIELD_INFLIGHT_SCOPE)
263 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
264 .arg(FIELD_STATE)
265 .arg(FIELD_VERSION)
266 .arg(FIELD_PAUSED)
267 .arg(now_millis)
268 .arg(format!(
269 "{}{}:occurrence:",
270 self.execution_key_prefix, resource_id
271 ))
272 .arg(&token)
273 .arg(ttl_millis)
274 .arg(expires_at_millis)
275 .arg(FIELD_INFLIGHT_TOKEN)
276 .arg(FIELD_INFLIGHT_LEASE_KEY)
277 .invoke_async(&mut connection)
278 .await
279 .map_err(ValkeyStoreError::from)?;
280
281 let Some(values) = result else {
282 return Ok(None);
283 };
284 if values.len() != 8 {
285 return Ok(None);
286 }
287 let revision = values[0].parse::<u64>().unwrap_or(0);
288 let state: JobState = serde_json::from_str(&values[1]).map_err(ValkeyStoreError::from)?;
289 let scheduled_at = DateTime::parse_from_rfc3339(&values[2])
290 .map_err(|error| {
291 ValkeyStoreError::Codec(serde_json::Error::io(std::io::Error::other(
292 error.to_string(),
293 )))
294 })?
295 .with_timezone(&Utc);
296 let catch_up = values[3].parse::<bool>().unwrap_or(false);
297 let trigger_count = values[4].parse::<u32>().unwrap_or(0);
298 let scope = parse_scope(&values[5]);
299 Ok(Some(CoordinatedClaim {
300 state: CoordinatedRuntimeState {
301 state,
302 revision,
303 paused: false,
304 },
305 trigger: CoordinatedPendingTrigger {
306 scheduled_at,
307 catch_up,
308 trigger_count,
309 },
310 lease: ExecutionLease::new(
311 job_id.to_string(),
312 resource_id.to_string(),
313 scope,
314 Some(scheduled_at),
315 values[7].clone(),
316 values[6].clone(),
317 ),
318 replayed: true,
319 }))
320 }
321
322 async fn claim_trigger(
323 &self,
324 job_id: &str,
325 resource_id: &str,
326 revision: u64,
327 trigger: CoordinatedPendingTrigger,
328 next_state: &JobState,
329 lease_config: CoordinatedLeaseConfig,
330 ) -> Result<Option<CoordinatedClaim>, Self::Error> {
331 let key = self.state_key(job_id);
332 let lease_key = self.occurrence_lease_key(resource_id, trigger.scheduled_at);
333 let token = next_token(&COORDINATED_TOKEN_COUNTER, "coord");
334 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
335 let now_millis = now_millis();
336 let expires_at_millis = now_millis.saturating_add(ttl_millis);
337 let next_state_payload =
338 serde_json::to_string(next_state).map_err(ValkeyStoreError::from)?;
339 let mut connection = self.connection.clone();
340 let new_revision: i64 = Script::new(CLAIM_TRIGGER_LUA)
341 .key(key)
342 .key(self.resource_lock_key(resource_id))
343 .key(&lease_key)
344 .key(self.occurrence_index_key(resource_id))
345 .arg(FIELD_VERSION)
346 .arg(FIELD_INFLIGHT_TOKEN)
347 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
348 .arg(FIELD_PAUSED)
349 .arg(now_millis)
350 .arg(revision)
351 .arg(&token)
352 .arg(ttl_millis)
353 .arg(expires_at_millis)
354 .arg(FIELD_STATE)
355 .arg(next_state_payload)
356 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
357 .arg(
358 trigger
359 .scheduled_at
360 .to_rfc3339_opts(SecondsFormat::Nanos, true),
361 )
362 .arg(FIELD_INFLIGHT_CATCH_UP)
363 .arg(trigger.catch_up)
364 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
365 .arg(trigger.trigger_count)
366 .arg(FIELD_INFLIGHT_RESOURCE_ID)
367 .arg(resource_id)
368 .arg(FIELD_INFLIGHT_SCOPE)
369 .arg("occurrence")
370 .arg(FIELD_INFLIGHT_TOKEN)
371 .arg(FIELD_INFLIGHT_LEASE_KEY)
372 .invoke_async(&mut connection)
373 .await
374 .map_err(ValkeyStoreError::from)?;
375
376 if new_revision <= 0 {
377 return Ok(None);
378 }
379
380 Ok(Some(CoordinatedClaim {
381 state: CoordinatedRuntimeState {
382 state: next_state.clone(),
383 revision: new_revision as u64,
384 paused: false,
385 },
386 trigger: trigger.clone(),
387 lease: ExecutionLease::new(
388 job_id.to_string(),
389 resource_id.to_string(),
390 ExecutionGuardScope::Occurrence,
391 Some(trigger.scheduled_at),
392 token,
393 lease_key,
394 ),
395 replayed: false,
396 }))
397 }
398
399 async fn renew(
400 &self,
401 lease: &ExecutionLease,
402 lease_config: CoordinatedLeaseConfig,
403 ) -> Result<ExecutionGuardRenewal, Self::Error> {
404 let ttl_millis = u64::try_from(lease_config.ttl.as_millis()).unwrap_or(u64::MAX);
405 let expires_at_millis = now_millis().saturating_add(ttl_millis);
406 let mut connection = self.connection.clone();
407 let renewed: i32 = Script::new(RENEW_LEASE_LUA)
408 .key(&lease.lease_key)
409 .key(self.occurrence_index_key(&lease.resource_id))
410 .key(self.state_key(&lease.job_id))
411 .arg(&lease.token)
412 .arg(ttl_millis)
413 .arg(expires_at_millis)
414 .arg(FIELD_INFLIGHT_LEASE_EXPIRES_AT)
415 .invoke_async(&mut connection)
416 .await
417 .map_err(ValkeyStoreError::from)?;
418 Ok(if renewed == 1 {
419 ExecutionGuardRenewal::Renewed
420 } else {
421 ExecutionGuardRenewal::Lost
422 })
423 }
424
425 async fn complete(
426 &self,
427 job_id: &str,
428 revision: u64,
429 lease: &ExecutionLease,
430 state: &JobState,
431 ) -> Result<bool, Self::Error> {
432 let key = self.state_key(job_id);
433 let payload = serde_json::to_string(state).map_err(ValkeyStoreError::from)?;
434 let mut connection = self.connection.clone();
435 let completed: i32 = Script::new(COMPLETE_LUA)
436 .key(key)
437 .key(&lease.lease_key)
438 .key(self.occurrence_index_key(&lease.resource_id))
439 .arg(FIELD_VERSION)
440 .arg(FIELD_INFLIGHT_TOKEN)
441 .arg(revision)
442 .arg(&lease.token)
443 .arg(FIELD_STATE)
444 .arg(payload)
445 .arg(FIELD_INFLIGHT_SCHEDULED_AT)
446 .arg(FIELD_INFLIGHT_CATCH_UP)
447 .arg(FIELD_INFLIGHT_TRIGGER_COUNT)
448 .arg(FIELD_INFLIGHT_RESOURCE_ID)
449 .arg(FIELD_INFLIGHT_SCOPE)
450 .arg(FIELD_INFLIGHT_LEASE_KEY)
451 .invoke_async(&mut connection)
452 .await
453 .map_err(ValkeyStoreError::from)?;
454 Ok(completed == 1)
455 }
456
457 async fn delete(&self, job_id: &str) -> Result<(), Self::Error> {
458 let key = self.state_key(job_id);
459 let mut connection = self.connection.clone();
460 let _: () = cmd("DEL")
461 .arg(key)
462 .query_async(&mut connection)
463 .await
464 .map_err(ValkeyStoreError::from)?;
465 Ok(())
466 }
467
468 async fn pause(&self, job_id: &str) -> Result<bool, Self::Error> {
469 let key = self.state_key(job_id);
470 let mut connection = self.connection.clone();
471 let changed: i32 = Script::new(PAUSE_LUA)
472 .key(key)
473 .arg(FIELD_PAUSED)
474 .invoke_async(&mut connection)
475 .await
476 .map_err(ValkeyStoreError::from)?;
477 Ok(changed == 1)
478 }
479
480 async fn resume(&self, job_id: &str) -> Result<bool, Self::Error> {
481 let key = self.state_key(job_id);
482 let mut connection = self.connection.clone();
483 let changed: i32 = Script::new(RESUME_LUA)
484 .key(key)
485 .arg(FIELD_PAUSED)
486 .invoke_async(&mut connection)
487 .await
488 .map_err(ValkeyStoreError::from)?;
489 Ok(changed == 1)
490 }
491
492 fn classify_store_error(error: &Self::Error) -> StoreErrorKind
493 where
494 Self: Sized,
495 {
496 if matches!(error, ValkeyStoreError::Codec(_)) {
497 StoreErrorKind::Data
498 } else if error.is_connection_issue() {
499 StoreErrorKind::Connection
500 } else {
501 StoreErrorKind::Unknown
502 }
503 }
504
505 fn classify_guard_error(error: &Self::Error) -> ExecutionGuardErrorKind
506 where
507 Self: Sized,
508 {
509 if matches!(error, ValkeyStoreError::Codec(_)) {
510 ExecutionGuardErrorKind::Data
511 } else if error.is_connection_issue() {
512 ExecutionGuardErrorKind::Connection
513 } else {
514 ExecutionGuardErrorKind::Unknown
515 }
516 }
517}
518
519fn parse_runtime_state(
520 fields: &HashMap<String, String>,
521) -> Result<CoordinatedRuntimeState, ValkeyStoreError> {
522 let revision = fields
523 .get(FIELD_VERSION)
524 .and_then(|value| value.parse::<u64>().ok())
525 .unwrap_or(0);
526 let paused = fields
527 .get(FIELD_PAUSED)
528 .map(|value| value == "1" || value.eq_ignore_ascii_case("true"))
529 .unwrap_or(false);
530 let state = serde_json::from_str(fields.get(FIELD_STATE).map(String::as_str).unwrap_or("{}"))
531 .map_err(ValkeyStoreError::from)?;
532 Ok(CoordinatedRuntimeState {
533 state,
534 revision,
535 paused,
536 })
537}
538
539fn parse_scope(raw: &str) -> ExecutionGuardScope {
540 match raw {
541 "resource" => ExecutionGuardScope::Resource,
542 _ => ExecutionGuardScope::Occurrence,
543 }
544}