Skip to main content

a2a_rust/
store.rs

1use std::cmp::Reverse;
2use std::collections::BTreeMap;
3use std::time::{Duration, Instant};
4
5use async_trait::async_trait;
6use tokio::sync::RwLock;
7
8use crate::A2AError;
9use crate::types::{ListTasksRequest, ListTasksResponse, Task};
10
11/// Persistence abstraction for task state exposed by the A2A server.
12#[async_trait]
13pub trait TaskStore: Send + Sync + 'static {
14    /// Fetch a task by its identifier.
15    async fn get(&self, task_id: &str) -> Result<Option<Task>, A2AError>;
16
17    /// Insert or replace a task snapshot.
18    async fn put(&self, task: &Task) -> Result<(), A2AError>;
19
20    /// Implementations should reject invalid pagination inputs or delegate to
21    /// `ListTasksRequest::validate()` before applying query semantics.
22    async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError>;
23
24    /// Delete a task by identifier.
25    async fn delete(&self, task_id: &str) -> Result<bool, A2AError>;
26}
27
28/// Configuration for the in-memory task store.
29#[derive(Debug, Clone, Copy, Default)]
30pub struct InMemoryTaskStoreConfig {
31    /// Maximum age for stored entries before they are purged on access.
32    pub entry_ttl: Option<Duration>,
33    /// Maximum number of tasks retained before least-recently-used eviction.
34    pub max_entries: Option<usize>,
35}
36
37#[derive(Debug, Clone)]
38struct StoredTask {
39    task: Task,
40    updated_at: Instant,
41    last_accessed_at: Instant,
42}
43
44/// In-process task store with TTL expiry and LRU capacity eviction.
45#[derive(Debug)]
46pub struct InMemoryTaskStore {
47    config: InMemoryTaskStoreConfig,
48    tasks: RwLock<BTreeMap<String, StoredTask>>,
49}
50
51impl Default for InMemoryTaskStore {
52    fn default() -> Self {
53        Self::with_config(InMemoryTaskStoreConfig::default())
54    }
55}
56
57impl InMemoryTaskStore {
58    /// Create a store with default configuration.
59    pub fn new() -> Self {
60        Self::default()
61    }
62
63    /// Create a store with explicit TTL and capacity settings.
64    pub fn with_config(config: InMemoryTaskStoreConfig) -> Self {
65        Self {
66            config,
67            tasks: RwLock::new(BTreeMap::new()),
68        }
69    }
70}
71
72#[async_trait]
73impl TaskStore for InMemoryTaskStore {
74    async fn get(&self, task_id: &str) -> Result<Option<Task>, A2AError> {
75        let mut tasks = self.tasks.write().await;
76        purge_expired(&mut tasks, self.config);
77
78        Ok(tasks.get_mut(task_id).map(|stored| {
79            stored.last_accessed_at = Instant::now();
80            stored.task.clone()
81        }))
82    }
83
84    async fn put(&self, task: &Task) -> Result<(), A2AError> {
85        let mut tasks = self.tasks.write().await;
86        purge_expired(&mut tasks, self.config);
87
88        let now = Instant::now();
89        tasks.insert(
90            task.id.clone(),
91            StoredTask {
92                task: task.clone(),
93                updated_at: now,
94                last_accessed_at: now,
95            },
96        );
97        enforce_capacity(&mut tasks, self.config.max_entries);
98        Ok(())
99    }
100
101    async fn list(&self, req: &ListTasksRequest) -> Result<ListTasksResponse, A2AError> {
102        req.validate()?;
103
104        let mut tasks = self.tasks.write().await;
105        purge_expired(&mut tasks, self.config);
106
107        let mut matching_tasks: Vec<Task> =
108            tasks.values().map(|stored| stored.task.clone()).collect();
109        matching_tasks.retain(|task| task_matches(task, req));
110        matching_tasks.sort_by_key(|task| Reverse(task_sort_key(task)));
111
112        // The in-memory store currently uses offset-style tokens for simplicity.
113        // Downstream stores should prefer stable cursors that do not shift under writes.
114        let start = req
115            .page_token
116            .as_deref()
117            .unwrap_or("0")
118            .parse::<usize>()
119            .map_err(|_| A2AError::InvalidRequest("invalid pageToken".to_owned()))?;
120        let requested_page_size = req.page_size.unwrap_or(50);
121        let page_size = requested_page_size.clamp(1, 100) as usize;
122        let total_size = matching_tasks.len() as i32;
123        let page = matching_tasks
124            .into_iter()
125            .skip(start)
126            .take(page_size)
127            .map(|mut task| {
128                apply_history_length(&mut task, req.history_length);
129                if req.include_artifacts != Some(true) {
130                    task.artifacts.clear();
131                }
132                task
133            })
134            .collect::<Vec<_>>();
135        let accessed_at = Instant::now();
136        for task in &page {
137            if let Some(stored) = tasks.get_mut(&task.id) {
138                stored.last_accessed_at = accessed_at;
139            }
140        }
141
142        let next_start = start + page.len();
143        let next_page_token = if next_start >= total_size as usize {
144            String::new()
145        } else {
146            next_start.to_string()
147        };
148
149        Ok(ListTasksResponse {
150            tasks: page,
151            next_page_token,
152            page_size: requested_page_size,
153            total_size,
154        })
155    }
156
157    async fn delete(&self, task_id: &str) -> Result<bool, A2AError> {
158        let mut tasks = self.tasks.write().await;
159        purge_expired(&mut tasks, self.config);
160
161        Ok(tasks.remove(task_id).is_some())
162    }
163}
164
165fn purge_expired(tasks: &mut BTreeMap<String, StoredTask>, config: InMemoryTaskStoreConfig) {
166    let Some(entry_ttl) = config.entry_ttl else {
167        return;
168    };
169
170    let now = Instant::now();
171    tasks.retain(|_, stored| now.duration_since(stored.updated_at) < entry_ttl);
172}
173
174fn enforce_capacity(tasks: &mut BTreeMap<String, StoredTask>, max_entries: Option<usize>) {
175    let Some(max_entries) = max_entries else {
176        return;
177    };
178
179    while tasks.len() > max_entries {
180        let Some(oldest_key) = tasks
181            .iter()
182            .min_by(|(left_id, left), (right_id, right)| {
183                left.last_accessed_at
184                    .cmp(&right.last_accessed_at)
185                    .then_with(|| left_id.cmp(right_id))
186            })
187            .map(|(task_id, _)| task_id.clone())
188        else {
189            break;
190        };
191
192        tasks.remove(&oldest_key);
193    }
194}
195
196fn task_matches(task: &Task, req: &ListTasksRequest) -> bool {
197    if let Some(context_id) = &req.context_id
198        && task.context_id.as_deref() != Some(context_id.as_str())
199    {
200        return false;
201    }
202
203    if let Some(status) = req.status
204        && task.status.state != status
205    {
206        return false;
207    }
208
209    if let Some(after) = &req.status_timestamp_after {
210        let Some(timestamp) = task.status.timestamp.as_ref() else {
211            return false;
212        };
213
214        if timestamp < after {
215            return false;
216        }
217    }
218
219    true
220}
221
222fn task_sort_key(task: &Task) -> (String, String) {
223    (
224        task.status.timestamp.clone().unwrap_or_default(),
225        task.id.clone(),
226    )
227}
228
229fn apply_history_length(task: &mut Task, history_length: Option<i32>) {
230    let Some(history_length) = history_length else {
231        return;
232    };
233
234    if history_length <= 0 {
235        task.history.clear();
236        return;
237    }
238
239    let keep = history_length as usize;
240    if task.history.len() > keep {
241        let start = task.history.len() - keep;
242        task.history = task.history.split_off(start);
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use std::sync::Arc;
249    use std::time::Duration;
250
251    use tokio::time::sleep;
252
253    use super::{InMemoryTaskStore, InMemoryTaskStoreConfig, TaskStore};
254    use crate::types::{ListTasksRequest, Task, TaskState, TaskStatus};
255
256    #[tokio::test]
257    async fn in_memory_task_store_lists_tasks_in_timestamp_order() {
258        let store = InMemoryTaskStore::new();
259
260        store
261            .put(&Task {
262                id: "task-1".to_owned(),
263                context_id: Some("ctx-1".to_owned()),
264                status: TaskStatus {
265                    state: TaskState::Submitted,
266                    message: None,
267                    timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
268                },
269                artifacts: Vec::new(),
270                history: Vec::new(),
271                metadata: None,
272            })
273            .await
274            .expect("task should store");
275
276        store
277            .put(&Task {
278                id: "task-2".to_owned(),
279                context_id: Some("ctx-1".to_owned()),
280                status: TaskStatus {
281                    state: TaskState::Working,
282                    message: None,
283                    timestamp: Some("2026-03-12T13:00:00Z".to_owned()),
284                },
285                artifacts: Vec::new(),
286                history: Vec::new(),
287                metadata: None,
288            })
289            .await
290            .expect("task should store");
291
292        let response = store
293            .list(&ListTasksRequest {
294                tenant: None,
295                context_id: Some("ctx-1".to_owned()),
296                status: None,
297                page_size: Some(10),
298                page_token: None,
299                history_length: None,
300                status_timestamp_after: None,
301                include_artifacts: None,
302            })
303            .await
304            .expect("tasks should list");
305
306        assert_eq!(response.tasks.len(), 2);
307        assert_eq!(response.tasks[0].id, "task-2");
308        assert_eq!(response.tasks[1].id, "task-1");
309        assert_eq!(response.next_page_token, "");
310    }
311
312    #[tokio::test]
313    async fn in_memory_task_store_excludes_artifacts_by_default() {
314        let store = InMemoryTaskStore::new();
315
316        store
317            .put(&Task {
318                id: "task-1".to_owned(),
319                context_id: Some("ctx-1".to_owned()),
320                status: TaskStatus {
321                    state: TaskState::Completed,
322                    message: None,
323                    timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
324                },
325                artifacts: vec![crate::types::Artifact {
326                    artifact_id: "artifact-1".to_owned(),
327                    name: None,
328                    description: None,
329                    parts: vec![crate::types::Part {
330                        text: Some("done".to_owned()),
331                        raw: None,
332                        url: None,
333                        data: None,
334                        metadata: None,
335                        filename: None,
336                        media_type: None,
337                    }],
338                    metadata: None,
339                    extensions: Vec::new(),
340                }],
341                history: Vec::new(),
342                metadata: None,
343            })
344            .await
345            .expect("task should store");
346
347        let response = store
348            .list(&ListTasksRequest {
349                tenant: None,
350                context_id: None,
351                status: None,
352                page_size: None,
353                page_token: None,
354                history_length: None,
355                status_timestamp_after: None,
356                include_artifacts: None,
357            })
358            .await
359            .expect("tasks should list");
360
361        assert_eq!(response.tasks.len(), 1);
362        assert!(response.tasks[0].artifacts.is_empty());
363        assert_eq!(response.page_size, 50);
364    }
365
366    #[tokio::test]
367    async fn in_memory_task_store_expires_entries_by_ttl() {
368        let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
369            entry_ttl: Some(Duration::from_millis(5)),
370            max_entries: None,
371        });
372
373        store
374            .put(&Task {
375                id: "task-1".to_owned(),
376                context_id: Some("ctx-1".to_owned()),
377                status: TaskStatus {
378                    state: TaskState::Submitted,
379                    message: None,
380                    timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
381                },
382                artifacts: Vec::new(),
383                history: Vec::new(),
384                metadata: None,
385            })
386            .await
387            .expect("task should store");
388
389        sleep(Duration::from_millis(10)).await;
390
391        let task = store.get("task-1").await.expect("lookup should succeed");
392        assert!(task.is_none());
393    }
394
395    #[tokio::test]
396    async fn in_memory_task_store_evicts_least_recently_used_when_capacity_is_exceeded() {
397        let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
398            entry_ttl: None,
399            max_entries: Some(2),
400        });
401
402        store
403            .put(&Task {
404                id: "task-1".to_owned(),
405                context_id: Some("ctx-1".to_owned()),
406                status: TaskStatus {
407                    state: TaskState::Submitted,
408                    message: None,
409                    timestamp: Some("2026-03-12T12:00:00Z".to_owned()),
410                },
411                artifacts: Vec::new(),
412                history: Vec::new(),
413                metadata: None,
414            })
415            .await
416            .expect("task should store");
417        sleep(Duration::from_millis(2)).await;
418
419        store
420            .put(&Task {
421                id: "task-2".to_owned(),
422                context_id: Some("ctx-2".to_owned()),
423                status: TaskStatus {
424                    state: TaskState::Working,
425                    message: None,
426                    timestamp: Some("2026-03-12T12:01:00Z".to_owned()),
427                },
428                artifacts: Vec::new(),
429                history: Vec::new(),
430                metadata: None,
431            })
432            .await
433            .expect("task should store");
434        sleep(Duration::from_millis(2)).await;
435
436        assert!(
437            store
438                .get("task-1")
439                .await
440                .expect("lookup should succeed")
441                .is_some()
442        );
443        sleep(Duration::from_millis(2)).await;
444
445        store
446            .put(&Task {
447                id: "task-3".to_owned(),
448                context_id: Some("ctx-3".to_owned()),
449                status: TaskStatus {
450                    state: TaskState::Completed,
451                    message: None,
452                    timestamp: Some("2026-03-12T12:02:00Z".to_owned()),
453                },
454                artifacts: Vec::new(),
455                history: Vec::new(),
456                metadata: None,
457            })
458            .await
459            .expect("task should store");
460
461        assert!(
462            store
463                .get("task-1")
464                .await
465                .expect("lookup should succeed")
466                .is_some()
467        );
468        assert!(
469            store
470                .get("task-2")
471                .await
472                .expect("lookup should succeed")
473                .is_none()
474        );
475        assert!(
476            store
477                .get("task-3")
478                .await
479                .expect("lookup should succeed")
480                .is_some()
481        );
482    }
483
484    #[tokio::test]
485    async fn in_memory_task_store_supports_concurrent_reads_and_writes() {
486        let store = InMemoryTaskStore::with_config(InMemoryTaskStoreConfig {
487            entry_ttl: None,
488            max_entries: None,
489        });
490        let store = Arc::new(store);
491        let mut tasks = Vec::new();
492
493        for index in 0..16 {
494            let store = Arc::clone(&store);
495            tasks.push(tokio::spawn(async move {
496                let task_id = format!("task-{index}");
497                store
498                    .put(&Task {
499                        id: task_id.clone(),
500                        context_id: Some("ctx-1".to_owned()),
501                        status: TaskStatus {
502                            state: TaskState::Working,
503                            message: None,
504                            timestamp: Some(format!("2026-03-12T12:{index:02}:00Z")),
505                        },
506                        artifacts: Vec::new(),
507                        history: Vec::new(),
508                        metadata: None,
509                    })
510                    .await
511                    .expect("task should store");
512
513                let fetched = store.get(&task_id).await.expect("lookup should succeed");
514                assert!(fetched.is_some());
515            }));
516        }
517
518        for task in tasks {
519            task.await.expect("task should join");
520        }
521
522        let response = store
523            .list(&ListTasksRequest {
524                tenant: None,
525                context_id: Some("ctx-1".to_owned()),
526                status: None,
527                page_size: Some(100),
528                page_token: None,
529                history_length: None,
530                status_timestamp_after: None,
531                include_artifacts: Some(true),
532            })
533            .await
534            .expect("tasks should list");
535
536        assert_eq!(response.tasks.len(), 16);
537    }
538}