opendev_runtime/
task_scheduler.rs1use std::collections::HashMap;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12
13use tokio::sync::Mutex;
14use tokio::task::JoinHandle;
15use tracing::debug;
16
17pub type TaskId = u64;
19
20struct SchedulerInner {
22 next_id: AtomicU64,
23 tasks: Mutex<HashMap<TaskId, TaskEntry>>,
24}
25
26struct TaskEntry {
27 label: String,
28 handle: JoinHandle<()>,
29}
30
31#[derive(Clone)]
36pub struct TaskScheduler {
37 inner: Arc<SchedulerInner>,
38}
39
40impl TaskScheduler {
41 pub fn new() -> Self {
43 Self {
44 inner: Arc::new(SchedulerInner {
45 next_id: AtomicU64::new(1),
46 tasks: Mutex::new(HashMap::new()),
47 }),
48 }
49 }
50
51 pub fn schedule_once<F, Fut>(&self, delay: Duration, label: impl Into<String>, f: F) -> TaskId
55 where
56 F: FnOnce() -> Fut + Send + 'static,
57 Fut: Future<Output = ()> + Send + 'static,
58 {
59 let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
60 let label_str = label.into();
61 let inner = Arc::clone(&self.inner);
62 let task_label = label_str.clone();
63
64 let handle = tokio::spawn(async move {
65 tokio::time::sleep(delay).await;
66 debug!("Running one-shot task {id} ({task_label})");
67 f().await;
68 inner.tasks.lock().await.remove(&id);
70 });
71
72 {
73 let inner = Arc::clone(&self.inner);
74 let label_str = label_str.clone();
75 tokio::spawn(async move {
76 inner.tasks.lock().await.insert(
77 id,
78 TaskEntry {
79 label: label_str,
80 handle,
81 },
82 );
83 });
84 }
85
86 id
87 }
88
89 pub fn schedule_periodic<F, Fut>(
96 &self,
97 interval: Duration,
98 label: impl Into<String>,
99 f: F,
100 ) -> TaskId
101 where
102 F: Fn(u64) -> Fut + Send + Sync + 'static,
103 Fut: Future<Output = ()> + Send + 'static,
104 {
105 let id = self.inner.next_id.fetch_add(1, Ordering::Relaxed);
106 let label_str = label.into();
107 let task_label = label_str.clone();
108
109 let handle = tokio::spawn(async move {
110 let mut ticker = tokio::time::interval(interval);
111 ticker.tick().await;
114
115 let mut tick_count: u64 = 0;
116 loop {
117 ticker.tick().await;
118 tick_count += 1;
119 debug!("Periodic task {id} ({task_label}) tick {tick_count}");
120 f(tick_count).await;
121 }
122 });
123
124 let inner = Arc::clone(&self.inner);
125 let label_owned = label_str;
126 tokio::spawn(async move {
127 inner.tasks.lock().await.insert(
128 id,
129 TaskEntry {
130 label: label_owned,
131 handle,
132 },
133 );
134 });
135
136 id
137 }
138
139 pub async fn cancel(&self, id: TaskId) -> bool {
143 if let Some(entry) = self.inner.tasks.lock().await.remove(&id) {
144 entry.handle.abort();
145 debug!("Cancelled task {id} ({})", entry.label);
146 true
147 } else {
148 false
149 }
150 }
151
152 pub async fn active_count(&self) -> usize {
154 self.inner.tasks.lock().await.len()
155 }
156
157 pub async fn shutdown(&self) {
159 let mut tasks = self.inner.tasks.lock().await;
160 for (id, entry) in tasks.drain() {
161 entry.handle.abort();
162 debug!("Shutdown: cancelled task {id} ({})", entry.label);
163 }
164 }
165}
166
167impl Default for TaskScheduler {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl std::fmt::Debug for TaskScheduler {
174 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
175 f.debug_struct("TaskScheduler").finish()
176 }
177}
178
179pub fn boxed_task<F>(f: F) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
181where
182 F: Future<Output = ()> + Send + 'static,
183{
184 Box::pin(f)
185}
186
187#[cfg(test)]
188#[path = "task_scheduler_tests.rs"]
189mod tests;