use ahash::AHashMap;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::a2a::core::task_types::TaskId;
#[derive(Debug, thiserror::Error)]
pub enum PushNotificationError {
#[error("invalid input: {reason}")]
InvalidInput {
reason: String,
},
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct PushNotificationId(Uuid);
impl PushNotificationId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl Default for PushNotificationId {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for PushNotificationId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::str::FromStr for PushNotificationId {
type Err = uuid::Error;
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
Ok(Self(s.parse()?))
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PushNotificationAuth {
pub scheme: String,
pub credentials: String,
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PushNotificationConfig {
pub id: PushNotificationId,
pub task_id: TaskId,
pub url: String,
#[serde(default, skip_serializing_if = "String::is_empty")]
pub token: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub authentication: Option<PushNotificationAuth>,
}
const MAX_PUSH_CONFIGS_PER_TASK: usize = 16;
#[derive(Debug, Default)]
pub struct PushNotificationStore {
configs: AHashMap<TaskId, Vec<PushNotificationConfig>>,
}
impl PushNotificationStore {
pub fn new() -> Self {
Self::default()
}
pub fn create(
&mut self,
task_id: TaskId,
url: String,
token: String,
authentication: Option<PushNotificationAuth>,
) -> Result<PushNotificationConfig, PushNotificationError> {
validate_webhook_url(&url)?;
let existing = self.configs.get(&task_id).map_or(0, Vec::len);
if existing >= MAX_PUSH_CONFIGS_PER_TASK {
return Err(PushNotificationError::InvalidInput {
reason: format!(
"task already has the maximum of {MAX_PUSH_CONFIGS_PER_TASK} push notification configs"
),
});
}
let cfg = PushNotificationConfig {
id: PushNotificationId::new(),
task_id,
url,
token,
authentication,
};
self.configs.entry(task_id).or_default().push(cfg.clone());
Ok(cfg)
}
pub fn get(
&self,
task_id: &TaskId,
id: &PushNotificationId,
) -> Option<&PushNotificationConfig> {
self.configs.get(task_id)?.iter().find(|c| &c.id == id)
}
pub fn list(&self, task_id: &TaskId) -> &[PushNotificationConfig] {
self.configs.get(task_id).map_or(&[], Vec::as_slice)
}
pub fn delete(&mut self, task_id: &TaskId, id: &PushNotificationId) -> bool {
let Some(v) = self.configs.get_mut(task_id) else {
return false;
};
let len_before = v.len();
v.retain(|c| &c.id != id);
let removed = v.len() < len_before;
if v.is_empty() {
self.configs.remove(task_id);
}
removed
}
}
fn validate_webhook_url(url: &str) -> Result<(), PushNotificationError> {
crate::a2a::core::ssrf::validate_webhook_url(url)
.map(|_target| ())
.map_err(|crate::a2a::core::ssrf::SsrfRejected { reason }| {
PushNotificationError::InvalidInput { reason }
})
}
#[cfg(test)]
mod tests {
use super::*;
fn task_id() -> TaskId {
TaskId::new()
}
#[test]
fn create_and_get_round_trip() {
let mut store = PushNotificationStore::new();
let tid = task_id();
let cfg = store
.create(
tid,
"https://example.com/webhook".to_owned(),
"tok".to_owned(),
None,
)
.expect("create must succeed");
let fetched = store.get(&tid, &cfg.id).expect("must find created config");
assert_eq!(fetched, &cfg, "round-trip must yield identical config");
}
#[test]
fn create_rejects_non_http_url() {
let mut store = PushNotificationStore::new();
let err = store
.create(
task_id(),
"ftp://example.com/x".to_owned(),
String::new(),
None,
)
.expect_err("non-http url must be rejected");
assert!(
matches!(err, PushNotificationError::InvalidInput { ref reason } if reason.contains("http")),
"expected InvalidInput about http scheme, got: {err:?}"
);
}
#[test]
fn create_rejects_malformed_url() {
let mut store = PushNotificationStore::new();
let err = store
.create(task_id(), "not a url".to_owned(), String::new(), None)
.expect_err("invalid url must be rejected");
assert!(matches!(err, PushNotificationError::InvalidInput { .. }));
}
#[test]
fn list_returns_all_for_task() {
let mut store = PushNotificationStore::new();
let tid = task_id();
store
.create(tid, "https://a.example/".to_owned(), String::new(), None)
.unwrap();
store
.create(tid, "https://b.example/".to_owned(), String::new(), None)
.unwrap();
store
.create(
task_id(),
"https://c.example/".to_owned(),
String::new(),
None,
)
.unwrap();
let listed = store.list(&tid);
assert_eq!(listed.len(), 2, "must list exactly 2 configs for the task");
}
#[test]
fn create_rejects_when_per_task_cap_reached() {
let mut store = PushNotificationStore::new();
let tid = task_id();
for i in 0..MAX_PUSH_CONFIGS_PER_TASK {
store
.create(tid, format!("https://h{i}.example/"), String::new(), None)
.expect("create within cap must succeed");
}
let err = store
.create(
tid,
"https://overflow.example/".to_owned(),
String::new(),
None,
)
.expect_err("create past the cap must be rejected");
assert!(
matches!(err, PushNotificationError::InvalidInput { ref reason } if reason.contains("maximum")),
"expected a cap InvalidInput, got: {err:?}"
);
store
.create(
task_id(),
"https://other.example/".to_owned(),
String::new(),
None,
)
.expect("a different task must still accept configs");
}
#[test]
fn delete_removes_config_and_returns_true() {
let mut store = PushNotificationStore::new();
let tid = task_id();
let cfg = store
.create(tid, "https://x.example/".to_owned(), String::new(), None)
.unwrap();
assert!(store.delete(&tid, &cfg.id), "delete must report success");
assert!(
store.get(&tid, &cfg.id).is_none(),
"config must be gone after delete"
);
assert!(
!store.delete(&tid, &cfg.id),
"second delete must report no-op"
);
}
}