Skip to main content

a2a_rs_server/
task_store.rs

1//! In-memory task store
2//!
3//! Provides thread-safe storage for A2A tasks.
4
5use a2a_rs_core::{ListTasksRequest, Task, TaskListResponse};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Thread-safe in-memory task store
11#[derive(Clone)]
12pub struct TaskStore {
13    tasks: Arc<RwLock<HashMap<String, Task>>>,
14}
15
16impl Default for TaskStore {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl TaskStore {
23    /// Create a new empty task store
24    pub fn new() -> Self {
25        Self {
26            tasks: Arc::new(RwLock::new(HashMap::new())),
27        }
28    }
29
30    /// Insert or update a task
31    pub async fn insert(&self, task: Task) {
32        let id = task.id.clone();
33        self.tasks.write().await.insert(id, task);
34    }
35
36    /// Get a task by ID
37    pub async fn get(&self, id: &str) -> Option<Task> {
38        self.tasks.read().await.get(id).cloned()
39    }
40
41    /// Get a task, trying multiple key formats
42    ///
43    /// Tries: exact match, with "tasks/" prefix, without "tasks/" prefix
44    pub async fn get_flexible(&self, id: &str) -> Option<Task> {
45        let guard = self.tasks.read().await;
46
47        // Try exact match
48        if let Some(task) = guard.get(id) {
49            return Some(task.clone());
50        }
51
52        // Try with "tasks/" prefix
53        let prefixed = format!("tasks/{}", id);
54        if let Some(task) = guard.get(&prefixed) {
55            return Some(task.clone());
56        }
57
58        // Try without "tasks/" prefix
59        if let Some(stripped) = id.strip_prefix("tasks/") {
60            if let Some(task) = guard.get(stripped) {
61                return Some(task.clone());
62            }
63        }
64
65        None
66    }
67
68    /// Update a task's state
69    pub async fn update<F>(&self, id: &str, f: F) -> Option<Task>
70    where
71        F: FnOnce(&mut Task),
72    {
73        let mut guard = self.tasks.write().await;
74        if let Some(task) = guard.get_mut(id) {
75            f(task);
76            Some(task.clone())
77        } else {
78            None
79        }
80    }
81
82    /// Update a task with a fallible closure, trying multiple key formats
83    ///
84    /// Returns:
85    /// - `None` if task not found
86    /// - `Some(Err(e))` if closure returned error
87    /// - `Some(Ok(task))` if update succeeded
88    pub async fn update_flexible<F, E>(&self, id: &str, f: F) -> Option<Result<Task, E>>
89    where
90        F: FnOnce(&mut Task) -> Result<(), E>,
91    {
92        let mut guard = self.tasks.write().await;
93
94        // Try to find the task with flexible key matching
95        let key = if guard.contains_key(id) {
96            Some(id.to_string())
97        } else {
98            let prefixed = format!("tasks/{}", id);
99            if guard.contains_key(&prefixed) {
100                Some(prefixed)
101            } else if let Some(stripped) = id.strip_prefix("tasks/") {
102                if guard.contains_key(stripped) {
103                    Some(stripped.to_string())
104                } else {
105                    None
106                }
107            } else {
108                None
109            }
110        };
111
112        let key = key?;
113        let task = guard.get_mut(&key)?;
114
115        match f(task) {
116            Ok(()) => Some(Ok(task.clone())),
117            Err(e) => Some(Err(e)),
118        }
119    }
120
121    /// Remove a task
122    pub async fn remove(&self, id: &str) -> Option<Task> {
123        self.tasks.write().await.remove(id)
124    }
125
126    /// List all tasks
127    pub async fn list(&self) -> Vec<Task> {
128        self.tasks.read().await.values().cloned().collect()
129    }
130
131    /// List tasks with filtering and pagination
132    ///
133    /// Returns a TaskListResponse with filtered tasks and pagination info.
134    pub async fn list_filtered(&self, params: &ListTasksRequest) -> TaskListResponse {
135        let guard = self.tasks.read().await;
136
137        // Apply filters
138        let mut filtered: Vec<_> = guard
139            .values()
140            .filter(|task| {
141                // Filter by context_id
142                if let Some(ref ctx) = params.context_id {
143                    if task.context_id != *ctx {
144                        return false;
145                    }
146                }
147                // Filter by status
148                if let Some(status) = params.status {
149                    if task.status.state != status {
150                        return false;
151                    }
152                }
153                // Filter by status_timestamp_after (milliseconds since epoch)
154                if let Some(after_ms) = params.status_timestamp_after {
155                    if let Some(ref ts) = task.status.timestamp {
156                        // Parse ISO8601 timestamp and compare
157                        if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(ts) {
158                            if dt.timestamp_millis() <= after_ms {
159                                return false;
160                            }
161                        }
162                    } else {
163                        // No timestamp means we can't determine if it's after, exclude it
164                        return false;
165                    }
166                }
167                true
168            })
169            .cloned()
170            .collect();
171
172        let total_size = filtered.len() as u32;
173        let page_size = params.page_size.unwrap_or(50).min(100);
174
175        // Sort by task ID for consistent pagination (could be timestamp-based in production)
176        filtered.sort_by(|a, b| a.id.cmp(&b.id));
177
178        // Apply pagination using page_token as offset
179        let offset: usize = params
180            .page_token
181            .as_ref()
182            .and_then(|t| t.parse().ok())
183            .unwrap_or(0);
184
185        let paginated: Vec<_> = filtered
186            .into_iter()
187            .skip(offset)
188            .take(page_size as usize)
189            .map(|mut task| {
190                // Optionally trim/omit history
191                match params.history_length {
192                    Some(0) => {
193                        task.history = None;
194                    }
195                    Some(n) => {
196                        if let Some(ref mut history) = task.history {
197                            let keep = n as usize;
198                            if history.len() > keep {
199                                *history = history.split_off(history.len() - keep);
200                            }
201                        }
202                    }
203                    None => {}
204                }
205                // Optionally exclude artifacts
206                if params.include_artifacts == Some(false) {
207                    task.artifacts = None;
208                }
209                task
210            })
211            .collect();
212
213        let next_offset = offset + paginated.len();
214        let next_page_token = if next_offset < total_size as usize {
215            next_offset.to_string()
216        } else {
217            String::new()
218        };
219
220        TaskListResponse {
221            tasks: paginated,
222            next_page_token,
223            page_size,
224            total_size,
225        }
226    }
227
228    /// Get the number of stored tasks
229    pub async fn len(&self) -> usize {
230        self.tasks.read().await.len()
231    }
232
233    /// Check if the store is empty
234    pub async fn is_empty(&self) -> bool {
235        self.tasks.read().await.is_empty()
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use a2a_rs_core::{TaskState, TaskStatus};
243
244    fn make_task(id: &str) -> Task {
245        Task {
246            kind: "task".to_string(),
247            id: id.to_string(),
248            context_id: "ctx".to_string(),
249            status: TaskStatus {
250                state: TaskState::Working,
251                message: None,
252                timestamp: None,
253            },
254            history: None,
255            artifacts: None,
256            metadata: None,
257        }
258    }
259
260    #[tokio::test]
261    async fn test_insert_and_get() {
262        let store = TaskStore::new();
263        let task = make_task("task-1");
264
265        store.insert(task.clone()).await;
266
267        let retrieved = store.get("task-1").await;
268        assert!(retrieved.is_some());
269        assert_eq!(retrieved.unwrap().id, "task-1");
270    }
271
272    #[tokio::test]
273    async fn test_get_flexible() {
274        let store = TaskStore::new();
275        let task = make_task("tasks/abc-123");
276
277        store.insert(task).await;
278
279        // Exact match
280        assert!(store.get_flexible("tasks/abc-123").await.is_some());
281
282        // Without prefix
283        assert!(store.get_flexible("abc-123").await.is_some());
284    }
285
286    #[tokio::test]
287    async fn test_update() {
288        let store = TaskStore::new();
289        let task = make_task("task-1");
290        store.insert(task).await;
291
292        let updated = store
293            .update("task-1", |t| {
294                t.status.state = TaskState::Completed;
295            })
296            .await;
297
298        assert!(updated.is_some());
299        assert_eq!(updated.unwrap().status.state, TaskState::Completed);
300    }
301
302    #[tokio::test]
303    async fn test_concurrent_inserts() {
304        let store = Arc::new(TaskStore::new());
305
306        // Spawn 100 concurrent inserts
307        let handles: Vec<_> = (0..100)
308            .map(|i| {
309                let store = store.clone();
310                tokio::spawn(async move {
311                    store.insert(make_task(&format!("task-{}", i))).await;
312                })
313            })
314            .collect();
315
316        for h in handles {
317            h.await.unwrap();
318        }
319
320        assert_eq!(store.len().await, 100);
321    }
322
323    #[tokio::test]
324    async fn test_concurrent_reads_and_writes() {
325        let store = Arc::new(TaskStore::new());
326
327        // Pre-populate with some tasks
328        for i in 0..10 {
329            store.insert(make_task(&format!("task-{}", i))).await;
330        }
331
332        // Spawn concurrent readers and writers
333        let mut handles = Vec::new();
334
335        // Writers
336        for i in 10..60 {
337            let store = store.clone();
338            handles.push(tokio::spawn(async move {
339                store.insert(make_task(&format!("task-{}", i))).await;
340            }));
341        }
342
343        // Readers
344        for i in 0..10 {
345            let store = store.clone();
346            handles.push(tokio::spawn(async move {
347                let _ = store.get(&format!("task-{}", i)).await;
348            }));
349        }
350
351        // Updaters
352        for i in 0..10 {
353            let store = store.clone();
354            handles.push(tokio::spawn(async move {
355                store
356                    .update(&format!("task-{}", i), |t| {
357                        t.status.state = TaskState::Completed;
358                    })
359                    .await;
360            }));
361        }
362
363        for h in handles {
364            h.await.unwrap();
365        }
366
367        // Should have 60 tasks (10 original + 50 new)
368        assert_eq!(store.len().await, 60);
369
370        // All original tasks should be completed
371        for i in 0..10 {
372            let task = store.get(&format!("task-{}", i)).await.unwrap();
373            assert_eq!(task.status.state, TaskState::Completed);
374        }
375    }
376
377    #[tokio::test]
378    async fn test_concurrent_update_flexible() {
379        let store = Arc::new(TaskStore::new());
380        store.insert(make_task("tasks/shared-task")).await;
381
382        // Spawn concurrent updates on the same task
383        let handles: Vec<_> = (0..50)
384            .map(|_| {
385                let store = store.clone();
386                tokio::spawn(async move {
387                    store
388                        .update_flexible("shared-task", |t| -> Result<(), ()> {
389                            t.context_id = "updated".to_string();
390                            Ok(())
391                        })
392                        .await
393                })
394            })
395            .collect();
396
397        for h in handles {
398            let result = h.await.unwrap();
399            assert!(result.is_some());
400            assert!(result.unwrap().is_ok());
401        }
402
403        // Task should exist and be updated
404        let task = store.get("tasks/shared-task").await.unwrap();
405        assert_eq!(task.context_id, "updated");
406    }
407}