1use crate::errors::{ClaudeError, Result};
59use serde::{Deserialize, Serialize};
60use std::collections::HashMap;
61use std::sync::Arc;
62use tokio::sync::RwLock;
63use uuid::Uuid;
64
65pub type TaskId = String;
67
68pub type TaskUri = String;
70
71#[derive(Debug, Clone, Serialize, Deserialize, Default)]
73pub struct TaskRequest {
74 #[serde(default)]
76 pub method: String,
77 #[serde(default)]
79 pub params: serde_json::Value,
80 #[serde(skip_serializing_if = "Option::is_none")]
82 pub task_hint: Option<TaskHint>,
83 #[serde(skip_serializing_if = "Option::is_none")]
85 pub priority: Option<TaskPriority>,
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct TaskHint {
91 #[serde(skip_serializing_if = "Option::is_none")]
93 pub estimated_duration_secs: Option<u64>,
94 #[serde(default)]
96 pub supports_progress: bool,
97 #[serde(default)]
99 pub cancellable: bool,
100}
101
102impl Default for TaskHint {
103 fn default() -> Self {
104 Self {
105 estimated_duration_secs: None,
106 supports_progress: false,
107 cancellable: true,
108 }
109 }
110}
111
112#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
114#[serde(rename_all = "lowercase")]
115pub enum TaskPriority {
116 Low,
117 Normal,
118 High,
119 Urgent,
120}
121
122impl Default for TaskPriority {
123 fn default() -> Self {
124 Self::Normal
125 }
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
130#[serde(rename_all = "lowercase")]
131pub enum TaskState {
132 Queued,
133 Working,
134 InputRequired,
135 Completed,
136 Failed,
137 Cancelled,
138}
139
140impl TaskState {
141 pub fn is_terminal(&self) -> bool {
143 matches!(self, Self::Completed | Self::Failed | Self::Cancelled)
144 }
145
146 pub fn is_active(&self) -> bool {
148 matches!(self, Self::Queued | Self::Working | Self::InputRequired)
149 }
150}
151
152#[derive(Debug, Clone, Serialize, Deserialize)]
154pub struct TaskProgress {
155 pub value: f64,
157 #[serde(skip_serializing_if = "Option::is_none")]
159 pub message: Option<String>,
160}
161
162impl TaskProgress {
163 pub fn new(value: f64) -> Self {
165 assert!(
166 (0.0..=1.0).contains(&value),
167 "Progress must be between 0.0 and 1.0"
168 );
169 Self {
170 value,
171 message: None,
172 }
173 }
174
175 pub fn with_message(mut self, message: impl Into<String>) -> Self {
177 self.message = Some(message.into());
178 self
179 }
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize)]
184pub struct TaskStatus {
185 pub id: TaskId,
187 pub state: TaskState,
189 #[serde(skip_serializing_if = "Option::is_none")]
191 pub progress: Option<TaskProgress>,
192 #[serde(skip_serializing_if = "Option::is_none")]
194 pub error: Option<String>,
195 pub created_at: chrono::DateTime<chrono::Utc>,
197 pub updated_at: chrono::DateTime<chrono::Utc>,
199 #[serde(skip_serializing_if = "Option::is_none")]
201 pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
202}
203
204impl TaskStatus {
205 pub fn is_terminal(&self) -> bool {
207 self.state.is_terminal()
208 }
209
210 pub fn is_active(&self) -> bool {
212 self.state.is_active()
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct TaskResult {
219 pub id: TaskId,
221 pub data: serde_json::Value,
223 pub completed_at: chrono::DateTime<chrono::Utc>,
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct TaskHandle {
230 pub id: TaskId,
232 pub uri: TaskUri,
234 pub status: TaskStatus,
236}
237
238#[derive(Debug, Clone)]
240struct Task {
241 id: TaskId,
242 request: TaskRequest,
243 state: TaskState,
244 progress: Option<TaskProgress>,
245 result: Option<serde_json::Value>,
246 error: Option<String>,
247 created_at: chrono::DateTime<chrono::Utc>,
248 updated_at: chrono::DateTime<chrono::Utc>,
249 completed_at: Option<chrono::DateTime<chrono::Utc>>,
250}
251
252impl Task {
253 fn new(request: TaskRequest) -> Self {
254 let now = chrono::Utc::now();
255 Self {
256 id: Uuid::new_v4().to_string(),
257 request,
258 state: TaskState::Queued,
259 progress: None,
260 result: None,
261 error: None,
262 created_at: now,
263 updated_at: now,
264 completed_at: None,
265 }
266 }
267
268 fn to_status(&self) -> TaskStatus {
269 TaskStatus {
270 id: self.id.clone(),
271 state: self.state.clone(),
272 progress: self.progress.clone(),
273 error: self.error.clone(),
274 created_at: self.created_at,
275 updated_at: self.updated_at,
276 completed_at: self.completed_at,
277 }
278 }
279}
280
281#[derive(Clone)]
286pub struct TaskManager {
287 tasks: Arc<RwLock<HashMap<TaskId, Task>>>,
288 base_uri: String,
289}
290
291impl TaskManager {
292 pub fn new() -> Self {
294 Self::with_base_uri("mcp://tasks".to_string())
295 }
296
297 pub fn with_base_uri(base_uri: impl Into<String>) -> Self {
299 Self {
300 tasks: Arc::new(RwLock::new(HashMap::new())),
301 base_uri: base_uri.into(),
302 }
303 }
304
305 pub async fn create_task(&self, request: TaskRequest) -> Result<TaskHandle> {
309 let task = Task::new(request);
310 let task_id = task.id.clone();
311 let uri = format!("{}/{}", self.base_uri, task_id);
312 let status = task.to_status();
313
314 let mut tasks = self.tasks.write().await;
316 tasks.insert(task_id.clone(), task);
317
318 Ok(TaskHandle {
319 id: task_id,
320 uri,
321 status,
322 })
323 }
324
325 pub async fn get_task_status(&self, task_id: &TaskId) -> Result<TaskStatus> {
327 let tasks = self.tasks.read().await;
328 let task = tasks
329 .get(task_id)
330 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
331
332 Ok(task.to_status())
333 }
334
335 pub async fn get_task_result(&self, task_id: &TaskId) -> Result<TaskResult> {
339 let tasks = self.tasks.read().await;
340 let task = tasks
341 .get(task_id)
342 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
343
344 if task.state != TaskState::Completed {
345 return Err(ClaudeError::InvalidInput(format!(
346 "Task is not completed. Current state: {:?}",
347 task.state
348 )));
349 }
350
351 let result = task.result.as_ref().ok_or_else(|| {
352 ClaudeError::InternalError("Completed task has no result".to_string())
353 })?;
354
355 Ok(TaskResult {
356 id: task_id.clone(),
357 data: result.clone(),
358 completed_at: task.completed_at.unwrap(),
359 })
360 }
361
362 pub async fn update_progress(&self, task_id: &TaskId, progress: TaskProgress) -> Result<()> {
366 let mut tasks = self.tasks.write().await;
367 let task = tasks
368 .get_mut(task_id)
369 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
370
371 if task.state.is_terminal() {
372 return Err(ClaudeError::InvalidInput(
373 "Cannot update progress for terminal task".to_string(),
374 ));
375 }
376
377 task.progress = Some(progress);
378 task.updated_at = chrono::Utc::now();
379
380 Ok(())
381 }
382
383 pub async fn mark_working(&self, task_id: &TaskId) -> Result<()> {
385 let mut tasks = self.tasks.write().await;
386 let task = tasks
387 .get_mut(task_id)
388 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
389
390 if task.state.is_terminal() {
391 return Err(ClaudeError::InvalidInput(
392 "Cannot transition terminal task".to_string(),
393 ));
394 }
395
396 task.state = TaskState::Working;
397 task.updated_at = chrono::Utc::now();
398
399 Ok(())
400 }
401
402 pub async fn mark_completed(&self, task_id: &TaskId, result: serde_json::Value) -> Result<()> {
404 let mut tasks = self.tasks.write().await;
405 let task = tasks
406 .get_mut(task_id)
407 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
408
409 if task.state.is_terminal() {
410 return Err(ClaudeError::InvalidInput(
411 "Cannot transition terminal task".to_string(),
412 ));
413 }
414
415 let now = chrono::Utc::now();
416 task.state = TaskState::Completed;
417 task.result = Some(result);
418 task.updated_at = now;
419 task.completed_at = Some(now);
420
421 Ok(())
422 }
423
424 pub async fn mark_failed(&self, task_id: &TaskId, error: impl Into<String>) -> Result<()> {
426 let mut tasks = self.tasks.write().await;
427 let task = tasks
428 .get_mut(task_id)
429 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
430
431 if task.state.is_terminal() {
432 return Err(ClaudeError::InvalidInput(
433 "Cannot transition terminal task".to_string(),
434 ));
435 }
436
437 let now = chrono::Utc::now();
438 task.state = TaskState::Failed;
439 task.error = Some(error.into());
440 task.updated_at = now;
441 task.completed_at = Some(now);
442
443 Ok(())
444 }
445
446 pub async fn mark_cancelled(&self, task_id: &TaskId) -> Result<()> {
448 let mut tasks = self.tasks.write().await;
449 let task = tasks
450 .get_mut(task_id)
451 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
452
453 if task.state.is_terminal() {
454 return Err(ClaudeError::InvalidInput(
455 "Cannot transition terminal task".to_string(),
456 ));
457 }
458
459 let now = chrono::Utc::now();
460 task.state = TaskState::Cancelled;
461 task.updated_at = now;
462 task.completed_at = Some(now);
463
464 Ok(())
465 }
466
467 pub async fn mark_input_required(&self, task_id: &TaskId) -> Result<()> {
469 let mut tasks = self.tasks.write().await;
470 let task = tasks
471 .get_mut(task_id)
472 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
473
474 if task.state.is_terminal() {
475 return Err(ClaudeError::InvalidInput(
476 "Cannot transition terminal task".to_string(),
477 ));
478 }
479
480 task.state = TaskState::InputRequired;
481 task.updated_at = chrono::Utc::now();
482
483 Ok(())
484 }
485
486 pub async fn list_tasks(&self) -> Result<Vec<TaskStatus>> {
488 let tasks = self.tasks.read().await;
489 Ok(tasks.values().map(|t| t.to_status()).collect())
490 }
491
492 pub async fn cancel_task(&self, task_id: &TaskId) -> Result<()> {
497 let mut tasks = self.tasks.write().await;
498 let task = tasks
499 .get_mut(task_id)
500 .ok_or_else(|| ClaudeError::NotFound(format!("Task not found: {}", task_id)))?;
501
502 if task.state.is_terminal() {
503 return Err(ClaudeError::InvalidInput(format!(
504 "Cannot cancel task in state: {:?}",
505 task.state
506 )));
507 }
508
509 if let Some(hint) = &task.request.task_hint {
511 if !hint.cancellable {
512 return Err(ClaudeError::InvalidInput(
513 "Task is not cancellable".to_string(),
514 ));
515 }
516 }
517
518 let now = chrono::Utc::now();
519 task.state = TaskState::Cancelled;
520 task.updated_at = now;
521 task.completed_at = Some(now);
522
523 Ok(())
524 }
525
526 pub async fn cleanup_old_tasks(&self, older_than: chrono::Duration) -> Result<usize> {
530 let mut tasks = self.tasks.write().await;
531 let cutoff = chrono::Utc::now() - older_than;
532
533 let initial_count = tasks.len();
534 tasks.retain(|_, task| {
535 if let Some(completed_at) = task.completed_at {
536 completed_at > cutoff
537 } else {
538 true }
540 });
541
542 Ok(initial_count - tasks.len())
543 }
544}
545
546impl Default for TaskManager {
547 fn default() -> Self {
548 Self::new()
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555 use serde_json::json;
556
557 #[tokio::test]
558 async fn test_task_creation() {
559 let manager = TaskManager::new();
560
561 let request = TaskRequest {
562 method: "tools/call".to_string(),
563 params: json!({"name": "test"}),
564 ..Default::default()
565 };
566
567 let handle = manager.create_task(request).await.unwrap();
568 assert!(!handle.id.is_empty());
569 assert!(!handle.uri.is_empty());
570 assert_eq!(handle.status.state, TaskState::Queued);
571 }
572
573 #[tokio::test]
574 async fn test_task_status() {
575 let manager = TaskManager::new();
576
577 let request = TaskRequest {
578 method: "tools/call".to_string(),
579 params: json!({}),
580 ..Default::default()
581 };
582
583 let handle = manager.create_task(request).await.unwrap();
584 let status = manager.get_task_status(&handle.id).await.unwrap();
585
586 assert_eq!(status.id, handle.id);
587 assert_eq!(status.state, TaskState::Queued);
588 assert!(status.is_active());
589 assert!(!status.is_terminal());
590 }
591
592 #[tokio::test]
593 async fn test_task_lifecycle() {
594 let manager = TaskManager::new();
595
596 let request = TaskRequest {
597 method: "tools/call".to_string(),
598 params: json!({}),
599 ..Default::default()
600 };
601
602 let handle = manager.create_task(request).await.unwrap();
603
604 manager.mark_working(&handle.id).await.unwrap();
606 let status = manager.get_task_status(&handle.id).await.unwrap();
607 assert_eq!(status.state, TaskState::Working);
608
609 let progress = TaskProgress::new(0.5).with_message("Half done");
611 manager.update_progress(&handle.id, progress).await.unwrap();
612 let status = manager.get_task_status(&handle.id).await.unwrap();
613 assert_eq!(status.progress.as_ref().unwrap().value, 0.5);
614 assert_eq!(
615 status.progress.as_ref().unwrap().message.as_ref().unwrap(),
616 "Half done"
617 );
618
619 let result = json!({"output": "success"});
621 manager.mark_completed(&handle.id, result).await.unwrap();
622 let status = manager.get_task_status(&handle.id).await.unwrap();
623 assert_eq!(status.state, TaskState::Completed);
624 assert!(status.is_terminal());
625 assert!(!status.is_active());
626
627 let task_result = manager.get_task_result(&handle.id).await.unwrap();
629 assert_eq!(task_result.id, handle.id);
630 assert_eq!(task_result.data, json!({"output": "success"}));
631 }
632
633 #[tokio::test]
634 async fn test_task_failure() {
635 let manager = TaskManager::new();
636
637 let request = TaskRequest {
638 method: "tools/call".to_string(),
639 params: json!({}),
640 ..Default::default()
641 };
642
643 let handle = manager.create_task(request).await.unwrap();
644
645 manager
646 .mark_failed(&handle.id, "Something went wrong")
647 .await
648 .unwrap();
649
650 let status = manager.get_task_status(&handle.id).await.unwrap();
651 assert_eq!(status.state, TaskState::Failed);
652 assert!(status.is_terminal());
653 assert_eq!(status.error.as_ref().unwrap(), "Something went wrong");
654 }
655
656 #[tokio::test]
657 async fn test_task_cancellation() {
658 let manager = TaskManager::new();
659
660 let request = TaskRequest {
661 method: "tools/call".to_string(),
662 params: json!({}),
663 task_hint: Some(TaskHint {
664 cancellable: true,
665 ..Default::default()
666 }),
667 ..Default::default()
668 };
669
670 let handle = manager.create_task(request).await.unwrap();
671 manager.cancel_task(&handle.id).await.unwrap();
672
673 let status = manager.get_task_status(&handle.id).await.unwrap();
674 assert_eq!(status.state, TaskState::Cancelled);
675 assert!(status.is_terminal());
676 }
677
678 #[tokio::test]
679 async fn test_non_cancellable_task() {
680 let manager = TaskManager::new();
681
682 let request = TaskRequest {
683 method: "tools/call".to_string(),
684 params: json!({}),
685 task_hint: Some(TaskHint {
686 cancellable: false,
687 ..Default::default()
688 }),
689 ..Default::default()
690 };
691
692 let handle = manager.create_task(request).await.unwrap();
693 let result = manager.cancel_task(&handle.id).await;
694
695 assert!(result.is_err());
696 }
697
698 #[tokio::test]
699 async fn test_terminal_state_transitions() {
700 let manager = TaskManager::new();
701
702 let request = TaskRequest {
703 method: "tools/call".to_string(),
704 params: json!({}),
705 ..Default::default()
706 };
707
708 let handle = manager.create_task(request).await.unwrap();
709
710 manager.mark_completed(&handle.id, json!({})).await.unwrap();
712
713 assert!(manager.mark_working(&handle.id).await.is_err());
715 assert!(
716 manager
717 .update_progress(&handle.id, TaskProgress::new(0.5))
718 .await
719 .is_err()
720 );
721 }
722
723 #[tokio::test]
724 async fn test_list_tasks() {
725 let manager = TaskManager::new();
726
727 let request = TaskRequest {
728 method: "tools/call".to_string(),
729 params: json!({}),
730 ..Default::default()
731 };
732
733 let _task1 = manager.create_task(request.clone()).await.unwrap();
734 let _task2 = manager.create_task(request).await.unwrap();
735
736 let tasks = manager.list_tasks().await.unwrap();
737 assert_eq!(tasks.len(), 2);
738 }
739
740 #[tokio::test]
741 async fn test_progress_bounds() {
742 assert!(TaskProgress::new(0.0).value == 0.0);
744 assert!(TaskProgress::new(0.5).value == 0.5);
745 assert!(TaskProgress::new(1.0).value == 1.0);
746 }
747
748 #[tokio::test]
749 async fn test_priority_ordering() {
750 assert!(TaskPriority::Low < TaskPriority::Normal);
751 assert!(TaskPriority::Normal < TaskPriority::High);
752 assert!(TaskPriority::High < TaskPriority::Urgent);
753 }
754
755 #[tokio::test]
756 async fn test_cleanup_old_tasks() {
757 let manager = TaskManager::new();
758
759 let request = TaskRequest {
760 method: "tools/call".to_string(),
761 params: json!({}),
762 ..Default::default()
763 };
764
765 let handle = manager.create_task(request).await.unwrap();
766 manager.mark_completed(&handle.id, json!({})).await.unwrap();
767
768 let cleaned = manager
770 .cleanup_old_tasks(chrono::Duration::seconds(1))
771 .await
772 .unwrap();
773 assert_eq!(cleaned, 0);
774
775 let cleaned = manager
777 .cleanup_old_tasks(chrono::Duration::seconds(0))
778 .await
779 .unwrap();
780 assert_eq!(cleaned, 1);
781 }
782}