a2a_protocol_server/store/
tenant_postgres_store.rs1use std::future::Future;
29use std::pin::Pin;
30
31use a2a_protocol_types::error::{A2aError, A2aResult};
32use a2a_protocol_types::params::ListTasksParams;
33use a2a_protocol_types::responses::TaskListResponse;
34use a2a_protocol_types::task::{Task, TaskId};
35use sqlx::postgres::{PgPool, PgPoolOptions};
36
37use super::task_store::TaskStore;
38use super::tenant::TenantContext;
39
40#[derive(Debug, Clone)]
46pub struct TenantAwarePostgresTaskStore {
47 pool: PgPool,
48}
49
50impl TenantAwarePostgresTaskStore {
51 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
57 let pool = PgPoolOptions::new()
58 .max_connections(10)
59 .connect(url)
60 .await?;
61 Self::from_pool(pool).await
62 }
63
64 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
70 sqlx::query(
71 "CREATE TABLE IF NOT EXISTS tenant_tasks (
72 tenant_id TEXT NOT NULL DEFAULT '',
73 id TEXT NOT NULL,
74 context_id TEXT NOT NULL,
75 state TEXT NOT NULL,
76 data JSONB NOT NULL,
77 created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
78 updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
79 PRIMARY KEY (tenant_id, id)
80 )",
81 )
82 .execute(&pool)
83 .await?;
84
85 sqlx::query(
86 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_ctx ON tenant_tasks(tenant_id, context_id)",
87 )
88 .execute(&pool)
89 .await?;
90
91 sqlx::query(
92 "CREATE INDEX IF NOT EXISTS idx_tenant_tasks_state ON tenant_tasks(tenant_id, state)",
93 )
94 .execute(&pool)
95 .await?;
96
97 Ok(Self { pool })
98 }
99}
100
101fn to_a2a_error(e: &sqlx::Error) -> A2aError {
102 A2aError::internal(format!("postgres error: {e}"))
103}
104
105#[allow(clippy::manual_async_fn)]
106impl TaskStore for TenantAwarePostgresTaskStore {
107 fn save<'a>(&'a self, task: Task) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
108 Box::pin(async move {
109 let tenant = TenantContext::current();
110 let id = task.id.0.as_str();
111 let context_id = task.context_id.0.as_str();
112 let state = task.status.state.to_string();
113 let data = serde_json::to_value(&task)
114 .map_err(|e| A2aError::internal(format!("failed to serialize task: {e}")))?;
115
116 sqlx::query(
117 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
118 VALUES ($1, $2, $3, $4, $5, now())
119 ON CONFLICT(tenant_id, id) DO UPDATE SET
120 context_id = EXCLUDED.context_id,
121 state = EXCLUDED.state,
122 data = EXCLUDED.data,
123 updated_at = now()",
124 )
125 .bind(&tenant)
126 .bind(id)
127 .bind(context_id)
128 .bind(&state)
129 .bind(&data)
130 .execute(&self.pool)
131 .await
132 .map_err(|e| to_a2a_error(&e))?;
133
134 Ok(())
135 })
136 }
137
138 fn get<'a>(
139 &'a self,
140 id: &'a TaskId,
141 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<Task>>> + Send + 'a>> {
142 Box::pin(async move {
143 let tenant = TenantContext::current();
144 let row: Option<(serde_json::Value,)> =
145 sqlx::query_as("SELECT data FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
146 .bind(&tenant)
147 .bind(id.0.as_str())
148 .fetch_optional(&self.pool)
149 .await
150 .map_err(|e| to_a2a_error(&e))?;
151
152 match row {
153 Some((data,)) => {
154 let task: Task = serde_json::from_value(data)
155 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
156 Ok(Some(task))
157 }
158 None => Ok(None),
159 }
160 })
161 }
162
163 fn list<'a>(
164 &'a self,
165 params: &'a ListTasksParams,
166 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskListResponse>> + Send + 'a>> {
167 Box::pin(async move {
168 let tenant = TenantContext::current();
169 let mut conditions = vec!["tenant_id = $1".to_string()];
170 let mut bind_values: Vec<String> = vec![tenant];
171
172 if let Some(ref ctx) = params.context_id {
173 bind_values.push(ctx.clone());
174 conditions.push(format!("context_id = ${}", bind_values.len()));
175 }
176 if let Some(ref status) = params.status {
177 bind_values.push(status.to_string());
178 conditions.push(format!("state = ${}", bind_values.len()));
179 }
180 if let Some(ref token) = params.page_token {
181 bind_values.push(token.clone());
182 conditions.push(format!("id > ${}", bind_values.len()));
183 }
184
185 let where_clause = format!("WHERE {}", conditions.join(" AND "));
186
187 let page_size = match params.page_size {
188 Some(0) | None => 50_u32,
189 Some(n) => n.min(1000),
190 };
191
192 let limit = page_size + 1;
193 let sql = format!(
194 "SELECT data FROM tenant_tasks {where_clause} ORDER BY id ASC LIMIT {limit}"
195 );
196
197 let mut query = sqlx::query_as::<_, (serde_json::Value,)>(&sql);
198 for val in &bind_values {
199 query = query.bind(val);
200 }
201
202 let rows: Vec<(serde_json::Value,)> = query
203 .fetch_all(&self.pool)
204 .await
205 .map_err(|e| to_a2a_error(&e))?;
206
207 let mut tasks: Vec<Task> = rows
208 .into_iter()
209 .map(|(data,)| {
210 serde_json::from_value(data)
211 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
212 })
213 .collect::<A2aResult<Vec<_>>>()?;
214
215 let next_page_token = if tasks.len() > page_size as usize {
216 tasks.truncate(page_size as usize);
217 tasks.last().map(|t| t.id.0.clone()).unwrap_or_default()
218 } else {
219 String::new()
220 };
221
222 #[allow(clippy::cast_possible_truncation)]
223 let page_len = tasks.len() as u32;
224 let mut response = TaskListResponse::new(tasks);
225 response.next_page_token = next_page_token;
226 response.page_size = page_len;
227 Ok(response)
228 })
229 }
230
231 fn insert_if_absent<'a>(
232 &'a self,
233 task: Task,
234 ) -> Pin<Box<dyn Future<Output = A2aResult<bool>> + Send + 'a>> {
235 Box::pin(async move {
236 let tenant = TenantContext::current();
237 let id = task.id.0.as_str();
238 let context_id = task.context_id.0.as_str();
239 let state = task.status.state.to_string();
240 let data = serde_json::to_value(&task)
241 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
242
243 let result = sqlx::query(
244 "INSERT INTO tenant_tasks (tenant_id, id, context_id, state, data, updated_at)
245 VALUES ($1, $2, $3, $4, $5, now())
246 ON CONFLICT(tenant_id, id) DO NOTHING",
247 )
248 .bind(&tenant)
249 .bind(id)
250 .bind(context_id)
251 .bind(&state)
252 .bind(&data)
253 .execute(&self.pool)
254 .await
255 .map_err(|e| to_a2a_error(&e))?;
256
257 Ok(result.rows_affected() > 0)
258 })
259 }
260
261 fn delete<'a>(
262 &'a self,
263 id: &'a TaskId,
264 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
265 Box::pin(async move {
266 let tenant = TenantContext::current();
267 sqlx::query("DELETE FROM tenant_tasks WHERE tenant_id = $1 AND id = $2")
268 .bind(&tenant)
269 .bind(id.0.as_str())
270 .execute(&self.pool)
271 .await
272 .map_err(|e| to_a2a_error(&e))?;
273 Ok(())
274 })
275 }
276
277 fn count<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<u64>> + Send + 'a>> {
278 Box::pin(async move {
279 let tenant = TenantContext::current();
280 let row: (i64,) =
281 sqlx::query_as("SELECT COUNT(*) FROM tenant_tasks WHERE tenant_id = $1")
282 .bind(&tenant)
283 .fetch_one(&self.pool)
284 .await
285 .map_err(|e| to_a2a_error(&e))?;
286 #[allow(clippy::cast_sign_loss)]
287 Ok(row.0 as u64)
288 })
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn to_a2a_error_formats_message() {
298 let pg_err = sqlx::Error::RowNotFound;
299 let a2a_err = to_a2a_error(&pg_err);
300 let msg = format!("{a2a_err}");
301 assert!(
302 msg.contains("postgres error"),
303 "error message should contain 'postgres error': {msg}"
304 );
305 }
306}