Skip to main content

a2a_rs_server/
task_store.rs

1//! In-memory task store
2//!
3//! Provides thread-safe storage for A2A tasks.
4
5use a2a_rs_core::{ListTasksRequest, Task, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Thread-safe in-memory task store
11#[derive(Clone)]
12pub struct TaskStore {
13    tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl TaskStore {
23    /// Create a new empty task store
24    pub fn new() -> Self {
25        Self {
26            tasks: Arc::new(RwLock::new(HashMap::new())),
27        }
28    }
29
30    /// Insert or update a task
31    pub async fn insert(&self, task: Task) {
32        let id = task.id.clone();
33        self.tasks.write().await.insert(id, task);
34    }
35
36    /// Get a task by ID
37    pub async fn get(&self, id: &str) -> Option<Task> {
38        self.tasks.read().await.get(id).cloned()
39    }
40
41    /// Get a task, trying multiple key formats
42    ///
43    /// Tries: exact match, with "tasks/" prefix, without "tasks/" prefix
44    pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45        let guard = self.tasks.read().await;
46
47        // Try exact match
48        if let Some(task) = guard.get(id) {
49            return Some(task.clone());
50        }
51
52        // Try with "tasks/" prefix
53        let prefixed = format!("tasks/{}", id);
54        if let Some(task) = guard.get(&prefixed) {
55            return Some(task.clone());
56        }
57
58        // Try without "tasks/" prefix
59        if let Some(stripped) = id.strip_prefix("tasks/") {
60            if let Some(task) = guard.get(stripped) {
61                return Some(task.clone());
62            }
63        }
64
65        None
66    }
67
68    /// Update a task's state
69    pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70    where
71        F: FnOnce(&mut Task),
72    {
73        let mut guard = self.tasks.write().await;
74        if let Some(task) = guard.get_mut(id) {
75            f(task);
76            Some(task.clone())
77        } else {
78            None
79        }
80    }
81
82    /// Update a task with a fallible closure, trying multiple key formats
83    ///
84    /// Returns:
85    /// - `None` if task not found
86    /// - `Some(Err(e))` if closure returned error
87    /// - `Some(Ok(task))` if update succeeded
88    pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89    where
90        F: FnOnce(&mut Task) -> Result<(), E>,
91    {
92        let mut guard = self.tasks.write().await;
93
94        // Try to find the task with flexible key matching
95        let key = if guard.contains_key(id) {
96            Some(id.to_string())
97        } else {
98            let prefixed = format!("tasks/{}", id);
99            if guard.contains_key(&prefixed) {
100                Some(prefixed)
101            } else if let Some(stripped) = id.strip_prefix("tasks/") {
102                if guard.contains_key(stripped) {
103                    Some(stripped.to_string())
104                } else {
105                    None
106                }
107            } else {
108                None
109            }
110        };
111
112        let key = key?;
113        let task = guard.get_mut(&key)?;
114
115        match f(task) {
116            Ok(()) => Some(Ok(task.clone())),
117            Err(e) => Some(Err(e)),
118        }
119    }
120
121    /// Remove a task
122    pub async fn remove(&self, id: &str) -> Option<Task> {
123        self.tasks.write().await.remove(id)
124    }
125
126    /// List all tasks
127    pub async fn list(&self) -> Vec<Task> {
128        self.tasks.read().await.values().cloned().collect()
129    }
130
131    /// List tasks with filtering and pagination
132    ///
133    /// Returns a TaskListResponse with filtered tasks and pagination info.
134    pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135        let guard = self.tasks.read().await;
136
137        // Apply filters
138        let mut filtered: Vec<_> = guard
139            .values()
140            .filter(|task| {
141                // Filter by context_id
142                if let Some(ref ctx) = params.context_id {
143                    if task.context_id != *ctx {
144                        return false;
145                    }
146                }
147                // Filter by status
148                if let Some(status) = params.status {
149                    if task.status.state != status {
150                        return false;
151                    }
152                }
153                // Filter by status_timestamp_after (milliseconds since epoch)
154                if let Some(after_ms) = params.status_timestamp_after {
155                    if let Some(ref ts) = task.status.timestamp {
156                        // Parse ISO8601 timestamp and compare
157                        if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158                            if dt.timestamp_millis() <= after_ms {
159                                return false;
160                            }
161                        }
162                    } else {
163                        // No timestamp means we can't determine if it's after, exclude it
164                        return false;
165                    }
166                }
167                true
168            })
169            .cloned()
170            .collect();
171
172        let total_size = filtered.len() as u32;
173        let page_size = params.page_size.unwrap_or(50).min(100);
174
175        // Sort by task ID for consistent pagination (could be timestamp-based in production)
176        filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178        // Apply pagination using page_token as offset
179        let offset: usize = params
180            .page_token
181            .as_ref()
182            .and_then(|t| t.parse().ok())
183            .unwrap_or(0);
184
185        let paginated: Vec<_> = filtered
186            .into_iter()
187            .skip(offset)
188            .take(page_size as usize)
189            .map(|mut task| {
190                // Optionally trim history
191                if let Some(len) = params.history_length {
192                    if let Some(ref mut history) = task.history {
193                        let keep = len as usize;
194                        if history.len() > keep {
195                            *history = history.iter().rev().take(keep).cloned().collect();
196                            history.reverse();
197                        }
198                    }
199                }
200                // Optionally exclude artifacts
201                if params.include_artifacts == Some(false) {
202                    task.artifacts = None;
203                }
204                task
205            })
206            .collect();
207
208        let next_offset = offset + paginated.len();
209        let next_page_token = if next_offset < total_size as usize {
210            next_offset.to_string()
211        } else {
212            String::new()
213        };
214
215        TaskListResponse {
216            tasks: paginated,
217            next_page_token,
218            page_size,
219            total_size,
220        }
221    }
222
223    /// Get the number of stored tasks
224    pub async fn len(&self) -> usize {
225        self.tasks.read().await.len()
226    }
227
228    /// Check if the store is empty
229    pub async fn is_empty(&self) -> bool {
230        self.tasks.read().await.is_empty()
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use a2a_rs_core::{TaskState, TaskStatus};
238
239    fn make_task(id: &str) -> Task {
240        Task {
241            kind: "task".to_string(),
242            id: id.to_string(),
243            context_id: "ctx".to_string(),
244            status: TaskStatus {
245                state: TaskState::Working,
246                message: None,
247                timestamp: None,
248            },
249            history: None,
250            artifacts: None,
251            metadata: None,
252        }
253    }
254
255    #[tokio::test]
256    async fn test_insert_and_get() {
257        let store = TaskStore::new();
258        let task = make_task("task-1");
259
260        store.insert(task.clone()).await;
261
262        let retrieved = store.get("task-1").await;
263        assert!(retrieved.is_some());
264        assert_eq!(retrieved.unwrap().id, "task-1");
265    }
266
267    #[tokio::test]
268    async fn test_get_flexible() {
269        let store = TaskStore::new();
270        let task = make_task("tasks/abc-123");
271
272        store.insert(task).await;
273
274        // Exact match
275        assert!(store.get_flexible("tasks/abc-123").await.is_some());
276
277        // Without prefix
278        assert!(store.get_flexible("abc-123").await.is_some());
279    }
280
281    #[tokio::test]
282    async fn test_update() {
283        let store = TaskStore::new();
284        let task = make_task("task-1");
285        store.insert(task).await;
286
287        let updated = store
288            .update("task-1", |t| {
289                t.status.state = TaskState::Completed;
290            })
291            .await;
292
293        assert!(updated.is_some());
294        assert_eq!(updated.unwrap().status.state, TaskState::Completed);
295    }
296
297    #[tokio::test]
298    async fn test_concurrent_inserts() {
299        let store = Arc::new(TaskStore::new());
300
301        // Spawn 100 concurrent inserts
302        let handles: Vec<_> = (0..100)
303            .map(|i| {
304                let store = store.clone();
305                tokio::spawn(async move {
306                    store.insert(make_task(&format!("task-{}", i))).await;
307                })
308            })
309            .collect();
310
311        for h in handles {
312            h.await.unwrap();
313        }
314
315        assert_eq!(store.len().await, 100);
316    }
317
318    #[tokio::test]
319    async fn test_concurrent_reads_and_writes() {
320        let store = Arc::new(TaskStore::new());
321
322        // Pre-populate with some tasks
323        for i in 0..10 {
324            store.insert(make_task(&format!("task-{}", i))).await;
325        }
326
327        // Spawn concurrent readers and writers
328        let mut handles = Vec::new();
329
330        // Writers
331        for i in 10..60 {
332            let store = store.clone();
333            handles.push(tokio::spawn(async move {
334                store.insert(make_task(&format!("task-{}", i))).await;
335            }));
336        }
337
338        // Readers
339        for i in 0..10 {
340            let store = store.clone();
341            handles.push(tokio::spawn(async move {
342                let _ = store.get(&format!("task-{}", i)).await;
343            }));
344        }
345
346        // Updaters
347        for i in 0..10 {
348            let store = store.clone();
349            handles.push(tokio::spawn(async move {
350                store
351                    .update(&format!("task-{}", i), |t| {
352                        t.status.state = TaskState::Completed;
353                    })
354                    .await;
355            }));
356        }
357
358        for h in handles {
359            h.await.unwrap();
360        }
361
362        // Should have 60 tasks (10 original + 50 new)
363        assert_eq!(store.len().await, 60);
364
365        // All original tasks should be completed
366        for i in 0..10 {
367            let task = store.get(&format!("task-{}", i)).await.unwrap();
368            assert_eq!(task.status.state, TaskState::Completed);
369        }
370    }
371
372    #[tokio::test]
373    async fn test_concurrent_update_flexible() {
374        let store = Arc::new(TaskStore::new());
375        store.insert(make_task("tasks/shared-task")).await;
376
377        // Spawn concurrent updates on the same task
378        let handles: Vec<_> = (0..50)
379            .map(|_| {
380                let store = store.clone();
381                tokio::spawn(async move {
382                    store
383                        .update_flexible("shared-task", |t| -> Result<(), ()> {
384                            t.context_id = "updated".to_string();
385                            Ok(())
386                        })
387                        .await
388                })
389            })
390            .collect();
391
392        for h in handles {
393            let result = h.await.unwrap();
394            assert!(result.is_some());
395            assert!(result.unwrap().is_ok());
396        }
397
398        // Task should exist and be updated
399        let task = store.get("tasks/shared-task").await.unwrap();
400        assert_eq!(task.context_id, "updated");
401    }
402}