Skip to main content

opendev_runtime/
task_scheduler.rs

1//! Background task scheduler for deferred and periodic work (#95).
2//!
3//! Provides [`TaskScheduler`] which manages one-shot and recurring async tasks
4//! using `tokio::spawn` and `tokio::time`.
5
6use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12
13use tokio::sync::Mutex;
14use tokio::task::JoinHandle;
15use tracing::debug;
16
17/// Unique identifier for a scheduled task.
18pub type TaskId = u64;
19
20/// Internal state shared across clones.
21struct SchedulerInner {
22    next_id: AtomicU64,
23    tasks: Mutex<HashMap<TaskId, TaskEntry>>,
24}
25
26struct TaskEntry {
27    label: String,
28    handle: JoinHandle<()>,
29}
30
31/// A scheduler for one-shot and periodic background tasks.
32///
33/// All tasks are cancelled when [`TaskScheduler::shutdown`] is called or when
34/// the scheduler is dropped.
35#[derive(Clone)]
36pub struct TaskScheduler {
37    inner: Arc<SchedulerInner>,
38}
39
40impl TaskScheduler {
41    /// Create a new task scheduler.
42    pub fn new() -> Self {
43        Self {
44            inner: Arc::new(SchedulerInner {
45                next_id: AtomicU64::new(1),
46                tasks: Mutex::new(HashMap::new()),
47            }),
48        }
49    }
50
51    /// Schedule a one-shot task that executes after `delay`.
52    ///
53    /// Returns a [`TaskId`] that can be used to cancel the task.
54    pub fn schedule_once<F, Fut>(&self, delay: Duration, label: impl Into<String>, f: F) -> TaskId
55    where
56        F: FnOnce() -> Fut + Send + 'static,
57        Fut: Future<Output = ()> + Send + 'static,
58    {
59        let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
60        let label_str = label.into();
61        let inner = Arc::clone(&self.inner);
62        let task_label = label_str.clone();
63
64        let handle = tokio::spawn(async move {
65            tokio::time::sleep(delay).await;
66            debug!("Running one-shot task {id} ({task_label})");
67            f().await;
68            // Remove self from the map after completion.
69            inner.tasks.lock().await.remove(&id);
70        });
71
72        {
73            let inner = Arc::clone(&self.inner);
74            let label_str = label_str.clone();
75            tokio::spawn(async move {
76                inner.tasks.lock().await.insert(
77                    id,
78                    TaskEntry {
79                        label: label_str,
80                        handle,
81                    },
82                );
83            });
84        }
85
86        id
87    }
88
89    /// Schedule a periodic task that runs every `interval`.
90    ///
91    /// The task function receives the current tick count (starting at 1).
92    /// The first execution happens after `interval` elapses.
93    ///
94    /// Returns a [`TaskId`] that can be used to cancel the task.
95    pub fn schedule_periodic<F, Fut>(
96        &self,
97        interval: Duration,
98        label: impl Into<String>,
99        f: F,
100    ) -> TaskId
101    where
102        F: Fn(u64) -> Fut + Send + Sync + 'static,
103        Fut: Future<Output = ()> + Send + 'static,
104    {
105        let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
106        let label_str = label.into();
107        let task_label = label_str.clone();
108
109        let handle = tokio::spawn(async move {
110            let mut ticker = tokio::time::interval(interval);
111            // First tick fires immediately — skip it so the first execution
112            // happens after one full interval.
113            ticker.tick().await;
114
115            let mut tick_count: u64 = 0;
116            loop {
117                ticker.tick().await;
118                tick_count += 1;
119                debug!("Periodic task {id} ({task_label}) tick {tick_count}");
120                f(tick_count).await;
121            }
122        });
123
124        let inner = Arc::clone(&self.inner);
125        let label_owned = label_str;
126        tokio::spawn(async move {
127            inner.tasks.lock().await.insert(
128                id,
129                TaskEntry {
130                    label: label_owned,
131                    handle,
132                },
133            );
134        });
135
136        id
137    }
138
139    /// Cancel a previously scheduled task.
140    ///
141    /// Returns `true` if the task was found and cancelled.
142    pub async fn cancel(&self, id: TaskId) -> bool {
143        if let Some(entry) = self.inner.tasks.lock().await.remove(&id) {
144            entry.handle.abort();
145            debug!("Cancelled task {id} ({})", entry.label);
146            true
147        } else {
148            false
149        }
150    }
151
152    /// Return the number of active (not yet completed / cancelled) tasks.
153    pub async fn active_count(&self) -> usize {
154        self.inner.tasks.lock().await.len()
155    }
156
157    /// Cancel all tasks and shut down the scheduler.
158    pub async fn shutdown(&self) {
159        let mut tasks = self.inner.tasks.lock().await;
160        for (id, entry) in tasks.drain() {
161            entry.handle.abort();
162            debug!("Shutdown: cancelled task {id} ({})", entry.label);
163        }
164    }
165}
166
167impl Default for TaskScheduler {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl std::fmt::Debug for TaskScheduler {
174    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175        f.debug_struct("TaskScheduler").finish()
176    }
177}
178
179/// Convenience: schedule a one-shot closure that returns a boxed future.
180pub fn boxed_task<F>(f: F) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
181where
182    F: Future<Output = ()> + Send + 'static,
183{
184    Box::pin(f)
185}
186
187#[cfg(test)]
188#[path = "task_scheduler_tests.rs"]
189mod tests;