a2a_protocol_server/push/
tenant_postgres_config_store.rs1use std::future::Future;
14use std::pin::Pin;
15
16use a2a_protocol_types::error::{A2aError, A2aResult};
17use a2a_protocol_types::push::TaskPushNotificationConfig;
18use sqlx::postgres::PgPool;
19
20use super::config_store::PushConfigStore;
21use crate::store::tenant::TenantContext;
22
23#[derive(Debug, Clone)]
39pub struct TenantAwarePostgresPushConfigStore {
40 pool: PgPool,
41}
42
43fn to_a2a_error(e: &sqlx::Error) -> A2aError {
44 A2aError::internal(format!("postgres error: {e}"))
45}
46
47impl TenantAwarePostgresPushConfigStore {
48 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
54 let pool = sqlx::postgres::PgPoolOptions::new()
55 .max_connections(10)
56 .connect(url)
57 .await?;
58 Self::from_pool(pool).await
59 }
60
61 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
67 sqlx::query(
68 "CREATE TABLE IF NOT EXISTS tenant_push_configs (
69 tenant_id TEXT NOT NULL DEFAULT '',
70 task_id TEXT NOT NULL,
71 id TEXT NOT NULL,
72 data JSONB NOT NULL,
73 PRIMARY KEY (tenant_id, task_id, id)
74 )",
75 )
76 .execute(&pool)
77 .await?;
78
79 Ok(Self { pool })
80 }
81}
82
83#[allow(clippy::manual_async_fn)]
84impl PushConfigStore for TenantAwarePostgresPushConfigStore {
85 fn set<'a>(
86 &'a self,
87 mut config: TaskPushNotificationConfig,
88 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
89 Box::pin(async move {
90 let tenant = TenantContext::current();
91 let id = config
92 .id
93 .clone()
94 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
95 config.id = Some(id.clone());
96
97 let data = serde_json::to_value(&config)
98 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
99
100 sqlx::query(
101 "INSERT INTO tenant_push_configs (tenant_id, task_id, id, data)
102 VALUES ($1, $2, $3, $4)
103 ON CONFLICT(tenant_id, task_id, id) DO UPDATE SET data = EXCLUDED.data",
104 )
105 .bind(&tenant)
106 .bind(&config.task_id)
107 .bind(&id)
108 .bind(&data)
109 .execute(&self.pool)
110 .await
111 .map_err(|e| to_a2a_error(&e))?;
112
113 Ok(config)
114 })
115 }
116
117 fn get<'a>(
118 &'a self,
119 task_id: &'a str,
120 id: &'a str,
121 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
122 {
123 Box::pin(async move {
124 let tenant = TenantContext::current();
125 let row: Option<(serde_json::Value,)> = sqlx::query_as(
126 "SELECT data FROM tenant_push_configs WHERE tenant_id = $1 AND task_id = $2 AND id = $3",
127 )
128 .bind(&tenant)
129 .bind(task_id)
130 .bind(id)
131 .fetch_optional(&self.pool)
132 .await
133 .map_err(|e| to_a2a_error(&e))?;
134
135 match row {
136 Some((data,)) => {
137 let config: TaskPushNotificationConfig = serde_json::from_value(data)
138 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
139 Ok(Some(config))
140 }
141 None => Ok(None),
142 }
143 })
144 }
145
146 fn list<'a>(
147 &'a self,
148 task_id: &'a str,
149 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
150 Box::pin(async move {
151 let tenant = TenantContext::current();
152 let rows: Vec<(serde_json::Value,)> = sqlx::query_as(
153 "SELECT data FROM tenant_push_configs WHERE tenant_id = $1 AND task_id = $2",
154 )
155 .bind(&tenant)
156 .bind(task_id)
157 .fetch_all(&self.pool)
158 .await
159 .map_err(|e| to_a2a_error(&e))?;
160
161 rows.into_iter()
162 .map(|(data,)| {
163 serde_json::from_value(data)
164 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
165 })
166 .collect()
167 })
168 }
169
170 fn delete<'a>(
171 &'a self,
172 task_id: &'a str,
173 id: &'a str,
174 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
175 Box::pin(async move {
176 let tenant = TenantContext::current();
177 sqlx::query(
178 "DELETE FROM tenant_push_configs WHERE tenant_id = $1 AND task_id = $2 AND id = $3",
179 )
180 .bind(&tenant)
181 .bind(task_id)
182 .bind(id)
183 .execute(&self.pool)
184 .await
185 .map_err(|e| to_a2a_error(&e))?;
186 Ok(())
187 })
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn to_a2a_error_formats_message() {
197 let pg_err = sqlx::Error::RowNotFound;
198 let a2a_err = to_a2a_error(&pg_err);
199 let msg = format!("{a2a_err}");
200 assert!(
201 msg.contains("postgres error"),
202 "error message should contain 'postgres error': {msg}"
203 );
204 }
205}