use std::future::Future;
use std::pin::Pin;
use a2a_protocol_types::error::{A2aError, A2aResult};
use a2a_protocol_types::params::ListTasksParams;
use a2a_protocol_types::responses::TaskListResponse;
use a2a_protocol_types::task::{Task, TaskId};
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions};
use super::task_store::TaskStore;
#[derive(Debug, Clone)]
pub struct SqliteTaskStore {
pool: SqlitePool,
}
impl SqliteTaskStore {
pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
let pool = sqlite_pool(url).await?;
Self::from_pool(pool).await
}
pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
let pool = sqlite_pool(url).await?;
let runner = super::migration::MigrationRunner::new(pool.clone());
runner.run_pending().await?;
Ok(Self { pool })
}
pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
sqlx::query(
"CREATE TABLE IF NOT EXISTS tasks (
id TEXT PRIMARY KEY,
context_id TEXT NOT NULL,
state TEXT NOT NULL,
data TEXT NOT NULL,
updated_at TEXT NOT NULL DEFAULT (datetime('now')),
created_at TEXT NOT NULL DEFAULT (datetime('now'))
)",
)
.execute(&pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
.execute(&pool)
.await?;
sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
.execute(&pool)
.await?;
sqlx::query(
"CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
)
.execute(&pool)
.await?;
Ok(Self { pool })
}
}
async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
sqlite_pool_with_size(url, 8).await
}
async fn sqlite_pool_with_size(url: &str, max_connections: u32) -> Result<SqlitePool, sqlx::Error> {
use sqlx::sqlite::SqliteConnectOptions;
use std::str::FromStr;
let opts = SqliteConnectOptions::from_str(url)?
.pragma("journal_mode", "WAL")
.pragma("busy_timeout", "5000")
.pragma("synchronous", "NORMAL")
.pragma("foreign_keys", "ON")
.create_if_missing(true);
SqlitePoolOptions::new()
.max_connections(max_connections)
.connect_with(opts)
.await
}
#[allow(clippy::needless_pass_by_value)]
fn to_a2a_error(e: sqlx::Error) -> A2aError {
A2aError::internal(format!("sqlite error: {e}"))
}
#[allow(clippy::manual_async_fn)]
impl TaskStore for SqliteTaskStore {
fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let id = task.id.0.as_str();
let context_id = task.context_id.0.as_str();
let state = task.status.state.to_string();
let data = serde_json::to_string(&task)
.map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
sqlx::query(
"INSERT INTO tasks (id, context_id, state, data, updated_at)
VALUES (?1, ?2, ?3, ?4, datetime('now'))
ON CONFLICT(id) DO UPDATE SET
context_id = excluded.context_id,
state = excluded.state,
data = excluded.data,
updated_at = datetime('now')",
)
.bind(id)
.bind(context_id)
.bind(&state)
.bind(&data)
.execute(&self.pool)
.await
.map_err(to_a2a_error)?;
Ok(())
})
}
fn get<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
Box::pin(async move {
let row: Option<(String,)> = sqlx::query_as("SELECT data FROM tasks WHERE id = ?1")
.bind(id.0.as_str())
.fetch_optional(&self.pool)
.await
.map_err(to_a2a_error)?;
match row {
Some((data,)) => {
let task: Task = serde_json::from_str(&data).map_err(|e| {
A2aError::internal(format!("failed to deserialize task: {e}"))
})?;
Ok(Some(task))
}
None => Ok(None),
}
})
}
fn list<'a>(
&'a self,
params: &'a ListTasksParams,
) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
Box::pin(async move {
let mut conditions = Vec::new();
let mut bind_values: Vec<String> = Vec::new();
if let Some(ref ctx) = params.context_id {
conditions.push(format!("context_id = ?{}", bind_values.len() + 1));
bind_values.push(ctx.clone());
}
if let Some(ref status) = params.status {
conditions.push(format!("state = ?{}", bind_values.len() + 1));
bind_values.push(status.to_string());
}
if let Some(ref token) = params.page_token {
conditions.push(format!("id > ?{}", bind_values.len() + 1));
bind_values.push(token.clone());
}
let where_clause = if conditions.is_empty() {
String::new()
} else {
format!("WHERE {}", conditions.join(" AND "))
};
let page_size = match params.page_size {
Some(0) | None => 50_u32,
Some(n) => n.min(1000),
};
let limit = page_size + 1;
let limit_param = bind_values.len() + 1;
let sql = format!(
"SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT ?{limit_param}"
);
let mut query = sqlx::query_as::<_, (String,)>(&sql);
for val in &bind_values {
query = query.bind(val);
}
query = query.bind(limit);
let rows: Vec<(String,)> = query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
let mut tasks: Vec<Task> = rows
.into_iter()
.map(|(data,)| {
serde_json::from_str(&data)
.map_err(|e| A2aError::internal(format!("deserialize: {e}")))
})
.collect::<A2aResult<Vec<_>>>()?;
let next_page_token = if tasks.len() > page_size as usize {
tasks.truncate(page_size as usize);
tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
} else {
String::new()
};
#[allow(clippy::cast_possible_truncation)]
let page_len = tasks.len() as u32;
let mut response = TaskListResponse::new(tasks);
response.next_page_token = next_page_token;
response.page_size = page_len;
Ok(response)
})
}
fn insert_if_absent<'a>(
&'a self,
task: Task,
) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
Box::pin(async move {
let id = task.id.0.as_str();
let context_id = task.context_id.0.as_str();
let state = task.status.state.to_string();
let data = serde_json::to_string(&task)
.map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
let result = sqlx::query(
"INSERT OR IGNORE INTO tasks (id, context_id, state, data, updated_at)
VALUES (?1, ?2, ?3, ?4, datetime('now'))",
)
.bind(id)
.bind(context_id)
.bind(&state)
.bind(&data)
.execute(&self.pool)
.await
.map_err(to_a2a_error)?;
Ok(result.rows_affected() > 0)
})
}
fn delete<'a>(
&'a self,
id: &'a TaskId,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
sqlx::query("DELETE FROM tasks WHERE id = ?1")
.bind(id.0.as_str())
.execute(&self.pool)
.await
.map_err(to_a2a_error)?;
Ok(())
})
}
fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
Box::pin(async move {
let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
.fetch_one(&self.pool)
.await
.map_err(to_a2a_error)?;
#[allow(clippy::cast_sign_loss)]
Ok(row.0 as u64)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use a2a_protocol_types::task::{ContextId, Task, TaskId, TaskState, TaskStatus};
async fn make_store() -> SqliteTaskStore {
SqliteTaskStore::new("sqlite::memory:")
.await
.expect("failed to create in-memory store")
}
fn make_task(id: &str, ctx: &str, state: TaskState) -> Task {
Task {
id: TaskId::new(id),
context_id: ContextId::new(ctx),
status: TaskStatus::new(state),
history: None,
artifacts: None,
metadata: None,
}
}
#[tokio::test]
async fn save_and_get_round_trip() {
let store = make_store().await;
let task = make_task("t1", "ctx1", TaskState::Submitted);
store.save(task.clone()).await.expect("save should succeed");
let retrieved = store
.get(&TaskId::new("t1"))
.await
.expect("get should succeed");
let retrieved = retrieved.expect("task should exist after save");
assert_eq!(retrieved.id, TaskId::new("t1"), "task id should match");
assert_eq!(
retrieved.context_id,
ContextId::new("ctx1"),
"context_id should match"
);
assert_eq!(
retrieved.status.state,
TaskState::Submitted,
"state should match"
);
}
#[tokio::test]
async fn get_returns_none_for_missing_task() {
let store = make_store().await;
let result = store
.get(&TaskId::new("nonexistent"))
.await
.expect("get should succeed");
assert!(
result.is_none(),
"get should return None for a missing task"
);
}
#[tokio::test]
async fn save_overwrites_existing_task() {
let store = make_store().await;
let task1 = make_task("t1", "ctx1", TaskState::Submitted);
store.save(task1).await.expect("first save should succeed");
let task2 = make_task("t1", "ctx1", TaskState::Working);
store.save(task2).await.expect("second save should succeed");
let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
assert_eq!(
retrieved.status.state,
TaskState::Working,
"state should be updated after overwrite"
);
}
#[tokio::test]
async fn insert_if_absent_returns_true_for_new_task() {
let store = make_store().await;
let task = make_task("t1", "ctx1", TaskState::Submitted);
let inserted = store
.insert_if_absent(task)
.await
.expect("insert_if_absent should succeed");
assert!(
inserted,
"insert_if_absent should return true for a new task"
);
}
#[tokio::test]
async fn insert_if_absent_returns_false_for_existing_task() {
let store = make_store().await;
let task = make_task("t1", "ctx1", TaskState::Submitted);
store.save(task.clone()).await.unwrap();
let duplicate = make_task("t1", "ctx1", TaskState::Working);
let inserted = store
.insert_if_absent(duplicate)
.await
.expect("insert_if_absent should succeed");
assert!(
!inserted,
"insert_if_absent should return false for an existing task"
);
let retrieved = store.get(&TaskId::new("t1")).await.unwrap().unwrap();
assert_eq!(
retrieved.status.state,
TaskState::Submitted,
"original state should be preserved"
);
}
#[tokio::test]
async fn delete_removes_task() {
let store = make_store().await;
store
.save(make_task("t1", "ctx1", TaskState::Submitted))
.await
.unwrap();
store
.delete(&TaskId::new("t1"))
.await
.expect("delete should succeed");
let result = store.get(&TaskId::new("t1")).await.unwrap();
assert!(result.is_none(), "task should be gone after delete");
}
#[tokio::test]
async fn delete_nonexistent_is_ok() {
let store = make_store().await;
let result = store.delete(&TaskId::new("nonexistent")).await;
assert!(
result.is_ok(),
"deleting a nonexistent task should not error"
);
}
#[tokio::test]
async fn count_tracks_inserts_and_deletes() {
let store = make_store().await;
assert_eq!(
store.count().await.unwrap(),
0,
"empty store should have count 0"
);
store
.save(make_task("t1", "ctx1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", "ctx1", TaskState::Working))
.await
.unwrap();
assert_eq!(
store.count().await.unwrap(),
2,
"count should be 2 after two saves"
);
store.delete(&TaskId::new("t1")).await.unwrap();
assert_eq!(
store.count().await.unwrap(),
1,
"count should be 1 after one delete"
);
}
#[tokio::test]
async fn list_all_tasks() {
let store = make_store().await;
store
.save(make_task("t1", "ctx1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", "ctx2", TaskState::Working))
.await
.unwrap();
let params = ListTasksParams::default();
let response = store.list(¶ms).await.expect("list should succeed");
assert_eq!(response.tasks.len(), 2, "list should return all tasks");
}
#[tokio::test]
async fn list_filter_by_context_id() {
let store = make_store().await;
store
.save(make_task("t1", "ctx-a", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", "ctx-b", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t3", "ctx-a", TaskState::Working))
.await
.unwrap();
let params = ListTasksParams {
context_id: Some("ctx-a".to_string()),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(
response.tasks.len(),
2,
"should return only tasks with context_id ctx-a"
);
}
#[tokio::test]
async fn list_filter_by_status() {
let store = make_store().await;
store
.save(make_task("t1", "ctx1", TaskState::Submitted))
.await
.unwrap();
store
.save(make_task("t2", "ctx1", TaskState::Working))
.await
.unwrap();
store
.save(make_task("t3", "ctx1", TaskState::Working))
.await
.unwrap();
let params = ListTasksParams {
status: Some(TaskState::Working),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(response.tasks.len(), 2, "should return only Working tasks");
}
#[tokio::test]
async fn list_pagination() {
let store = make_store().await;
for i in 0..5 {
store
.save(make_task(
&format!("task-{i:03}"),
"ctx1",
TaskState::Submitted,
))
.await
.unwrap();
}
let params = ListTasksParams {
page_size: Some(2),
..Default::default()
};
let response = store.list(¶ms).await.unwrap();
assert_eq!(response.tasks.len(), 2, "first page should have 2 tasks");
assert!(
!response.next_page_token.is_empty(),
"should have a next page token"
);
let params2 = ListTasksParams {
page_size: Some(2),
page_token: Some(response.next_page_token),
..Default::default()
};
let response2 = store.list(¶ms2).await.unwrap();
assert_eq!(response2.tasks.len(), 2, "second page should have 2 tasks");
assert!(
!response2.next_page_token.is_empty(),
"should still have a next page token"
);
let params3 = ListTasksParams {
page_size: Some(2),
page_token: Some(response2.next_page_token),
..Default::default()
};
let response3 = store.list(¶ms3).await.unwrap();
assert_eq!(response3.tasks.len(), 1, "last page should have 1 task");
assert!(
response3.next_page_token.is_empty(),
"last page should have no next page token"
);
}
#[test]
fn to_a2a_error_formats_message() {
let sqlite_err = sqlx::Error::RowNotFound;
let a2a_err = to_a2a_error(sqlite_err);
let msg = format!("{a2a_err}");
assert!(
msg.contains("sqlite error"),
"error message should contain 'sqlite error': {msg}"
);
}
#[tokio::test]
async fn with_migrations_creates_store() {
let result = SqliteTaskStore::with_migrations("sqlite::memory:").await;
assert!(
result.is_ok(),
"with_migrations should succeed on a fresh database"
);
let store = result.unwrap();
let count = store.count().await.unwrap();
assert_eq!(count, 0, "freshly migrated store should be empty");
}
#[tokio::test]
async fn list_empty_store() {
let store = make_store().await;
let params = ListTasksParams::default();
let response = store.list(¶ms).await.unwrap();
assert!(
response.tasks.is_empty(),
"list on empty store should return no tasks"
);
assert!(
response.next_page_token.is_empty(),
"no pagination token for empty results"
);
}
}