1use std::cmp::Reverse;
2use std::collections::BTreeMap;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use crate::A2AError;
9use crate::types::{ListTasksRequest, ListTasksResponse, Task};
10
11#[async_trait]
13pub trait TaskStore: Send + Sync + 'static {
14 async fn get(&self, task_id: &str) -> Result<Option<Task>, A2AError>;
16
17 async fn put(&self, task: &Task) -> Result<(), A2AError>;
19
20 async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError>;
23
24 async fn delete(&self, task_id: &str) -> Result<bool, A2AError>;
26}
27
28#[derive(Debug, Clone, Copy, Default)]
30pub struct InMemoryTaskStoreConfig {
31 pub entry_ttl: Option<Duration>,
33 pub max_entries: Option<usize>,
35}
36
37#[derive(Debug, Clone)]
38struct StoredTask {
39 task: Task,
40 updated_at: Instant,
41 last_accessed_at: Instant,
42}
43
44#[derive(Debug)]
46pub struct InMemoryTaskStore {
47 config: InMemoryTaskStoreConfig,
48 tasks: RwLock<BTreeMap<String, StoredTask>>,
49}
50
51impl Default for InMemoryTaskStore {
52 fn default() -> Self {
53 Self::with_config(InMemoryTaskStoreConfig::default())
54 }
55}
56
57impl InMemoryTaskStore {
58 pub fn new() -> Self {
60 Self::default()
61 }
62
63 pub fn with_config(config: InMemoryTaskStoreConfig) -> Self {
65 Self {
66 config,
67 tasks: RwLock::new(BTreeMap::new()),
68 }
69 }
70}
71
72#[async_trait]
73impl TaskStore for InMemoryTaskStore {
74 async fn get(&self, task_id: &str) -> Result<Option<Task>, A2AError> {
75 let mut tasks = self.tasks.write().await;
76 purge_expired(&mut tasks, self.config);
77
78 Ok(tasks.get_mut(task_id).map(|stored| {
79 stored.last_accessed_at = Instant::now();
80 stored.task.clone()
81 }))
82 }
83
84 async fn put(&self, task: &Task) -> Result<(), A2AError> {
85 let mut tasks = self.tasks.write().await;
86 purge_expired(&mut tasks, self.config);
87
88 let now = Instant::now();
89 tasks.insert(
90 task.id.clone(),
91 StoredTask {
92 task: task.clone(),
93 updated_at: now,
94 last_accessed_at: now,
95 },
96 );
97 enforce_capacity(&mut tasks, self.config.max_entries);
98 Ok(())
99 }
100
101 async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
102 req.validate()?;
103
104 let mut tasks = self.tasks.write().await;
105 purge_expired(&mut tasks, self.config);
106
107 let mut matching_tasks: Vec<Task> =
108 tasks.values().map(|stored| stored.task.clone()).collect();
109 matching_tasks.retain(|task| task_matches(task, req));
110 matching_tasks.sort_by_key(|task| Reverse(task_sort_key(task)));
111
112 let start = req
115 .page_token
116 .as_deref()
117 .unwrap_or("0")
118 .parse::<usize>()
119 .map_err(|_| A2AError::InvalidRequest("invalid pageToken".to_owned()))?;
120 let requested_page_size = req.page_size.unwrap_or(50);
121 let page_size = requested_page_size.clamp(1, 100) as usize;
122 let total_size = matching_tasks.len() as i32;
123 let page = matching_tasks
124 .into_iter()
125 .skip(start)
126 .take(page_size)
127 .map(|mut task| {
128 apply_history_length(&mut task, req.history_length);
129 if req.include_artifacts != Some(true) {
130 task.artifacts.clear();
131 }
132 task
133 })
134 .collect::<Vec<_>>();
135 let accessed_at = Instant::now();
136 for task in &page {
137 if let Some(stored) = tasks.get_mut(&task.id) {
138 stored.last_accessed_at = accessed_at;
139 }
140 }
141
142 let next_start = start + page.len();
143 let next_page_token = if next_start >= total_size as usize {
144 String::new()
145 } else {
146 next_start.to_string()
147 };
148
149 Ok(ListTasksResponse {
150 tasks: page,
151 next_page_token,
152 page_size: requested_page_size,
153 total_size,
154 })
155 }
156
157 async fn delete(&self, task_id: &str) -> Result<bool, A2AError> {
158 let mut tasks = self.tasks.write().await;
159 purge_expired(&mut tasks, self.config);
160
161 Ok(tasks.remove(task_id).is_some())
162 }
163}
164
165fn purge_expired(tasks: &mut BTreeMap<String, StoredTask>, config: InMemoryTaskStoreConfig) {
166 let Some(entry_ttl) = config.entry_ttl else {
167 return;
168 };
169
170 let now = Instant::now();
171 tasks.retain(|_, stored| now.duration_since(stored.updated_at) < entry_ttl);
172}
173
174fn enforce_capacity(tasks: &mut BTreeMap<String, StoredTask>, max_entries: Option<usize>) {
175 let Some(max_entries) = max_entries else {
176 return;
177 };
178
179 while tasks.len() > max_entries {
180 let Some(oldest_key) = tasks
181 .iter()
182 .min_by(|(left_id, left), (right_id, right)| {
183 left.last_accessed_at
184 .cmp(&right.last_accessed_at)
185 .then_with(|| left_id.cmp(right_id))
186 })
187 .map(|(task_id, _)| task_id.clone())
188 else {
189 break;
190 };
191
192 tasks.remove(&oldest_key);
193 }
194}
195
196fn task_matches(task: &Task, req: &ListTasksRequest) -> bool {
197 if let Some(context_id) = &req.context_id
198 && task.context_id.as_deref() != Some(context_id.as_str())
199 {
200 return false;
201 }
202
203 if let Some(status) = req.status
204 && task.status.state != status
205 {
206 return false;
207 }
208
209 if let Some(after) = &req.status_timestamp_after {
210 let Some(timestamp) = task.status.timestamp.as_ref() else {
211 return false;
212 };
213
214 if timestamp < after {
215 return false;
216 }
217 }
218
219 true
220}
221
222fn task_sort_key(task: &Task) -> (String, String) {
223 (
224 task.status.timestamp.clone().unwrap_or_default(),
225 task.id.clone(),
226 )
227}
228
229fn apply_history_length(task: &mut Task, history_length: Option<i32>) {
230 let Some(history_length) = history_length else {
231 return;
232 };
233
234 if history_length <= 0 {
235 task.history.clear();
236 return;
237 }
238
239 let keep = history_length as usize;
240 if task.history.len() > keep {
241 let start = task.history.len() - keep;
242 task.history = task.history.split_off(start);
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use std::sync::Arc;
249 use std::time::Duration;
250
251 use tokio::time::sleep;
252
253 use super::{InMemoryTaskStore, InMemoryTaskStoreConfig, TaskStore};
254 use crate::types::{ListTasksRequest, Task, TaskState, TaskStatus};
255
256 #[tokio::test]
257 async fn in_memory_task_store_lists_tasks_in_timestamp_order() {
258 let store = InMemoryTaskStore::new();
259
260 store
261 .put(&Task {
262 id: "task-1".to_owned(),
263 context_id: Some("ctx-1".to_owned()),
264 status: TaskStatus {
265 state: TaskState::Submitted,
266 message: None,
267 timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
268 },
269 artifacts: Vec::new(),
270 history: Vec::new(),
271 metadata: None,
272 })
273 .await
274 .expect("task should store");
275
276 store
277 .put(&Task {
278 id: "task-2".to_owned(),
279 context_id: Some("ctx-1".to_owned()),
280 status: TaskStatus {
281 state: TaskState::Working,
282 message: None,
283 timestamp: Some("2026-03-12T13:00:00Z".to_owned()),
284 },
285 artifacts: Vec::new(),
286 history: Vec::new(),
287 metadata: None,
288 })
289 .await
290 .expect("task should store");
291
292 let response = store
293 .list(&ListTasksRequest {
294 tenant: None,
295 context_id: Some("ctx-1".to_owned()),
296 status: None,
297 page_size: Some(10),
298 page_token: None,
299 history_length: None,
300 status_timestamp_after: None,
301 include_artifacts: None,
302 })
303 .await
304 .expect("tasks should list");
305
306 assert_eq!(response.tasks.len(), 2);
307 assert_eq!(response.tasks[0].id, "task-2");
308 assert_eq!(response.tasks[1].id, "task-1");
309 assert_eq!(response.next_page_token, "");
310 }
311
312 #[tokio::test]
313 async fn in_memory_task_store_excludes_artifacts_by_default() {
314 let store = InMemoryTaskStore::new();
315
316 store
317 .put(&Task {
318 id: "task-1".to_owned(),
319 context_id: Some("ctx-1".to_owned()),
320 status: TaskStatus {
321 state: TaskState::Completed,
322 message: None,
323 timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
324 },
325 artifacts: vec![crate::types::Artifact {
326 artifact_id: "artifact-1".to_owned(),
327 name: None,
328 description: None,
329 parts: vec![crate::types::Part {
330 text: Some("done".to_owned()),
331 raw: None,
332 url: None,
333 data: None,
334 metadata: None,
335 filename: None,
336 media_type: None,
337 }],
338 metadata: None,
339 extensions: Vec::new(),
340 }],
341 history: Vec::new(),
342 metadata: None,
343 })
344 .await
345 .expect("task should store");
346
347 let response = store
348 .list(&ListTasksRequest {
349 tenant: None,
350 context_id: None,
351 status: None,
352 page_size: None,
353 page_token: None,
354 history_length: None,
355 status_timestamp_after: None,
356 include_artifacts: None,
357 })
358 .await
359 .expect("tasks should list");
360
361 assert_eq!(response.tasks.len(), 1);
362 assert!(response.tasks[0].artifacts.is_empty());
363 assert_eq!(response.page_size, 50);
364 }
365
366 #[tokio::test]
367 async fn in_memory_task_store_expires_entries_by_ttl() {
368 let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
369 entry_ttl: Some(Duration::from_millis(5)),
370 max_entries: None,
371 });
372
373 store
374 .put(&Task {
375 id: "task-1".to_owned(),
376 context_id: Some("ctx-1".to_owned()),
377 status: TaskStatus {
378 state: TaskState::Submitted,
379 message: None,
380 timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
381 },
382 artifacts: Vec::new(),
383 history: Vec::new(),
384 metadata: None,
385 })
386 .await
387 .expect("task should store");
388
389 sleep(Duration::from_millis(10)).await;
390
391 let task = store.get("task-1").await.expect("lookup should succeed");
392 assert!(task.is_none());
393 }
394
395 #[tokio::test]
396 async fn in_memory_task_store_evicts_least_recently_used_when_capacity_is_exceeded() {
397 let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
398 entry_ttl: None,
399 max_entries: Some(2),
400 });
401
402 store
403 .put(&Task {
404 id: "task-1".to_owned(),
405 context_id: Some("ctx-1".to_owned()),
406 status: TaskStatus {
407 state: TaskState::Submitted,
408 message: None,
409 timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
410 },
411 artifacts: Vec::new(),
412 history: Vec::new(),
413 metadata: None,
414 })
415 .await
416 .expect("task should store");
417 sleep(Duration::from_millis(2)).await;
418
419 store
420 .put(&Task {
421 id: "task-2".to_owned(),
422 context_id: Some("ctx-2".to_owned()),
423 status: TaskStatus {
424 state: TaskState::Working,
425 message: None,
426 timestamp: Some("2026-03-12T12:01:00Z".to_owned()),
427 },
428 artifacts: Vec::new(),
429 history: Vec::new(),
430 metadata: None,
431 })
432 .await
433 .expect("task should store");
434 sleep(Duration::from_millis(2)).await;
435
436 assert!(
437 store
438 .get("task-1")
439 .await
440 .expect("lookup should succeed")
441 .is_some()
442 );
443 sleep(Duration::from_millis(2)).await;
444
445 store
446 .put(&Task {
447 id: "task-3".to_owned(),
448 context_id: Some("ctx-3".to_owned()),
449 status: TaskStatus {
450 state: TaskState::Completed,
451 message: None,
452 timestamp: Some("2026-03-12T12:02:00Z".to_owned()),
453 },
454 artifacts: Vec::new(),
455 history: Vec::new(),
456 metadata: None,
457 })
458 .await
459 .expect("task should store");
460
461 assert!(
462 store
463 .get("task-1")
464 .await
465 .expect("lookup should succeed")
466 .is_some()
467 );
468 assert!(
469 store
470 .get("task-2")
471 .await
472 .expect("lookup should succeed")
473 .is_none()
474 );
475 assert!(
476 store
477 .get("task-3")
478 .await
479 .expect("lookup should succeed")
480 .is_some()
481 );
482 }
483
484 #[tokio::test]
485 async fn in_memory_task_store_supports_concurrent_reads_and_writes() {
486 let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
487 entry_ttl: None,
488 max_entries: None,
489 });
490 let store = Arc::new(store);
491 let mut tasks = Vec::new();
492
493 for index in 0..16 {
494 let store = Arc::clone(&store);
495 tasks.push(tokio::spawn(async move {
496 let task_id = format!("task-{index}");
497 store
498 .put(&Task {
499 id: task_id.clone(),
500 context_id: Some("ctx-1".to_owned()),
501 status: TaskStatus {
502 state: TaskState::Working,
503 message: None,
504 timestamp: Some(format!("2026-03-12T12:{index:02}:00Z")),
505 },
506 artifacts: Vec::new(),
507 history: Vec::new(),
508 metadata: None,
509 })
510 .await
511 .expect("task should store");
512
513 let fetched = store.get(&task_id).await.expect("lookup should succeed");
514 assert!(fetched.is_some());
515 }));
516 }
517
518 for task in tasks {
519 task.await.expect("task should join");
520 }
521
522 let response = store
523 .list(&ListTasksRequest {
524 tenant: None,
525 context_id: Some("ctx-1".to_owned()),
526 status: None,
527 page_size: Some(100),
528 page_token: None,
529 history_length: None,
530 status_timestamp_after: None,
531 include_artifacts: Some(true),
532 })
533 .await
534 .expect("tasks should list");
535
536 assert_eq!(response.tasks.len(), 16);
537 }
538}