axum_tasks/
app_tasks.rs

1use crate::{
2    CachedJobResult, JobMetrics, TaskMetrics, TaskState, TaskStatus,
3    types::{HealthStatus, MAX_QUEUE_SIZE, QueuedTask},
4};
5use chrono::{DateTime, Utc};
6use error_stack::ResultExt;
7use flume::{Receiver, Sender};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13
14#[derive(Clone)]
15pub struct AppTasks {
16    // Internal queue
17    sender: Sender<QueuedTask>,
18    receiver: Receiver<QueuedTask>,
19    metrics: Arc<TaskMetrics>,
20    task_states: Arc<tokio::sync::RwLock<HashMap<String, TaskState>>>,
21    results_cache: Arc<RwLock<HashMap<String, CachedJobResult>>>,
22    persistence_callback: Option<Arc<dyn Fn(&HashMap<String, TaskState>) + Send + Sync>>,
23    is_shutting_down: Arc<std::sync::atomic::AtomicBool>,
24}
25
26impl AppTasks {
27    pub fn new() -> Self {
28        let (sender, receiver) = flume::bounded(MAX_QUEUE_SIZE);
29
30        Self {
31            sender,
32            receiver,
33            metrics: Arc::new(TaskMetrics::new()),
34            task_states: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
35            results_cache: Arc::new(RwLock::new(HashMap::new())),
36            persistence_callback: None,
37            is_shutting_down: Arc::new(std::sync::atomic::AtomicBool::new(false)),
38        }
39    }
40
41    /// Enable automatic persistence when task states change
42    pub fn with_auto_persist<F>(mut self, callback: F) -> Self
43    where
44        F: Fn(&HashMap<String, TaskState>) + Send + Sync + 'static,
45    {
46        self.persistence_callback = Some(Arc::new(callback));
47        self
48    }
49
50    pub async fn queue<T>(&self, task: T) -> Result<String, error_stack::Report<TaskQueueError>>
51    where
52        T: crate::TaskHandler + serde::Serialize + Send + Sync + 'static,
53    {
54        // Check if system is shutting down
55        if self
56            .is_shutting_down
57            .load(std::sync::atomic::Ordering::Relaxed)
58        {
59            return Err(error_stack::report!(TaskQueueError)
60                .attach_printable("System is shutting down")
61                .attach_printable("No new tasks accepted during shutdown"));
62        }
63
64        // Check capacity
65        let queue_depth = self.metrics.get_queue_depth();
66        if queue_depth >= MAX_QUEUE_SIZE as u64 {
67            return Err(error_stack::report!(TaskQueueError)
68                .attach_printable("Queue is full")
69                .attach_printable(format!("Current depth: {}/{}", queue_depth, MAX_QUEUE_SIZE)));
70        }
71
72        let task_id = uuid::Uuid::new_v4().to_string();
73        let task_name = std::any::type_name::<T>()
74            .split("::")
75            .last()
76            .unwrap_or("Unknown")
77            .to_string();
78
79        let task_data = serde_json::to_vec(&task)
80            .change_context(TaskQueueError)
81            .attach_printable("Failed to serialize task")?;
82
83        // 1. FIRST: Persist state (crash-proof)
84        let task_state = TaskState {
85            id: task_id.clone(),
86            task_name: task_name.clone(),
87            task_data: serde_json::to_value(&task)
88                .change_context(TaskQueueError)
89                .attach_printable("Failed to serialize task for state")?,
90            status: TaskStatus::Queued,
91            retry_count: 0,
92            created_at: Utc::now(),
93            started_at: None,
94            completed_at: None,
95            duration_ms: None,
96            error_message: None,
97            worker_id: None,
98        };
99
100        // Add to persistent state
101        {
102            let mut states = self.task_states.write().await;
103            states.insert(task_id.clone(), task_state);
104
105            // Auto-persist if callback set
106            if let Some(callback) = &self.persistence_callback {
107                callback(&states);
108            }
109        }
110
111        // 2. THEN: Queue for processing
112        let queued_task = QueuedTask {
113            id: task_id.clone(),
114            task_name,
115            task_data,
116            retry_count: 0,
117            created_at: std::time::Instant::now(),
118        };
119
120        match tokio::time::timeout(
121            Duration::from_millis(100),
122            self.sender.send_async(queued_task),
123        )
124        .await
125        {
126            Ok(Ok(_)) => {
127                self.metrics.record_queued();
128                Ok(task_id)
129            }
130            _ => {
131                // Remove from state if queueing failed
132                self.task_states.write().await.remove(&task_id);
133                Err(error_stack::report!(TaskQueueError)
134                    .attach_printable("Failed to send task to queue")
135                    .attach_printable("Timeout or channel disconnected")
136                    .attach_printable(format!("Task ID: {}", task_id)))
137            }
138        }
139    }
140
141    pub async fn load_state(&self, states: HashMap<String, TaskState>) {
142        let mut task_states = self.task_states.write().await;
143        task_states.clear();
144        task_states.extend(states);
145
146        // Re-queue incomplete tasks for recovery
147        for (task_id, task_state) in &*task_states {
148            if matches!(
149                task_state.status,
150                TaskStatus::Queued | TaskStatus::InProgress
151            ) {
152                if let Ok(task_data) = serde_json::to_vec(&task_state.task_data) {
153                    let queued_task = QueuedTask {
154                        id: task_id.clone(),
155                        task_name: task_state.task_name.clone(),
156                        task_data,
157                        retry_count: task_state.retry_count,
158                        created_at: std::time::Instant::now(),
159                    };
160
161                    // Best effort recovery - don't block startup if queue is full
162                    let _ = self.sender.try_send(queued_task);
163                }
164            }
165        }
166
167        tracing::info!(
168            "Loaded {} task states, {} incomplete tasks requeued",
169            task_states.len(),
170            task_states.values().filter(|t| !t.is_terminal()).count()
171        );
172    }
173
174    pub async fn get_state(&self) -> HashMap<String, TaskState> {
175        self.task_states.read().await.clone()
176    }
177
178    pub async fn get_status(&self, job_id: &str) -> Option<TaskStatus> {
179        let states = self.task_states.read().await;
180        states.get(job_id).map(|state| state.status.clone())
181    }
182
183    pub async fn get_task(&self, task_id: &str) -> Option<TaskState> {
184        self.task_states.read().await.get(task_id).cloned()
185    }
186
187    pub async fn get_result(&self, job_id: &str) -> Option<CachedJobResult> {
188        let results = self.results_cache.read().await;
189        results.get(job_id).cloned()
190    }
191
192    pub async fn get_job_metrics(&self, job_id: &str) -> Option<JobMetrics> {
193        let states = self.task_states.read().await;
194        states.get(job_id).map(JobMetrics::from)
195    }
196
197    pub async fn list_tasks(
198        &self,
199        status: Option<TaskStatus>,
200        limit: Option<usize>,
201    ) -> Vec<TaskState> {
202        let states = self.task_states.read().await;
203        let mut tasks: Vec<TaskState> = states
204            .values()
205            .filter(|task| status.as_ref().is_none_or(|s| &task.status == s))
206            .cloned()
207            .collect();
208
209        tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
210
211        if let Some(limit) = limit {
212            tasks.truncate(limit);
213        }
214
215        tasks
216    }
217
218    pub async fn get_tasks_by_status(&self, status: TaskStatus) -> Vec<TaskState> {
219        self.task_states
220            .read()
221            .await
222            .values()
223            .filter(|task| task.status == status)
224            .cloned()
225            .collect()
226    }
227
228    pub async fn store_success(
229        &self,
230        job_id: String,
231        data: serde_json::Value,
232        ttl: Option<Duration>,
233    ) {
234        let cached_result = CachedJobResult {
235            job_id: job_id.clone(),
236            completed_at: Utc::now(),
237            success: true,
238            data,
239            error: None,
240            ttl,
241        };
242
243        let mut results = self.results_cache.write().await;
244        results.insert(job_id.clone(), cached_result);
245
246        if let Some(ttl) = ttl {
247            let cache = self.results_cache.clone();
248            let id = job_id.clone();
249            tokio::spawn(async move {
250                tokio::time::sleep(ttl).await;
251                let mut results = cache.write().await;
252                results.remove(&id);
253            });
254        }
255    }
256
257    pub async fn store_failure(&self, job_id: String, error: String, ttl: Option<Duration>) {
258        let cached_result = CachedJobResult {
259            job_id: job_id.clone(),
260            completed_at: Utc::now(),
261            success: false,
262            data: serde_json::json!({}),
263            error: Some(error),
264            ttl,
265        };
266
267        let mut results = self.results_cache.write().await;
268        results.insert(job_id.clone(), cached_result);
269
270        if let Some(ttl) = ttl {
271            let cache = self.results_cache.clone();
272            let id = job_id.clone();
273            tokio::spawn(async move {
274                tokio::time::sleep(ttl).await;
275                let mut results = cache.write().await;
276                results.remove(&id);
277            });
278        }
279    }
280
281    pub async fn cleanup_old_tasks(&self, older_than: DateTime<Utc>) -> usize {
282        let mut states = self.task_states.write().await;
283        let initial_count = states.len();
284
285        states.retain(|_, task| {
286            match task.status {
287                TaskStatus::Completed | TaskStatus::Failed => task
288                    .completed_at
289                    .is_none_or(|completed| completed >= older_than),
290                _ => true, // Keep in-progress tasks
291            }
292        });
293
294        let removed = initial_count - states.len();
295
296        // Auto-persist if tasks were removed
297        if removed > 0 {
298            if let Some(callback) = &self.persistence_callback {
299                callback(&states);
300            }
301            tracing::info!("Cleaned up {} old tasks", removed);
302        }
303
304        removed
305    }
306
307    pub(crate) fn sender(&self) -> &Sender<QueuedTask> {
308        &self.sender
309    }
310
311    pub(crate) fn receiver(&self) -> &Receiver<QueuedTask> {
312        &self.receiver
313    }
314
315    pub fn get_task_metrics(&self) -> crate::metrics::MetricsSnapshot {
316        self.metrics.snapshot()
317    }
318
319    pub fn queue_depth(&self) -> u64 {
320        self.metrics.get_queue_depth()
321    }
322
323    pub fn is_healthy(&self) -> bool {
324        let queue_depth = self.queue_depth();
325        queue_depth < (MAX_QUEUE_SIZE as u64 / 2)
326    }
327
328    pub fn health_status(&self) -> crate::types::HealthStatus {
329        let queue_depth = self.queue_depth();
330
331        if self.is_shutting_down() || queue_depth >= MAX_QUEUE_SIZE as u64 {
332            HealthStatus::unhealthy(queue_depth)
333        } else if queue_depth >= (MAX_QUEUE_SIZE as u64 * 3 / 4) {
334            crate::types::HealthStatus::degraded(queue_depth)
335        } else {
336            crate::types::HealthStatus::healthy(queue_depth)
337        }
338    }
339
340    pub fn shutdown(&self) {
341        self.is_shutting_down
342            .store(true, std::sync::atomic::Ordering::Relaxed);
343        tracing::info!("Task system shutdown initiated - no new tasks will be accepted");
344    }
345
346    pub fn is_shutting_down(&self) -> bool {
347        self.is_shutting_down
348            .load(std::sync::atomic::Ordering::Relaxed)
349    }
350
351    //=========================================================================
352    // INTERNAL METHODS (used by worker system)
353    //=========================================================================
354
355    pub(crate) fn metrics_ref(&self) -> &Arc<TaskMetrics> {
356        &self.metrics
357    }
358
359    pub(crate) async fn update_task_status(
360        &self,
361        task_id: &str,
362        status: TaskStatus,
363        worker_id: Option<usize>,
364        duration_ms: Option<u64>,
365        error_message: Option<String>,
366    ) {
367        let mut states = self.task_states.write().await;
368        if let Some(task) = states.get_mut(task_id) {
369            let old_status = task.status.clone();
370
371            task.status = status.clone();
372            task.worker_id = worker_id;
373            task.error_message = error_message;
374
375            if let Some(duration) = duration_ms {
376                task.duration_ms = Some(duration);
377                self.metrics.record_processing_time(duration);
378            }
379
380            match status {
381                TaskStatus::InProgress => {
382                    task.started_at = Some(Utc::now());
383                }
384                TaskStatus::Completed | TaskStatus::Failed => {
385                    task.completed_at = Some(Utc::now());
386                }
387                TaskStatus::Retrying => {
388                    task.retry_count += 1;
389                    task.started_at = None; // Reset for retry
390                    self.metrics.record_retried();
391                }
392                _ => {}
393            }
394
395            // Log status changes
396            match (&old_status, &status) {
397                (TaskStatus::Queued, TaskStatus::InProgress) => {
398                    tracing::debug!(task_id = %task_id, worker_id = ?worker_id, "Task started");
399                }
400                (TaskStatus::InProgress, TaskStatus::Completed) => {
401                    tracing::info!(
402                        task_id = %task_id,
403                        duration_ms = ?duration_ms,
404                        "Task completed successfully"
405                    );
406                }
407                (TaskStatus::InProgress, TaskStatus::Failed) => {
408                    tracing::warn!(
409                        task_id = %task_id,
410                        error = ?task.error_message,
411                        retry_count = task.retry_count,
412                        "Task failed"
413                    );
414                }
415                _ => {}
416            }
417
418            // Auto-persist if callback set
419            if let Some(callback) = &self.persistence_callback {
420                callback(&states);
421            }
422        }
423    }
424}
425
426impl Default for AppTasks {
427    fn default() -> Self {
428        Self::new()
429    }
430}
431
432impl Serialize for AppTasks {
433    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
434    where
435        S: serde::Serializer,
436    {
437        // Only serialize the task states - channels can't be serialized
438        #[derive(Serialize)]
439        struct AppTasksSnapshot {
440            task_states: HashMap<String, TaskState>,
441        }
442
443        let states = tokio::task::block_in_place(|| {
444            tokio::runtime::Handle::current()
445                .block_on(async { self.task_states.read().await.clone() })
446        });
447
448        let snapshot = AppTasksSnapshot {
449            task_states: states,
450        };
451        snapshot.serialize(serializer)
452    }
453}
454
455impl<'de> Deserialize<'de> for AppTasks {
456    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
457    where
458        D: serde::Deserializer<'de>,
459    {
460        #[derive(Deserialize)]
461        struct AppTasksSnapshot {
462            task_states: HashMap<String, TaskState>,
463        }
464
465        let snapshot = AppTasksSnapshot::deserialize(deserializer)?;
466
467        let app_tasks = AppTasks::new();
468
469        let states = snapshot.task_states;
470        let app_tasks_clone = app_tasks.clone();
471        tokio::spawn(async move {
472            app_tasks_clone.load_state(states).await;
473        });
474
475        Ok(app_tasks)
476    }
477}
478
479#[derive(Debug)]
480pub struct TaskQueueError;
481
482impl std::fmt::Display for TaskQueueError {
483    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        write!(f, "Task queue operation failed")
485    }
486}
487
488impl error_stack::Context for TaskQueueError {}