celery/beat/backend/
redis.rs

1use super::{DistributedScheduler, TickDecision};
2use crate::beat::schedule::ScheduleDescriptor;
3use crate::beat::scheduled_task::ScheduledTask;
4use crate::error::BeatError;
5use hostname::get as hostname_get;
6use log::{info, warn};
7use redis::{AsyncCommands, Client, Script};
8use std::collections::{BinaryHeap, HashMap};
9use std::future::Future;
10use std::pin::Pin;
11use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
12use uuid::Uuid;
13
14const DEFAULT_KEY_PREFIX: &str = "celery_beat";
15const LOCK_RENEW_SCRIPT: &str = "if redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('PEXPIRE', KEYS[1], ARGV[2]) else return 0 end";
16const LOCK_RELEASE_SCRIPT: &str = "if redis.call('GET', KEYS[1]) == ARGV[1] then return redis.call('DEL', KEYS[1]) else return 0 end";
17
18fn ensure_min_duration(duration: Duration) -> Duration {
19    if duration.is_zero() {
20        Duration::from_millis(1)
21    } else {
22        duration
23    }
24}
25
26fn leader_sleep_hint(lock_renewal_interval: Duration) -> Duration {
27    let half = lock_renewal_interval.as_secs_f64() / 2.0;
28    if half < 0.001 {
29        Duration::from_millis(1)
30    } else {
31        Duration::from_secs_f64(half)
32    }
33}
34
35fn generate_instance_id(prefix: &str) -> String {
36    let host = hostname_get()
37        .map(|s| s.to_string_lossy().into_owned())
38        .unwrap_or_else(|_| "unknown-host".to_string());
39    format!("{}:{}:{}", prefix, host, Uuid::new_v4())
40}
41
42fn system_time_to_epoch(time: SystemTime) -> u64 {
43    time.duration_since(UNIX_EPOCH)
44        .unwrap_or_else(|_| Duration::from_secs(0))
45        .as_secs()
46}
47
48fn epoch_to_system_time(epoch: u64) -> SystemTime {
49    UNIX_EPOCH + Duration::from_secs(epoch)
50}
51
52#[derive(Clone)]
53pub struct RedisBackendConfig {
54    redis_url: String,
55    key_prefix: String,
56    lock_timeout: Duration,
57    lock_renewal_interval: Duration,
58    follower_check_interval: Duration,
59    sync_interval: Duration,
60    follower_idle_sleep: Duration,
61    instance_id: Option<String>,
62}
63
64impl RedisBackendConfig {
65    pub fn new(redis_url: impl Into<String>) -> Self {
66        Self {
67            redis_url: redis_url.into(),
68            key_prefix: DEFAULT_KEY_PREFIX.to_string(),
69            lock_timeout: Duration::from_secs(30),
70            lock_renewal_interval: Duration::from_secs(10),
71            follower_check_interval: Duration::from_secs(5),
72            sync_interval: Duration::from_secs(5),
73            follower_idle_sleep: Duration::from_millis(750),
74            instance_id: None,
75        }
76    }
77
78    pub fn key_prefix(mut self, prefix: impl Into<String>) -> Self {
79        self.key_prefix = prefix.into();
80        self
81    }
82
83    pub fn lock_timeout(mut self, timeout: Duration) -> Self {
84        self.lock_timeout = timeout;
85        self
86    }
87
88    pub fn lock_renewal_interval(mut self, interval: Duration) -> Self {
89        self.lock_renewal_interval = interval;
90        self
91    }
92
93    pub fn follower_check_interval(mut self, interval: Duration) -> Self {
94        self.follower_check_interval = interval;
95        self
96    }
97
98    pub fn sync_interval(mut self, interval: Duration) -> Self {
99        self.sync_interval = interval;
100        self
101    }
102
103    pub fn follower_idle_sleep(mut self, interval: Duration) -> Self {
104        self.follower_idle_sleep = interval;
105        self
106    }
107
108    pub fn instance_id(mut self, id: impl Into<String>) -> Self {
109        self.instance_id = Some(id.into());
110        self
111    }
112
113    pub fn resolve(self) -> ResolvedRedisBackendConfig {
114        let RedisBackendConfig {
115            redis_url,
116            key_prefix,
117            lock_timeout,
118            lock_renewal_interval,
119            follower_check_interval,
120            sync_interval,
121            follower_idle_sleep,
122            instance_id,
123        } = self;
124
125        let instance_id = instance_id.unwrap_or_else(|| generate_instance_id(&key_prefix));
126
127        ResolvedRedisBackendConfig {
128            redis_url,
129            key_prefix: key_prefix.clone(),
130            lock_key: format!("{}:lock", key_prefix),
131            schedule_key: format!("{}:schedule", key_prefix),
132            instance_id,
133            lock_timeout,
134            lock_renewal_interval,
135            follower_check_interval,
136            sync_interval,
137            follower_idle_sleep,
138        }
139    }
140}
141
142#[derive(Clone)]
143pub struct ResolvedRedisBackendConfig {
144    pub redis_url: String,
145    pub key_prefix: String,
146    pub lock_key: String,
147    pub schedule_key: String,
148    pub instance_id: String,
149    pub lock_timeout: Duration,
150    pub lock_renewal_interval: Duration,
151    pub follower_check_interval: Duration,
152    pub sync_interval: Duration,
153    pub follower_idle_sleep: Duration,
154}
155
156impl ResolvedRedisBackendConfig {
157    fn task_key(&self, name: &str) -> String {
158        format!("{}:task:{}", self.key_prefix, name)
159    }
160
161    fn lock_ttl_millis(&self) -> usize {
162        self.lock_timeout.as_millis() as usize
163    }
164}
165
166pub struct RedisSchedulerBackend {
167    config: ResolvedRedisBackendConfig,
168    client: Client,
169    state: BackendState,
170}
171
172struct BackendState {
173    is_leader: bool,
174    last_lock_refresh: Option<Instant>,
175    last_leader_attempt: Option<Instant>,
176    last_sync: Option<Instant>,
177    local_snapshot: HashMap<String, TaskState>,
178    pending_full_refresh: bool,
179}
180
181#[derive(Clone, Debug, PartialEq)]
182struct TaskState {
183    descriptor: ScheduleDescriptor,
184    next_run_at: SystemTime,
185    last_run_at: Option<SystemTime>,
186    total_run_count: u32,
187}
188
189impl RedisSchedulerBackend {
190    pub fn new(config: RedisBackendConfig) -> Result<Self, BeatError> {
191        let resolved = config.resolve();
192        let client = Client::open(resolved.redis_url.as_str())
193            .map_err(|err| BeatError::RedisError(err.to_string()))?;
194
195        Ok(Self {
196            config: resolved,
197            client,
198            state: BackendState {
199                is_leader: false,
200                last_lock_refresh: None,
201                last_leader_attempt: None,
202                last_sync: None,
203                local_snapshot: HashMap::new(),
204                pending_full_refresh: false,
205            },
206        })
207    }
208
209    async fn get_connection(&self) -> Result<redis::aio::MultiplexedConnection, BeatError> {
210        self.client
211            .get_multiplexed_async_connection()
212            .await
213            .map_err(|err| BeatError::RedisError(err.to_string()))
214    }
215
216    async fn try_acquire_lock(&mut self) -> Result<bool, BeatError> {
217        let mut conn = self.get_connection().await?;
218        let result: Option<String> = redis::cmd("SET")
219            .arg(&self.config.lock_key)
220            .arg(&self.config.instance_id)
221            .arg("NX")
222            .arg("PX")
223            .arg(self.config.lock_ttl_millis())
224            .query_async(&mut conn)
225            .await
226            .map_err(|err| BeatError::RedisError(err.to_string()))?;
227
228        if result.is_some() {
229            info!("Redis scheduler backend acquired leadership");
230            self.state.last_lock_refresh = Some(Instant::now());
231            self.state.is_leader = true;
232            self.state.pending_full_refresh = true;
233            Ok(true)
234        } else {
235            Ok(false)
236        }
237    }
238
239    async fn renew_lock(&mut self) -> Result<(), BeatError> {
240        let mut conn = self.get_connection().await?;
241        let script = Script::new(LOCK_RENEW_SCRIPT);
242        let result: i32 = script
243            .key(&self.config.lock_key)
244            .arg(&self.config.instance_id)
245            .arg(self.config.lock_ttl_millis())
246            .invoke_async(&mut conn)
247            .await
248            .map_err(|err| BeatError::RedisError(err.to_string()))?;
249
250        if result == 1 {
251            self.state.last_lock_refresh = Some(Instant::now());
252            Ok(())
253        } else {
254            Err(BeatError::RedisError("lost leadership".into()))
255        }
256    }
257
258    async fn release_lock(&mut self) -> Result<(), BeatError> {
259        let mut conn = self.get_connection().await?;
260        let script = Script::new(LOCK_RELEASE_SCRIPT);
261        let _: i32 = script
262            .key(&self.config.lock_key)
263            .arg(&self.config.instance_id)
264            .invoke_async(&mut conn)
265            .await
266            .map_err(|err| BeatError::RedisError(err.to_string()))?;
267        Ok(())
268    }
269
270    fn collect_task_state(
271        &self,
272        scheduled_tasks: &BinaryHeap<ScheduledTask>,
273    ) -> (HashMap<String, TaskState>, Vec<String>) {
274        let mut map = HashMap::new();
275        let mut unsupported = Vec::new();
276
277        for task in scheduled_tasks.iter() {
278            let descriptor = match task.schedule.describe() {
279                Some(desc) => desc,
280                None => {
281                    unsupported.push(task.name.clone());
282                    continue;
283                }
284            };
285
286            map.insert(
287                task.name.clone(),
288                TaskState {
289                    descriptor,
290                    next_run_at: task.next_call_at,
291                    last_run_at: task.last_run_at,
292                    total_run_count: task.total_run_count,
293                },
294            );
295        }
296
297        (map, unsupported)
298    }
299
300    async fn apply_remote_state(
301        &mut self,
302        scheduled_tasks: &mut BinaryHeap<ScheduledTask>,
303    ) -> Result<(), BeatError> {
304        if scheduled_tasks.is_empty() {
305            self.state.local_snapshot.clear();
306            return Ok(());
307        }
308
309        let mut tasks = Vec::with_capacity(scheduled_tasks.len());
310        while let Some(task) = scheduled_tasks.pop() {
311            tasks.push(task);
312        }
313
314        let mut conn = self.get_connection().await?;
315        for task in tasks.iter_mut() {
316            let key = self.config.task_key(&task.name);
317            let data: HashMap<String, String> = conn
318                .hgetall(&key)
319                .await
320                .map_err(|err| BeatError::RedisError(err.to_string()))?;
321
322            if data.is_empty() {
323                continue;
324            }
325
326            if let Some(value) = data.get("last_run_at") {
327                if let Ok(epoch) = value.parse::<u64>() {
328                    task.last_run_at = Some(epoch_to_system_time(epoch));
329                }
330            }
331            if let Some(value) = data.get("next_run_at") {
332                if let Ok(epoch) = value.parse::<u64>() {
333                    task.next_call_at = epoch_to_system_time(epoch);
334                }
335            }
336            if let Some(value) = data.get("total_run_count") {
337                if let Ok(count) = value.parse::<u32>() {
338                    task.total_run_count = count;
339                }
340            }
341        }
342
343        for task in tasks.into_iter() {
344            scheduled_tasks.push(task);
345        }
346
347        Ok(())
348    }
349
350    async fn write_updates(
351        &mut self,
352        upserts: &HashMap<String, TaskState>,
353        deletes: &[String],
354    ) -> Result<(), BeatError> {
355        if upserts.is_empty() && deletes.is_empty() {
356            return Ok(());
357        }
358
359        let mut conn = self.get_connection().await?;
360        let mut pipe = redis::pipe();
361
362        for (name, state) in upserts {
363            let key = self.config.task_key(name);
364            let descriptor = serde_json::to_string(&state.descriptor)
365                .map_err(|err| BeatError::RedisError(err.to_string()))?;
366
367            pipe.cmd("HSET")
368                .arg(&key)
369                .arg("descriptor")
370                .arg(descriptor)
371                .arg("task")
372                .arg(name)
373                .arg("total_run_count")
374                .arg(state.total_run_count)
375                .arg("next_run_at")
376                .arg(system_time_to_epoch(state.next_run_at));
377
378            if let Some(last_run) = state.last_run_at {
379                pipe.cmd("HSET")
380                    .arg(&key)
381                    .arg("last_run_at")
382                    .arg(system_time_to_epoch(last_run));
383            }
384
385            pipe.cmd("ZADD")
386                .arg(&self.config.schedule_key)
387                .arg(system_time_to_epoch(state.next_run_at))
388                .arg(&key);
389        }
390
391        for name in deletes {
392            let key = self.config.task_key(name);
393            pipe.cmd("DEL").arg(&key);
394            pipe.cmd("ZREM").arg(&self.config.schedule_key).arg(&key);
395        }
396
397        pipe.query_async::<()>(&mut conn)
398            .await
399            .map_err(|err| BeatError::RedisError(err.to_string()))?;
400
401        Ok(())
402    }
403}
404
405impl super::SchedulerBackend for RedisSchedulerBackend {
406    fn should_sync(&self) -> bool {
407        false
408    }
409
410    fn sync(&mut self, _scheduled_tasks: &mut BinaryHeap<ScheduledTask>) -> Result<(), BeatError> {
411        Ok(())
412    }
413
414    fn as_distributed(&mut self) -> Option<&mut dyn DistributedScheduler> {
415        Some(self)
416    }
417}
418
419impl DistributedScheduler for RedisSchedulerBackend {
420    fn before_tick<'a>(
421        &'a mut self,
422    ) -> Pin<Box<dyn Future<Output = Result<TickDecision, BeatError>> + 'a>> {
423        Box::pin(async move {
424            let now = Instant::now();
425            let leader_hint = leader_sleep_hint(self.config.lock_renewal_interval);
426            let follower_hint = ensure_min_duration(std::cmp::min(
427                self.config.follower_idle_sleep,
428                self.config.follower_check_interval,
429            ));
430
431            if self.state.is_leader {
432                if self
433                    .state
434                    .last_lock_refresh
435                    .map(|instant| now.duration_since(instant) >= self.config.lock_renewal_interval)
436                    .unwrap_or(true)
437                {
438                    if let Err(err) = self.renew_lock().await {
439                        warn!("Redis scheduler backend failed to renew lock: {}", err);
440                        self.state.is_leader = false;
441                        return Ok(TickDecision::skip(follower_hint));
442                    }
443                }
444                Ok(TickDecision::execute_with_hint(leader_hint))
445            } else {
446                if self
447                    .state
448                    .last_leader_attempt
449                    .map(|instant| {
450                        now.duration_since(instant) >= self.config.follower_check_interval
451                    })
452                    .unwrap_or(true)
453                {
454                    self.state.last_leader_attempt = Some(now);
455                    if self.try_acquire_lock().await? {
456                        return Ok(TickDecision::execute_with_hint(leader_hint));
457                    }
458                }
459                Ok(TickDecision::skip(follower_hint))
460            }
461        })
462    }
463
464    fn after_tick<'a>(
465        &'a mut self,
466        scheduled_tasks: &'a mut BinaryHeap<ScheduledTask>,
467    ) -> Pin<Box<dyn Future<Output = Result<(), BeatError>> + 'a>> {
468        Box::pin(async move {
469            if !self.state.is_leader {
470                return Ok(());
471            }
472
473            if self.state.pending_full_refresh {
474                self.apply_remote_state(scheduled_tasks).await?;
475                self.state.pending_full_refresh = false;
476            }
477
478            if self
479                .state
480                .last_sync
481                .map(|instant| instant.elapsed() < self.config.sync_interval)
482                .unwrap_or(false)
483            {
484                return Ok(());
485            }
486
487            let (current_state, unsupported) = self.collect_task_state(scheduled_tasks);
488            for name in unsupported {
489                warn!(
490                    "Redis scheduler backend skipping task '{}' (unsupported schedule)",
491                    name
492                );
493            }
494
495            let mut upserts = HashMap::new();
496            for (name, state) in current_state.iter() {
497                match self.state.local_snapshot.get(name) {
498                    Some(existing) if existing == state => {}
499                    _ => {
500                        upserts.insert(name.clone(), state.clone());
501                    }
502                }
503            }
504
505            let mut deletes = Vec::new();
506            for name in self.state.local_snapshot.keys() {
507                if !current_state.contains_key(name) {
508                    deletes.push(name.clone());
509                }
510            }
511
512            self.write_updates(&upserts, &deletes).await?;
513            self.state.local_snapshot = current_state;
514            self.state.last_sync = Some(Instant::now());
515            Ok(())
516        })
517    }
518
519    fn shutdown<'a>(&'a mut self) -> Pin<Box<dyn Future<Output = Result<(), BeatError>> + 'a>> {
520        Box::pin(async move {
521            if self.state.is_leader {
522                if let Err(err) = self.release_lock().await {
523                    warn!("Redis scheduler backend failed to release lock: {}", err);
524                }
525                self.state.is_leader = false;
526            }
527            Ok(())
528        })
529    }
530}
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535    use uuid::Uuid;
536
537    #[test]
538    fn resolve_applies_defaults() {
539        let config = RedisBackendConfig::new("redis://localhost:6379");
540        let resolved = config.resolve();
541
542        assert_eq!(resolved.key_prefix, DEFAULT_KEY_PREFIX);
543        assert_eq!(resolved.lock_key, format!("{}:lock", DEFAULT_KEY_PREFIX));
544        assert_eq!(
545            resolved.schedule_key,
546            format!("{}:schedule", DEFAULT_KEY_PREFIX)
547        );
548        assert_eq!(resolved.lock_timeout, Duration::from_secs(30));
549        assert_eq!(resolved.lock_renewal_interval, Duration::from_secs(10));
550        assert_eq!(resolved.follower_check_interval, Duration::from_secs(5));
551        assert!(resolved.instance_id.starts_with(DEFAULT_KEY_PREFIX));
552    }
553
554    #[tokio::test]
555    async fn lock_lifecycle_smoke() {
556        let url =
557            std::env::var("REDIS_URL").unwrap_or_else(|_| "redis://127.0.0.1:6379/0".to_string());
558        let prefix = format!("test_lock_{}", Uuid::new_v4());
559        let config = RedisBackendConfig::new(&url).key_prefix(&prefix);
560        let mut backend = match RedisSchedulerBackend::new(config) {
561            Ok(backend) => backend,
562            Err(err) => {
563                eprintln!("Skipping Redis lock test: {err}");
564                return;
565            }
566        };
567
568        match backend.try_acquire_lock().await {
569            Ok(true) => {
570                backend.renew_lock().await.expect("renew");
571                backend.release_lock().await.expect("release");
572            }
573            Ok(false) => {
574                eprintln!("Skipping Redis lock test: lock already held");
575            }
576            Err(err) => {
577                eprintln!("Skipping Redis lock test: {err}");
578            }
579        }
580    }
581}