a2a_protocol_server/push/
postgres_config_store.rs1use std::future::Future;
11use std::pin::Pin;
12
13use a2a_protocol_types::error::{A2aError, A2aResult};
14use a2a_protocol_types::push::TaskPushNotificationConfig;
15use sqlx::postgres::PgPool;
16
17use super::config_store::PushConfigStore;
18
19#[derive(Debug, Clone)]
34pub struct PostgresPushConfigStore {
35 pool: PgPool,
36}
37
38#[allow(clippy::needless_pass_by_value)]
40fn to_a2a_error(e: sqlx::Error) -> A2aError {
41 A2aError::internal(format!("postgres error: {e}"))
42}
43
44impl PostgresPushConfigStore {
45 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
51 let pool = sqlx::postgres::PgPoolOptions::new()
52 .max_connections(10)
53 .connect(url)
54 .await?;
55 Self::from_pool(pool).await
56 }
57
58 pub async fn from_pool(pool: PgPool) -> Result<Self, sqlx::Error> {
64 sqlx::query(
65 "CREATE TABLE IF NOT EXISTS push_configs (
66 task_id TEXT NOT NULL,
67 id TEXT NOT NULL,
68 data JSONB NOT NULL,
69 PRIMARY KEY (task_id, id)
70 )",
71 )
72 .execute(&pool)
73 .await?;
74
75 Ok(Self { pool })
76 }
77}
78
79#[allow(clippy::manual_async_fn)]
80impl PushConfigStore for PostgresPushConfigStore {
81 fn set<'a>(
82 &'a self,
83 mut config: TaskPushNotificationConfig,
84 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
85 Box::pin(async move {
86 let id = config
87 .id
88 .clone()
89 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
90 config.id = Some(id.clone());
91
92 let data = serde_json::to_value(&config)
93 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
94
95 sqlx::query(
96 "INSERT INTO push_configs (task_id, id, data)
97 VALUES ($1, $2, $3)
98 ON CONFLICT(task_id, id) DO UPDATE SET data = EXCLUDED.data",
99 )
100 .bind(&config.task_id)
101 .bind(&id)
102 .bind(&data)
103 .execute(&self.pool)
104 .await
105 .map_err(to_a2a_error)?;
106
107 Ok(config)
108 })
109 }
110
111 fn get<'a>(
112 &'a self,
113 task_id: &'a str,
114 id: &'a str,
115 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
116 {
117 Box::pin(async move {
118 let row: Option<(serde_json::Value,)> =
119 sqlx::query_as("SELECT data FROM push_configs WHERE task_id = $1 AND id = $2")
120 .bind(task_id)
121 .bind(id)
122 .fetch_optional(&self.pool)
123 .await
124 .map_err(to_a2a_error)?;
125
126 match row {
127 Some((data,)) => {
128 let config: TaskPushNotificationConfig = serde_json::from_value(data)
129 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
130 Ok(Some(config))
131 }
132 None => Ok(None),
133 }
134 })
135 }
136
137 fn list<'a>(
138 &'a self,
139 task_id: &'a str,
140 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
141 Box::pin(async move {
142 let rows: Vec<(serde_json::Value,)> =
143 sqlx::query_as("SELECT data FROM push_configs WHERE task_id = $1")
144 .bind(task_id)
145 .fetch_all(&self.pool)
146 .await
147 .map_err(to_a2a_error)?;
148
149 rows.into_iter()
150 .map(|(data,)| {
151 serde_json::from_value(data)
152 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
153 })
154 .collect()
155 })
156 }
157
158 fn delete<'a>(
159 &'a self,
160 task_id: &'a str,
161 id: &'a str,
162 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
163 Box::pin(async move {
164 sqlx::query("DELETE FROM push_configs WHERE task_id = $1 AND id = $2")
165 .bind(task_id)
166 .bind(id)
167 .execute(&self.pool)
168 .await
169 .map_err(to_a2a_error)?;
170 Ok(())
171 })
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[test]
180 fn to_a2a_error_formats_message() {
181 let pg_err = sqlx::Error::RowNotFound;
182 let a2a_err = to_a2a_error(pg_err);
183 let msg = format!("{a2a_err}");
184 assert!(
185 msg.contains("postgres error"),
186 "error message should contain 'postgres error': {msg}"
187 );
188 }
189}