1use std::collections::HashMap;
27use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
28use std::sync::{Arc, RwLock};
29use std::time::{Duration, Instant};
30
31use crate::protocol::{CallToolResult, TaskInfo, TaskStatus};
32
33const DEFAULT_TTL_SECS: u64 = 300;
35
36const DEFAULT_POLL_INTERVAL_SECS: u64 = 2;
38
39#[derive(Debug)]
41pub struct Task {
42 pub id: String,
44 pub tool_name: String,
46 pub arguments: serde_json::Value,
48 pub status: TaskStatus,
50 pub created_at: Instant,
52 pub created_at_str: String,
54 pub ttl: u64,
56 pub poll_interval: u64,
58 pub progress: Option<f64>,
60 pub message: Option<String>,
62 pub result: Option<CallToolResult>,
64 pub error: Option<String>,
66 pub cancellation_token: CancellationToken,
68 pub completed_at: Option<Instant>,
70}
71
72impl Task {
73 fn new(id: String, tool_name: String, arguments: serde_json::Value, ttl: Option<u64>) -> Self {
75 let cancelled = Arc::new(AtomicBool::new(false));
76 Self {
77 id,
78 tool_name,
79 arguments,
80 status: TaskStatus::Working,
81 created_at: Instant::now(),
82 created_at_str: chrono_now_iso8601(),
83 ttl: ttl.unwrap_or(DEFAULT_TTL_SECS),
84 poll_interval: DEFAULT_POLL_INTERVAL_SECS,
85 progress: None,
86 message: Some("Task started".to_string()),
87 result: None,
88 error: None,
89 cancellation_token: CancellationToken { cancelled },
90 completed_at: None,
91 }
92 }
93
94 pub fn to_info(&self) -> TaskInfo {
96 TaskInfo {
97 task_id: self.id.clone(),
98 status: self.status,
99 created_at: self.created_at_str.clone(),
100 ttl: Some(self.ttl),
101 poll_interval: Some(self.poll_interval),
102 progress: self.progress,
103 message: self.message.clone(),
104 }
105 }
106
107 pub fn is_expired(&self) -> bool {
109 if let Some(completed_at) = self.completed_at {
110 completed_at.elapsed() > Duration::from_secs(self.ttl)
111 } else {
112 false
113 }
114 }
115
116 pub fn is_cancelled(&self) -> bool {
118 self.cancellation_token.is_cancelled()
119 }
120}
121
122#[derive(Debug, Clone)]
124pub struct CancellationToken {
125 cancelled: Arc<AtomicBool>,
126}
127
128impl CancellationToken {
129 pub fn is_cancelled(&self) -> bool {
131 self.cancelled.load(Ordering::Relaxed)
132 }
133
134 pub fn cancel(&self) {
136 self.cancelled.store(true, Ordering::Relaxed);
137 }
138}
139
140#[derive(Debug, Clone)]
142pub struct TaskStore {
143 tasks: Arc<RwLock<HashMap<String, Task>>>,
144 next_id: Arc<AtomicU64>,
145}
146
147impl Default for TaskStore {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153impl TaskStore {
154 pub fn new() -> Self {
156 Self {
157 tasks: Arc::new(RwLock::new(HashMap::new())),
158 next_id: Arc::new(AtomicU64::new(1)),
159 }
160 }
161
162 fn generate_id(&self) -> String {
164 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
165 format!("task-{}", id)
166 }
167
168 pub fn create_task(
172 &self,
173 tool_name: &str,
174 arguments: serde_json::Value,
175 ttl: Option<u64>,
176 ) -> (String, CancellationToken) {
177 let id = self.generate_id();
178 let task = Task::new(id.clone(), tool_name.to_string(), arguments, ttl);
179 let token = task.cancellation_token.clone();
180
181 if let Ok(mut tasks) = self.tasks.write() {
182 tasks.insert(id.clone(), task);
183 }
184
185 (id, token)
186 }
187
188 pub fn get_task(&self, task_id: &str) -> Option<TaskInfo> {
190 if let Ok(tasks) = self.tasks.read() {
191 tasks.get(task_id).map(|t| t.to_info())
192 } else {
193 None
194 }
195 }
196
197 pub fn get_task_full(
199 &self,
200 task_id: &str,
201 ) -> Option<(TaskStatus, Option<CallToolResult>, Option<String>)> {
202 if let Ok(tasks) = self.tasks.read() {
203 tasks
204 .get(task_id)
205 .map(|t| (t.status, t.result.clone(), t.error.clone()))
206 } else {
207 None
208 }
209 }
210
211 pub fn list_tasks(&self, status_filter: Option<TaskStatus>) -> Vec<TaskInfo> {
213 if let Ok(tasks) = self.tasks.read() {
214 tasks
215 .values()
216 .filter(|t| status_filter.is_none() || status_filter == Some(t.status))
217 .map(|t| t.to_info())
218 .collect()
219 } else {
220 vec![]
221 }
222 }
223
224 pub fn update_progress(&self, task_id: &str, progress: f64, message: Option<String>) -> bool {
226 let Ok(mut tasks) = self.tasks.write() else {
227 return false;
228 };
229 let Some(task) = tasks.get_mut(task_id) else {
230 return false;
231 };
232 if task.status.is_terminal() {
233 return false;
234 }
235 task.progress = Some(progress);
236 if let Some(msg) = message {
237 task.message = Some(msg);
238 }
239 true
240 }
241
242 pub fn require_input(&self, task_id: &str, message: &str) -> bool {
244 let Ok(mut tasks) = self.tasks.write() else {
245 return false;
246 };
247 let Some(task) = tasks.get_mut(task_id) else {
248 return false;
249 };
250 if task.status.is_terminal() {
251 return false;
252 }
253 task.status = TaskStatus::InputRequired;
254 task.message = Some(message.to_string());
255 true
256 }
257
258 pub fn complete_task(&self, task_id: &str, result: CallToolResult) -> bool {
260 let Ok(mut tasks) = self.tasks.write() else {
261 return false;
262 };
263 let Some(task) = tasks.get_mut(task_id) else {
264 return false;
265 };
266 if task.status.is_terminal() {
267 return false;
268 }
269 task.status = TaskStatus::Completed;
270 task.progress = Some(100.0);
271 task.message = Some("Task completed".to_string());
272 task.result = Some(result);
273 task.completed_at = Some(Instant::now());
274 true
275 }
276
277 pub fn fail_task(&self, task_id: &str, error: &str) -> bool {
279 let Ok(mut tasks) = self.tasks.write() else {
280 return false;
281 };
282 let Some(task) = tasks.get_mut(task_id) else {
283 return false;
284 };
285 if task.status.is_terminal() {
286 return false;
287 }
288 task.status = TaskStatus::Failed;
289 task.message = Some(format!("Task failed: {}", error));
290 task.error = Some(error.to_string());
291 task.completed_at = Some(Instant::now());
292 true
293 }
294
295 pub fn cancel_task(&self, task_id: &str, reason: Option<&str>) -> Option<TaskStatus> {
297 let mut tasks = self.tasks.write().ok()?;
298 let task = tasks.get_mut(task_id)?;
299
300 task.cancellation_token.cancel();
302
303 if !task.status.is_terminal() {
305 task.status = TaskStatus::Cancelled;
306 task.message = Some(
307 reason
308 .map(|r| format!("Cancelled: {}", r))
309 .unwrap_or_else(|| "Task cancelled".to_string()),
310 );
311 task.completed_at = Some(Instant::now());
312 }
313 Some(task.status)
314 }
315
316 pub fn cleanup_expired(&self) -> usize {
318 if let Ok(mut tasks) = self.tasks.write() {
319 let before = tasks.len();
320 tasks.retain(|_, t| !t.is_expired());
321 before - tasks.len()
322 } else {
323 0
324 }
325 }
326
327 #[cfg(test)]
329 pub fn len(&self) -> usize {
330 if let Ok(tasks) = self.tasks.read() {
331 tasks.len()
332 } else {
333 0
334 }
335 }
336
337 #[cfg(test)]
339 pub fn is_empty(&self) -> bool {
340 self.len() == 0
341 }
342}
343
344fn chrono_now_iso8601() -> String {
346 use std::time::SystemTime;
347
348 let now = SystemTime::now();
349 let duration = now
350 .duration_since(SystemTime::UNIX_EPOCH)
351 .unwrap_or_default();
352
353 let secs = duration.as_secs();
354 let millis = duration.subsec_millis();
355
356 let days = secs / 86400;
359 let remaining = secs % 86400;
360 let hours = remaining / 3600;
361 let remaining = remaining % 3600;
362 let minutes = remaining / 60;
363 let seconds = remaining % 60;
364
365 let mut year = 1970i32;
368 let mut remaining_days = days as i32;
369
370 loop {
371 let days_in_year = if is_leap_year(year) { 366 } else { 365 };
372 if remaining_days < days_in_year {
373 break;
374 }
375 remaining_days -= days_in_year;
376 year += 1;
377 }
378
379 let days_in_months: [i32; 12] = if is_leap_year(year) {
380 [31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
381 } else {
382 [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]
383 };
384
385 let mut month = 1;
386 for days_in_month in days_in_months.iter() {
387 if remaining_days < *days_in_month {
388 break;
389 }
390 remaining_days -= days_in_month;
391 month += 1;
392 }
393
394 let day = remaining_days + 1;
395
396 format!(
397 "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z",
398 year, month, day, hours, minutes, seconds, millis
399 )
400}
401
402fn is_leap_year(year: i32) -> bool {
403 (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0)
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_create_task() {
412 let store = TaskStore::new();
413 let (id, token) = store.create_task("test-tool", serde_json::json!({"a": 1}), None);
414
415 assert!(id.starts_with("task-"));
416 assert!(!token.is_cancelled());
417
418 let info = store.get_task(&id).expect("task should exist");
419 assert_eq!(info.task_id, id);
420 assert_eq!(info.status, TaskStatus::Working);
421 }
422
423 #[test]
424 fn test_task_lifecycle() {
425 let store = TaskStore::new();
426 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
427
428 assert!(store.update_progress(&id, 50.0, Some("Halfway".to_string())));
430
431 let info = store.get_task(&id).unwrap();
432 assert_eq!(info.progress, Some(50.0));
433 assert_eq!(info.message.as_deref(), Some("Halfway"));
434
435 assert!(store.complete_task(&id, CallToolResult::text("Done")));
437
438 let info = store.get_task(&id).unwrap();
439 assert_eq!(info.status, TaskStatus::Completed);
440 assert_eq!(info.progress, Some(100.0));
441 }
442
443 #[test]
444 fn test_task_cancellation() {
445 let store = TaskStore::new();
446 let (id, token) = store.create_task("test-tool", serde_json::json!({}), None);
447
448 assert!(!token.is_cancelled());
449
450 let status = store.cancel_task(&id, Some("User requested"));
451 assert_eq!(status, Some(TaskStatus::Cancelled));
452 assert!(token.is_cancelled());
453
454 let info = store.get_task(&id).unwrap();
455 assert_eq!(info.status, TaskStatus::Cancelled);
456 }
457
458 #[test]
459 fn test_task_failure() {
460 let store = TaskStore::new();
461 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
462
463 assert!(store.fail_task(&id, "Something went wrong"));
464
465 let info = store.get_task(&id).unwrap();
466 assert_eq!(info.status, TaskStatus::Failed);
467 assert!(info.message.as_ref().unwrap().contains("failed"));
468 }
469
470 #[test]
471 fn test_list_tasks() {
472 let store = TaskStore::new();
473 store.create_task("tool1", serde_json::json!({}), None);
474 store.create_task("tool2", serde_json::json!({}), None);
475 let (id3, _) = store.create_task("tool3", serde_json::json!({}), None);
476
477 store.complete_task(&id3, CallToolResult::text("Done"));
479
480 let all = store.list_tasks(None);
482 assert_eq!(all.len(), 3);
483
484 let working = store.list_tasks(Some(TaskStatus::Working));
486 assert_eq!(working.len(), 2);
487
488 let completed = store.list_tasks(Some(TaskStatus::Completed));
490 assert_eq!(completed.len(), 1);
491 }
492
493 #[test]
494 fn test_terminal_state_immutable() {
495 let store = TaskStore::new();
496 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
497
498 store.complete_task(&id, CallToolResult::text("Done"));
500
501 assert!(!store.update_progress(&id, 50.0, None));
503 assert!(!store.fail_task(&id, "Error"));
504
505 let info = store.get_task(&id).unwrap();
507 assert_eq!(info.status, TaskStatus::Completed);
508 }
509
510 #[test]
511 fn test_task_ids_unique() {
512 let store = TaskStore::new();
513 let (id1, _) = store.create_task("tool", serde_json::json!({}), None);
514 let (id2, _) = store.create_task("tool", serde_json::json!({}), None);
515 let (id3, _) = store.create_task("tool", serde_json::json!({}), None);
516
517 assert_ne!(id1, id2);
518 assert_ne!(id2, id3);
519 assert_ne!(id1, id3);
520 }
521
522 #[test]
523 fn test_get_task_full() {
524 let store = TaskStore::new();
525 let (id, _) = store.create_task("test-tool", serde_json::json!({}), None);
526
527 let result = CallToolResult::text("The result");
529 store.complete_task(&id, result);
530
531 let (status, result, error) = store.get_task_full(&id).unwrap();
532 assert_eq!(status, TaskStatus::Completed);
533 assert!(result.is_some());
534 assert!(error.is_none());
535 }
536
537 #[test]
538 fn test_iso8601_timestamp() {
539 let ts = chrono_now_iso8601();
540 assert!(ts.ends_with('Z'));
542 assert!(ts.contains('T'));
543 assert_eq!(ts.len(), 24); }
545
546 #[test]
547 fn test_task_status_display() {
548 assert_eq!(TaskStatus::Working.to_string(), "working");
549 assert_eq!(TaskStatus::InputRequired.to_string(), "input_required");
550 assert_eq!(TaskStatus::Completed.to_string(), "completed");
551 assert_eq!(TaskStatus::Failed.to_string(), "failed");
552 assert_eq!(TaskStatus::Cancelled.to_string(), "cancelled");
553 }
554
555 #[test]
556 fn test_task_status_is_terminal() {
557 assert!(!TaskStatus::Working.is_terminal());
558 assert!(!TaskStatus::InputRequired.is_terminal());
559 assert!(TaskStatus::Completed.is_terminal());
560 assert!(TaskStatus::Failed.is_terminal());
561 assert!(TaskStatus::Cancelled.is_terminal());
562 }
563}