a2a_protocol_server/store/
tenant_postgres_store.rs1use std::future::Future;
29use std::pin::Pin;
30
31use a2a_protocol_types::error::{A2aError, A2aResult};
32use a2a_protocol_types::params::ListTasksParams;
33use a2a_protocol_types::responses::TaskListResponse;
34use a2a_protocol_types::task::{Task, TaskId};
35use sqlx::postgres::{PgPool, PgPoolOptions};
36
37use super::task_store::TaskStore;
38use super::tenant::TenantContext;
39
40#[derive(Debug, Clone)]
46pub struct TenantAwarePostgresTaskStore {
47 pool: PgPool,
48}
49
50impl TenantAwarePostgresTaskStore {
51 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
57 let pool = PgPoolOptions::new()
58 .max_connections(10)
59 .connect(url)
60 .await?;
61 Self::from_pool(pool).await
62 }
63
64 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
70 sqlx::query(
71 "CREATE TABLE IF NOT EXISTS tenant_tasks (
72 tenant_id TEXT NOT NULL DEFAULT '',
73 id TEXT NOT NULL,
74 context_id TEXT NOT NULL,
75 state TEXT NOT NULL,
76 data JSONB NOT NULL,
77 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
78 updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
79 PRIMARY KEY (tenant_id, id)
80 )",
81 )
82 .execute(&pool)
83 .await?;
84
85 sqlx::query(
86 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
87 )
88 .execute(&pool)
89 .await?;
90
91 sqlx::query(
92 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
93 )
94 .execute(&pool)
95 .await?;
96
97 Ok(Self { pool })
98 }
99}
100
101fn to_a2a_error(e: &sqlx::Error) -> A2aError {
102 A2aError::internal(format!("postgres error: {e}"))
103}
104
105#[allow(clippy::manual_async_fn)]
106impl TaskStore for TenantAwarePostgresTaskStore {
107 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
108 Box::pin(async move {
109 let tenant = TenantContext::current();
110 let id = task.id.0.as_str();
111 let context_id = task.context_id.0.as_str();
112 let state = task.status.state.to_string();
113 let data = serde_json::to_value(&task)
114 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
115
116 sqlx::query(
117 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
118 VALUES ($1, $2, $3, $4, $5, now())
119 ON CONFLICT(tenant_id, id) DO UPDATE SET
120 context_id = EXCLUDED.context_id,
121 state = EXCLUDED.state,
122 data = EXCLUDED.data,
123 updated_at = now()",
124 )
125 .bind(&tenant)
126 .bind(id)
127 .bind(context_id)
128 .bind(&state)
129 .bind(&data)
130 .execute(&self.pool)
131 .await
132 .map_err(|e| to_a2a_error(&e))?;
133
134 Ok(())
135 })
136 }
137
138 fn get<'a>(
139 &'a self,
140 id: &'a TaskId,
141 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
142 Box::pin(async move {
143 let tenant = TenantContext::current();
144 let row: Option<(serde_json::Value,)> =
145 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
146 .bind(&tenant)
147 .bind(id.0.as_str())
148 .fetch_optional(&self.pool)
149 .await
150 .map_err(|e| to_a2a_error(&e))?;
151
152 match row {
153 Some((data,)) => {
154 let task: Task = serde_json::from_value(data)
155 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
156 Ok(Some(task))
157 }
158 None => Ok(None),
159 }
160 })
161 }
162
163 fn list<'a>(
164 &'a self,
165 params: &'a ListTasksParams,
166 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
167 Box::pin(async move {
168 let tenant = TenantContext::current();
169 let mut conditions = vec!["tenant_id = $1".to_string()];
170 let mut bind_values: Vec<String> = vec![tenant];
171
172 if let Some(ref ctx) = params.context_id {
173 bind_values.push(ctx.clone());
174 conditions.push(format!("context_id = ${}", bind_values.len()));
175 }
176 if let Some(ref status) = params.status {
177 bind_values.push(status.to_string());
178 conditions.push(format!("state = ${}", bind_values.len()));
179 }
180 if let Some(ref token) = params.page_token {
181 bind_values.push(token.clone());
182 conditions.push(format!("id > ${}", bind_values.len()));
183 }
184
185 let where_clause = format!("WHERE {}", conditions.join(" AND "));
186
187 let page_size = match params.page_size {
188 Some(0) | None => 50_u32,
189 Some(n) => n.min(1000),
190 };
191
192 let limit = page_size + 1;
193 let sql = format!(
194 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
195 );
196
197 let mut query = sqlx::query_as::<_, (serde_json::Value,)>(&sql);
198 for val in &bind_values {
199 query = query.bind(val);
200 }
201
202 let rows: Vec<(serde_json::Value,)> = query
203 .fetch_all(&self.pool)
204 .await
205 .map_err(|e| to_a2a_error(&e))?;
206
207 let mut tasks: Vec<Task> = rows
208 .into_iter()
209 .map(|(data,)| {
210 serde_json::from_value(data)
211 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
212 })
213 .collect::<A2aResult<Vec<_>>>()?;
214
215 let next_page_token = if tasks.len() > page_size as usize {
216 tasks.truncate(page_size as usize);
217 tasks.last().map(|t| t.id.0.clone())
218 } else {
219 None
220 };
221
222 let mut response = TaskListResponse::new(tasks);
223 response.next_page_token = next_page_token;
224 Ok(response)
225 })
226 }
227
228 fn insert_if_absent<'a>(
229 &'a self,
230 task: Task,
231 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
232 Box::pin(async move {
233 let tenant = TenantContext::current();
234 let id = task.id.0.as_str();
235 let context_id = task.context_id.0.as_str();
236 let state = task.status.state.to_string();
237 let data = serde_json::to_value(&task)
238 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
239
240 let result = sqlx::query(
241 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
242 VALUES ($1, $2, $3, $4, $5, now())
243 ON CONFLICT(tenant_id, id) DO NOTHING",
244 )
245 .bind(&tenant)
246 .bind(id)
247 .bind(context_id)
248 .bind(&state)
249 .bind(&data)
250 .execute(&self.pool)
251 .await
252 .map_err(|e| to_a2a_error(&e))?;
253
254 Ok(result.rows_affected() > 0)
255 })
256 }
257
258 fn delete<'a>(
259 &'a self,
260 id: &'a TaskId,
261 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
262 Box::pin(async move {
263 let tenant = TenantContext::current();
264 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
265 .bind(&tenant)
266 .bind(id.0.as_str())
267 .execute(&self.pool)
268 .await
269 .map_err(|e| to_a2a_error(&e))?;
270 Ok(())
271 })
272 }
273
274 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
275 Box::pin(async move {
276 let tenant = TenantContext::current();
277 let row: (i64,) =
278 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = $1")
279 .bind(&tenant)
280 .fetch_one(&self.pool)
281 .await
282 .map_err(|e| to_a2a_error(&e))?;
283 #[allow(clippy::cast_sign_loss)]
284 Ok(row.0 as u64)
285 })
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn to_a2a_error_formats_message() {
295 let pg_err = sqlx::Error::RowNotFound;
296 let a2a_err = to_a2a_error(&pg_err);
297 let msg = format!("{a2a_err}");
298 assert!(
299 msg.contains("postgres error"),
300 "error message should contain 'postgres error': {msg}"
301 );
302 }
303}