Skip to main content

jamjet_a2a/
store.rs

1//! TaskStore trait and in-memory implementation for A2A v1.0 servers.
2
3use jamjet_a2a_types::*;
4use std::collections::HashMap;
5use tokio::sync::{broadcast, Mutex};
6use tracing::debug;
7
8// ────────────────────────────────────────────────────────────────────────────
9// TaskStore trait
10// ────────────────────────────────────────────────────────────────────────────
11
12/// Async storage backend for A2A tasks.
13///
14/// Implementations must be `Send + Sync` so they can be shared across Axum
15/// handlers via `Arc<dyn TaskStore>`.
16#[async_trait::async_trait]
17pub trait TaskStore: Send + Sync {
18    /// Insert a new task into the store.
19    async fn insert(&self, task: Task) -> Result<(), A2aError>;
20
21    /// Retrieve a task by ID, returning `None` if it does not exist.
22    async fn get(&self, task_id: &str) -> Result<Option<Task>, A2aError>;
23
24    /// Update a task's status and broadcast a [`TaskStatusUpdateEvent`].
25    async fn update_status(&self, task_id: &str, status: TaskStatus) -> Result<(), A2aError>;
26
27    /// Append an artifact to a task and broadcast a [`TaskArtifactUpdateEvent`].
28    async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> Result<(), A2aError>;
29
30    /// List tasks with cursor-based pagination and optional filters.
31    async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2aError>;
32
33    /// Append a message to an existing task's history.
34    async fn append_message(&self, task_id: &str, message: Message) -> Result<(), A2aError>;
35
36    /// Cancel a task. Returns an error if the task is in a terminal state.
37    async fn cancel(&self, task_id: &str) -> Result<(), A2aError>;
38
39    /// Subscribe to streaming events for a task. Returns `None` if the task
40    /// does not exist or has no broadcast channel.
41    async fn subscribe(&self, task_id: &str) -> Option<broadcast::Receiver<StreamResponse>>;
42}
43
44// ────────────────────────────────────────────────────────────────────────────
45// InMemoryTaskStore
46// ────────────────────────────────────────────────────────────────────────────
47
48/// Broadcast channel capacity for per-task event streams.
49const CHANNEL_CAPACITY: usize = 64;
50
51struct InMemoryInner {
52    tasks: HashMap<String, Task>,
53    /// Insertion-order list of task IDs (used for cursor-based pagination).
54    order: Vec<String>,
55    /// Per-task broadcast channels for streaming events.
56    channels: HashMap<String, broadcast::Sender<StreamResponse>>,
57}
58
59/// A simple in-memory [`TaskStore`] backed by a `tokio::sync::Mutex`.
60///
61/// Suitable for development, testing, and single-node deployments. For
62/// production use with multiple server instances, swap in a database-backed
63/// implementation.
64pub struct InMemoryTaskStore {
65    inner: Mutex<InMemoryInner>,
66}
67
68impl InMemoryTaskStore {
69    /// Create a new, empty in-memory store.
70    pub fn new() -> Self {
71        Self {
72            inner: Mutex::new(InMemoryInner {
73                tasks: HashMap::new(),
74                order: Vec::new(),
75                channels: HashMap::new(),
76            }),
77        }
78    }
79}
80
81impl Default for InMemoryTaskStore {
82    fn default() -> Self {
83        Self::new()
84    }
85}
86
87#[async_trait::async_trait]
88impl TaskStore for InMemoryTaskStore {
89    async fn insert(&self, task: Task) -> Result<(), A2aError> {
90        let mut inner = self.inner.lock().await;
91        let task_id = task.id.clone();
92        let (tx, _) = broadcast::channel(CHANNEL_CAPACITY);
93        inner.channels.insert(task_id.clone(), tx);
94        inner.order.push(task_id.clone());
95        inner.tasks.insert(task_id, task);
96        Ok(())
97    }
98
99    async fn get(&self, task_id: &str) -> Result<Option<Task>, A2aError> {
100        let inner = self.inner.lock().await;
101        Ok(inner.tasks.get(task_id).cloned())
102    }
103
104    async fn update_status(&self, task_id: &str, status: TaskStatus) -> Result<(), A2aError> {
105        let mut inner = self.inner.lock().await;
106        let task = inner
107            .tasks
108            .get_mut(task_id)
109            .ok_or_else(|| A2aProtocolError::TaskNotFound {
110                task_id: task_id.to_string(),
111            })?;
112        task.status = status.clone();
113        let context_id = task.context_id.clone().unwrap_or_default();
114        // Release the mutable borrow on `tasks` before accessing `channels`.
115        let _ = task;
116
117        // Broadcast status update event.
118        if let Some(tx) = inner.channels.get(task_id) {
119            let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
120                task_id: task_id.to_string(),
121                context_id,
122                status,
123                metadata: None,
124            });
125            // Ignore send errors (no active receivers).
126            let _ = tx.send(event);
127        }
128
129        debug!(task_id, "status updated");
130        Ok(())
131    }
132
133    async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> Result<(), A2aError> {
134        let mut inner = self.inner.lock().await;
135        let task = inner
136            .tasks
137            .get_mut(task_id)
138            .ok_or_else(|| A2aProtocolError::TaskNotFound {
139                task_id: task_id.to_string(),
140            })?;
141        task.artifacts.push(artifact.clone());
142        let context_id = task.context_id.clone().unwrap_or_default();
143        // Release the mutable borrow on `tasks` before accessing `channels`.
144        let _ = task;
145
146        // Broadcast artifact update event.
147        if let Some(tx) = inner.channels.get(task_id) {
148            let event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
149                task_id: task_id.to_string(),
150                context_id,
151                artifact,
152                append: None,
153                last_chunk: None,
154                metadata: None,
155            });
156            let _ = tx.send(event);
157        }
158
159        debug!(task_id, "artifact added");
160        Ok(())
161    }
162
163    async fn append_message(&self, task_id: &str, message: Message) -> Result<(), A2aError> {
164        let mut inner = self.inner.lock().await;
165        let task = inner
166            .tasks
167            .get_mut(task_id)
168            .ok_or_else(|| A2aProtocolError::TaskNotFound {
169                task_id: task_id.to_string(),
170            })?;
171        match task.history {
172            Some(ref mut hist) => hist.push(message),
173            None => task.history = Some(vec![message]),
174        }
175        debug!(task_id, "message appended to history");
176        Ok(())
177    }
178
179    async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2aError> {
180        let inner = self.inner.lock().await;
181
182        let page_size = req.page_size.unwrap_or(50).max(1).min(100) as usize;
183        let history_length = req.history_length;
184        let include_artifacts = req.include_artifacts.unwrap_or(false);
185
186        // Parse statusTimestampAfter filter.
187        let ts_after = req
188            .status_timestamp_after
189            .as_deref()
190            .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok());
191
192        // Step 1: Collect ALL matching tasks (applying filters) in reverse insertion
193        // order (descending — most recently inserted first, per A2A spec).
194        let mut all_matching: Vec<Task> = Vec::new();
195        for id in inner.order.iter().rev() {
196            if let Some(task) = inner.tasks.get(id) {
197                // Filter by context_id.
198                if let Some(ref ctx) = req.context_id {
199                    if task.context_id.as_deref() != Some(ctx.as_str()) {
200                        continue;
201                    }
202                }
203                // Filter by status.
204                if let Some(ref status) = req.status {
205                    if task.status.state != *status {
206                        continue;
207                    }
208                }
209                // Filter by statusTimestampAfter.
210                if let Some(ts_cutoff) = &ts_after {
211                    let passes = task
212                        .status
213                        .timestamp
214                        .as_deref()
215                        .and_then(|t| chrono::DateTime::parse_from_rfc3339(t).ok())
216                        .map(|t| t >= *ts_cutoff)
217                        .unwrap_or(false);
218                    if !passes {
219                        continue;
220                    }
221                }
222                all_matching.push(task.clone());
223            }
224        }
225
226        let total_size = all_matching.len() as i32;
227
228        // Step 2: Cursor-based pagination on the filtered set.
229        let start_idx = if let Some(ref token) = req.page_token {
230            if token.is_empty() {
231                0
232            } else {
233                // Find the cursor in the filtered list and start after it.
234                all_matching
235                    .iter()
236                    .position(|t| t.id == *token)
237                    .map(|pos| pos + 1)
238                    .unwrap_or(all_matching.len())
239            }
240        } else {
241            0
242        };
243
244        let page: Vec<Task> = all_matching
245            .into_iter()
246            .skip(start_idx)
247            .take(page_size)
248            .map(|mut task| {
249                // Apply historyLength limiting.
250                if let Some(hl) = history_length {
251                    if hl == 0 {
252                        task.history = None;
253                    } else if let Some(ref mut hist) = task.history {
254                        let hl = hl as usize;
255                        if hist.len() > hl {
256                            let start = hist.len() - hl;
257                            *hist = hist.split_off(start);
258                        }
259                    }
260                }
261                // Exclude artifacts unless explicitly requested.
262                if !include_artifacts {
263                    task.artifacts = vec![];
264                }
265                task
266            })
267            .collect();
268
269        let actual_count = page.len() as i32;
270
271        // Determine next_page_token: empty string means no more results.
272        let next_page_token = if page.len() == page_size
273            && start_idx + page_size < total_size as usize
274        {
275            page.last().map(|t| t.id.clone()).unwrap_or_default()
276        } else {
277            String::new()
278        };
279
280        // pageSize in response = actual number of tasks returned (capped by request).
281        Ok(ListTasksResponse {
282            tasks: page,
283            next_page_token,
284            page_size: actual_count,
285            total_size,
286        })
287    }
288
289    async fn cancel(&self, task_id: &str) -> Result<(), A2aError> {
290        let mut inner = self.inner.lock().await;
291        let task = inner
292            .tasks
293            .get_mut(task_id)
294            .ok_or_else(|| A2aProtocolError::TaskNotFound {
295                task_id: task_id.to_string(),
296            })?;
297
298        // Terminal states cannot be canceled.
299        match task.status.state {
300            TaskState::Completed | TaskState::Failed | TaskState::Canceled => {
301                return Err(A2aProtocolError::TaskNotCancelable {
302                    task_id: task_id.to_string(),
303                }
304                .into());
305            }
306            _ => {}
307        }
308
309        let canceled_status = TaskStatus {
310            state: TaskState::Canceled,
311            message: None,
312            timestamp: Some(chrono::Utc::now().to_rfc3339()),
313        };
314        task.status = canceled_status.clone();
315        let context_id = task.context_id.clone().unwrap_or_default();
316        // Release the mutable borrow on `tasks` before accessing `channels`.
317        let _ = task;
318
319        // Broadcast cancellation event.
320        if let Some(tx) = inner.channels.get(task_id) {
321            let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
322                task_id: task_id.to_string(),
323                context_id,
324                status: canceled_status,
325                metadata: None,
326            });
327            let _ = tx.send(event);
328        }
329
330        debug!(task_id, "task canceled");
331        Ok(())
332    }
333
334    async fn subscribe(&self, task_id: &str) -> Option<broadcast::Receiver<StreamResponse>> {
335        let inner = self.inner.lock().await;
336        inner.channels.get(task_id).map(|tx| tx.subscribe())
337    }
338}
339
340// ────────────────────────────────────────────────────────────────────────────
341// Tests
342// ────────────────────────────────────────────────────────────────────────────
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    fn test_task(id: &str) -> Task {
349        Task {
350            id: id.into(),
351            context_id: None,
352            status: TaskStatus {
353                state: TaskState::Submitted,
354                message: None,
355                timestamp: None,
356            },
357            artifacts: vec![],
358            history: None,
359            metadata: None,
360        }
361    }
362
363    #[tokio::test]
364    async fn insert_and_get() {
365        let store = InMemoryTaskStore::new();
366        store.insert(test_task("t1")).await.unwrap();
367        let task = store.get("t1").await.unwrap();
368        assert!(task.is_some());
369        assert_eq!(task.unwrap().id, "t1");
370    }
371
372    #[tokio::test]
373    async fn get_missing_returns_none() {
374        let store = InMemoryTaskStore::new();
375        assert!(store.get("nope").await.unwrap().is_none());
376    }
377
378    #[tokio::test]
379    async fn cancel_terminal_task_fails() {
380        let store = InMemoryTaskStore::new();
381        store.insert(test_task("t1")).await.unwrap();
382        store
383            .update_status(
384                "t1",
385                TaskStatus {
386                    state: TaskState::Completed,
387                    message: None,
388                    timestamp: None,
389                },
390            )
391            .await
392            .unwrap();
393        let result = store.cancel("t1").await;
394        assert!(result.is_err());
395    }
396
397    #[tokio::test]
398    async fn cancel_working_task_succeeds() {
399        let store = InMemoryTaskStore::new();
400        store.insert(test_task("t1")).await.unwrap();
401        store
402            .update_status(
403                "t1",
404                TaskStatus {
405                    state: TaskState::Working,
406                    message: None,
407                    timestamp: None,
408                },
409            )
410            .await
411            .unwrap();
412        store.cancel("t1").await.unwrap();
413        let task = store.get("t1").await.unwrap().unwrap();
414        assert_eq!(task.status.state, TaskState::Canceled);
415    }
416
417    #[tokio::test]
418    async fn list_with_pagination() {
419        let store = InMemoryTaskStore::new();
420        for i in 0..5 {
421            store.insert(test_task(&format!("t{i}"))).await.unwrap();
422        }
423        let resp = store
424            .list(&ListTasksRequest {
425                tenant: None,
426                context_id: None,
427                status: None,
428                page_size: Some(2),
429                page_token: None,
430                history_length: None,
431                status_timestamp_after: None,
432                include_artifacts: None,
433            })
434            .await
435            .unwrap();
436        assert_eq!(resp.tasks.len(), 2);
437        assert_eq!(resp.total_size, 5);
438
439        // Next page using the cursor.
440        let resp2 = store
441            .list(&ListTasksRequest {
442                page_token: Some(resp.next_page_token.clone()),
443                page_size: Some(2),
444                ..Default::default()
445            })
446            .await
447            .unwrap();
448        assert_eq!(resp2.tasks.len(), 2);
449    }
450
451    #[tokio::test]
452    async fn list_filters_by_context_id() {
453        let store = InMemoryTaskStore::new();
454        let mut task_a = test_task("t1");
455        task_a.context_id = Some("ctx-a".into());
456        let mut task_b = test_task("t2");
457        task_b.context_id = Some("ctx-b".into());
458        store.insert(task_a).await.unwrap();
459        store.insert(task_b).await.unwrap();
460
461        let resp = store
462            .list(&ListTasksRequest {
463                context_id: Some("ctx-a".into()),
464                ..Default::default()
465            })
466            .await
467            .unwrap();
468        assert_eq!(resp.tasks.len(), 1);
469        assert_eq!(resp.tasks[0].id, "t1");
470    }
471
472    #[tokio::test]
473    async fn list_filters_by_status() {
474        let store = InMemoryTaskStore::new();
475        store.insert(test_task("t1")).await.unwrap();
476        store.insert(test_task("t2")).await.unwrap();
477        store
478            .update_status(
479                "t2",
480                TaskStatus {
481                    state: TaskState::Working,
482                    message: None,
483                    timestamp: None,
484                },
485            )
486            .await
487            .unwrap();
488
489        let resp = store
490            .list(&ListTasksRequest {
491                status: Some(TaskState::Working),
492                ..Default::default()
493            })
494            .await
495            .unwrap();
496        assert_eq!(resp.tasks.len(), 1);
497        assert_eq!(resp.tasks[0].id, "t2");
498    }
499
500    #[tokio::test]
501    async fn update_status_broadcasts_event() {
502        let store = InMemoryTaskStore::new();
503        store.insert(test_task("t1")).await.unwrap();
504        let mut rx = store.subscribe("t1").await.unwrap();
505
506        store
507            .update_status(
508                "t1",
509                TaskStatus {
510                    state: TaskState::Working,
511                    message: None,
512                    timestamp: None,
513                },
514            )
515            .await
516            .unwrap();
517
518        let event = rx.recv().await.unwrap();
519        match event {
520            StreamResponse::StatusUpdate(e) => {
521                assert_eq!(e.task_id, "t1");
522                assert_eq!(e.status.state, TaskState::Working);
523            }
524            _ => panic!("expected StatusUpdate event"),
525        }
526    }
527
528    #[tokio::test]
529    async fn add_artifact_broadcasts_event() {
530        let store = InMemoryTaskStore::new();
531        store.insert(test_task("t1")).await.unwrap();
532        let mut rx = store.subscribe("t1").await.unwrap();
533
534        let artifact = Artifact {
535            artifact_id: "a1".into(),
536            name: Some("test".into()),
537            description: None,
538            parts: vec![],
539            metadata: None,
540            extensions: vec![],
541        };
542        store.add_artifact("t1", artifact).await.unwrap();
543
544        let event = rx.recv().await.unwrap();
545        match event {
546            StreamResponse::ArtifactUpdate(e) => {
547                assert_eq!(e.task_id, "t1");
548                assert_eq!(e.artifact.artifact_id, "a1");
549            }
550            _ => panic!("expected ArtifactUpdate event"),
551        }
552    }
553
554    #[tokio::test]
555    async fn subscribe_missing_task_returns_none() {
556        let store = InMemoryTaskStore::new();
557        assert!(store.subscribe("nope").await.is_none());
558    }
559
560    #[tokio::test]
561    async fn cancel_missing_task_returns_error() {
562        let store = InMemoryTaskStore::new();
563        let result = store.cancel("nope").await;
564        assert!(result.is_err());
565    }
566}