a2a_protocol_server/store/
tenant_postgres_store.rs1use std::future::Future;
29use std::pin::Pin;
30
31use a2a_protocol_types::error::{A2aError, A2aResult};
32use a2a_protocol_types::params::ListTasksParams;
33use a2a_protocol_types::responses::TaskListResponse;
34use a2a_protocol_types::task::{Task, TaskId};
35use sqlx::postgres::{PgPool, PgPoolOptions};
36
37use super::task_store::TaskStore;
38use super::tenant::TenantContext;
39
40#[derive(Debug, Clone)]
46pub struct TenantAwarePostgresTaskStore {
47 pool: PgPool,
48}
49
50impl TenantAwarePostgresTaskStore {
51 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
57 let pool = PgPoolOptions::new()
58 .max_connections(10)
59 .connect(url)
60 .await?;
61 Self::from_pool(pool).await
62 }
63
64 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
70 sqlx::query(
71 "CREATE TABLE IF NOT EXISTS tenant_tasks (
72 tenant_id TEXT NOT NULL DEFAULT '',
73 id TEXT NOT NULL,
74 context_id TEXT NOT NULL,
75 state TEXT NOT NULL,
76 data JSONB NOT NULL,
77 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
78 updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
79 PRIMARY KEY (tenant_id, id)
80 )",
81 )
82 .execute(&pool)
83 .await?;
84
85 sqlx::query(
86 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
87 )
88 .execute(&pool)
89 .await?;
90
91 sqlx::query(
92 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
93 )
94 .execute(&pool)
95 .await?;
96
97 Ok(Self { pool })
98 }
99}
100
101fn to_a2a_error(e: &sqlx::Error) -> A2aError {
102 A2aError::internal(format!("postgres error: {e}"))
103}
104
105#[allow(clippy::manual_async_fn)]
106impl TaskStore for TenantAwarePostgresTaskStore {
107 fn save<'a>(
108 &'a self,
109 task: &'a Task,
110 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
111 Box::pin(async move {
112 let tenant = TenantContext::current();
113 let id = task.id.0.as_str();
114 let context_id = task.context_id.0.as_str();
115 let state = task.status.state.to_string();
116 let data = serde_json::to_value(task)
117 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
118
119 sqlx::query(
120 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
121 VALUES ($1, $2, $3, $4, $5, now())
122 ON CONFLICT(tenant_id, id) DO UPDATE SET
123 context_id = EXCLUDED.context_id,
124 state = EXCLUDED.state,
125 data = EXCLUDED.data,
126 updated_at = now()",
127 )
128 .bind(&tenant)
129 .bind(id)
130 .bind(context_id)
131 .bind(&state)
132 .bind(&data)
133 .execute(&self.pool)
134 .await
135 .map_err(|e| to_a2a_error(&e))?;
136
137 Ok(())
138 })
139 }
140
141 fn get<'a>(
142 &'a self,
143 id: &'a TaskId,
144 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
145 Box::pin(async move {
146 let tenant = TenantContext::current();
147 let row: Option<(serde_json::Value,)> =
148 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
149 .bind(&tenant)
150 .bind(id.0.as_str())
151 .fetch_optional(&self.pool)
152 .await
153 .map_err(|e| to_a2a_error(&e))?;
154
155 match row {
156 Some((data,)) => {
157 let task: Task = serde_json::from_value(data)
158 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
159 Ok(Some(task))
160 }
161 None => Ok(None),
162 }
163 })
164 }
165
166 fn list<'a>(
167 &'a self,
168 params: &'a ListTasksParams,
169 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
170 Box::pin(async move {
171 let tenant = TenantContext::current();
172 let mut conditions = vec!["tenant_id = $1".to_string()];
173 let mut bind_values: Vec<String> = vec![tenant];
174
175 if let Some(ref ctx) = params.context_id {
176 bind_values.push(ctx.clone());
177 conditions.push(format!("context_id = ${}", bind_values.len()));
178 }
179 if let Some(ref status) = params.status {
180 bind_values.push(status.to_string());
181 conditions.push(format!("state = ${}", bind_values.len()));
182 }
183 if let Some(ref token) = params.page_token {
184 bind_values.push(token.clone());
185 conditions.push(format!("id > ${}", bind_values.len()));
186 }
187
188 let where_clause = format!("WHERE {}", conditions.join(" AND "));
189
190 let page_size = match params.page_size {
191 Some(0) | None => 50_u32,
192 Some(n) => n.min(1000),
193 };
194
195 let limit = page_size + 1;
196 let sql = format!(
197 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
198 );
199
200 let mut query = sqlx::query_as::<_, (serde_json::Value,)>(&sql);
201 for val in &bind_values {
202 query = query.bind(val);
203 }
204
205 let rows: Vec<(serde_json::Value,)> = query
206 .fetch_all(&self.pool)
207 .await
208 .map_err(|e| to_a2a_error(&e))?;
209
210 let mut tasks: Vec<Task> = rows
211 .into_iter()
212 .map(|(data,)| {
213 serde_json::from_value(data)
214 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
215 })
216 .collect::<A2aResult<Vec<_>>>()?;
217
218 let next_page_token = if tasks.len() > page_size as usize {
219 tasks.truncate(page_size as usize);
220 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
221 } else {
222 String::new()
223 };
224
225 #[allow(clippy::cast_possible_truncation)]
226 let page_len = tasks.len() as u32;
227 let mut response = TaskListResponse::new(tasks);
228 response.next_page_token = next_page_token;
229 response.page_size = page_len;
230 Ok(response)
231 })
232 }
233
234 fn insert_if_absent<'a>(
235 &'a self,
236 task: &'a Task,
237 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
238 Box::pin(async move {
239 let tenant = TenantContext::current();
240 let id = task.id.0.as_str();
241 let context_id = task.context_id.0.as_str();
242 let state = task.status.state.to_string();
243 let data = serde_json::to_value(task)
244 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
245
246 let result = sqlx::query(
247 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
248 VALUES ($1, $2, $3, $4, $5, now())
249 ON CONFLICT(tenant_id, id) DO NOTHING",
250 )
251 .bind(&tenant)
252 .bind(id)
253 .bind(context_id)
254 .bind(&state)
255 .bind(&data)
256 .execute(&self.pool)
257 .await
258 .map_err(|e| to_a2a_error(&e))?;
259
260 Ok(result.rows_affected() > 0)
261 })
262 }
263
264 fn delete<'a>(
265 &'a self,
266 id: &'a TaskId,
267 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
268 Box::pin(async move {
269 let tenant = TenantContext::current();
270 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
271 .bind(&tenant)
272 .bind(id.0.as_str())
273 .execute(&self.pool)
274 .await
275 .map_err(|e| to_a2a_error(&e))?;
276 Ok(())
277 })
278 }
279
280 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
281 Box::pin(async move {
282 let tenant = TenantContext::current();
283 let row: (i64,) =
284 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = $1")
285 .bind(&tenant)
286 .fetch_one(&self.pool)
287 .await
288 .map_err(|e| to_a2a_error(&e))?;
289 #[allow(clippy::cast_sign_loss)]
290 Ok(row.0 as u64)
291 })
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn to_a2a_error_formats_message() {
301 let pg_err = sqlx::Error::RowNotFound;
302 let a2a_err = to_a2a_error(&pg_err);
303 let msg = format!("{a2a_err}");
304 assert!(
305 msg.contains("postgres error"),
306 "error message should contain 'postgres error': {msg}"
307 );
308 }
309}