1use std::future::Future;
28use std::pin::Pin;
29
30use a2a_protocol_types::error::{A2aError, A2aResult};
31use a2a_protocol_types::params::ListTasksParams;
32use a2a_protocol_types::responses::TaskListResponse;
33use a2a_protocol_types::task::{Task, TaskId};
34use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
35
36use super::task_store::TaskStore;
37use super::tenant::TenantContext;
38
39#[derive(Debug, Clone)]
61pub struct TenantAwareSqliteTaskStore {
62 pool: SqlitePool,
63}
64
65impl TenantAwareSqliteTaskStore {
66 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
72 let pool = sqlite_pool(url).await?;
73 Self::from_pool(pool).await
74 }
75
76 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
82 sqlx::query(
83 "CREATE TABLE IF NOT EXISTS tenant_tasks (
84 tenant_id TEXT NOT NULL DEFAULT '',
85 id TEXT NOT NULL,
86 context_id TEXT NOT NULL,
87 state TEXT NOT NULL,
88 data TEXT NOT NULL,
89 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
90 created_at TEXT NOT NULL DEFAULT (datetime('now')),
91 PRIMARY KEY (tenant_id, id)
92 )",
93 )
94 .execute(&pool)
95 .await?;
96
97 sqlx::query(
98 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query(
104 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
105 )
106 .execute(&pool)
107 .await?;
108
109 sqlx::query(
110 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx_state ON tenant_tasks(tenant_id, context_id, state)",
111 )
112 .execute(&pool)
113 .await?;
114
115 Ok(Self { pool })
116 }
117}
118
119async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
121 use sqlx::sqlite::SqliteConnectOptions;
122 use std::str::FromStr;
123
124 let opts = SqliteConnectOptions::from_str(url)?
125 .pragma("journal_mode", "WAL")
126 .pragma("busy_timeout", "5000")
127 .pragma("synchronous", "NORMAL")
128 .pragma("foreign_keys", "ON")
129 .create_if_missing(true);
130
131 SqlitePoolOptions::new()
132 .max_connections(8)
133 .connect_with(opts)
134 .await
135}
136
137fn to_a2a_error(e: &sqlx::Error) -> A2aError {
138 A2aError::internal(format!("sqlite error: {e}"))
139}
140
141#[allow(clippy::manual_async_fn)]
142impl TaskStore for TenantAwareSqliteTaskStore {
143 fn save<'a>(
144 &'a self,
145 task: &'a Task,
146 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
147 Box::pin(async move {
148 let tenant = TenantContext::current();
149 let id = task.id.0.as_str();
150 let context_id = task.context_id.0.as_str();
151 let state = task.status.state.to_string();
152 let data = serde_json::to_string(task)
153 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
154
155 sqlx::query(
156 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
157 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))
158 ON CONFLICT(tenant_id, id) DO UPDATE SET
159 context_id = excluded.context_id,
160 state = excluded.state,
161 data = excluded.data,
162 updated_at = datetime('now')",
163 )
164 .bind(&tenant)
165 .bind(id)
166 .bind(context_id)
167 .bind(&state)
168 .bind(&data)
169 .execute(&self.pool)
170 .await
171 .map_err(|e| to_a2a_error(&e))?;
172
173 Ok(())
174 })
175 }
176
177 fn get<'a>(
178 &'a self,
179 id: &'a TaskId,
180 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
181 Box::pin(async move {
182 let tenant = TenantContext::current();
183 let row: Option<(String,)> =
184 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
185 .bind(&tenant)
186 .bind(id.0.as_str())
187 .fetch_optional(&self.pool)
188 .await
189 .map_err(|e| to_a2a_error(&e))?;
190
191 match row {
192 Some((data,)) => {
193 let task: Task = serde_json::from_str(&data)
194 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
195 Ok(Some(task))
196 }
197 None => Ok(None),
198 }
199 })
200 }
201
202 fn list<'a>(
203 &'a self,
204 params: &'a ListTasksParams,
205 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
206 Box::pin(async move {
207 let tenant = TenantContext::current();
208 let mut conditions = vec!["tenant_id = ?1".to_string()];
209 let mut bind_values: Vec<String> = vec![tenant];
210
211 if let Some(ref ctx) = params.context_id {
212 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
213 bind_values.push(ctx.clone());
214 }
215 if let Some(ref status) = params.status {
216 conditions.push(format!("state = ?{}", bind_values.len() + 1));
217 bind_values.push(status.to_string());
218 }
219 if let Some(ref token) = params.page_token {
220 conditions.push(format!("id > ?{}", bind_values.len() + 1));
221 bind_values.push(token.clone());
222 }
223
224 let where_clause = format!("WHERE {}", conditions.join(" AND "));
225
226 let page_size = match params.page_size {
227 Some(0) | None => 50_u32,
228 Some(n) => n.min(1000),
229 };
230
231 let limit = page_size + 1;
232 let sql = format!(
233 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
234 );
235
236 let mut query = sqlx::query_as::<_, (String,)>(&sql);
237 for val in &bind_values {
238 query = query.bind(val);
239 }
240
241 let rows: Vec<(String,)> = query
242 .fetch_all(&self.pool)
243 .await
244 .map_err(|e| to_a2a_error(&e))?;
245
246 let mut tasks: Vec<Task> = rows
247 .into_iter()
248 .map(|(data,)| {
249 serde_json::from_str(&data)
250 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
251 })
252 .collect::<A2aResult<Vec<_>>>()?;
253
254 let next_page_token = if tasks.len() > page_size as usize {
255 tasks.truncate(page_size as usize);
256 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
257 } else {
258 String::new()
259 };
260
261 #[allow(clippy::cast_possible_truncation)]
262 let page_len = tasks.len() as u32;
263 let mut response = TaskListResponse::new(tasks);
264 response.next_page_token = next_page_token;
265 response.page_size = page_len;
266 Ok(response)
267 })
268 }
269
270 fn insert_if_absent<'a>(
271 &'a self,
272 task: &'a Task,
273 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
274 Box::pin(async move {
275 let tenant = TenantContext::current();
276 let id = task.id.0.as_str();
277 let context_id = task.context_id.0.as_str();
278 let state = task.status.state.to_string();
279 let data = serde_json::to_string(task)
280 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
281
282 let result = sqlx::query(
283 "INSERT OR IGNORE INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
284 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))",
285 )
286 .bind(&tenant)
287 .bind(id)
288 .bind(context_id)
289 .bind(&state)
290 .bind(&data)
291 .execute(&self.pool)
292 .await
293 .map_err(|e| to_a2a_error(&e))?;
294
295 Ok(result.rows_affected() > 0)
296 })
297 }
298
299 fn delete<'a>(
300 &'a self,
301 id: &'a TaskId,
302 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
303 Box::pin(async move {
304 let tenant = TenantContext::current();
305 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
306 .bind(&tenant)
307 .bind(id.0.as_str())
308 .execute(&self.pool)
309 .await
310 .map_err(|e| to_a2a_error(&e))?;
311 Ok(())
312 })
313 }
314
315 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
316 Box::pin(async move {
317 let tenant = TenantContext::current();
318 let row: (i64,) =
319 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = ?1")
320 .bind(&tenant)
321 .fetch_one(&self.pool)
322 .await
323 .map_err(|e| to_a2a_error(&e))?;
324 #[allow(clippy::cast_sign_loss)]
325 Ok(row.0 as u64)
326 })
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
334
335 async fn make_store() -> TenantAwareSqliteTaskStore {
336 TenantAwareSqliteTaskStore::new("sqlite::memory:")
337 .await
338 .expect("failed to create in-memory tenant store")
339 }
340
341 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
342 Task {
343 id: TaskId::new(id),
344 context_id: ContextId::new(ctx),
345 status: TaskStatus::new(state),
346 history: None,
347 artifacts: None,
348 metadata: None,
349 }
350 }
351
352 #[tokio::test]
353 async fn save_and_get_within_tenant() {
354 let store = make_store().await;
355 TenantContext::scope("acme", async {
356 store
357 .save(&make_task("t1", "ctx1", TaskState::Submitted))
358 .await
359 .unwrap();
360 let task = store.get(&TaskId::new("t1")).await.unwrap();
361 assert!(
362 task.is_some(),
363 "task should be retrievable within its tenant"
364 );
365 assert_eq!(task.unwrap().id, TaskId::new("t1"));
366 })
367 .await;
368 }
369
370 #[tokio::test]
371 async fn tenant_isolation_get() {
372 let store = make_store().await;
373 TenantContext::scope("tenant-a", async {
374 store
375 .save(&make_task("t1", "ctx1", TaskState::Submitted))
376 .await
377 .unwrap();
378 })
379 .await;
380
381 TenantContext::scope("tenant-b", async {
382 let result = store.get(&TaskId::new("t1")).await.unwrap();
383 assert!(result.is_none(), "tenant-b should not see tenant-a's task");
384 })
385 .await;
386 }
387
388 #[tokio::test]
389 async fn tenant_isolation_list() {
390 let store = make_store().await;
391 TenantContext::scope("tenant-a", async {
392 store
393 .save(&make_task("t1", "ctx1", TaskState::Submitted))
394 .await
395 .unwrap();
396 store
397 .save(&make_task("t2", "ctx1", TaskState::Working))
398 .await
399 .unwrap();
400 })
401 .await;
402
403 TenantContext::scope("tenant-b", async {
404 store
405 .save(&make_task("t3", "ctx1", TaskState::Submitted))
406 .await
407 .unwrap();
408 })
409 .await;
410
411 TenantContext::scope("tenant-a", async {
412 let response = store.list(&ListTasksParams::default()).await.unwrap();
413 assert_eq!(
414 response.tasks.len(),
415 2,
416 "tenant-a should see only its 2 tasks"
417 );
418 })
419 .await;
420
421 TenantContext::scope("tenant-b", async {
422 let response = store.list(&ListTasksParams::default()).await.unwrap();
423 assert_eq!(
424 response.tasks.len(),
425 1,
426 "tenant-b should see only its 1 task"
427 );
428 })
429 .await;
430 }
431
432 #[tokio::test]
433 async fn tenant_isolation_count() {
434 let store = make_store().await;
435 TenantContext::scope("tenant-a", async {
436 store
437 .save(&make_task("t1", "ctx1", TaskState::Submitted))
438 .await
439 .unwrap();
440 store
441 .save(&make_task("t2", "ctx1", TaskState::Working))
442 .await
443 .unwrap();
444 })
445 .await;
446
447 TenantContext::scope("tenant-b", async {
448 let count = store.count().await.unwrap();
449 assert_eq!(count, 0, "tenant-b should have zero tasks");
450 })
451 .await;
452
453 TenantContext::scope("tenant-a", async {
454 let count = store.count().await.unwrap();
455 assert_eq!(count, 2, "tenant-a should have 2 tasks");
456 })
457 .await;
458 }
459
460 #[tokio::test]
461 async fn tenant_isolation_delete() {
462 let store = make_store().await;
463 TenantContext::scope("tenant-a", async {
464 store
465 .save(&make_task("t1", "ctx1", TaskState::Submitted))
466 .await
467 .unwrap();
468 })
469 .await;
470
471 TenantContext::scope("tenant-b", async {
473 store.delete(&TaskId::new("t1")).await.unwrap();
474 })
475 .await;
476
477 TenantContext::scope("tenant-a", async {
478 let task = store.get(&TaskId::new("t1")).await.unwrap();
479 assert!(
480 task.is_some(),
481 "tenant-a's task should still exist after tenant-b's delete"
482 );
483 })
484 .await;
485 }
486
487 #[tokio::test]
488 async fn same_task_id_different_tenants() {
489 let store = make_store().await;
490 TenantContext::scope("tenant-a", async {
491 store
492 .save(&make_task("t1", "ctx-a", TaskState::Submitted))
493 .await
494 .unwrap();
495 })
496 .await;
497
498 TenantContext::scope("tenant-b", async {
499 store
500 .save(&make_task("t1", "ctx-b", TaskState::Working))
501 .await
502 .unwrap();
503 })
504 .await;
505
506 TenantContext::scope("tenant-a", async {
507 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
508 assert_eq!(
509 task.context_id,
510 ContextId::new("ctx-a"),
511 "tenant-a should get its own version of t1"
512 );
513 assert_eq!(task.status.state, TaskState::Submitted);
514 })
515 .await;
516
517 TenantContext::scope("tenant-b", async {
518 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
519 assert_eq!(
520 task.context_id,
521 ContextId::new("ctx-b"),
522 "tenant-b should get its own version of t1"
523 );
524 assert_eq!(task.status.state, TaskState::Working);
525 })
526 .await;
527 }
528
529 #[tokio::test]
530 async fn insert_if_absent_respects_tenant_scope() {
531 let store = make_store().await;
532 TenantContext::scope("tenant-a", async {
533 let inserted = store
534 .insert_if_absent(&make_task("t1", "ctx1", TaskState::Submitted))
535 .await
536 .unwrap();
537 assert!(inserted, "first insert should succeed");
538
539 let inserted = store
540 .insert_if_absent(&make_task("t1", "ctx1", TaskState::Working))
541 .await
542 .unwrap();
543 assert!(!inserted, "duplicate insert in same tenant should fail");
544 })
545 .await;
546
547 TenantContext::scope("tenant-b", async {
549 let inserted = store
550 .insert_if_absent(&make_task("t1", "ctx1", TaskState::Working))
551 .await
552 .unwrap();
553 assert!(
554 inserted,
555 "insert of same task id in different tenant should succeed"
556 );
557 })
558 .await;
559 }
560
561 #[tokio::test]
562 async fn list_pagination_within_tenant() {
563 let store = make_store().await;
564 TenantContext::scope("tenant-a", async {
565 for i in 0..5 {
566 store
567 .save(&make_task(
568 &format!("task-{i:03}"),
569 "ctx1",
570 TaskState::Submitted,
571 ))
572 .await
573 .unwrap();
574 }
575
576 let params = ListTasksParams {
577 page_size: Some(2),
578 ..Default::default()
579 };
580 let response = store.list(¶ms).await.unwrap();
581 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
582 assert!(
583 !response.next_page_token.is_empty(),
584 "should have a next page token"
585 );
586
587 let params2 = ListTasksParams {
588 page_size: Some(2),
589 page_token: Some(response.next_page_token),
590 ..Default::default()
591 };
592 let response2 = store.list(¶ms2).await.unwrap();
593 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
594 })
595 .await;
596 }
597
598 #[test]
600 fn to_a2a_error_formats_message() {
601 let sqlite_err = sqlx::Error::RowNotFound;
602 let a2a_err = to_a2a_error(&sqlite_err);
603 let msg = format!("{a2a_err}");
604 assert!(
605 msg.contains("sqlite error"),
606 "error message should contain 'sqlite error': {msg}"
607 );
608 }
609
610 #[tokio::test]
611 async fn default_tenant_context_uses_empty_string() {
612 let store = make_store().await;
613 store
615 .save(&make_task("t1", "ctx1", TaskState::Submitted))
616 .await
617 .unwrap();
618 let task = store.get(&TaskId::new("t1")).await.unwrap();
619 assert!(task.is_some(), "default (empty) tenant should work");
620 }
621}