Skip to main content

a2a_rs_server/
task_store.rs

1//! In-memory task store
2//!
3//! Provides thread-safe storage for A2A tasks.
4
5use a2a_rs_core::{Task, ListTasksRequest, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Thread-safe in-memory task store
11#[derive(Clone)]
12pub struct TaskStore {
13    tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl TaskStore {
23    /// Create a new empty task store
24    pub fn new() -> Self {
25        Self {
26            tasks: Arc::new(RwLock::new(HashMap::new())),
27        }
28    }
29
30    /// Insert or update a task
31    pub async fn insert(&self, task: Task) {
32        let id = task.id.clone();
33        self.tasks.write().await.insert(id, task);
34    }
35
36    /// Get a task by ID
37    pub async fn get(&self, id: &str) -> Option<Task> {
38        self.tasks.read().await.get(id).cloned()
39    }
40
41    /// Get a task, trying multiple key formats
42    ///
43    /// Tries: exact match, with "tasks/" prefix, without "tasks/" prefix
44    pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45        let guard = self.tasks.read().await;
46        
47        // Try exact match
48        if let Some(task) = guard.get(id) {
49            return Some(task.clone());
50        }
51
52        // Try with "tasks/" prefix
53        let prefixed = format!("tasks/{}", id);
54        if let Some(task) = guard.get(&prefixed) {
55            return Some(task.clone());
56        }
57
58        // Try without "tasks/" prefix
59        if let Some(stripped) = id.strip_prefix("tasks/") {
60            if let Some(task) = guard.get(stripped) {
61                return Some(task.clone());
62            }
63        }
64
65        None
66    }
67
68    /// Update a task's state
69    pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70    where
71        F: FnOnce(&mut Task),
72    {
73        let mut guard = self.tasks.write().await;
74        if let Some(task) = guard.get_mut(id) {
75            f(task);
76            Some(task.clone())
77        } else {
78            None
79        }
80    }
81
82    /// Update a task with a fallible closure, trying multiple key formats
83    ///
84    /// Returns:
85    /// - `None` if task not found
86    /// - `Some(Err(e))` if closure returned error
87    /// - `Some(Ok(task))` if update succeeded
88    pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89    where
90        F: FnOnce(&mut Task) -> Result<(), E>,
91    {
92        let mut guard = self.tasks.write().await;
93
94        // Try to find the task with flexible key matching
95        let key = if guard.contains_key(id) {
96            Some(id.to_string())
97        } else {
98            let prefixed = format!("tasks/{}", id);
99            if guard.contains_key(&prefixed) {
100                Some(prefixed)
101            } else if let Some(stripped) = id.strip_prefix("tasks/") {
102                if guard.contains_key(stripped) {
103                    Some(stripped.to_string())
104                } else {
105                    None
106                }
107            } else {
108                None
109            }
110        };
111
112        let key = key?;
113        let task = guard.get_mut(&key)?;
114
115        match f(task) {
116            Ok(()) => Some(Ok(task.clone())),
117            Err(e) => Some(Err(e)),
118        }
119    }
120
121    /// Remove a task
122    pub async fn remove(&self, id: &str) -> Option<Task> {
123        self.tasks.write().await.remove(id)
124    }
125
126    /// List all tasks
127    pub async fn list(&self) -> Vec<Task> {
128        self.tasks.read().await.values().cloned().collect()
129    }
130
131    /// List tasks with filtering and pagination
132    ///
133    /// Returns a TaskListResponse with filtered tasks and pagination info.
134    pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135        let guard = self.tasks.read().await;
136
137        // Apply filters
138        let mut filtered: Vec<_> = guard
139            .values()
140            .filter(|task| {
141                // Filter by context_id
142                if let Some(ref ctx) = params.context_id {
143                    if task.context_id != *ctx {
144                        return false;
145                    }
146                }
147                // Filter by status
148                if let Some(status) = params.status {
149                    if task.status.state != status {
150                        return false;
151                    }
152                }
153                // Filter by status_timestamp_after (milliseconds since epoch)
154                if let Some(after_ms) = params.status_timestamp_after {
155                    if let Some(ref ts) = task.status.timestamp {
156                        // Parse ISO8601 timestamp and compare
157                        if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158                            if dt.timestamp_millis() <= after_ms {
159                                return false;
160                            }
161                        }
162                    } else {
163                        // No timestamp means we can't determine if it's after, exclude it
164                        return false;
165                    }
166                }
167                true
168            })
169            .cloned()
170            .collect();
171
172        let total_size = filtered.len() as u32;
173        let page_size = params.page_size.unwrap_or(50).min(100);
174
175        // Sort by task ID for consistent pagination (could be timestamp-based in production)
176        filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178        // Apply pagination using page_token as offset
179        let offset: usize = params
180            .page_token
181            .as_ref()
182            .and_then(|t| t.parse().ok())
183            .unwrap_or(0);
184
185        let paginated: Vec<_> = filtered
186            .into_iter()
187            .skip(offset)
188            .take(page_size as usize)
189            .map(|mut task| {
190                // Optionally trim history
191                if let Some(len) = params.history_length {
192                    if let Some(ref mut history) = task.history {
193                        let keep = len as usize;
194                        if history.len() > keep {
195                            *history = history.iter().rev().take(keep).cloned().collect();
196                            history.reverse();
197                        }
198                    }
199                }
200                // Optionally exclude artifacts
201                if params.include_artifacts == Some(false) {
202                    task.artifacts = None;
203                }
204                task
205            })
206            .collect();
207
208        let next_offset = offset + paginated.len();
209        let next_page_token = if next_offset < total_size as usize {
210            next_offset.to_string()
211        } else {
212            String::new()
213        };
214
215        TaskListResponse {
216            tasks: paginated,
217            next_page_token,
218            page_size,
219            total_size,
220        }
221    }
222
223    /// Get the number of stored tasks
224    pub async fn len(&self) -> usize {
225        self.tasks.read().await.len()
226    }
227
228    /// Check if the store is empty
229    pub async fn is_empty(&self) -> bool {
230        self.tasks.read().await.is_empty()
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use a2a_rs_core::{TaskState, TaskStatus};
238
239    fn make_task(id: &str) -> Task {
240        Task {
241            id: id.to_string(),
242            context_id: "ctx".to_string(),
243            status: TaskStatus {
244                state: TaskState::Working,
245                message: None,
246                timestamp: None,
247            },
248            history: None,
249            artifacts: None,
250            metadata: None,
251        }
252    }
253
254    #[tokio::test]
255    async fn test_insert_and_get() {
256        let store = TaskStore::new();
257        let task = make_task("task-1");
258        
259        store.insert(task.clone()).await;
260        
261        let retrieved = store.get("task-1").await;
262        assert!(retrieved.is_some());
263        assert_eq!(retrieved.unwrap().id, "task-1");
264    }
265
266    #[tokio::test]
267    async fn test_get_flexible() {
268        let store = TaskStore::new();
269        let task = make_task("tasks/abc-123");
270        
271        store.insert(task).await;
272        
273        // Exact match
274        assert!(store.get_flexible("tasks/abc-123").await.is_some());
275        
276        // Without prefix
277        assert!(store.get_flexible("abc-123").await.is_some());
278    }
279
280    #[tokio::test]
281    async fn test_update() {
282        let store = TaskStore::new();
283        let task = make_task("task-1");
284        store.insert(task).await;
285
286        let updated = store
287            .update("task-1", |t| {
288                t.status.state = TaskState::Completed;
289            })
290            .await;
291
292        assert!(updated.is_some());
293        assert_eq!(updated.unwrap().status.state, TaskState::Completed);
294    }
295
296    #[tokio::test]
297    async fn test_concurrent_inserts() {
298        let store = Arc::new(TaskStore::new());
299
300        // Spawn 100 concurrent inserts
301        let handles: Vec<_> = (0..100)
302            .map(|i| {
303                let store = store.clone();
304                tokio::spawn(async move {
305                    store.insert(make_task(&format!("task-{}", i))).await;
306                })
307            })
308            .collect();
309
310        for h in handles {
311            h.await.unwrap();
312        }
313
314        assert_eq!(store.len().await, 100);
315    }
316
317    #[tokio::test]
318    async fn test_concurrent_reads_and_writes() {
319        let store = Arc::new(TaskStore::new());
320
321        // Pre-populate with some tasks
322        for i in 0..10 {
323            store.insert(make_task(&format!("task-{}", i))).await;
324        }
325
326        // Spawn concurrent readers and writers
327        let mut handles = Vec::new();
328
329        // Writers
330        for i in 10..60 {
331            let store = store.clone();
332            handles.push(tokio::spawn(async move {
333                store.insert(make_task(&format!("task-{}", i))).await;
334            }));
335        }
336
337        // Readers
338        for i in 0..10 {
339            let store = store.clone();
340            handles.push(tokio::spawn(async move {
341                let _ = store.get(&format!("task-{}", i)).await;
342            }));
343        }
344
345        // Updaters
346        for i in 0..10 {
347            let store = store.clone();
348            handles.push(tokio::spawn(async move {
349                store
350                    .update(&format!("task-{}", i), |t| {
351                        t.status.state = TaskState::Completed;
352                    })
353                    .await;
354            }));
355        }
356
357        for h in handles {
358            h.await.unwrap();
359        }
360
361        // Should have 60 tasks (10 original + 50 new)
362        assert_eq!(store.len().await, 60);
363
364        // All original tasks should be completed
365        for i in 0..10 {
366            let task = store.get(&format!("task-{}", i)).await.unwrap();
367            assert_eq!(task.status.state, TaskState::Completed);
368        }
369    }
370
371    #[tokio::test]
372    async fn test_concurrent_update_flexible() {
373        let store = Arc::new(TaskStore::new());
374        store.insert(make_task("tasks/shared-task")).await;
375
376        // Spawn concurrent updates on the same task
377        let handles: Vec<_> = (0..50)
378            .map(|_| {
379                let store = store.clone();
380                tokio::spawn(async move {
381                    store
382                        .update_flexible("shared-task", |t| -> Result<(), ()> {
383                            t.context_id = "updated".to_string();
384                            Ok(())
385                        })
386                        .await
387                })
388            })
389            .collect();
390
391        for h in handles {
392            let result = h.await.unwrap();
393            assert!(result.is_some());
394            assert!(result.unwrap().is_ok());
395        }
396
397        // Task should exist and be updated
398        let task = store.get("tasks/shared-task").await.unwrap();
399        assert_eq!(task.context_id, "updated");
400    }
401}