a2a_protocol_server/store/
sqlite_store.rs1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
51pub struct SqliteTaskStore {
52 pool: SqlitePool,
53}
54
55impl SqliteTaskStore {
56 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
62 let pool = sqlite_pool(url).await?;
63 Self::from_pool(pool).await
64 }
65
66 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
76 let pool = sqlite_pool(url).await?;
77
78 let runner = super::migration::MigrationRunner::new(pool.clone());
79 runner.run_pending().await?;
80
81 Ok(Self { pool })
82 }
83
84 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
90 sqlx::query(
91 "CREATE TABLE IF NOT EXISTS tasks (
92 id TEXT PRIMARY KEY,
93 context_id TEXT NOT NULL,
94 state TEXT NOT NULL,
95 data TEXT NOT NULL,
96 updated_at TEXT NOT NULL DEFAULT (datetime('now'))
97 )",
98 )
99 .execute(&pool)
100 .await?;
101
102 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
103 .execute(&pool)
104 .await?;
105
106 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
107 .execute(&pool)
108 .await?;
109
110 Ok(Self { pool })
111 }
112}
113
114async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
119 sqlite_pool_with_size(url, 8).await
120}
121
122async fn sqlite_pool_with_size(url: &str, max_connections: u32) -> Result<SqlitePool, sqlx::Error> {
124 use sqlx::sqlite::SqliteConnectOptions;
125 use std::str::FromStr;
126
127 let opts = SqliteConnectOptions::from_str(url)?
128 .pragma("journal_mode", "WAL")
129 .pragma("busy_timeout", "5000")
130 .pragma("synchronous", "NORMAL")
131 .pragma("foreign_keys", "ON")
132 .create_if_missing(true);
133
134 SqlitePoolOptions::new()
135 .max_connections(max_connections)
136 .connect_with(opts)
137 .await
138}
139
140#[allow(clippy::needless_pass_by_value)]
142fn to_a2a_error(e: sqlx::Error) -> A2aError {
143 A2aError::internal(format!("sqlite error: {e}"))
144}
145
146#[allow(clippy::manual_async_fn)]
147impl TaskStore for SqliteTaskStore {
148 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
149 Box::pin(async move {
150 let id = task.id.0.as_str();
151 let context_id = task.context_id.0.as_str();
152 let state = task.status.state.to_string();
153 let data = serde_json::to_string(&task)
154 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
155
156 sqlx::query(
157 "INSERT INTO tasks (id, context_id, state, data, updated_at)
158 VALUES (?1, ?2, ?3, ?4, datetime('now'))
159 ON CONFLICT(id) DO UPDATE SET
160 context_id = excluded.context_id,
161 state = excluded.state,
162 data = excluded.data,
163 updated_at = datetime('now')",
164 )
165 .bind(id)
166 .bind(context_id)
167 .bind(&state)
168 .bind(&data)
169 .execute(&self.pool)
170 .await
171 .map_err(to_a2a_error)?;
172
173 Ok(())
174 })
175 }
176
177 fn get<'a>(
178 &'a self,
179 id: &'a TaskId,
180 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
181 Box::pin(async move {
182 let row: Option<(String,)> = sqlx::query_as("SELECT data FROM tasks WHERE id = ?1")
183 .bind(id.0.as_str())
184 .fetch_optional(&self.pool)
185 .await
186 .map_err(to_a2a_error)?;
187
188 match row {
189 Some((data,)) => {
190 let task: Task = serde_json::from_str(&data).map_err(|e| {
191 A2aError::internal(format!("failed to deserialize task: {e}"))
192 })?;
193 Ok(Some(task))
194 }
195 None => Ok(None),
196 }
197 })
198 }
199
200 fn list<'a>(
201 &'a self,
202 params: &'a ListTasksParams,
203 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
204 Box::pin(async move {
205 let mut conditions = Vec::new();
207 let mut bind_values: Vec<String> = Vec::new();
208
209 if let Some(ref ctx) = params.context_id {
210 conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
211 bind_values.push(ctx.clone());
212 }
213 if let Some(ref status) = params.status {
214 conditions.push(format!("state = ?{}", bind_values.len() + 1));
215 bind_values.push(status.to_string());
216 }
217 if let Some(ref token) = params.page_token {
218 conditions.push(format!("id > ?{}", bind_values.len() + 1));
219 bind_values.push(token.clone());
220 }
221
222 let where_clause = if conditions.is_empty() {
223 String::new()
224 } else {
225 format!("WHERE {}", conditions.join(" AND "))
226 };
227
228 let page_size = match params.page_size {
229 Some(0) | None => 50_u32,
230 Some(n) => n.min(1000),
231 };
232
233 let limit = page_size + 1;
237 let limit_param = bind_values.len() + 1;
238 let sql = format!(
239 "SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT ?{limit_param}"
240 );
241
242 let mut query = sqlx::query_as::<_, (String,)>(&sql);
243 for val in &bind_values {
244 query = query.bind(val);
245 }
246 query = query.bind(limit);
247
248 let rows: Vec<(String,)> = query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
249
250 let mut tasks: Vec<Task> = rows
251 .into_iter()
252 .map(|(data,)| {
253 serde_json::from_str(&data)
254 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
255 })
256 .collect::<A2aResult<Vec<_>>>()?;
257
258 let next_page_token = if tasks.len() > page_size as usize {
259 tasks.truncate(page_size as usize);
260 tasks.last().map(|t| t.id.0.clone())
261 } else {
262 None
263 };
264
265 let mut response = TaskListResponse::new(tasks);
266 response.next_page_token = next_page_token;
267 Ok(response)
268 })
269 }
270
271 fn insert_if_absent<'a>(
272 &'a self,
273 task: Task,
274 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
275 Box::pin(async move {
276 let id = task.id.0.as_str();
277 let context_id = task.context_id.0.as_str();
278 let state = task.status.state.to_string();
279 let data = serde_json::to_string(&task)
280 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
281
282 let result = sqlx::query(
283 "INSERT OR IGNORE INTO tasks (id, context_id, state, data, updated_at)
284 VALUES (?1, ?2, ?3, ?4, datetime('now'))",
285 )
286 .bind(id)
287 .bind(context_id)
288 .bind(&state)
289 .bind(&data)
290 .execute(&self.pool)
291 .await
292 .map_err(to_a2a_error)?;
293
294 Ok(result.rows_affected() > 0)
295 })
296 }
297
298 fn delete<'a>(
299 &'a self,
300 id: &'a TaskId,
301 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
302 Box::pin(async move {
303 sqlx::query("DELETE FROM tasks WHERE id = ?1")
304 .bind(id.0.as_str())
305 .execute(&self.pool)
306 .await
307 .map_err(to_a2a_error)?;
308 Ok(())
309 })
310 }
311
312 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
313 Box::pin(async move {
314 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
315 .fetch_one(&self.pool)
316 .await
317 .map_err(to_a2a_error)?;
318 #[allow(clippy::cast_sign_loss)]
319 Ok(row.0 as u64)
320 })
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
328
329 async fn make_store() -> SqliteTaskStore {
330 SqliteTaskStore::new("sqlite::memory:")
331 .await
332 .expect("failed to create in-memory store")
333 }
334
335 fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
336 Task {
337 id: TaskId::new(id),
338 context_id: ContextId::new(ctx),
339 status: TaskStatus::new(state),
340 history: None,
341 artifacts: None,
342 metadata: None,
343 }
344 }
345
346 #[tokio::test]
347 async fn save_and_get_round_trip() {
348 let store = make_store().await;
349 let task = make_task("t1", "ctx1", TaskState::Submitted);
350 store.save(task.clone()).await.expect("save should succeed");
351
352 let retrieved = store
353 .get(&TaskId::new("t1"))
354 .await
355 .expect("get should succeed");
356 let retrieved = retrieved.expect("task should exist after save");
357 assert_eq!(retrieved.id, TaskId::new("t1"), "task id should match");
358 assert_eq!(
359 retrieved.context_id,
360 ContextId::new("ctx1"),
361 "context_id should match"
362 );
363 assert_eq!(
364 retrieved.status.state,
365 TaskState::Submitted,
366 "state should match"
367 );
368 }
369
370 #[tokio::test]
371 async fn get_returns_none_for_missing_task() {
372 let store = make_store().await;
373 let result = store
374 .get(&TaskId::new("nonexistent"))
375 .await
376 .expect("get should succeed");
377 assert!(
378 result.is_none(),
379 "get should return None for a missing task"
380 );
381 }
382
383 #[tokio::test]
384 async fn save_overwrites_existing_task() {
385 let store = make_store().await;
386 let task1 = make_task("t1", "ctx1", TaskState::Submitted);
387 store.save(task1).await.expect("first save should succeed");
388
389 let task2 = make_task("t1", "ctx1", TaskState::Working);
390 store.save(task2).await.expect("second save should succeed");
391
392 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
393 assert_eq!(
394 retrieved.status.state,
395 TaskState::Working,
396 "state should be updated after overwrite"
397 );
398 }
399
400 #[tokio::test]
401 async fn insert_if_absent_returns_true_for_new_task() {
402 let store = make_store().await;
403 let task = make_task("t1", "ctx1", TaskState::Submitted);
404 let inserted = store
405 .insert_if_absent(task)
406 .await
407 .expect("insert_if_absent should succeed");
408 assert!(
409 inserted,
410 "insert_if_absent should return true for a new task"
411 );
412 }
413
414 #[tokio::test]
415 async fn insert_if_absent_returns_false_for_existing_task() {
416 let store = make_store().await;
417 let task = make_task("t1", "ctx1", TaskState::Submitted);
418 store.save(task.clone()).await.unwrap();
419
420 let duplicate = make_task("t1", "ctx1", TaskState::Working);
421 let inserted = store
422 .insert_if_absent(duplicate)
423 .await
424 .expect("insert_if_absent should succeed");
425 assert!(
426 !inserted,
427 "insert_if_absent should return false for an existing task"
428 );
429
430 let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
432 assert_eq!(
433 retrieved.status.state,
434 TaskState::Submitted,
435 "original state should be preserved"
436 );
437 }
438
439 #[tokio::test]
440 async fn delete_removes_task() {
441 let store = make_store().await;
442 store
443 .save(make_task("t1", "ctx1", TaskState::Submitted))
444 .await
445 .unwrap();
446
447 store
448 .delete(&TaskId::new("t1"))
449 .await
450 .expect("delete should succeed");
451
452 let result = store.get(&TaskId::new("t1")).await.unwrap();
453 assert!(result.is_none(), "task should be gone after delete");
454 }
455
456 #[tokio::test]
457 async fn delete_nonexistent_is_ok() {
458 let store = make_store().await;
459 let result = store.delete(&TaskId::new("nonexistent")).await;
460 assert!(
461 result.is_ok(),
462 "deleting a nonexistent task should not error"
463 );
464 }
465
466 #[tokio::test]
467 async fn count_tracks_inserts_and_deletes() {
468 let store = make_store().await;
469 assert_eq!(
470 store.count().await.unwrap(),
471 0,
472 "empty store should have count 0"
473 );
474
475 store
476 .save(make_task("t1", "ctx1", TaskState::Submitted))
477 .await
478 .unwrap();
479 store
480 .save(make_task("t2", "ctx1", TaskState::Working))
481 .await
482 .unwrap();
483 assert_eq!(
484 store.count().await.unwrap(),
485 2,
486 "count should be 2 after two saves"
487 );
488
489 store.delete(&TaskId::new("t1")).await.unwrap();
490 assert_eq!(
491 store.count().await.unwrap(),
492 1,
493 "count should be 1 after one delete"
494 );
495 }
496
497 #[tokio::test]
498 async fn list_all_tasks() {
499 let store = make_store().await;
500 store
501 .save(make_task("t1", "ctx1", TaskState::Submitted))
502 .await
503 .unwrap();
504 store
505 .save(make_task("t2", "ctx2", TaskState::Working))
506 .await
507 .unwrap();
508
509 let params = ListTasksParams::default();
510 let response = store.list(¶ms).await.expect("list should succeed");
511 assert_eq!(response.tasks.len(), 2, "list should return all tasks");
512 }
513
514 #[tokio::test]
515 async fn list_filter_by_context_id() {
516 let store = make_store().await;
517 store
518 .save(make_task("t1", "ctx-a", TaskState::Submitted))
519 .await
520 .unwrap();
521 store
522 .save(make_task("t2", "ctx-b", TaskState::Submitted))
523 .await
524 .unwrap();
525 store
526 .save(make_task("t3", "ctx-a", TaskState::Working))
527 .await
528 .unwrap();
529
530 let params = ListTasksParams {
531 context_id: Some("ctx-a".to_string()),
532 ..Default::default()
533 };
534 let response = store.list(¶ms).await.unwrap();
535 assert_eq!(
536 response.tasks.len(),
537 2,
538 "should return only tasks with context_id ctx-a"
539 );
540 }
541
542 #[tokio::test]
543 async fn list_filter_by_status() {
544 let store = make_store().await;
545 store
546 .save(make_task("t1", "ctx1", TaskState::Submitted))
547 .await
548 .unwrap();
549 store
550 .save(make_task("t2", "ctx1", TaskState::Working))
551 .await
552 .unwrap();
553 store
554 .save(make_task("t3", "ctx1", TaskState::Working))
555 .await
556 .unwrap();
557
558 let params = ListTasksParams {
559 status: Some(TaskState::Working),
560 ..Default::default()
561 };
562 let response = store.list(¶ms).await.unwrap();
563 assert_eq!(response.tasks.len(), 2, "should return only Working tasks");
564 }
565
566 #[tokio::test]
567 async fn list_pagination() {
568 let store = make_store().await;
569 for i in 0..5 {
571 store
572 .save(make_task(
573 &format!("task-{i:03}"),
574 "ctx1",
575 TaskState::Submitted,
576 ))
577 .await
578 .unwrap();
579 }
580
581 let params = ListTasksParams {
583 page_size: Some(2),
584 ..Default::default()
585 };
586 let response = store.list(¶ms).await.unwrap();
587 assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
588 assert!(
589 response.next_page_token.is_some(),
590 "should have a next page token"
591 );
592
593 let params2 = ListTasksParams {
595 page_size: Some(2),
596 page_token: response.next_page_token,
597 ..Default::default()
598 };
599 let response2 = store.list(¶ms2).await.unwrap();
600 assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
601 assert!(
602 response2.next_page_token.is_some(),
603 "should still have a next page token"
604 );
605
606 let params3 = ListTasksParams {
608 page_size: Some(2),
609 page_token: response2.next_page_token,
610 ..Default::default()
611 };
612 let response3 = store.list(¶ms3).await.unwrap();
613 assert_eq!(response3.tasks.len(), 1, "last page should have 1 task");
614 assert!(
615 response3.next_page_token.is_none(),
616 "last page should have no next page token"
617 );
618 }
619
620 #[test]
622 fn to_a2a_error_formats_message() {
623 let sqlite_err = sqlx::Error::RowNotFound;
624 let a2a_err = to_a2a_error(sqlite_err);
625 let msg = format!("{a2a_err}");
626 assert!(
627 msg.contains("sqlite error"),
628 "error message should contain 'sqlite error': {msg}"
629 );
630 }
631
632 #[tokio::test]
634 async fn with_migrations_creates_store() {
635 let result = SqliteTaskStore::with_migrations("sqlite::memory:").await;
637 assert!(
638 result.is_ok(),
639 "with_migrations should succeed on a fresh database"
640 );
641 let store = result.unwrap();
642 let count = store.count().await.unwrap();
643 assert_eq!(count, 0, "freshly migrated store should be empty");
644 }
645
646 #[tokio::test]
647 async fn list_empty_store() {
648 let store = make_store().await;
649 let params = ListTasksParams::default();
650 let response = store.list(¶ms).await.unwrap();
651 assert!(
652 response.tasks.is_empty(),
653 "list on empty store should return no tasks"
654 );
655 assert!(
656 response.next_page_token.is_none(),
657 "no pagination token for empty results"
658 );
659 }
660}