1use crate::{
2 CachedJobResult, JobMetrics, TaskMetrics, TaskState, TaskStatus,
3 types::{HealthStatus, MAX_QUEUE_SIZE, QueuedTask},
4};
5use chrono::{DateTime, Utc};
6use error_stack::ResultExt;
7use flume::{Receiver, Sender};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13
14const DEFAULT_TASK_TIMEOUT_SECS: u64 = 30 * 60;
16
17#[derive(Clone)]
18pub struct AppTasks {
19 sender: Sender<QueuedTask>,
21 receiver: Receiver<QueuedTask>,
22 metrics: Arc<TaskMetrics>,
23 task_states: Arc<tokio::sync::RwLock<HashMap<String, TaskState>>>,
24 results_cache: Arc<RwLock<HashMap<String, CachedJobResult>>>,
25 persistence_callback: Option<Arc<dyn Fn(&HashMap<String, TaskState>) + Send + Sync>>,
26 is_shutting_down: Arc<std::sync::atomic::AtomicBool>,
27 cancellation_tokens: Arc<RwLock<HashMap<String, tokio_util::sync::CancellationToken>>>,
29 task_timeout_secs: Arc<std::sync::atomic::AtomicU64>,
31}
32
33impl AppTasks {
34 pub fn new() -> Self {
35 let (sender, receiver) = flume::bounded(MAX_QUEUE_SIZE);
36
37 Self {
38 sender,
39 receiver,
40 metrics: Arc::new(TaskMetrics::new()),
41 task_states: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
42 results_cache: Arc::new(RwLock::new(HashMap::new())),
43 persistence_callback: None,
44 is_shutting_down: Arc::new(std::sync::atomic::AtomicBool::new(false)),
45 cancellation_tokens: Arc::new(RwLock::new(HashMap::new())),
46 task_timeout_secs: Arc::new(std::sync::atomic::AtomicU64::new(DEFAULT_TASK_TIMEOUT_SECS)),
47 }
48 }
49
50 pub fn with_task_timeout(self, timeout_secs: u64) -> Self {
53 self.task_timeout_secs.store(timeout_secs, std::sync::atomic::Ordering::Relaxed);
54 self
55 }
56
57 pub fn task_timeout(&self) -> Duration {
59 Duration::from_secs(self.task_timeout_secs.load(std::sync::atomic::Ordering::Relaxed))
60 }
61
62 pub fn with_auto_persist<F>(mut self, callback: F) -> Self
64 where
65 F: Fn(&HashMap<String, TaskState>) + Send + Sync + 'static,
66 {
67 self.persistence_callback = Some(Arc::new(callback));
68 self
69 }
70
71 pub async fn queue<T>(&self, task: T) -> Result<String, error_stack::Report<TaskQueueError>>
72 where
73 T: crate::TaskHandler + serde::Serialize + Send + Sync + 'static,
74 {
75 if self
77 .is_shutting_down
78 .load(std::sync::atomic::Ordering::Relaxed)
79 {
80 return Err(error_stack::report!(TaskQueueError)
81 .attach_printable("System is shutting down")
82 .attach_printable("No new tasks accepted during shutdown"));
83 }
84
85 let queue_depth = self.metrics.get_queue_depth();
87 if queue_depth >= MAX_QUEUE_SIZE as u64 {
88 return Err(error_stack::report!(TaskQueueError)
89 .attach_printable("Queue is full")
90 .attach_printable(format!("Current depth: {}/{}", queue_depth, MAX_QUEUE_SIZE)));
91 }
92
93 let task_id = uuid::Uuid::new_v4().to_string();
94 let task_name = std::any::type_name::<T>()
95 .split("::")
96 .last()
97 .unwrap_or("Unknown")
98 .to_string();
99
100 let task_data = serde_json::to_vec(&task)
101 .change_context(TaskQueueError)
102 .attach_printable("Failed to serialize task")?;
103
104 let task_state = TaskState {
106 id: task_id.clone(),
107 task_name: task_name.clone(),
108 task_data: serde_json::to_value(&task)
109 .change_context(TaskQueueError)
110 .attach_printable("Failed to serialize task for state")?,
111 status: TaskStatus::Queued,
112 retry_count: 0,
113 created_at: Utc::now(),
114 started_at: None,
115 completed_at: None,
116 duration_ms: None,
117 error_message: None,
118 worker_id: None,
119 };
120
121 {
123 let mut states = self.task_states.write().await;
124 states.insert(task_id.clone(), task_state);
125
126 if let Some(callback) = &self.persistence_callback {
128 callback(&states);
129 }
130 }
131
132 let queued_task = QueuedTask {
134 id: task_id.clone(),
135 task_name,
136 task_data,
137 retry_count: 0,
138 created_at: std::time::Instant::now(),
139 };
140
141 match tokio::time::timeout(
142 Duration::from_millis(100),
143 self.sender.send_async(queued_task),
144 )
145 .await
146 {
147 Ok(Ok(_)) => {
148 self.metrics.record_queued();
149 Ok(task_id)
150 }
151 _ => {
152 self.task_states.write().await.remove(&task_id);
154 Err(error_stack::report!(TaskQueueError)
155 .attach_printable("Failed to send task to queue")
156 .attach_printable("Timeout or channel disconnected")
157 .attach_printable(format!("Task ID: {}", task_id)))
158 }
159 }
160 }
161
162 pub async fn load_state(&self, states: HashMap<String, TaskState>) {
163 let mut task_states = self.task_states.write().await;
164 task_states.clear();
165 task_states.extend(states);
166
167 for (task_id, task_state) in &*task_states {
169 if matches!(
170 task_state.status,
171 TaskStatus::Queued | TaskStatus::InProgress
172 ) {
173 if let Ok(task_data) = serde_json::to_vec(&task_state.task_data) {
174 let queued_task = QueuedTask {
175 id: task_id.clone(),
176 task_name: task_state.task_name.clone(),
177 task_data,
178 retry_count: task_state.retry_count,
179 created_at: std::time::Instant::now(),
180 };
181
182 let _ = self.sender.try_send(queued_task);
184 }
185 }
186 }
187
188 tracing::info!(
189 "Loaded {} task states, {} incomplete tasks requeued",
190 task_states.len(),
191 task_states.values().filter(|t| !t.is_terminal()).count()
192 );
193 }
194
195 pub async fn get_state(&self) -> HashMap<String, TaskState> {
196 self.task_states.read().await.clone()
197 }
198
199 pub async fn get_status(&self, job_id: &str) -> Option<TaskStatus> {
200 let states = self.task_states.read().await;
201 states.get(job_id).map(|state| state.status.clone())
202 }
203
204 pub async fn get_task(&self, task_id: &str) -> Option<TaskState> {
205 self.task_states.read().await.get(task_id).cloned()
206 }
207
208 pub async fn get_result(&self, job_id: &str) -> Option<CachedJobResult> {
209 let results = self.results_cache.read().await;
210 results.get(job_id).cloned()
211 }
212
213 pub async fn get_job_metrics(&self, job_id: &str) -> Option<JobMetrics> {
214 let states = self.task_states.read().await;
215 states.get(job_id).map(JobMetrics::from)
216 }
217
218 pub async fn list_tasks(
219 &self,
220 status: Option<TaskStatus>,
221 limit: Option<usize>,
222 ) -> Vec<TaskState> {
223 let states = self.task_states.read().await;
224 let mut tasks: Vec<TaskState> = states
225 .values()
226 .filter(|task| status.as_ref().is_none_or(|s| &task.status == s))
227 .cloned()
228 .collect();
229
230 tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
231
232 if let Some(limit) = limit {
233 tasks.truncate(limit);
234 }
235
236 tasks
237 }
238
239 pub async fn get_tasks_by_status(&self, status: TaskStatus) -> Vec<TaskState> {
240 self.task_states
241 .read()
242 .await
243 .values()
244 .filter(|task| task.status == status)
245 .cloned()
246 .collect()
247 }
248
249 pub async fn store_success(
250 &self,
251 job_id: String,
252 data: serde_json::Value,
253 ttl: Option<Duration>,
254 ) {
255 let cached_result = CachedJobResult {
256 job_id: job_id.clone(),
257 completed_at: Utc::now(),
258 success: true,
259 data,
260 error: None,
261 ttl,
262 };
263
264 let mut results = self.results_cache.write().await;
265 results.insert(job_id.clone(), cached_result);
266
267 if let Some(ttl) = ttl {
268 let cache = self.results_cache.clone();
269 let id = job_id.clone();
270 tokio::spawn(async move {
271 tokio::time::sleep(ttl).await;
272 let mut results = cache.write().await;
273 results.remove(&id);
274 });
275 }
276 }
277
278 pub async fn store_failure(&self, job_id: String, error: String, ttl: Option<Duration>) {
279 let cached_result = CachedJobResult {
280 job_id: job_id.clone(),
281 completed_at: Utc::now(),
282 success: false,
283 data: serde_json::json!({}),
284 error: Some(error),
285 ttl,
286 };
287
288 let mut results = self.results_cache.write().await;
289 results.insert(job_id.clone(), cached_result);
290
291 if let Some(ttl) = ttl {
292 let cache = self.results_cache.clone();
293 let id = job_id.clone();
294 tokio::spawn(async move {
295 tokio::time::sleep(ttl).await;
296 let mut results = cache.write().await;
297 results.remove(&id);
298 });
299 }
300 }
301
302 pub async fn cleanup_old_tasks(&self, older_than: DateTime<Utc>) -> usize {
303 let mut states = self.task_states.write().await;
304 let initial_count = states.len();
305
306 states.retain(|_, task| {
307 match task.status {
308 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => task
309 .completed_at
310 .is_none_or(|completed| completed >= older_than),
311 _ => true, }
313 });
314
315 let removed = initial_count - states.len();
316
317 if removed > 0 {
319 if let Some(callback) = &self.persistence_callback {
320 callback(&states);
321 }
322 tracing::info!("Cleaned up {} old tasks", removed);
323 }
324
325 removed
326 }
327
328 pub(crate) fn sender(&self) -> &Sender<QueuedTask> {
329 &self.sender
330 }
331
332 pub(crate) fn receiver(&self) -> &Receiver<QueuedTask> {
333 &self.receiver
334 }
335
336 pub fn get_task_metrics(&self) -> crate::metrics::MetricsSnapshot {
337 self.metrics.snapshot()
338 }
339
340 pub fn queue_depth(&self) -> u64 {
341 self.metrics.get_queue_depth()
342 }
343
344 pub fn is_healthy(&self) -> bool {
345 let queue_depth = self.queue_depth();
346 queue_depth < (MAX_QUEUE_SIZE as u64 / 2)
347 }
348
349 pub fn health_status(&self) -> crate::types::HealthStatus {
350 let queue_depth = self.queue_depth();
351
352 if self.is_shutting_down() || queue_depth >= MAX_QUEUE_SIZE as u64 {
353 HealthStatus::unhealthy(queue_depth)
354 } else if queue_depth >= (MAX_QUEUE_SIZE as u64 * 3 / 4) {
355 crate::types::HealthStatus::degraded(queue_depth)
356 } else {
357 crate::types::HealthStatus::healthy(queue_depth)
358 }
359 }
360
361 pub fn shutdown(&self) {
362 self.is_shutting_down
363 .store(true, std::sync::atomic::Ordering::Relaxed);
364 tracing::info!("Task system shutdown initiated - no new tasks will be accepted");
365 }
366
367 pub fn is_shutting_down(&self) -> bool {
368 self.is_shutting_down
369 .load(std::sync::atomic::Ordering::Relaxed)
370 }
371
372 pub async fn cancel_task(&self, task_id: &str) -> bool {
376 let states = self.task_states.read().await;
377 let status = match states.get(task_id) {
378 Some(task) => task.status.clone(),
379 None => return false,
380 };
381 drop(states);
382
383 match status {
384 TaskStatus::Queued => {
385 self.update_task_status(
387 task_id,
388 TaskStatus::Cancelled,
389 None,
390 None,
391 Some("Cancelled by user".to_string()),
392 ).await;
393 true
394 }
395 TaskStatus::InProgress => {
396 let tokens = self.cancellation_tokens.read().await;
398 if let Some(token) = tokens.get(task_id) {
399 token.cancel();
400 tracing::info!(task_id = %task_id, "Cancellation signalled for in-progress task");
401 true
402 } else {
403 tracing::warn!(task_id = %task_id, "No cancellation token found, force-failing task");
405 self.update_task_status(
406 task_id,
407 TaskStatus::Cancelled,
408 None,
409 None,
410 Some("Force-cancelled (no token)".to_string()),
411 ).await;
412 true
413 }
414 }
415 _ => false,
417 }
418 }
419
420 pub(crate) async fn create_cancellation_token(&self, task_id: &str) -> tokio_util::sync::CancellationToken {
422 let token = tokio_util::sync::CancellationToken::new();
423 let mut tokens = self.cancellation_tokens.write().await;
424 tokens.insert(task_id.to_string(), token.clone());
425 token
426 }
427
428 pub(crate) async fn remove_cancellation_token(&self, task_id: &str) {
430 let mut tokens = self.cancellation_tokens.write().await;
431 tokens.remove(task_id);
432 }
433
434 pub(crate) fn metrics_ref(&self) -> &Arc<TaskMetrics> {
439 &self.metrics
440 }
441
442 pub(crate) async fn update_task_status(
443 &self,
444 task_id: &str,
445 status: TaskStatus,
446 worker_id: Option<usize>,
447 duration_ms: Option<u64>,
448 error_message: Option<String>,
449 ) {
450 let mut states = self.task_states.write().await;
451 if let Some(task) = states.get_mut(task_id) {
452 let old_status = task.status.clone();
453
454 task.status = status.clone();
455 task.worker_id = worker_id;
456 task.error_message = error_message;
457
458 if let Some(duration) = duration_ms {
459 task.duration_ms = Some(duration);
460 self.metrics.record_processing_time(duration);
461 }
462
463 match status {
464 TaskStatus::InProgress => {
465 task.started_at = Some(Utc::now());
466 }
467 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
468 task.completed_at = Some(Utc::now());
469 }
470 TaskStatus::Retrying => {
471 task.retry_count += 1;
472 task.started_at = None; self.metrics.record_retried();
474 }
475 _ => {}
476 }
477
478 match (&old_status, &status) {
480 (TaskStatus::Queued, TaskStatus::InProgress) => {
481 tracing::debug!(task_id = %task_id, worker_id = ?worker_id, "Task started");
482 }
483 (TaskStatus::InProgress, TaskStatus::Completed) => {
484 tracing::info!(
485 task_id = %task_id,
486 duration_ms = ?duration_ms,
487 "Task completed successfully"
488 );
489 }
490 (TaskStatus::InProgress, TaskStatus::Failed) => {
491 tracing::warn!(
492 task_id = %task_id,
493 error = ?task.error_message,
494 retry_count = task.retry_count,
495 "Task failed"
496 );
497 }
498 _ => {}
499 }
500
501 if let Some(callback) = &self.persistence_callback {
503 callback(&states);
504 }
505 }
506 }
507}
508
509impl Default for AppTasks {
510 fn default() -> Self {
511 Self::new()
512 }
513}
514
515impl Serialize for AppTasks {
516 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
517 where
518 S: serde::Serializer,
519 {
520 #[derive(Serialize)]
522 struct AppTasksSnapshot {
523 task_states: HashMap<String, TaskState>,
524 }
525
526 let states = tokio::task::block_in_place(|| {
527 tokio::runtime::Handle::current()
528 .block_on(async { self.task_states.read().await.clone() })
529 });
530
531 let snapshot = AppTasksSnapshot {
532 task_states: states,
533 };
534 snapshot.serialize(serializer)
535 }
536}
537
538impl<'de> Deserialize<'de> for AppTasks {
539 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
540 where
541 D: serde::Deserializer<'de>,
542 {
543 #[derive(Deserialize)]
544 struct AppTasksSnapshot {
545 task_states: HashMap<String, TaskState>,
546 }
547
548 let snapshot = AppTasksSnapshot::deserialize(deserializer)?;
549
550 let app_tasks = AppTasks::new();
551
552 let states = snapshot.task_states;
553 let app_tasks_clone = app_tasks.clone();
554 tokio::spawn(async move {
555 app_tasks_clone.load_state(states).await;
556 });
557
558 Ok(app_tasks)
559 }
560}
561
562#[derive(Debug)]
563pub struct TaskQueueError;
564
565impl std::fmt::Display for TaskQueueError {
566 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567 write!(f, "Task queue operation failed")
568 }
569}
570
571impl error_stack::Context for TaskQueueError {}