1use std::collections::HashSet;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use chrono::Utc;
10use dashmap::DashMap;
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{oneshot, RwLock};
14use uuid::Uuid;
15
16use crate::protocol::types::{Task, TaskStatus, JsonRpcRequest};
17
18#[derive(Debug, Error)]
20pub enum TaskError {
21 #[error("Task not found: {0}")]
23 NotFound(String),
24
25 #[error("Task timeout after {0:?}")]
27 Timeout(Duration),
28
29 #[error("Task failed: {0}")]
31 Failed(String),
32
33 #[error("Task already in terminal status: {0:?}")]
35 AlreadyTerminal(TaskStatus),
36
37 #[error("Internal error: {0}")]
39 Internal(String),
40}
41
42struct TaskEntry {
44 task: Arc<RwLock<Task>>,
46 session_id: String,
48 #[allow(dead_code)]
50 original_request: JsonRpcRequest,
51 result_tx: Option<oneshot::Sender<Value>>,
53 result: Arc<RwLock<Option<Value>>>,
55 expires_at: Option<Instant>,
57}
58
59pub struct TaskStore {
61 tasks: DashMap<String, TaskEntry>,
63 by_session: DashMap<String, HashSet<String>>,
65 default_ttl: Duration,
67 default_poll_interval: Duration,
69}
70
71impl TaskStore {
72 pub fn new(default_ttl: Duration, default_poll_interval: Duration) -> Self {
74 Self {
75 tasks: DashMap::new(),
76 by_session: DashMap::new(),
77 default_ttl,
78 default_poll_interval,
79 }
80 }
81
82 pub fn create_task(
84 &self,
85 session_id: &str,
86 original_request: JsonRpcRequest,
87 requested_ttl: Option<u64>,
88 ) -> (Task, oneshot::Receiver<Value>) {
89 let task_id = Uuid::new_v4().to_string();
90 let now = Utc::now().to_rfc3339();
91
92 let ttl_duration = requested_ttl
94 .map(|ms| Duration::from_millis(ms))
95 .unwrap_or(self.default_ttl);
96 let ttl_ms = if ttl_duration == Duration::ZERO {
97 None
98 } else {
99 Some(ttl_duration.as_millis() as u64)
100 };
101
102 let task = Task {
103 task_id: task_id.clone(),
104 status: TaskStatus::Working,
105 status_message: None,
106 created_at: now.clone(),
107 last_updated_at: now,
108 ttl: ttl_ms,
109 poll_interval: Some(self.default_poll_interval.as_millis() as u64),
110 };
111
112 let (result_tx, result_rx) = oneshot::channel();
114
115 let entry = TaskEntry {
116 task: Arc::new(RwLock::new(task.clone())),
117 session_id: session_id.to_string(),
118 original_request,
119 result_tx: Some(result_tx),
120 result: Arc::new(RwLock::new(None)),
121 expires_at: ttl_ms.map(|ms| Instant::now() + Duration::from_millis(ms)),
122 };
123
124 self.tasks.insert(task_id.clone(), entry);
126
127 self.by_session
129 .entry(session_id.to_string())
130 .or_insert_with(HashSet::new)
131 .insert(task_id);
132
133 (task, result_rx)
134 }
135
136 pub async fn get_task(&self, task_id: &str) -> Option<Task> {
138 let entry = self.tasks.get(task_id)?;
139 Some(entry.task.read().await.clone())
140 }
141
142 pub async fn get_task_for_session(
144 &self,
145 task_id: &str,
146 session_id: &str,
147 ) -> Option<Task> {
148 let entry = self.tasks.get(task_id)?;
149 if entry.session_id == session_id {
150 Some(entry.task.read().await.clone())
151 } else {
152 None
153 }
154 }
155
156 pub async fn list_tasks(
158 &self,
159 session_id: &str,
160 cursor: Option<&str>,
161 limit: usize,
162 ) -> (Vec<Task>, Option<String>) {
163 let task_ids = match self.by_session.get(session_id) {
164 Some(ids) => ids.clone(),
165 None => return (vec![], None),
166 };
167
168 let mut task_ids: Vec<_> = task_ids.into_iter().collect();
169 task_ids.sort(); let start = if let Some(c) = cursor {
173 task_ids.iter().position(|id| id == c).map(|p| p + 1).unwrap_or(0)
174 } else {
175 0
176 };
177
178 let end = (start + limit).min(task_ids.len());
179 let page_ids = &task_ids[start..end];
180
181 let mut tasks = Vec::new();
183 for id in page_ids {
184 if let Some(task) = self.get_task(id).await {
185 tasks.push(task);
186 }
187 }
188
189 let next_cursor = if end < task_ids.len() {
191 task_ids.get(end).cloned()
192 } else {
193 None
194 };
195
196 (tasks, next_cursor)
197 }
198
199 pub async fn update_status(
201 &self,
202 task_id: &str,
203 new_status: TaskStatus,
204 status_message: Option<String>,
205 ) -> Result<(), TaskError> {
206 let entry = self.tasks.get(task_id)
207 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
208
209 let mut task = entry.task.write().await;
210
211 match (&task.status, &new_status) {
213 (TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled, _) => {
214 return Err(TaskError::AlreadyTerminal(task.status));
215 }
216 _ => {}
217 }
218
219 task.status = new_status;
220 task.status_message = status_message;
221 task.last_updated_at = Utc::now().to_rfc3339();
222
223 Ok(())
224 }
225
226 pub async fn store_result(&self, task_id: &str, result: Value) -> Result<(), TaskError> {
228 let entry = self.tasks.get(task_id)
229 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
230
231 *entry.result.write().await = Some(result.clone());
233
234 drop(entry); if let Some(mut entry_mut) = self.tasks.get_mut(task_id) {
237 if let Some(tx) = entry_mut.result_tx.take() {
238 let _ = tx.send(result);
239 }
240 }
241
242 Ok(())
243 }
244
245 pub async fn get_result(&self, task_id: &str) -> Option<Value> {
247 let entry = self.tasks.get(task_id)?;
248 entry.result.read().await.clone()
249 }
250
251 pub async fn wait_for_result(
253 &self,
254 task_id: &str,
255 timeout: Duration,
256 ) -> Result<Value, TaskError> {
257 let entry = self.tasks.get(task_id)
258 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
259
260 let task = entry.task.read().await.clone();
261
262 if matches!(
264 task.status,
265 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
266 ) {
267 drop(task);
268 if let Some(result) = self.get_result(task_id).await {
269 return Ok(result);
270 } else {
271 return Err(TaskError::Failed(
272 "Task in terminal state but no result stored".to_string(),
273 ));
274 }
275 }
276
277 drop(entry);
279
280 let start = Instant::now();
284 loop {
285 if start.elapsed() > timeout {
286 return Err(TaskError::Timeout(timeout));
287 }
288
289 let task = self.get_task(task_id).await
290 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
291
292 if matches!(
293 task.status,
294 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
295 ) {
296 if let Some(result) = self.get_result(task_id).await {
297 return Ok(result);
298 } else {
299 return Err(TaskError::Failed(
300 "Task completed but no result stored".to_string(),
301 ));
302 }
303 }
304
305 tokio::time::sleep(Duration::from_millis(100)).await;
306 }
307 }
308
309 pub async fn cancel_task(
311 &self,
312 task_id: &str,
313 session_id: &str,
314 ) -> Result<Task, TaskError> {
315 let entry = self.tasks.get(task_id)
317 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
318
319 if entry.session_id != session_id {
320 return Err(TaskError::NotFound(task_id.to_string()));
321 }
322
323 let mut task = entry.task.write().await;
324
325 if matches!(
327 task.status,
328 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
329 ) {
330 return Err(TaskError::AlreadyTerminal(task.status));
331 }
332
333 task.status = TaskStatus::Cancelled;
335 task.status_message = Some("Cancelled by request".to_string());
336 task.last_updated_at = Utc::now().to_rfc3339();
337
338 Ok(task.clone())
339 }
340
341 pub async fn cleanup_expired(&self) {
343 let now = Instant::now();
344 let mut to_remove = Vec::new();
345
346 for entry in self.tasks.iter() {
347 if let Some(expires_at) = entry.expires_at {
348 if now >= expires_at {
349 to_remove.push(entry.key().clone());
350 }
351 }
352 }
353
354 for task_id in to_remove {
355 if let Some((_, entry)) = self.tasks.remove(&task_id) {
356 if let Some(mut ids) = self.by_session.get_mut(&entry.session_id) {
358 ids.remove(&task_id);
359 }
360 }
361 }
362 }
363
364 pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) {
366 tokio::spawn(async move {
367 let mut interval = tokio::time::interval(interval);
368 loop {
369 interval.tick().await;
370 self.cleanup_expired().await;
371 }
372 });
373 }
374}
375
376#[cfg(test)]
377mod tests {
378 use super::*;
379 use serde_json::json;
380
381 #[tokio::test]
382 async fn test_create_task() {
383 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
384 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
385
386 let (task, _rx) = store.create_task("session1", req, None);
387
388 assert_eq!(task.status, TaskStatus::Working);
389 assert!(task.task_id.len() > 0);
390 assert_eq!(task.ttl, Some(60000));
391 }
392
393 #[tokio::test]
394 async fn test_get_task() {
395 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
396 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
397
398 let (task, _rx) = store.create_task("session1", req, None);
399 let retrieved = store.get_task(&task.task_id).await.unwrap();
400
401 assert_eq!(retrieved.task_id, task.task_id);
402 assert_eq!(retrieved.status, TaskStatus::Working);
403 }
404
405 #[tokio::test]
406 async fn test_session_isolation() {
407 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
408 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
409
410 let (task, _rx) = store.create_task("session1", req, None);
411
412 assert!(store.get_task_for_session(&task.task_id, "session2").await.is_none());
414
415 assert!(store.get_task_for_session(&task.task_id, "session1").await.is_some());
417 }
418
419 #[tokio::test]
420 async fn test_update_status() {
421 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
422 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
423
424 let (task, _rx) = store.create_task("session1", req, None);
425
426 store.update_status(&task.task_id, TaskStatus::Completed, Some("Done".to_string()))
427 .await
428 .unwrap();
429
430 let updated = store.get_task(&task.task_id).await.unwrap();
431 assert_eq!(updated.status, TaskStatus::Completed);
432 assert_eq!(updated.status_message, Some("Done".to_string()));
433 }
434
435 #[tokio::test]
436 async fn test_store_and_get_result() {
437 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
438 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
439
440 let (task, _rx) = store.create_task("session1", req, None);
441
442 let result = json!({"answer": 42});
443 store.store_result(&task.task_id, result.clone()).await.unwrap();
444
445 let retrieved = store.get_result(&task.task_id).await.unwrap();
446 assert_eq!(retrieved, result);
447 }
448
449 #[tokio::test]
450 async fn test_cancel_task() {
451 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
452 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
453
454 let (task, _rx) = store.create_task("session1", req, None);
455
456 let cancelled = store.cancel_task(&task.task_id, "session1").await.unwrap();
457 assert_eq!(cancelled.status, TaskStatus::Cancelled);
458 }
459
460 #[tokio::test]
461 async fn test_list_tasks() {
462 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
463
464 for _ in 0..5 {
466 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
467 store.create_task("session1", req, None);
468 }
469
470 let (tasks, cursor) = store.list_tasks("session1", None, 10).await;
471 assert_eq!(tasks.len(), 5);
472 assert!(cursor.is_none());
473 }
474}