use awaken_protocol_a2a::{ListPushNotificationConfigsResponse, PushNotificationConfig};
use awaken_server_contract::thread::Thread;
use axum::Json;
use axum::extract::{Path, Query, State};
use axum::http::{HeaderMap, StatusCode};
use axum::response::{IntoResponse, Response};
use uuid::Uuid;
use crate::app::{ProtocolModuleState, ProtocolRoutesState};
use super::common::{
ensure_supported_version, load_thread_metadata_projection, parse_page_token,
persist_thread_metadata, trim_to_option,
};
use super::error::A2aError;
use super::push_outbox::enqueue_push_notification;
use super::stream_projector::{InitialStreamEvent, TaskStreamProjector};
use super::task::{
ensure_task_visible, load_task_snapshot, resolve_task, submitted_task, task_context_id,
};
use super::types::{
BLOCKING_POLL_INTERVAL, DEFAULT_PAGE_SIZE, ListPushConfigsQuery, MAX_PAGE_SIZE,
PUSH_CONFIGS_METADATA_KEY, StoredPushConfigs, TaskSnapshot,
};
pub(super) async fn a2a_create_push_config_default(
State(st): State<ProtocolRoutesState>,
Path(task_id): Path<String>,
headers: HeaderMap,
Json(payload): Json<PushNotificationConfig>,
) -> Result<Response, A2aError> {
create_push_config(st, headers, None, task_id, payload)
.await
.map(IntoResponse::into_response)
}
pub(super) async fn a2a_create_push_config_tenant(
State(st): State<ProtocolRoutesState>,
Path((tenant, task_id)): Path<(String, String)>,
headers: HeaderMap,
Json(payload): Json<PushNotificationConfig>,
) -> Result<Response, A2aError> {
create_push_config(st, headers, Some(tenant), task_id, payload)
.await
.map(IntoResponse::into_response)
}
pub(super) async fn a2a_list_push_configs_default(
State(st): State<ProtocolRoutesState>,
Path(task_id): Path<String>,
headers: HeaderMap,
Query(query): Query<ListPushConfigsQuery>,
) -> Result<Json<ListPushNotificationConfigsResponse>, A2aError> {
list_push_configs(st, headers, None, task_id, query).await
}
pub(super) async fn a2a_list_push_configs_tenant(
State(st): State<ProtocolRoutesState>,
Path((tenant, task_id)): Path<(String, String)>,
headers: HeaderMap,
Query(query): Query<ListPushConfigsQuery>,
) -> Result<Json<ListPushNotificationConfigsResponse>, A2aError> {
list_push_configs(st, headers, Some(tenant), task_id, query).await
}
pub(super) async fn a2a_get_push_config_default(
State(st): State<ProtocolRoutesState>,
Path((task_id, config_id)): Path<(String, String)>,
headers: HeaderMap,
) -> Result<Response, A2aError> {
get_push_config(st, headers, None, task_id, config_id)
.await
.map(IntoResponse::into_response)
}
pub(super) async fn a2a_get_push_config_tenant(
State(st): State<ProtocolRoutesState>,
Path((tenant, task_id, config_id)): Path<(String, String, String)>,
headers: HeaderMap,
) -> Result<Response, A2aError> {
get_push_config(st, headers, Some(tenant), task_id, config_id)
.await
.map(IntoResponse::into_response)
}
pub(super) async fn a2a_delete_push_config_default(
State(st): State<ProtocolRoutesState>,
Path((task_id, config_id)): Path<(String, String)>,
headers: HeaderMap,
) -> Result<Response, A2aError> {
delete_push_config(st, headers, None, task_id, config_id).await
}
pub(super) async fn a2a_delete_push_config_tenant(
State(st): State<ProtocolRoutesState>,
Path((tenant, task_id, config_id)): Path<(String, String, String)>,
headers: HeaderMap,
) -> Result<Response, A2aError> {
delete_push_config(st, headers, Some(tenant), task_id, config_id).await
}
async fn create_push_config(
st: ProtocolRoutesState,
headers: HeaderMap,
tenant: Option<String>,
task_id: String,
payload: PushNotificationConfig,
) -> Result<Json<PushNotificationConfig>, A2aError> {
ensure_supported_version(&headers)?;
ensure_task_visible(&st, &task_id, tenant.as_deref()).await?;
let config = normalize_push_config(payload, tenant.as_deref(), &task_id)?;
upsert_push_notification_config(&st, &task_id, tenant.as_deref(), config.clone()).await?;
spawn_push_notification_driver(st, task_id, tenant, config.clone());
Ok(Json(config))
}
async fn list_push_configs(
st: ProtocolRoutesState,
headers: HeaderMap,
tenant: Option<String>,
task_id: String,
query: ListPushConfigsQuery,
) -> Result<Json<ListPushNotificationConfigsResponse>, A2aError> {
ensure_supported_version(&headers)?;
ensure_task_visible(&st, &task_id, tenant.as_deref()).await?;
let page_size = query
.page_size
.unwrap_or(DEFAULT_PAGE_SIZE)
.clamp(1, MAX_PAGE_SIZE);
let offset = parse_page_token(query.page_token.as_deref())?;
let configs = load_push_notification_configs(&st, &task_id, tenant.as_deref()).await?;
let total = configs.len();
let items = configs
.into_iter()
.skip(offset)
.take(page_size)
.collect::<Vec<_>>();
let next_offset = offset + items.len();
Ok(Json(ListPushNotificationConfigsResponse {
configs: items,
next_page_token: if next_offset < total {
next_offset.to_string()
} else {
String::new()
},
}))
}
async fn get_push_config(
st: ProtocolRoutesState,
headers: HeaderMap,
tenant: Option<String>,
task_id: String,
config_id: String,
) -> Result<Json<PushNotificationConfig>, A2aError> {
ensure_supported_version(&headers)?;
ensure_task_visible(&st, &task_id, tenant.as_deref()).await?;
let config = find_push_notification_config(&st, &task_id, tenant.as_deref(), &config_id)
.await?
.ok_or_else(|| A2aError::push_config_not_found(task_id.clone(), config_id.clone()))?;
Ok(Json(config))
}
async fn delete_push_config(
st: ProtocolRoutesState,
headers: HeaderMap,
tenant: Option<String>,
task_id: String,
config_id: String,
) -> Result<Response, A2aError> {
ensure_supported_version(&headers)?;
ensure_task_visible(&st, &task_id, tenant.as_deref()).await?;
let mut configs = load_push_notification_configs(&st, &task_id, None).await?;
let before = configs.len();
configs.retain(|config| {
!(config.id.as_deref() == Some(config_id.as_str())
&& config.agent_id.as_deref() == tenant.as_deref())
});
if configs.len() == before {
return Err(A2aError::push_config_not_found(task_id, config_id));
}
save_push_notification_configs(&st, &task_id, configs).await?;
Ok(StatusCode::NO_CONTENT.into_response())
}
pub(super) fn normalize_push_config(
mut config: PushNotificationConfig,
tenant: Option<&str>,
task_id: &str,
) -> Result<PushNotificationConfig, A2aError> {
let parsed_url = reqwest::Url::parse(&config.url)
.map_err(|err| A2aError::invalid("pushNotificationConfig.url", err.to_string()))?;
if !matches!(parsed_url.scheme(), "http" | "https") {
return Err(A2aError::invalid(
"pushNotificationConfig.url",
"push notification URL must use http or https",
));
}
if let Some(existing_task_id) = trim_to_option(config.task_id.as_deref())
&& existing_task_id != task_id
{
return Err(A2aError::invalid(
"pushNotificationConfig.taskId",
"push notification taskId must match the enclosing task",
));
}
if let Some(existing_tenant) = trim_to_option(config.agent_id.as_deref())
&& tenant != Some(existing_tenant.as_str())
{
return Err(A2aError::invalid(
"pushNotificationConfig.tenant",
"push notification tenant must match the enclosing task tenant",
));
}
if let Some(authentication) = config.authentication.as_ref()
&& authentication.scheme.trim().is_empty()
{
return Err(A2aError::invalid(
"pushNotificationConfig.authentication.scheme",
"authentication scheme must not be empty",
));
}
config.id.get_or_insert_with(|| Uuid::now_v7().to_string());
config.task_id = Some(task_id.to_string());
config.agent_id = tenant.map(ToOwned::to_owned);
Ok(config)
}
pub(super) async fn load_push_notification_configs(
st: &ProtocolRoutesState,
task_id: &str,
tenant: Option<&str>,
) -> Result<Vec<PushNotificationConfig>, A2aError> {
let Some(task) = resolve_task(st, task_id).await? else {
return Ok(Vec::new());
};
let Some(thread) = st
.run
.store()
.load_thread(&task.thread_id)
.await
.map_err(|e| A2aError::Internal(e.to_string()))?
else {
return Ok(Vec::new());
};
let mut configs = load_thread_push_notification_configs(&thread, task_id)?;
if let Some(tenant) = tenant {
configs.retain(|config| config.agent_id.as_deref() == Some(tenant));
}
configs.sort_by(|left, right| left.id.cmp(&right.id));
Ok(configs)
}
async fn find_push_notification_config(
st: &ProtocolRoutesState,
task_id: &str,
tenant: Option<&str>,
config_id: &str,
) -> Result<Option<PushNotificationConfig>, A2aError> {
Ok(load_push_notification_configs(st, task_id, tenant)
.await?
.into_iter()
.find(|config| config.id.as_deref() == Some(config_id)))
}
async fn save_push_notification_configs(
st: &ProtocolRoutesState,
task_id: &str,
configs: Vec<PushNotificationConfig>,
) -> Result<(), A2aError> {
let Some(task) = resolve_task(st, task_id).await? else {
return Err(A2aError::task_not_found(task_id.to_string()));
};
let thread_id = task.thread_id;
let (exists, thread) = load_thread_metadata_projection(st, &thread_id).await?;
save_thread_push_notification_configs(st, &thread_id, exists, thread, task_id, configs).await
}
fn upsert_into(
configs: &mut Vec<PushNotificationConfig>,
tenant: Option<&str>,
config: PushNotificationConfig,
) {
if let Some(position) = configs
.iter()
.position(|existing| existing.id == config.id && existing.agent_id.as_deref() == tenant)
{
configs[position] = config;
} else {
configs.push(config);
}
}
async fn upsert_push_notification_config(
st: &ProtocolRoutesState,
task_id: &str,
tenant: Option<&str>,
config: PushNotificationConfig,
) -> Result<(), A2aError> {
let mut configs = load_push_notification_configs(st, task_id, None).await?;
upsert_into(&mut configs, tenant, config);
save_push_notification_configs(st, task_id, configs).await
}
pub(super) async fn upsert_push_notification_config_for_thread(
st: &ProtocolRoutesState,
thread_id: &str,
task_id: &str,
tenant: Option<&str>,
config: PushNotificationConfig,
) -> Result<(), A2aError> {
let (exists, thread) = load_thread_metadata_projection(st, thread_id).await?;
let mut configs = load_thread_push_notification_configs(&thread, task_id)?;
upsert_into(&mut configs, tenant, config);
save_thread_push_notification_configs(st, thread_id, exists, thread, task_id, configs).await
}
fn load_thread_push_notification_configs(
thread: &Thread,
task_id: &str,
) -> Result<Vec<PushNotificationConfig>, A2aError> {
let Some(value) = thread.metadata.custom.get(PUSH_CONFIGS_METADATA_KEY) else {
return Ok(Vec::new());
};
let stored = decode_push_configs_metadata(value.clone(), task_id)?;
Ok(stored.tasks.get(task_id).cloned().unwrap_or_default())
}
async fn save_thread_push_notification_configs(
st: &ProtocolRoutesState,
thread_id: &str,
exists: bool,
mut thread: Thread,
task_id: &str,
configs: Vec<PushNotificationConfig>,
) -> Result<(), A2aError> {
let mut stored = match thread.metadata.custom.remove(PUSH_CONFIGS_METADATA_KEY) {
Some(value) => decode_push_configs_metadata(value, task_id)?,
None => StoredPushConfigs::default(),
};
if configs.is_empty() {
stored.tasks.remove(task_id);
} else {
stored.tasks.insert(task_id.to_string(), configs);
}
if stored.tasks.is_empty() {
thread.metadata.custom.remove(PUSH_CONFIGS_METADATA_KEY);
} else {
thread.metadata.custom.insert(
PUSH_CONFIGS_METADATA_KEY.to_string(),
serde_json::to_value(stored).map_err(|e| A2aError::Internal(e.to_string()))?,
);
}
persist_thread_metadata(st, thread_id, exists, thread).await?;
Ok(())
}
fn decode_push_configs_metadata(
value: serde_json::Value,
task_id: &str,
) -> Result<StoredPushConfigs, A2aError> {
match serde_json::from_value::<StoredPushConfigs>(value.clone()) {
Ok(stored) => Ok(stored),
Err(stored_error) => {
let legacy_configs = serde_json::from_value::<Vec<PushNotificationConfig>>(value)
.map_err(|legacy_error| {
A2aError::Internal(format!(
"corrupt A2A push config metadata at {PUSH_CONFIGS_METADATA_KEY}: \
stored format error: {stored_error}; legacy format error: {legacy_error}"
))
})?;
let mut stored = StoredPushConfigs::default();
if !legacy_configs.is_empty() {
stored.tasks.insert(task_id.to_string(), legacy_configs);
}
Ok(stored)
}
}
}
struct PushDriverGuard {
protocol: ProtocolModuleState,
task_id: String,
tenant: Option<String>,
config_id: String,
}
impl Drop for PushDriverGuard {
fn drop(&mut self) {
self.protocol.unregister_a2a_push_driver(
&self.task_id,
self.tenant.as_deref(),
&self.config_id,
);
}
}
pub(super) fn spawn_push_notification_driver(
st: ProtocolRoutesState,
task_id: String,
tenant: Option<String>,
config: PushNotificationConfig,
) {
let config_id = config.id.clone().unwrap_or_default();
if !st
.protocol
.register_a2a_push_driver(&task_id, tenant.as_deref(), &config_id)
{
return;
}
tokio::spawn(async move {
let _guard = PushDriverGuard {
protocol: st.protocol.clone(),
task_id: task_id.clone(),
tenant: tenant.clone(),
config_id: config_id.clone(),
};
if let Err(err) = drive_push_notification(st, task_id, tenant, config_id).await {
tracing::warn!(error = ?err, "A2A push notification driver stopped with error");
}
});
}
async fn drive_push_notification(
st: ProtocolRoutesState,
task_id: String,
tenant: Option<String>,
config_id: String,
) -> Result<(), A2aError> {
let outbox = crate::protocol_replay_state::a2a_push_webhook_outbox_for_buffers(
&st.protocol.replay_buffers,
)
.ok_or_else(|| {
A2aError::Internal("A2A push notification outbox relay is not configured".to_string())
})?;
let mut projector = TaskStreamProjector::new(InitialStreamEvent::StatusUpdate);
loop {
let Some(config) =
find_push_notification_config(&st, &task_id, tenant.as_deref(), &config_id).await?
else {
break;
};
let snapshot = load_task_snapshot(&st, &task_id, tenant.as_deref(), usize::MAX, true)
.await?
.unwrap_or(TaskSnapshot {
task: submitted_task(
&task_id,
&task_context_id(&st, &task_id)
.await
.unwrap_or_else(|_| task_id.clone()),
tenant.as_deref(),
),
updated_at_ms: 0,
current_agent_id: tenant.clone(),
});
for response in projector.project(&snapshot) {
enqueue_push_notification(outbox.as_ref(), &config, &response)
.await
.map_err(|error| A2aError::Internal(error.to_string()))?;
if let Err(error) =
crate::protocol_replay_state::tick_a2a_push_webhook_outbox_for_buffers(
&st.protocol.replay_buffers,
)
.await
{
tracing::warn!(
error = %error,
"A2A push notification outbox relay tick failed"
);
}
}
if snapshot.task.status.state.is_terminal() || snapshot.task.status.state.is_interrupted() {
break;
}
tokio::time::sleep(BLOCKING_POLL_INTERVAL).await;
}
Ok(())
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
fn config(id: &str, tenant: Option<&str>, url: &str) -> PushNotificationConfig {
PushNotificationConfig {
agent_id: tenant.map(ToOwned::to_owned),
id: Some(id.to_string()),
task_id: Some("task-1".to_string()),
url: url.to_string(),
token: None,
authentication: None,
}
}
#[test]
fn upsert_into_is_scoped_per_tenant() {
let mut configs = Vec::new();
upsert_into(
&mut configs,
Some("a"),
config("shared", Some("a"), "http://a/1"),
);
upsert_into(
&mut configs,
Some("b"),
config("shared", Some("b"), "http://b/1"),
);
assert_eq!(configs.len(), 2);
upsert_into(
&mut configs,
Some("a"),
config("shared", Some("a"), "http://a/2"),
);
assert_eq!(configs.len(), 2);
let a = configs
.iter()
.find(|c| c.agent_id.as_deref() == Some("a"))
.unwrap();
assert_eq!(a.url, "http://a/2");
let b = configs
.iter()
.find(|c| c.agent_id.as_deref() == Some("b"))
.unwrap();
assert_eq!(b.url, "http://b/1");
}
#[test]
fn decode_push_configs_metadata_rejects_malformed_state() {
let error = decode_push_configs_metadata(json!({"tasks": []}), "task-1")
.expect_err("malformed push config metadata must fail closed");
match error {
A2aError::Internal(message) => {
assert!(message.contains(PUSH_CONFIGS_METADATA_KEY));
assert!(message.contains("stored format error"));
assert!(message.contains("legacy format error"));
}
other => panic!("expected internal corruption error, got {other:?}"),
}
}
#[test]
fn decode_push_configs_metadata_preserves_legacy_task_list() {
let stored =
decode_push_configs_metadata(json!([config("cfg-1", None, "http://a/1")]), "task-1")
.expect("legacy push config list must remain readable");
assert_eq!(stored.tasks["task-1"][0].id.as_deref(), Some("cfg-1"));
}
#[test]
fn push_driver_guard_releases_key_on_drop() {
let protocol = ProtocolModuleState::new();
assert!(protocol.register_a2a_push_driver("task-1", Some("a"), "cfg-1"));
assert!(!protocol.register_a2a_push_driver("task-1", Some("a"), "cfg-1"));
let guard = PushDriverGuard {
protocol: protocol.clone(),
task_id: "task-1".to_string(),
tenant: Some("a".to_string()),
config_id: "cfg-1".to_string(),
};
drop(guard);
assert!(protocol.register_a2a_push_driver("task-1", Some("a"), "cfg-1"));
}
}