1use jamjet_a2a_types::*;
4use std::collections::HashMap;
5use tokio::sync::{broadcast, Mutex};
6use tracing::debug;
7
8#[async_trait::async_trait]
17pub trait TaskStore: Send + Sync {
18 async fn insert(&self, task: Task) -> Result<(), A2aError>;
20
21 async fn get(&self, task_id: &str) -> Result<Option<Task>, A2aError>;
23
24 async fn update_status(&self, task_id: &str, status: TaskStatus) -> Result<(), A2aError>;
26
27 async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> Result<(), A2aError>;
29
30 async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2aError>;
32
33 async fn append_message(&self, task_id: &str, message: Message) -> Result<(), A2aError>;
35
36 async fn cancel(&self, task_id: &str) -> Result<(), A2aError>;
38
39 async fn subscribe(&self, task_id: &str) -> Option<broadcast::Receiver<StreamResponse>>;
42}
43
44const CHANNEL_CAPACITY: usize = 64;
50
51struct InMemoryInner {
52 tasks: HashMap<String, Task>,
53 order: Vec<String>,
55 channels: HashMap<String, broadcast::Sender<StreamResponse>>,
57}
58
59pub struct InMemoryTaskStore {
65 inner: Mutex<InMemoryInner>,
66}
67
68impl InMemoryTaskStore {
69 pub fn new() -> Self {
71 Self {
72 inner: Mutex::new(InMemoryInner {
73 tasks: HashMap::new(),
74 order: Vec::new(),
75 channels: HashMap::new(),
76 }),
77 }
78 }
79}
80
81impl Default for InMemoryTaskStore {
82 fn default() -> Self {
83 Self::new()
84 }
85}
86
87#[async_trait::async_trait]
88impl TaskStore for InMemoryTaskStore {
89 async fn insert(&self, task: Task) -> Result<(), A2aError> {
90 let mut inner = self.inner.lock().await;
91 let task_id = task.id.clone();
92 let (tx, _) = broadcast::channel(CHANNEL_CAPACITY);
93 inner.channels.insert(task_id.clone(), tx);
94 inner.order.push(task_id.clone());
95 inner.tasks.insert(task_id, task);
96 Ok(())
97 }
98
99 async fn get(&self, task_id: &str) -> Result<Option<Task>, A2aError> {
100 let inner = self.inner.lock().await;
101 Ok(inner.tasks.get(task_id).cloned())
102 }
103
104 async fn update_status(&self, task_id: &str, status: TaskStatus) -> Result<(), A2aError> {
105 let mut inner = self.inner.lock().await;
106 let task = inner
107 .tasks
108 .get_mut(task_id)
109 .ok_or_else(|| A2aProtocolError::TaskNotFound {
110 task_id: task_id.to_string(),
111 })?;
112 task.status = status.clone();
113 let context_id = task.context_id.clone().unwrap_or_default();
114 let _ = task;
116
117 if let Some(tx) = inner.channels.get(task_id) {
119 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
120 task_id: task_id.to_string(),
121 context_id,
122 status,
123 metadata: None,
124 });
125 let _ = tx.send(event);
127 }
128
129 debug!(task_id, "status updated");
130 Ok(())
131 }
132
133 async fn add_artifact(&self, task_id: &str, artifact: Artifact) -> Result<(), A2aError> {
134 let mut inner = self.inner.lock().await;
135 let task = inner
136 .tasks
137 .get_mut(task_id)
138 .ok_or_else(|| A2aProtocolError::TaskNotFound {
139 task_id: task_id.to_string(),
140 })?;
141 task.artifacts.push(artifact.clone());
142 let context_id = task.context_id.clone().unwrap_or_default();
143 let _ = task;
145
146 if let Some(tx) = inner.channels.get(task_id) {
148 let event = StreamResponse::ArtifactUpdate(TaskArtifactUpdateEvent {
149 task_id: task_id.to_string(),
150 context_id,
151 artifact,
152 append: None,
153 last_chunk: None,
154 metadata: None,
155 });
156 let _ = tx.send(event);
157 }
158
159 debug!(task_id, "artifact added");
160 Ok(())
161 }
162
163 async fn append_message(&self, task_id: &str, message: Message) -> Result<(), A2aError> {
164 let mut inner = self.inner.lock().await;
165 let task = inner
166 .tasks
167 .get_mut(task_id)
168 .ok_or_else(|| A2aProtocolError::TaskNotFound {
169 task_id: task_id.to_string(),
170 })?;
171 match task.history {
172 Some(ref mut hist) => hist.push(message),
173 None => task.history = Some(vec![message]),
174 }
175 debug!(task_id, "message appended to history");
176 Ok(())
177 }
178
179 async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2aError> {
180 let inner = self.inner.lock().await;
181
182 let page_size = req.page_size.unwrap_or(50).max(1).min(100) as usize;
183 let history_length = req.history_length;
184 let include_artifacts = req.include_artifacts.unwrap_or(false);
185
186 let ts_after = req
188 .status_timestamp_after
189 .as_deref()
190 .and_then(|s| chrono::DateTime::parse_from_rfc3339(s).ok());
191
192 let mut all_matching: Vec<Task> = Vec::new();
195 for id in inner.order.iter().rev() {
196 if let Some(task) = inner.tasks.get(id) {
197 if let Some(ref ctx) = req.context_id {
199 if task.context_id.as_deref() != Some(ctx.as_str()) {
200 continue;
201 }
202 }
203 if let Some(ref status) = req.status {
205 if task.status.state != *status {
206 continue;
207 }
208 }
209 if let Some(ts_cutoff) = &ts_after {
211 let passes = task
212 .status
213 .timestamp
214 .as_deref()
215 .and_then(|t| chrono::DateTime::parse_from_rfc3339(t).ok())
216 .map(|t| t >= *ts_cutoff)
217 .unwrap_or(false);
218 if !passes {
219 continue;
220 }
221 }
222 all_matching.push(task.clone());
223 }
224 }
225
226 let total_size = all_matching.len() as i32;
227
228 let start_idx = if let Some(ref token) = req.page_token {
230 if token.is_empty() {
231 0
232 } else {
233 all_matching
235 .iter()
236 .position(|t| t.id == *token)
237 .map(|pos| pos + 1)
238 .unwrap_or(all_matching.len())
239 }
240 } else {
241 0
242 };
243
244 let page: Vec<Task> = all_matching
245 .into_iter()
246 .skip(start_idx)
247 .take(page_size)
248 .map(|mut task| {
249 if let Some(hl) = history_length {
251 if hl == 0 {
252 task.history = None;
253 } else if let Some(ref mut hist) = task.history {
254 let hl = hl as usize;
255 if hist.len() > hl {
256 let start = hist.len() - hl;
257 *hist = hist.split_off(start);
258 }
259 }
260 }
261 if !include_artifacts {
263 task.artifacts = vec![];
264 }
265 task
266 })
267 .collect();
268
269 let actual_count = page.len() as i32;
270
271 let next_page_token = if page.len() == page_size
273 && start_idx + page_size < total_size as usize
274 {
275 page.last().map(|t| t.id.clone()).unwrap_or_default()
276 } else {
277 String::new()
278 };
279
280 Ok(ListTasksResponse {
282 tasks: page,
283 next_page_token,
284 page_size: actual_count,
285 total_size,
286 })
287 }
288
289 async fn cancel(&self, task_id: &str) -> Result<(), A2aError> {
290 let mut inner = self.inner.lock().await;
291 let task = inner
292 .tasks
293 .get_mut(task_id)
294 .ok_or_else(|| A2aProtocolError::TaskNotFound {
295 task_id: task_id.to_string(),
296 })?;
297
298 match task.status.state {
300 TaskState::Completed | TaskState::Failed | TaskState::Canceled => {
301 return Err(A2aProtocolError::TaskNotCancelable {
302 task_id: task_id.to_string(),
303 }
304 .into());
305 }
306 _ => {}
307 }
308
309 let canceled_status = TaskStatus {
310 state: TaskState::Canceled,
311 message: None,
312 timestamp: Some(chrono::Utc::now().to_rfc3339()),
313 };
314 task.status = canceled_status.clone();
315 let context_id = task.context_id.clone().unwrap_or_default();
316 let _ = task;
318
319 if let Some(tx) = inner.channels.get(task_id) {
321 let event = StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
322 task_id: task_id.to_string(),
323 context_id,
324 status: canceled_status,
325 metadata: None,
326 });
327 let _ = tx.send(event);
328 }
329
330 debug!(task_id, "task canceled");
331 Ok(())
332 }
333
334 async fn subscribe(&self, task_id: &str) -> Option<broadcast::Receiver<StreamResponse>> {
335 let inner = self.inner.lock().await;
336 inner.channels.get(task_id).map(|tx| tx.subscribe())
337 }
338}
339
340#[cfg(test)]
345mod tests {
346 use super::*;
347
348 fn test_task(id: &str) -> Task {
349 Task {
350 id: id.into(),
351 context_id: None,
352 status: TaskStatus {
353 state: TaskState::Submitted,
354 message: None,
355 timestamp: None,
356 },
357 artifacts: vec![],
358 history: None,
359 metadata: None,
360 }
361 }
362
363 #[tokio::test]
364 async fn insert_and_get() {
365 let store = InMemoryTaskStore::new();
366 store.insert(test_task("t1")).await.unwrap();
367 let task = store.get("t1").await.unwrap();
368 assert!(task.is_some());
369 assert_eq!(task.unwrap().id, "t1");
370 }
371
372 #[tokio::test]
373 async fn get_missing_returns_none() {
374 let store = InMemoryTaskStore::new();
375 assert!(store.get("nope").await.unwrap().is_none());
376 }
377
378 #[tokio::test]
379 async fn cancel_terminal_task_fails() {
380 let store = InMemoryTaskStore::new();
381 store.insert(test_task("t1")).await.unwrap();
382 store
383 .update_status(
384 "t1",
385 TaskStatus {
386 state: TaskState::Completed,
387 message: None,
388 timestamp: None,
389 },
390 )
391 .await
392 .unwrap();
393 let result = store.cancel("t1").await;
394 assert!(result.is_err());
395 }
396
397 #[tokio::test]
398 async fn cancel_working_task_succeeds() {
399 let store = InMemoryTaskStore::new();
400 store.insert(test_task("t1")).await.unwrap();
401 store
402 .update_status(
403 "t1",
404 TaskStatus {
405 state: TaskState::Working,
406 message: None,
407 timestamp: None,
408 },
409 )
410 .await
411 .unwrap();
412 store.cancel("t1").await.unwrap();
413 let task = store.get("t1").await.unwrap().unwrap();
414 assert_eq!(task.status.state, TaskState::Canceled);
415 }
416
417 #[tokio::test]
418 async fn list_with_pagination() {
419 let store = InMemoryTaskStore::new();
420 for i in 0..5 {
421 store.insert(test_task(&format!("t{i}"))).await.unwrap();
422 }
423 let resp = store
424 .list(&ListTasksRequest {
425 tenant: None,
426 context_id: None,
427 status: None,
428 page_size: Some(2),
429 page_token: None,
430 history_length: None,
431 status_timestamp_after: None,
432 include_artifacts: None,
433 })
434 .await
435 .unwrap();
436 assert_eq!(resp.tasks.len(), 2);
437 assert_eq!(resp.total_size, 5);
438
439 let resp2 = store
441 .list(&ListTasksRequest {
442 page_token: Some(resp.next_page_token.clone()),
443 page_size: Some(2),
444 ..Default::default()
445 })
446 .await
447 .unwrap();
448 assert_eq!(resp2.tasks.len(), 2);
449 }
450
451 #[tokio::test]
452 async fn list_filters_by_context_id() {
453 let store = InMemoryTaskStore::new();
454 let mut task_a = test_task("t1");
455 task_a.context_id = Some("ctx-a".into());
456 let mut task_b = test_task("t2");
457 task_b.context_id = Some("ctx-b".into());
458 store.insert(task_a).await.unwrap();
459 store.insert(task_b).await.unwrap();
460
461 let resp = store
462 .list(&ListTasksRequest {
463 context_id: Some("ctx-a".into()),
464 ..Default::default()
465 })
466 .await
467 .unwrap();
468 assert_eq!(resp.tasks.len(), 1);
469 assert_eq!(resp.tasks[0].id, "t1");
470 }
471
472 #[tokio::test]
473 async fn list_filters_by_status() {
474 let store = InMemoryTaskStore::new();
475 store.insert(test_task("t1")).await.unwrap();
476 store.insert(test_task("t2")).await.unwrap();
477 store
478 .update_status(
479 "t2",
480 TaskStatus {
481 state: TaskState::Working,
482 message: None,
483 timestamp: None,
484 },
485 )
486 .await
487 .unwrap();
488
489 let resp = store
490 .list(&ListTasksRequest {
491 status: Some(TaskState::Working),
492 ..Default::default()
493 })
494 .await
495 .unwrap();
496 assert_eq!(resp.tasks.len(), 1);
497 assert_eq!(resp.tasks[0].id, "t2");
498 }
499
500 #[tokio::test]
501 async fn update_status_broadcasts_event() {
502 let store = InMemoryTaskStore::new();
503 store.insert(test_task("t1")).await.unwrap();
504 let mut rx = store.subscribe("t1").await.unwrap();
505
506 store
507 .update_status(
508 "t1",
509 TaskStatus {
510 state: TaskState::Working,
511 message: None,
512 timestamp: None,
513 },
514 )
515 .await
516 .unwrap();
517
518 let event = rx.recv().await.unwrap();
519 match event {
520 StreamResponse::StatusUpdate(e) => {
521 assert_eq!(e.task_id, "t1");
522 assert_eq!(e.status.state, TaskState::Working);
523 }
524 _ => panic!("expected StatusUpdate event"),
525 }
526 }
527
528 #[tokio::test]
529 async fn add_artifact_broadcasts_event() {
530 let store = InMemoryTaskStore::new();
531 store.insert(test_task("t1")).await.unwrap();
532 let mut rx = store.subscribe("t1").await.unwrap();
533
534 let artifact = Artifact {
535 artifact_id: "a1".into(),
536 name: Some("test".into()),
537 description: None,
538 parts: vec![],
539 metadata: None,
540 extensions: vec![],
541 };
542 store.add_artifact("t1", artifact).await.unwrap();
543
544 let event = rx.recv().await.unwrap();
545 match event {
546 StreamResponse::ArtifactUpdate(e) => {
547 assert_eq!(e.task_id, "t1");
548 assert_eq!(e.artifact.artifact_id, "a1");
549 }
550 _ => panic!("expected ArtifactUpdate event"),
551 }
552 }
553
554 #[tokio::test]
555 async fn subscribe_missing_task_returns_none() {
556 let store = InMemoryTaskStore::new();
557 assert!(store.subscribe("nope").await.is_none());
558 }
559
560 #[tokio::test]
561 async fn cancel_missing_task_returns_error() {
562 let store = InMemoryTaskStore::new();
563 let result = store.cancel("nope").await;
564 assert!(result.is_err());
565 }
566}