1use std::collections::HashSet;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use chrono::Utc;
10use dashmap::DashMap;
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::{RwLock, oneshot};
14use uuid::Uuid;
15
16use crate::protocol::types::{JsonRpcRequest, Task, TaskStatus};
17
18#[derive(Debug, Error)]
20pub enum TaskError {
21 #[error("Task not found: {0}")]
23 NotFound(String),
24
25 #[error("Task timeout after {0:?}")]
27 Timeout(Duration),
28
29 #[error("Task failed: {0}")]
31 Failed(String),
32
33 #[error("Task already in terminal status: {0:?}")]
35 AlreadyTerminal(TaskStatus),
36
37 #[error("Internal error: {0}")]
39 Internal(String),
40}
41
42struct TaskEntry {
44 task: Arc<RwLock<Task>>,
46 session_id: String,
48 #[allow(dead_code)]
50 original_request: JsonRpcRequest,
51 result_tx: Option<oneshot::Sender<Value>>,
53 result: Arc<RwLock<Option<Value>>>,
55 expires_at: Option<Instant>,
57}
58
59pub struct TaskStore {
61 tasks: DashMap<String, TaskEntry>,
63 by_session: DashMap<String, HashSet<String>>,
65 default_ttl: Duration,
67 default_poll_interval: Duration,
69}
70
71impl TaskStore {
72 pub fn new(default_ttl: Duration, default_poll_interval: Duration) -> Self {
74 Self {
75 tasks: DashMap::new(),
76 by_session: DashMap::new(),
77 default_ttl,
78 default_poll_interval,
79 }
80 }
81
82 pub fn create_task(
84 &self,
85 session_id: &str,
86 original_request: JsonRpcRequest,
87 requested_ttl: Option<u64>,
88 ) -> (Task, oneshot::Receiver<Value>) {
89 let task_id = Uuid::new_v4().to_string();
90 let now = Utc::now().to_rfc3339();
91
92 let ttl_duration = requested_ttl
94 .map(Duration::from_millis)
95 .unwrap_or(self.default_ttl);
96 let ttl_ms = if ttl_duration == Duration::ZERO {
97 None
98 } else {
99 Some(ttl_duration.as_millis() as u64)
100 };
101
102 let task = Task {
103 task_id: task_id.clone(),
104 status: TaskStatus::Working,
105 status_message: None,
106 created_at: now.clone(),
107 last_updated_at: now,
108 ttl: ttl_ms,
109 poll_interval: Some(self.default_poll_interval.as_millis() as u64),
110 };
111
112 let (result_tx, result_rx) = oneshot::channel();
114
115 let entry = TaskEntry {
116 task: Arc::new(RwLock::new(task.clone())),
117 session_id: session_id.to_string(),
118 original_request,
119 result_tx: Some(result_tx),
120 result: Arc::new(RwLock::new(None)),
121 expires_at: ttl_ms.map(|ms| Instant::now() + Duration::from_millis(ms)),
122 };
123
124 self.tasks.insert(task_id.clone(), entry);
126
127 self.by_session
129 .entry(session_id.to_string())
130 .or_default()
131 .insert(task_id);
132
133 (task, result_rx)
134 }
135
136 pub async fn get_task(&self, task_id: &str) -> Option<Task> {
138 let entry = self.tasks.get(task_id)?;
139 Some(entry.task.read().await.clone())
140 }
141
142 pub async fn get_task_for_session(&self, task_id: &str, session_id: &str) -> Option<Task> {
144 let entry = self.tasks.get(task_id)?;
145 if entry.session_id == session_id {
146 Some(entry.task.read().await.clone())
147 } else {
148 None
149 }
150 }
151
152 pub async fn list_tasks(
154 &self,
155 session_id: &str,
156 cursor: Option<&str>,
157 limit: usize,
158 ) -> (Vec<Task>, Option<String>) {
159 let task_ids = match self.by_session.get(session_id) {
160 Some(ids) => ids.clone(),
161 None => return (vec![], None),
162 };
163
164 let mut task_ids: Vec<_> = task_ids.into_iter().collect();
165 task_ids.sort(); let start = if let Some(c) = cursor {
169 task_ids
170 .iter()
171 .position(|id| id == c)
172 .map(|p| p + 1)
173 .unwrap_or(0)
174 } else {
175 0
176 };
177
178 let end = (start + limit).min(task_ids.len());
179 let page_ids = &task_ids[start..end];
180
181 let mut tasks = Vec::new();
183 for id in page_ids {
184 if let Some(task) = self.get_task(id).await {
185 tasks.push(task);
186 }
187 }
188
189 let next_cursor = if end < task_ids.len() {
191 task_ids.get(end).cloned()
192 } else {
193 None
194 };
195
196 (tasks, next_cursor)
197 }
198
199 pub async fn update_status(
201 &self,
202 task_id: &str,
203 new_status: TaskStatus,
204 status_message: Option<String>,
205 ) -> Result<(), TaskError> {
206 let entry = self
207 .tasks
208 .get(task_id)
209 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
210
211 let mut task = entry.task.write().await;
212
213 if let (TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled, _) =
215 (&task.status, &new_status)
216 {
217 return Err(TaskError::AlreadyTerminal(task.status));
218 }
219
220 task.status = new_status;
221 task.status_message = status_message;
222 task.last_updated_at = Utc::now().to_rfc3339();
223
224 Ok(())
225 }
226
227 pub async fn store_result(&self, task_id: &str, result: Value) -> Result<(), TaskError> {
229 let entry = self
230 .tasks
231 .get(task_id)
232 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
233
234 *entry.result.write().await = Some(result.clone());
236
237 drop(entry); if let Some(mut entry_mut) = self.tasks.get_mut(task_id)
240 && let Some(tx) = entry_mut.result_tx.take()
241 {
242 let _ = tx.send(result);
243 }
244
245 Ok(())
246 }
247
248 pub async fn get_result(&self, task_id: &str) -> Option<Value> {
250 let entry = self.tasks.get(task_id)?;
251 entry.result.read().await.clone()
252 }
253
254 pub async fn wait_for_result(
256 &self,
257 task_id: &str,
258 timeout: Duration,
259 ) -> Result<Value, TaskError> {
260 let entry = self
261 .tasks
262 .get(task_id)
263 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
264
265 let task = entry.task.read().await.clone();
266
267 if matches!(
269 task.status,
270 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
271 ) {
272 drop(task);
273 if let Some(result) = self.get_result(task_id).await {
274 return Ok(result);
275 } else {
276 return Err(TaskError::Failed(
277 "Task in terminal state but no result stored".to_string(),
278 ));
279 }
280 }
281
282 drop(entry);
284
285 let start = Instant::now();
289 loop {
290 if start.elapsed() > timeout {
291 return Err(TaskError::Timeout(timeout));
292 }
293
294 let task = self
295 .get_task(task_id)
296 .await
297 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
298
299 if matches!(
300 task.status,
301 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
302 ) {
303 if let Some(result) = self.get_result(task_id).await {
304 return Ok(result);
305 } else {
306 return Err(TaskError::Failed(
307 "Task completed but no result stored".to_string(),
308 ));
309 }
310 }
311
312 tokio::time::sleep(Duration::from_millis(100)).await;
313 }
314 }
315
316 pub async fn cancel_task(&self, task_id: &str, session_id: &str) -> Result<Task, TaskError> {
318 let entry = self
320 .tasks
321 .get(task_id)
322 .ok_or_else(|| TaskError::NotFound(task_id.to_string()))?;
323
324 if entry.session_id != session_id {
325 return Err(TaskError::NotFound(task_id.to_string()));
326 }
327
328 let mut task = entry.task.write().await;
329
330 if matches!(
332 task.status,
333 TaskStatus::Completed | TaskStatus::Failed | TaskStatus::Cancelled
334 ) {
335 return Err(TaskError::AlreadyTerminal(task.status));
336 }
337
338 task.status = TaskStatus::Cancelled;
340 task.status_message = Some("Cancelled by request".to_string());
341 task.last_updated_at = Utc::now().to_rfc3339();
342
343 Ok(task.clone())
344 }
345
346 pub async fn cleanup_expired(&self) {
348 let now = Instant::now();
349 let mut to_remove = Vec::new();
350
351 for entry in self.tasks.iter() {
352 if let Some(expires_at) = entry.expires_at
353 && now >= expires_at
354 {
355 to_remove.push(entry.key().clone());
356 }
357 }
358
359 for task_id in to_remove {
360 if let Some((_, entry)) = self.tasks.remove(&task_id) {
361 if let Some(mut ids) = self.by_session.get_mut(&entry.session_id) {
363 ids.remove(&task_id);
364 }
365 }
366 }
367 }
368
369 pub fn spawn_cleanup_task(self: Arc<Self>, interval: Duration) {
371 tokio::spawn(async move {
372 let mut interval = tokio::time::interval(interval);
373 loop {
374 interval.tick().await;
375 self.cleanup_expired().await;
376 }
377 });
378 }
379}
380
381#[cfg(test)]
382mod tests {
383 use super::*;
384 use serde_json::json;
385
386 #[tokio::test]
387 async fn test_create_task() {
388 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
389 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
390
391 let (task, _rx) = store.create_task("session1", req, None);
392
393 assert_eq!(task.status, TaskStatus::Working);
394 assert!(task.task_id.len() > 0);
395 assert_eq!(task.ttl, Some(60000));
396 }
397
398 #[tokio::test]
399 async fn test_get_task() {
400 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
401 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
402
403 let (task, _rx) = store.create_task("session1", req, None);
404 let retrieved = store.get_task(&task.task_id).await.unwrap();
405
406 assert_eq!(retrieved.task_id, task.task_id);
407 assert_eq!(retrieved.status, TaskStatus::Working);
408 }
409
410 #[tokio::test]
411 async fn test_session_isolation() {
412 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
413 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
414
415 let (task, _rx) = store.create_task("session1", req, None);
416
417 assert!(
419 store
420 .get_task_for_session(&task.task_id, "session2")
421 .await
422 .is_none()
423 );
424
425 assert!(
427 store
428 .get_task_for_session(&task.task_id, "session1")
429 .await
430 .is_some()
431 );
432 }
433
434 #[tokio::test]
435 async fn test_update_status() {
436 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
437 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
438
439 let (task, _rx) = store.create_task("session1", req, None);
440
441 store
442 .update_status(
443 &task.task_id,
444 TaskStatus::Completed,
445 Some("Done".to_string()),
446 )
447 .await
448 .unwrap();
449
450 let updated = store.get_task(&task.task_id).await.unwrap();
451 assert_eq!(updated.status, TaskStatus::Completed);
452 assert_eq!(updated.status_message, Some("Done".to_string()));
453 }
454
455 #[tokio::test]
456 async fn test_store_and_get_result() {
457 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
458 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
459
460 let (task, _rx) = store.create_task("session1", req, None);
461
462 let result = json!({"answer": 42});
463 store
464 .store_result(&task.task_id, result.clone())
465 .await
466 .unwrap();
467
468 let retrieved = store.get_result(&task.task_id).await.unwrap();
469 assert_eq!(retrieved, result);
470 }
471
472 #[tokio::test]
473 async fn test_cancel_task() {
474 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
475 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
476
477 let (task, _rx) = store.create_task("session1", req, None);
478
479 let cancelled = store.cancel_task(&task.task_id, "session1").await.unwrap();
480 assert_eq!(cancelled.status, TaskStatus::Cancelled);
481 }
482
483 #[tokio::test]
484 async fn test_list_tasks() {
485 let store = TaskStore::new(Duration::from_secs(60), Duration::from_secs(5));
486
487 for _ in 0..5 {
489 let req = JsonRpcRequest::new(json!(1), "tools/call", Some(json!({})));
490 store.create_task("session1", req, None);
491 }
492
493 let (tasks, cursor) = store.list_tasks("session1", None, 10).await;
494 assert_eq!(tasks.len(), 5);
495 assert!(cursor.is_none());
496 }
497}