1use std::collections::HashMap;
35use std::sync::atomic::{AtomicU64, Ordering};
36use std::sync::{Arc, RwLock};
37
38use asupersync::runtime::{RuntimeBuilder, RuntimeHandle};
39use asupersync::{Budget, CancelKind, Cx};
40use fastmcp_core::logging::{debug, info, targets, warn};
41use fastmcp_core::{McpError, McpResult};
42use fastmcp_protocol::{
43 JsonRpcRequest, TaskId, TaskInfo, TaskResult, TaskStatus, TaskStatusNotificationParams,
44};
45
46pub type TaskNotificationSender = Arc<dyn Fn(JsonRpcRequest) + Send + Sync>;
48
49pub type TaskHandler = Box<dyn Fn(&Cx, serde_json::Value) -> TaskFuture + Send + Sync + 'static>;
53
54pub type TaskFuture = std::pin::Pin<
56 Box<dyn std::future::Future<Output = McpResult<serde_json::Value>> + Send + 'static>,
57>;
58
59struct TaskState {
61 info: TaskInfo,
63 cancel_requested: bool,
65 result: Option<TaskResult>,
67 cx: Cx,
69}
70
71fn can_transition(from: TaskStatus, to: TaskStatus) -> bool {
72 matches!(
73 (from, to),
74 (
75 TaskStatus::Pending,
76 TaskStatus::Running | TaskStatus::Cancelled
77 ) | (
78 TaskStatus::Running,
79 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
80 )
81 )
82}
83
84fn transition_state(state: &mut TaskState, to: TaskStatus) -> bool {
85 let from = state.info.status;
86 if from == to {
87 return true;
88 }
89 if !can_transition(from, to) {
90 warn!(
91 target: targets::SERVER,
92 "task {} invalid transition {:?} -> {:?}",
93 state.info.id,
94 from,
95 to
96 );
97 return false;
98 }
99
100 state.info.status = to;
101 let now = chrono::Utc::now().to_rfc3339();
102 match to {
103 TaskStatus::Running => {
104 state.info.started_at = Some(now.clone());
105 }
106 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
107 state.info.completed_at = Some(now.clone());
108 }
109 TaskStatus::Pending => {}
110 }
111
112 info!(
113 target: targets::SERVER,
114 "task {} status {:?} -> {:?} at {}",
115 state.info.id,
116 from,
117 to,
118 now
119 );
120 true
121}
122
123pub struct TaskManager {
128 tasks: Arc<RwLock<HashMap<TaskId, TaskState>>>,
130 handlers: Arc<RwLock<HashMap<String, TaskHandler>>>,
132 task_counter: AtomicU64,
134 list_changed_notifications: bool,
136 runtime: RuntimeHandle,
138 auto_execute: bool,
140 notification_sender: Arc<RwLock<Option<TaskNotificationSender>>>,
142}
143
144impl TaskManager {
145 #[must_use]
147 pub fn new() -> Self {
148 let runtime = RuntimeBuilder::multi_thread()
149 .build()
150 .expect("failed to build background task runtime")
151 .handle();
152 Self {
153 tasks: Arc::new(RwLock::new(HashMap::new())),
154 handlers: Arc::new(RwLock::new(HashMap::new())),
155 task_counter: AtomicU64::new(0),
156 list_changed_notifications: false,
157 runtime,
158 auto_execute: true,
159 notification_sender: Arc::new(RwLock::new(None)),
160 }
161 }
162
163 #[must_use]
165 pub fn with_list_changed_notifications() -> Self {
166 Self {
167 list_changed_notifications: true,
168 ..Self::new()
169 }
170 }
171
172 #[must_use]
176 pub fn new_for_testing() -> Self {
177 let mut manager = Self::new();
178 manager.auto_execute = false;
179 manager
180 }
181
182 #[must_use]
184 pub fn into_shared(self) -> SharedTaskManager {
185 Arc::new(self)
186 }
187
188 #[must_use]
190 pub fn has_list_changed_notifications(&self) -> bool {
191 self.list_changed_notifications
192 }
193
194 pub fn set_notification_sender(&self, sender: TaskNotificationSender) {
196 let mut guard = self.notification_sender.write().unwrap_or_else(|poisoned| {
197 warn!(target: targets::SERVER, "notification sender lock poisoned, recovering");
198 poisoned.into_inner()
199 });
200 *guard = Some(sender);
201 }
202
203 pub fn register_handler<F, Fut>(&self, task_type: impl Into<String>, handler: F)
207 where
208 F: Fn(&Cx, serde_json::Value) -> Fut + Send + Sync + 'static,
209 Fut: std::future::Future<Output = McpResult<serde_json::Value>> + Send + 'static,
210 {
211 let task_type = task_type.into();
212 let boxed_handler: TaskHandler = Box::new(move |cx, params| Box::pin(handler(cx, params)));
213
214 let mut handlers = self.handlers.write().unwrap_or_else(|poisoned| {
215 warn!(target: targets::SERVER, "handlers lock poisoned, recovering");
216 poisoned.into_inner()
217 });
218 handlers.insert(task_type, boxed_handler);
219 }
220
221 pub fn submit(
226 &self,
227 _cx: &Cx,
228 task_type: impl Into<String>,
229 params: Option<serde_json::Value>,
230 ) -> McpResult<TaskId> {
231 let task_type = task_type.into();
232
233 {
235 let handlers = self.handlers.read().unwrap_or_else(|poisoned| {
236 warn!(target: targets::SERVER, "handlers lock poisoned, recovering");
237 poisoned.into_inner()
238 });
239 if !handlers.contains_key(&task_type) {
240 return Err(McpError::invalid_params(format!(
241 "Unknown task type: {task_type}"
242 )));
243 }
244 }
245
246 let counter = self.task_counter.fetch_add(1, Ordering::SeqCst);
248 let task_id = TaskId::from_string(format!("task-{counter:08x}"));
249
250 let now = chrono::Utc::now().to_rfc3339();
252 let task_cx = Cx::for_request_with_budget(Budget::INFINITE);
253 let info = TaskInfo {
254 id: task_id.clone(),
255 task_type: task_type.clone(),
256 status: TaskStatus::Pending,
257 progress: None,
258 message: None,
259 created_at: now,
260 started_at: None,
261 completed_at: None,
262 error: None,
263 };
264
265 let info_snapshot = info.clone();
266
267 let state = TaskState {
269 info,
270 cancel_requested: false,
271 result: None,
272 cx: task_cx.clone(),
273 };
274
275 {
276 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
277 warn!(target: targets::SERVER, "tasks lock poisoned, recovering");
278 poisoned.into_inner()
279 });
280 tasks.insert(task_id.clone(), state);
281 }
282
283 self.notify_status(info_snapshot, None);
284
285 if self.auto_execute {
286 let params = params.unwrap_or_else(|| serde_json::json!({}));
287 self.spawn_task(task_id.clone(), task_type, task_cx, params);
288 }
289
290 Ok(task_id)
291 }
292
293 #[allow(clippy::too_many_lines)]
294 fn spawn_task(
295 &self,
296 task_id: TaskId,
297 task_type: String,
298 task_cx: Cx,
299 params: serde_json::Value,
300 ) {
301 let tasks = Arc::clone(&self.tasks);
302 let handlers = Arc::clone(&self.handlers);
303 let notification_sender = Arc::clone(&self.notification_sender);
304
305 self.runtime.spawn(async move {
306 let running_snapshot = {
307 let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
308 warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task, recovering");
309 poisoned.into_inner()
310 });
311 match tasks_guard.get_mut(&task_id) {
312 Some(state) => {
313 if state.cancel_requested || !transition_state(state, TaskStatus::Running) {
314 None
315 } else {
316 Some(TaskStatusSnapshot::from(state))
317 }
318 }
319 None => None,
320 }
321 };
322
323 notify_snapshot(¬ification_sender, running_snapshot);
324
325 let task_future = {
326 let handlers_guard = handlers.read().unwrap_or_else(|poisoned| {
327 warn!(target: targets::SERVER, "handlers lock poisoned in spawn_task, recovering");
328 poisoned.into_inner()
329 });
330 let Some(handler) = handlers_guard.get(&task_type) else {
331 let failure_snapshot = {
332 let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
333 warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task failure, recovering");
334 poisoned.into_inner()
335 });
336 match tasks_guard.get_mut(&task_id) {
337 Some(state) => {
338 if !state.cancel_requested {
339 let error_msg = format!("Unknown task type: {task_type}");
340 state.info.status = TaskStatus::Failed;
341 state.info.completed_at = Some(chrono::Utc::now().to_rfc3339());
342 state.info.error = Some(error_msg.clone());
343 state.result = Some(TaskResult {
344 id: task_id.clone(),
345 success: false,
346 data: None,
347 error: Some(error_msg),
348 });
349 Some(TaskStatusSnapshot::from(state))
350 } else {
351 None
352 }
353 }
354 None => None,
355 }
356 };
357 notify_snapshot(¬ification_sender, failure_snapshot);
358 return;
359 };
360 (handler)(&task_cx, params)
361 };
362
363 let result = task_future.await;
364
365 let completion_snapshot = {
366 let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
367 warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task completion, recovering");
368 poisoned.into_inner()
369 });
370 match tasks_guard.get_mut(&task_id) {
371 Some(state) => {
372 if state.cancel_requested {
373 None
374 } else {
375 let mut snapshot = None;
376 match result {
377 Ok(data) => {
378 if transition_state(state, TaskStatus::Completed) {
379 state.info.progress = Some(1.0);
380 state.result = Some(TaskResult {
381 id: task_id.clone(),
382 success: true,
383 data: Some(data),
384 error: None,
385 });
386 snapshot = Some(TaskStatusSnapshot::from(state));
387 }
388 }
389 Err(err) => {
390 let error_msg = err.message;
391 if transition_state(state, TaskStatus::Failed) {
392 state.info.error = Some(error_msg.clone());
393 state.result = Some(TaskResult {
394 id: task_id.clone(),
395 success: false,
396 data: None,
397 error: Some(error_msg),
398 });
399 snapshot = Some(TaskStatusSnapshot::from(state));
400 }
401 }
402 }
403 snapshot
404 }
405 }
406 None => None,
407 }
408 };
409
410 notify_snapshot(¬ification_sender, completion_snapshot);
411 });
412 }
413
414 pub fn start_task(&self, task_id: &TaskId) -> McpResult<()> {
418 let snapshot = {
419 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
420 warn!(target: targets::SERVER, "tasks lock poisoned in start_task, recovering");
421 poisoned.into_inner()
422 });
423 let state = tasks
424 .get_mut(task_id)
425 .ok_or_else(|| McpError::invalid_params(format!("Task not found: {task_id}")))?;
426
427 if state.info.status != TaskStatus::Pending {
428 return Err(McpError::invalid_params(format!(
429 "Task {task_id} is not pending"
430 )));
431 }
432
433 if !transition_state(state, TaskStatus::Running) {
434 return Err(McpError::invalid_params(format!(
435 "Task {task_id} cannot transition to running"
436 )));
437 }
438 Some(TaskStatusSnapshot::from(state))
439 };
440
441 self.notify_snapshot(snapshot);
442 Ok(())
443 }
444
445 pub fn update_progress(&self, task_id: &TaskId, progress: f64, message: Option<String>) {
447 let snapshot = {
448 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
449 warn!(target: targets::SERVER, "tasks lock poisoned in update_progress, recovering");
450 poisoned.into_inner()
451 });
452 if let Some(state) = tasks.get_mut(task_id) {
453 if state.info.status != TaskStatus::Running {
454 debug!(
455 target: targets::SERVER,
456 "task {} progress update ignored in state {:?}",
457 task_id,
458 state.info.status
459 );
460 return;
461 }
462 state.info.progress = Some(progress.clamp(0.0, 1.0));
463 state.info.message = message;
464 Some(TaskStatusSnapshot::from(state))
465 } else {
466 None
467 }
468 };
469
470 self.notify_snapshot(snapshot);
471 }
472
473 pub fn complete_task(&self, task_id: &TaskId, data: serde_json::Value) {
475 let snapshot = {
476 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
477 warn!(target: targets::SERVER, "tasks lock poisoned in complete_task, recovering");
478 poisoned.into_inner()
479 });
480 if let Some(state) = tasks.get_mut(task_id) {
481 if !transition_state(state, TaskStatus::Completed) {
482 return;
483 }
484 state.info.progress = Some(1.0);
485 state.result = Some(TaskResult {
486 id: task_id.clone(),
487 success: true,
488 data: Some(data),
489 error: None,
490 });
491 Some(TaskStatusSnapshot::from(state))
492 } else {
493 None
494 }
495 };
496
497 self.notify_snapshot(snapshot);
498 }
499
500 pub fn fail_task(&self, task_id: &TaskId, error: impl Into<String>) {
502 let error = error.into();
503 let snapshot = {
504 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
505 warn!(target: targets::SERVER, "tasks lock poisoned in fail_task, recovering");
506 poisoned.into_inner()
507 });
508 if let Some(state) = tasks.get_mut(task_id) {
509 if !transition_state(state, TaskStatus::Failed) {
510 return;
511 }
512 state.info.error = Some(error.clone());
513 state.result = Some(TaskResult {
514 id: task_id.clone(),
515 success: false,
516 data: None,
517 error: Some(error),
518 });
519 Some(TaskStatusSnapshot::from(state))
520 } else {
521 None
522 }
523 };
524
525 self.notify_snapshot(snapshot);
526 }
527
528 #[must_use]
530 pub fn get_info(&self, task_id: &TaskId) -> Option<TaskInfo> {
531 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
532 warn!(target: targets::SERVER, "tasks lock poisoned in get_info, recovering");
533 poisoned.into_inner()
534 });
535 tasks.get(task_id).map(|s| s.info.clone())
536 }
537
538 #[must_use]
540 pub fn get_result(&self, task_id: &TaskId) -> Option<TaskResult> {
541 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
542 warn!(target: targets::SERVER, "tasks lock poisoned in get_result, recovering");
543 poisoned.into_inner()
544 });
545 tasks.get(task_id).and_then(|s| s.result.clone())
546 }
547
548 #[must_use]
550 pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
551 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
552 warn!(target: targets::SERVER, "tasks lock poisoned in list_tasks, recovering");
553 poisoned.into_inner()
554 });
555 tasks
556 .values()
557 .filter(|s| status_filter.is_none_or(|f| s.info.status == f))
558 .map(|s| s.info.clone())
559 .collect()
560 }
561
562 pub fn cancel(&self, task_id: &TaskId, reason: Option<String>) -> McpResult<TaskInfo> {
567 let snapshot = {
568 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
569 warn!(target: targets::SERVER, "tasks lock poisoned in cancel, recovering");
570 poisoned.into_inner()
571 });
572 let state = tasks
573 .get_mut(task_id)
574 .ok_or_else(|| McpError::invalid_params(format!("Task not found: {task_id}")))?;
575
576 if state.info.status.is_terminal() {
578 return Err(McpError::invalid_params(format!(
579 "Task {task_id} is already in terminal state: {:?}",
580 state.info.status
581 )));
582 }
583
584 if !transition_state(state, TaskStatus::Cancelled) {
585 return Err(McpError::invalid_params(format!(
586 "Task {task_id} cannot be cancelled from {:?}",
587 state.info.status
588 )));
589 }
590
591 state.cancel_requested = true;
592
593 state.cx.cancel_with(CancelKind::User, None);
594 if !state.cx.is_cancel_requested() {
595 warn!(
596 target: targets::SERVER,
597 "task {} cancel signal not observed on context",
598 task_id
599 );
600 }
601
602 let error_msg = reason.unwrap_or_else(|| "Cancelled by request".to_string());
603 state.info.error = Some(error_msg.clone());
604 state.result = Some(TaskResult {
605 id: task_id.clone(),
606 success: false,
607 data: None,
608 error: Some(error_msg),
609 });
610
611 let snapshot = TaskStatusSnapshot::from(state);
612 (snapshot, state.info.clone())
613 };
614
615 let (snapshot, info) = snapshot;
616 self.notify_snapshot(Some(snapshot));
617 Ok(info)
618 }
619
620 #[must_use]
622 pub fn is_cancel_requested(&self, task_id: &TaskId) -> bool {
623 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
624 warn!(target: targets::SERVER, "tasks lock poisoned in is_cancel_requested, recovering");
625 poisoned.into_inner()
626 });
627 tasks.get(task_id).is_some_and(|s| s.cancel_requested)
628 }
629
630 #[must_use]
632 pub fn active_count(&self) -> usize {
633 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
634 warn!(target: targets::SERVER, "tasks lock poisoned in active_count, recovering");
635 poisoned.into_inner()
636 });
637 tasks.values().filter(|s| s.info.status.is_active()).count()
638 }
639
640 #[must_use]
642 pub fn total_count(&self) -> usize {
643 let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
644 warn!(target: targets::SERVER, "tasks lock poisoned in total_count, recovering");
645 poisoned.into_inner()
646 });
647 tasks.len()
648 }
649
650 pub fn cleanup_completed(&self, max_age: std::time::Duration) {
654 let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_age).unwrap_or_default();
655
656 let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
657 warn!(target: targets::SERVER, "tasks lock poisoned in cleanup_completed, recovering");
658 poisoned.into_inner()
659 });
660 tasks.retain(|_, state| {
661 if state.info.status.is_active() {
663 return true;
664 }
665
666 if let Some(ref completed) = state.info.completed_at {
668 if let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(completed) {
669 return parsed.with_timezone(&chrono::Utc) > cutoff;
670 }
671 return true;
672 }
673
674 true
675 });
676 }
677
678 fn notify_snapshot(&self, snapshot: Option<TaskStatusSnapshot>) {
679 if let Some(snapshot) = snapshot {
680 self.notify_status(snapshot.info, snapshot.result);
681 }
682 }
683
684 fn notify_status(&self, info: TaskInfo, result: Option<TaskResult>) {
685 let sender = {
686 let guard = self.notification_sender.read().unwrap_or_else(|poisoned| {
687 warn!(target: targets::SERVER, "notification sender lock poisoned in notify_status, recovering");
688 poisoned.into_inner()
689 });
690 guard.clone()
691 };
692 let Some(sender) = sender else {
693 return;
694 };
695
696 let params = TaskStatusNotificationParams {
697 id: info.id.clone(),
698 status: info.status,
699 progress: info.progress,
700 message: info.message.clone(),
701 error: info.error.clone(),
702 result,
703 };
704 let payload = match serde_json::to_value(params) {
705 Ok(value) => value,
706 Err(err) => {
707 warn!(
708 target: targets::SERVER,
709 "failed to serialize task status notification: {}",
710 err
711 );
712 return;
713 }
714 };
715 sender(JsonRpcRequest::notification(
716 "notifications/tasks/status",
717 Some(payload),
718 ));
719 }
720}
721
722#[derive(Debug, Clone)]
723struct TaskStatusSnapshot {
724 info: TaskInfo,
725 result: Option<TaskResult>,
726}
727
728impl TaskStatusSnapshot {
729 fn from(state: &TaskState) -> Self {
730 Self {
731 info: state.info.clone(),
732 result: state.result.clone(),
733 }
734 }
735}
736
737fn notify_snapshot(
738 sender: &Arc<RwLock<Option<TaskNotificationSender>>>,
739 snapshot: Option<TaskStatusSnapshot>,
740) {
741 let Some(snapshot) = snapshot else {
742 return;
743 };
744 let sender = {
745 let guard = sender.read().unwrap_or_else(|poisoned| {
746 warn!(target: targets::SERVER, "notification sender lock poisoned in notify_snapshot, recovering");
747 poisoned.into_inner()
748 });
749 guard.clone()
750 };
751 let Some(sender) = sender else {
752 return;
753 };
754 let params = TaskStatusNotificationParams {
755 id: snapshot.info.id.clone(),
756 status: snapshot.info.status,
757 progress: snapshot.info.progress,
758 message: snapshot.info.message.clone(),
759 error: snapshot.info.error.clone(),
760 result: snapshot.result,
761 };
762 let payload = match serde_json::to_value(params) {
763 Ok(value) => value,
764 Err(err) => {
765 warn!(
766 target: targets::SERVER,
767 "failed to serialize task status notification: {}",
768 err
769 );
770 return;
771 }
772 };
773 sender(JsonRpcRequest::notification(
774 "notifications/tasks/status",
775 Some(payload),
776 ));
777}
778
779impl Default for TaskManager {
780 fn default() -> Self {
781 Self::new()
782 }
783}
784
785impl std::fmt::Debug for TaskManager {
786 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
787 let task_count = self
789 .tasks
790 .read()
791 .map(|g| g.len())
792 .unwrap_or_else(|poisoned| poisoned.into_inner().len());
793 let handler_count = self
794 .handlers
795 .read()
796 .map(|g| g.len())
797 .unwrap_or_else(|poisoned| poisoned.into_inner().len());
798 f.debug_struct("TaskManager")
799 .field("task_count", &task_count)
800 .field("handler_count", &handler_count)
801 .field("task_counter", &self.task_counter.load(Ordering::SeqCst))
802 .field(
803 "list_changed_notifications",
804 &self.list_changed_notifications,
805 )
806 .field("auto_execute", &self.auto_execute)
807 .finish_non_exhaustive()
808 }
809}
810
811pub type SharedTaskManager = Arc<TaskManager>;
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817 use std::sync::Arc;
818 use std::thread;
819
820 #[test]
821 fn test_task_manager_creation() {
822 let manager = TaskManager::new();
823 assert_eq!(manager.total_count(), 0);
824 assert_eq!(manager.active_count(), 0);
825 assert!(!manager.has_list_changed_notifications());
826 }
827
828 #[test]
829 fn test_task_manager_with_notifications() {
830 let manager = TaskManager::with_list_changed_notifications();
831 assert!(manager.has_list_changed_notifications());
832 }
833
834 #[test]
835 fn test_register_handler() {
836 let manager = TaskManager::new();
837
838 manager.register_handler("test_task", |_cx, _params| async {
839 Ok(serde_json::json!({}))
840 });
841
842 let cx = Cx::for_testing();
844 let result = manager.submit(&cx, "test_task", None);
845 assert!(result.is_ok());
846 }
847
848 #[test]
849 fn test_submit_unknown_task_type() {
850 let manager = TaskManager::new();
851 let cx = Cx::for_testing();
852
853 let result = manager.submit(&cx, "unknown_task", None);
854 assert!(result.is_err());
855 }
856
857 #[test]
858 fn test_task_lifecycle() {
859 let manager = TaskManager::new_for_testing();
860 let cx = Cx::for_testing();
861
862 manager.register_handler("test", |_cx, _params| async {
863 Ok(serde_json::json!({"done": true}))
864 });
865
866 let task_id = manager.submit(&cx, "test", None).unwrap();
868
869 let info = manager.get_info(&task_id).unwrap();
871 assert_eq!(info.status, TaskStatus::Pending);
872 assert!(info.started_at.is_none());
873
874 manager.start_task(&task_id).unwrap();
876 let info = manager.get_info(&task_id).unwrap();
877 assert_eq!(info.status, TaskStatus::Running);
878 assert!(info.started_at.is_some());
879
880 manager.update_progress(&task_id, 0.5, Some("Halfway done".into()));
882 let info = manager.get_info(&task_id).unwrap();
883 assert_eq!(info.progress, Some(0.5));
884 assert_eq!(info.message, Some("Halfway done".into()));
885
886 manager.complete_task(&task_id, serde_json::json!({"result": 42}));
888 let info = manager.get_info(&task_id).unwrap();
889 assert_eq!(info.status, TaskStatus::Completed);
890 assert!(info.completed_at.is_some());
891
892 let result = manager.get_result(&task_id).unwrap();
894 assert!(result.success);
895 assert_eq!(result.data, Some(serde_json::json!({"result": 42})));
896 }
897
898 #[test]
899 fn test_task_failure() {
900 let manager = TaskManager::new_for_testing();
901 let cx = Cx::for_testing();
902
903 manager.register_handler("fail_test", |_cx, _params| async {
904 Ok(serde_json::json!({}))
905 });
906
907 let task_id = manager.submit(&cx, "fail_test", None).unwrap();
908 manager.start_task(&task_id).unwrap();
909 manager.fail_task(&task_id, "Something went wrong");
910
911 let info = manager.get_info(&task_id).unwrap();
912 assert_eq!(info.status, TaskStatus::Failed);
913 assert_eq!(info.error, Some("Something went wrong".into()));
914
915 let result = manager.get_result(&task_id).unwrap();
916 assert!(!result.success);
917 assert_eq!(result.error, Some("Something went wrong".into()));
918 }
919
920 #[test]
921 fn test_task_cancellation() {
922 let manager = TaskManager::new_for_testing();
923 let cx = Cx::for_testing();
924
925 manager.register_handler("cancel_test", |_cx, _params| async {
926 Ok(serde_json::json!({}))
927 });
928
929 let task_id = manager.submit(&cx, "cancel_test", None).unwrap();
930 manager.start_task(&task_id).unwrap();
931
932 let info = manager
934 .cancel(&task_id, Some("User cancelled".into()))
935 .unwrap();
936 assert_eq!(info.status, TaskStatus::Cancelled);
937
938 assert!(manager.is_cancel_requested(&task_id));
940
941 let result = manager.cancel(&task_id, None);
943 assert!(result.is_err());
944 }
945
946 #[test]
947 fn test_list_tasks() {
948 let manager = TaskManager::new_for_testing();
949 let cx = Cx::for_testing();
950
951 manager.register_handler("list_test", |_cx, _params| async {
952 Ok(serde_json::json!({}))
953 });
954
955 let task1 = manager.submit(&cx, "list_test", None).unwrap();
956 let task2 = manager.submit(&cx, "list_test", None).unwrap();
957 let _task3 = manager.submit(&cx, "list_test", None).unwrap();
958
959 assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 3);
961 assert_eq!(manager.list_tasks(Some(TaskStatus::Running)).len(), 0);
962
963 manager.start_task(&task1).unwrap();
965 assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 2);
966 assert_eq!(manager.list_tasks(Some(TaskStatus::Running)).len(), 1);
967
968 manager.start_task(&task2).unwrap();
970 manager.complete_task(&task2, serde_json::json!({}));
971 assert_eq!(manager.list_tasks(Some(TaskStatus::Completed)).len(), 1);
972
973 assert_eq!(manager.list_tasks(None).len(), 3);
975 }
976
977 #[test]
978 fn test_active_count() {
979 let manager = TaskManager::new_for_testing();
980 let cx = Cx::for_testing();
981
982 manager.register_handler("count_test", |_cx, _params| async {
983 Ok(serde_json::json!({}))
984 });
985
986 let task1 = manager.submit(&cx, "count_test", None).unwrap();
987 let task2 = manager.submit(&cx, "count_test", None).unwrap();
988
989 assert_eq!(manager.active_count(), 2);
990 assert_eq!(manager.total_count(), 2);
991
992 manager.start_task(&task1).unwrap();
993 assert_eq!(manager.active_count(), 2);
994
995 manager.complete_task(&task1, serde_json::json!({}));
996 assert_eq!(manager.active_count(), 1);
997
998 manager.cancel(&task2, None).unwrap();
999 assert_eq!(manager.active_count(), 0);
1000 assert_eq!(manager.total_count(), 2);
1001 }
1002
1003 #[test]
1004 fn test_progress_clamping() {
1005 let manager = TaskManager::new_for_testing();
1006 let cx = Cx::for_testing();
1007
1008 manager.register_handler("clamp_test", |_cx, _params| async {
1009 Ok(serde_json::json!({}))
1010 });
1011
1012 let task_id = manager.submit(&cx, "clamp_test", None).unwrap();
1013 manager.start_task(&task_id).unwrap();
1014
1015 manager.update_progress(&task_id, -0.5, None);
1017 assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(0.0));
1018
1019 manager.update_progress(&task_id, 1.5, None);
1020 assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(1.0));
1021
1022 manager.update_progress(&task_id, 0.75, None);
1023 assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(0.75));
1024 }
1025
1026 #[test]
1027 fn test_invalid_transition_rejected() {
1028 let manager = TaskManager::new_for_testing();
1029 let cx = Cx::for_testing();
1030
1031 manager.register_handler("transition_test", |_cx, _params| async {
1032 Ok(serde_json::json!({}))
1033 });
1034
1035 let task_id = manager.submit(&cx, "transition_test", None).unwrap();
1036
1037 manager.complete_task(&task_id, serde_json::json!({"result": "noop"}));
1039 let info = manager.get_info(&task_id).unwrap();
1040 assert_eq!(info.status, TaskStatus::Pending);
1041
1042 manager.start_task(&task_id).unwrap();
1043 manager.complete_task(&task_id, serde_json::json!({"result": "ok"}));
1044 let info = manager.get_info(&task_id).unwrap();
1045 assert_eq!(info.status, TaskStatus::Completed);
1046
1047 let result = manager.start_task(&task_id);
1049 assert!(result.is_err());
1050 }
1051
1052 #[test]
1053 fn test_concurrent_submissions() {
1054 let manager = Arc::new(TaskManager::new_for_testing());
1055 manager.register_handler("concurrent_test", |_cx, _params| async {
1056 Ok(serde_json::json!({}))
1057 });
1058
1059 let mut handles = Vec::new();
1060 for _ in 0..4 {
1061 let manager = Arc::clone(&manager);
1062 handles.push(thread::spawn(move || {
1063 let cx = Cx::for_testing();
1064 for _ in 0..10 {
1065 let _ = manager.submit(&cx, "concurrent_test", None).unwrap();
1066 }
1067 }));
1068 }
1069
1070 for handle in handles {
1071 handle.join().expect("thread join failed");
1072 }
1073
1074 assert_eq!(manager.total_count(), 40);
1075 assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 40);
1076 }
1077
1078 #[test]
1079 fn test_task_status_notifications() {
1080 let manager = TaskManager::new_for_testing();
1081 manager.register_handler("notify_test", |_cx, _params| async {
1082 Ok(serde_json::json!({"ok": true}))
1083 });
1084
1085 let events: Arc<std::sync::Mutex<Vec<TaskStatusNotificationParams>>> =
1086 Arc::new(std::sync::Mutex::new(Vec::new()));
1087 let sender_events = Arc::clone(&events);
1088 let sender: TaskNotificationSender = Arc::new(move |request| {
1089 if request.method != "notifications/tasks/status" {
1090 return;
1091 }
1092 let params = request
1093 .params
1094 .as_ref()
1095 .and_then(|value| serde_json::from_value(value.clone()).ok())
1096 .expect("task status params");
1097 sender_events
1098 .lock()
1099 .expect("events lock poisoned")
1100 .push(params);
1101 });
1102 manager.set_notification_sender(sender);
1103
1104 let cx = Cx::for_testing();
1105 let task_id = manager.submit(&cx, "notify_test", None).unwrap();
1106 manager.start_task(&task_id).unwrap();
1107 manager.update_progress(&task_id, 0.5, Some("half".to_string()));
1108 manager.complete_task(&task_id, serde_json::json!({"result": 1}));
1109
1110 let recorded = events.lock().expect("events lock poisoned").clone();
1111 assert!(!recorded.is_empty(), "expected task status notifications");
1112 assert_eq!(recorded[0].id, task_id);
1113 assert_eq!(recorded[0].status, TaskStatus::Pending);
1114 assert_eq!(recorded[1].status, TaskStatus::Running);
1115 assert_eq!(recorded[2].progress, Some(0.5));
1116 assert_eq!(recorded.last().expect("last").status, TaskStatus::Completed);
1117 }
1118}