mcpkit_server/capability/
tasks.rs1use crate::context::CancellationToken;
10use crate::context::Context;
11use crate::handler::TaskHandler;
12use mcpkit_core::error::McpError;
13use mcpkit_core::types::task::{Task, TaskId, TaskStatus};
14use serde_json::Value;
15use std::collections::HashMap;
16use std::sync::{Arc, RwLock};
17use std::time::Instant;
18
19#[derive(Debug)]
21pub struct TaskState {
22 pub task: Task,
24 pub cancel_token: CancellationToken,
26 pub last_access: Instant,
28}
29
30impl TaskState {
31 fn new(task: Task) -> Self {
33 Self {
34 task,
35 cancel_token: CancellationToken::new(),
36 last_access: Instant::now(),
37 }
38 }
39
40 pub fn is_cancelled(&self) -> bool {
42 self.cancel_token.is_cancelled()
43 }
44}
45
46pub struct TaskHandle {
51 task_id: TaskId,
52 manager: Arc<TaskManager>,
53}
54
55impl TaskHandle {
56 pub fn id(&self) -> &TaskId {
58 &self.task_id
59 }
60
61 pub async fn running(&self) -> Result<(), McpError> {
63 self.manager.update_status(&self.task_id, TaskStatus::Running).await
64 }
65
66 pub async fn progress(
68 &self,
69 current: u64,
70 total: Option<u64>,
71 message: Option<&str>,
72 ) -> Result<(), McpError> {
73 self.manager
74 .update_progress(&self.task_id, current, total, message)
75 .await
76 }
77
78 pub async fn complete(&self, result: Value) -> Result<(), McpError> {
80 self.manager.complete_success(&self.task_id, result).await
81 }
82
83 pub async fn error(&self, message: impl Into<String>) -> Result<(), McpError> {
85 self.manager.complete_error(&self.task_id, message.into()).await
86 }
87
88 pub fn is_cancelled(&self) -> bool {
90 self.manager
91 .get(&self.task_id)
92 .map(|s| s.is_cancelled())
93 .unwrap_or(true)
94 }
95
96 pub async fn cancelled(&self) {
98 if let Some(state) = self.manager.get(&self.task_id) {
99 state.cancel_token.cancelled().await;
100 }
101 }
102}
103
104pub struct TaskManager {
109 tasks: RwLock<HashMap<TaskId, TaskState>>,
110}
111
112impl Default for TaskManager {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118impl TaskManager {
119 pub fn new() -> Self {
121 Self {
122 tasks: RwLock::new(HashMap::new()),
123 }
124 }
125
126 pub fn create(self: &Arc<Self>, tool_name: Option<&str>) -> TaskHandle {
128 let mut task = Task::create();
129 task.tool = tool_name.map(String::from);
130
131 let task_id = task.id.clone();
132 let state = TaskState::new(task);
133
134 if let Ok(mut tasks) = self.tasks.write() {
135 tasks.insert(task_id.clone(), state);
136 }
137
138 TaskHandle {
139 task_id,
140 manager: Arc::clone(self),
141 }
142 }
143
144 pub fn get(&self, id: &TaskId) -> Option<TaskState> {
146 self.tasks.read().ok()?.get(id).map(|s| TaskState {
147 task: s.task.clone(),
148 cancel_token: s.cancel_token.clone(),
149 last_access: s.last_access,
150 })
151 }
152
153 pub fn list(&self) -> Vec<Task> {
155 self.tasks
156 .read()
157 .map(|tasks| tasks.values().map(|s| s.task.clone()).collect())
158 .unwrap_or_default()
159 }
160
161 pub fn cancel(&self, id: &TaskId) -> Result<(), McpError> {
163 let mut tasks = self.tasks.write().map_err(|_| {
164 McpError::internal("Failed to acquire task lock")
165 })?;
166
167 if let Some(state) = tasks.get_mut(id) {
168 state.cancel_token.cancel();
169 state.task.status = TaskStatus::Cancelled;
170 state.task.updated_at = chrono::Utc::now();
171 Ok(())
172 } else {
173 Err(McpError::invalid_params(
174 "tasks/cancel",
175 format!("Unknown task: {}", id.as_str()),
176 ))
177 }
178 }
179
180 async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Result<(), McpError> {
182 let mut tasks = self.tasks.write().map_err(|_| {
183 McpError::internal("Failed to acquire task lock")
184 })?;
185
186 if let Some(state) = tasks.get_mut(id) {
187 state.task.status = status;
188 state.task.updated_at = chrono::Utc::now();
189 state.last_access = Instant::now();
190 Ok(())
191 } else {
192 Err(McpError::invalid_params(
193 "tasks/get",
194 format!("Unknown task: {}", id.as_str()),
195 ))
196 }
197 }
198
199 async fn update_progress(
201 &self,
202 id: &TaskId,
203 current: u64,
204 total: Option<u64>,
205 message: Option<&str>,
206 ) -> Result<(), McpError> {
207 let mut tasks = self.tasks.write().map_err(|_| {
208 McpError::internal("Failed to acquire task lock")
209 })?;
210
211 if let Some(state) = tasks.get_mut(id) {
212 state.task.progress = Some(mcpkit_core::types::task::TaskProgress {
213 current,
214 total,
215 message: message.map(String::from),
216 });
217 state.task.updated_at = chrono::Utc::now();
218 state.last_access = Instant::now();
219 Ok(())
220 } else {
221 Err(McpError::invalid_params(
222 "tasks/get",
223 format!("Unknown task: {}", id.as_str()),
224 ))
225 }
226 }
227
228 async fn complete_success(&self, id: &TaskId, result: Value) -> Result<(), McpError> {
230 let mut tasks = self.tasks.write().map_err(|_| {
231 McpError::internal("Failed to acquire task lock")
232 })?;
233
234 if let Some(state) = tasks.get_mut(id) {
235 state.task.status = TaskStatus::Completed;
236 state.task.result = Some(result);
237 state.task.updated_at = chrono::Utc::now();
238 state.last_access = Instant::now();
239 Ok(())
240 } else {
241 Err(McpError::invalid_params(
242 "tasks/get",
243 format!("Unknown task: {}", id.as_str()),
244 ))
245 }
246 }
247
248 async fn complete_error(&self, id: &TaskId, message: String) -> Result<(), McpError> {
250 let mut tasks = self.tasks.write().map_err(|_| {
251 McpError::internal("Failed to acquire task lock")
252 })?;
253
254 if let Some(state) = tasks.get_mut(id) {
255 state.task.status = TaskStatus::Failed;
256 state.task.error = Some(mcpkit_core::types::task::TaskError {
257 code: -1,
258 message,
259 data: None,
260 });
261 state.task.updated_at = chrono::Utc::now();
262 state.last_access = Instant::now();
263 Ok(())
264 } else {
265 Err(McpError::invalid_params(
266 "tasks/get",
267 format!("Unknown task: {}", id.as_str()),
268 ))
269 }
270 }
271
272 pub fn cleanup(&self, max_age: std::time::Duration) {
274 if let Ok(mut tasks) = self.tasks.write() {
275 tasks.retain(|_, state| {
276 let is_terminal = state.task.status.is_terminal();
277 !is_terminal || state.last_access.elapsed() < max_age
278 });
279 }
280 }
281}
282
283pub struct TaskService {
285 manager: Arc<TaskManager>,
286}
287
288impl Default for TaskService {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294impl TaskService {
295 pub fn new() -> Self {
297 Self {
298 manager: Arc::new(TaskManager::new()),
299 }
300 }
301
302 pub fn manager(&self) -> &Arc<TaskManager> {
304 &self.manager
305 }
306
307 pub fn create(&self, tool_name: Option<&str>) -> TaskHandle {
309 self.manager.create(tool_name)
310 }
311}
312
313impl TaskHandler for TaskService {
314 async fn list_tasks(&self, _ctx: &Context<'_>) -> Result<Vec<Task>, McpError> {
315 Ok(self.manager.list())
316 }
317
318 async fn get_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<Option<Task>, McpError> {
319 Ok(self.manager.get(task_id).map(|s| s.task))
320 }
321
322 async fn cancel_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<bool, McpError> {
323 match self.manager.cancel(task_id) {
324 Ok(()) => Ok(true),
325 Err(_) => Ok(false),
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333
334 #[test]
335 fn test_task_manager() {
336 let manager = Arc::new(TaskManager::new());
337
338 let handle = manager.create(Some("test-tool"));
339 assert!(!handle.is_cancelled());
340
341 let tasks = manager.list();
342 assert_eq!(tasks.len(), 1);
343 assert_eq!(tasks[0].tool.as_deref(), Some("test-tool"));
344 assert_eq!(tasks[0].status, TaskStatus::Pending);
345 }
346
347 #[tokio::test]
348 async fn test_task_lifecycle() {
349 let manager = Arc::new(TaskManager::new());
350
351 let handle = manager.create(Some("processor"));
352 let task_id = handle.id().clone();
353
354 handle.running().await.unwrap();
356 let state = manager.get(&task_id).unwrap();
357 assert_eq!(state.task.status, TaskStatus::Running);
358
359 handle.progress(50, Some(100), Some("Halfway done")).await.unwrap();
361 let state = manager.get(&task_id).unwrap();
362 assert_eq!(state.task.progress.as_ref().map(|p| p.current), Some(50));
363
364 handle.complete(serde_json::json!({"result": "success"})).await.unwrap();
366 let state = manager.get(&task_id).unwrap();
367 assert_eq!(state.task.status, TaskStatus::Completed);
368 }
369
370 #[test]
371 fn test_task_cancellation() {
372 let manager = Arc::new(TaskManager::new());
373
374 let handle = manager.create(None);
375 let task_id = handle.id().clone();
376
377 assert!(!handle.is_cancelled());
378
379 manager.cancel(&task_id).unwrap();
380
381 assert!(handle.is_cancelled());
382 let state = manager.get(&task_id).unwrap();
383 assert_eq!(state.task.status, TaskStatus::Cancelled);
384 }
385
386 #[tokio::test]
387 async fn test_task_service() {
388 let service = TaskService::new();
389
390 let handle = service.create(Some("service-task"));
391 handle.running().await.unwrap();
392
393 let tasks = service.manager.list();
394 assert_eq!(tasks.len(), 1);
395 }
396}