use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use tokio::sync::RwLock;
use crate::error::{A2AError, Result};
use crate::types::{PushNotificationConfig, Task};
pub trait PushNotificationConfigStore: Send + Sync {
fn save<'a>(
&'a self,
task_id: &'a str,
config: &'a PushNotificationConfig,
) -> Pin<Box<dyn Future<Output = Result<PushNotificationConfig>> + Send + 'a>>;
fn get<'a>(
&'a self,
task_id: &'a str,
config_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<PushNotificationConfig>> + Send + 'a>>;
fn list<'a>(
&'a self,
task_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<Vec<PushNotificationConfig>>> + Send + 'a>>;
fn delete<'a>(
&'a self,
task_id: &'a str,
config_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
fn delete_all<'a>(
&'a self,
task_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
}
#[derive(Debug, Default)]
pub struct InMemoryPushNotificationConfigStore {
configs: RwLock<HashMap<String, HashMap<String, PushNotificationConfig>>>,
}
impl InMemoryPushNotificationConfigStore {
#[must_use]
pub fn new() -> Self {
Self::default()
}
}
impl PushNotificationConfigStore for InMemoryPushNotificationConfigStore {
fn save<'a>(
&'a self,
task_id: &'a str,
config: &'a PushNotificationConfig,
) -> Pin<Box<dyn Future<Output = Result<PushNotificationConfig>> + Send + 'a>> {
let task_id = task_id.to_owned();
let config = config.clone();
Box::pin(async move {
validate_push_config(&config)?;
let mut to_save = config;
if to_save.id.as_deref().unwrap_or("").is_empty() {
to_save.id = Some(uuid::Uuid::new_v4().to_string());
}
let config_id = to_save.id.clone().unwrap_or_default();
self.configs
.write()
.await
.entry(task_id)
.or_default()
.insert(config_id, to_save.clone());
Ok(to_save)
})
}
fn get<'a>(
&'a self,
task_id: &'a str,
config_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<PushNotificationConfig>> + Send + 'a>> {
let task_id = task_id.to_owned();
let config_id = config_id.to_owned();
Box::pin(async move {
let store = self.configs.read().await;
store
.get(&task_id)
.and_then(|m| m.get(&config_id))
.cloned()
.ok_or_else(|| A2AError::InvalidParams("push config not found".into()))
})
}
fn list<'a>(
&'a self,
task_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<Vec<PushNotificationConfig>>> + Send + 'a>> {
let task_id = task_id.to_owned();
Box::pin(async move {
let store = self.configs.read().await;
Ok(store
.get(&task_id)
.map(|m| m.values().cloned().collect())
.unwrap_or_default())
})
}
fn delete<'a>(
&'a self,
task_id: &'a str,
config_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
let task_id = task_id.to_owned();
let config_id = config_id.to_owned();
Box::pin(async move {
if let Some(m) = self.configs.write().await.get_mut(&task_id) {
m.remove(&config_id);
}
Ok(())
})
}
fn delete_all<'a>(
&'a self,
task_id: &'a str,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
let task_id = task_id.to_owned();
Box::pin(async move {
self.configs.write().await.remove(&task_id);
Ok(())
})
}
}
fn validate_push_config(config: &PushNotificationConfig) -> Result<()> {
if config.url.is_empty() {
return Err(A2AError::InvalidParams(
"push config URL cannot be empty".into(),
));
}
if !config.url.starts_with("http://") && !config.url.starts_with("https://") {
return Err(A2AError::InvalidParams(
"push config URL must be http or https".into(),
));
}
Ok(())
}
pub trait PushSender: Send + Sync {
fn send_push<'a>(
&'a self,
config: &'a PushNotificationConfig,
task: &'a Task,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
}
#[derive(Debug)]
pub struct HttpPushSender {
client: reqwest::Client,
fail_on_error: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct HttpPushSenderConfig {
pub timeout: std::time::Duration,
pub fail_on_error: bool,
}
impl Default for HttpPushSenderConfig {
fn default() -> Self {
Self {
timeout: std::time::Duration::from_secs(30),
fail_on_error: false,
}
}
}
impl HttpPushSender {
#[must_use]
pub fn new() -> Self {
Self::with_config(HttpPushSenderConfig::default())
}
#[must_use]
pub fn with_config(config: HttpPushSenderConfig) -> Self {
let client = reqwest::Client::builder()
.timeout(config.timeout)
.build()
.unwrap_or_default();
Self {
client,
fail_on_error: config.fail_on_error,
}
}
}
impl Default for HttpPushSender {
fn default() -> Self {
Self::new()
}
}
impl PushSender for HttpPushSender {
fn send_push<'a>(
&'a self,
config: &'a PushNotificationConfig,
task: &'a Task,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(self.send_push(config, task))
}
}
impl HttpPushSender {
async fn send_push(&self, config: &PushNotificationConfig, task: &Task) -> Result<()> {
let json_data = serde_json::to_vec(task)
.map_err(|e| A2AError::ServerError(format!("failed to serialize task: {e}")))?;
let mut req = self
.client
.post(&config.url)
.header("Content-Type", "application/json")
.body(json_data);
if let Some(ref token) = config.token
&& !token.is_empty()
{
req = req.header("X-A2A-Notification-Token", token);
}
if let Some(ref auth) = config.authentication
&& let Some(ref credentials) = auth.credentials
&& !credentials.is_empty()
{
req = Self::apply_auth(req, &auth.scheme, credentials);
}
let result = req.send().await;
self.handle_push_result(result)
}
fn apply_auth(
req: reqwest::RequestBuilder,
scheme: &str,
credentials: &str,
) -> reqwest::RequestBuilder {
match scheme.to_lowercase().as_str() {
"bearer" => req.header("Authorization", format!("Bearer {credentials}")),
"basic" => req.header("Authorization", format!("Basic {credentials}")),
_ => req,
}
}
fn handle_push_result(
&self,
result: std::result::Result<reqwest::Response, reqwest::Error>,
) -> Result<()> {
match result {
Ok(resp) if !resp.status().is_success() => {
let msg = format!(
"push notification endpoint returned non-success status: {}",
resp.status()
);
self.maybe_fail(msg)
}
Ok(_) => Ok(()),
Err(e) => self.maybe_fail(format!("failed to send push notification: {e}")),
}
}
fn maybe_fail(&self, msg: String) -> Result<()> {
if self.fail_on_error {
Err(A2AError::ServerError(msg))
} else {
tracing::error!("{msg}");
Ok(())
}
}
}