Skip to main content

fastmcp_server/
tasks.rs

1//! Background task manager (Docket/SEP-1686).
2//!
3//! Provides support for long-running background tasks that outlive individual
4//! request lifecycles. Tasks are managed in a dedicated region that survives
5//! until server shutdown.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Server Region (root)
11//! ├── Session Region (per connection)
12//! │   └── Request Regions (tools/call, etc.)
13//! └── Background Task Region (managed by TaskManager)
14//!     ├── Task 1
15//!     ├── Task 2
16//!     └── ...
17//! ```
18//!
19//! # Usage
20//!
21//! ```ignore
22//! let task_manager = TaskManager::new();
23//!
24//! // Submit a background task
25//! let task_id = task_manager.submit(&cx, "long_analysis", Some(json!({"data": ...})))?;
26//!
27//! // Check status
28//! let info = task_manager.get_info(&task_id);
29//!
30//! // Cancel if needed
31//! task_manager.cancel(&task_id, Some("User requested"))?;
32//! ```
33
34use std::collections::HashMap;
35use std::sync::atomic::{AtomicU64, Ordering};
36use std::sync::{Arc, RwLock};
37
38use asupersync::runtime::{RuntimeBuilder, RuntimeHandle};
39use asupersync::{Budget, CancelKind, Cx};
40use fastmcp_core::logging::{debug, info, targets, warn};
41use fastmcp_core::{McpError, McpResult};
42use fastmcp_protocol::{
43    JsonRpcRequest, TaskId, TaskInfo, TaskResult, TaskStatus, TaskStatusNotificationParams,
44};
45
46/// Notification sender used for task status updates.
47pub type TaskNotificationSender = Arc<dyn Fn(JsonRpcRequest) + Send + Sync>;
48
49/// Callback type for task execution.
50///
51/// Task handlers receive the context and parameters, and return a result.
52pub type TaskHandler = Box<dyn Fn(&Cx, serde_json::Value) -> TaskFuture + Send + Sync + 'static>;
53
54/// Future type for task execution.
55pub type TaskFuture = std::pin::Pin<
56    Box<dyn std::future::Future<Output = McpResult<serde_json::Value>> + Send + 'static>,
57>;
58
59/// Internal state for a running task.
60struct TaskState {
61    /// Task information.
62    info: TaskInfo,
63    /// Whether cancellation has been requested.
64    cancel_requested: bool,
65    /// Task result once completed.
66    result: Option<TaskResult>,
67    /// Task-scoped cancellation context.
68    cx: Cx,
69}
70
71fn can_transition(from: TaskStatus, to: TaskStatus) -> bool {
72    matches!(
73        (from, to),
74        (
75            TaskStatus::Pending,
76            TaskStatus::Running | TaskStatus::Cancelled
77        ) | (
78            TaskStatus::Running,
79            TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
80        )
81    )
82}
83
84fn transition_state(state: &mut TaskState, to: TaskStatus) -> bool {
85    let from = state.info.status;
86    if from == to {
87        return true;
88    }
89    if !can_transition(from, to) {
90        warn!(
91            target: targets::SERVER,
92            "task {} invalid transition {:?} -> {:?}",
93            state.info.id,
94            from,
95            to
96        );
97        return false;
98    }
99
100    state.info.status = to;
101    let now = chrono::Utc::now().to_rfc3339();
102    match to {
103        TaskStatus::Running => {
104            state.info.started_at = Some(now.clone());
105        }
106        TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled => {
107            state.info.completed_at = Some(now.clone());
108        }
109        TaskStatus::Pending => {}
110    }
111
112    info!(
113        target: targets::SERVER,
114        "task {} status {:?} -> {:?} at {}",
115        state.info.id,
116        from,
117        to,
118        now
119    );
120    true
121}
122
123/// Background task manager.
124///
125/// Manages the lifecycle of background tasks including submission, status
126/// tracking, and cancellation.
127pub struct TaskManager {
128    /// Active and completed tasks by ID.
129    tasks: Arc<RwLock<HashMap<TaskId, TaskState>>>,
130    /// Registered task handlers by type.
131    handlers: Arc<RwLock<HashMap<String, TaskHandler>>>,
132    /// Counter for generating unique task IDs.
133    task_counter: AtomicU64,
134    /// Whether task list changes should trigger notifications.
135    list_changed_notifications: bool,
136    /// Background runtime handle for executing tasks.
137    runtime: RuntimeHandle,
138    /// Whether submitted tasks should execute immediately.
139    auto_execute: bool,
140    /// Optional notification sender for task status updates.
141    notification_sender: Arc<RwLock<Option<TaskNotificationSender>>>,
142}
143
144impl TaskManager {
145    /// Creates a new task manager.
146    #[must_use]
147    pub fn new() -> Self {
148        let runtime = RuntimeBuilder::multi_thread()
149            .build()
150            .expect("failed to build background task runtime")
151            .handle();
152        Self {
153            tasks: Arc::new(RwLock::new(HashMap::new())),
154            handlers: Arc::new(RwLock::new(HashMap::new())),
155            task_counter: AtomicU64::new(0),
156            list_changed_notifications: false,
157            runtime,
158            auto_execute: true,
159            notification_sender: Arc::new(RwLock::new(None)),
160        }
161    }
162
163    /// Creates a new task manager with list change notifications enabled.
164    #[must_use]
165    pub fn with_list_changed_notifications() -> Self {
166        Self {
167            list_changed_notifications: true,
168            ..Self::new()
169        }
170    }
171
172    /// Creates a task manager configured for deterministic tests.
173    ///
174    /// Tasks are not executed automatically; tests can drive state manually.
175    #[must_use]
176    pub fn new_for_testing() -> Self {
177        let mut manager = Self::new();
178        manager.auto_execute = false;
179        manager
180    }
181
182    /// Converts this manager into a shared handle.
183    #[must_use]
184    pub fn into_shared(self) -> SharedTaskManager {
185        Arc::new(self)
186    }
187
188    /// Returns whether list change notifications are enabled.
189    #[must_use]
190    pub fn has_list_changed_notifications(&self) -> bool {
191        self.list_changed_notifications
192    }
193
194    /// Sets the notification sender for task status updates.
195    pub fn set_notification_sender(&self, sender: TaskNotificationSender) {
196        let mut guard = self.notification_sender.write().unwrap_or_else(|poisoned| {
197            warn!(target: targets::SERVER, "notification sender lock poisoned, recovering");
198            poisoned.into_inner()
199        });
200        *guard = Some(sender);
201    }
202
203    /// Registers a task handler for a specific task type.
204    ///
205    /// The handler will be invoked when a task of this type is submitted.
206    pub fn register_handler<F, Fut>(&self, task_type: impl Into<String>, handler: F)
207    where
208        F: Fn(&Cx, serde_json::Value) -> Fut + Send + Sync + 'static,
209        Fut: std::future::Future<Output = McpResult<serde_json::Value>> + Send + 'static,
210    {
211        let task_type = task_type.into();
212        let boxed_handler: TaskHandler = Box::new(move |cx, params| Box::pin(handler(cx, params)));
213
214        let mut handlers = self.handlers.write().unwrap_or_else(|poisoned| {
215            warn!(target: targets::SERVER, "handlers lock poisoned, recovering");
216            poisoned.into_inner()
217        });
218        handlers.insert(task_type, boxed_handler);
219    }
220
221    /// Submits a new background task.
222    ///
223    /// Returns the task ID for tracking. The task runs asynchronously in the
224    /// background region.
225    pub fn submit(
226        &self,
227        _cx: &Cx,
228        task_type: impl Into<String>,
229        params: Option<serde_json::Value>,
230    ) -> McpResult<TaskId> {
231        let task_type = task_type.into();
232
233        // Check if handler exists
234        {
235            let handlers = self.handlers.read().unwrap_or_else(|poisoned| {
236                warn!(target: targets::SERVER, "handlers lock poisoned, recovering");
237                poisoned.into_inner()
238            });
239            if !handlers.contains_key(&task_type) {
240                return Err(McpError::invalid_params(format!(
241                    "Unknown task type: {task_type}"
242                )));
243            }
244        }
245
246        // Generate unique task ID
247        let counter = self.task_counter.fetch_add(1, Ordering::SeqCst);
248        let task_id = TaskId::from_string(format!("task-{counter:08x}"));
249
250        // Create task info
251        let now = chrono::Utc::now().to_rfc3339();
252        let task_cx = Cx::for_request_with_budget(Budget::INFINITE);
253        let info = TaskInfo {
254            id: task_id.clone(),
255            task_type: task_type.clone(),
256            status: TaskStatus::Pending,
257            progress: None,
258            message: None,
259            created_at: now,
260            started_at: None,
261            completed_at: None,
262            error: None,
263        };
264
265        let info_snapshot = info.clone();
266
267        // Store task state
268        let state = TaskState {
269            info,
270            cancel_requested: false,
271            result: None,
272            cx: task_cx.clone(),
273        };
274
275        {
276            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
277                warn!(target: targets::SERVER, "tasks lock poisoned, recovering");
278                poisoned.into_inner()
279            });
280            tasks.insert(task_id.clone(), state);
281        }
282
283        self.notify_status(info_snapshot, None);
284
285        if self.auto_execute {
286            let params = params.unwrap_or_else(|| serde_json::json!({}));
287            self.spawn_task(task_id.clone(), task_type, task_cx, params);
288        }
289
290        Ok(task_id)
291    }
292
293    #[allow(clippy::too_many_lines)]
294    fn spawn_task(
295        &self,
296        task_id: TaskId,
297        task_type: String,
298        task_cx: Cx,
299        params: serde_json::Value,
300    ) {
301        let tasks = Arc::clone(&self.tasks);
302        let handlers = Arc::clone(&self.handlers);
303        let notification_sender = Arc::clone(&self.notification_sender);
304
305        self.runtime.spawn(async move {
306            let running_snapshot = {
307                let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
308                    warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task, recovering");
309                    poisoned.into_inner()
310                });
311                match tasks_guard.get_mut(&task_id) {
312                    Some(state) => {
313                        if state.cancel_requested || !transition_state(state, TaskStatus::Running) {
314                            None
315                        } else {
316                            Some(TaskStatusSnapshot::from(state))
317                        }
318                    }
319                    None => None,
320                }
321            };
322
323            notify_snapshot(&notification_sender, running_snapshot);
324
325            let task_future = {
326                let handlers_guard = handlers.read().unwrap_or_else(|poisoned| {
327                    warn!(target: targets::SERVER, "handlers lock poisoned in spawn_task, recovering");
328                    poisoned.into_inner()
329                });
330                let Some(handler) = handlers_guard.get(&task_type) else {
331                    let failure_snapshot = {
332                        let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
333                            warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task failure, recovering");
334                            poisoned.into_inner()
335                        });
336                        match tasks_guard.get_mut(&task_id) {
337                            Some(state) => {
338                                if !state.cancel_requested {
339                                    let error_msg = format!("Unknown task type: {task_type}");
340                                    state.info.status = TaskStatus::Failed;
341                                    state.info.completed_at = Some(chrono::Utc::now().to_rfc3339());
342                                    state.info.error = Some(error_msg.clone());
343                                    state.result = Some(TaskResult {
344                                        id: task_id.clone(),
345                                        success: false,
346                                        data: None,
347                                        error: Some(error_msg),
348                                    });
349                                    Some(TaskStatusSnapshot::from(state))
350                                } else {
351                                    None
352                                }
353                            }
354                            None => None,
355                        }
356                    };
357                    notify_snapshot(&notification_sender, failure_snapshot);
358                    return;
359                };
360                (handler)(&task_cx, params)
361            };
362
363            let result = task_future.await;
364
365            let completion_snapshot = {
366                let mut tasks_guard = tasks.write().unwrap_or_else(|poisoned| {
367                    warn!(target: targets::SERVER, "tasks lock poisoned in spawn_task completion, recovering");
368                    poisoned.into_inner()
369                });
370                match tasks_guard.get_mut(&task_id) {
371                    Some(state) => {
372                        if state.cancel_requested {
373                            None
374                        } else {
375                            let mut snapshot = None;
376                            match result {
377                                Ok(data) => {
378                                    if transition_state(state, TaskStatus::Completed) {
379                                        state.info.progress = Some(1.0);
380                                        state.result = Some(TaskResult {
381                                            id: task_id.clone(),
382                                            success: true,
383                                            data: Some(data),
384                                            error: None,
385                                        });
386                                        snapshot = Some(TaskStatusSnapshot::from(state));
387                                    }
388                                }
389                                Err(err) => {
390                                    let error_msg = err.message;
391                                    if transition_state(state, TaskStatus::Failed) {
392                                        state.info.error = Some(error_msg.clone());
393                                        state.result = Some(TaskResult {
394                                            id: task_id.clone(),
395                                            success: false,
396                                            data: None,
397                                            error: Some(error_msg),
398                                        });
399                                        snapshot = Some(TaskStatusSnapshot::from(state));
400                                    }
401                                }
402                            }
403                            snapshot
404                        }
405                    }
406                    None => None,
407                }
408            };
409
410            notify_snapshot(&notification_sender, completion_snapshot);
411        });
412    }
413
414    /// Starts execution of a pending task.
415    ///
416    /// This is called internally to transition a task from Pending to Running.
417    pub fn start_task(&self, task_id: &TaskId) -> McpResult<()> {
418        let snapshot = {
419            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
420                warn!(target: targets::SERVER, "tasks lock poisoned in start_task, recovering");
421                poisoned.into_inner()
422            });
423            let state = tasks
424                .get_mut(task_id)
425                .ok_or_else(|| McpError::invalid_params(format!("Task not found: {task_id}")))?;
426
427            if state.info.status != TaskStatus::Pending {
428                return Err(McpError::invalid_params(format!(
429                    "Task {task_id} is not pending"
430                )));
431            }
432
433            if !transition_state(state, TaskStatus::Running) {
434                return Err(McpError::invalid_params(format!(
435                    "Task {task_id} cannot transition to running"
436                )));
437            }
438            Some(TaskStatusSnapshot::from(state))
439        };
440
441        self.notify_snapshot(snapshot);
442        Ok(())
443    }
444
445    /// Updates progress for a running task.
446    pub fn update_progress(&self, task_id: &TaskId, progress: f64, message: Option<String>) {
447        let snapshot = {
448            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
449                warn!(target: targets::SERVER, "tasks lock poisoned in update_progress, recovering");
450                poisoned.into_inner()
451            });
452            if let Some(state) = tasks.get_mut(task_id) {
453                if state.info.status != TaskStatus::Running {
454                    debug!(
455                        target: targets::SERVER,
456                        "task {} progress update ignored in state {:?}",
457                        task_id,
458                        state.info.status
459                    );
460                    return;
461                }
462                state.info.progress = Some(progress.clamp(0.0, 1.0));
463                state.info.message = message;
464                Some(TaskStatusSnapshot::from(state))
465            } else {
466                None
467            }
468        };
469
470        self.notify_snapshot(snapshot);
471    }
472
473    /// Completes a task with a successful result.
474    pub fn complete_task(&self, task_id: &TaskId, data: serde_json::Value) {
475        let snapshot = {
476            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
477                warn!(target: targets::SERVER, "tasks lock poisoned in complete_task, recovering");
478                poisoned.into_inner()
479            });
480            if let Some(state) = tasks.get_mut(task_id) {
481                if !transition_state(state, TaskStatus::Completed) {
482                    return;
483                }
484                state.info.progress = Some(1.0);
485                state.result = Some(TaskResult {
486                    id: task_id.clone(),
487                    success: true,
488                    data: Some(data),
489                    error: None,
490                });
491                Some(TaskStatusSnapshot::from(state))
492            } else {
493                None
494            }
495        };
496
497        self.notify_snapshot(snapshot);
498    }
499
500    /// Fails a task with an error.
501    pub fn fail_task(&self, task_id: &TaskId, error: impl Into<String>) {
502        let error = error.into();
503        let snapshot = {
504            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
505                warn!(target: targets::SERVER, "tasks lock poisoned in fail_task, recovering");
506                poisoned.into_inner()
507            });
508            if let Some(state) = tasks.get_mut(task_id) {
509                if !transition_state(state, TaskStatus::Failed) {
510                    return;
511                }
512                state.info.error = Some(error.clone());
513                state.result = Some(TaskResult {
514                    id: task_id.clone(),
515                    success: false,
516                    data: None,
517                    error: Some(error),
518                });
519                Some(TaskStatusSnapshot::from(state))
520            } else {
521                None
522            }
523        };
524
525        self.notify_snapshot(snapshot);
526    }
527
528    /// Gets information about a task.
529    #[must_use]
530    pub fn get_info(&self, task_id: &TaskId) -> Option<TaskInfo> {
531        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
532            warn!(target: targets::SERVER, "tasks lock poisoned in get_info, recovering");
533            poisoned.into_inner()
534        });
535        tasks.get(task_id).map(|s| s.info.clone())
536    }
537
538    /// Gets the result of a completed task.
539    #[must_use]
540    pub fn get_result(&self, task_id: &TaskId) -> Option<TaskResult> {
541        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
542            warn!(target: targets::SERVER, "tasks lock poisoned in get_result, recovering");
543            poisoned.into_inner()
544        });
545        tasks.get(task_id).and_then(|s| s.result.clone())
546    }
547
548    /// Lists all tasks, optionally filtered by status.
549    #[must_use]
550    pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
551        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
552            warn!(target: targets::SERVER, "tasks lock poisoned in list_tasks, recovering");
553            poisoned.into_inner()
554        });
555        tasks
556            .values()
557            .filter(|s| status_filter.is_none_or(|f| s.info.status == f))
558            .map(|s| s.info.clone())
559            .collect()
560    }
561
562    /// Requests cancellation of a task.
563    ///
564    /// Returns true if the task exists and cancellation was requested.
565    /// The task may still be running until it checks for cancellation.
566    pub fn cancel(&self, task_id: &TaskId, reason: Option<String>) -> McpResult<TaskInfo> {
567        let snapshot = {
568            let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
569                warn!(target: targets::SERVER, "tasks lock poisoned in cancel, recovering");
570                poisoned.into_inner()
571            });
572            let state = tasks
573                .get_mut(task_id)
574                .ok_or_else(|| McpError::invalid_params(format!("Task not found: {task_id}")))?;
575
576            // Can only cancel pending or running tasks
577            if state.info.status.is_terminal() {
578                return Err(McpError::invalid_params(format!(
579                    "Task {task_id} is already in terminal state: {:?}",
580                    state.info.status
581                )));
582            }
583
584            if !transition_state(state, TaskStatus::Cancelled) {
585                return Err(McpError::invalid_params(format!(
586                    "Task {task_id} cannot be cancelled from {:?}",
587                    state.info.status
588                )));
589            }
590
591            state.cancel_requested = true;
592
593            state.cx.cancel_with(CancelKind::User, None);
594            if !state.cx.is_cancel_requested() {
595                warn!(
596                    target: targets::SERVER,
597                    "task {} cancel signal not observed on context",
598                    task_id
599                );
600            }
601
602            let error_msg = reason.unwrap_or_else(|| "Cancelled by request".to_string());
603            state.info.error = Some(error_msg.clone());
604            state.result = Some(TaskResult {
605                id: task_id.clone(),
606                success: false,
607                data: None,
608                error: Some(error_msg),
609            });
610
611            let snapshot = TaskStatusSnapshot::from(state);
612            (snapshot, state.info.clone())
613        };
614
615        let (snapshot, info) = snapshot;
616        self.notify_snapshot(Some(snapshot));
617        Ok(info)
618    }
619
620    /// Checks if cancellation has been requested for a task.
621    #[must_use]
622    pub fn is_cancel_requested(&self, task_id: &TaskId) -> bool {
623        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
624            warn!(target: targets::SERVER, "tasks lock poisoned in is_cancel_requested, recovering");
625            poisoned.into_inner()
626        });
627        tasks.get(task_id).is_some_and(|s| s.cancel_requested)
628    }
629
630    /// Returns the number of active (non-terminal) tasks.
631    #[must_use]
632    pub fn active_count(&self) -> usize {
633        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
634            warn!(target: targets::SERVER, "tasks lock poisoned in active_count, recovering");
635            poisoned.into_inner()
636        });
637        tasks.values().filter(|s| s.info.status.is_active()).count()
638    }
639
640    /// Returns the total number of tasks.
641    #[must_use]
642    pub fn total_count(&self) -> usize {
643        let tasks = self.tasks.read().unwrap_or_else(|poisoned| {
644            warn!(target: targets::SERVER, "tasks lock poisoned in total_count, recovering");
645            poisoned.into_inner()
646        });
647        tasks.len()
648    }
649
650    /// Removes completed tasks older than the specified duration.
651    ///
652    /// This is useful for preventing unbounded memory growth from completed tasks.
653    pub fn cleanup_completed(&self, max_age: std::time::Duration) {
654        let cutoff = chrono::Utc::now() - chrono::Duration::from_std(max_age).unwrap_or_default();
655
656        let mut tasks = self.tasks.write().unwrap_or_else(|poisoned| {
657            warn!(target: targets::SERVER, "tasks lock poisoned in cleanup_completed, recovering");
658            poisoned.into_inner()
659        });
660        tasks.retain(|_, state| {
661            // Keep active tasks
662            if state.info.status.is_active() {
663                return true;
664            }
665
666            // Keep recent completed tasks
667            if let Some(ref completed) = state.info.completed_at {
668                if let Ok(parsed) = chrono::DateTime::parse_from_rfc3339(completed) {
669                    return parsed.with_timezone(&chrono::Utc) > cutoff;
670                }
671                return true;
672            }
673
674            true
675        });
676    }
677
678    fn notify_snapshot(&self, snapshot: Option<TaskStatusSnapshot>) {
679        if let Some(snapshot) = snapshot {
680            self.notify_status(snapshot.info, snapshot.result);
681        }
682    }
683
684    fn notify_status(&self, info: TaskInfo, result: Option<TaskResult>) {
685        let sender = {
686            let guard = self.notification_sender.read().unwrap_or_else(|poisoned| {
687                warn!(target: targets::SERVER, "notification sender lock poisoned in notify_status, recovering");
688                poisoned.into_inner()
689            });
690            guard.clone()
691        };
692        let Some(sender) = sender else {
693            return;
694        };
695
696        let params = TaskStatusNotificationParams {
697            id: info.id.clone(),
698            status: info.status,
699            progress: info.progress,
700            message: info.message.clone(),
701            error: info.error.clone(),
702            result,
703        };
704        let payload = match serde_json::to_value(params) {
705            Ok(value) => value,
706            Err(err) => {
707                warn!(
708                    target: targets::SERVER,
709                    "failed to serialize task status notification: {}",
710                    err
711                );
712                return;
713            }
714        };
715        sender(JsonRpcRequest::notification(
716            "notifications/tasks/status",
717            Some(payload),
718        ));
719    }
720}
721
722#[derive(Debug, Clone)]
723struct TaskStatusSnapshot {
724    info: TaskInfo,
725    result: Option<TaskResult>,
726}
727
728impl TaskStatusSnapshot {
729    fn from(state: &TaskState) -> Self {
730        Self {
731            info: state.info.clone(),
732            result: state.result.clone(),
733        }
734    }
735}
736
737fn notify_snapshot(
738    sender: &Arc<RwLock<Option<TaskNotificationSender>>>,
739    snapshot: Option<TaskStatusSnapshot>,
740) {
741    let Some(snapshot) = snapshot else {
742        return;
743    };
744    let sender = {
745        let guard = sender.read().unwrap_or_else(|poisoned| {
746            warn!(target: targets::SERVER, "notification sender lock poisoned in notify_snapshot, recovering");
747            poisoned.into_inner()
748        });
749        guard.clone()
750    };
751    let Some(sender) = sender else {
752        return;
753    };
754    let params = TaskStatusNotificationParams {
755        id: snapshot.info.id.clone(),
756        status: snapshot.info.status,
757        progress: snapshot.info.progress,
758        message: snapshot.info.message.clone(),
759        error: snapshot.info.error.clone(),
760        result: snapshot.result,
761    };
762    let payload = match serde_json::to_value(params) {
763        Ok(value) => value,
764        Err(err) => {
765            warn!(
766                target: targets::SERVER,
767                "failed to serialize task status notification: {}",
768                err
769            );
770            return;
771        }
772    };
773    sender(JsonRpcRequest::notification(
774        "notifications/tasks/status",
775        Some(payload),
776    ));
777}
778
779impl Default for TaskManager {
780    fn default() -> Self {
781        Self::new()
782    }
783}
784
785impl std::fmt::Debug for TaskManager {
786    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
787        // Use poison recovery to avoid panic during Debug formatting
788        let task_count = self
789            .tasks
790            .read()
791            .map(|g| g.len())
792            .unwrap_or_else(|poisoned| poisoned.into_inner().len());
793        let handler_count = self
794            .handlers
795            .read()
796            .map(|g| g.len())
797            .unwrap_or_else(|poisoned| poisoned.into_inner().len());
798        f.debug_struct("TaskManager")
799            .field("task_count", &task_count)
800            .field("handler_count", &handler_count)
801            .field("task_counter", &self.task_counter.load(Ordering::SeqCst))
802            .field(
803                "list_changed_notifications",
804                &self.list_changed_notifications,
805            )
806            .field("auto_execute", &self.auto_execute)
807            .finish_non_exhaustive()
808    }
809}
810
811/// Thread-safe handle to a TaskManager.
812pub type SharedTaskManager = Arc<TaskManager>;
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817    use std::sync::Arc;
818    use std::thread;
819
820    #[test]
821    fn test_task_manager_creation() {
822        let manager = TaskManager::new();
823        assert_eq!(manager.total_count(), 0);
824        assert_eq!(manager.active_count(), 0);
825        assert!(!manager.has_list_changed_notifications());
826    }
827
828    #[test]
829    fn test_task_manager_with_notifications() {
830        let manager = TaskManager::with_list_changed_notifications();
831        assert!(manager.has_list_changed_notifications());
832    }
833
834    #[test]
835    fn test_register_handler() {
836        let manager = TaskManager::new();
837
838        manager.register_handler("test_task", |_cx, _params| async {
839            Ok(serde_json::json!({}))
840        });
841
842        // Submit should succeed now
843        let cx = Cx::for_testing();
844        let result = manager.submit(&cx, "test_task", None);
845        assert!(result.is_ok());
846    }
847
848    #[test]
849    fn test_submit_unknown_task_type() {
850        let manager = TaskManager::new();
851        let cx = Cx::for_testing();
852
853        let result = manager.submit(&cx, "unknown_task", None);
854        assert!(result.is_err());
855    }
856
857    #[test]
858    fn test_task_lifecycle() {
859        let manager = TaskManager::new_for_testing();
860        let cx = Cx::for_testing();
861
862        manager.register_handler("test", |_cx, _params| async {
863            Ok(serde_json::json!({"done": true}))
864        });
865
866        // Submit
867        let task_id = manager.submit(&cx, "test", None).unwrap();
868
869        // Check initial state
870        let info = manager.get_info(&task_id).unwrap();
871        assert_eq!(info.status, TaskStatus::Pending);
872        assert!(info.started_at.is_none());
873
874        // Start
875        manager.start_task(&task_id).unwrap();
876        let info = manager.get_info(&task_id).unwrap();
877        assert_eq!(info.status, TaskStatus::Running);
878        assert!(info.started_at.is_some());
879
880        // Update progress
881        manager.update_progress(&task_id, 0.5, Some("Halfway done".into()));
882        let info = manager.get_info(&task_id).unwrap();
883        assert_eq!(info.progress, Some(0.5));
884        assert_eq!(info.message, Some("Halfway done".into()));
885
886        // Complete
887        manager.complete_task(&task_id, serde_json::json!({"result": 42}));
888        let info = manager.get_info(&task_id).unwrap();
889        assert_eq!(info.status, TaskStatus::Completed);
890        assert!(info.completed_at.is_some());
891
892        // Check result
893        let result = manager.get_result(&task_id).unwrap();
894        assert!(result.success);
895        assert_eq!(result.data, Some(serde_json::json!({"result": 42})));
896    }
897
898    #[test]
899    fn test_task_failure() {
900        let manager = TaskManager::new_for_testing();
901        let cx = Cx::for_testing();
902
903        manager.register_handler("fail_test", |_cx, _params| async {
904            Ok(serde_json::json!({}))
905        });
906
907        let task_id = manager.submit(&cx, "fail_test", None).unwrap();
908        manager.start_task(&task_id).unwrap();
909        manager.fail_task(&task_id, "Something went wrong");
910
911        let info = manager.get_info(&task_id).unwrap();
912        assert_eq!(info.status, TaskStatus::Failed);
913        assert_eq!(info.error, Some("Something went wrong".into()));
914
915        let result = manager.get_result(&task_id).unwrap();
916        assert!(!result.success);
917        assert_eq!(result.error, Some("Something went wrong".into()));
918    }
919
920    #[test]
921    fn test_task_cancellation() {
922        let manager = TaskManager::new_for_testing();
923        let cx = Cx::for_testing();
924
925        manager.register_handler("cancel_test", |_cx, _params| async {
926            Ok(serde_json::json!({}))
927        });
928
929        let task_id = manager.submit(&cx, "cancel_test", None).unwrap();
930        manager.start_task(&task_id).unwrap();
931
932        // Cancel
933        let info = manager
934            .cancel(&task_id, Some("User cancelled".into()))
935            .unwrap();
936        assert_eq!(info.status, TaskStatus::Cancelled);
937
938        // Check cancel flag
939        assert!(manager.is_cancel_requested(&task_id));
940
941        // Cannot cancel again
942        let result = manager.cancel(&task_id, None);
943        assert!(result.is_err());
944    }
945
946    #[test]
947    fn test_list_tasks() {
948        let manager = TaskManager::new_for_testing();
949        let cx = Cx::for_testing();
950
951        manager.register_handler("list_test", |_cx, _params| async {
952            Ok(serde_json::json!({}))
953        });
954
955        let task1 = manager.submit(&cx, "list_test", None).unwrap();
956        let task2 = manager.submit(&cx, "list_test", None).unwrap();
957        let _task3 = manager.submit(&cx, "list_test", None).unwrap();
958
959        // All pending initially
960        assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 3);
961        assert_eq!(manager.list_tasks(Some(TaskStatus::Running)).len(), 0);
962
963        // Start one
964        manager.start_task(&task1).unwrap();
965        assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 2);
966        assert_eq!(manager.list_tasks(Some(TaskStatus::Running)).len(), 1);
967
968        // Complete one
969        manager.start_task(&task2).unwrap();
970        manager.complete_task(&task2, serde_json::json!({}));
971        assert_eq!(manager.list_tasks(Some(TaskStatus::Completed)).len(), 1);
972
973        // All tasks
974        assert_eq!(manager.list_tasks(None).len(), 3);
975    }
976
977    #[test]
978    fn test_active_count() {
979        let manager = TaskManager::new_for_testing();
980        let cx = Cx::for_testing();
981
982        manager.register_handler("count_test", |_cx, _params| async {
983            Ok(serde_json::json!({}))
984        });
985
986        let task1 = manager.submit(&cx, "count_test", None).unwrap();
987        let task2 = manager.submit(&cx, "count_test", None).unwrap();
988
989        assert_eq!(manager.active_count(), 2);
990        assert_eq!(manager.total_count(), 2);
991
992        manager.start_task(&task1).unwrap();
993        assert_eq!(manager.active_count(), 2);
994
995        manager.complete_task(&task1, serde_json::json!({}));
996        assert_eq!(manager.active_count(), 1);
997
998        manager.cancel(&task2, None).unwrap();
999        assert_eq!(manager.active_count(), 0);
1000        assert_eq!(manager.total_count(), 2);
1001    }
1002
1003    #[test]
1004    fn test_progress_clamping() {
1005        let manager = TaskManager::new_for_testing();
1006        let cx = Cx::for_testing();
1007
1008        manager.register_handler("clamp_test", |_cx, _params| async {
1009            Ok(serde_json::json!({}))
1010        });
1011
1012        let task_id = manager.submit(&cx, "clamp_test", None).unwrap();
1013        manager.start_task(&task_id).unwrap();
1014
1015        // Progress should be clamped to [0.0, 1.0]
1016        manager.update_progress(&task_id, -0.5, None);
1017        assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(0.0));
1018
1019        manager.update_progress(&task_id, 1.5, None);
1020        assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(1.0));
1021
1022        manager.update_progress(&task_id, 0.75, None);
1023        assert_eq!(manager.get_info(&task_id).unwrap().progress, Some(0.75));
1024    }
1025
1026    #[test]
1027    fn test_invalid_transition_rejected() {
1028        let manager = TaskManager::new_for_testing();
1029        let cx = Cx::for_testing();
1030
1031        manager.register_handler("transition_test", |_cx, _params| async {
1032            Ok(serde_json::json!({}))
1033        });
1034
1035        let task_id = manager.submit(&cx, "transition_test", None).unwrap();
1036
1037        // Completing before running should be ignored.
1038        manager.complete_task(&task_id, serde_json::json!({"result": "noop"}));
1039        let info = manager.get_info(&task_id).unwrap();
1040        assert_eq!(info.status, TaskStatus::Pending);
1041
1042        manager.start_task(&task_id).unwrap();
1043        manager.complete_task(&task_id, serde_json::json!({"result": "ok"}));
1044        let info = manager.get_info(&task_id).unwrap();
1045        assert_eq!(info.status, TaskStatus::Completed);
1046
1047        // Starting after completion should fail.
1048        let result = manager.start_task(&task_id);
1049        assert!(result.is_err());
1050    }
1051
1052    #[test]
1053    fn test_concurrent_submissions() {
1054        let manager = Arc::new(TaskManager::new_for_testing());
1055        manager.register_handler("concurrent_test", |_cx, _params| async {
1056            Ok(serde_json::json!({}))
1057        });
1058
1059        let mut handles = Vec::new();
1060        for _ in 0..4 {
1061            let manager = Arc::clone(&manager);
1062            handles.push(thread::spawn(move || {
1063                let cx = Cx::for_testing();
1064                for _ in 0..10 {
1065                    let _ = manager.submit(&cx, "concurrent_test", None).unwrap();
1066                }
1067            }));
1068        }
1069
1070        for handle in handles {
1071            handle.join().expect("thread join failed");
1072        }
1073
1074        assert_eq!(manager.total_count(), 40);
1075        assert_eq!(manager.list_tasks(Some(TaskStatus::Pending)).len(), 40);
1076    }
1077
1078    #[test]
1079    fn test_task_status_notifications() {
1080        let manager = TaskManager::new_for_testing();
1081        manager.register_handler("notify_test", |_cx, _params| async {
1082            Ok(serde_json::json!({"ok": true}))
1083        });
1084
1085        let events: Arc<std::sync::Mutex<Vec<TaskStatusNotificationParams>>> =
1086            Arc::new(std::sync::Mutex::new(Vec::new()));
1087        let sender_events = Arc::clone(&events);
1088        let sender: TaskNotificationSender = Arc::new(move |request| {
1089            if request.method != "notifications/tasks/status" {
1090                return;
1091            }
1092            let params = request
1093                .params
1094                .as_ref()
1095                .and_then(|value| serde_json::from_value(value.clone()).ok())
1096                .expect("task status params");
1097            sender_events
1098                .lock()
1099                .expect("events lock poisoned")
1100                .push(params);
1101        });
1102        manager.set_notification_sender(sender);
1103
1104        let cx = Cx::for_testing();
1105        let task_id = manager.submit(&cx, "notify_test", None).unwrap();
1106        manager.start_task(&task_id).unwrap();
1107        manager.update_progress(&task_id, 0.5, Some("half".to_string()));
1108        manager.complete_task(&task_id, serde_json::json!({"result": 1}));
1109
1110        let recorded = events.lock().expect("events lock poisoned").clone();
1111        assert!(!recorded.is_empty(), "expected task status notifications");
1112        assert_eq!(recorded[0].id, task_id);
1113        assert_eq!(recorded[0].status, TaskStatus::Pending);
1114        assert_eq!(recorded[1].status, TaskStatus::Running);
1115        assert_eq!(recorded[2].progress, Some(0.5));
1116        assert_eq!(recorded.last().expect("last").status, TaskStatus::Completed);
1117    }
1118}