1use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskInfo, TaskStatus};
32
33const DEFAULT_TTL_SECS: u64 = 300;
35
36const DEFAULT_POLL_INTERVAL_SECS: u64 = 2;
38
39#[derive(Debug)]
41pub struct Task {
42 pub id: String,
44 pub tool_name: String,
46 pub arguments: serde_json::Value,
48 pub status: TaskStatus,
50 pub created_at: Instant,
52 pub created_at_str: String,
54 pub ttl: u64,
56 pub poll_interval: u64,
58 pub progress: Option<f64>,
60 pub message: Option<String>,
62 pub result: Option<CallToolResult>,
64 pub error: Option<String>,
66 pub cancellation_token: CancellationToken,
68 pub completed_at: Option<Instant>,
70}
71
72impl Task {
73 fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
75 let cancelled = Arc::new(AtomicBool::new(false));
76 Self {
77 id,
78 tool_name,
79 arguments,
80 status: TaskStatus::Working,
81 created_at: Instant::now(),
82 created_at_str: chrono_now_iso8601(),
83 ttl: ttl.unwrap_or(DEFAULT_TTL_SECS),
84 poll_interval: DEFAULT_POLL_INTERVAL_SECS,
85 progress: None,
86 message: Some("Task started".to_string()),
87 result: None,
88 error: None,
89 cancellation_token: CancellationToken { cancelled },
90 completed_at: None,
91 }
92 }
93
94 pub fn to_info(&self) -> TaskInfo {
96 TaskInfo {
97 task_id: self.id.clone(),
98 status: self.status,
99 created_at: self.created_at_str.clone(),
100 ttl: Some(self.ttl),
101 poll_interval: Some(self.poll_interval),
102 progress: self.progress,
103 message: self.message.clone(),
104 }
105 }
106
107 pub fn is_expired(&self) -> bool {
109 if let Some(completed_at) = self.completed_at {
110 completed_at.elapsed() > Duration::from_secs(self.ttl)
111 } else {
112 false
113 }
114 }
115
116 pub fn is_cancelled(&self) -> bool {
118 self.cancellation_token.is_cancelled()
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct CancellationToken {
125 cancelled: Arc<AtomicBool>,
126}
127
128impl CancellationToken {
129 pub fn is_cancelled(&self) -> bool {
131 self.cancelled.load(Ordering::Relaxed)
132 }
133
134 pub fn cancel(&self) {
136 self.cancelled.store(true, Ordering::Relaxed);
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct TaskStore {
143 tasks: Arc<RwLock<HashMap<String, Task>>>,
144 next_id: Arc<AtomicU64>,
145}
146
147impl Default for TaskStore {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl TaskStore {
154 pub fn new() -> Self {
156 Self {
157 tasks: Arc::new(RwLock::new(HashMap::new())),
158 next_id: Arc::new(AtomicU64::new(1)),
159 }
160 }
161
162 fn generate_id(&self) -> String {
164 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
165 format!("task-{}", id)
166 }
167
168 pub fn create_task(
172 &self,
173 tool_name: &str,
174 arguments: serde_json::Value,
175 ttl: Option<u64>,
176 ) -> (String, CancellationToken) {
177 let id = self.generate_id();
178 let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
179 let token = task.cancellation_token.clone();
180
181 if let Ok(mut tasks) = self.tasks.write() {
182 tasks.insert(id.clone(), task);
183 }
184
185 (id, token)
186 }
187
188 pub fn get_task(&self, task_id: &str) -> Option<TaskInfo> {
190 if let Ok(tasks) = self.tasks.read() {
191 tasks.get(task_id).map(|t| t.to_info())
192 } else {
193 None
194 }
195 }
196
197 pub fn get_task_full(
199 &self,
200 task_id: &str,
201 ) -> Option<(TaskStatus, Option<CallToolResult>, Option<String>)> {
202 if let Ok(tasks) = self.tasks.read() {
203 tasks
204 .get(task_id)
205 .map(|t| (t.status, t.result.clone(), t.error.clone()))
206 } else {
207 None
208 }
209 }
210
211 pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
213 if let Ok(tasks) = self.tasks.read() {
214 tasks
215 .values()
216 .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
217 .map(|t| t.to_info())
218 .collect()
219 } else {
220 vec![]
221 }
222 }
223
224 pub fn update_progress(&self, task_id: &str, progress: f64, message: Option<String>) -> bool {
226 if let Ok(mut tasks) = self.tasks.write()
227 && let Some(task) = tasks.get_mut(task_id)
228 && !task.status.is_terminal()
229 {
230 task.progress = Some(progress);
231 if let Some(msg) = message {
232 task.message = Some(msg);
233 }
234 return true;
235 }
236 false
237 }
238
239 pub fn require_input(&self, task_id: &str, message: &str) -> bool {
241 if let Ok(mut tasks) = self.tasks.write()
242 && let Some(task) = tasks.get_mut(task_id)
243 && !task.status.is_terminal()
244 {
245 task.status = TaskStatus::InputRequired;
246 task.message = Some(message.to_string());
247 return true;
248 }
249 false
250 }
251
252 pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
254 if let Ok(mut tasks) = self.tasks.write()
255 && let Some(task) = tasks.get_mut(task_id)
256 && !task.status.is_terminal()
257 {
258 task.status = TaskStatus::Completed;
259 task.progress = Some(100.0);
260 task.message = Some("Task completed".to_string());
261 task.result = Some(result);
262 task.completed_at = Some(Instant::now());
263 return true;
264 }
265 false
266 }
267
268 pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
270 if let Ok(mut tasks) = self.tasks.write()
271 && let Some(task) = tasks.get_mut(task_id)
272 && !task.status.is_terminal()
273 {
274 task.status = TaskStatus::Failed;
275 task.message = Some(format!("Task failed: {}", error));
276 task.error = Some(error.to_string());
277 task.completed_at = Some(Instant::now());
278 return true;
279 }
280 false
281 }
282
283 pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskStatus> {
285 if let Ok(mut tasks) = self.tasks.write()
286 && let Some(task) = tasks.get_mut(task_id)
287 {
288 task.cancellation_token.cancel();
290
291 if !task.status.is_terminal() {
293 task.status = TaskStatus::Cancelled;
294 task.message = Some(
295 reason
296 .map(|r| format!("Cancelled: {}", r))
297 .unwrap_or_else(|| "Task cancelled".to_string()),
298 );
299 task.completed_at = Some(Instant::now());
300 }
301 return Some(task.status);
302 }
303 None
304 }
305
306 pub fn cleanup_expired(&self) -> usize {
308 if let Ok(mut tasks) = self.tasks.write() {
309 let before = tasks.len();
310 tasks.retain(|_, t| !t.is_expired());
311 before - tasks.len()
312 } else {
313 0
314 }
315 }
316
317 #[cfg(test)]
319 pub fn len(&self) -> usize {
320 if let Ok(tasks) = self.tasks.read() {
321 tasks.len()
322 } else {
323 0
324 }
325 }
326
327 #[cfg(test)]
329 pub fn is_empty(&self) -> bool {
330 self.len() == 0
331 }
332}
333
334fn chrono_now_iso8601() -> String {
336 use std::time::SystemTime;
337
338 let now = SystemTime::now();
339 let duration = now
340 .duration_since(SystemTime::UNIX_EPOCH)
341 .unwrap_or_default();
342
343 let secs = duration.as_secs();
344 let millis = duration.subsec_millis();
345
346 let days = secs / 86400;
349 let remaining = secs % 86400;
350 let hours = remaining / 3600;
351 let remaining = remaining % 3600;
352 let minutes = remaining / 60;
353 let seconds = remaining % 60;
354
355 let mut year = 1970i32;
358 let mut remaining_days = days as i32;
359
360 loop {
361 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
362 if remaining_days < days_in_year {
363 break;
364 }
365 remaining_days -= days_in_year;
366 year += 1;
367 }
368
369 let days_in_months: [i32; 12] = if is_leap_year(year) {
370 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
371 } else {
372 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
373 };
374
375 let mut month = 1;
376 for days_in_month in days_in_months.iter() {
377 if remaining_days < *days_in_month {
378 break;
379 }
380 remaining_days -= days_in_month;
381 month += 1;
382 }
383
384 let day = remaining_days + 1;
385
386 format!(
387 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
388 year, month, day, hours, minutes, seconds, millis
389 )
390}
391
392fn is_leap_year(year: i32) -> bool {
393 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399
400 #[test]
401 fn test_create_task() {
402 let store = TaskStore::new();
403 let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
404
405 assert!(id.starts_with("task-"));
406 assert!(!token.is_cancelled());
407
408 let info = store.get_task(&id).expect("task should exist");
409 assert_eq!(info.task_id, id);
410 assert_eq!(info.status, TaskStatus::Working);
411 }
412
413 #[test]
414 fn test_task_lifecycle() {
415 let store = TaskStore::new();
416 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
417
418 assert!(store.update_progress(&id, 50.0, Some("Halfway".to_string())));
420
421 let info = store.get_task(&id).unwrap();
422 assert_eq!(info.progress, Some(50.0));
423 assert_eq!(info.message.as_deref(), Some("Halfway"));
424
425 assert!(store.complete_task(&id, CallToolResult::text("Done")));
427
428 let info = store.get_task(&id).unwrap();
429 assert_eq!(info.status, TaskStatus::Completed);
430 assert_eq!(info.progress, Some(100.0));
431 }
432
433 #[test]
434 fn test_task_cancellation() {
435 let store = TaskStore::new();
436 let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
437
438 assert!(!token.is_cancelled());
439
440 let status = store.cancel_task(&id, Some("User requested"));
441 assert_eq!(status, Some(TaskStatus::Cancelled));
442 assert!(token.is_cancelled());
443
444 let info = store.get_task(&id).unwrap();
445 assert_eq!(info.status, TaskStatus::Cancelled);
446 }
447
448 #[test]
449 fn test_task_failure() {
450 let store = TaskStore::new();
451 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
452
453 assert!(store.fail_task(&id, "Something went wrong"));
454
455 let info = store.get_task(&id).unwrap();
456 assert_eq!(info.status, TaskStatus::Failed);
457 assert!(info.message.as_ref().unwrap().contains("failed"));
458 }
459
460 #[test]
461 fn test_list_tasks() {
462 let store = TaskStore::new();
463 store.create_task("tool1", serde_json::json!({}), None);
464 store.create_task("tool2", serde_json::json!({}), None);
465 let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
466
467 store.complete_task(&id3, CallToolResult::text("Done"));
469
470 let all = store.list_tasks(None);
472 assert_eq!(all.len(), 3);
473
474 let working = store.list_tasks(Some(TaskStatus::Working));
476 assert_eq!(working.len(), 2);
477
478 let completed = store.list_tasks(Some(TaskStatus::Completed));
480 assert_eq!(completed.len(), 1);
481 }
482
483 #[test]
484 fn test_terminal_state_immutable() {
485 let store = TaskStore::new();
486 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
487
488 store.complete_task(&id, CallToolResult::text("Done"));
490
491 assert!(!store.update_progress(&id, 50.0, None));
493 assert!(!store.fail_task(&id, "Error"));
494
495 let info = store.get_task(&id).unwrap();
497 assert_eq!(info.status, TaskStatus::Completed);
498 }
499
500 #[test]
501 fn test_task_ids_unique() {
502 let store = TaskStore::new();
503 let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
504 let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
505 let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
506
507 assert_ne!(id1, id2);
508 assert_ne!(id2, id3);
509 assert_ne!(id1, id3);
510 }
511
512 #[test]
513 fn test_get_task_full() {
514 let store = TaskStore::new();
515 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
516
517 let result = CallToolResult::text("The result");
519 store.complete_task(&id, result);
520
521 let (status, result, error) = store.get_task_full(&id).unwrap();
522 assert_eq!(status, TaskStatus::Completed);
523 assert!(result.is_some());
524 assert!(error.is_none());
525 }
526
527 #[test]
528 fn test_iso8601_timestamp() {
529 let ts = chrono_now_iso8601();
530 assert!(ts.ends_with('Z'));
532 assert!(ts.contains('T'));
533 assert_eq!(ts.len(), 24); }
535
536 #[test]
537 fn test_task_status_display() {
538 assert_eq!(TaskStatus::Working.to_string(), "working");
539 assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
540 assert_eq!(TaskStatus::Completed.to_string(), "completed");
541 assert_eq!(TaskStatus::Failed.to_string(), "failed");
542 assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
543 }
544
545 #[test]
546 fn test_task_status_is_terminal() {
547 assert!(!TaskStatus::Working.is_terminal());
548 assert!(!TaskStatus::InputRequired.is_terminal());
549 assert!(TaskStatus::Completed.is_terminal());
550 assert!(TaskStatus::Failed.is_terminal());
551 assert!(TaskStatus::Cancelled.is_terminal());
552 }
553}