1use std::future::Future;
28use std::pin::Pin;
29
30use a2a_protocol_types::error::{A2aError, A2aResult};
31use a2a_protocol_types::params::ListTasksParams;
32use a2a_protocol_types::responses::TaskListResponse;
33use a2a_protocol_types::task::{Task, TaskId};
34use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
35
36use super::task_store::TaskStore;
37use super::tenant::TenantContext;
38
39#[derive(Debug, Clone)]
61pub struct TenantAwareSqliteTaskStore {
62 pool: SqlitePool,
63}
64
65impl TenantAwareSqliteTaskStore {
66 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
72 let pool = sqlite_pool(url).await?;
73 Self::from_pool(pool).await
74 }
75
76 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
82 sqlx::query(
83 "CREATE TABLE IF NOT EXISTS tenant_tasks (
84 tenant_id TEXT NOT NULL DEFAULT '',
85 id TEXT NOT NULL,
86 context_id TEXT NOT NULL,
87 state TEXT NOT NULL,
88 data TEXT NOT NULL,
89 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
90 created_at TEXT NOT NULL DEFAULT (datetime('now')),
91 PRIMARY KEY (tenant_id, id)
92 )",
93 )
94 .execute(&pool)
95 .await?;
96
97 sqlx::query(
98 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query(
104 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
105 )
106 .execute(&pool)
107 .await?;
108
109 sqlx::query(
110 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx_state ON tenant_tasks(tenant_id, context_id, state)",
111 )
112 .execute(&pool)
113 .await?;
114
115 Ok(Self { pool })
116 }
117}
118
119async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
121 use sqlx::sqlite::SqliteConnectOptions;
122 use std::str::FromStr;
123
124 let opts = SqliteConnectOptions::from_str(url)?
125 .pragma("journal_mode", "WAL")
126 .pragma("busy_timeout", "5000")
127 .pragma("synchronous", "NORMAL")
128 .pragma("foreign_keys", "ON")
129 .create_if_missing(true);
130
131 SqlitePoolOptions::new()
132 .max_connections(8)
133 .connect_with(opts)
134 .await
135}
136
137fn to_a2a_error(e: &sqlx::Error) -> A2aError {
138 A2aError::internal(format!("sqlite error: {e}"))
139}
140
141#[allow(clippy::manual_async_fn)]
142impl TaskStore for TenantAwareSqliteTaskStore {
143 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
144 Box::pin(async move {
145 let tenant = TenantContext::current();
146 let id = task.id.0.as_str();
147 let context_id = task.context_id.0.as_str();
148 let state = task.status.state.to_string();
149 let data = serde_json::to_string(&task)
150 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
151
152 sqlx::query(
153 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
154 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))
155 ON CONFLICT(tenant_id, id) DO UPDATE SET
156 context_id = excluded.context_id,
157 state = excluded.state,
158 data = excluded.data,
159 updated_at = datetime('now')",
160 )
161 .bind(&tenant)
162 .bind(id)
163 .bind(context_id)
164 .bind(&state)
165 .bind(&data)
166 .execute(&self.pool)
167 .await
168 .map_err(|e| to_a2a_error(&e))?;
169
170 Ok(())
171 })
172 }
173
174 fn get<'a>(
175 &'a self,
176 id: &'a TaskId,
177 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
178 Box::pin(async move {
179 let tenant = TenantContext::current();
180 let row: Option<(String,)> =
181 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
182 .bind(&tenant)
183 .bind(id.0.as_str())
184 .fetch_optional(&self.pool)
185 .await
186 .map_err(|e| to_a2a_error(&e))?;
187
188 match row {
189 Some((data,)) => {
190 let task: Task = serde_json::from_str(&data)
191 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
192 Ok(Some(task))
193 }
194 None => Ok(None),
195 }
196 })
197 }
198
199 fn list<'a>(
200 &'a self,
201 params: &'a ListTasksParams,
202 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
203 Box::pin(async move {
204 let tenant = TenantContext::current();
205 let mut conditions = vec!["tenant_id = ?1".to_string()];
206 let mut bind_values: Vec<String> = vec![tenant];
207
208 if let Some(ref ctx) = params.context_id {
209 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
210 bind_values.push(ctx.clone());
211 }
212 if let Some(ref status) = params.status {
213 conditions.push(format!("state = ?{}", bind_values.len() + 1));
214 bind_values.push(status.to_string());
215 }
216 if let Some(ref token) = params.page_token {
217 conditions.push(format!("id > ?{}", bind_values.len() + 1));
218 bind_values.push(token.clone());
219 }
220
221 let where_clause = format!("WHERE {}", conditions.join(" AND "));
222
223 let page_size = match params.page_size {
224 Some(0) | None => 50_u32,
225 Some(n) => n.min(1000),
226 };
227
228 let limit = page_size + 1;
229 let sql = format!(
230 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
231 );
232
233 let mut query = sqlx::query_as::<_, (String,)>(&sql);
234 for val in &bind_values {
235 query = query.bind(val);
236 }
237
238 let rows: Vec<(String,)> = query
239 .fetch_all(&self.pool)
240 .await
241 .map_err(|e| to_a2a_error(&e))?;
242
243 let mut tasks: Vec<Task> = rows
244 .into_iter()
245 .map(|(data,)| {
246 serde_json::from_str(&data)
247 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
248 })
249 .collect::<A2aResult<Vec<_>>>()?;
250
251 let next_page_token = if tasks.len() > page_size as usize {
252 tasks.truncate(page_size as usize);
253 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
254 } else {
255 String::new()
256 };
257
258 #[allow(clippy::cast_possible_truncation)]
259 let page_len = tasks.len() as u32;
260 let mut response = TaskListResponse::new(tasks);
261 response.next_page_token = next_page_token;
262 response.page_size = page_len;
263 Ok(response)
264 })
265 }
266
267 fn insert_if_absent<'a>(
268 &'a self,
269 task: Task,
270 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
271 Box::pin(async move {
272 let tenant = TenantContext::current();
273 let id = task.id.0.as_str();
274 let context_id = task.context_id.0.as_str();
275 let state = task.status.state.to_string();
276 let data = serde_json::to_string(&task)
277 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
278
279 let result = sqlx::query(
280 "INSERT OR IGNORE INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
281 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))",
282 )
283 .bind(&tenant)
284 .bind(id)
285 .bind(context_id)
286 .bind(&state)
287 .bind(&data)
288 .execute(&self.pool)
289 .await
290 .map_err(|e| to_a2a_error(&e))?;
291
292 Ok(result.rows_affected() > 0)
293 })
294 }
295
296 fn delete<'a>(
297 &'a self,
298 id: &'a TaskId,
299 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
300 Box::pin(async move {
301 let tenant = TenantContext::current();
302 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
303 .bind(&tenant)
304 .bind(id.0.as_str())
305 .execute(&self.pool)
306 .await
307 .map_err(|e| to_a2a_error(&e))?;
308 Ok(())
309 })
310 }
311
312 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
313 Box::pin(async move {
314 let tenant = TenantContext::current();
315 let row: (i64,) =
316 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = ?1")
317 .bind(&tenant)
318 .fetch_one(&self.pool)
319 .await
320 .map_err(|e| to_a2a_error(&e))?;
321 #[allow(clippy::cast_sign_loss)]
322 Ok(row.0 as u64)
323 })
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
331
332 async fn make_store() -> TenantAwareSqliteTaskStore {
333 TenantAwareSqliteTaskStore::new("sqlite::memory:")
334 .await
335 .expect("failed to create in-memory tenant store")
336 }
337
338 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
339 Task {
340 id: TaskId::new(id),
341 context_id: ContextId::new(ctx),
342 status: TaskStatus::new(state),
343 history: None,
344 artifacts: None,
345 metadata: None,
346 }
347 }
348
349 #[tokio::test]
350 async fn save_and_get_within_tenant() {
351 let store = make_store().await;
352 TenantContext::scope("acme", async {
353 store
354 .save(make_task("t1", "ctx1", TaskState::Submitted))
355 .await
356 .unwrap();
357 let task = store.get(&TaskId::new("t1")).await.unwrap();
358 assert!(
359 task.is_some(),
360 "task should be retrievable within its tenant"
361 );
362 assert_eq!(task.unwrap().id, TaskId::new("t1"));
363 })
364 .await;
365 }
366
367 #[tokio::test]
368 async fn tenant_isolation_get() {
369 let store = make_store().await;
370 TenantContext::scope("tenant-a", async {
371 store
372 .save(make_task("t1", "ctx1", TaskState::Submitted))
373 .await
374 .unwrap();
375 })
376 .await;
377
378 TenantContext::scope("tenant-b", async {
379 let result = store.get(&TaskId::new("t1")).await.unwrap();
380 assert!(result.is_none(), "tenant-b should not see tenant-a's task");
381 })
382 .await;
383 }
384
385 #[tokio::test]
386 async fn tenant_isolation_list() {
387 let store = make_store().await;
388 TenantContext::scope("tenant-a", async {
389 store
390 .save(make_task("t1", "ctx1", TaskState::Submitted))
391 .await
392 .unwrap();
393 store
394 .save(make_task("t2", "ctx1", TaskState::Working))
395 .await
396 .unwrap();
397 })
398 .await;
399
400 TenantContext::scope("tenant-b", async {
401 store
402 .save(make_task("t3", "ctx1", TaskState::Submitted))
403 .await
404 .unwrap();
405 })
406 .await;
407
408 TenantContext::scope("tenant-a", async {
409 let response = store.list(&ListTasksParams::default()).await.unwrap();
410 assert_eq!(
411 response.tasks.len(),
412 2,
413 "tenant-a should see only its 2 tasks"
414 );
415 })
416 .await;
417
418 TenantContext::scope("tenant-b", async {
419 let response = store.list(&ListTasksParams::default()).await.unwrap();
420 assert_eq!(
421 response.tasks.len(),
422 1,
423 "tenant-b should see only its 1 task"
424 );
425 })
426 .await;
427 }
428
429 #[tokio::test]
430 async fn tenant_isolation_count() {
431 let store = make_store().await;
432 TenantContext::scope("tenant-a", async {
433 store
434 .save(make_task("t1", "ctx1", TaskState::Submitted))
435 .await
436 .unwrap();
437 store
438 .save(make_task("t2", "ctx1", TaskState::Working))
439 .await
440 .unwrap();
441 })
442 .await;
443
444 TenantContext::scope("tenant-b", async {
445 let count = store.count().await.unwrap();
446 assert_eq!(count, 0, "tenant-b should have zero tasks");
447 })
448 .await;
449
450 TenantContext::scope("tenant-a", async {
451 let count = store.count().await.unwrap();
452 assert_eq!(count, 2, "tenant-a should have 2 tasks");
453 })
454 .await;
455 }
456
457 #[tokio::test]
458 async fn tenant_isolation_delete() {
459 let store = make_store().await;
460 TenantContext::scope("tenant-a", async {
461 store
462 .save(make_task("t1", "ctx1", TaskState::Submitted))
463 .await
464 .unwrap();
465 })
466 .await;
467
468 TenantContext::scope("tenant-b", async {
470 store.delete(&TaskId::new("t1")).await.unwrap();
471 })
472 .await;
473
474 TenantContext::scope("tenant-a", async {
475 let task = store.get(&TaskId::new("t1")).await.unwrap();
476 assert!(
477 task.is_some(),
478 "tenant-a's task should still exist after tenant-b's delete"
479 );
480 })
481 .await;
482 }
483
484 #[tokio::test]
485 async fn same_task_id_different_tenants() {
486 let store = make_store().await;
487 TenantContext::scope("tenant-a", async {
488 store
489 .save(make_task("t1", "ctx-a", TaskState::Submitted))
490 .await
491 .unwrap();
492 })
493 .await;
494
495 TenantContext::scope("tenant-b", async {
496 store
497 .save(make_task("t1", "ctx-b", TaskState::Working))
498 .await
499 .unwrap();
500 })
501 .await;
502
503 TenantContext::scope("tenant-a", async {
504 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
505 assert_eq!(
506 task.context_id,
507 ContextId::new("ctx-a"),
508 "tenant-a should get its own version of t1"
509 );
510 assert_eq!(task.status.state, TaskState::Submitted);
511 })
512 .await;
513
514 TenantContext::scope("tenant-b", async {
515 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
516 assert_eq!(
517 task.context_id,
518 ContextId::new("ctx-b"),
519 "tenant-b should get its own version of t1"
520 );
521 assert_eq!(task.status.state, TaskState::Working);
522 })
523 .await;
524 }
525
526 #[tokio::test]
527 async fn insert_if_absent_respects_tenant_scope() {
528 let store = make_store().await;
529 TenantContext::scope("tenant-a", async {
530 let inserted = store
531 .insert_if_absent(make_task("t1", "ctx1", TaskState::Submitted))
532 .await
533 .unwrap();
534 assert!(inserted, "first insert should succeed");
535
536 let inserted = store
537 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
538 .await
539 .unwrap();
540 assert!(!inserted, "duplicate insert in same tenant should fail");
541 })
542 .await;
543
544 TenantContext::scope("tenant-b", async {
546 let inserted = store
547 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
548 .await
549 .unwrap();
550 assert!(
551 inserted,
552 "insert of same task id in different tenant should succeed"
553 );
554 })
555 .await;
556 }
557
558 #[tokio::test]
559 async fn list_pagination_within_tenant() {
560 let store = make_store().await;
561 TenantContext::scope("tenant-a", async {
562 for i in 0..5 {
563 store
564 .save(make_task(
565 &format!("task-{i:03}"),
566 "ctx1",
567 TaskState::Submitted,
568 ))
569 .await
570 .unwrap();
571 }
572
573 let params = ListTasksParams {
574 page_size: Some(2),
575 ..Default::default()
576 };
577 let response = store.list(¶ms).await.unwrap();
578 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
579 assert!(
580 !response.next_page_token.is_empty(),
581 "should have a next page token"
582 );
583
584 let params2 = ListTasksParams {
585 page_size: Some(2),
586 page_token: Some(response.next_page_token),
587 ..Default::default()
588 };
589 let response2 = store.list(¶ms2).await.unwrap();
590 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
591 })
592 .await;
593 }
594
595 #[test]
597 fn to_a2a_error_formats_message() {
598 let sqlite_err = sqlx::Error::RowNotFound;
599 let a2a_err = to_a2a_error(&sqlite_err);
600 let msg = format!("{a2a_err}");
601 assert!(
602 msg.contains("sqlite error"),
603 "error message should contain 'sqlite error': {msg}"
604 );
605 }
606
607 #[tokio::test]
608 async fn default_tenant_context_uses_empty_string() {
609 let store = make_store().await;
610 store
612 .save(make_task("t1", "ctx1", TaskState::Submitted))
613 .await
614 .unwrap();
615 let task = store.get(&TaskId::new("t1")).await.unwrap();
616 assert!(task.is_some(), "default (empty) tenant should work");
617 }
618}