1use super::{DistributedScheduler, TickDecision};
2use crate::beat::schedule::ScheduleDescriptor;
3use crate::beat::scheduled_task::ScheduledTask;
4use crate::error::BeatError;
5use hostname::get as hostname_get;
6use log::{info, warn};
7use redis::{AsyncCommands, Client, Script};
8use std::collections::{BinaryHeap, HashMap};
9use std::future::Future;
10use std::pin::Pin;
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12use uuid::Uuid;
13
14const DEFAULT_KEY_PREFIX: &str = "celery_beat";
15const LOCK_RENEW_SCRIPT: &str = "if redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('PEXPIRE', KEYS[1], ARGV[2]) else return 0 end";
16const LOCK_RELEASE_SCRIPT: &str = "if redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('DEL', KEYS[1]) else return 0 end";
17
18fn ensure_min_duration(duration: Duration) -> Duration {
19 if duration.is_zero() {
20 Duration::from_millis(1)
21 } else {
22 duration
23 }
24}
25
26fn leader_sleep_hint(lock_renewal_interval: Duration) -> Duration {
27 let half = lock_renewal_interval.as_secs_f64() / 2.0;
28 if half < 0.001 {
29 Duration::from_millis(1)
30 } else {
31 Duration::from_secs_f64(half)
32 }
33}
34
35fn generate_instance_id(prefix: &str) -> String {
36 let host = hostname_get()
37 .map(|s| s.to_string_lossy().into_owned())
38 .unwrap_or_else(|_| "unknown-host".to_string());
39 format!("{}:{}:{}", prefix, host, Uuid::new_v4())
40}
41
42fn system_time_to_epoch(time: SystemTime) -> u64 {
43 time.duration_since(UNIX_EPOCH)
44 .unwrap_or_else(|_| Duration::from_secs(0))
45 .as_secs()
46}
47
48fn epoch_to_system_time(epoch: u64) -> SystemTime {
49 UNIX_EPOCH + Duration::from_secs(epoch)
50}
51
52#[derive(Clone)]
53pub struct RedisBackendConfig {
54 redis_url: String,
55 key_prefix: String,
56 lock_timeout: Duration,
57 lock_renewal_interval: Duration,
58 follower_check_interval: Duration,
59 sync_interval: Duration,
60 follower_idle_sleep: Duration,
61 instance_id: Option<String>,
62}
63
64impl RedisBackendConfig {
65 pub fn new(redis_url: impl Into<String>) -> Self {
66 Self {
67 redis_url: redis_url.into(),
68 key_prefix: DEFAULT_KEY_PREFIX.to_string(),
69 lock_timeout: Duration::from_secs(30),
70 lock_renewal_interval: Duration::from_secs(10),
71 follower_check_interval: Duration::from_secs(5),
72 sync_interval: Duration::from_secs(5),
73 follower_idle_sleep: Duration::from_millis(750),
74 instance_id: None,
75 }
76 }
77
78 pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
79 self.key_prefix = prefix.into();
80 self
81 }
82
83 pub fn lock_timeout(mut self, timeout: Duration) -> Self {
84 self.lock_timeout = timeout;
85 self
86 }
87
88 pub fn lock_renewal_interval(mut self, interval: Duration) -> Self {
89 self.lock_renewal_interval = interval;
90 self
91 }
92
93 pub fn follower_check_interval(mut self, interval: Duration) -> Self {
94 self.follower_check_interval = interval;
95 self
96 }
97
98 pub fn sync_interval(mut self, interval: Duration) -> Self {
99 self.sync_interval = interval;
100 self
101 }
102
103 pub fn follower_idle_sleep(mut self, interval: Duration) -> Self {
104 self.follower_idle_sleep = interval;
105 self
106 }
107
108 pub fn instance_id(mut self, id: impl Into<String>) -> Self {
109 self.instance_id = Some(id.into());
110 self
111 }
112
113 pub fn resolve(self) -> ResolvedRedisBackendConfig {
114 let RedisBackendConfig {
115 redis_url,
116 key_prefix,
117 lock_timeout,
118 lock_renewal_interval,
119 follower_check_interval,
120 sync_interval,
121 follower_idle_sleep,
122 instance_id,
123 } = self;
124
125 let instance_id = instance_id.unwrap_or_else(|| generate_instance_id(&key_prefix));
126
127 ResolvedRedisBackendConfig {
128 redis_url,
129 key_prefix: key_prefix.clone(),
130 lock_key: format!("{}:lock", key_prefix),
131 schedule_key: format!("{}:schedule", key_prefix),
132 instance_id,
133 lock_timeout,
134 lock_renewal_interval,
135 follower_check_interval,
136 sync_interval,
137 follower_idle_sleep,
138 }
139 }
140}
141
142#[derive(Clone)]
143pub struct ResolvedRedisBackendConfig {
144 pub redis_url: String,
145 pub key_prefix: String,
146 pub lock_key: String,
147 pub schedule_key: String,
148 pub instance_id: String,
149 pub lock_timeout: Duration,
150 pub lock_renewal_interval: Duration,
151 pub follower_check_interval: Duration,
152 pub sync_interval: Duration,
153 pub follower_idle_sleep: Duration,
154}
155
156impl ResolvedRedisBackendConfig {
157 fn task_key(&self, name: &str) -> String {
158 format!("{}:task:{}", self.key_prefix, name)
159 }
160
161 fn lock_ttl_millis(&self) -> usize {
162 self.lock_timeout.as_millis() as usize
163 }
164}
165
166pub struct RedisSchedulerBackend {
167 config: ResolvedRedisBackendConfig,
168 client: Client,
169 state: BackendState,
170}
171
172struct BackendState {
173 is_leader: bool,
174 last_lock_refresh: Option<Instant>,
175 last_leader_attempt: Option<Instant>,
176 last_sync: Option<Instant>,
177 local_snapshot: HashMap<String, TaskState>,
178 pending_full_refresh: bool,
179}
180
181#[derive(Clone, Debug, PartialEq)]
182struct TaskState {
183 descriptor: ScheduleDescriptor,
184 next_run_at: SystemTime,
185 last_run_at: Option<SystemTime>,
186 total_run_count: u32,
187}
188
189impl RedisSchedulerBackend {
190 pub fn new(config: RedisBackendConfig) -> Result<Self, BeatError> {
191 let resolved = config.resolve();
192 let client = Client::open(resolved.redis_url.as_str())
193 .map_err(|err| BeatError::RedisError(err.to_string()))?;
194
195 Ok(Self {
196 config: resolved,
197 client,
198 state: BackendState {
199 is_leader: false,
200 last_lock_refresh: None,
201 last_leader_attempt: None,
202 last_sync: None,
203 local_snapshot: HashMap::new(),
204 pending_full_refresh: false,
205 },
206 })
207 }
208
209 async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection, BeatError> {
210 self.client
211 .get_multiplexed_async_connection()
212 .await
213 .map_err(|err| BeatError::RedisError(err.to_string()))
214 }
215
216 async fn try_acquire_lock(&mut self) -> Result<bool, BeatError> {
217 let mut conn = self.get_connection().await?;
218 let result: Option<String> = redis::cmd("SET")
219 .arg(&self.config.lock_key)
220 .arg(&self.config.instance_id)
221 .arg("NX")
222 .arg("PX")
223 .arg(self.config.lock_ttl_millis())
224 .query_async(&mut conn)
225 .await
226 .map_err(|err| BeatError::RedisError(err.to_string()))?;
227
228 if result.is_some() {
229 info!("Redis scheduler backend acquired leadership");
230 self.state.last_lock_refresh = Some(Instant::now());
231 self.state.is_leader = true;
232 self.state.pending_full_refresh = true;
233 Ok(true)
234 } else {
235 Ok(false)
236 }
237 }
238
239 async fn renew_lock(&mut self) -> Result<(), BeatError> {
240 let mut conn = self.get_connection().await?;
241 let script = Script::new(LOCK_RENEW_SCRIPT);
242 let result: i32 = script
243 .key(&self.config.lock_key)
244 .arg(&self.config.instance_id)
245 .arg(self.config.lock_ttl_millis())
246 .invoke_async(&mut conn)
247 .await
248 .map_err(|err| BeatError::RedisError(err.to_string()))?;
249
250 if result == 1 {
251 self.state.last_lock_refresh = Some(Instant::now());
252 Ok(())
253 } else {
254 Err(BeatError::RedisError("lost leadership".into()))
255 }
256 }
257
258 async fn release_lock(&mut self) -> Result<(), BeatError> {
259 let mut conn = self.get_connection().await?;
260 let script = Script::new(LOCK_RELEASE_SCRIPT);
261 let _: i32 = script
262 .key(&self.config.lock_key)
263 .arg(&self.config.instance_id)
264 .invoke_async(&mut conn)
265 .await
266 .map_err(|err| BeatError::RedisError(err.to_string()))?;
267 Ok(())
268 }
269
270 fn collect_task_state(
271 &self,
272 scheduled_tasks: &BinaryHeap<ScheduledTask>,
273 ) -> (HashMap<String, TaskState>, Vec<String>) {
274 let mut map = HashMap::new();
275 let mut unsupported = Vec::new();
276
277 for task in scheduled_tasks.iter() {
278 let descriptor = match task.schedule.describe() {
279 Some(desc) => desc,
280 None => {
281 unsupported.push(task.name.clone());
282 continue;
283 }
284 };
285
286 map.insert(
287 task.name.clone(),
288 TaskState {
289 descriptor,
290 next_run_at: task.next_call_at,
291 last_run_at: task.last_run_at,
292 total_run_count: task.total_run_count,
293 },
294 );
295 }
296
297 (map, unsupported)
298 }
299
300 async fn apply_remote_state(
301 &mut self,
302 scheduled_tasks: &mut BinaryHeap<ScheduledTask>,
303 ) -> Result<(), BeatError> {
304 if scheduled_tasks.is_empty() {
305 self.state.local_snapshot.clear();
306 return Ok(());
307 }
308
309 let mut tasks = Vec::with_capacity(scheduled_tasks.len());
310 while let Some(task) = scheduled_tasks.pop() {
311 tasks.push(task);
312 }
313
314 let mut conn = self.get_connection().await?;
315 for task in tasks.iter_mut() {
316 let key = self.config.task_key(&task.name);
317 let data: HashMap<String, String> = conn
318 .hgetall(&key)
319 .await
320 .map_err(|err| BeatError::RedisError(err.to_string()))?;
321
322 if data.is_empty() {
323 continue;
324 }
325
326 if let Some(value) = data.get("last_run_at") {
327 if let Ok(epoch) = value.parse::<u64>() {
328 task.last_run_at = Some(epoch_to_system_time(epoch));
329 }
330 }
331 if let Some(value) = data.get("next_run_at") {
332 if let Ok(epoch) = value.parse::<u64>() {
333 task.next_call_at = epoch_to_system_time(epoch);
334 }
335 }
336 if let Some(value) = data.get("total_run_count") {
337 if let Ok(count) = value.parse::<u32>() {
338 task.total_run_count = count;
339 }
340 }
341 }
342
343 for task in tasks.into_iter() {
344 scheduled_tasks.push(task);
345 }
346
347 Ok(())
348 }
349
350 async fn write_updates(
351 &mut self,
352 upserts: &HashMap<String, TaskState>,
353 deletes: &[String],
354 ) -> Result<(), BeatError> {
355 if upserts.is_empty() && deletes.is_empty() {
356 return Ok(());
357 }
358
359 let mut conn = self.get_connection().await?;
360 let mut pipe = redis::pipe();
361
362 for (name, state) in upserts {
363 let key = self.config.task_key(name);
364 let descriptor = serde_json::to_string(&state.descriptor)
365 .map_err(|err| BeatError::RedisError(err.to_string()))?;
366
367 pipe.cmd("HSET")
368 .arg(&key)
369 .arg("descriptor")
370 .arg(descriptor)
371 .arg("task")
372 .arg(name)
373 .arg("total_run_count")
374 .arg(state.total_run_count)
375 .arg("next_run_at")
376 .arg(system_time_to_epoch(state.next_run_at));
377
378 if let Some(last_run) = state.last_run_at {
379 pipe.cmd("HSET")
380 .arg(&key)
381 .arg("last_run_at")
382 .arg(system_time_to_epoch(last_run));
383 }
384
385 pipe.cmd("ZADD")
386 .arg(&self.config.schedule_key)
387 .arg(system_time_to_epoch(state.next_run_at))
388 .arg(&key);
389 }
390
391 for name in deletes {
392 let key = self.config.task_key(name);
393 pipe.cmd("DEL").arg(&key);
394 pipe.cmd("ZREM").arg(&self.config.schedule_key).arg(&key);
395 }
396
397 pipe.query_async::<()>(&mut conn)
398 .await
399 .map_err(|err| BeatError::RedisError(err.to_string()))?;
400
401 Ok(())
402 }
403}
404
405impl super::SchedulerBackend for RedisSchedulerBackend {
406 fn should_sync(&self) -> bool {
407 false
408 }
409
410 fn sync(&mut self, _scheduled_tasks: &mut BinaryHeap<ScheduledTask>) -> Result<(), BeatError> {
411 Ok(())
412 }
413
414 fn as_distributed(&mut self) -> Option<&mut dyn DistributedScheduler> {
415 Some(self)
416 }
417}
418
419impl DistributedScheduler for RedisSchedulerBackend {
420 fn before_tick<'a>(
421 &'a mut self,
422 ) -> Pin<Box<dyn Future<Output = Result<TickDecision, BeatError>> + 'a>> {
423 Box::pin(async move {
424 let now = Instant::now();
425 let leader_hint = leader_sleep_hint(self.config.lock_renewal_interval);
426 let follower_hint = ensure_min_duration(std::cmp::min(
427 self.config.follower_idle_sleep,
428 self.config.follower_check_interval,
429 ));
430
431 if self.state.is_leader {
432 if self
433 .state
434 .last_lock_refresh
435 .map(|instant| now.duration_since(instant) >= self.config.lock_renewal_interval)
436 .unwrap_or(true)
437 {
438 if let Err(err) = self.renew_lock().await {
439 warn!("Redis scheduler backend failed to renew lock: {}", err);
440 self.state.is_leader = false;
441 return Ok(TickDecision::skip(follower_hint));
442 }
443 }
444 Ok(TickDecision::execute_with_hint(leader_hint))
445 } else {
446 if self
447 .state
448 .last_leader_attempt
449 .map(|instant| {
450 now.duration_since(instant) >= self.config.follower_check_interval
451 })
452 .unwrap_or(true)
453 {
454 self.state.last_leader_attempt = Some(now);
455 if self.try_acquire_lock().await? {
456 return Ok(TickDecision::execute_with_hint(leader_hint));
457 }
458 }
459 Ok(TickDecision::skip(follower_hint))
460 }
461 })
462 }
463
464 fn after_tick<'a>(
465 &'a mut self,
466 scheduled_tasks: &'a mut BinaryHeap<ScheduledTask>,
467 ) -> Pin<Box<dyn Future<Output = Result<(), BeatError>> + 'a>> {
468 Box::pin(async move {
469 if !self.state.is_leader {
470 return Ok(());
471 }
472
473 if self.state.pending_full_refresh {
474 self.apply_remote_state(scheduled_tasks).await?;
475 self.state.pending_full_refresh = false;
476 }
477
478 if self
479 .state
480 .last_sync
481 .map(|instant| instant.elapsed() < self.config.sync_interval)
482 .unwrap_or(false)
483 {
484 return Ok(());
485 }
486
487 let (current_state, unsupported) = self.collect_task_state(scheduled_tasks);
488 for name in unsupported {
489 warn!(
490 "Redis scheduler backend skipping task '{}' (unsupported schedule)",
491 name
492 );
493 }
494
495 let mut upserts = HashMap::new();
496 for (name, state) in current_state.iter() {
497 match self.state.local_snapshot.get(name) {
498 Some(existing) if existing == state => {}
499 _ => {
500 upserts.insert(name.clone(), state.clone());
501 }
502 }
503 }
504
505 let mut deletes = Vec::new();
506 for name in self.state.local_snapshot.keys() {
507 if !current_state.contains_key(name) {
508 deletes.push(name.clone());
509 }
510 }
511
512 self.write_updates(&upserts, &deletes).await?;
513 self.state.local_snapshot = current_state;
514 self.state.last_sync = Some(Instant::now());
515 Ok(())
516 })
517 }
518
519 fn shutdown<'a>(&'a mut self) -> Pin<Box<dyn Future<Output = Result<(), BeatError>> + 'a>> {
520 Box::pin(async move {
521 if self.state.is_leader {
522 if let Err(err) = self.release_lock().await {
523 warn!("Redis scheduler backend failed to release lock: {}", err);
524 }
525 self.state.is_leader = false;
526 }
527 Ok(())
528 })
529 }
530}
531
532#[cfg(test)]
533mod tests {
534 use super::*;
535 use uuid::Uuid;
536
537 #[test]
538 fn resolve_applies_defaults() {
539 let config = RedisBackendConfig::new("redis://localhost:6379");
540 let resolved = config.resolve();
541
542 assert_eq!(resolved.key_prefix, DEFAULT_KEY_PREFIX);
543 assert_eq!(resolved.lock_key, format!("{}:lock", DEFAULT_KEY_PREFIX));
544 assert_eq!(
545 resolved.schedule_key,
546 format!("{}:schedule", DEFAULT_KEY_PREFIX)
547 );
548 assert_eq!(resolved.lock_timeout, Duration::from_secs(30));
549 assert_eq!(resolved.lock_renewal_interval, Duration::from_secs(10));
550 assert_eq!(resolved.follower_check_interval, Duration::from_secs(5));
551 assert!(resolved.instance_id.starts_with(DEFAULT_KEY_PREFIX));
552 }
553
554 #[tokio::test]
555 async fn lock_lifecycle_smoke() {
556 let url =
557 std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string());
558 let prefix = format!("test_lock_{}", Uuid::new_v4());
559 let config = RedisBackendConfig::new(&url).key_prefix(&prefix);
560 let mut backend = match RedisSchedulerBackend::new(config) {
561 Ok(backend) => backend,
562 Err(err) => {
563 eprintln!("Skipping Redis lock test: {err}");
564 return;
565 }
566 };
567
568 match backend.try_acquire_lock().await {
569 Ok(true) => {
570 backend.renew_lock().await.expect("renew");
571 backend.release_lock().await.expect("release");
572 }
573 Ok(false) => {
574 eprintln!("Skipping Redis lock test: lock already held");
575 }
576 Err(err) => {
577 eprintln!("Skipping Redis lock test: {err}");
578 }
579 }
580 }
581}