1use std::future::Future;
28use std::pin::Pin;
29
30use a2a_protocol_types::error::{A2aError, A2aResult};
31use a2a_protocol_types::params::ListTasksParams;
32use a2a_protocol_types::responses::TaskListResponse;
33use a2a_protocol_types::task::{Task, TaskId};
34use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
35
36use super::task_store::TaskStore;
37use super::tenant::TenantContext;
38
39#[derive(Debug, Clone)]
61pub struct TenantAwareSqliteTaskStore {
62 pool: SqlitePool,
63}
64
65impl TenantAwareSqliteTaskStore {
66 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
72 let pool = sqlite_pool(url).await?;
73 Self::from_pool(pool).await
74 }
75
76 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
82 sqlx::query(
83 "CREATE TABLE IF NOT EXISTS tenant_tasks (
84 tenant_id TEXT NOT NULL DEFAULT '',
85 id TEXT NOT NULL,
86 context_id TEXT NOT NULL,
87 state TEXT NOT NULL,
88 data TEXT NOT NULL,
89 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
90 PRIMARY KEY (tenant_id, id)
91 )",
92 )
93 .execute(&pool)
94 .await?;
95
96 sqlx::query(
97 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
98 )
99 .execute(&pool)
100 .await?;
101
102 sqlx::query(
103 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
104 )
105 .execute(&pool)
106 .await?;
107
108 Ok(Self { pool })
109 }
110}
111
112async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
114 use sqlx::sqlite::SqliteConnectOptions;
115 use std::str::FromStr;
116
117 let opts = SqliteConnectOptions::from_str(url)?
118 .pragma("journal_mode", "WAL")
119 .pragma("busy_timeout", "5000")
120 .pragma("synchronous", "NORMAL")
121 .pragma("foreign_keys", "ON")
122 .create_if_missing(true);
123
124 SqlitePoolOptions::new()
125 .max_connections(8)
126 .connect_with(opts)
127 .await
128}
129
130fn to_a2a_error(e: &sqlx::Error) -> A2aError {
131 A2aError::internal(format!("sqlite error: {e}"))
132}
133
134#[allow(clippy::manual_async_fn)]
135impl TaskStore for TenantAwareSqliteTaskStore {
136 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
137 Box::pin(async move {
138 let tenant = TenantContext::current();
139 let id = task.id.0.as_str();
140 let context_id = task.context_id.0.as_str();
141 let state = task.status.state.to_string();
142 let data = serde_json::to_string(&task)
143 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
144
145 sqlx::query(
146 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
147 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))
148 ON CONFLICT(tenant_id, id) DO UPDATE SET
149 context_id = excluded.context_id,
150 state = excluded.state,
151 data = excluded.data,
152 updated_at = datetime('now')",
153 )
154 .bind(&tenant)
155 .bind(id)
156 .bind(context_id)
157 .bind(&state)
158 .bind(&data)
159 .execute(&self.pool)
160 .await
161 .map_err(|e| to_a2a_error(&e))?;
162
163 Ok(())
164 })
165 }
166
167 fn get<'a>(
168 &'a self,
169 id: &'a TaskId,
170 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
171 Box::pin(async move {
172 let tenant = TenantContext::current();
173 let row: Option<(String,)> =
174 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
175 .bind(&tenant)
176 .bind(id.0.as_str())
177 .fetch_optional(&self.pool)
178 .await
179 .map_err(|e| to_a2a_error(&e))?;
180
181 match row {
182 Some((data,)) => {
183 let task: Task = serde_json::from_str(&data)
184 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
185 Ok(Some(task))
186 }
187 None => Ok(None),
188 }
189 })
190 }
191
192 fn list<'a>(
193 &'a self,
194 params: &'a ListTasksParams,
195 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
196 Box::pin(async move {
197 let tenant = TenantContext::current();
198 let mut conditions = vec!["tenant_id = ?1".to_string()];
199 let mut bind_values: Vec<String> = vec![tenant];
200
201 if let Some(ref ctx) = params.context_id {
202 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
203 bind_values.push(ctx.clone());
204 }
205 if let Some(ref status) = params.status {
206 conditions.push(format!("state = ?{}", bind_values.len() + 1));
207 bind_values.push(status.to_string());
208 }
209 if let Some(ref token) = params.page_token {
210 conditions.push(format!("id > ?{}", bind_values.len() + 1));
211 bind_values.push(token.clone());
212 }
213
214 let where_clause = format!("WHERE {}", conditions.join(" AND "));
215
216 let page_size = match params.page_size {
217 Some(0) | None => 50_u32,
218 Some(n) => n.min(1000),
219 };
220
221 let limit = page_size + 1;
222 let sql = format!(
223 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
224 );
225
226 let mut query = sqlx::query_as::<_, (String,)>(&sql);
227 for val in &bind_values {
228 query = query.bind(val);
229 }
230
231 let rows: Vec<(String,)> = query
232 .fetch_all(&self.pool)
233 .await
234 .map_err(|e| to_a2a_error(&e))?;
235
236 let mut tasks: Vec<Task> = rows
237 .into_iter()
238 .map(|(data,)| {
239 serde_json::from_str(&data)
240 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
241 })
242 .collect::<A2aResult<Vec<_>>>()?;
243
244 let next_page_token = if tasks.len() > page_size as usize {
245 tasks.truncate(page_size as usize);
246 tasks.last().map(|t| t.id.0.clone())
247 } else {
248 None
249 };
250
251 let mut response = TaskListResponse::new(tasks);
252 response.next_page_token = next_page_token;
253 Ok(response)
254 })
255 }
256
257 fn insert_if_absent<'a>(
258 &'a self,
259 task: Task,
260 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
261 Box::pin(async move {
262 let tenant = TenantContext::current();
263 let id = task.id.0.as_str();
264 let context_id = task.context_id.0.as_str();
265 let state = task.status.state.to_string();
266 let data = serde_json::to_string(&task)
267 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
268
269 let result = sqlx::query(
270 "INSERT OR IGNORE INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
271 VALUES (?1, ?2, ?3, ?4, ?5, datetime('now'))",
272 )
273 .bind(&tenant)
274 .bind(id)
275 .bind(context_id)
276 .bind(&state)
277 .bind(&data)
278 .execute(&self.pool)
279 .await
280 .map_err(|e| to_a2a_error(&e))?;
281
282 Ok(result.rows_affected() > 0)
283 })
284 }
285
286 fn delete<'a>(
287 &'a self,
288 id: &'a TaskId,
289 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
290 Box::pin(async move {
291 let tenant = TenantContext::current();
292 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = ?1 AND id = ?2")
293 .bind(&tenant)
294 .bind(id.0.as_str())
295 .execute(&self.pool)
296 .await
297 .map_err(|e| to_a2a_error(&e))?;
298 Ok(())
299 })
300 }
301
302 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
303 Box::pin(async move {
304 let tenant = TenantContext::current();
305 let row: (i64,) =
306 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = ?1")
307 .bind(&tenant)
308 .fetch_one(&self.pool)
309 .await
310 .map_err(|e| to_a2a_error(&e))?;
311 #[allow(clippy::cast_sign_loss)]
312 Ok(row.0 as u64)
313 })
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
321
322 async fn make_store() -> TenantAwareSqliteTaskStore {
323 TenantAwareSqliteTaskStore::new("sqlite::memory:")
324 .await
325 .expect("failed to create in-memory tenant store")
326 }
327
328 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
329 Task {
330 id: TaskId::new(id),
331 context_id: ContextId::new(ctx),
332 status: TaskStatus::new(state),
333 history: None,
334 artifacts: None,
335 metadata: None,
336 }
337 }
338
339 #[tokio::test]
340 async fn save_and_get_within_tenant() {
341 let store = make_store().await;
342 TenantContext::scope("acme", async {
343 store
344 .save(make_task("t1", "ctx1", TaskState::Submitted))
345 .await
346 .unwrap();
347 let task = store.get(&TaskId::new("t1")).await.unwrap();
348 assert!(
349 task.is_some(),
350 "task should be retrievable within its tenant"
351 );
352 assert_eq!(task.unwrap().id, TaskId::new("t1"));
353 })
354 .await;
355 }
356
357 #[tokio::test]
358 async fn tenant_isolation_get() {
359 let store = make_store().await;
360 TenantContext::scope("tenant-a", async {
361 store
362 .save(make_task("t1", "ctx1", TaskState::Submitted))
363 .await
364 .unwrap();
365 })
366 .await;
367
368 TenantContext::scope("tenant-b", async {
369 let result = store.get(&TaskId::new("t1")).await.unwrap();
370 assert!(result.is_none(), "tenant-b should not see tenant-a's task");
371 })
372 .await;
373 }
374
375 #[tokio::test]
376 async fn tenant_isolation_list() {
377 let store = make_store().await;
378 TenantContext::scope("tenant-a", async {
379 store
380 .save(make_task("t1", "ctx1", TaskState::Submitted))
381 .await
382 .unwrap();
383 store
384 .save(make_task("t2", "ctx1", TaskState::Working))
385 .await
386 .unwrap();
387 })
388 .await;
389
390 TenantContext::scope("tenant-b", async {
391 store
392 .save(make_task("t3", "ctx1", TaskState::Submitted))
393 .await
394 .unwrap();
395 })
396 .await;
397
398 TenantContext::scope("tenant-a", async {
399 let response = store.list(&ListTasksParams::default()).await.unwrap();
400 assert_eq!(
401 response.tasks.len(),
402 2,
403 "tenant-a should see only its 2 tasks"
404 );
405 })
406 .await;
407
408 TenantContext::scope("tenant-b", async {
409 let response = store.list(&ListTasksParams::default()).await.unwrap();
410 assert_eq!(
411 response.tasks.len(),
412 1,
413 "tenant-b should see only its 1 task"
414 );
415 })
416 .await;
417 }
418
419 #[tokio::test]
420 async fn tenant_isolation_count() {
421 let store = make_store().await;
422 TenantContext::scope("tenant-a", async {
423 store
424 .save(make_task("t1", "ctx1", TaskState::Submitted))
425 .await
426 .unwrap();
427 store
428 .save(make_task("t2", "ctx1", TaskState::Working))
429 .await
430 .unwrap();
431 })
432 .await;
433
434 TenantContext::scope("tenant-b", async {
435 let count = store.count().await.unwrap();
436 assert_eq!(count, 0, "tenant-b should have zero tasks");
437 })
438 .await;
439
440 TenantContext::scope("tenant-a", async {
441 let count = store.count().await.unwrap();
442 assert_eq!(count, 2, "tenant-a should have 2 tasks");
443 })
444 .await;
445 }
446
447 #[tokio::test]
448 async fn tenant_isolation_delete() {
449 let store = make_store().await;
450 TenantContext::scope("tenant-a", async {
451 store
452 .save(make_task("t1", "ctx1", TaskState::Submitted))
453 .await
454 .unwrap();
455 })
456 .await;
457
458 TenantContext::scope("tenant-b", async {
460 store.delete(&TaskId::new("t1")).await.unwrap();
461 })
462 .await;
463
464 TenantContext::scope("tenant-a", async {
465 let task = store.get(&TaskId::new("t1")).await.unwrap();
466 assert!(
467 task.is_some(),
468 "tenant-a's task should still exist after tenant-b's delete"
469 );
470 })
471 .await;
472 }
473
474 #[tokio::test]
475 async fn same_task_id_different_tenants() {
476 let store = make_store().await;
477 TenantContext::scope("tenant-a", async {
478 store
479 .save(make_task("t1", "ctx-a", TaskState::Submitted))
480 .await
481 .unwrap();
482 })
483 .await;
484
485 TenantContext::scope("tenant-b", async {
486 store
487 .save(make_task("t1", "ctx-b", TaskState::Working))
488 .await
489 .unwrap();
490 })
491 .await;
492
493 TenantContext::scope("tenant-a", async {
494 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
495 assert_eq!(
496 task.context_id,
497 ContextId::new("ctx-a"),
498 "tenant-a should get its own version of t1"
499 );
500 assert_eq!(task.status.state, TaskState::Submitted);
501 })
502 .await;
503
504 TenantContext::scope("tenant-b", async {
505 let task = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
506 assert_eq!(
507 task.context_id,
508 ContextId::new("ctx-b"),
509 "tenant-b should get its own version of t1"
510 );
511 assert_eq!(task.status.state, TaskState::Working);
512 })
513 .await;
514 }
515
516 #[tokio::test]
517 async fn insert_if_absent_respects_tenant_scope() {
518 let store = make_store().await;
519 TenantContext::scope("tenant-a", async {
520 let inserted = store
521 .insert_if_absent(make_task("t1", "ctx1", TaskState::Submitted))
522 .await
523 .unwrap();
524 assert!(inserted, "first insert should succeed");
525
526 let inserted = store
527 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
528 .await
529 .unwrap();
530 assert!(!inserted, "duplicate insert in same tenant should fail");
531 })
532 .await;
533
534 TenantContext::scope("tenant-b", async {
536 let inserted = store
537 .insert_if_absent(make_task("t1", "ctx1", TaskState::Working))
538 .await
539 .unwrap();
540 assert!(
541 inserted,
542 "insert of same task id in different tenant should succeed"
543 );
544 })
545 .await;
546 }
547
548 #[tokio::test]
549 async fn list_pagination_within_tenant() {
550 let store = make_store().await;
551 TenantContext::scope("tenant-a", async {
552 for i in 0..5 {
553 store
554 .save(make_task(
555 &format!("task-{i:03}"),
556 "ctx1",
557 TaskState::Submitted,
558 ))
559 .await
560 .unwrap();
561 }
562
563 let params = ListTasksParams {
564 page_size: Some(2),
565 ..Default::default()
566 };
567 let response = store.list(¶ms).await.unwrap();
568 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
569 assert!(
570 response.next_page_token.is_some(),
571 "should have a next page token"
572 );
573
574 let params2 = ListTasksParams {
575 page_size: Some(2),
576 page_token: response.next_page_token,
577 ..Default::default()
578 };
579 let response2 = store.list(¶ms2).await.unwrap();
580 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
581 })
582 .await;
583 }
584
585 #[test]
587 fn to_a2a_error_formats_message() {
588 let sqlite_err = sqlx::Error::RowNotFound;
589 let a2a_err = to_a2a_error(&sqlite_err);
590 let msg = format!("{a2a_err}");
591 assert!(
592 msg.contains("sqlite error"),
593 "error message should contain 'sqlite error': {msg}"
594 );
595 }
596
597 #[tokio::test]
598 async fn default_tenant_context_uses_empty_string() {
599 let store = make_store().await;
600 store
602 .save(make_task("t1", "ctx1", TaskState::Submitted))
603 .await
604 .unwrap();
605 let task = store.get(&TaskId::new("t1")).await.unwrap();
606 assert!(task.is_some(), "default (empty) tenant should work");
607 }
608}