1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
51pub struct SqliteTaskStore {
52 pool: SqlitePool,
53}
54
55impl SqliteTaskStore {
56 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
62 let pool = sqlite_pool(url).await?;
63 Self::from_pool(pool).await
64 }
65
66 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
76 let pool = sqlite_pool(url).await?;
77
78 let runner = super::migration::MigrationRunner::new(pool.clone());
79 runner.run_pending().await?;
80
81 Ok(Self { pool })
82 }
83
84 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
90 sqlx::query(
91 "CREATE TABLE IF NOT EXISTS tasks (
92 id TEXT PRIMARY KEY,
93 context_id TEXT NOT NULL,
94 state TEXT NOT NULL,
95 data TEXT NOT NULL,
96 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
97 created_at TEXT NOT NULL DEFAULT (datetime('now'))
98 )",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
104 .execute(&pool)
105 .await?;
106
107 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
108 .execute(&pool)
109 .await?;
110
111 sqlx::query(
112 "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
113 )
114 .execute(&pool)
115 .await?;
116
117 Ok(Self { pool })
118 }
119}
120
121async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
126 sqlite_pool_with_size(url, 8).await
127}
128
129async fn sqlite_pool_with_size(url: &str, max_connections: u32) -> Result<SqlitePool, sqlx::Error> {
131 use sqlx::sqlite::SqliteConnectOptions;
132 use std::str::FromStr;
133
134 let opts = SqliteConnectOptions::from_str(url)?
135 .pragma("journal_mode", "WAL")
136 .pragma("busy_timeout", "5000")
137 .pragma("synchronous", "NORMAL")
138 .pragma("foreign_keys", "ON")
139 .create_if_missing(true);
140
141 SqlitePoolOptions::new()
142 .max_connections(max_connections)
143 .connect_with(opts)
144 .await
145}
146
147#[allow(clippy::needless_pass_by_value)]
149fn to_a2a_error(e: sqlx::Error) -> A2aError {
150 A2aError::internal(format!("sqlite error: {e}"))
151}
152
153#[allow(clippy::manual_async_fn)]
154impl TaskStore for SqliteTaskStore {
155 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
156 Box::pin(async move {
157 let id = task.id.0.as_str();
158 let context_id = task.context_id.0.as_str();
159 let state = task.status.state.to_string();
160 let data = serde_json::to_string(&task)
161 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
162
163 sqlx::query(
164 "INSERT INTO tasks (id, context_id, state, data, updated_at)
165 VALUES (?1, ?2, ?3, ?4, datetime('now'))
166 ON CONFLICT(id) DO UPDATE SET
167 context_id = excluded.context_id,
168 state = excluded.state,
169 data = excluded.data,
170 updated_at = datetime('now')",
171 )
172 .bind(id)
173 .bind(context_id)
174 .bind(&state)
175 .bind(&data)
176 .execute(&self.pool)
177 .await
178 .map_err(to_a2a_error)?;
179
180 Ok(())
181 })
182 }
183
184 fn get<'a>(
185 &'a self,
186 id: &'a TaskId,
187 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
188 Box::pin(async move {
189 let row: Option<(String,)> = sqlx::query_as("SELECT data FROM tasks WHERE id = ?1")
190 .bind(id.0.as_str())
191 .fetch_optional(&self.pool)
192 .await
193 .map_err(to_a2a_error)?;
194
195 match row {
196 Some((data,)) => {
197 let task: Task = serde_json::from_str(&data).map_err(|e| {
198 A2aError::internal(format!("failed to deserialize task: {e}"))
199 })?;
200 Ok(Some(task))
201 }
202 None => Ok(None),
203 }
204 })
205 }
206
207 fn list<'a>(
208 &'a self,
209 params: &'a ListTasksParams,
210 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
211 Box::pin(async move {
212 let mut conditions = Vec::new();
214 let mut bind_values: Vec<String> = Vec::new();
215
216 if let Some(ref ctx) = params.context_id {
217 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
218 bind_values.push(ctx.clone());
219 }
220 if let Some(ref status) = params.status {
221 conditions.push(format!("state = ?{}", bind_values.len() + 1));
222 bind_values.push(status.to_string());
223 }
224 if let Some(ref token) = params.page_token {
225 conditions.push(format!("id > ?{}", bind_values.len() + 1));
226 bind_values.push(token.clone());
227 }
228
229 let where_clause = if conditions.is_empty() {
230 String::new()
231 } else {
232 format!("WHERE {}", conditions.join(" AND "))
233 };
234
235 let page_size = match params.page_size {
236 Some(0) | None => 50_u32,
237 Some(n) => n.min(1000),
238 };
239
240 let limit = page_size + 1;
244 let limit_param = bind_values.len() + 1;
245 let sql = format!(
246 "SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT ?{limit_param}"
247 );
248
249 let mut query = sqlx::query_as::<_, (String,)>(&sql);
250 for val in &bind_values {
251 query = query.bind(val);
252 }
253 query = query.bind(limit);
254
255 let rows: Vec<(String,)> = query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
256
257 let mut tasks: Vec<Task> = rows
258 .into_iter()
259 .map(|(data,)| {
260 serde_json::from_str(&data)
261 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
262 })
263 .collect::<A2aResult<Vec<_>>>()?;
264
265 let next_page_token = if tasks.len() > page_size as usize {
266 tasks.truncate(page_size as usize);
267 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
268 } else {
269 String::new()
270 };
271
272 #[allow(clippy::cast_possible_truncation)]
273 let page_len = tasks.len() as u32;
274 let mut response = TaskListResponse::new(tasks);
275 response.next_page_token = next_page_token;
276 response.page_size = page_len;
277 Ok(response)
278 })
279 }
280
281 fn insert_if_absent<'a>(
282 &'a self,
283 task: Task,
284 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
285 Box::pin(async move {
286 let id = task.id.0.as_str();
287 let context_id = task.context_id.0.as_str();
288 let state = task.status.state.to_string();
289 let data = serde_json::to_string(&task)
290 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
291
292 let result = sqlx::query(
293 "INSERT OR IGNORE INTO tasks (id, context_id, state, data, updated_at)
294 VALUES (?1, ?2, ?3, ?4, datetime('now'))",
295 )
296 .bind(id)
297 .bind(context_id)
298 .bind(&state)
299 .bind(&data)
300 .execute(&self.pool)
301 .await
302 .map_err(to_a2a_error)?;
303
304 Ok(result.rows_affected() > 0)
305 })
306 }
307
308 fn delete<'a>(
309 &'a self,
310 id: &'a TaskId,
311 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
312 Box::pin(async move {
313 sqlx::query("DELETE FROM tasks WHERE id = ?1")
314 .bind(id.0.as_str())
315 .execute(&self.pool)
316 .await
317 .map_err(to_a2a_error)?;
318 Ok(())
319 })
320 }
321
322 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
323 Box::pin(async move {
324 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
325 .fetch_one(&self.pool)
326 .await
327 .map_err(to_a2a_error)?;
328 #[allow(clippy::cast_sign_loss)]
329 Ok(row.0 as u64)
330 })
331 }
332}
333
334#[cfg(test)]
335mod tests {
336 use super::*;
337 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
338
339 async fn make_store() -> SqliteTaskStore {
340 SqliteTaskStore::new("sqlite::memory:")
341 .await
342 .expect("failed to create in-memory store")
343 }
344
345 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
346 Task {
347 id: TaskId::new(id),
348 context_id: ContextId::new(ctx),
349 status: TaskStatus::new(state),
350 history: None,
351 artifacts: None,
352 metadata: None,
353 }
354 }
355
356 #[tokio::test]
357 async fn save_and_get_round_trip() {
358 let store = make_store().await;
359 let task = make_task("t1", "ctx1", TaskState::Submitted);
360 store.save(task.clone()).await.expect("save should succeed");
361
362 let retrieved = store
363 .get(&TaskId::new("t1"))
364 .await
365 .expect("get should succeed");
366 let retrieved = retrieved.expect("task should exist after save");
367 assert_eq!(retrieved.id, TaskId::new("t1"), "task id should match");
368 assert_eq!(
369 retrieved.context_id,
370 ContextId::new("ctx1"),
371 "context_id should match"
372 );
373 assert_eq!(
374 retrieved.status.state,
375 TaskState::Submitted,
376 "state should match"
377 );
378 }
379
380 #[tokio::test]
381 async fn get_returns_none_for_missing_task() {
382 let store = make_store().await;
383 let result = store
384 .get(&TaskId::new("nonexistent"))
385 .await
386 .expect("get should succeed");
387 assert!(
388 result.is_none(),
389 "get should return None for a missing task"
390 );
391 }
392
393 #[tokio::test]
394 async fn save_overwrites_existing_task() {
395 let store = make_store().await;
396 let task1 = make_task("t1", "ctx1", TaskState::Submitted);
397 store.save(task1).await.expect("first save should succeed");
398
399 let task2 = make_task("t1", "ctx1", TaskState::Working);
400 store.save(task2).await.expect("second save should succeed");
401
402 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
403 assert_eq!(
404 retrieved.status.state,
405 TaskState::Working,
406 "state should be updated after overwrite"
407 );
408 }
409
410 #[tokio::test]
411 async fn insert_if_absent_returns_true_for_new_task() {
412 let store = make_store().await;
413 let task = make_task("t1", "ctx1", TaskState::Submitted);
414 let inserted = store
415 .insert_if_absent(task)
416 .await
417 .expect("insert_if_absent should succeed");
418 assert!(
419 inserted,
420 "insert_if_absent should return true for a new task"
421 );
422 }
423
424 #[tokio::test]
425 async fn insert_if_absent_returns_false_for_existing_task() {
426 let store = make_store().await;
427 let task = make_task("t1", "ctx1", TaskState::Submitted);
428 store.save(task.clone()).await.unwrap();
429
430 let duplicate = make_task("t1", "ctx1", TaskState::Working);
431 let inserted = store
432 .insert_if_absent(duplicate)
433 .await
434 .expect("insert_if_absent should succeed");
435 assert!(
436 !inserted,
437 "insert_if_absent should return false for an existing task"
438 );
439
440 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
442 assert_eq!(
443 retrieved.status.state,
444 TaskState::Submitted,
445 "original state should be preserved"
446 );
447 }
448
449 #[tokio::test]
450 async fn delete_removes_task() {
451 let store = make_store().await;
452 store
453 .save(make_task("t1", "ctx1", TaskState::Submitted))
454 .await
455 .unwrap();
456
457 store
458 .delete(&TaskId::new("t1"))
459 .await
460 .expect("delete should succeed");
461
462 let result = store.get(&TaskId::new("t1")).await.unwrap();
463 assert!(result.is_none(), "task should be gone after delete");
464 }
465
466 #[tokio::test]
467 async fn delete_nonexistent_is_ok() {
468 let store = make_store().await;
469 let result = store.delete(&TaskId::new("nonexistent")).await;
470 assert!(
471 result.is_ok(),
472 "deleting a nonexistent task should not error"
473 );
474 }
475
476 #[tokio::test]
477 async fn count_tracks_inserts_and_deletes() {
478 let store = make_store().await;
479 assert_eq!(
480 store.count().await.unwrap(),
481 0,
482 "empty store should have count 0"
483 );
484
485 store
486 .save(make_task("t1", "ctx1", TaskState::Submitted))
487 .await
488 .unwrap();
489 store
490 .save(make_task("t2", "ctx1", TaskState::Working))
491 .await
492 .unwrap();
493 assert_eq!(
494 store.count().await.unwrap(),
495 2,
496 "count should be 2 after two saves"
497 );
498
499 store.delete(&TaskId::new("t1")).await.unwrap();
500 assert_eq!(
501 store.count().await.unwrap(),
502 1,
503 "count should be 1 after one delete"
504 );
505 }
506
507 #[tokio::test]
508 async fn list_all_tasks() {
509 let store = make_store().await;
510 store
511 .save(make_task("t1", "ctx1", TaskState::Submitted))
512 .await
513 .unwrap();
514 store
515 .save(make_task("t2", "ctx2", TaskState::Working))
516 .await
517 .unwrap();
518
519 let params = ListTasksParams::default();
520 let response = store.list(¶ms).await.expect("list should succeed");
521 assert_eq!(response.tasks.len(), 2, "list should return all tasks");
522 }
523
524 #[tokio::test]
525 async fn list_filter_by_context_id() {
526 let store = make_store().await;
527 store
528 .save(make_task("t1", "ctx-a", TaskState::Submitted))
529 .await
530 .unwrap();
531 store
532 .save(make_task("t2", "ctx-b", TaskState::Submitted))
533 .await
534 .unwrap();
535 store
536 .save(make_task("t3", "ctx-a", TaskState::Working))
537 .await
538 .unwrap();
539
540 let params = ListTasksParams {
541 context_id: Some("ctx-a".to_string()),
542 ..Default::default()
543 };
544 let response = store.list(¶ms).await.unwrap();
545 assert_eq!(
546 response.tasks.len(),
547 2,
548 "should return only tasks with context_id ctx-a"
549 );
550 }
551
552 #[tokio::test]
553 async fn list_filter_by_status() {
554 let store = make_store().await;
555 store
556 .save(make_task("t1", "ctx1", TaskState::Submitted))
557 .await
558 .unwrap();
559 store
560 .save(make_task("t2", "ctx1", TaskState::Working))
561 .await
562 .unwrap();
563 store
564 .save(make_task("t3", "ctx1", TaskState::Working))
565 .await
566 .unwrap();
567
568 let params = ListTasksParams {
569 status: Some(TaskState::Working),
570 ..Default::default()
571 };
572 let response = store.list(¶ms).await.unwrap();
573 assert_eq!(response.tasks.len(), 2, "should return only Working tasks");
574 }
575
576 #[tokio::test]
577 async fn list_pagination() {
578 let store = make_store().await;
579 for i in 0..5 {
581 store
582 .save(make_task(
583 &format!("task-{i:03}"),
584 "ctx1",
585 TaskState::Submitted,
586 ))
587 .await
588 .unwrap();
589 }
590
591 let params = ListTasksParams {
593 page_size: Some(2),
594 ..Default::default()
595 };
596 let response = store.list(¶ms).await.unwrap();
597 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
598 assert!(
599 !response.next_page_token.is_empty(),
600 "should have a next page token"
601 );
602
603 let params2 = ListTasksParams {
605 page_size: Some(2),
606 page_token: Some(response.next_page_token),
607 ..Default::default()
608 };
609 let response2 = store.list(¶ms2).await.unwrap();
610 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
611 assert!(
612 !response2.next_page_token.is_empty(),
613 "should still have a next page token"
614 );
615
616 let params3 = ListTasksParams {
618 page_size: Some(2),
619 page_token: Some(response2.next_page_token),
620 ..Default::default()
621 };
622 let response3 = store.list(¶ms3).await.unwrap();
623 assert_eq!(response3.tasks.len(), 1, "last page should have 1 task");
624 assert!(
625 response3.next_page_token.is_empty(),
626 "last page should have no next page token"
627 );
628 }
629
630 #[test]
632 fn to_a2a_error_formats_message() {
633 let sqlite_err = sqlx::Error::RowNotFound;
634 let a2a_err = to_a2a_error(sqlite_err);
635 let msg = format!("{a2a_err}");
636 assert!(
637 msg.contains("sqlite error"),
638 "error message should contain 'sqlite error': {msg}"
639 );
640 }
641
642 #[tokio::test]
644 async fn with_migrations_creates_store() {
645 let result = SqliteTaskStore::with_migrations("sqlite::memory:").await;
647 assert!(
648 result.is_ok(),
649 "with_migrations should succeed on a fresh database"
650 );
651 let store = result.unwrap();
652 let count = store.count().await.unwrap();
653 assert_eq!(count, 0, "freshly migrated store should be empty");
654 }
655
656 #[tokio::test]
657 async fn list_empty_store() {
658 let store = make_store().await;
659 let params = ListTasksParams::default();
660 let response = store.list(¶ms).await.unwrap();
661 assert!(
662 response.tasks.is_empty(),
663 "list on empty store should return no tasks"
664 );
665 assert!(
666 response.next_page_token.is_empty(),
667 "no pagination token for empty results"
668 );
669 }
670}