a2a_protocol_server/push/
sqlite_config_store.rs1use std::future::Future;
11use std::pin::Pin;
12
13use a2a_protocol_types::error::{A2aError, A2aResult};
14use a2a_protocol_types::push::TaskPushNotificationConfig;
15use sqlx::sqlite::SqlitePool;
16
17use super::config_store::PushConfigStore;
18
19#[derive(Debug, Clone)]
34pub struct SqlitePushConfigStore {
35 pool: SqlitePool,
36}
37
38async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
40 use sqlx::sqlite::SqliteConnectOptions;
41 use std::str::FromStr;
42
43 let opts = SqliteConnectOptions::from_str(url)?
44 .pragma("journal_mode", "WAL")
45 .pragma("busy_timeout", "5000")
46 .pragma("synchronous", "NORMAL")
47 .pragma("foreign_keys", "ON")
48 .create_if_missing(true);
49
50 sqlx::sqlite::SqlitePoolOptions::new()
51 .max_connections(8)
52 .connect_with(opts)
53 .await
54}
55
56#[allow(clippy::needless_pass_by_value)]
58fn to_a2a_error(e: sqlx::Error) -> A2aError {
59 A2aError::internal(format!("sqlite error: {e}"))
60}
61
62impl SqlitePushConfigStore {
63 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
69 let pool = sqlite_pool(url).await?;
70 Self::from_pool(pool).await
71 }
72
73 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
79 sqlx::query(
80 "CREATE TABLE IF NOT EXISTS push_configs (
81 task_id TEXT NOT NULL,
82 id TEXT NOT NULL,
83 data TEXT NOT NULL,
84 PRIMARY KEY (task_id, id)
85 )",
86 )
87 .execute(&pool)
88 .await?;
89
90 Ok(Self { pool })
91 }
92}
93
94#[allow(clippy::manual_async_fn)]
95impl PushConfigStore for SqlitePushConfigStore {
96 fn set<'a>(
97 &'a self,
98 mut config: TaskPushNotificationConfig,
99 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
100 Box::pin(async move {
101 let id = config
102 .id
103 .clone()
104 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
105 config.id = Some(id.clone());
106
107 let data = serde_json::to_string(&config)
108 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
109
110 sqlx::query(
111 "INSERT INTO push_configs (task_id, id, data)
112 VALUES (?1, ?2, ?3)
113 ON CONFLICT(task_id, id) DO UPDATE SET data = excluded.data",
114 )
115 .bind(&config.task_id)
116 .bind(&id)
117 .bind(&data)
118 .execute(&self.pool)
119 .await
120 .map_err(to_a2a_error)?;
121
122 Ok(config)
123 })
124 }
125
126 fn get<'a>(
127 &'a self,
128 task_id: &'a str,
129 id: &'a str,
130 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
131 {
132 Box::pin(async move {
133 let row: Option<(String,)> =
134 sqlx::query_as("SELECT data FROM push_configs WHERE task_id = ?1 AND id = ?2")
135 .bind(task_id)
136 .bind(id)
137 .fetch_optional(&self.pool)
138 .await
139 .map_err(to_a2a_error)?;
140
141 match row {
142 Some((data,)) => {
143 let config: TaskPushNotificationConfig = serde_json::from_str(&data)
144 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
145 Ok(Some(config))
146 }
147 None => Ok(None),
148 }
149 })
150 }
151
152 fn list<'a>(
153 &'a self,
154 task_id: &'a str,
155 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
156 Box::pin(async move {
157 let rows: Vec<(String,)> =
158 sqlx::query_as("SELECT data FROM push_configs WHERE task_id = ?1")
159 .bind(task_id)
160 .fetch_all(&self.pool)
161 .await
162 .map_err(to_a2a_error)?;
163
164 rows.into_iter()
165 .map(|(data,)| {
166 serde_json::from_str(&data)
167 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
168 })
169 .collect()
170 })
171 }
172
173 fn delete<'a>(
174 &'a self,
175 task_id: &'a str,
176 id: &'a str,
177 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
178 Box::pin(async move {
179 sqlx::query("DELETE FROM push_configs WHERE task_id = ?1 AND id = ?2")
180 .bind(task_id)
181 .bind(id)
182 .execute(&self.pool)
183 .await
184 .map_err(to_a2a_error)?;
185 Ok(())
186 })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use a2a_protocol_types::push::TaskPushNotificationConfig;
194
195 async fn make_store() -> SqlitePushConfigStore {
196 SqlitePushConfigStore::new("sqlite::memory:")
197 .await
198 .expect("failed to create in-memory push config store")
199 }
200
201 fn make_config(task_id: &str, id: Option<&str>, url: &str) -> TaskPushNotificationConfig {
202 TaskPushNotificationConfig {
203 tenant: None,
204 id: id.map(String::from),
205 task_id: task_id.to_string(),
206 url: url.to_string(),
207 token: None,
208 authentication: None,
209 }
210 }
211
212 #[tokio::test]
213 async fn set_assigns_id_when_none() {
214 let store = make_store().await;
215 let config = make_config("task-1", None, "https://example.com/hook");
216 let result = store.set(config).await.expect("set should succeed");
217 assert!(
218 result.id.is_some(),
219 "set should assign an id when none is provided"
220 );
221 }
222
223 #[tokio::test]
224 async fn set_preserves_explicit_id() {
225 let store = make_store().await;
226 let config = make_config("task-1", Some("my-id"), "https://example.com/hook");
227 let result = store.set(config).await.expect("set should succeed");
228 assert_eq!(
229 result.id.as_deref(),
230 Some("my-id"),
231 "set should preserve the explicit id"
232 );
233 }
234
235 #[tokio::test]
236 async fn set_then_get_round_trip() {
237 let store = make_store().await;
238 let config = make_config("task-1", Some("cfg-1"), "https://example.com/hook");
239 store.set(config).await.unwrap();
240
241 let retrieved = store.get("task-1", "cfg-1").await.unwrap();
242 let retrieved = retrieved.expect("config should exist after set");
243 assert_eq!(retrieved.task_id, "task-1");
244 assert_eq!(retrieved.url, "https://example.com/hook");
245 assert_eq!(retrieved.id.as_deref(), Some("cfg-1"));
246 }
247
248 #[tokio::test]
249 async fn get_returns_none_for_missing_config() {
250 let store = make_store().await;
251 let result = store
252 .get("no-task", "no-id")
253 .await
254 .expect("get should succeed");
255 assert!(
256 result.is_none(),
257 "get should return None for a missing config"
258 );
259 }
260
261 #[tokio::test]
262 async fn overwrite_existing_config() {
263 let store = make_store().await;
264 store
265 .set(make_config(
266 "task-1",
267 Some("cfg-1"),
268 "https://example.com/v1",
269 ))
270 .await
271 .unwrap();
272 store
273 .set(make_config(
274 "task-1",
275 Some("cfg-1"),
276 "https://example.com/v2",
277 ))
278 .await
279 .unwrap();
280
281 let retrieved = store.get("task-1", "cfg-1").await.unwrap().unwrap();
282 assert_eq!(
283 retrieved.url, "https://example.com/v2",
284 "overwrite should update the URL"
285 );
286 }
287
288 #[tokio::test]
289 async fn list_returns_empty_for_unknown_task() {
290 let store = make_store().await;
291 let configs = store.list("no-such-task").await.unwrap();
292 assert!(
293 configs.is_empty(),
294 "list should return empty vec for unknown task"
295 );
296 }
297
298 #[tokio::test]
299 async fn list_returns_only_configs_for_given_task() {
300 let store = make_store().await;
301 store
302 .set(make_config("task-a", Some("c1"), "https://a.com/1"))
303 .await
304 .unwrap();
305 store
306 .set(make_config("task-a", Some("c2"), "https://a.com/2"))
307 .await
308 .unwrap();
309 store
310 .set(make_config("task-b", Some("c3"), "https://b.com/1"))
311 .await
312 .unwrap();
313
314 let a_configs = store.list("task-a").await.unwrap();
315 assert_eq!(a_configs.len(), 2, "task-a should have exactly 2 configs");
316
317 let b_configs = store.list("task-b").await.unwrap();
318 assert_eq!(b_configs.len(), 1, "task-b should have exactly 1 config");
319 }
320
321 #[tokio::test]
322 async fn delete_removes_config() {
323 let store = make_store().await;
324 store
325 .set(make_config("task-1", Some("cfg-1"), "https://example.com"))
326 .await
327 .unwrap();
328
329 store
330 .delete("task-1", "cfg-1")
331 .await
332 .expect("delete should succeed");
333
334 let result = store.get("task-1", "cfg-1").await.unwrap();
335 assert!(result.is_none(), "config should be gone after delete");
336 }
337
338 #[tokio::test]
339 async fn delete_nonexistent_is_ok() {
340 let store = make_store().await;
341 let result = store.delete("no-task", "no-id").await;
342 assert!(
343 result.is_ok(),
344 "deleting a nonexistent config should not error"
345 );
346 }
347
348 #[tokio::test]
349 async fn delete_does_not_affect_other_configs() {
350 let store = make_store().await;
351 store
352 .set(make_config("task-1", Some("c1"), "https://a.com"))
353 .await
354 .unwrap();
355 store
356 .set(make_config("task-1", Some("c2"), "https://b.com"))
357 .await
358 .unwrap();
359
360 store.delete("task-1", "c1").await.unwrap();
361
362 let remaining = store.list("task-1").await.unwrap();
363 assert_eq!(
364 remaining.len(),
365 1,
366 "only the deleted config should be removed"
367 );
368 assert_eq!(remaining[0].id.as_deref(), Some("c2"));
369 }
370
371 #[test]
373 fn to_a2a_error_formats_message() {
374 let sqlite_err = sqlx::Error::RowNotFound;
375 let a2a_err = to_a2a_error(sqlite_err);
376 let msg = format!("{a2a_err}");
377 assert!(
378 msg.contains("sqlite error"),
379 "error message should contain 'sqlite error': {msg}"
380 );
381 }
382
383 #[tokio::test]
384 async fn multiple_tasks_independent_configs() {
385 let store = make_store().await;
386 store
388 .set(make_config("task-a", Some("cfg-1"), "https://a.com"))
389 .await
390 .unwrap();
391 store
392 .set(make_config("task-b", Some("cfg-1"), "https://b.com"))
393 .await
394 .unwrap();
395
396 let a = store.get("task-a", "cfg-1").await.unwrap().unwrap();
397 assert_eq!(a.url, "https://a.com");
398
399 let b = store.get("task-b", "cfg-1").await.unwrap().unwrap();
400 assert_eq!(b.url, "https://b.com");
401 }
402}