1use async_trait::async_trait;
2use chrono::{DateTime, Utc};
3use std::{
4 collections::BTreeMap,
5 sync::{
6 Arc,
7 atomic::{AtomicBool, Ordering},
8 },
9 time::Duration,
10};
11use tokio::sync::{RwLock, broadcast, mpsc};
12use tracing::{debug, info};
13use uuid::Uuid;
14
15use crate::task::{TaskHandle, TaskId, TaskState};
16use wae_types::{WaeError, WaeResult};
17
18#[derive(Debug, Clone)]
20pub struct DelayedTask<T: Send + Sync + Clone + 'static> {
21 pub id: TaskId,
23 pub name: String,
25 pub execute_at: DateTime<Utc>,
27 pub priority: u32,
29 pub data: T,
31 pub created_at: DateTime<Utc>,
33}
34
35impl<T: Send + Sync + Clone + 'static> DelayedTask<T> {
36 pub fn new(name: String, execute_at: DateTime<Utc>, data: T) -> Self {
38 Self { id: Uuid::new_v4().to_string(), name, execute_at, priority: 0, data, created_at: Utc::now() }
39 }
40
41 pub fn with_priority(mut self, priority: u32) -> Self {
43 self.priority = priority;
44 self
45 }
46
47 pub fn is_due(&self) -> bool {
49 Utc::now() >= self.execute_at
50 }
51
52 pub fn remaining(&self) -> Duration {
54 let now = Utc::now();
55 if now >= self.execute_at { Duration::ZERO } else { (self.execute_at - now).to_std().unwrap_or(Duration::ZERO) }
56 }
57}
58
59#[async_trait]
61pub trait DelayedTaskExecutor<T: Send + Sync + Clone + 'static>: Send + Sync {
62 async fn execute(&self, task: DelayedTask<T>) -> WaeResult<()>;
64}
65
66#[derive(Debug, Clone)]
68pub struct DelayedQueueConfig {
69 pub max_queue_size: usize,
71 pub poll_interval: Duration,
73 pub max_concurrent_executions: usize,
75}
76
77impl Default for DelayedQueueConfig {
78 fn default() -> Self {
79 Self { max_queue_size: 10000, poll_interval: Duration::from_millis(100), max_concurrent_executions: 10 }
80 }
81}
82
83pub struct DelayedQueue<T: Send + Sync + Clone + 'static> {
87 #[allow(dead_code)]
89 config: DelayedQueueConfig,
90 queue: Arc<RwLock<Vec<DelayedTask<T>>>>,
92 handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>>,
94 task_tx: mpsc::Sender<DelayedTask<T>>,
96 shutdown_tx: broadcast::Sender<()>,
98 is_shutdown: Arc<AtomicBool>,
100}
101
102impl<T: Send + Sync + Clone + 'static> DelayedQueue<T> {
103 pub fn new<E: DelayedTaskExecutor<T> + 'static>(config: DelayedQueueConfig, executor: Arc<E>) -> Self {
105 let (task_tx, mut task_rx) = mpsc::channel(config.max_queue_size);
106 let (shutdown_tx, _) = broadcast::channel(1);
107 let queue: Arc<RwLock<Vec<DelayedTask<T>>>> = Arc::new(RwLock::new(Vec::new()));
108 let handles: Arc<RwLock<BTreeMap<TaskId, TaskHandle>>> = Arc::new(RwLock::new(BTreeMap::new()));
109 let is_shutdown = Arc::new(AtomicBool::new(false));
110
111 let queue_clone = queue.clone();
112 let handles_clone = handles.clone();
113 let is_shutdown_clone = is_shutdown.clone();
114 let mut shutdown_rx = shutdown_tx.subscribe();
115
116 tokio::spawn(async move {
117 loop {
118 tokio::select! {
119 _ = shutdown_rx.recv() => {
120 debug!("Delayed queue received shutdown signal");
121 break;
122 }
123 _ = tokio::time::sleep(config.poll_interval) => {
124 if is_shutdown_clone.load(Ordering::SeqCst) {
125 break;
126 }
127
128 let now = Utc::now();
129 let mut queue_guard = queue_clone.write().await;
130
131 let due_tasks: Vec<_> = queue_guard
132 .iter()
133 .filter(|t| t.execute_at <= now)
134 .cloned()
135 .collect();
136
137 queue_guard.retain(|t| t.execute_at > now);
138
139 drop(queue_guard);
140
141 for task in due_tasks {
142 let executor_clone = executor.clone();
143 let handles_ref = handles_clone.clone();
144
145 tokio::spawn(async move {
146 if let Some(handle) = handles_ref.read().await.get(&task.id) {
147 handle.set_state(TaskState::Running).await;
148 }
149
150 let result = executor_clone.execute(task.clone()).await;
151
152 if let Some(handle) = handles_ref.read().await.get(&task.id) {
153 match result {
154 Ok(()) => {
155 handle.record_execution().await;
156 handle.set_state(TaskState::Completed).await;
157 }
158 Err(e) => {
159 handle.record_error(e.to_string()).await;
160 handle.set_state(TaskState::Failed).await;
161 }
162 }
163 }
164 });
165 }
166 }
167 }
168 }
169 });
170
171 let queue_clone = queue.clone();
172 tokio::spawn(async move {
173 while let Some(task) = task_rx.recv().await {
174 let mut queue_guard = queue_clone.write().await;
175 queue_guard.push(task);
176 queue_guard.sort_by(|a, b| a.execute_at.cmp(&b.execute_at).then_with(|| a.priority.cmp(&b.priority)));
177 }
178 });
179
180 Self { config, queue, handles, task_tx, shutdown_tx, is_shutdown }
181 }
182
183 pub async fn schedule_delayed(&self, task: DelayedTask<T>) -> WaeResult<TaskHandle> {
193 if self.is_shutdown.load(Ordering::SeqCst) {
194 return Err(WaeError::scheduler_shutdown());
195 }
196
197 let handle = TaskHandle::new(task.id.clone(), task.name.clone());
198
199 {
200 let mut handles = self.handles.write().await;
201 handles.insert(task.id.clone(), handle.clone());
202 }
203
204 self.task_tx.send(task).await.map_err(|e| WaeError::internal(format!("Failed to send task: {}", e)))?;
205
206 info!("Scheduled delayed task: {}", handle.name);
207 Ok(handle)
208 }
209
210 pub async fn queue_size(&self) -> usize {
212 self.queue.read().await.len()
213 }
214
215 pub async fn cancel_task(&self, task_id: &str) -> WaeResult<bool> {
217 let mut queue = self.queue.write().await;
218 let initial_len = queue.len();
219 queue.retain(|t| t.id != task_id);
220
221 if queue.len() < initial_len {
222 if let Some(handle) = self.handles.read().await.get(task_id) {
223 handle.cancel();
224 handle.set_state(TaskState::Cancelled).await;
225 }
226 info!("Cancelled delayed task: {}", task_id);
227 Ok(true)
228 }
229 else {
230 Err(WaeError::task_not_found(task_id))
231 }
232 }
233
234 pub async fn get_handle(&self, task_id: &str) -> Option<TaskHandle> {
236 self.handles.read().await.get(task_id).cloned()
237 }
238
239 pub fn shutdown(&self) {
241 self.is_shutdown.store(true, Ordering::SeqCst);
242 let _ = self.shutdown_tx.send(());
243 info!("Delayed queue shutdown initiated");
244 }
245
246 pub fn is_shutdown(&self) -> bool {
248 self.is_shutdown.load(Ordering::SeqCst)
249 }
250}