a2a_protocol_server/store/
postgres_store.rs1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::postgres::{PgPool, PgPoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
52pub struct PostgresTaskStore {
53 pool: PgPool,
54}
55
56impl PostgresTaskStore {
57 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
63 let pool = pg_pool(url).await?;
64 Self::from_pool(pool).await
65 }
66
67 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
77 let pool = pg_pool(url).await?;
78
79 let runner = super::pg_migration::PgMigrationRunner::new(pool.clone());
80 runner.run_pending().await?;
81
82 Ok(Self { pool })
83 }
84
85 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
91 sqlx::query(
92 "CREATE TABLE IF NOT EXISTS tasks (
93 id TEXT PRIMARY KEY,
94 context_id TEXT NOT NULL,
95 state TEXT NOT NULL,
96 data JSONB NOT NULL,
97 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
98 updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
99 )",
100 )
101 .execute(&pool)
102 .await?;
103
104 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
105 .execute(&pool)
106 .await?;
107
108 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
109 .execute(&pool)
110 .await?;
111
112 sqlx::query(
113 "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
114 )
115 .execute(&pool)
116 .await?;
117
118 Ok(Self { pool })
119 }
120}
121
122async fn pg_pool(url: &str) -> Result<PgPool, sqlx::Error> {
124 pg_pool_with_size(url, 10).await
125}
126
127async fn pg_pool_with_size(url: &str, max_connections: u32) -> Result<PgPool, sqlx::Error> {
129 PgPoolOptions::new()
130 .max_connections(max_connections)
131 .connect(url)
132 .await
133}
134
135#[allow(clippy::needless_pass_by_value)]
137fn to_a2a_error(e: sqlx::Error) -> A2aError {
138 A2aError::internal(format!("postgres error: {e}"))
139}
140
141#[allow(clippy::manual_async_fn)]
142impl TaskStore for PostgresTaskStore {
143 fn save<'a>(
144 &'a self,
145 task: &'a Task,
146 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
147 Box::pin(async move {
148 let id = task.id.0.as_str();
149 let context_id = task.context_id.0.as_str();
150 let state = task.status.state.to_string();
151 let data = serde_json::to_value(task)
152 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
153
154 sqlx::query(
155 "INSERT INTO tasks (id, context_id, state, data, updated_at)
156 VALUES ($1, $2, $3, $4, now())
157 ON CONFLICT(id) DO UPDATE SET
158 context_id = EXCLUDED.context_id,
159 state = EXCLUDED.state,
160 data = EXCLUDED.data,
161 updated_at = now()",
162 )
163 .bind(id)
164 .bind(context_id)
165 .bind(&state)
166 .bind(&data)
167 .execute(&self.pool)
168 .await
169 .map_err(to_a2a_error)?;
170
171 Ok(())
172 })
173 }
174
175 fn get<'a>(
176 &'a self,
177 id: &'a TaskId,
178 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
179 Box::pin(async move {
180 let row: Option<(serde_json::Value,)> =
181 sqlx::query_as("SELECT data FROM tasks WHERE id = $1")
182 .bind(id.0.as_str())
183 .fetch_optional(&self.pool)
184 .await
185 .map_err(to_a2a_error)?;
186
187 match row {
188 Some((data,)) => {
189 let task: Task = serde_json::from_value(data).map_err(|e| {
190 A2aError::internal(format!("failed to deserialize task: {e}"))
191 })?;
192 Ok(Some(task))
193 }
194 None => Ok(None),
195 }
196 })
197 }
198
199 fn list<'a>(
200 &'a self,
201 params: &'a ListTasksParams,
202 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
203 Box::pin(async move {
204 let mut conditions = Vec::new();
206 let mut bind_values: Vec<String> = Vec::new();
207
208 if let Some(ref ctx) = params.context_id {
209 bind_values.push(ctx.clone());
210 conditions.push(format!("context_id = ${}", bind_values.len()));
211 }
212 if let Some(ref status) = params.status {
213 bind_values.push(status.to_string());
214 conditions.push(format!("state = ${}", bind_values.len()));
215 }
216 if let Some(ref token) = params.page_token {
217 bind_values.push(token.clone());
218 conditions.push(format!("id > ${}", bind_values.len()));
219 }
220
221 let where_clause = if conditions.is_empty() {
222 String::new()
223 } else {
224 format!("WHERE {}", conditions.join(" AND "))
225 };
226
227 let page_size = match params.page_size {
228 Some(0) | None => 50_u32,
229 Some(n) => n.min(1000),
230 };
231
232 let limit = page_size + 1;
234 let sql =
235 format!("SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT {limit}");
236
237 let mut query = sqlx::query_as::<_, (serde_json::Value,)>(&sql);
238 for val in &bind_values {
239 query = query.bind(val);
240 }
241
242 let rows: Vec<(serde_json::Value,)> =
243 query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
244
245 let mut tasks: Vec<Task> = rows
246 .into_iter()
247 .map(|(data,)| {
248 serde_json::from_value(data)
249 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
250 })
251 .collect::<A2aResult<Vec<_>>>()?;
252
253 let next_page_token = if tasks.len() > page_size as usize {
254 tasks.truncate(page_size as usize);
255 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
256 } else {
257 String::new()
258 };
259
260 #[allow(clippy::cast_possible_truncation)]
261 let page_len = tasks.len() as u32;
262 let mut response = TaskListResponse::new(tasks);
263 response.next_page_token = next_page_token;
264 response.page_size = page_len;
265 Ok(response)
266 })
267 }
268
269 fn insert_if_absent<'a>(
270 &'a self,
271 task: &'a Task,
272 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
273 Box::pin(async move {
274 let id = task.id.0.as_str();
275 let context_id = task.context_id.0.as_str();
276 let state = task.status.state.to_string();
277 let data = serde_json::to_value(task)
278 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
279
280 let result = sqlx::query(
281 "INSERT INTO tasks (id, context_id, state, data, updated_at)
282 VALUES ($1, $2, $3, $4, now())
283 ON CONFLICT(id) DO NOTHING",
284 )
285 .bind(id)
286 .bind(context_id)
287 .bind(&state)
288 .bind(&data)
289 .execute(&self.pool)
290 .await
291 .map_err(to_a2a_error)?;
292
293 Ok(result.rows_affected() > 0)
294 })
295 }
296
297 fn delete<'a>(
298 &'a self,
299 id: &'a TaskId,
300 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
301 Box::pin(async move {
302 sqlx::query("DELETE FROM tasks WHERE id = $1")
303 .bind(id.0.as_str())
304 .execute(&self.pool)
305 .await
306 .map_err(to_a2a_error)?;
307 Ok(())
308 })
309 }
310
311 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
312 Box::pin(async move {
313 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
314 .fetch_one(&self.pool)
315 .await
316 .map_err(to_a2a_error)?;
317 #[allow(clippy::cast_sign_loss)]
318 Ok(row.0 as u64)
319 })
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn to_a2a_error_formats_message() {
329 let pg_err = sqlx::Error::RowNotFound;
330 let a2a_err = to_a2a_error(pg_err);
331 let msg = format!("{a2a_err}");
332 assert!(
333 msg.contains("postgres error"),
334 "error message should contain 'postgres error': {msg}"
335 );
336 }
337}