use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::responses::TaskListResponse;
use a2a_protocol_types::task::{Task, TaskId};
use tokio::sync::RwLock;
use super::super::task_store::{InMemoryTaskStore, TaskStore, TaskStoreConfig};
use super::context::TenantContext;
#[derive(Debug, Clone)]
pub struct TenantStoreConfig {
pub per_tenant: TaskStoreConfig,
pub max_tenants: usize,
}
impl Default for TenantStoreConfig {
fn default() -> Self {
Self {
per_tenant: TaskStoreConfig::default(),
max_tenants: 1000,
}
}
}
#[derive(Debug)]
pub struct TenantAwareInMemoryTaskStore {
stores: RwLock<HashMap<String, Arc<InMemoryTaskStore>>>,
config: TenantStoreConfig,
}
impl Default for TenantAwareInMemoryTaskStore {
fn default() -> Self {
Self::new()
}
}
impl TenantAwareInMemoryTaskStore {
#[must_use]
pub fn new() -> Self {
Self {
stores: RwLock::new(HashMap::new()),
config: TenantStoreConfig::default(),
}
}
#[must_use]
pub fn with_config(config: TenantStoreConfig) -> Self {
Self {
stores: RwLock::new(HashMap::new()),
config,
}
}
async fn get_store(&self) -> A2aResult<Arc<InMemoryTaskStore>> {
let tenant = TenantContext::current();
{
let stores = self.stores.read().await;
if let Some(store) = stores.get(&tenant) {
return Ok(Arc::clone(store));
}
}
let mut stores = self.stores.write().await;
if let Some(store) = stores.get(&tenant) {
return Ok(Arc::clone(store));
}
if stores.len() >= self.config.max_tenants {
return Err(a2a_protocol_types::error::A2aError::internal(format!(
"tenant limit exceeded: max {} tenants",
self.config.max_tenants
)));
}
let store = Arc::new(InMemoryTaskStore::with_config(
self.config.per_tenant.clone(),
));
stores.insert(tenant, Arc::clone(&store));
drop(stores);
Ok(store)
}
async fn get_existing_store(&self) -> Option<Arc<InMemoryTaskStore>> {
let tenant = TenantContext::current();
let stores = self.stores.read().await;
stores.get(&tenant).map(Arc::clone)
}
pub async fn tenant_count(&self) -> usize {
self.stores.read().await.len()
}
pub async fn run_eviction_all(&self) {
let stores = self.stores.read().await;
for store in stores.values() {
store.run_eviction().await;
}
}
pub async fn prune_empty_tenants(&self) {
let mut stores = self.stores.write().await;
let mut empty_tenants = Vec::new();
for (tenant, store) in stores.iter() {
if store.count().await.unwrap_or(0) == 0 {
empty_tenants.push(tenant.clone());
}
}
for tenant in empty_tenants {
stores.remove(&tenant);
}
}
}
#[allow(clippy::manual_async_fn)]
impl TaskStore for TenantAwareInMemoryTaskStore {
fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let store = self.get_store().await?;
store.save(task).await
})
}
fn get<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
Box::pin(async move {
match self.get_existing_store().await {
Some(store) => store.get(id).await,
None => Ok(None),
}
})
}
fn list<'a>(
&'a self,
params: &'a ListTasksParams,
) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
Box::pin(async move {
match self.get_existing_store().await {
Some(store) => store.list(params).await,
None => Ok(TaskListResponse::new(Vec::new())),
}
})
}
fn insert_if_absent<'a>(
&'a self,
task: Task,
) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
Box::pin(async move {
let store = self.get_store().await?;
store.insert_if_absent(task).await
})
}
fn delete<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
match self.get_existing_store().await {
Some(store) => store.delete(id).await,
None => Ok(()),
}
})
}
fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
Box::pin(async move {
match self.get_existing_store().await {
Some(store) => store.count().await,
None => Ok(0),
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
fn make_task(id: &str, state: TaskState) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new("ctx-default"),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
#[tokio::test]
async fn tenant_context_default_is_empty_string() {
let tenant = TenantContext::current();
assert_eq!(tenant, "", "default tenant should be empty string");
}
#[tokio::test]
async fn tenant_context_scope_sets_and_restores() {
let before = TenantContext::current();
assert_eq!(before, "");
let inside = TenantContext::scope("acme", async { TenantContext::current() }).await;
assert_eq!(inside, "acme", "scope should set the tenant");
let after = TenantContext::current();
assert_eq!(after, "", "tenant should revert after scope exits");
}
#[tokio::test]
async fn tenant_context_nested_scopes() {
TenantContext::scope("outer", async {
assert_eq!(TenantContext::current(), "outer");
TenantContext::scope("inner", async {
assert_eq!(TenantContext::current(), "inner");
})
.await;
assert_eq!(
TenantContext::current(),
"outer",
"should restore outer tenant after inner scope"
);
})
.await;
}
#[tokio::test]
async fn tenant_isolation_save_and_get() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("tenant-a", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
})
.await;
let found = TenantContext::scope("tenant-a", async {
store.get(&TaskId::new("t1")).await.unwrap()
})
.await;
assert!(found.is_some(), "tenant-a should see its own task");
let not_found = TenantContext::scope("tenant-b", async {
store.get(&TaskId::new("t1")).await.unwrap()
})
.await;
assert!(
not_found.is_none(),
"tenant-b should not see tenant-a's task"
);
}
#[tokio::test]
async fn tenant_isolation_list() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("alpha", async {
store
.save(make_task("a1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("a2", TaskState::Working))
.await
.unwrap();
})
.await;
TenantContext::scope("beta", async {
store
.save(make_task("b1", TaskState::Submitted))
.await
.unwrap();
})
.await;
let alpha_list = TenantContext::scope("alpha", async {
let params = ListTasksParams::default();
store.list(¶ms).await.unwrap()
})
.await;
assert_eq!(
alpha_list.tasks.len(),
2,
"alpha should see only its 2 tasks"
);
let beta_list = TenantContext::scope("beta", async {
let params = ListTasksParams::default();
store.list(¶ms).await.unwrap()
})
.await;
assert_eq!(beta_list.tasks.len(), 1, "beta should see only its 1 task");
}
#[tokio::test]
async fn tenant_isolation_delete() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("tenant-a", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
})
.await;
TenantContext::scope("tenant-b", async {
store.delete(&TaskId::new("t1")).await.unwrap();
})
.await;
let still_exists = TenantContext::scope("tenant-a", async {
store.get(&TaskId::new("t1")).await.unwrap()
})
.await;
assert!(
still_exists.is_some(),
"tenant-a's task should survive tenant-b's delete"
);
}
#[tokio::test]
async fn tenant_isolation_insert_if_absent() {
let store = TenantAwareInMemoryTaskStore::new();
let inserted_a = TenantContext::scope("tenant-a", async {
store
.insert_if_absent(make_task("shared-id", TaskState::Submitted))
.await
.unwrap()
})
.await;
assert!(inserted_a, "tenant-a insert should succeed");
let inserted_b = TenantContext::scope("tenant-b", async {
store
.insert_if_absent(make_task("shared-id", TaskState::Working))
.await
.unwrap()
})
.await;
assert!(
inserted_b,
"tenant-b insert of same ID should also succeed (different partition)"
);
}
#[tokio::test]
async fn tenant_isolation_count() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("x", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", TaskState::Submitted))
.await
.unwrap();
})
.await;
TenantContext::scope("y", async {
store
.save(make_task("t3", TaskState::Submitted))
.await
.unwrap();
})
.await;
let count_x = TenantContext::scope("x", async { store.count().await.unwrap() }).await;
assert_eq!(count_x, 2, "tenant x should have 2 tasks");
let count_y = TenantContext::scope("y", async { store.count().await.unwrap() }).await;
assert_eq!(count_y, 1, "tenant y should have 1 task");
}
#[tokio::test]
async fn tenant_count_reflects_active_tenants() {
let store = TenantAwareInMemoryTaskStore::new();
assert_eq!(store.tenant_count().await, 0);
TenantContext::scope("a", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
})
.await;
assert_eq!(store.tenant_count().await, 1);
TenantContext::scope("b", async {
store
.save(make_task("t2", TaskState::Submitted))
.await
.unwrap();
})
.await;
assert_eq!(store.tenant_count().await, 2);
}
#[tokio::test]
async fn max_tenants_limit_enforced() {
let config = TenantStoreConfig {
per_tenant: TaskStoreConfig::default(),
max_tenants: 2,
};
let store = TenantAwareInMemoryTaskStore::with_config(config);
TenantContext::scope("t1", async {
store
.save(make_task("task-a", TaskState::Submitted))
.await
.unwrap();
})
.await;
TenantContext::scope("t2", async {
store
.save(make_task("task-b", TaskState::Submitted))
.await
.unwrap();
})
.await;
let result = TenantContext::scope("t3", async {
store.save(make_task("task-c", TaskState::Submitted)).await
})
.await;
assert!(
result.is_err(),
"exceeding max_tenants should return an error"
);
}
#[tokio::test]
async fn existing_tenant_does_not_count_against_limit() {
let config = TenantStoreConfig {
per_tenant: TaskStoreConfig::default(),
max_tenants: 1,
};
let store = TenantAwareInMemoryTaskStore::with_config(config);
TenantContext::scope("only", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", TaskState::Working))
.await
.unwrap();
})
.await;
let count = TenantContext::scope("only", async { store.count().await.unwrap() }).await;
assert_eq!(count, 2, "existing tenant can add more tasks");
}
#[tokio::test]
async fn no_tenant_context_uses_default_partition() {
let store = TenantAwareInMemoryTaskStore::new();
store
.save(make_task("default-task", TaskState::Submitted))
.await
.unwrap();
let fetched = store.get(&TaskId::new("default-task")).await.unwrap();
assert!(
fetched.is_some(),
"task saved without tenant context should be retrievable without context"
);
let not_found = TenantContext::scope("other", async {
store.get(&TaskId::new("default-task")).await.unwrap()
})
.await;
assert!(
not_found.is_none(),
"default partition task should not leak to named tenants"
);
}
#[tokio::test]
async fn prune_empty_tenants_removes_empty_partitions() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("keep", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
})
.await;
TenantContext::scope("remove", async {
store
.save(make_task("t2", TaskState::Submitted))
.await
.unwrap();
})
.await;
assert_eq!(store.tenant_count().await, 2);
TenantContext::scope("remove", async {
store.delete(&TaskId::new("t2")).await.unwrap();
})
.await;
store.prune_empty_tenants().await;
assert_eq!(
store.tenant_count().await,
1,
"empty tenant partition should be pruned"
);
}
#[test]
fn default_creates_new_tenant_store() {
let store = TenantAwareInMemoryTaskStore::default();
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let count = rt.block_on(store.tenant_count());
assert_eq!(count, 0, "default store should have no tenants");
}
#[tokio::test]
async fn run_eviction_all_runs_without_error() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("t1", async {
store
.save(make_task("task-a", TaskState::Completed))
.await
.unwrap();
})
.await;
TenantContext::scope("t2", async {
store
.save(make_task("task-b", TaskState::Working))
.await
.unwrap();
})
.await;
store.run_eviction_all().await;
}
#[tokio::test]
async fn get_store_double_check_path() {
let store = TenantAwareInMemoryTaskStore::new();
TenantContext::scope("racer", async {
store
.save(make_task("t1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", TaskState::Working))
.await
.unwrap();
let count = store.count().await.unwrap();
assert_eq!(count, 2, "both tasks should be in same tenant store");
})
.await;
assert_eq!(
store.tenant_count().await,
1,
"should have exactly 1 tenant"
);
}
#[test]
fn default_tenant_store_config() {
let cfg = TenantStoreConfig::default();
assert_eq!(cfg.max_tenants, 1000);
}
}