mcpkit_server/capability/
tasks.rs1use crate::context::CancellationToken;
9use crate::context::Context;
10use crate::handler::TaskHandler;
11use mcpkit_core::error::McpError;
12use mcpkit_core::types::task::{Task, TaskId, TaskStatus};
13use serde_json::Value;
14use std::collections::HashMap;
15use std::sync::{Arc, RwLock};
16use std::time::Instant;
17
18#[derive(Debug)]
20pub struct TaskState {
21 pub task: Task,
23 pub cancel_token: CancellationToken,
25 pub last_access: Instant,
27}
28
29impl TaskState {
30 fn new(task: Task) -> Self {
32 Self {
33 task,
34 cancel_token: CancellationToken::new(),
35 last_access: Instant::now(),
36 }
37 }
38
39 #[must_use]
41 pub fn is_cancelled(&self) -> bool {
42 self.cancel_token.is_cancelled()
43 }
44}
45
46pub struct TaskHandle {
51 task_id: TaskId,
52 manager: Arc<TaskManager>,
53}
54
55impl TaskHandle {
56 #[must_use]
58 pub const fn id(&self) -> &TaskId {
59 &self.task_id
60 }
61
62 pub async fn running(&self) -> Result<(), McpError> {
64 self.manager
65 .update_status(&self.task_id, TaskStatus::Running)
66 .await
67 }
68
69 pub async fn progress(
71 &self,
72 current: u64,
73 total: Option<u64>,
74 message: Option<&str>,
75 ) -> Result<(), McpError> {
76 self.manager
77 .update_progress(&self.task_id, current, total, message)
78 .await
79 }
80
81 pub async fn complete(&self, result: Value) -> Result<(), McpError> {
83 self.manager.complete_success(&self.task_id, result).await
84 }
85
86 pub async fn error(&self, message: impl Into<String>) -> Result<(), McpError> {
88 self.manager
89 .complete_error(&self.task_id, message.into())
90 .await
91 }
92
93 #[must_use]
95 pub fn is_cancelled(&self) -> bool {
96 self.manager
97 .get(&self.task_id)
98 .is_none_or(|s| s.is_cancelled())
99 }
100
101 pub async fn cancelled(&self) {
103 if let Some(state) = self.manager.get(&self.task_id) {
104 state.cancel_token.cancelled().await;
105 }
106 }
107}
108
109pub struct TaskManager {
114 tasks: RwLock<HashMap<TaskId, TaskState>>,
115}
116
117impl Default for TaskManager {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl TaskManager {
124 #[must_use]
126 pub fn new() -> Self {
127 Self {
128 tasks: RwLock::new(HashMap::new()),
129 }
130 }
131
132 pub fn create(self: &Arc<Self>, tool_name: Option<&str>) -> TaskHandle {
134 let mut task = Task::create();
135 task.tool = tool_name.map(String::from);
136
137 let task_id = task.id.clone();
138 let state = TaskState::new(task);
139
140 if let Ok(mut tasks) = self.tasks.write() {
141 tasks.insert(task_id.clone(), state);
142 }
143
144 TaskHandle {
145 task_id,
146 manager: Arc::clone(self),
147 }
148 }
149
150 pub fn get(&self, id: &TaskId) -> Option<TaskState> {
152 self.tasks.read().ok()?.get(id).map(|s| TaskState {
153 task: s.task.clone(),
154 cancel_token: s.cancel_token.clone(),
155 last_access: s.last_access,
156 })
157 }
158
159 pub fn list(&self) -> Vec<Task> {
161 self.tasks
162 .read()
163 .map(|tasks| tasks.values().map(|s| s.task.clone()).collect())
164 .unwrap_or_default()
165 }
166
167 pub fn cancel(&self, id: &TaskId) -> Result<(), McpError> {
169 let mut tasks = self
170 .tasks
171 .write()
172 .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
173
174 if let Some(state) = tasks.get_mut(id) {
175 state.cancel_token.cancel();
176 state.task.status = TaskStatus::Cancelled;
177 state.task.updated_at = chrono::Utc::now();
178 Ok(())
179 } else {
180 Err(McpError::invalid_params(
181 "tasks/cancel",
182 format!("Unknown task: {}", id.as_str()),
183 ))
184 }
185 }
186
187 async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Result<(), McpError> {
189 let mut tasks = self
190 .tasks
191 .write()
192 .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
193
194 if let Some(state) = tasks.get_mut(id) {
195 state.task.status = status;
196 state.task.updated_at = chrono::Utc::now();
197 state.last_access = Instant::now();
198 Ok(())
199 } else {
200 Err(McpError::invalid_params(
201 "tasks/get",
202 format!("Unknown task: {}", id.as_str()),
203 ))
204 }
205 }
206
207 async fn update_progress(
209 &self,
210 id: &TaskId,
211 current: u64,
212 total: Option<u64>,
213 message: Option<&str>,
214 ) -> Result<(), McpError> {
215 let mut tasks = self
216 .tasks
217 .write()
218 .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
219
220 if let Some(state) = tasks.get_mut(id) {
221 state.task.progress = Some(mcpkit_core::types::task::TaskProgress {
222 current,
223 total,
224 message: message.map(String::from),
225 });
226 state.task.updated_at = chrono::Utc::now();
227 state.last_access = Instant::now();
228 Ok(())
229 } else {
230 Err(McpError::invalid_params(
231 "tasks/get",
232 format!("Unknown task: {}", id.as_str()),
233 ))
234 }
235 }
236
237 async fn complete_success(&self, id: &TaskId, result: Value) -> Result<(), McpError> {
239 let mut tasks = self
240 .tasks
241 .write()
242 .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
243
244 if let Some(state) = tasks.get_mut(id) {
245 state.task.status = TaskStatus::Completed;
246 state.task.result = Some(result);
247 state.task.updated_at = chrono::Utc::now();
248 state.last_access = Instant::now();
249 Ok(())
250 } else {
251 Err(McpError::invalid_params(
252 "tasks/get",
253 format!("Unknown task: {}", id.as_str()),
254 ))
255 }
256 }
257
258 async fn complete_error(&self, id: &TaskId, message: String) -> Result<(), McpError> {
260 let mut tasks = self
261 .tasks
262 .write()
263 .map_err(|_| McpError::internal("Failed to acquire task lock"))?;
264
265 if let Some(state) = tasks.get_mut(id) {
266 state.task.status = TaskStatus::Failed;
267 state.task.error = Some(mcpkit_core::types::task::TaskError {
268 code: -1,
269 message,
270 data: None,
271 });
272 state.task.updated_at = chrono::Utc::now();
273 state.last_access = Instant::now();
274 Ok(())
275 } else {
276 Err(McpError::invalid_params(
277 "tasks/get",
278 format!("Unknown task: {}", id.as_str()),
279 ))
280 }
281 }
282
283 pub fn cleanup(&self, max_age: std::time::Duration) {
285 if let Ok(mut tasks) = self.tasks.write() {
286 tasks.retain(|_, state| {
287 let is_terminal = state.task.status.is_terminal();
288 !is_terminal || state.last_access.elapsed() < max_age
289 });
290 }
291 }
292}
293
294pub struct TaskService {
296 manager: Arc<TaskManager>,
297}
298
299impl Default for TaskService {
300 fn default() -> Self {
301 Self::new()
302 }
303}
304
305impl TaskService {
306 #[must_use]
308 pub fn new() -> Self {
309 Self {
310 manager: Arc::new(TaskManager::new()),
311 }
312 }
313
314 #[must_use]
316 pub const fn manager(&self) -> &Arc<TaskManager> {
317 &self.manager
318 }
319
320 #[must_use]
322 pub fn create(&self, tool_name: Option<&str>) -> TaskHandle {
323 self.manager.create(tool_name)
324 }
325}
326
327impl TaskHandler for TaskService {
328 async fn list_tasks(&self, _ctx: &Context<'_>) -> Result<Vec<Task>, McpError> {
329 Ok(self.manager.list())
330 }
331
332 async fn get_task(
333 &self,
334 task_id: &TaskId,
335 _ctx: &Context<'_>,
336 ) -> Result<Option<Task>, McpError> {
337 Ok(self.manager.get(task_id).map(|s| s.task))
338 }
339
340 async fn cancel_task(&self, task_id: &TaskId, _ctx: &Context<'_>) -> Result<bool, McpError> {
341 match self.manager.cancel(task_id) {
342 Ok(()) => Ok(true),
343 Err(_) => Ok(false),
344 }
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_task_manager() {
354 let manager = Arc::new(TaskManager::new());
355
356 let handle = manager.create(Some("test-tool"));
357 assert!(!handle.is_cancelled());
358
359 let tasks = manager.list();
360 assert_eq!(tasks.len(), 1);
361 assert_eq!(tasks[0].tool.as_deref(), Some("test-tool"));
362 assert_eq!(tasks[0].status, TaskStatus::Pending);
363 }
364
365 #[tokio::test]
366 async fn test_task_lifecycle() -> Result<(), Box<dyn std::error::Error>> {
367 let manager = Arc::new(TaskManager::new());
368
369 let handle = manager.create(Some("processor"));
370 let task_id = handle.id().clone();
371
372 handle.running().await?;
374 let state = manager.get(&task_id).ok_or("Task not found")?;
375 assert_eq!(state.task.status, TaskStatus::Running);
376
377 handle.progress(50, Some(100), Some("Halfway done")).await?;
379 let state = manager.get(&task_id).ok_or("Task not found")?;
380 assert_eq!(state.task.progress.as_ref().map(|p| p.current), Some(50));
381
382 handle
384 .complete(serde_json::json!({"result": "success"}))
385 .await?;
386 let state = manager.get(&task_id).ok_or("Task not found")?;
387 assert_eq!(state.task.status, TaskStatus::Completed);
388
389 Ok(())
390 }
391
392 #[test]
393 fn test_task_cancellation() -> Result<(), Box<dyn std::error::Error>> {
394 let manager = Arc::new(TaskManager::new());
395
396 let handle = manager.create(None);
397 let task_id = handle.id().clone();
398
399 assert!(!handle.is_cancelled());
400
401 manager.cancel(&task_id)?;
402
403 assert!(handle.is_cancelled());
404 let state = manager.get(&task_id).ok_or("Task not found")?;
405 assert_eq!(state.task.status, TaskStatus::Cancelled);
406
407 Ok(())
408 }
409
410 #[tokio::test]
411 async fn test_task_service() -> Result<(), Box<dyn std::error::Error>> {
412 let service = TaskService::new();
413
414 let handle = service.create(Some("service-task"));
415 handle.running().await?;
416
417 let tasks = service.manager.list();
418 assert_eq!(tasks.len(), 1);
419
420 Ok(())
421 }
422}