a2a_protocol_server/store/
postgres_store.rs1use std::future::Future;
22use std::pin::Pin;
23
24use a2a_protocol_types::error::{A2aError, A2aResult};
25use a2a_protocol_types::params::ListTasksParams;
26use a2a_protocol_types::responses::TaskListResponse;
27use a2a_protocol_types::task::{Task, TaskId};
28use sqlx::postgres::{PgPool, PgPoolOptions};
29
30use super::task_store::TaskStore;
31
32#[derive(Debug, Clone)]
52pub struct PostgresTaskStore {
53 pool: PgPool,
54}
55
56impl PostgresTaskStore {
57 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
63 let pool = pg_pool(url).await?;
64 Self::from_pool(pool).await
65 }
66
67 pub async fn with_migrations(url: &str) -> Result<Self, sqlx::Error> {
77 let pool = pg_pool(url).await?;
78
79 let runner = super::pg_migration::PgMigrationRunner::new(pool.clone());
80 runner.run_pending().await?;
81
82 Ok(Self { pool })
83 }
84
85 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
91 sqlx::query(
92 "CREATE TABLE IF NOT EXISTS tasks (
93 id TEXT PRIMARY KEY,
94 context_id TEXT NOT NULL,
95 state TEXT NOT NULL,
96 data JSONB NOT NULL,
97 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
98 updated_at TIMESTAMPTZ NOT NULL DEFAULT now()
99 )",
100 )
101 .execute(&pool)
102 .await?;
103
104 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_context_id ON tasks(context_id)")
105 .execute(&pool)
106 .await?;
107
108 sqlx::query("CREATE INDEX IF NOT EXISTS idx_tasks_state ON tasks(state)")
109 .execute(&pool)
110 .await?;
111
112 sqlx::query(
113 "CREATE INDEX IF NOT EXISTS idx_tasks_context_id_state ON tasks(context_id, state)",
114 )
115 .execute(&pool)
116 .await?;
117
118 Ok(Self { pool })
119 }
120}
121
122async fn pg_pool(url: &str) -> Result<PgPool, sqlx::Error> {
124 pg_pool_with_size(url, 10).await
125}
126
127async fn pg_pool_with_size(url: &str, max_connections: u32) -> Result<PgPool, sqlx::Error> {
129 PgPoolOptions::new()
130 .max_connections(max_connections)
131 .connect(url)
132 .await
133}
134
135#[allow(clippy::needless_pass_by_value)]
137fn to_a2a_error(e: sqlx::Error) -> A2aError {
138 A2aError::internal(format!("postgres error: {e}"))
139}
140
141#[allow(clippy::manual_async_fn)]
142impl TaskStore for PostgresTaskStore {
143 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
144 Box::pin(async move {
145 let id = task.id.0.as_str();
146 let context_id = task.context_id.0.as_str();
147 let state = task.status.state.to_string();
148 let data = serde_json::to_value(&task)
149 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
150
151 sqlx::query(
152 "INSERT INTO tasks (id, context_id, state, data, updated_at)
153 VALUES ($1, $2, $3, $4, now())
154 ON CONFLICT(id) DO UPDATE SET
155 context_id = EXCLUDED.context_id,
156 state = EXCLUDED.state,
157 data = EXCLUDED.data,
158 updated_at = now()",
159 )
160 .bind(id)
161 .bind(context_id)
162 .bind(&state)
163 .bind(&data)
164 .execute(&self.pool)
165 .await
166 .map_err(to_a2a_error)?;
167
168 Ok(())
169 })
170 }
171
172 fn get<'a>(
173 &'a self,
174 id: &'a TaskId,
175 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
176 Box::pin(async move {
177 let row: Option<(serde_json::Value,)> =
178 sqlx::query_as("SELECT data FROM tasks WHERE id = $1")
179 .bind(id.0.as_str())
180 .fetch_optional(&self.pool)
181 .await
182 .map_err(to_a2a_error)?;
183
184 match row {
185 Some((data,)) => {
186 let task: Task = serde_json::from_value(data).map_err(|e| {
187 A2aError::internal(format!("failed to deserialize task: {e}"))
188 })?;
189 Ok(Some(task))
190 }
191 None => Ok(None),
192 }
193 })
194 }
195
196 fn list<'a>(
197 &'a self,
198 params: &'a ListTasksParams,
199 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
200 Box::pin(async move {
201 let mut conditions = Vec::new();
203 let mut bind_values: Vec<String> = Vec::new();
204
205 if let Some(ref ctx) = params.context_id {
206 bind_values.push(ctx.clone());
207 conditions.push(format!("context_id = ${}", bind_values.len()));
208 }
209 if let Some(ref status) = params.status {
210 bind_values.push(status.to_string());
211 conditions.push(format!("state = ${}", bind_values.len()));
212 }
213 if let Some(ref token) = params.page_token {
214 bind_values.push(token.clone());
215 conditions.push(format!("id > ${}", bind_values.len()));
216 }
217
218 let where_clause = if conditions.is_empty() {
219 String::new()
220 } else {
221 format!("WHERE {}", conditions.join(" AND "))
222 };
223
224 let page_size = match params.page_size {
225 Some(0) | None => 50_u32,
226 Some(n) => n.min(1000),
227 };
228
229 let limit = page_size + 1;
231 let sql =
232 format!("SELECT data FROM tasks {where_clause} ORDER BY id ASC LIMIT {limit}");
233
234 let mut query = sqlx::query_as::<_, (serde_json::Value,)>(&sql);
235 for val in &bind_values {
236 query = query.bind(val);
237 }
238
239 let rows: Vec<(serde_json::Value,)> =
240 query.fetch_all(&self.pool).await.map_err(to_a2a_error)?;
241
242 let mut tasks: Vec<Task> = rows
243 .into_iter()
244 .map(|(data,)| {
245 serde_json::from_value(data)
246 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
247 })
248 .collect::<A2aResult<Vec<_>>>()?;
249
250 let next_page_token = if tasks.len() > page_size as usize {
251 tasks.truncate(page_size as usize);
252 tasks.last().map(|t| t.id.0.clone())
253 } else {
254 None
255 };
256
257 let mut response = TaskListResponse::new(tasks);
258 response.next_page_token = next_page_token;
259 Ok(response)
260 })
261 }
262
263 fn insert_if_absent<'a>(
264 &'a self,
265 task: Task,
266 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
267 Box::pin(async move {
268 let id = task.id.0.as_str();
269 let context_id = task.context_id.0.as_str();
270 let state = task.status.state.to_string();
271 let data = serde_json::to_value(&task)
272 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
273
274 let result = sqlx::query(
275 "INSERT INTO tasks (id, context_id, state, data, updated_at)
276 VALUES ($1, $2, $3, $4, now())
277 ON CONFLICT(id) DO NOTHING",
278 )
279 .bind(id)
280 .bind(context_id)
281 .bind(&state)
282 .bind(&data)
283 .execute(&self.pool)
284 .await
285 .map_err(to_a2a_error)?;
286
287 Ok(result.rows_affected() > 0)
288 })
289 }
290
291 fn delete<'a>(
292 &'a self,
293 id: &'a TaskId,
294 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
295 Box::pin(async move {
296 sqlx::query("DELETE FROM tasks WHERE id = $1")
297 .bind(id.0.as_str())
298 .execute(&self.pool)
299 .await
300 .map_err(to_a2a_error)?;
301 Ok(())
302 })
303 }
304
305 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
306 Box::pin(async move {
307 let row: (i64,) = sqlx::query_as("SELECT COUNT(*) FROM tasks")
308 .fetch_one(&self.pool)
309 .await
310 .map_err(to_a2a_error)?;
311 #[allow(clippy::cast_sign_loss)]
312 Ok(row.0 as u64)
313 })
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320
321 #[test]
322 fn to_a2a_error_formats_message() {
323 let pg_err = sqlx::Error::RowNotFound;
324 let a2a_err = to_a2a_error(pg_err);
325 let msg = format!("{a2a_err}");
326 assert!(
327 msg.contains("postgres error"),
328 "error message should contain 'postgres error': {msg}"
329 );
330 }
331}