a2a_protocol_server/push/
tenant_sqlite_config_store.rs1use std::future::Future;
14use std::pin::Pin;
15
16use a2a_protocol_types::error::{A2aError, A2aResult};
17use a2a_protocol_types::push::TaskPushNotificationConfig;
18use sqlx::sqlite::SqlitePool;
19
20use super::config_store::PushConfigStore;
21use crate::store::tenant::TenantContext;
22
23#[derive(Debug, Clone)]
39pub struct TenantAwareSqlitePushConfigStore {
40 pool: SqlitePool,
41}
42
43async fn sqlite_pool(url: &str) -> Result<SqlitePool, sqlx::Error> {
45 use sqlx::sqlite::SqliteConnectOptions;
46 use std::str::FromStr;
47
48 let opts = SqliteConnectOptions::from_str(url)?
49 .pragma("journal_mode", "WAL")
50 .pragma("busy_timeout", "5000")
51 .pragma("synchronous", "NORMAL")
52 .pragma("foreign_keys", "ON")
53 .create_if_missing(true);
54
55 sqlx::sqlite::SqlitePoolOptions::new()
56 .max_connections(8)
57 .connect_with(opts)
58 .await
59}
60
61fn to_a2a_error(e: &sqlx::Error) -> A2aError {
62 A2aError::internal(format!("sqlite error: {e}"))
63}
64
65impl TenantAwareSqlitePushConfigStore {
66 pub async fn new(url: &str) -> Result<Self, sqlx::Error> {
72 let pool = sqlite_pool(url).await?;
73 Self::from_pool(pool).await
74 }
75
76 pub async fn from_pool(pool: SqlitePool) -> Result<Self, sqlx::Error> {
82 sqlx::query(
83 "CREATE TABLE IF NOT EXISTS tenant_push_configs (
84 tenant_id TEXT NOT NULL DEFAULT '',
85 task_id TEXT NOT NULL,
86 id TEXT NOT NULL,
87 data TEXT NOT NULL,
88 PRIMARY KEY (tenant_id, task_id, id)
89 )",
90 )
91 .execute(&pool)
92 .await?;
93
94 Ok(Self { pool })
95 }
96}
97
98#[allow(clippy::manual_async_fn)]
99impl PushConfigStore for TenantAwareSqlitePushConfigStore {
100 fn set<'a>(
101 &'a self,
102 mut config: TaskPushNotificationConfig,
103 ) -> Pin<Box<dyn Future<Output = A2aResult<TaskPushNotificationConfig>> + Send + 'a>> {
104 Box::pin(async move {
105 let tenant = TenantContext::current();
106 let id = config
107 .id
108 .clone()
109 .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
110 config.id = Some(id.clone());
111
112 let data = serde_json::to_string(&config)
113 .map_err(|e| A2aError::internal(format!("serialize: {e}")))?;
114
115 sqlx::query(
116 "INSERT INTO tenant_push_configs (tenant_id, task_id, id, data)
117 VALUES (?1, ?2, ?3, ?4)
118 ON CONFLICT(tenant_id, task_id, id) DO UPDATE SET data = excluded.data",
119 )
120 .bind(&tenant)
121 .bind(&config.task_id)
122 .bind(&id)
123 .bind(&data)
124 .execute(&self.pool)
125 .await
126 .map_err(|e| to_a2a_error(&e))?;
127
128 Ok(config)
129 })
130 }
131
132 fn get<'a>(
133 &'a self,
134 task_id: &'a str,
135 id: &'a str,
136 ) -> Pin<Box<dyn Future<Output = A2aResult<Option<TaskPushNotificationConfig>>> + Send + 'a>>
137 {
138 Box::pin(async move {
139 let tenant = TenantContext::current();
140 let row: Option<(String,)> = sqlx::query_as(
141 "SELECT data FROM tenant_push_configs WHERE tenant_id = ?1 AND task_id = ?2 AND id = ?3",
142 )
143 .bind(&tenant)
144 .bind(task_id)
145 .bind(id)
146 .fetch_optional(&self.pool)
147 .await
148 .map_err(|e| to_a2a_error(&e))?;
149
150 match row {
151 Some((data,)) => {
152 let config: TaskPushNotificationConfig = serde_json::from_str(&data)
153 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))?;
154 Ok(Some(config))
155 }
156 None => Ok(None),
157 }
158 })
159 }
160
161 fn list<'a>(
162 &'a self,
163 task_id: &'a str,
164 ) -> Pin<Box<dyn Future<Output = A2aResult<Vec<TaskPushNotificationConfig>>> + Send + 'a>> {
165 Box::pin(async move {
166 let tenant = TenantContext::current();
167 let rows: Vec<(String,)> = sqlx::query_as(
168 "SELECT data FROM tenant_push_configs WHERE tenant_id = ?1 AND task_id = ?2",
169 )
170 .bind(&tenant)
171 .bind(task_id)
172 .fetch_all(&self.pool)
173 .await
174 .map_err(|e| to_a2a_error(&e))?;
175
176 rows.into_iter()
177 .map(|(data,)| {
178 serde_json::from_str(&data)
179 .map_err(|e| A2aError::internal(format!("deserialize: {e}")))
180 })
181 .collect()
182 })
183 }
184
185 fn delete<'a>(
186 &'a self,
187 task_id: &'a str,
188 id: &'a str,
189 ) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
190 Box::pin(async move {
191 let tenant = TenantContext::current();
192 sqlx::query(
193 "DELETE FROM tenant_push_configs WHERE tenant_id = ?1 AND task_id = ?2 AND id = ?3",
194 )
195 .bind(&tenant)
196 .bind(task_id)
197 .bind(id)
198 .execute(&self.pool)
199 .await
200 .map_err(|e| to_a2a_error(&e))?;
201 Ok(())
202 })
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use a2a_protocol_types::push::TaskPushNotificationConfig;
210
211 async fn make_store() -> TenantAwareSqlitePushConfigStore {
212 TenantAwareSqlitePushConfigStore::new("sqlite::memory:")
213 .await
214 .expect("failed to create in-memory tenant push config store")
215 }
216
217 fn make_config(task_id: &str, id: Option<&str>, url: &str) -> TaskPushNotificationConfig {
218 TaskPushNotificationConfig {
219 tenant: None,
220 id: id.map(String::from),
221 task_id: task_id.to_string(),
222 url: url.to_string(),
223 token: None,
224 authentication: None,
225 }
226 }
227
228 #[tokio::test]
229 async fn set_and_get_within_tenant() {
230 let store = make_store().await;
231 TenantContext::scope("acme", async {
232 store
233 .set(make_config("task-1", Some("cfg-1"), "https://example.com"))
234 .await
235 .unwrap();
236 let config = store.get("task-1", "cfg-1").await.unwrap();
237 assert!(
238 config.is_some(),
239 "config should be retrievable within its tenant"
240 );
241 assert_eq!(config.unwrap().url, "https://example.com");
242 })
243 .await;
244 }
245
246 #[tokio::test]
247 async fn tenant_isolation_get() {
248 let store = make_store().await;
249 TenantContext::scope("tenant-a", async {
250 store
251 .set(make_config("task-1", Some("cfg-1"), "https://a.com"))
252 .await
253 .unwrap();
254 })
255 .await;
256
257 TenantContext::scope("tenant-b", async {
258 let result = store.get("task-1", "cfg-1").await.unwrap();
259 assert!(
260 result.is_none(),
261 "tenant-b should not see tenant-a's config"
262 );
263 })
264 .await;
265 }
266
267 #[tokio::test]
268 async fn tenant_isolation_list() {
269 let store = make_store().await;
270 TenantContext::scope("tenant-a", async {
271 store
272 .set(make_config("task-1", Some("c1"), "https://a.com/1"))
273 .await
274 .unwrap();
275 store
276 .set(make_config("task-1", Some("c2"), "https://a.com/2"))
277 .await
278 .unwrap();
279 })
280 .await;
281
282 TenantContext::scope("tenant-b", async {
283 store
284 .set(make_config("task-1", Some("c3"), "https://b.com/1"))
285 .await
286 .unwrap();
287 })
288 .await;
289
290 TenantContext::scope("tenant-a", async {
291 let configs = store.list("task-1").await.unwrap();
292 assert_eq!(configs.len(), 2, "tenant-a should see only its 2 configs");
293 })
294 .await;
295
296 TenantContext::scope("tenant-b", async {
297 let configs = store.list("task-1").await.unwrap();
298 assert_eq!(configs.len(), 1, "tenant-b should see only its 1 config");
299 })
300 .await;
301 }
302
303 #[tokio::test]
304 async fn tenant_isolation_delete() {
305 let store = make_store().await;
306 TenantContext::scope("tenant-a", async {
307 store
308 .set(make_config("task-1", Some("cfg-1"), "https://a.com"))
309 .await
310 .unwrap();
311 })
312 .await;
313
314 TenantContext::scope("tenant-b", async {
316 store.delete("task-1", "cfg-1").await.unwrap();
317 })
318 .await;
319
320 TenantContext::scope("tenant-a", async {
321 let config = store.get("task-1", "cfg-1").await.unwrap();
322 assert!(
323 config.is_some(),
324 "tenant-a's config should survive tenant-b's delete"
325 );
326 })
327 .await;
328 }
329
330 #[tokio::test]
331 async fn same_keys_different_tenants() {
332 let store = make_store().await;
333 TenantContext::scope("tenant-a", async {
334 store
335 .set(make_config("task-1", Some("cfg-1"), "https://a.com"))
336 .await
337 .unwrap();
338 })
339 .await;
340
341 TenantContext::scope("tenant-b", async {
342 store
343 .set(make_config("task-1", Some("cfg-1"), "https://b.com"))
344 .await
345 .unwrap();
346 })
347 .await;
348
349 TenantContext::scope("tenant-a", async {
350 let config = store.get("task-1", "cfg-1").await.unwrap().unwrap();
351 assert_eq!(
352 config.url, "https://a.com",
353 "tenant-a should get its own config"
354 );
355 })
356 .await;
357
358 TenantContext::scope("tenant-b", async {
359 let config = store.get("task-1", "cfg-1").await.unwrap().unwrap();
360 assert_eq!(
361 config.url, "https://b.com",
362 "tenant-b should get its own config"
363 );
364 })
365 .await;
366 }
367
368 #[tokio::test]
369 async fn overwrite_within_tenant() {
370 let store = make_store().await;
371 TenantContext::scope("acme", async {
372 store
373 .set(make_config("task-1", Some("cfg-1"), "https://old.com"))
374 .await
375 .unwrap();
376 store
377 .set(make_config("task-1", Some("cfg-1"), "https://new.com"))
378 .await
379 .unwrap();
380
381 let config = store.get("task-1", "cfg-1").await.unwrap().unwrap();
382 assert_eq!(
383 config.url, "https://new.com",
384 "overwrite should update the URL"
385 );
386 })
387 .await;
388 }
389
390 #[tokio::test]
391 async fn set_assigns_id_when_none() {
392 let store = make_store().await;
393 TenantContext::scope("acme", async {
394 let config = make_config("task-1", None, "https://example.com");
395 let result = store.set(config).await.unwrap();
396 assert!(
397 result.id.is_some(),
398 "set should assign an id when none is provided"
399 );
400 })
401 .await;
402 }
403
404 #[tokio::test]
405 async fn delete_nonexistent_is_ok() {
406 let store = make_store().await;
407 TenantContext::scope("acme", async {
408 let result = store.delete("no-task", "no-id").await;
409 assert!(
410 result.is_ok(),
411 "deleting a nonexistent config should not error"
412 );
413 })
414 .await;
415 }
416
417 #[test]
419 fn to_a2a_error_formats_message() {
420 let sqlite_err = sqlx::Error::RowNotFound;
421 let a2a_err = to_a2a_error(&sqlite_err);
422 let msg = format!("{a2a_err}");
423 assert!(
424 msg.contains("sqlite error"),
425 "error message should contain 'sqlite error': {msg}"
426 );
427 }
428
429 #[tokio::test]
430 async fn default_tenant_context_uses_empty_string() {
431 let store = make_store().await;
432 store
434 .set(make_config("task-1", Some("cfg-1"), "https://default.com"))
435 .await
436 .unwrap();
437 let config = store.get("task-1", "cfg-1").await.unwrap();
438 assert!(config.is_some(), "default (empty) tenant should work");
439 }
440}