1use std::future::Future;
28use std::pin::Pin;
29
30use a2a_protocol_types::error::{A2aError, A2aResult};
31use a2a_protocol_types::params::ListTasksParams;
32use a2a_protocol_types::responses::TaskListResponse;
33use a2a_protocol_types::task::{Task, TaskId};
34use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
35
36use super::task_store::TaskStore;
37use super::tenant::TenantContext;
38
39#[derive(Debug, Clone)]
61pub struct TenantAwareSqliteTaskStore {
62 pool: SqlitePool,
63}
64
65impl TenantAwareSqliteTaskStore {
66 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
72 let pool = sqlite_pool(url).await?;
73 Self::from_pool(pool).await
74 }
75
76 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
82 sqlx::query(
83 "CREATE TABLE IF NOT EXISTS tenant_tasks (
84 tenant_id TEXT NOT NULL DEFAULT '',
85 id TEXT NOT NULL,
86 context_id TEXT NOT NULL,
87 state TEXT NOT NULL,
88 data TEXT NOT NULL,
89 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
90 created_at TEXT NOT NULL DEFAULT (datetime('now')),
91 PRIMARY KEY (tenant_id, id)
92 )",
93 )
94 .execute(&pool)
95 .await?;
96
97 sqlx::query(
98 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query(
104 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
105 )
106 .execute(&pool)
107 .await?;
108
109 sqlx::query(
110 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx_state ON tenant_tasks(tenant_id, context_id, state)",
111 )
112 .execute(&pool)
113 .await?;
114
115 Ok(Self { pool })
116 }
117}
118
119async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
121 use sqlx::sqlite::SqliteConnectOptions;
122 use std::str::FromStr;
123
124 let opts = SqliteConnectOptions::from_str(url)?
125 .pragma("journal_mode", "WAL")
126 .pragma("busy_timeout", "5000")
127 .pragma("synchronous", "NORMAL")
128 .pragma("foreign_keys", "ON")
129 .create_if_missing(true);
130
131 SqlitePoolOptions::new()
132 .max_connections(8)
133 .connect_with(opts)
134 .await
135}
136
137fn to_a2a_error(e: &sqlx::Error) -> A2aError {
138 A2aError::internal(format!("sqlite error: {e}"))
139}
140
141#[allow(clippy::manual_async_fn)]
142impl TaskStore for TenantAwareSqliteTaskStore {
143 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
144 Box::pin(async move {
145 let tenant = TenantContext::current();
146 let id = task.id.0.as_str();
147 let context_id = task.context_id.0.as_str();
148 let state = task.status.state.to_string();
149 let data = serde_json::to_string(&task)
150 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
151
152 sqlx::query(
153 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
154 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))
155 ON CONFLICT(tenant_id, id) DO UPDATE SET
156 context_id = excluded.context_id,
157 state = excluded.state,
158 data = excluded.data,
159 updated_at = datetime('now')",
160 )
161 .bind(&tenant)
162 .bind(id)
163 .bind(context_id)
164 .bind(&state)
165 .bind(&data)
166 .execute(&self.pool)
167 .await
168 .map_err(|e| to_a2a_error(&e))?;
169
170 Ok(())
171 })
172 }
173
174 fn get<'a>(
175 &'a self,
176 id: &'a TaskId,
177 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
178 Box::pin(async move {
179 let tenant = TenantContext::current();
180 let row: Option<(String,)> =
181 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
182 .bind(&tenant)
183 .bind(id.0.as_str())
184 .fetch_optional(&self.pool)
185 .await
186 .map_err(|e| to_a2a_error(&e))?;
187
188 match row {
189 Some((data,)) => {
190 let task: Task = serde_json::from_str(&data)
191 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
192 Ok(Some(task))
193 }
194 None => Ok(None),
195 }
196 })
197 }
198
199 fn list<'a>(
200 &'a self,
201 params: &'a ListTasksParams,
202 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
203 Box::pin(async move {
204 let tenant = TenantContext::current();
205 let mut conditions = vec!["tenant_id = ?1".to_string()];
206 let mut bind_values: Vec<String> = vec![tenant];
207
208 if let Some(ref ctx) = params.context_id {
209 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
210 bind_values.push(ctx.clone());
211 }
212 if let Some(ref status) = params.status {
213 conditions.push(format!("state = ?{}", bind_values.len() + 1));
214 bind_values.push(status.to_string());
215 }
216 if let Some(ref token) = params.page_token {
217 conditions.push(format!("id > ?{}", bind_values.len() + 1));
218 bind_values.push(token.clone());
219 }
220
221 let where_clause = format!("WHERE {}", conditions.join(" AND "));
222
223 let page_size = match params.page_size {
224 Some(0) | None => 50_u32,
225 Some(n) => n.min(1000),
226 };
227
228 let limit = page_size + 1;
229 let sql = format!(
230 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
231 );
232
233 let mut query = sqlx::query_as::<_, (String,)>(&sql);
234 for val in &bind_values {
235 query = query.bind(val);
236 }
237
238 let rows: Vec<(String,)> = query
239 .fetch_all(&self.pool)
240 .await
241 .map_err(|e| to_a2a_error(&e))?;
242
243 let mut tasks: Vec<Task> = rows
244 .into_iter()
245 .map(|(data,)| {
246 serde_json::from_str(&data)
247 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
248 })
249 .collect::<A2aResult<Vec<_>>>()?;
250
251 let next_page_token = if tasks.len() > page_size as usize {
252 tasks.truncate(page_size as usize);
253 tasks.last().map(|t| t.id.0.clone())
254 } else {
255 None
256 };
257
258 let mut response = TaskListResponse::new(tasks);
259 response.next_page_token = next_page_token;
260 Ok(response)
261 })
262 }
263
264 fn insert_if_absent<'a>(
265 &'a self,
266 task: Task,
267 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
268 Box::pin(async move {
269 let tenant = TenantContext::current();
270 let id = task.id.0.as_str();
271 let context_id = task.context_id.0.as_str();
272 let state = task.status.state.to_string();
273 let data = serde_json::to_string(&task)
274 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
275
276 let result = sqlx::query(
277 "INSERT OR IGNORE INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
278 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))",
279 )
280 .bind(&tenant)
281 .bind(id)
282 .bind(context_id)
283 .bind(&state)
284 .bind(&data)
285 .execute(&self.pool)
286 .await
287 .map_err(|e| to_a2a_error(&e))?;
288
289 Ok(result.rows_affected() > 0)
290 })
291 }
292
293 fn delete<'a>(
294 &'a self,
295 id: &'a TaskId,
296 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
297 Box::pin(async move {
298 let tenant = TenantContext::current();
299 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
300 .bind(&tenant)
301 .bind(id.0.as_str())
302 .execute(&self.pool)
303 .await
304 .map_err(|e| to_a2a_error(&e))?;
305 Ok(())
306 })
307 }
308
309 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
310 Box::pin(async move {
311 let tenant = TenantContext::current();
312 let row: (i64,) =
313 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = ?1")
314 .bind(&tenant)
315 .fetch_one(&self.pool)
316 .await
317 .map_err(|e| to_a2a_error(&e))?;
318 #[allow(clippy::cast_sign_loss)]
319 Ok(row.0 as u64)
320 })
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
328
329 async fn make_store() -> TenantAwareSqliteTaskStore {
330 TenantAwareSqliteTaskStore::new("sqlite::memory:")
331 .await
332 .expect("failed to create in-memory tenant store")
333 }
334
335 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
336 Task {
337 id: TaskId::new(id),
338 context_id: ContextId::new(ctx),
339 status: TaskStatus::new(state),
340 history: None,
341 artifacts: None,
342 metadata: None,
343 }
344 }
345
346 #[tokio::test]
347 async fn save_and_get_within_tenant() {
348 let store = make_store().await;
349 TenantContext::scope("acme", async {
350 store
351 .save(make_task("t1", "ctx1", TaskState::Submitted))
352 .await
353 .unwrap();
354 let task = store.get(&TaskId::new("t1")).await.unwrap();
355 assert!(
356 task.is_some(),
357 "task should be retrievable within its tenant"
358 );
359 assert_eq!(task.unwrap().id, TaskId::new("t1"));
360 })
361 .await;
362 }
363
364 #[tokio::test]
365 async fn tenant_isolation_get() {
366 let store = make_store().await;
367 TenantContext::scope("tenant-a", async {
368 store
369 .save(make_task("t1", "ctx1", TaskState::Submitted))
370 .await
371 .unwrap();
372 })
373 .await;
374
375 TenantContext::scope("tenant-b", async {
376 let result = store.get(&TaskId::new("t1")).await.unwrap();
377 assert!(result.is_none(), "tenant-b should not see tenant-a's task");
378 })
379 .await;
380 }
381
382 #[tokio::test]
383 async fn tenant_isolation_list() {
384 let store = make_store().await;
385 TenantContext::scope("tenant-a", async {
386 store
387 .save(make_task("t1", "ctx1", TaskState::Submitted))
388 .await
389 .unwrap();
390 store
391 .save(make_task("t2", "ctx1", TaskState::Working))
392 .await
393 .unwrap();
394 })
395 .await;
396
397 TenantContext::scope("tenant-b", async {
398 store
399 .save(make_task("t3", "ctx1", TaskState::Submitted))
400 .await
401 .unwrap();
402 })
403 .await;
404
405 TenantContext::scope("tenant-a", async {
406 let response = store.list(&ListTasksParams::default()).await.unwrap();
407 assert_eq!(
408 response.tasks.len(),
409 2,
410 "tenant-a should see only its 2 tasks"
411 );
412 })
413 .await;
414
415 TenantContext::scope("tenant-b", async {
416 let response = store.list(&ListTasksParams::default()).await.unwrap();
417 assert_eq!(
418 response.tasks.len(),
419 1,
420 "tenant-b should see only its 1 task"
421 );
422 })
423 .await;
424 }
425
426 #[tokio::test]
427 async fn tenant_isolation_count() {
428 let store = make_store().await;
429 TenantContext::scope("tenant-a", async {
430 store
431 .save(make_task("t1", "ctx1", TaskState::Submitted))
432 .await
433 .unwrap();
434 store
435 .save(make_task("t2", "ctx1", TaskState::Working))
436 .await
437 .unwrap();
438 })
439 .await;
440
441 TenantContext::scope("tenant-b", async {
442 let count = store.count().await.unwrap();
443 assert_eq!(count, 0, "tenant-b should have zero tasks");
444 })
445 .await;
446
447 TenantContext::scope("tenant-a", async {
448 let count = store.count().await.unwrap();
449 assert_eq!(count, 2, "tenant-a should have 2 tasks");
450 })
451 .await;
452 }
453
454 #[tokio::test]
455 async fn tenant_isolation_delete() {
456 let store = make_store().await;
457 TenantContext::scope("tenant-a", async {
458 store
459 .save(make_task("t1", "ctx1", TaskState::Submitted))
460 .await
461 .unwrap();
462 })
463 .await;
464
465 TenantContext::scope("tenant-b", async {
467 store.delete(&TaskId::new("t1")).await.unwrap();
468 })
469 .await;
470
471 TenantContext::scope("tenant-a", async {
472 let task = store.get(&TaskId::new("t1")).await.unwrap();
473 assert!(
474 task.is_some(),
475 "tenant-a's task should still exist after tenant-b's delete"
476 );
477 })
478 .await;
479 }
480
481 #[tokio::test]
482 async fn same_task_id_different_tenants() {
483 let store = make_store().await;
484 TenantContext::scope("tenant-a", async {
485 store
486 .save(make_task("t1", "ctx-a", TaskState::Submitted))
487 .await
488 .unwrap();
489 })
490 .await;
491
492 TenantContext::scope("tenant-b", async {
493 store
494 .save(make_task("t1", "ctx-b", TaskState::Working))
495 .await
496 .unwrap();
497 })
498 .await;
499
500 TenantContext::scope("tenant-a", async {
501 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
502 assert_eq!(
503 task.context_id,
504 ContextId::new("ctx-a"),
505 "tenant-a should get its own version of t1"
506 );
507 assert_eq!(task.status.state, TaskState::Submitted);
508 })
509 .await;
510
511 TenantContext::scope("tenant-b", async {
512 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
513 assert_eq!(
514 task.context_id,
515 ContextId::new("ctx-b"),
516 "tenant-b should get its own version of t1"
517 );
518 assert_eq!(task.status.state, TaskState::Working);
519 })
520 .await;
521 }
522
523 #[tokio::test]
524 async fn insert_if_absent_respects_tenant_scope() {
525 let store = make_store().await;
526 TenantContext::scope("tenant-a", async {
527 let inserted = store
528 .insert_if_absent(make_task("t1", "ctx1", TaskState::Submitted))
529 .await
530 .unwrap();
531 assert!(inserted, "first insert should succeed");
532
533 let inserted = store
534 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
535 .await
536 .unwrap();
537 assert!(!inserted, "duplicate insert in same tenant should fail");
538 })
539 .await;
540
541 TenantContext::scope("tenant-b", async {
543 let inserted = store
544 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
545 .await
546 .unwrap();
547 assert!(
548 inserted,
549 "insert of same task id in different tenant should succeed"
550 );
551 })
552 .await;
553 }
554
555 #[tokio::test]
556 async fn list_pagination_within_tenant() {
557 let store = make_store().await;
558 TenantContext::scope("tenant-a", async {
559 for i in 0..5 {
560 store
561 .save(make_task(
562 &format!("task-{i:03}"),
563 "ctx1",
564 TaskState::Submitted,
565 ))
566 .await
567 .unwrap();
568 }
569
570 let params = ListTasksParams {
571 page_size: Some(2),
572 ..Default::default()
573 };
574 let response = store.list(¶ms).await.unwrap();
575 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
576 assert!(
577 response.next_page_token.is_some(),
578 "should have a next page token"
579 );
580
581 let params2 = ListTasksParams {
582 page_size: Some(2),
583 page_token: response.next_page_token,
584 ..Default::default()
585 };
586 let response2 = store.list(¶ms2).await.unwrap();
587 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
588 })
589 .await;
590 }
591
592 #[test]
594 fn to_a2a_error_formats_message() {
595 let sqlite_err = sqlx::Error::RowNotFound;
596 let a2a_err = to_a2a_error(&sqlite_err);
597 let msg = format!("{a2a_err}");
598 assert!(
599 msg.contains("sqlite error"),
600 "error message should contain 'sqlite error': {msg}"
601 );
602 }
603
604 #[tokio::test]
605 async fn default_tenant_context_uses_empty_string() {
606 let store = make_store().await;
607 store
609 .save(make_task("t1", "ctx1", TaskState::Submitted))
610 .await
611 .unwrap();
612 let task = store.get(&TaskId::new("t1")).await.unwrap();
613 assert!(task.is_some(), "default (empty) tenant should work");
614 }
615}