mod eviction;
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::time::Instant;
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::responses::TaskListResponse;
use a2a_protocol_types::task::{Task, TaskId};
use tokio::sync::RwLock;
use super::{TaskStore, TaskStoreConfig};
#[derive(Debug, Clone)]
pub(super) struct TaskEntry {
pub(super) task: Task,
pub(super) last_updated: Instant,
}
#[derive(Debug)]
pub struct InMemoryTaskStore {
pub(super) entries: RwLock<BTreeMap<TaskId, TaskEntry>>,
pub(super) config: TaskStoreConfig,
pub(super) write_count: std::sync::atomic::AtomicU64,
pub(super) eviction_in_progress: std::sync::atomic::AtomicBool,
}
impl Default for InMemoryTaskStore {
fn default() -> Self {
Self::new()
}
}
impl InMemoryTaskStore {
#[must_use]
pub fn new() -> Self {
Self {
entries: RwLock::new(BTreeMap::new()),
config: TaskStoreConfig::default(),
write_count: std::sync::atomic::AtomicU64::new(0),
eviction_in_progress: std::sync::atomic::AtomicBool::new(false),
}
}
#[must_use]
pub fn with_config(config: TaskStoreConfig) -> Self {
Self {
entries: RwLock::new(BTreeMap::new()),
config,
write_count: std::sync::atomic::AtomicU64::new(0),
eviction_in_progress: std::sync::atomic::AtomicBool::new(false),
}
}
}
#[allow(clippy::manual_async_fn)]
impl TaskStore for InMemoryTaskStore {
fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
trace_debug!(task_id = %task.id, state = ?task.status.state, "saving task");
let needs_eviction = {
let mut store = self.entries.write().await;
store.insert(
task.id.clone(),
TaskEntry {
task,
last_updated: Instant::now(),
},
);
let len = store.len();
drop(store);
self.should_evict(len)
};
if needs_eviction {
self.maybe_evict().await;
}
Ok(())
})
}
fn get<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
Box::pin(async move {
trace_debug!(task_id = %id, "fetching task");
let store = self.entries.read().await;
let result = store.get(id).map(|e| e.task.clone());
drop(store);
Ok(result)
})
}
#[allow(clippy::too_many_lines)]
fn list<'a>(
&'a self,
params: &'a ListTasksParams,
) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
Box::pin(async move {
let store = self.entries.read().await;
let page_size = match params.page_size {
Some(0) | None => 50_usize,
Some(n) => (n.min(self.config.max_page_size)) as usize,
};
let filter = |e: &TaskEntry| -> bool {
if let Some(ref ctx) = params.context_id {
if e.task.context_id.0 != *ctx {
return false;
}
}
if let Some(ref status) = params.status {
if e.task.status.state != *status {
return false;
}
}
true
};
let iter: Box<dyn Iterator<Item = (&TaskId, &TaskEntry)>> =
if let Some(ref token) = params.page_token {
let cursor = TaskId::new(token.clone());
if !store.contains_key(&cursor) {
let empty: Vec<Task> = Vec::new();
let response = TaskListResponse::new(empty);
return Ok(response);
}
Box::new(store.range((
std::ops::Bound::Excluded(cursor),
std::ops::Bound::Unbounded,
)))
} else {
Box::new(store.iter())
};
let tasks: Vec<Task> = iter
.filter(|(_, e)| filter(e))
.take(page_size + 1)
.map(|(_, e)| e.task.clone())
.collect();
#[allow(clippy::cast_possible_truncation)]
let total_size = store.len() as u32;
drop(store);
let has_next_page = tasks.len() > page_size;
let next_page_token = if has_next_page {
tasks
.get(page_size.saturating_sub(1))
.map(|t| t.id.0.clone())
.unwrap_or_default()
} else {
String::new()
};
let mut tasks = tasks;
tasks.truncate(page_size);
let mut response = TaskListResponse::new(tasks);
response.next_page_token = next_page_token;
#[allow(clippy::cast_possible_truncation)]
{
response.page_size = page_size as u32;
}
response.total_size = total_size;
Ok(response)
})
}
fn insert_if_absent<'a>(
&'a self,
task: Task,
) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
Box::pin(async move {
let (inserted, needs_eviction) = {
let mut store = self.entries.write().await;
if store.contains_key(&task.id) {
return Ok(false);
}
store.insert(
task.id.clone(),
TaskEntry {
task,
last_updated: Instant::now(),
},
);
let len = store.len();
drop(store);
(true, self.should_evict(len))
};
if needs_eviction {
self.maybe_evict().await;
}
Ok(inserted)
})
}
fn delete<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let mut store = self.entries.write().await;
store.remove(id);
drop(store);
Ok(())
})
}
fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
Box::pin(async move {
let store = self.entries.read().await;
Ok(store.len() as u64)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
use std::time::Duration;
fn make_task(id: &str, state: TaskState) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new("ctx-default"),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
fn make_task_with_ctx(id: &str, ctx: &str, state: TaskState) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new(ctx),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
#[tokio::test]
async fn save_and_get_returns_task() {
let store = InMemoryTaskStore::new();
let task = make_task("t1", TaskState::Submitted);
store.save(task.clone()).await.unwrap();
let fetched = store.get(&TaskId::new("t1")).await.unwrap();
assert!(fetched.is_some(), "saved task should be retrievable");
assert_eq!(fetched.unwrap().id, task.id);
}
#[tokio::test]
async fn get_nonexistent_returns_none() {
let store = InMemoryTaskStore::new();
let result = store.get(&TaskId::new("no-such-task")).await.unwrap();
assert!(result.is_none(), "missing task should return None");
}
#[tokio::test]
async fn save_overwrites_existing_task() {
let store = InMemoryTaskStore::new();
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t1", TaskState::Working))
.await
.unwrap();
let fetched = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
assert_eq!(
fetched.status.state,
TaskState::Working,
"save should overwrite existing task"
);
}
#[tokio::test]
async fn delete_removes_task() {
let store = InMemoryTaskStore::new();
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store.delete(&TaskId::new("t1")).await.unwrap();
let result = store.get(&TaskId::new("t1")).await.unwrap();
assert!(result.is_none(), "deleted task should no longer exist");
}
#[tokio::test]
async fn delete_nonexistent_is_ok() {
let store = InMemoryTaskStore::new();
store.delete(&TaskId::new("ghost")).await.unwrap();
}
#[tokio::test]
async fn insert_if_absent_inserts_new_task() {
let store = InMemoryTaskStore::new();
let inserted = store
.insert_if_absent(make_task("t1", TaskState::Submitted))
.await
.unwrap();
assert!(inserted, "first insert should succeed");
let fetched = store.get(&TaskId::new("t1")).await.unwrap();
assert!(fetched.is_some());
}
#[tokio::test]
async fn insert_if_absent_rejects_duplicate() {
let store = InMemoryTaskStore::new();
store
.insert_if_absent(make_task("t1", TaskState::Submitted))
.await
.unwrap();
let second = store
.insert_if_absent(make_task("t1", TaskState::Working))
.await
.unwrap();
assert!(!second, "duplicate insert should return false");
let fetched = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
assert_eq!(
fetched.status.state,
TaskState::Submitted,
"original task should not be overwritten by insert_if_absent"
);
}
#[tokio::test]
async fn count_empty_store() {
let store = InMemoryTaskStore::new();
assert_eq!(store.count().await.unwrap(), 0);
}
#[tokio::test]
async fn count_reflects_saves_and_deletes() {
let store = InMemoryTaskStore::new();
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", TaskState::Working))
.await
.unwrap();
assert_eq!(store.count().await.unwrap(), 2);
store.delete(&TaskId::new("t1")).await.unwrap();
assert_eq!(store.count().await.unwrap(), 1);
}
#[tokio::test]
async fn list_empty_store_returns_empty() {
let store = InMemoryTaskStore::new();
let params = ListTasksParams::default();
let response = store.list(¶ms).await.unwrap();
assert!(response.tasks.is_empty());
assert!(response.next_page_token.is_empty());
}
#[tokio::test]
async fn list_returns_all_tasks_sorted_by_id() {
let store = InMemoryTaskStore::new();
store
.save(make_task("c", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("a", TaskState::Working))
.await
.unwrap();
store
.save(make_task("b", TaskState::Completed))
.await
.unwrap();
let params = ListTasksParams::default();
let response = store.list(¶ms).await.unwrap();
let ids: Vec<&str> = response.tasks.iter().map(|t| t.id.0.as_str()).collect();
assert_eq!(ids, vec!["a", "b", "c"], "tasks should be sorted by ID");
}
#[tokio::test]
async fn list_filters_by_context_id() {
let store = InMemoryTaskStore::new();
store
.save(make_task_with_ctx("t1", "ctx-a", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task_with_ctx("t2", "ctx-b", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task_with_ctx("t3", "ctx-a", TaskState::Working))
.await
.unwrap();
let params = ListTasksParams {
context_id: Some("ctx-a".to_string()),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(response.tasks.len(), 2);
assert!(response.tasks.iter().all(|t| t.context_id.0 == "ctx-a"));
}
#[tokio::test]
async fn list_filters_by_status() {
let store = InMemoryTaskStore::new();
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", TaskState::Working))
.await
.unwrap();
store
.save(make_task("t3", TaskState::Submitted))
.await
.unwrap();
let params = ListTasksParams {
status: Some(TaskState::Submitted),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(response.tasks.len(), 2);
}
#[tokio::test]
async fn list_pagination_page_size() {
let store = InMemoryTaskStore::new();
for i in 0..5 {
store
.save(make_task(&format!("t{i:02}"), TaskState::Submitted))
.await
.unwrap();
}
let params = ListTasksParams {
page_size: Some(2),
..Default::default()
};
let page1 = store.list(¶ms).await.unwrap();
assert_eq!(page1.tasks.len(), 2, "first page should have 2 tasks");
assert!(
!page1.next_page_token.is_empty(),
"should have next_page_token when more results exist"
);
let params2 = ListTasksParams {
page_size: Some(2),
page_token: Some(page1.next_page_token),
..Default::default()
};
let page2 = store.list(¶ms2).await.unwrap();
assert_eq!(page2.tasks.len(), 2, "second page should have 2 tasks");
let params3 = ListTasksParams {
page_size: Some(2),
page_token: Some(page2.next_page_token),
..Default::default()
};
let page3 = store.list(¶ms3).await.unwrap();
assert_eq!(page3.tasks.len(), 1, "third page should have 1 task");
assert!(
page3.next_page_token.is_empty(),
"no more pages after the last task"
);
}
#[tokio::test]
async fn list_invalid_page_token_returns_empty() {
let store = InMemoryTaskStore::new();
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
let params = ListTasksParams {
page_token: Some("nonexistent-cursor".to_string()),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert!(
response.tasks.is_empty(),
"invalid page_token should yield empty results"
);
}
#[tokio::test]
async fn list_page_size_zero_uses_default() {
let store = InMemoryTaskStore::new();
for i in 0..60 {
store
.save(make_task(&format!("t{i:03}"), TaskState::Submitted))
.await
.unwrap();
}
let params = ListTasksParams {
page_size: Some(0),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(
response.tasks.len(),
50,
"page_size=0 should use the default of 50"
);
}
#[tokio::test]
async fn ttl_eviction_removes_expired_terminal_tasks() {
let config = TaskStoreConfig {
max_capacity: None,
task_ttl: Some(Duration::from_millis(1)),
eviction_interval: 1,
max_page_size: 100,
};
let store = InMemoryTaskStore::with_config(config);
store
.save(make_task("terminal", TaskState::Completed))
.await
.unwrap();
store
.save(make_task("active", TaskState::Working))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
store.run_eviction().await;
assert!(
store.get(&TaskId::new("terminal")).await.unwrap().is_none(),
"expired terminal task should be evicted"
);
assert!(
store.get(&TaskId::new("active")).await.unwrap().is_some(),
"non-terminal task should survive TTL eviction"
);
}
#[tokio::test]
async fn ttl_eviction_keeps_fresh_terminal_tasks() {
let config = TaskStoreConfig {
max_capacity: None,
task_ttl: Some(Duration::from_secs(3600)),
eviction_interval: 1,
max_page_size: 100,
};
let store = InMemoryTaskStore::with_config(config);
store
.save(make_task("t1", TaskState::Completed))
.await
.unwrap();
store.run_eviction().await;
assert!(
store.get(&TaskId::new("t1")).await.unwrap().is_some(),
"fresh terminal task should not be evicted"
);
}
#[tokio::test]
async fn max_capacity_eviction_removes_oldest_terminal_tasks() {
let config = TaskStoreConfig {
max_capacity: Some(2),
task_ttl: None,
eviction_interval: 1,
max_page_size: 100,
};
let store = InMemoryTaskStore::with_config(config);
store
.save(make_task("oldest", TaskState::Completed))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("middle", TaskState::Completed))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("newest", TaskState::Completed))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(
store.get(&TaskId::new("oldest")).await.unwrap().is_none(),
"oldest terminal task should be evicted when over capacity"
);
assert_eq!(
store.count().await.unwrap(),
2,
"store should be back at max capacity"
);
}
#[tokio::test]
async fn capacity_eviction_prefers_terminal_tasks() {
let config = TaskStoreConfig {
max_capacity: Some(2),
task_ttl: None,
eviction_interval: 1,
max_page_size: 100,
};
let store = InMemoryTaskStore::with_config(config);
store
.save(make_task("active", TaskState::Working))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("done", TaskState::Completed))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("new", TaskState::Submitted))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(
store.get(&TaskId::new("active")).await.unwrap().is_some(),
"non-terminal task should survive capacity eviction"
);
assert!(
store.get(&TaskId::new("done")).await.unwrap().is_none(),
"terminal task should be evicted first"
);
}
#[tokio::test]
async fn capacity_eviction_falls_back_to_non_terminal_when_needed() {
let config = TaskStoreConfig {
max_capacity: Some(2),
task_ttl: None,
eviction_interval: 1,
max_page_size: 100,
};
let store = InMemoryTaskStore::with_config(config);
store
.save(make_task("oldest-active", TaskState::Working))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("middle-active", TaskState::Submitted))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
store
.save(make_task("newest-active", TaskState::Working))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
assert!(
store
.get(&TaskId::new("oldest-active"))
.await
.unwrap()
.is_none(),
"oldest non-terminal task should be evicted as fallback"
);
assert_eq!(
store.count().await.unwrap(),
2,
"store should be at max capacity after fallback eviction"
);
}
#[test]
fn default_creates_new_store() {
let store = InMemoryTaskStore::default();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let count = rt.block_on(store.count()).unwrap();
assert_eq!(count, 0, "default store should be empty");
}
#[test]
fn default_config_has_expected_values() {
let cfg = TaskStoreConfig::default();
assert_eq!(cfg.max_capacity, Some(10_000));
assert_eq!(cfg.task_ttl, Some(Duration::from_secs(3600)));
assert_eq!(cfg.eviction_interval, 64);
assert_eq!(cfg.max_page_size, 1000);
}
}