a2a_protocol_server/store/tenant/
store.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use a2a_protocol_types::error::A2aResult;
14use a2a_protocol_types::params::ListTasksParams;
15use a2a_protocol_types::responses::TaskListResponse;
16use a2a_protocol_types::task::{Task, TaskId};
17use tokio::sync::RwLock;
18
19use super::super::task_store::{InMemoryTaskStore, TaskStore, TaskStoreConfig};
20use super::context::TenantContext;
21
22#[derive(Debug, Clone)]
26pub struct TenantStoreConfig {
27 pub per_tenant: TaskStoreConfig,
29
30 pub max_tenants: usize,
33}
34
35impl Default for TenantStoreConfig {
36 fn default() -> Self {
37 Self {
38 per_tenant: TaskStoreConfig::default(),
39 max_tenants: 1000,
40 }
41 }
42}
43
44#[derive(Debug)]
81pub struct TenantAwareInMemoryTaskStore {
82 stores: RwLock<HashMap<String, Arc<InMemoryTaskStore>>>,
83 config: TenantStoreConfig,
84}
85
86impl Default for TenantAwareInMemoryTaskStore {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl TenantAwareInMemoryTaskStore {
93 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 stores: RwLock::new(HashMap::new()),
98 config: TenantStoreConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_config(config: TenantStoreConfig) -> Self {
105 Self {
106 stores: RwLock::new(HashMap::new()),
107 config,
108 }
109 }
110
111 async fn get_store(&self) -> A2aResult<Arc<InMemoryTaskStore>> {
113 let tenant = TenantContext::current();
114
115 {
117 let stores = self.stores.read().await;
118 if let Some(store) = stores.get(&tenant) {
119 return Ok(Arc::clone(store));
120 }
121 }
122
123 let mut stores = self.stores.write().await;
125 if let Some(store) = stores.get(&tenant) {
127 return Ok(Arc::clone(store));
128 }
129
130 if stores.len() >= self.config.max_tenants {
131 return Err(a2a_protocol_types::error::A2aError::internal(format!(
132 "tenant limit exceeded: max {} tenants",
133 self.config.max_tenants
134 )));
135 }
136
137 let store = Arc::new(InMemoryTaskStore::with_config(
138 self.config.per_tenant.clone(),
139 ));
140 stores.insert(tenant, Arc::clone(&store));
141 drop(stores);
142 Ok(store)
143 }
144
145 async fn get_existing_store(&self) -> Option<Arc<InMemoryTaskStore>> {
151 let tenant = TenantContext::current();
152 let stores = self.stores.read().await;
153 stores.get(&tenant).map(Arc::clone)
154 }
155
156 pub async fn tenant_count(&self) -> usize {
158 self.stores.read().await.len()
159 }
160
161 pub async fn run_eviction_all(&self) {
165 let stores = self.stores.read().await;
166 for store in stores.values() {
167 store.run_eviction().await;
168 }
169 }
170
171 pub async fn prune_empty_tenants(&self) {
175 let mut stores = self.stores.write().await;
176 let mut empty_tenants = Vec::new();
177 for (tenant, store) in stores.iter() {
178 if store.count().await.unwrap_or(0) == 0 {
179 empty_tenants.push(tenant.clone());
180 }
181 }
182 for tenant in empty_tenants {
183 stores.remove(&tenant);
184 }
185 }
186}
187
188#[allow(clippy::manual_async_fn)]
189impl TaskStore for TenantAwareInMemoryTaskStore {
190 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
191 Box::pin(async move {
192 let store = self.get_store().await?;
193 store.save(task).await
194 })
195 }
196
197 fn get<'a>(
198 &'a self,
199 id: &'a TaskId,
200 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
201 Box::pin(async move {
202 match self.get_existing_store().await {
203 Some(store) => store.get(id).await,
204 None => Ok(None),
205 }
206 })
207 }
208
209 fn list<'a>(
210 &'a self,
211 params: &'a ListTasksParams,
212 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
213 Box::pin(async move {
214 match self.get_existing_store().await {
215 Some(store) => store.list(params).await,
216 None => Ok(TaskListResponse::new(Vec::new())),
217 }
218 })
219 }
220
221 fn insert_if_absent<'a>(
222 &'a self,
223 task: Task,
224 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
225 Box::pin(async move {
226 let store = self.get_store().await?;
227 store.insert_if_absent(task).await
228 })
229 }
230
231 fn delete<'a>(
232 &'a self,
233 id: &'a TaskId,
234 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
235 Box::pin(async move {
236 match self.get_existing_store().await {
237 Some(store) => store.delete(id).await,
238 None => Ok(()),
239 }
240 })
241 }
242
243 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
244 Box::pin(async move {
245 match self.get_existing_store().await {
246 Some(store) => store.count().await,
247 None => Ok(0),
248 }
249 })
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
257
258 fn make_task(id: &str, state: TaskState) -> Task {
260 Task {
261 id: TaskId::new(id),
262 context_id: ContextId::new("ctx-default"),
263 status: TaskStatus::new(state),
264 history: None,
265 artifacts: None,
266 metadata: None,
267 }
268 }
269
270 #[tokio::test]
273 async fn tenant_context_default_is_empty_string() {
274 let tenant = TenantContext::current();
276 assert_eq!(tenant, "", "default tenant should be empty string");
277 }
278
279 #[tokio::test]
280 async fn tenant_context_scope_sets_and_restores() {
281 let before = TenantContext::current();
282 assert_eq!(before, "");
283
284 let inside = TenantContext::scope("acme", async { TenantContext::current() }).await;
285 assert_eq!(inside, "acme", "scope should set the tenant");
286
287 let after = TenantContext::current();
288 assert_eq!(after, "", "tenant should revert after scope exits");
289 }
290
291 #[tokio::test]
292 async fn tenant_context_nested_scopes() {
293 TenantContext::scope("outer", async {
294 assert_eq!(TenantContext::current(), "outer");
295 TenantContext::scope("inner", async {
296 assert_eq!(TenantContext::current(), "inner");
297 })
298 .await;
299 assert_eq!(
300 TenantContext::current(),
301 "outer",
302 "should restore outer tenant after inner scope"
303 );
304 })
305 .await;
306 }
307
308 #[tokio::test]
311 async fn tenant_isolation_save_and_get() {
312 let store = TenantAwareInMemoryTaskStore::new();
313
314 TenantContext::scope("tenant-a", async {
316 store
317 .save(make_task("t1", TaskState::Submitted))
318 .await
319 .unwrap();
320 })
321 .await;
322
323 let found = TenantContext::scope("tenant-a", async {
325 store.get(&TaskId::new("t1")).await.unwrap()
326 })
327 .await;
328 assert!(found.is_some(), "tenant-a should see its own task");
329
330 let not_found = TenantContext::scope("tenant-b", async {
332 store.get(&TaskId::new("t1")).await.unwrap()
333 })
334 .await;
335 assert!(
336 not_found.is_none(),
337 "tenant-b should not see tenant-a's task"
338 );
339 }
340
341 #[tokio::test]
342 async fn tenant_isolation_list() {
343 let store = TenantAwareInMemoryTaskStore::new();
344
345 TenantContext::scope("alpha", async {
346 store
347 .save(make_task("a1", TaskState::Submitted))
348 .await
349 .unwrap();
350 store
351 .save(make_task("a2", TaskState::Working))
352 .await
353 .unwrap();
354 })
355 .await;
356
357 TenantContext::scope("beta", async {
358 store
359 .save(make_task("b1", TaskState::Submitted))
360 .await
361 .unwrap();
362 })
363 .await;
364
365 let alpha_list = TenantContext::scope("alpha", async {
366 let params = ListTasksParams::default();
367 store.list(¶ms).await.unwrap()
368 })
369 .await;
370 assert_eq!(
371 alpha_list.tasks.len(),
372 2,
373 "alpha should see only its 2 tasks"
374 );
375
376 let beta_list = TenantContext::scope("beta", async {
377 let params = ListTasksParams::default();
378 store.list(¶ms).await.unwrap()
379 })
380 .await;
381 assert_eq!(beta_list.tasks.len(), 1, "beta should see only its 1 task");
382 }
383
384 #[tokio::test]
385 async fn tenant_isolation_delete() {
386 let store = TenantAwareInMemoryTaskStore::new();
387
388 TenantContext::scope("tenant-a", async {
389 store
390 .save(make_task("t1", TaskState::Submitted))
391 .await
392 .unwrap();
393 })
394 .await;
395
396 TenantContext::scope("tenant-b", async {
398 store.delete(&TaskId::new("t1")).await.unwrap();
399 })
400 .await;
401
402 let still_exists = TenantContext::scope("tenant-a", async {
403 store.get(&TaskId::new("t1")).await.unwrap()
404 })
405 .await;
406 assert!(
407 still_exists.is_some(),
408 "tenant-a's task should survive tenant-b's delete"
409 );
410 }
411
412 #[tokio::test]
413 async fn tenant_isolation_insert_if_absent() {
414 let store = TenantAwareInMemoryTaskStore::new();
415
416 let inserted_a = TenantContext::scope("tenant-a", async {
418 store
419 .insert_if_absent(make_task("shared-id", TaskState::Submitted))
420 .await
421 .unwrap()
422 })
423 .await;
424 assert!(inserted_a, "tenant-a insert should succeed");
425
426 let inserted_b = TenantContext::scope("tenant-b", async {
427 store
428 .insert_if_absent(make_task("shared-id", TaskState::Working))
429 .await
430 .unwrap()
431 })
432 .await;
433 assert!(
434 inserted_b,
435 "tenant-b insert of same ID should also succeed (different partition)"
436 );
437 }
438
439 #[tokio::test]
440 async fn tenant_isolation_count() {
441 let store = TenantAwareInMemoryTaskStore::new();
442
443 TenantContext::scope("x", async {
444 store
445 .save(make_task("t1", TaskState::Submitted))
446 .await
447 .unwrap();
448 store
449 .save(make_task("t2", TaskState::Submitted))
450 .await
451 .unwrap();
452 })
453 .await;
454
455 TenantContext::scope("y", async {
456 store
457 .save(make_task("t3", TaskState::Submitted))
458 .await
459 .unwrap();
460 })
461 .await;
462
463 let count_x = TenantContext::scope("x", async { store.count().await.unwrap() }).await;
464 assert_eq!(count_x, 2, "tenant x should have 2 tasks");
465
466 let count_y = TenantContext::scope("y", async { store.count().await.unwrap() }).await;
467 assert_eq!(count_y, 1, "tenant y should have 1 task");
468 }
469
470 #[tokio::test]
473 async fn tenant_count_reflects_active_tenants() {
474 let store = TenantAwareInMemoryTaskStore::new();
475 assert_eq!(store.tenant_count().await, 0);
476
477 TenantContext::scope("a", async {
478 store
479 .save(make_task("t1", TaskState::Submitted))
480 .await
481 .unwrap();
482 })
483 .await;
484 assert_eq!(store.tenant_count().await, 1);
485
486 TenantContext::scope("b", async {
487 store
488 .save(make_task("t2", TaskState::Submitted))
489 .await
490 .unwrap();
491 })
492 .await;
493 assert_eq!(store.tenant_count().await, 2);
494 }
495
496 #[tokio::test]
497 async fn max_tenants_limit_enforced() {
498 let config = TenantStoreConfig {
499 per_tenant: TaskStoreConfig::default(),
500 max_tenants: 2,
501 };
502 let store = TenantAwareInMemoryTaskStore::with_config(config);
503
504 TenantContext::scope("t1", async {
506 store
507 .save(make_task("task-a", TaskState::Submitted))
508 .await
509 .unwrap();
510 })
511 .await;
512 TenantContext::scope("t2", async {
513 store
514 .save(make_task("task-b", TaskState::Submitted))
515 .await
516 .unwrap();
517 })
518 .await;
519
520 let result = TenantContext::scope("t3", async {
522 store.save(make_task("task-c", TaskState::Submitted)).await
523 })
524 .await;
525 assert!(
526 result.is_err(),
527 "exceeding max_tenants should return an error"
528 );
529 }
530
531 #[tokio::test]
532 async fn existing_tenant_does_not_count_against_limit() {
533 let config = TenantStoreConfig {
534 per_tenant: TaskStoreConfig::default(),
535 max_tenants: 1,
536 };
537 let store = TenantAwareInMemoryTaskStore::with_config(config);
538
539 TenantContext::scope("only", async {
540 store
541 .save(make_task("t1", TaskState::Submitted))
542 .await
543 .unwrap();
544 store
546 .save(make_task("t2", TaskState::Working))
547 .await
548 .unwrap();
549 })
550 .await;
551
552 let count = TenantContext::scope("only", async { store.count().await.unwrap() }).await;
553 assert_eq!(count, 2, "existing tenant can add more tasks");
554 }
555
556 #[tokio::test]
559 async fn no_tenant_context_uses_default_partition() {
560 let store = TenantAwareInMemoryTaskStore::new();
561
562 store
564 .save(make_task("default-task", TaskState::Submitted))
565 .await
566 .unwrap();
567
568 let fetched = store.get(&TaskId::new("default-task")).await.unwrap();
569 assert!(
570 fetched.is_some(),
571 "task saved without tenant context should be retrievable without context"
572 );
573
574 let not_found = TenantContext::scope("other", async {
576 store.get(&TaskId::new("default-task")).await.unwrap()
577 })
578 .await;
579 assert!(
580 not_found.is_none(),
581 "default partition task should not leak to named tenants"
582 );
583 }
584
585 #[tokio::test]
588 async fn prune_empty_tenants_removes_empty_partitions() {
589 let store = TenantAwareInMemoryTaskStore::new();
590
591 TenantContext::scope("keep", async {
592 store
593 .save(make_task("t1", TaskState::Submitted))
594 .await
595 .unwrap();
596 })
597 .await;
598 TenantContext::scope("remove", async {
599 store
600 .save(make_task("t2", TaskState::Submitted))
601 .await
602 .unwrap();
603 })
604 .await;
605 assert_eq!(store.tenant_count().await, 2);
606
607 TenantContext::scope("remove", async {
609 store.delete(&TaskId::new("t2")).await.unwrap();
610 })
611 .await;
612
613 store.prune_empty_tenants().await;
614 assert_eq!(
615 store.tenant_count().await,
616 1,
617 "empty tenant partition should be pruned"
618 );
619 }
620
621 #[test]
625 fn default_creates_new_tenant_store() {
626 let store = TenantAwareInMemoryTaskStore::default();
627 let rt = tokio::runtime::Builder::new_current_thread()
628 .enable_all()
629 .build()
630 .unwrap();
631 let count = rt.block_on(store.tenant_count());
632 assert_eq!(count, 0, "default store should have no tenants");
633 }
634
635 #[tokio::test]
637 async fn run_eviction_all_runs_without_error() {
638 let store = TenantAwareInMemoryTaskStore::new();
639
640 TenantContext::scope("t1", async {
642 store
643 .save(make_task("task-a", TaskState::Completed))
644 .await
645 .unwrap();
646 })
647 .await;
648 TenantContext::scope("t2", async {
649 store
650 .save(make_task("task-b", TaskState::Working))
651 .await
652 .unwrap();
653 })
654 .await;
655
656 store.run_eviction_all().await;
658 }
659
660 #[tokio::test]
664 async fn get_store_double_check_path() {
665 let store = TenantAwareInMemoryTaskStore::new();
666
667 TenantContext::scope("racer", async {
669 store
670 .save(make_task("t1", TaskState::Submitted))
671 .await
672 .unwrap();
673 store
675 .save(make_task("t2", TaskState::Working))
676 .await
677 .unwrap();
678
679 let count = store.count().await.unwrap();
680 assert_eq!(count, 2, "both tasks should be in same tenant store");
681 })
682 .await;
683
684 assert_eq!(
685 store.tenant_count().await,
686 1,
687 "should have exactly 1 tenant"
688 );
689 }
690
691 #[test]
692 fn default_tenant_store_config() {
693 let cfg = TenantStoreConfig::default();
694 assert_eq!(cfg.max_tenants, 1000);
695 }
696}