1use crate::{
2 CachedJobResult, JobMetrics, TaskMetrics, TaskState, TaskStatus,
3 types::{HealthStatus, MAX_QUEUE_SIZE, QueuedTask},
4};
5use chrono::{DateTime, Utc};
6use error_stack::ResultExt;
7use flume::{Receiver, Sender};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::Duration;
12use tokio::sync::RwLock;
13
14#[derive(Clone)]
15pub struct AppTasks {
16 sender: Sender<QueuedTask>,
18 receiver: Receiver<QueuedTask>,
19 metrics: Arc<TaskMetrics>,
20 task_states: Arc<tokio::sync::RwLock<HashMap<String, TaskState>>>,
21 results_cache: Arc<RwLock<HashMap<String, CachedJobResult>>>,
22 persistence_callback: Option<Arc<dyn Fn(&HashMap<String, TaskState>) + Send + Sync>>,
23 is_shutting_down: Arc<std::sync::atomic::AtomicBool>,
24}
25
26impl AppTasks {
27 pub fn new() -> Self {
28 let (sender, receiver) = flume::bounded(MAX_QUEUE_SIZE);
29
30 Self {
31 sender,
32 receiver,
33 metrics: Arc::new(TaskMetrics::new()),
34 task_states: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
35 results_cache: Arc::new(RwLock::new(HashMap::new())),
36 persistence_callback: None,
37 is_shutting_down: Arc::new(std::sync::atomic::AtomicBool::new(false)),
38 }
39 }
40
41 pub fn with_auto_persist<F>(mut self, callback: F) -> Self
43 where
44 F: Fn(&HashMap<String, TaskState>) + Send + Sync + 'static,
45 {
46 self.persistence_callback = Some(Arc::new(callback));
47 self
48 }
49
50 pub async fn queue<T>(&self, task: T) -> Result<String, error_stack::Report<TaskQueueError>>
51 where
52 T: crate::TaskHandler + serde::Serialize + Send + Sync + 'static,
53 {
54 if self
56 .is_shutting_down
57 .load(std::sync::atomic::Ordering::Relaxed)
58 {
59 return Err(error_stack::report!(TaskQueueError)
60 .attach_printable("System is shutting down")
61 .attach_printable("No new tasks accepted during shutdown"));
62 }
63
64 let queue_depth = self.metrics.get_queue_depth();
66 if queue_depth >= MAX_QUEUE_SIZE as u64 {
67 return Err(error_stack::report!(TaskQueueError)
68 .attach_printable("Queue is full")
69 .attach_printable(format!("Current depth: {}/{}", queue_depth, MAX_QUEUE_SIZE)));
70 }
71
72 let task_id = uuid::Uuid::new_v4().to_string();
73 let task_name = std::any::type_name::<T>()
74 .split("::")
75 .last()
76 .unwrap_or("Unknown")
77 .to_string();
78
79 let task_data = serde_json::to_vec(&task)
80 .change_context(TaskQueueError)
81 .attach_printable("Failed to serialize task")?;
82
83 let task_state = TaskState {
85 id: task_id.clone(),
86 task_name: task_name.clone(),
87 task_data: serde_json::to_value(&task)
88 .change_context(TaskQueueError)
89 .attach_printable("Failed to serialize task for state")?,
90 status: TaskStatus::Queued,
91 retry_count: 0,
92 created_at: Utc::now(),
93 started_at: None,
94 completed_at: None,
95 duration_ms: None,
96 error_message: None,
97 worker_id: None,
98 };
99
100 {
102 let mut states = self.task_states.write().await;
103 states.insert(task_id.clone(), task_state);
104
105 if let Some(callback) = &self.persistence_callback {
107 callback(&states);
108 }
109 }
110
111 let queued_task = QueuedTask {
113 id: task_id.clone(),
114 task_name,
115 task_data,
116 retry_count: 0,
117 created_at: std::time::Instant::now(),
118 };
119
120 match tokio::time::timeout(
121 Duration::from_millis(100),
122 self.sender.send_async(queued_task),
123 )
124 .await
125 {
126 Ok(Ok(_)) => {
127 self.metrics.record_queued();
128 Ok(task_id)
129 }
130 _ => {
131 self.task_states.write().await.remove(&task_id);
133 Err(error_stack::report!(TaskQueueError)
134 .attach_printable("Failed to send task to queue")
135 .attach_printable("Timeout or channel disconnected")
136 .attach_printable(format!("Task ID: {}", task_id)))
137 }
138 }
139 }
140
141 pub async fn load_state(&self, states: HashMap<String, TaskState>) {
142 let mut task_states = self.task_states.write().await;
143 task_states.clear();
144 task_states.extend(states);
145
146 for (task_id, task_state) in &*task_states {
148 if matches!(
149 task_state.status,
150 TaskStatus::Queued | TaskStatus::InProgress
151 ) {
152 if let Ok(task_data) = serde_json::to_vec(&task_state.task_data) {
153 let queued_task = QueuedTask {
154 id: task_id.clone(),
155 task_name: task_state.task_name.clone(),
156 task_data,
157 retry_count: task_state.retry_count,
158 created_at: std::time::Instant::now(),
159 };
160
161 let _ = self.sender.try_send(queued_task);
163 }
164 }
165 }
166
167 tracing::info!(
168 "Loaded {} task states, {} incomplete tasks requeued",
169 task_states.len(),
170 task_states.values().filter(|t| !t.is_terminal()).count()
171 );
172 }
173
174 pub async fn get_state(&self) -> HashMap<String, TaskState> {
175 self.task_states.read().await.clone()
176 }
177
178 pub async fn get_status(&self, job_id: &str) -> Option<TaskStatus> {
179 let states = self.task_states.read().await;
180 states.get(job_id).map(|state| state.status.clone())
181 }
182
183 pub async fn get_task(&self, task_id: &str) -> Option<TaskState> {
184 self.task_states.read().await.get(task_id).cloned()
185 }
186
187 pub async fn get_result(&self, job_id: &str) -> Option<CachedJobResult> {
188 let results = self.results_cache.read().await;
189 results.get(job_id).cloned()
190 }
191
192 pub async fn get_job_metrics(&self, job_id: &str) -> Option<JobMetrics> {
193 let states = self.task_states.read().await;
194 states.get(job_id).map(JobMetrics::from)
195 }
196
197 pub async fn list_tasks(
198 &self,
199 status: Option<TaskStatus>,
200 limit: Option<usize>,
201 ) -> Vec<TaskState> {
202 let states = self.task_states.read().await;
203 let mut tasks: Vec<TaskState> = states
204 .values()
205 .filter(|task| status.as_ref().is_none_or(|s| &task.status == s))
206 .cloned()
207 .collect();
208
209 tasks.sort_by(|a, b| b.created_at.cmp(&a.created_at));
210
211 if let Some(limit) = limit {
212 tasks.truncate(limit);
213 }
214
215 tasks
216 }
217
218 pub async fn get_tasks_by_status(&self, status: TaskStatus) -> Vec<TaskState> {
219 self.task_states
220 .read()
221 .await
222 .values()
223 .filter(|task| task.status == status)
224 .cloned()
225 .collect()
226 }
227
228 pub async fn store_success(
229 &self,
230 job_id: String,
231 data: serde_json::Value,
232 ttl: Option<Duration>,
233 ) {
234 let cached_result = CachedJobResult {
235 job_id: job_id.clone(),
236 completed_at: Utc::now(),
237 success: true,
238 data,
239 error: None,
240 ttl,
241 };
242
243 let mut results = self.results_cache.write().await;
244 results.insert(job_id.clone(), cached_result);
245
246 if let Some(ttl) = ttl {
247 let cache = self.results_cache.clone();
248 let id = job_id.clone();
249 tokio::spawn(async move {
250 tokio::time::sleep(ttl).await;
251 let mut results = cache.write().await;
252 results.remove(&id);
253 });
254 }
255 }
256
257 pub async fn store_failure(&self, job_id: String, error: String, ttl: Option<Duration>) {
258 let cached_result = CachedJobResult {
259 job_id: job_id.clone(),
260 completed_at: Utc::now(),
261 success: false,
262 data: serde_json::json!({}),
263 error: Some(error),
264 ttl,
265 };
266
267 let mut results = self.results_cache.write().await;
268 results.insert(job_id.clone(), cached_result);
269
270 if let Some(ttl) = ttl {
271 let cache = self.results_cache.clone();
272 let id = job_id.clone();
273 tokio::spawn(async move {
274 tokio::time::sleep(ttl).await;
275 let mut results = cache.write().await;
276 results.remove(&id);
277 });
278 }
279 }
280
281 pub async fn cleanup_old_tasks(&self, older_than: DateTime<Utc>) -> usize {
282 let mut states = self.task_states.write().await;
283 let initial_count = states.len();
284
285 states.retain(|_, task| {
286 match task.status {
287 TaskStatus::Completed | TaskStatus::Failed => task
288 .completed_at
289 .is_none_or(|completed| completed >= older_than),
290 _ => true, }
292 });
293
294 let removed = initial_count - states.len();
295
296 if removed > 0 {
298 if let Some(callback) = &self.persistence_callback {
299 callback(&states);
300 }
301 tracing::info!("Cleaned up {} old tasks", removed);
302 }
303
304 removed
305 }
306
307 pub(crate) fn sender(&self) -> &Sender<QueuedTask> {
308 &self.sender
309 }
310
311 pub(crate) fn receiver(&self) -> &Receiver<QueuedTask> {
312 &self.receiver
313 }
314
315 pub fn get_task_metrics(&self) -> crate::metrics::MetricsSnapshot {
316 self.metrics.snapshot()
317 }
318
319 pub fn queue_depth(&self) -> u64 {
320 self.metrics.get_queue_depth()
321 }
322
323 pub fn is_healthy(&self) -> bool {
324 let queue_depth = self.queue_depth();
325 queue_depth < (MAX_QUEUE_SIZE as u64 / 2)
326 }
327
328 pub fn health_status(&self) -> crate::types::HealthStatus {
329 let queue_depth = self.queue_depth();
330
331 if self.is_shutting_down() || queue_depth >= MAX_QUEUE_SIZE as u64 {
332 HealthStatus::unhealthy(queue_depth)
333 } else if queue_depth >= (MAX_QUEUE_SIZE as u64 * 3 / 4) {
334 crate::types::HealthStatus::degraded(queue_depth)
335 } else {
336 crate::types::HealthStatus::healthy(queue_depth)
337 }
338 }
339
340 pub fn shutdown(&self) {
341 self.is_shutting_down
342 .store(true, std::sync::atomic::Ordering::Relaxed);
343 tracing::info!("Task system shutdown initiated - no new tasks will be accepted");
344 }
345
346 pub fn is_shutting_down(&self) -> bool {
347 self.is_shutting_down
348 .load(std::sync::atomic::Ordering::Relaxed)
349 }
350
351 pub(crate) fn metrics_ref(&self) -> &Arc<TaskMetrics> {
356 &self.metrics
357 }
358
359 pub(crate) async fn update_task_status(
360 &self,
361 task_id: &str,
362 status: TaskStatus,
363 worker_id: Option<usize>,
364 duration_ms: Option<u64>,
365 error_message: Option<String>,
366 ) {
367 let mut states = self.task_states.write().await;
368 if let Some(task) = states.get_mut(task_id) {
369 let old_status = task.status.clone();
370
371 task.status = status.clone();
372 task.worker_id = worker_id;
373 task.error_message = error_message;
374
375 if let Some(duration) = duration_ms {
376 task.duration_ms = Some(duration);
377 self.metrics.record_processing_time(duration);
378 }
379
380 match status {
381 TaskStatus::InProgress => {
382 task.started_at = Some(Utc::now());
383 }
384 TaskStatus::Completed | TaskStatus::Failed => {
385 task.completed_at = Some(Utc::now());
386 }
387 TaskStatus::Retrying => {
388 task.retry_count += 1;
389 task.started_at = None; self.metrics.record_retried();
391 }
392 _ => {}
393 }
394
395 match (&old_status, &status) {
397 (TaskStatus::Queued, TaskStatus::InProgress) => {
398 tracing::debug!(task_id = %task_id, worker_id = ?worker_id, "Task started");
399 }
400 (TaskStatus::InProgress, TaskStatus::Completed) => {
401 tracing::info!(
402 task_id = %task_id,
403 duration_ms = ?duration_ms,
404 "Task completed successfully"
405 );
406 }
407 (TaskStatus::InProgress, TaskStatus::Failed) => {
408 tracing::warn!(
409 task_id = %task_id,
410 error = ?task.error_message,
411 retry_count = task.retry_count,
412 "Task failed"
413 );
414 }
415 _ => {}
416 }
417
418 if let Some(callback) = &self.persistence_callback {
420 callback(&states);
421 }
422 }
423 }
424}
425
426impl Default for AppTasks {
427 fn default() -> Self {
428 Self::new()
429 }
430}
431
432impl Serialize for AppTasks {
433 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
434 where
435 S: serde::Serializer,
436 {
437 #[derive(Serialize)]
439 struct AppTasksSnapshot {
440 task_states: HashMap<String, TaskState>,
441 }
442
443 let states = tokio::task::block_in_place(|| {
444 tokio::runtime::Handle::current()
445 .block_on(async { self.task_states.read().await.clone() })
446 });
447
448 let snapshot = AppTasksSnapshot {
449 task_states: states,
450 };
451 snapshot.serialize(serializer)
452 }
453}
454
455impl<'de> Deserialize<'de> for AppTasks {
456 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
457 where
458 D: serde::Deserializer<'de>,
459 {
460 #[derive(Deserialize)]
461 struct AppTasksSnapshot {
462 task_states: HashMap<String, TaskState>,
463 }
464
465 let snapshot = AppTasksSnapshot::deserialize(deserializer)?;
466
467 let app_tasks = AppTasks::new();
468
469 let states = snapshot.task_states;
470 let app_tasks_clone = app_tasks.clone();
471 tokio::spawn(async move {
472 app_tasks_clone.load_state(states).await;
473 });
474
475 Ok(app_tasks)
476 }
477}
478
479#[derive(Debug)]
480pub struct TaskQueueError;
481
482impl std::fmt::Display for TaskQueueError {
483 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484 write!(f, "Task queue operation failed")
485 }
486}
487
488impl error_stack::Context for TaskQueueError {}