1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
51pub struct SqliteTaskStore {
52 pool: SqlitePool,
53}
54
55impl SqliteTaskStore {
56 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
62 let pool = sqlite_pool(url).await?;
63 Self::from_pool(pool).await
64 }
65
66 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
76 let pool = sqlite_pool(url).await?;
77
78 let runner = super::migration::MigrationRunner::new(pool.clone());
79 runner.run_pending().await?;
80
81 Ok(Self { pool })
82 }
83
84 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
90 sqlx::query(
91 "CREATE TABLE IF NOT EXISTS tasks (
92 id TEXT PRIMARY KEY,
93 context_id TEXT NOT NULL,
94 state TEXT NOT NULL,
95 data TEXT NOT NULL,
96 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
97 created_at TEXT NOT NULL DEFAULT (datetime('now'))
98 )",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
104 .execute(&pool)
105 .await?;
106
107 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
108 .execute(&pool)
109 .await?;
110
111 sqlx::query(
112 "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
113 )
114 .execute(&pool)
115 .await?;
116
117 Ok(Self { pool })
118 }
119}
120
121async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
126 sqlite_pool_with_size(url, 8).await
127}
128
129async fn sqlite_pool_with_size(url: &str, max_connections: u32) -> Result<SqlitePool, sqlx::Error> {
131 use sqlx::sqlite::SqliteConnectOptions;
132 use std::str::FromStr;
133
134 let opts = SqliteConnectOptions::from_str(url)?
135 .pragma("journal_mode", "WAL")
136 .pragma("busy_timeout", "5000")
137 .pragma("synchronous", "NORMAL")
138 .pragma("foreign_keys", "ON")
139 .create_if_missing(true);
140
141 SqlitePoolOptions::new()
142 .max_connections(max_connections)
143 .connect_with(opts)
144 .await
145}
146
147#[allow(clippy::needless_pass_by_value)]
149fn to_a2a_error(e: sqlx::Error) -> A2aError {
150 A2aError::internal(format!("sqlite error: {e}"))
151}
152
153#[allow(clippy::manual_async_fn)]
154impl TaskStore for SqliteTaskStore {
155 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
156 Box::pin(async move {
157 let id = task.id.0.as_str();
158 let context_id = task.context_id.0.as_str();
159 let state = task.status.state.to_string();
160 let data = serde_json::to_string(&task)
161 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
162
163 sqlx::query(
164 "INSERT INTO tasks (id, context_id, state, data, updated_at)
165 VALUES (?1, ?2, ?3, ?4, datetime('now'))
166 ON CONFLICT(id) DO UPDATE SET
167 context_id = excluded.context_id,
168 state = excluded.state,
169 data = excluded.data,
170 updated_at = datetime('now')",
171 )
172 .bind(id)
173 .bind(context_id)
174 .bind(&state)
175 .bind(&data)
176 .execute(&self.pool)
177 .await
178 .map_err(to_a2a_error)?;
179
180 Ok(())
181 })
182 }
183
184 fn get<'a>(
185 &'a self,
186 id: &'a TaskId,
187 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
188 Box::pin(async move {
189 let row: Option<(String,)> = sqlx::query_as("SELECT data FROM tasks WHERE id = ?1")
190 .bind(id.0.as_str())
191 .fetch_optional(&self.pool)
192 .await
193 .map_err(to_a2a_error)?;
194
195 match row {
196 Some((data,)) => {
197 let task: Task = serde_json::from_str(&data).map_err(|e| {
198 A2aError::internal(format!("failed to deserialize task: {e}"))
199 })?;
200 Ok(Some(task))
201 }
202 None => Ok(None),
203 }
204 })
205 }
206
207 fn list<'a>(
208 &'a self,
209 params: &'a ListTasksParams,
210 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
211 Box::pin(async move {
212 let mut conditions = Vec::new();
214 let mut bind_values: Vec<String> = Vec::new();
215
216 if let Some(ref ctx) = params.context_id {
217 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
218 bind_values.push(ctx.clone());
219 }
220 if let Some(ref status) = params.status {
221 conditions.push(format!("state = ?{}", bind_values.len() + 1));
222 bind_values.push(status.to_string());
223 }
224 if let Some(ref token) = params.page_token {
225 conditions.push(format!("id > ?{}", bind_values.len() + 1));
226 bind_values.push(token.clone());
227 }
228
229 let where_clause = if conditions.is_empty() {
230 String::new()
231 } else {
232 format!("WHERE {}", conditions.join(" AND "))
233 };
234
235 let page_size = match params.page_size {
236 Some(0) | None => 50_u32,
237 Some(n) => n.min(1000),
238 };
239
240 let limit = page_size + 1;
244 let limit_param = bind_values.len() + 1;
245 let sql = format!(
246 "SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT ?{limit_param}"
247 );
248
249 let mut query = sqlx::query_as::<_, (String,)>(&sql);
250 for val in &bind_values {
251 query = query.bind(val);
252 }
253 query = query.bind(limit);
254
255 let rows: Vec<(String,)> = query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
256
257 let mut tasks: Vec<Task> = rows
258 .into_iter()
259 .map(|(data,)| {
260 serde_json::from_str(&data)
261 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
262 })
263 .collect::<A2aResult<Vec<_>>>()?;
264
265 let next_page_token = if tasks.len() > page_size as usize {
266 tasks.truncate(page_size as usize);
267 tasks.last().map(|t| t.id.0.clone())
268 } else {
269 None
270 };
271
272 let mut response = TaskListResponse::new(tasks);
273 response.next_page_token = next_page_token;
274 Ok(response)
275 })
276 }
277
278 fn insert_if_absent<'a>(
279 &'a self,
280 task: Task,
281 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
282 Box::pin(async move {
283 let id = task.id.0.as_str();
284 let context_id = task.context_id.0.as_str();
285 let state = task.status.state.to_string();
286 let data = serde_json::to_string(&task)
287 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
288
289 let result = sqlx::query(
290 "INSERT OR IGNORE INTO tasks (id, context_id, state, data, updated_at)
291 VALUES (?1, ?2, ?3, ?4, datetime('now'))",
292 )
293 .bind(id)
294 .bind(context_id)
295 .bind(&state)
296 .bind(&data)
297 .execute(&self.pool)
298 .await
299 .map_err(to_a2a_error)?;
300
301 Ok(result.rows_affected() > 0)
302 })
303 }
304
305 fn delete<'a>(
306 &'a self,
307 id: &'a TaskId,
308 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
309 Box::pin(async move {
310 sqlx::query("DELETE FROM tasks WHERE id = ?1")
311 .bind(id.0.as_str())
312 .execute(&self.pool)
313 .await
314 .map_err(to_a2a_error)?;
315 Ok(())
316 })
317 }
318
319 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
320 Box::pin(async move {
321 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
322 .fetch_one(&self.pool)
323 .await
324 .map_err(to_a2a_error)?;
325 #[allow(clippy::cast_sign_loss)]
326 Ok(row.0 as u64)
327 })
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
335
336 async fn make_store() -> SqliteTaskStore {
337 SqliteTaskStore::new("sqlite::memory:")
338 .await
339 .expect("failed to create in-memory store")
340 }
341
342 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
343 Task {
344 id: TaskId::new(id),
345 context_id: ContextId::new(ctx),
346 status: TaskStatus::new(state),
347 history: None,
348 artifacts: None,
349 metadata: None,
350 }
351 }
352
353 #[tokio::test]
354 async fn save_and_get_round_trip() {
355 let store = make_store().await;
356 let task = make_task("t1", "ctx1", TaskState::Submitted);
357 store.save(task.clone()).await.expect("save should succeed");
358
359 let retrieved = store
360 .get(&TaskId::new("t1"))
361 .await
362 .expect("get should succeed");
363 let retrieved = retrieved.expect("task should exist after save");
364 assert_eq!(retrieved.id, TaskId::new("t1"), "task id should match");
365 assert_eq!(
366 retrieved.context_id,
367 ContextId::new("ctx1"),
368 "context_id should match"
369 );
370 assert_eq!(
371 retrieved.status.state,
372 TaskState::Submitted,
373 "state should match"
374 );
375 }
376
377 #[tokio::test]
378 async fn get_returns_none_for_missing_task() {
379 let store = make_store().await;
380 let result = store
381 .get(&TaskId::new("nonexistent"))
382 .await
383 .expect("get should succeed");
384 assert!(
385 result.is_none(),
386 "get should return None for a missing task"
387 );
388 }
389
390 #[tokio::test]
391 async fn save_overwrites_existing_task() {
392 let store = make_store().await;
393 let task1 = make_task("t1", "ctx1", TaskState::Submitted);
394 store.save(task1).await.expect("first save should succeed");
395
396 let task2 = make_task("t1", "ctx1", TaskState::Working);
397 store.save(task2).await.expect("second save should succeed");
398
399 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
400 assert_eq!(
401 retrieved.status.state,
402 TaskState::Working,
403 "state should be updated after overwrite"
404 );
405 }
406
407 #[tokio::test]
408 async fn insert_if_absent_returns_true_for_new_task() {
409 let store = make_store().await;
410 let task = make_task("t1", "ctx1", TaskState::Submitted);
411 let inserted = store
412 .insert_if_absent(task)
413 .await
414 .expect("insert_if_absent should succeed");
415 assert!(
416 inserted,
417 "insert_if_absent should return true for a new task"
418 );
419 }
420
421 #[tokio::test]
422 async fn insert_if_absent_returns_false_for_existing_task() {
423 let store = make_store().await;
424 let task = make_task("t1", "ctx1", TaskState::Submitted);
425 store.save(task.clone()).await.unwrap();
426
427 let duplicate = make_task("t1", "ctx1", TaskState::Working);
428 let inserted = store
429 .insert_if_absent(duplicate)
430 .await
431 .expect("insert_if_absent should succeed");
432 assert!(
433 !inserted,
434 "insert_if_absent should return false for an existing task"
435 );
436
437 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
439 assert_eq!(
440 retrieved.status.state,
441 TaskState::Submitted,
442 "original state should be preserved"
443 );
444 }
445
446 #[tokio::test]
447 async fn delete_removes_task() {
448 let store = make_store().await;
449 store
450 .save(make_task("t1", "ctx1", TaskState::Submitted))
451 .await
452 .unwrap();
453
454 store
455 .delete(&TaskId::new("t1"))
456 .await
457 .expect("delete should succeed");
458
459 let result = store.get(&TaskId::new("t1")).await.unwrap();
460 assert!(result.is_none(), "task should be gone after delete");
461 }
462
463 #[tokio::test]
464 async fn delete_nonexistent_is_ok() {
465 let store = make_store().await;
466 let result = store.delete(&TaskId::new("nonexistent")).await;
467 assert!(
468 result.is_ok(),
469 "deleting a nonexistent task should not error"
470 );
471 }
472
473 #[tokio::test]
474 async fn count_tracks_inserts_and_deletes() {
475 let store = make_store().await;
476 assert_eq!(
477 store.count().await.unwrap(),
478 0,
479 "empty store should have count 0"
480 );
481
482 store
483 .save(make_task("t1", "ctx1", TaskState::Submitted))
484 .await
485 .unwrap();
486 store
487 .save(make_task("t2", "ctx1", TaskState::Working))
488 .await
489 .unwrap();
490 assert_eq!(
491 store.count().await.unwrap(),
492 2,
493 "count should be 2 after two saves"
494 );
495
496 store.delete(&TaskId::new("t1")).await.unwrap();
497 assert_eq!(
498 store.count().await.unwrap(),
499 1,
500 "count should be 1 after one delete"
501 );
502 }
503
504 #[tokio::test]
505 async fn list_all_tasks() {
506 let store = make_store().await;
507 store
508 .save(make_task("t1", "ctx1", TaskState::Submitted))
509 .await
510 .unwrap();
511 store
512 .save(make_task("t2", "ctx2", TaskState::Working))
513 .await
514 .unwrap();
515
516 let params = ListTasksParams::default();
517 let response = store.list(¶ms).await.expect("list should succeed");
518 assert_eq!(response.tasks.len(), 2, "list should return all tasks");
519 }
520
521 #[tokio::test]
522 async fn list_filter_by_context_id() {
523 let store = make_store().await;
524 store
525 .save(make_task("t1", "ctx-a", TaskState::Submitted))
526 .await
527 .unwrap();
528 store
529 .save(make_task("t2", "ctx-b", TaskState::Submitted))
530 .await
531 .unwrap();
532 store
533 .save(make_task("t3", "ctx-a", TaskState::Working))
534 .await
535 .unwrap();
536
537 let params = ListTasksParams {
538 context_id: Some("ctx-a".to_string()),
539 ..Default::default()
540 };
541 let response = store.list(¶ms).await.unwrap();
542 assert_eq!(
543 response.tasks.len(),
544 2,
545 "should return only tasks with context_id ctx-a"
546 );
547 }
548
549 #[tokio::test]
550 async fn list_filter_by_status() {
551 let store = make_store().await;
552 store
553 .save(make_task("t1", "ctx1", TaskState::Submitted))
554 .await
555 .unwrap();
556 store
557 .save(make_task("t2", "ctx1", TaskState::Working))
558 .await
559 .unwrap();
560 store
561 .save(make_task("t3", "ctx1", TaskState::Working))
562 .await
563 .unwrap();
564
565 let params = ListTasksParams {
566 status: Some(TaskState::Working),
567 ..Default::default()
568 };
569 let response = store.list(¶ms).await.unwrap();
570 assert_eq!(response.tasks.len(), 2, "should return only Working tasks");
571 }
572
573 #[tokio::test]
574 async fn list_pagination() {
575 let store = make_store().await;
576 for i in 0..5 {
578 store
579 .save(make_task(
580 &format!("task-{i:03}"),
581 "ctx1",
582 TaskState::Submitted,
583 ))
584 .await
585 .unwrap();
586 }
587
588 let params = ListTasksParams {
590 page_size: Some(2),
591 ..Default::default()
592 };
593 let response = store.list(¶ms).await.unwrap();
594 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
595 assert!(
596 response.next_page_token.is_some(),
597 "should have a next page token"
598 );
599
600 let params2 = ListTasksParams {
602 page_size: Some(2),
603 page_token: response.next_page_token,
604 ..Default::default()
605 };
606 let response2 = store.list(¶ms2).await.unwrap();
607 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
608 assert!(
609 response2.next_page_token.is_some(),
610 "should still have a next page token"
611 );
612
613 let params3 = ListTasksParams {
615 page_size: Some(2),
616 page_token: response2.next_page_token,
617 ..Default::default()
618 };
619 let response3 = store.list(¶ms3).await.unwrap();
620 assert_eq!(response3.tasks.len(), 1, "last page should have 1 task");
621 assert!(
622 response3.next_page_token.is_none(),
623 "last page should have no next page token"
624 );
625 }
626
627 #[test]
629 fn to_a2a_error_formats_message() {
630 let sqlite_err = sqlx::Error::RowNotFound;
631 let a2a_err = to_a2a_error(sqlite_err);
632 let msg = format!("{a2a_err}");
633 assert!(
634 msg.contains("sqlite error"),
635 "error message should contain 'sqlite error': {msg}"
636 );
637 }
638
639 #[tokio::test]
641 async fn with_migrations_creates_store() {
642 let result = SqliteTaskStore::with_migrations("sqlite::memory:").await;
644 assert!(
645 result.is_ok(),
646 "with_migrations should succeed on a fresh database"
647 );
648 let store = result.unwrap();
649 let count = store.count().await.unwrap();
650 assert_eq!(count, 0, "freshly migrated store should be empty");
651 }
652
653 #[tokio::test]
654 async fn list_empty_store() {
655 let store = make_store().await;
656 let params = ListTasksParams::default();
657 let response = store.list(¶ms).await.unwrap();
658 assert!(
659 response.tasks.is_empty(),
660 "list on empty store should return no tasks"
661 );
662 assert!(
663 response.next_page_token.is_none(),
664 "no pagination token for empty results"
665 );
666 }
667}