a2a_protocol_server/store/tenant/
store.rs1use std::collections::HashMap;
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12
13use a2a_protocol_types::error::A2aResult;
14use a2a_protocol_types::params::ListTasksParams;
15use a2a_protocol_types::responses::TaskListResponse;
16use a2a_protocol_types::task::{Task, TaskId};
17use tokio::sync::RwLock;
18
19use super::super::task_store::{InMemoryTaskStore, TaskStore, TaskStoreConfig};
20use super::context::TenantContext;
21
22#[derive(Debug, Clone)]
26pub struct TenantStoreConfig {
27 pub per_tenant: TaskStoreConfig,
29
30 pub max_tenants: usize,
33}
34
35impl Default for TenantStoreConfig {
36 fn default() -> Self {
37 Self {
38 per_tenant: TaskStoreConfig::default(),
39 max_tenants: 1000,
40 }
41 }
42}
43
44#[derive(Debug)]
81pub struct TenantAwareInMemoryTaskStore {
82 stores: RwLock<HashMap<String, Arc<InMemoryTaskStore>>>,
83 config: TenantStoreConfig,
84}
85
86impl Default for TenantAwareInMemoryTaskStore {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92impl TenantAwareInMemoryTaskStore {
93 #[must_use]
95 pub fn new() -> Self {
96 Self {
97 stores: RwLock::new(HashMap::new()),
98 config: TenantStoreConfig::default(),
99 }
100 }
101
102 #[must_use]
104 pub fn with_config(config: TenantStoreConfig) -> Self {
105 Self {
106 stores: RwLock::new(HashMap::new()),
107 config,
108 }
109 }
110
111 async fn get_store(&self) -> A2aResult<Arc<InMemoryTaskStore>> {
113 let tenant = TenantContext::current();
114
115 {
117 let stores = self.stores.read().await;
118 if let Some(store) = stores.get(&tenant) {
119 return Ok(Arc::clone(store));
120 }
121 }
122
123 let mut stores = self.stores.write().await;
125 if let Some(store) = stores.get(&tenant) {
127 return Ok(Arc::clone(store));
128 }
129
130 if stores.len() >= self.config.max_tenants {
131 return Err(a2a_protocol_types::error::A2aError::internal(format!(
132 "tenant limit exceeded: max {} tenants",
133 self.config.max_tenants
134 )));
135 }
136
137 let store = Arc::new(InMemoryTaskStore::with_config(
138 self.config.per_tenant.clone(),
139 ));
140 stores.insert(tenant, Arc::clone(&store));
141 drop(stores);
142 Ok(store)
143 }
144
145 pub async fn tenant_count(&self) -> usize {
147 self.stores.read().await.len()
148 }
149
150 pub async fn run_eviction_all(&self) {
154 let stores = self.stores.read().await;
155 for store in stores.values() {
156 store.run_eviction().await;
157 }
158 }
159
160 pub async fn prune_empty_tenants(&self) {
164 let mut stores = self.stores.write().await;
165 let mut empty_tenants = Vec::new();
166 for (tenant, store) in stores.iter() {
167 if store.count().await.unwrap_or(0) == 0 {
168 empty_tenants.push(tenant.clone());
169 }
170 }
171 for tenant in empty_tenants {
172 stores.remove(&tenant);
173 }
174 }
175}
176
177#[allow(clippy::manual_async_fn)]
178impl TaskStore for TenantAwareInMemoryTaskStore {
179 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
180 Box::pin(async move {
181 let store = self.get_store().await?;
182 store.save(task).await
183 })
184 }
185
186 fn get<'a>(
187 &'a self,
188 id: &'a TaskId,
189 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
190 Box::pin(async move {
191 let store = self.get_store().await?;
192 store.get(id).await
193 })
194 }
195
196 fn list<'a>(
197 &'a self,
198 params: &'a ListTasksParams,
199 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
200 Box::pin(async move {
201 let store = self.get_store().await?;
202 store.list(params).await
203 })
204 }
205
206 fn insert_if_absent<'a>(
207 &'a self,
208 task: Task,
209 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
210 Box::pin(async move {
211 let store = self.get_store().await?;
212 store.insert_if_absent(task).await
213 })
214 }
215
216 fn delete<'a>(
217 &'a self,
218 id: &'a TaskId,
219 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
220 Box::pin(async move {
221 let store = self.get_store().await?;
222 store.delete(id).await
223 })
224 }
225
226 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
227 Box::pin(async move {
228 let store = self.get_store().await?;
229 store.count().await
230 })
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
238
239 fn make_task(id: &str, state: TaskState) -> Task {
241 Task {
242 id: TaskId::new(id),
243 context_id: ContextId::new("ctx-default"),
244 status: TaskStatus::new(state),
245 history: None,
246 artifacts: None,
247 metadata: None,
248 }
249 }
250
251 #[tokio::test]
254 async fn tenant_context_default_is_empty_string() {
255 let tenant = TenantContext::current();
257 assert_eq!(tenant, "", "default tenant should be empty string");
258 }
259
260 #[tokio::test]
261 async fn tenant_context_scope_sets_and_restores() {
262 let before = TenantContext::current();
263 assert_eq!(before, "");
264
265 let inside = TenantContext::scope("acme", async { TenantContext::current() }).await;
266 assert_eq!(inside, "acme", "scope should set the tenant");
267
268 let after = TenantContext::current();
269 assert_eq!(after, "", "tenant should revert after scope exits");
270 }
271
272 #[tokio::test]
273 async fn tenant_context_nested_scopes() {
274 TenantContext::scope("outer", async {
275 assert_eq!(TenantContext::current(), "outer");
276 TenantContext::scope("inner", async {
277 assert_eq!(TenantContext::current(), "inner");
278 })
279 .await;
280 assert_eq!(
281 TenantContext::current(),
282 "outer",
283 "should restore outer tenant after inner scope"
284 );
285 })
286 .await;
287 }
288
289 #[tokio::test]
292 async fn tenant_isolation_save_and_get() {
293 let store = TenantAwareInMemoryTaskStore::new();
294
295 TenantContext::scope("tenant-a", async {
297 store
298 .save(make_task("t1", TaskState::Submitted))
299 .await
300 .unwrap();
301 })
302 .await;
303
304 let found = TenantContext::scope("tenant-a", async {
306 store.get(&TaskId::new("t1")).await.unwrap()
307 })
308 .await;
309 assert!(found.is_some(), "tenant-a should see its own task");
310
311 let not_found = TenantContext::scope("tenant-b", async {
313 store.get(&TaskId::new("t1")).await.unwrap()
314 })
315 .await;
316 assert!(
317 not_found.is_none(),
318 "tenant-b should not see tenant-a's task"
319 );
320 }
321
322 #[tokio::test]
323 async fn tenant_isolation_list() {
324 let store = TenantAwareInMemoryTaskStore::new();
325
326 TenantContext::scope("alpha", async {
327 store
328 .save(make_task("a1", TaskState::Submitted))
329 .await
330 .unwrap();
331 store
332 .save(make_task("a2", TaskState::Working))
333 .await
334 .unwrap();
335 })
336 .await;
337
338 TenantContext::scope("beta", async {
339 store
340 .save(make_task("b1", TaskState::Submitted))
341 .await
342 .unwrap();
343 })
344 .await;
345
346 let alpha_list = TenantContext::scope("alpha", async {
347 let params = ListTasksParams::default();
348 store.list(¶ms).await.unwrap()
349 })
350 .await;
351 assert_eq!(
352 alpha_list.tasks.len(),
353 2,
354 "alpha should see only its 2 tasks"
355 );
356
357 let beta_list = TenantContext::scope("beta", async {
358 let params = ListTasksParams::default();
359 store.list(¶ms).await.unwrap()
360 })
361 .await;
362 assert_eq!(beta_list.tasks.len(), 1, "beta should see only its 1 task");
363 }
364
365 #[tokio::test]
366 async fn tenant_isolation_delete() {
367 let store = TenantAwareInMemoryTaskStore::new();
368
369 TenantContext::scope("tenant-a", async {
370 store
371 .save(make_task("t1", TaskState::Submitted))
372 .await
373 .unwrap();
374 })
375 .await;
376
377 TenantContext::scope("tenant-b", async {
379 store.delete(&TaskId::new("t1")).await.unwrap();
380 })
381 .await;
382
383 let still_exists = TenantContext::scope("tenant-a", async {
384 store.get(&TaskId::new("t1")).await.unwrap()
385 })
386 .await;
387 assert!(
388 still_exists.is_some(),
389 "tenant-a's task should survive tenant-b's delete"
390 );
391 }
392
393 #[tokio::test]
394 async fn tenant_isolation_insert_if_absent() {
395 let store = TenantAwareInMemoryTaskStore::new();
396
397 let inserted_a = TenantContext::scope("tenant-a", async {
399 store
400 .insert_if_absent(make_task("shared-id", TaskState::Submitted))
401 .await
402 .unwrap()
403 })
404 .await;
405 assert!(inserted_a, "tenant-a insert should succeed");
406
407 let inserted_b = TenantContext::scope("tenant-b", async {
408 store
409 .insert_if_absent(make_task("shared-id", TaskState::Working))
410 .await
411 .unwrap()
412 })
413 .await;
414 assert!(
415 inserted_b,
416 "tenant-b insert of same ID should also succeed (different partition)"
417 );
418 }
419
420 #[tokio::test]
421 async fn tenant_isolation_count() {
422 let store = TenantAwareInMemoryTaskStore::new();
423
424 TenantContext::scope("x", async {
425 store
426 .save(make_task("t1", TaskState::Submitted))
427 .await
428 .unwrap();
429 store
430 .save(make_task("t2", TaskState::Submitted))
431 .await
432 .unwrap();
433 })
434 .await;
435
436 TenantContext::scope("y", async {
437 store
438 .save(make_task("t3", TaskState::Submitted))
439 .await
440 .unwrap();
441 })
442 .await;
443
444 let count_x = TenantContext::scope("x", async { store.count().await.unwrap() }).await;
445 assert_eq!(count_x, 2, "tenant x should have 2 tasks");
446
447 let count_y = TenantContext::scope("y", async { store.count().await.unwrap() }).await;
448 assert_eq!(count_y, 1, "tenant y should have 1 task");
449 }
450
451 #[tokio::test]
454 async fn tenant_count_reflects_active_tenants() {
455 let store = TenantAwareInMemoryTaskStore::new();
456 assert_eq!(store.tenant_count().await, 0);
457
458 TenantContext::scope("a", async {
459 store
460 .save(make_task("t1", TaskState::Submitted))
461 .await
462 .unwrap();
463 })
464 .await;
465 assert_eq!(store.tenant_count().await, 1);
466
467 TenantContext::scope("b", async {
468 store
469 .save(make_task("t2", TaskState::Submitted))
470 .await
471 .unwrap();
472 })
473 .await;
474 assert_eq!(store.tenant_count().await, 2);
475 }
476
477 #[tokio::test]
478 async fn max_tenants_limit_enforced() {
479 let config = TenantStoreConfig {
480 per_tenant: TaskStoreConfig::default(),
481 max_tenants: 2,
482 };
483 let store = TenantAwareInMemoryTaskStore::with_config(config);
484
485 TenantContext::scope("t1", async {
487 store
488 .save(make_task("task-a", TaskState::Submitted))
489 .await
490 .unwrap();
491 })
492 .await;
493 TenantContext::scope("t2", async {
494 store
495 .save(make_task("task-b", TaskState::Submitted))
496 .await
497 .unwrap();
498 })
499 .await;
500
501 let result = TenantContext::scope("t3", async {
503 store.save(make_task("task-c", TaskState::Submitted)).await
504 })
505 .await;
506 assert!(
507 result.is_err(),
508 "exceeding max_tenants should return an error"
509 );
510 }
511
512 #[tokio::test]
513 async fn existing_tenant_does_not_count_against_limit() {
514 let config = TenantStoreConfig {
515 per_tenant: TaskStoreConfig::default(),
516 max_tenants: 1,
517 };
518 let store = TenantAwareInMemoryTaskStore::with_config(config);
519
520 TenantContext::scope("only", async {
521 store
522 .save(make_task("t1", TaskState::Submitted))
523 .await
524 .unwrap();
525 store
527 .save(make_task("t2", TaskState::Working))
528 .await
529 .unwrap();
530 })
531 .await;
532
533 let count = TenantContext::scope("only", async { store.count().await.unwrap() }).await;
534 assert_eq!(count, 2, "existing tenant can add more tasks");
535 }
536
537 #[tokio::test]
540 async fn no_tenant_context_uses_default_partition() {
541 let store = TenantAwareInMemoryTaskStore::new();
542
543 store
545 .save(make_task("default-task", TaskState::Submitted))
546 .await
547 .unwrap();
548
549 let fetched = store.get(&TaskId::new("default-task")).await.unwrap();
550 assert!(
551 fetched.is_some(),
552 "task saved without tenant context should be retrievable without context"
553 );
554
555 let not_found = TenantContext::scope("other", async {
557 store.get(&TaskId::new("default-task")).await.unwrap()
558 })
559 .await;
560 assert!(
561 not_found.is_none(),
562 "default partition task should not leak to named tenants"
563 );
564 }
565
566 #[tokio::test]
569 async fn prune_empty_tenants_removes_empty_partitions() {
570 let store = TenantAwareInMemoryTaskStore::new();
571
572 TenantContext::scope("keep", async {
573 store
574 .save(make_task("t1", TaskState::Submitted))
575 .await
576 .unwrap();
577 })
578 .await;
579 TenantContext::scope("remove", async {
580 store
581 .save(make_task("t2", TaskState::Submitted))
582 .await
583 .unwrap();
584 })
585 .await;
586 assert_eq!(store.tenant_count().await, 2);
587
588 TenantContext::scope("remove", async {
590 store.delete(&TaskId::new("t2")).await.unwrap();
591 })
592 .await;
593
594 store.prune_empty_tenants().await;
595 assert_eq!(
596 store.tenant_count().await,
597 1,
598 "empty tenant partition should be pruned"
599 );
600 }
601
602 #[test]
606 fn default_creates_new_tenant_store() {
607 let store = TenantAwareInMemoryTaskStore::default();
608 let rt = tokio::runtime::Builder::new_current_thread()
609 .enable_all()
610 .build()
611 .unwrap();
612 let count = rt.block_on(store.tenant_count());
613 assert_eq!(count, 0, "default store should have no tenants");
614 }
615
616 #[tokio::test]
618 async fn run_eviction_all_runs_without_error() {
619 let store = TenantAwareInMemoryTaskStore::new();
620
621 TenantContext::scope("t1", async {
623 store
624 .save(make_task("task-a", TaskState::Completed))
625 .await
626 .unwrap();
627 })
628 .await;
629 TenantContext::scope("t2", async {
630 store
631 .save(make_task("task-b", TaskState::Working))
632 .await
633 .unwrap();
634 })
635 .await;
636
637 store.run_eviction_all().await;
639 }
640
641 #[tokio::test]
645 async fn get_store_double_check_path() {
646 let store = TenantAwareInMemoryTaskStore::new();
647
648 TenantContext::scope("racer", async {
650 store
651 .save(make_task("t1", TaskState::Submitted))
652 .await
653 .unwrap();
654 store
656 .save(make_task("t2", TaskState::Working))
657 .await
658 .unwrap();
659
660 let count = store.count().await.unwrap();
661 assert_eq!(count, 2, "both tasks should be in same tenant store");
662 })
663 .await;
664
665 assert_eq!(
666 store.tenant_count().await,
667 1,
668 "should have exactly 1 tenant"
669 );
670 }
671
672 #[test]
673 fn default_tenant_store_config() {
674 let cfg = TenantStoreConfig::default();
675 assert_eq!(cfg.max_tenants, 1000);
676 }
677}