Skip to main content

axum_tasks/
app_tasks.rs

1use crate::{
2    CachedJobResult, JobMetrics, TaskMetrics, TaskState, TaskStatus,
3    types::{HealthStatus, MAX_QUEUE_SIZE, QueuedTask},
4};
5use chrono::{DateTime, Utc};
6use error_stack::ResultExt;
7use flume::{Receiver, Sender};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13
14/// Default task timeout: 30 minutes.
15const DEFAULT_TASK_TIMEOUT_SECS: u64 = 30 * 60;
16
17#[derive(Clone)]
18pub struct AppTasks {
19    // Internal queue
20    sender: Sender<QueuedTask>,
21    receiver: Receiver<QueuedTask>,
22    metrics: Arc<TaskMetrics>,
23    task_states: Arc<tokio::sync::RwLock<HashMap<String, TaskState>>>,
24    results_cache: Arc<RwLock<HashMap<String, CachedJobResult>>>,
25    persistence_callback: Option<Arc<dyn Fn(&HashMap<String, TaskState>) + Send + Sync>>,
26    is_shutting_down: Arc<std::sync::atomic::AtomicBool>,
27    /// Per-task cancellation tokens. Workers check these during execution.
28    cancellation_tokens: Arc<RwLock<HashMap<String, tokio_util::sync::CancellationToken>>>,
29    /// Task timeout in seconds. Tasks running longer than this are auto-cancelled.
30    task_timeout_secs: Arc<std::sync::atomic::AtomicU64>,
31}
32
33impl AppTasks {
34    pub fn new() -> Self {
35        let (sender, receiver) = flume::bounded(MAX_QUEUE_SIZE);
36
37        Self {
38            sender,
39            receiver,
40            metrics: Arc::new(TaskMetrics::new()),
41            task_states: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
42            results_cache: Arc::new(RwLock::new(HashMap::new())),
43            persistence_callback: None,
44            is_shutting_down: Arc::new(std::sync::atomic::AtomicBool::new(false)),
45            cancellation_tokens: Arc::new(RwLock::new(HashMap::new())),
46            task_timeout_secs: Arc::new(std::sync::atomic::AtomicU64::new(DEFAULT_TASK_TIMEOUT_SECS)),
47        }
48    }
49
50    /// Set the task timeout in seconds. Tasks running longer than this are auto-failed.
51    /// Default: 30 minutes (1800 seconds).
52    pub fn with_task_timeout(self, timeout_secs: u64) -> Self {
53        self.task_timeout_secs.store(timeout_secs, std::sync::atomic::Ordering::Relaxed);
54        self
55    }
56
57    /// Get the current task timeout.
58    pub fn task_timeout(&self) -> Duration {
59        Duration::from_secs(self.task_timeout_secs.load(std::sync::atomic::Ordering::Relaxed))
60    }
61
62    /// Enable automatic persistence when task states change
63    pub fn with_auto_persist<F>(mut self, callback: F) -> Self
64    where
65        F: Fn(&HashMap<String, TaskState>) + Send + Sync + 'static,
66    {
67        self.persistence_callback = Some(Arc::new(callback));
68        self
69    }
70
71    pub async fn queue<T>(&self, task: T) -> Result<String, error_stack::Report<TaskQueueError>>
72    where
73        T: crate::TaskHandler + serde::Serialize + Send + Sync + 'static,
74    {
75        // Check if system is shutting down
76        if self
77            .is_shutting_down
78            .load(std::sync::atomic::Ordering::Relaxed)
79        {
80            return Err(error_stack::report!(TaskQueueError)
81                .attach_printable("System is shutting down")
82                .attach_printable("No new tasks accepted during shutdown"));
83        }
84
85        // Check capacity
86        let queue_depth = self.metrics.get_queue_depth();
87        if queue_depth >= MAX_QUEUE_SIZE as u64 {
88            return Err(error_stack::report!(TaskQueueError)
89                .attach_printable("Queue is full")
90                .attach_printable(format!("Current depth: {}/{}", queue_depth, MAX_QUEUE_SIZE)));
91        }
92
93        let task_id = uuid::Uuid::new_v4().to_string();
94        let task_name = std::any::type_name::<T>()
95            .split("::")
96            .last()
97            .unwrap_or("Unknown")
98            .to_string();
99
100        let task_data = serde_json::to_vec(&task)
101            .change_context(TaskQueueError)
102            .attach_printable("Failed to serialize task")?;
103
104        // 1. FIRST: Persist state (crash-proof)
105        let task_state = TaskState {
106            id: task_id.clone(),
107            task_name: task_name.clone(),
108            task_data: serde_json::to_value(&task)
109                .change_context(TaskQueueError)
110                .attach_printable("Failed to serialize task for state")?,
111            status: TaskStatus::Queued,
112            retry_count: 0,
113            created_at: Utc::now(),
114            started_at: None,
115            completed_at: None,
116            duration_ms: None,
117            error_message: None,
118            worker_id: None,
119        };
120
121        // Add to persistent state
122        {
123            let mut states = self.task_states.write().await;
124            states.insert(task_id.clone(), task_state);
125
126            // Auto-persist if callback set
127            if let Some(callback) = &self.persistence_callback {
128                callback(&states);
129            }
130        }
131
132        // 2. THEN: Queue for processing
133        let queued_task = QueuedTask {
134            id: task_id.clone(),
135            task_name,
136            task_data,
137            retry_count: 0,
138            created_at: std::time::Instant::now(),
139        };
140
141        match tokio::time::timeout(
142            Duration::from_millis(100),
143            self.sender.send_async(queued_task),
144        )
145        .await
146        {
147            Ok(Ok(_)) => {
148                self.metrics.record_queued();
149                Ok(task_id)
150            }
151            _ => {
152                // Remove from state if queueing failed
153                self.task_states.write().await.remove(&task_id);
154                Err(error_stack::report!(TaskQueueError)
155                    .attach_printable("Failed to send task to queue")
156                    .attach_printable("Timeout or channel disconnected")
157                    .attach_printable(format!("Task ID: {}", task_id)))
158            }
159        }
160    }
161
162    pub async fn load_state(&self, states: HashMap<String, TaskState>) {
163        let mut task_states = self.task_states.write().await;
164        task_states.clear();
165        task_states.extend(states);
166
167        // Re-queue incomplete tasks for recovery
168        for (task_id, task_state) in &*task_states {
169            if matches!(
170                task_state.status,
171                TaskStatus::Queued | TaskStatus::InProgress
172            ) {
173                if let Ok(task_data) = serde_json::to_vec(&task_state.task_data) {
174                    let queued_task = QueuedTask {
175                        id: task_id.clone(),
176                        task_name: task_state.task_name.clone(),
177                        task_data,
178                        retry_count: task_state.retry_count,
179                        created_at: std::time::Instant::now(),
180                    };
181
182                    // Best effort recovery - don't block startup if queue is full
183                    let _ = self.sender.try_send(queued_task);
184                }
185            }
186        }
187
188        tracing::info!(
189            "Loaded {} task states, {} incomplete tasks requeued",
190            task_states.len(),
191            task_states.values().filter(|t| !t.is_terminal()).count()
192        );
193    }
194
195    pub async fn get_state(&self) -> HashMap<String, TaskState> {
196        self.task_states.read().await.clone()
197    }
198
199    pub async fn get_status(&self, job_id: &str) -> Option<TaskStatus> {
200        let states = self.task_states.read().await;
201        states.get(job_id).map(|state| state.status.clone())
202    }
203
204    pub async fn get_task(&self, task_id: &str) -> Option<TaskState> {
205        self.task_states.read().await.get(task_id).cloned()
206    }
207
208    pub async fn get_result(&self, job_id: &str) -> Option<CachedJobResult> {
209        let results = self.results_cache.read().await;
210        results.get(job_id).cloned()
211    }
212
213    pub async fn get_job_metrics(&self, job_id: &str) -> Option<JobMetrics> {
214        let states = self.task_states.read().await;
215        states.get(job_id).map(JobMetrics::from)
216    }
217
218    pub async fn list_tasks(
219        &self,
220        status: Option<TaskStatus>,
221        limit: Option<usize>,
222    ) -> Vec<TaskState> {
223        let states = self.task_states.read().await;
224        let mut tasks: Vec<TaskState> = states
225            .values()
226            .filter(|task| status.as_ref().is_none_or(|s| &task.status == s))
227            .cloned()
228            .collect();
229
230        tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
231
232        if let Some(limit) = limit {
233            tasks.truncate(limit);
234        }
235
236        tasks
237    }
238
239    pub async fn get_tasks_by_status(&self, status: TaskStatus) -> Vec<TaskState> {
240        self.task_states
241            .read()
242            .await
243            .values()
244            .filter(|task| task.status == status)
245            .cloned()
246            .collect()
247    }
248
249    pub async fn store_success(
250        &self,
251        job_id: String,
252        data: serde_json::Value,
253        ttl: Option<Duration>,
254    ) {
255        let cached_result = CachedJobResult {
256            job_id: job_id.clone(),
257            completed_at: Utc::now(),
258            success: true,
259            data,
260            error: None,
261            ttl,
262        };
263
264        let mut results = self.results_cache.write().await;
265        results.insert(job_id.clone(), cached_result);
266
267        if let Some(ttl) = ttl {
268            let cache = self.results_cache.clone();
269            let id = job_id.clone();
270            tokio::spawn(async move {
271                tokio::time::sleep(ttl).await;
272                let mut results = cache.write().await;
273                results.remove(&id);
274            });
275        }
276    }
277
278    pub async fn store_failure(&self, job_id: String, error: String, ttl: Option<Duration>) {
279        let cached_result = CachedJobResult {
280            job_id: job_id.clone(),
281            completed_at: Utc::now(),
282            success: false,
283            data: serde_json::json!({}),
284            error: Some(error),
285            ttl,
286        };
287
288        let mut results = self.results_cache.write().await;
289        results.insert(job_id.clone(), cached_result);
290
291        if let Some(ttl) = ttl {
292            let cache = self.results_cache.clone();
293            let id = job_id.clone();
294            tokio::spawn(async move {
295                tokio::time::sleep(ttl).await;
296                let mut results = cache.write().await;
297                results.remove(&id);
298            });
299        }
300    }
301
302    pub async fn cleanup_old_tasks(&self, older_than: DateTime<Utc>) -> usize {
303        let mut states = self.task_states.write().await;
304        let initial_count = states.len();
305
306        states.retain(|_, task| {
307            match task.status {
308                TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => task
309                    .completed_at
310                    .is_none_or(|completed| completed >= older_than),
311                _ => true, // Keep in-progress tasks
312            }
313        });
314
315        let removed = initial_count - states.len();
316
317        // Auto-persist if tasks were removed
318        if removed > 0 {
319            if let Some(callback) = &self.persistence_callback {
320                callback(&states);
321            }
322            tracing::info!("Cleaned up {} old tasks", removed);
323        }
324
325        removed
326    }
327
328    pub(crate) fn sender(&self) -> &Sender<QueuedTask> {
329        &self.sender
330    }
331
332    pub(crate) fn receiver(&self) -> &Receiver<QueuedTask> {
333        &self.receiver
334    }
335
336    pub fn get_task_metrics(&self) -> crate::metrics::MetricsSnapshot {
337        self.metrics.snapshot()
338    }
339
340    pub fn queue_depth(&self) -> u64 {
341        self.metrics.get_queue_depth()
342    }
343
344    pub fn is_healthy(&self) -> bool {
345        let queue_depth = self.queue_depth();
346        queue_depth < (MAX_QUEUE_SIZE as u64 / 2)
347    }
348
349    pub fn health_status(&self) -> crate::types::HealthStatus {
350        let queue_depth = self.queue_depth();
351
352        if self.is_shutting_down() || queue_depth >= MAX_QUEUE_SIZE as u64 {
353            HealthStatus::unhealthy(queue_depth)
354        } else if queue_depth >= (MAX_QUEUE_SIZE as u64 * 3 / 4) {
355            crate::types::HealthStatus::degraded(queue_depth)
356        } else {
357            crate::types::HealthStatus::healthy(queue_depth)
358        }
359    }
360
361    pub fn shutdown(&self) {
362        self.is_shutting_down
363            .store(true, std::sync::atomic::Ordering::Relaxed);
364        tracing::info!("Task system shutdown initiated - no new tasks will be accepted");
365    }
366
367    pub fn is_shutting_down(&self) -> bool {
368        self.is_shutting_down
369            .load(std::sync::atomic::Ordering::Relaxed)
370    }
371
372    /// Cancel a task by ID. If the task is in-progress, signals the worker
373    /// to abort it. If queued, marks it as cancelled immediately.
374    /// Returns true if the task was found and cancellation was initiated.
375    pub async fn cancel_task(&self, task_id: &str) -> bool {
376        let states = self.task_states.read().await;
377        let status = match states.get(task_id) {
378            Some(task) => task.status.clone(),
379            None => return false,
380        };
381        drop(states);
382
383        match status {
384            TaskStatus::Queued => {
385                // Mark as cancelled directly
386                self.update_task_status(
387                    task_id,
388                    TaskStatus::Cancelled,
389                    None,
390                    None,
391                    Some("Cancelled by user".to_string()),
392                ).await;
393                true
394            }
395            TaskStatus::InProgress => {
396                // Signal the cancellation token
397                let tokens = self.cancellation_tokens.read().await;
398                if let Some(token) = tokens.get(task_id) {
399                    token.cancel();
400                    tracing::info!(task_id = %task_id, "Cancellation signalled for in-progress task");
401                    true
402                } else {
403                    // Token not found — force-fail the task
404                    tracing::warn!(task_id = %task_id, "No cancellation token found, force-failing task");
405                    self.update_task_status(
406                        task_id,
407                        TaskStatus::Cancelled,
408                        None,
409                        None,
410                        Some("Force-cancelled (no token)".to_string()),
411                    ).await;
412                    true
413                }
414            }
415            // Already terminal
416            _ => false,
417        }
418    }
419
420    /// Create a cancellation token for a task. Called by worker before execution.
421    pub(crate) async fn create_cancellation_token(&self, task_id: &str) -> tokio_util::sync::CancellationToken {
422        let token = tokio_util::sync::CancellationToken::new();
423        let mut tokens = self.cancellation_tokens.write().await;
424        tokens.insert(task_id.to_string(), token.clone());
425        token
426    }
427
428    /// Remove a cancellation token after task completes. Called by worker.
429    pub(crate) async fn remove_cancellation_token(&self, task_id: &str) {
430        let mut tokens = self.cancellation_tokens.write().await;
431        tokens.remove(task_id);
432    }
433
434    //=========================================================================
435    // INTERNAL METHODS (used by worker system)
436    //=========================================================================
437
438    pub(crate) fn metrics_ref(&self) -> &Arc<TaskMetrics> {
439        &self.metrics
440    }
441
442    pub(crate) async fn update_task_status(
443        &self,
444        task_id: &str,
445        status: TaskStatus,
446        worker_id: Option<usize>,
447        duration_ms: Option<u64>,
448        error_message: Option<String>,
449    ) {
450        let mut states = self.task_states.write().await;
451        if let Some(task) = states.get_mut(task_id) {
452            let old_status = task.status.clone();
453
454            task.status = status.clone();
455            task.worker_id = worker_id;
456            task.error_message = error_message;
457
458            if let Some(duration) = duration_ms {
459                task.duration_ms = Some(duration);
460                self.metrics.record_processing_time(duration);
461            }
462
463            match status {
464                TaskStatus::InProgress => {
465                    task.started_at = Some(Utc::now());
466                }
467                TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
468                    task.completed_at = Some(Utc::now());
469                }
470                TaskStatus::Retrying => {
471                    task.retry_count += 1;
472                    task.started_at = None; // Reset for retry
473                    self.metrics.record_retried();
474                }
475                _ => {}
476            }
477
478            // Log status changes
479            match (&old_status, &status) {
480                (TaskStatus::Queued, TaskStatus::InProgress) => {
481                    tracing::debug!(task_id = %task_id, worker_id = ?worker_id, "Task started");
482                }
483                (TaskStatus::InProgress, TaskStatus::Completed) => {
484                    tracing::info!(
485                        task_id = %task_id,
486                        duration_ms = ?duration_ms,
487                        "Task completed successfully"
488                    );
489                }
490                (TaskStatus::InProgress, TaskStatus::Failed) => {
491                    tracing::warn!(
492                        task_id = %task_id,
493                        error = ?task.error_message,
494                        retry_count = task.retry_count,
495                        "Task failed"
496                    );
497                }
498                _ => {}
499            }
500
501            // Auto-persist if callback set
502            if let Some(callback) = &self.persistence_callback {
503                callback(&states);
504            }
505        }
506    }
507}
508
509impl Default for AppTasks {
510    fn default() -> Self {
511        Self::new()
512    }
513}
514
515impl Serialize for AppTasks {
516    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
517    where
518        S: serde::Serializer,
519    {
520        // Only serialize the task states - channels can't be serialized
521        #[derive(Serialize)]
522        struct AppTasksSnapshot {
523            task_states: HashMap<String, TaskState>,
524        }
525
526        let states = tokio::task::block_in_place(|| {
527            tokio::runtime::Handle::current()
528                .block_on(async { self.task_states.read().await.clone() })
529        });
530
531        let snapshot = AppTasksSnapshot {
532            task_states: states,
533        };
534        snapshot.serialize(serializer)
535    }
536}
537
538impl<'de> Deserialize<'de> for AppTasks {
539    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
540    where
541        D: serde::Deserializer<'de>,
542    {
543        #[derive(Deserialize)]
544        struct AppTasksSnapshot {
545            task_states: HashMap<String, TaskState>,
546        }
547
548        let snapshot = AppTasksSnapshot::deserialize(deserializer)?;
549
550        let app_tasks = AppTasks::new();
551
552        let states = snapshot.task_states;
553        let app_tasks_clone = app_tasks.clone();
554        tokio::spawn(async move {
555            app_tasks_clone.load_state(states).await;
556        });
557
558        Ok(app_tasks)
559    }
560}
561
562#[derive(Debug)]
563pub struct TaskQueueError;
564
565impl std::fmt::Display for TaskQueueError {
566    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
567        write!(f, "Task queue operation failed")
568    }
569}
570
571impl error_stack::Context for TaskQueueError {}