a2a_protocol_server/store/tenant/
store.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use a2a_protocol_types::error::A2aResult;
14use a2a_protocol_types::params::ListTasksParams;
15use a2a_protocol_types::responses::TaskListResponse;
16use a2a_protocol_types::task::{Task, TaskId};
17use tokio::sync::RwLock;
18
19use super::super::task_store::{InMemoryTaskStore, TaskStore, TaskStoreConfig};
20use super::context::TenantContext;
21
22#[derive(Debug, Clone)]
26pub struct TenantStoreConfig {
27 pub per_tenant: TaskStoreConfig,
29
30 pub max_tenants: usize,
33}
34
35impl Default for TenantStoreConfig {
36 fn default() -> Self {
37 Self {
38 per_tenant: TaskStoreConfig::default(),
39 max_tenants: 1000,
40 }
41 }
42}
43
44#[derive(Debug)]
81pub struct TenantAwareInMemoryTaskStore {
82 stores: RwLock<HashMap<String, Arc<InMemoryTaskStore>>>,
83 config: TenantStoreConfig,
84}
85
86impl Default for TenantAwareInMemoryTaskStore {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl TenantAwareInMemoryTaskStore {
93 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 stores: RwLock::new(HashMap::new()),
98 config: TenantStoreConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_config(config: TenantStoreConfig) -> Self {
105 Self {
106 stores: RwLock::new(HashMap::new()),
107 config,
108 }
109 }
110
111 async fn get_store(&self) -> A2aResult<Arc<InMemoryTaskStore>> {
113 let tenant = TenantContext::current();
114
115 {
117 let stores = self.stores.read().await;
118 if let Some(store) = stores.get(&tenant) {
119 return Ok(Arc::clone(store));
120 }
121 }
122
123 let mut stores = self.stores.write().await;
125 if let Some(store) = stores.get(&tenant) {
127 return Ok(Arc::clone(store));
128 }
129
130 if stores.len() >= self.config.max_tenants {
131 return Err(a2a_protocol_types::error::A2aError::internal(format!(
132 "tenant limit exceeded: max {} tenants",
133 self.config.max_tenants
134 )));
135 }
136
137 let store = Arc::new(InMemoryTaskStore::with_config(
138 self.config.per_tenant.clone(),
139 ));
140 stores.insert(tenant, Arc::clone(&store));
141 drop(stores);
142 Ok(store)
143 }
144
145 async fn get_existing_store(&self) -> Option<Arc<InMemoryTaskStore>> {
151 let tenant = TenantContext::current();
152 let stores = self.stores.read().await;
153 stores.get(&tenant).map(Arc::clone)
154 }
155
156 pub async fn tenant_count(&self) -> usize {
158 self.stores.read().await.len()
159 }
160
161 pub async fn run_eviction_all(&self) {
165 let stores = self.stores.read().await;
166 for store in stores.values() {
167 store.run_eviction().await;
168 }
169 }
170
171 pub async fn prune_empty_tenants(&self) {
175 let mut stores = self.stores.write().await;
176 let mut empty_tenants = Vec::new();
177 for (tenant, store) in stores.iter() {
178 if store.count().await.unwrap_or(0) == 0 {
179 empty_tenants.push(tenant.clone());
180 }
181 }
182 for tenant in empty_tenants {
183 stores.remove(&tenant);
184 }
185 }
186}
187
188#[allow(clippy::manual_async_fn)]
189impl TaskStore for TenantAwareInMemoryTaskStore {
190 fn save<'a>(
191 &'a self,
192 task: &'a Task,
193 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
194 Box::pin(async move {
195 let store = self.get_store().await?;
196 store.save(task).await
197 })
198 }
199
200 fn get<'a>(
201 &'a self,
202 id: &'a TaskId,
203 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
204 Box::pin(async move {
205 match self.get_existing_store().await {
206 Some(store) => store.get(id).await,
207 None => Ok(None),
208 }
209 })
210 }
211
212 fn list<'a>(
213 &'a self,
214 params: &'a ListTasksParams,
215 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
216 Box::pin(async move {
217 match self.get_existing_store().await {
218 Some(store) => store.list(params).await,
219 None => Ok(TaskListResponse::new(Vec::new())),
220 }
221 })
222 }
223
224 fn insert_if_absent<'a>(
225 &'a self,
226 task: &'a Task,
227 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
228 Box::pin(async move {
229 let store = self.get_store().await?;
230 store.insert_if_absent(task).await
231 })
232 }
233
234 fn delete<'a>(
235 &'a self,
236 id: &'a TaskId,
237 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
238 Box::pin(async move {
239 match self.get_existing_store().await {
240 Some(store) => store.delete(id).await,
241 None => Ok(()),
242 }
243 })
244 }
245
246 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
247 Box::pin(async move {
248 match self.get_existing_store().await {
249 Some(store) => store.count().await,
250 None => Ok(0),
251 }
252 })
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
260
261 fn make_task(id: &str, state: TaskState) -> Task {
263 Task {
264 id: TaskId::new(id),
265 context_id: ContextId::new("ctx-default"),
266 status: TaskStatus::new(state),
267 history: None,
268 artifacts: None,
269 metadata: None,
270 }
271 }
272
273 #[tokio::test]
276 async fn tenant_context_default_is_empty_string() {
277 let tenant = TenantContext::current();
279 assert_eq!(tenant, "", "default tenant should be empty string");
280 }
281
282 #[tokio::test]
283 async fn tenant_context_scope_sets_and_restores() {
284 let before = TenantContext::current();
285 assert_eq!(before, "");
286
287 let inside = TenantContext::scope("acme", async { TenantContext::current() }).await;
288 assert_eq!(inside, "acme", "scope should set the tenant");
289
290 let after = TenantContext::current();
291 assert_eq!(after, "", "tenant should revert after scope exits");
292 }
293
294 #[tokio::test]
295 async fn tenant_context_nested_scopes() {
296 TenantContext::scope("outer", async {
297 assert_eq!(TenantContext::current(), "outer");
298 TenantContext::scope("inner", async {
299 assert_eq!(TenantContext::current(), "inner");
300 })
301 .await;
302 assert_eq!(
303 TenantContext::current(),
304 "outer",
305 "should restore outer tenant after inner scope"
306 );
307 })
308 .await;
309 }
310
311 #[tokio::test]
314 async fn tenant_isolation_save_and_get() {
315 let store = TenantAwareInMemoryTaskStore::new();
316
317 TenantContext::scope("tenant-a", async {
319 store
320 .save(&make_task("t1", TaskState::Submitted))
321 .await
322 .unwrap();
323 })
324 .await;
325
326 let found = TenantContext::scope("tenant-a", async {
328 store.get(&TaskId::new("t1")).await.unwrap()
329 })
330 .await;
331 assert!(found.is_some(), "tenant-a should see its own task");
332
333 let not_found = TenantContext::scope("tenant-b", async {
335 store.get(&TaskId::new("t1")).await.unwrap()
336 })
337 .await;
338 assert!(
339 not_found.is_none(),
340 "tenant-b should not see tenant-a's task"
341 );
342 }
343
344 #[tokio::test]
345 async fn tenant_isolation_list() {
346 let store = TenantAwareInMemoryTaskStore::new();
347
348 TenantContext::scope("alpha", async {
349 store
350 .save(&make_task("a1", TaskState::Submitted))
351 .await
352 .unwrap();
353 store
354 .save(&make_task("a2", TaskState::Working))
355 .await
356 .unwrap();
357 })
358 .await;
359
360 TenantContext::scope("beta", async {
361 store
362 .save(&make_task("b1", TaskState::Submitted))
363 .await
364 .unwrap();
365 })
366 .await;
367
368 let alpha_list = TenantContext::scope("alpha", async {
369 let params = ListTasksParams::default();
370 store.list(¶ms).await.unwrap()
371 })
372 .await;
373 assert_eq!(
374 alpha_list.tasks.len(),
375 2,
376 "alpha should see only its 2 tasks"
377 );
378
379 let beta_list = TenantContext::scope("beta", async {
380 let params = ListTasksParams::default();
381 store.list(¶ms).await.unwrap()
382 })
383 .await;
384 assert_eq!(beta_list.tasks.len(), 1, "beta should see only its 1 task");
385 }
386
387 #[tokio::test]
388 async fn tenant_isolation_delete() {
389 let store = TenantAwareInMemoryTaskStore::new();
390
391 TenantContext::scope("tenant-a", async {
392 store
393 .save(&make_task("t1", TaskState::Submitted))
394 .await
395 .unwrap();
396 })
397 .await;
398
399 TenantContext::scope("tenant-b", async {
401 store.delete(&TaskId::new("t1")).await.unwrap();
402 })
403 .await;
404
405 let still_exists = TenantContext::scope("tenant-a", async {
406 store.get(&TaskId::new("t1")).await.unwrap()
407 })
408 .await;
409 assert!(
410 still_exists.is_some(),
411 "tenant-a's task should survive tenant-b's delete"
412 );
413 }
414
415 #[tokio::test]
416 async fn tenant_isolation_insert_if_absent() {
417 let store = TenantAwareInMemoryTaskStore::new();
418
419 let inserted_a = TenantContext::scope("tenant-a", async {
421 store
422 .insert_if_absent(&make_task("shared-id", TaskState::Submitted))
423 .await
424 .unwrap()
425 })
426 .await;
427 assert!(inserted_a, "tenant-a insert should succeed");
428
429 let inserted_b = TenantContext::scope("tenant-b", async {
430 store
431 .insert_if_absent(&make_task("shared-id", TaskState::Working))
432 .await
433 .unwrap()
434 })
435 .await;
436 assert!(
437 inserted_b,
438 "tenant-b insert of same ID should also succeed (different partition)"
439 );
440 }
441
442 #[tokio::test]
443 async fn tenant_isolation_count() {
444 let store = TenantAwareInMemoryTaskStore::new();
445
446 TenantContext::scope("x", async {
447 store
448 .save(&make_task("t1", TaskState::Submitted))
449 .await
450 .unwrap();
451 store
452 .save(&make_task("t2", TaskState::Submitted))
453 .await
454 .unwrap();
455 })
456 .await;
457
458 TenantContext::scope("y", async {
459 store
460 .save(&make_task("t3", TaskState::Submitted))
461 .await
462 .unwrap();
463 })
464 .await;
465
466 let count_x = TenantContext::scope("x", async { store.count().await.unwrap() }).await;
467 assert_eq!(count_x, 2, "tenant x should have 2 tasks");
468
469 let count_y = TenantContext::scope("y", async { store.count().await.unwrap() }).await;
470 assert_eq!(count_y, 1, "tenant y should have 1 task");
471 }
472
473 #[tokio::test]
476 async fn tenant_count_reflects_active_tenants() {
477 let store = TenantAwareInMemoryTaskStore::new();
478 assert_eq!(store.tenant_count().await, 0);
479
480 TenantContext::scope("a", async {
481 store
482 .save(&make_task("t1", TaskState::Submitted))
483 .await
484 .unwrap();
485 })
486 .await;
487 assert_eq!(store.tenant_count().await, 1);
488
489 TenantContext::scope("b", async {
490 store
491 .save(&make_task("t2", TaskState::Submitted))
492 .await
493 .unwrap();
494 })
495 .await;
496 assert_eq!(store.tenant_count().await, 2);
497 }
498
499 #[tokio::test]
500 async fn max_tenants_limit_enforced() {
501 let config = TenantStoreConfig {
502 per_tenant: TaskStoreConfig::default(),
503 max_tenants: 2,
504 };
505 let store = TenantAwareInMemoryTaskStore::with_config(config);
506
507 TenantContext::scope("t1", async {
509 store
510 .save(&make_task("task-a", TaskState::Submitted))
511 .await
512 .unwrap();
513 })
514 .await;
515 TenantContext::scope("t2", async {
516 store
517 .save(&make_task("task-b", TaskState::Submitted))
518 .await
519 .unwrap();
520 })
521 .await;
522
523 let result = TenantContext::scope("t3", async {
525 store.save(&make_task("task-c", TaskState::Submitted)).await
526 })
527 .await;
528 assert!(
529 result.is_err(),
530 "exceeding max_tenants should return an error"
531 );
532 }
533
534 #[tokio::test]
535 async fn existing_tenant_does_not_count_against_limit() {
536 let config = TenantStoreConfig {
537 per_tenant: TaskStoreConfig::default(),
538 max_tenants: 1,
539 };
540 let store = TenantAwareInMemoryTaskStore::with_config(config);
541
542 TenantContext::scope("only", async {
543 store
544 .save(&make_task("t1", TaskState::Submitted))
545 .await
546 .unwrap();
547 store
549 .save(&make_task("t2", TaskState::Working))
550 .await
551 .unwrap();
552 })
553 .await;
554
555 let count = TenantContext::scope("only", async { store.count().await.unwrap() }).await;
556 assert_eq!(count, 2, "existing tenant can add more tasks");
557 }
558
559 #[tokio::test]
562 async fn no_tenant_context_uses_default_partition() {
563 let store = TenantAwareInMemoryTaskStore::new();
564
565 store
567 .save(&make_task("default-task", TaskState::Submitted))
568 .await
569 .unwrap();
570
571 let fetched = store.get(&TaskId::new("default-task")).await.unwrap();
572 assert!(
573 fetched.is_some(),
574 "task saved without tenant context should be retrievable without context"
575 );
576
577 let not_found = TenantContext::scope("other", async {
579 store.get(&TaskId::new("default-task")).await.unwrap()
580 })
581 .await;
582 assert!(
583 not_found.is_none(),
584 "default partition task should not leak to named tenants"
585 );
586 }
587
588 #[tokio::test]
591 async fn prune_empty_tenants_removes_empty_partitions() {
592 let store = TenantAwareInMemoryTaskStore::new();
593
594 TenantContext::scope("keep", async {
595 store
596 .save(&make_task("t1", TaskState::Submitted))
597 .await
598 .unwrap();
599 })
600 .await;
601 TenantContext::scope("remove", async {
602 store
603 .save(&make_task("t2", TaskState::Submitted))
604 .await
605 .unwrap();
606 })
607 .await;
608 assert_eq!(store.tenant_count().await, 2);
609
610 TenantContext::scope("remove", async {
612 store.delete(&TaskId::new("t2")).await.unwrap();
613 })
614 .await;
615
616 store.prune_empty_tenants().await;
617 assert_eq!(
618 store.tenant_count().await,
619 1,
620 "empty tenant partition should be pruned"
621 );
622 }
623
624 #[test]
628 fn default_creates_new_tenant_store() {
629 let store = TenantAwareInMemoryTaskStore::default();
630 let rt = tokio::runtime::Builder::new_current_thread()
631 .enable_all()
632 .build()
633 .unwrap();
634 let count = rt.block_on(store.tenant_count());
635 assert_eq!(count, 0, "default store should have no tenants");
636 }
637
638 #[tokio::test]
640 async fn run_eviction_all_runs_without_error() {
641 let store = TenantAwareInMemoryTaskStore::new();
642
643 TenantContext::scope("t1", async {
645 store
646 .save(&make_task("task-a", TaskState::Completed))
647 .await
648 .unwrap();
649 })
650 .await;
651 TenantContext::scope("t2", async {
652 store
653 .save(&make_task("task-b", TaskState::Working))
654 .await
655 .unwrap();
656 })
657 .await;
658
659 store.run_eviction_all().await;
661 }
662
663 #[tokio::test]
667 async fn get_store_double_check_path() {
668 let store = TenantAwareInMemoryTaskStore::new();
669
670 TenantContext::scope("racer", async {
672 store
673 .save(&make_task("t1", TaskState::Submitted))
674 .await
675 .unwrap();
676 store
678 .save(&make_task("t2", TaskState::Working))
679 .await
680 .unwrap();
681
682 let count = store.count().await.unwrap();
683 assert_eq!(count, 2, "both tasks should be in same tenant store");
684 })
685 .await;
686
687 assert_eq!(
688 store.tenant_count().await,
689 1,
690 "should have exactly 1 tenant"
691 );
692 }
693
694 #[test]
695 fn default_tenant_store_config() {
696 let cfg = TenantStoreConfig::default();
697 assert_eq!(cfg.max_tenants, 1000);
698 }
699}