1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
51pub struct SqliteTaskStore {
52 pool: SqlitePool,
53}
54
55impl SqliteTaskStore {
56 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
62 let pool = sqlite_pool(url).await?;
63 Self::from_pool(pool).await
64 }
65
66 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
76 let pool = sqlite_pool(url).await?;
77
78 let runner = super::migration::MigrationRunner::new(pool.clone());
79 runner.run_pending().await?;
80
81 Ok(Self { pool })
82 }
83
84 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
90 sqlx::query(
91 "CREATE TABLE IF NOT EXISTS tasks (
92 id TEXT PRIMARY KEY,
93 context_id TEXT NOT NULL,
94 state TEXT NOT NULL,
95 data TEXT NOT NULL,
96 updated_at TEXT NOT NULL DEFAULT (datetime('now')),
97 created_at TEXT NOT NULL DEFAULT (datetime('now'))
98 )",
99 )
100 .execute(&pool)
101 .await?;
102
103 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
104 .execute(&pool)
105 .await?;
106
107 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
108 .execute(&pool)
109 .await?;
110
111 sqlx::query(
112 "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
113 )
114 .execute(&pool)
115 .await?;
116
117 Ok(Self { pool })
118 }
119}
120
121async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
126 sqlite_pool_with_size(url, 8).await
127}
128
129async fn sqlite_pool_with_size(url: &str, max_connections: u32) -> Result<SqlitePool, sqlx::Error> {
131 use sqlx::sqlite::SqliteConnectOptions;
132 use std::str::FromStr;
133
134 let opts = SqliteConnectOptions::from_str(url)?
135 .pragma("journal_mode", "WAL")
136 .pragma("busy_timeout", "5000")
137 .pragma("synchronous", "NORMAL")
138 .pragma("foreign_keys", "ON")
139 .create_if_missing(true);
140
141 SqlitePoolOptions::new()
142 .max_connections(max_connections)
143 .connect_with(opts)
144 .await
145}
146
147#[allow(clippy::needless_pass_by_value)]
149fn to_a2a_error(e: sqlx::Error) -> A2aError {
150 A2aError::internal(format!("sqlite error: {e}"))
151}
152
153#[allow(clippy::manual_async_fn)]
154impl TaskStore for SqliteTaskStore {
155 fn save<'a>(
156 &'a self,
157 task: &'a Task,
158 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
159 Box::pin(async move {
160 let id = task.id.0.as_str();
161 let context_id = task.context_id.0.as_str();
162 let state = task.status.state.to_string();
163 let data = serde_json::to_string(task)
164 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
165
166 sqlx::query(
167 "INSERT INTO tasks (id, context_id, state, data, updated_at)
168 VALUES (?1, ?2, ?3, ?4, datetime('now'))
169 ON CONFLICT(id) DO UPDATE SET
170 context_id = excluded.context_id,
171 state = excluded.state,
172 data = excluded.data,
173 updated_at = datetime('now')",
174 )
175 .bind(id)
176 .bind(context_id)
177 .bind(&state)
178 .bind(&data)
179 .execute(&self.pool)
180 .await
181 .map_err(to_a2a_error)?;
182
183 Ok(())
184 })
185 }
186
187 fn get<'a>(
188 &'a self,
189 id: &'a TaskId,
190 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
191 Box::pin(async move {
192 let row: Option<(String,)> = sqlx::query_as("SELECT data FROM tasks WHERE id = ?1")
193 .bind(id.0.as_str())
194 .fetch_optional(&self.pool)
195 .await
196 .map_err(to_a2a_error)?;
197
198 match row {
199 Some((data,)) => {
200 let task: Task = serde_json::from_str(&data).map_err(|e| {
201 A2aError::internal(format!("failed to deserialize task: {e}"))
202 })?;
203 Ok(Some(task))
204 }
205 None => Ok(None),
206 }
207 })
208 }
209
210 fn list<'a>(
211 &'a self,
212 params: &'a ListTasksParams,
213 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
214 Box::pin(async move {
215 let mut conditions = Vec::new();
217 let mut bind_values: Vec<String> = Vec::new();
218
219 if let Some(ref ctx) = params.context_id {
220 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
221 bind_values.push(ctx.clone());
222 }
223 if let Some(ref status) = params.status {
224 conditions.push(format!("state = ?{}", bind_values.len() + 1));
225 bind_values.push(status.to_string());
226 }
227 if let Some(ref token) = params.page_token {
228 conditions.push(format!("id > ?{}", bind_values.len() + 1));
229 bind_values.push(token.clone());
230 }
231
232 let where_clause = if conditions.is_empty() {
233 String::new()
234 } else {
235 format!("WHERE {}", conditions.join(" AND "))
236 };
237
238 let page_size = match params.page_size {
239 Some(0) | None => 50_u32,
240 Some(n) => n.min(1000),
241 };
242
243 let limit = page_size + 1;
247 let limit_param = bind_values.len() + 1;
248 let sql = format!(
249 "SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT ?{limit_param}"
250 );
251
252 let mut query = sqlx::query_as::<_, (String,)>(&sql);
253 for val in &bind_values {
254 query = query.bind(val);
255 }
256 query = query.bind(limit);
257
258 let rows: Vec<(String,)> = query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
259
260 let mut tasks: Vec<Task> = rows
261 .into_iter()
262 .map(|(data,)| {
263 serde_json::from_str(&data)
264 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
265 })
266 .collect::<A2aResult<Vec<_>>>()?;
267
268 let next_page_token = if tasks.len() > page_size as usize {
269 tasks.truncate(page_size as usize);
270 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
271 } else {
272 String::new()
273 };
274
275 #[allow(clippy::cast_possible_truncation)]
276 let page_len = tasks.len() as u32;
277 let mut response = TaskListResponse::new(tasks);
278 response.next_page_token = next_page_token;
279 response.page_size = page_len;
280 Ok(response)
281 })
282 }
283
284 fn insert_if_absent<'a>(
285 &'a self,
286 task: &'a Task,
287 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
288 Box::pin(async move {
289 let id = task.id.0.as_str();
290 let context_id = task.context_id.0.as_str();
291 let state = task.status.state.to_string();
292 let data = serde_json::to_string(task)
293 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
294
295 let result = sqlx::query(
296 "INSERT OR IGNORE INTO tasks (id, context_id, state, data, updated_at)
297 VALUES (?1, ?2, ?3, ?4, datetime('now'))",
298 )
299 .bind(id)
300 .bind(context_id)
301 .bind(&state)
302 .bind(&data)
303 .execute(&self.pool)
304 .await
305 .map_err(to_a2a_error)?;
306
307 Ok(result.rows_affected() > 0)
308 })
309 }
310
311 fn delete<'a>(
312 &'a self,
313 id: &'a TaskId,
314 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
315 Box::pin(async move {
316 sqlx::query("DELETE FROM tasks WHERE id = ?1")
317 .bind(id.0.as_str())
318 .execute(&self.pool)
319 .await
320 .map_err(to_a2a_error)?;
321 Ok(())
322 })
323 }
324
325 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
326 Box::pin(async move {
327 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
328 .fetch_one(&self.pool)
329 .await
330 .map_err(to_a2a_error)?;
331 #[allow(clippy::cast_sign_loss)]
332 Ok(row.0 as u64)
333 })
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
341
342 async fn make_store() -> SqliteTaskStore {
343 SqliteTaskStore::new("sqlite::memory:")
344 .await
345 .expect("failed to create in-memory store")
346 }
347
348 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
349 Task {
350 id: TaskId::new(id),
351 context_id: ContextId::new(ctx),
352 status: TaskStatus::new(state),
353 history: None,
354 artifacts: None,
355 metadata: None,
356 }
357 }
358
359 #[tokio::test]
360 async fn save_and_get_round_trip() {
361 let store = make_store().await;
362 let task = make_task("t1", "ctx1", TaskState::Submitted);
363 store.save(&task).await.expect("save should succeed");
364
365 let retrieved = store
366 .get(&TaskId::new("t1"))
367 .await
368 .expect("get should succeed");
369 let retrieved = retrieved.expect("task should exist after save");
370 assert_eq!(retrieved.id, TaskId::new("t1"), "task id should match");
371 assert_eq!(
372 retrieved.context_id,
373 ContextId::new("ctx1"),
374 "context_id should match"
375 );
376 assert_eq!(
377 retrieved.status.state,
378 TaskState::Submitted,
379 "state should match"
380 );
381 }
382
383 #[tokio::test]
384 async fn get_returns_none_for_missing_task() {
385 let store = make_store().await;
386 let result = store
387 .get(&TaskId::new("nonexistent"))
388 .await
389 .expect("get should succeed");
390 assert!(
391 result.is_none(),
392 "get should return None for a missing task"
393 );
394 }
395
396 #[tokio::test]
397 async fn save_overwrites_existing_task() {
398 let store = make_store().await;
399 let task1 = make_task("t1", "ctx1", TaskState::Submitted);
400 store.save(&task1).await.expect("first save should succeed");
401
402 let task2 = make_task("t1", "ctx1", TaskState::Working);
403 store
404 .save(&task2)
405 .await
406 .expect("second save should succeed");
407
408 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
409 assert_eq!(
410 retrieved.status.state,
411 TaskState::Working,
412 "state should be updated after overwrite"
413 );
414 }
415
416 #[tokio::test]
417 async fn insert_if_absent_returns_true_for_new_task() {
418 let store = make_store().await;
419 let task = make_task("t1", "ctx1", TaskState::Submitted);
420 let inserted = store
421 .insert_if_absent(&task)
422 .await
423 .expect("insert_if_absent should succeed");
424 assert!(
425 inserted,
426 "insert_if_absent should return true for a new task"
427 );
428 }
429
430 #[tokio::test]
431 async fn insert_if_absent_returns_false_for_existing_task() {
432 let store = make_store().await;
433 let task = make_task("t1", "ctx1", TaskState::Submitted);
434 store.save(&task).await.unwrap();
435
436 let duplicate = make_task("t1", "ctx1", TaskState::Working);
437 let inserted = store
438 .insert_if_absent(&duplicate)
439 .await
440 .expect("insert_if_absent should succeed");
441 assert!(
442 !inserted,
443 "insert_if_absent should return false for an existing task"
444 );
445
446 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
448 assert_eq!(
449 retrieved.status.state,
450 TaskState::Submitted,
451 "original state should be preserved"
452 );
453 }
454
455 #[tokio::test]
456 async fn delete_removes_task() {
457 let store = make_store().await;
458 store
459 .save(&make_task("t1", "ctx1", TaskState::Submitted))
460 .await
461 .unwrap();
462
463 store
464 .delete(&TaskId::new("t1"))
465 .await
466 .expect("delete should succeed");
467
468 let result = store.get(&TaskId::new("t1")).await.unwrap();
469 assert!(result.is_none(), "task should be gone after delete");
470 }
471
472 #[tokio::test]
473 async fn delete_nonexistent_is_ok() {
474 let store = make_store().await;
475 let result = store.delete(&TaskId::new("nonexistent")).await;
476 assert!(
477 result.is_ok(),
478 "deleting a nonexistent task should not error"
479 );
480 }
481
482 #[tokio::test]
483 async fn count_tracks_inserts_and_deletes() {
484 let store = make_store().await;
485 assert_eq!(
486 store.count().await.unwrap(),
487 0,
488 "empty store should have count 0"
489 );
490
491 store
492 .save(&make_task("t1", "ctx1", TaskState::Submitted))
493 .await
494 .unwrap();
495 store
496 .save(&make_task("t2", "ctx1", TaskState::Working))
497 .await
498 .unwrap();
499 assert_eq!(
500 store.count().await.unwrap(),
501 2,
502 "count should be 2 after two saves"
503 );
504
505 store.delete(&TaskId::new("t1")).await.unwrap();
506 assert_eq!(
507 store.count().await.unwrap(),
508 1,
509 "count should be 1 after one delete"
510 );
511 }
512
513 #[tokio::test]
514 async fn list_all_tasks() {
515 let store = make_store().await;
516 store
517 .save(&make_task("t1", "ctx1", TaskState::Submitted))
518 .await
519 .unwrap();
520 store
521 .save(&make_task("t2", "ctx2", TaskState::Working))
522 .await
523 .unwrap();
524
525 let params = ListTasksParams::default();
526 let response = store.list(¶ms).await.expect("list should succeed");
527 assert_eq!(response.tasks.len(), 2, "list should return all tasks");
528 }
529
530 #[tokio::test]
531 async fn list_filter_by_context_id() {
532 let store = make_store().await;
533 store
534 .save(&make_task("t1", "ctx-a", TaskState::Submitted))
535 .await
536 .unwrap();
537 store
538 .save(&make_task("t2", "ctx-b", TaskState::Submitted))
539 .await
540 .unwrap();
541 store
542 .save(&make_task("t3", "ctx-a", TaskState::Working))
543 .await
544 .unwrap();
545
546 let params = ListTasksParams {
547 context_id: Some("ctx-a".to_string()),
548 ..Default::default()
549 };
550 let response = store.list(¶ms).await.unwrap();
551 assert_eq!(
552 response.tasks.len(),
553 2,
554 "should return only tasks with context_id ctx-a"
555 );
556 }
557
558 #[tokio::test]
559 async fn list_filter_by_status() {
560 let store = make_store().await;
561 store
562 .save(&make_task("t1", "ctx1", TaskState::Submitted))
563 .await
564 .unwrap();
565 store
566 .save(&make_task("t2", "ctx1", TaskState::Working))
567 .await
568 .unwrap();
569 store
570 .save(&make_task("t3", "ctx1", TaskState::Working))
571 .await
572 .unwrap();
573
574 let params = ListTasksParams {
575 status: Some(TaskState::Working),
576 ..Default::default()
577 };
578 let response = store.list(¶ms).await.unwrap();
579 assert_eq!(response.tasks.len(), 2, "should return only Working tasks");
580 }
581
582 #[tokio::test]
583 async fn list_pagination() {
584 let store = make_store().await;
585 for i in 0..5 {
587 store
588 .save(&make_task(
589 &format!("task-{i:03}"),
590 "ctx1",
591 TaskState::Submitted,
592 ))
593 .await
594 .unwrap();
595 }
596
597 let params = ListTasksParams {
599 page_size: Some(2),
600 ..Default::default()
601 };
602 let response = store.list(¶ms).await.unwrap();
603 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
604 assert!(
605 !response.next_page_token.is_empty(),
606 "should have a next page token"
607 );
608
609 let params2 = ListTasksParams {
611 page_size: Some(2),
612 page_token: Some(response.next_page_token),
613 ..Default::default()
614 };
615 let response2 = store.list(¶ms2).await.unwrap();
616 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
617 assert!(
618 !response2.next_page_token.is_empty(),
619 "should still have a next page token"
620 );
621
622 let params3 = ListTasksParams {
624 page_size: Some(2),
625 page_token: Some(response2.next_page_token),
626 ..Default::default()
627 };
628 let response3 = store.list(¶ms3).await.unwrap();
629 assert_eq!(response3.tasks.len(), 1, "last page should have 1 task");
630 assert!(
631 response3.next_page_token.is_empty(),
632 "last page should have no next page token"
633 );
634 }
635
636 #[test]
638 fn to_a2a_error_formats_message() {
639 let sqlite_err = sqlx::Error::RowNotFound;
640 let a2a_err = to_a2a_error(sqlite_err);
641 let msg = format!("{a2a_err}");
642 assert!(
643 msg.contains("sqlite error"),
644 "error message should contain 'sqlite error': {msg}"
645 );
646 }
647
648 #[tokio::test]
650 async fn with_migrations_creates_store() {
651 let result = SqliteTaskStore::with_migrations("sqlite::memory:").await;
653 assert!(
654 result.is_ok(),
655 "with_migrations should succeed on a fresh database"
656 );
657 let store = result.unwrap();
658 let count = store.count().await.unwrap();
659 assert_eq!(count, 0, "freshly migrated store should be empty");
660 }
661
662 #[tokio::test]
663 async fn list_empty_store() {
664 let store = make_store().await;
665 let params = ListTasksParams::default();
666 let response = store.list(¶ms).await.unwrap();
667 assert!(
668 response.tasks.is_empty(),
669 "list on empty store should return no tasks"
670 );
671 assert!(
672 response.next_page_token.is_empty(),
673 "no pagination token for empty results"
674 );
675 }
676}