use async_trait::async_trait;
use crate::domain::{
A2AError, Message, Task, TaskArtifactUpdateEvent, TaskId, TaskState, TaskStatusUpdateEvent,
};
use crate::port::{AsyncPushNotifier, AsyncStreamingHandler, AsyncTaskLifecycle};
pub trait HasTaskLifecycle {
fn lifecycle(&self) -> &dyn AsyncTaskLifecycle;
}
pub trait HasStreaming {
fn streaming(&self) -> &dyn AsyncStreamingHandler;
}
pub trait HasPushNotifier {
fn push_notifier(&self) -> &dyn AsyncPushNotifier;
}
#[async_trait]
pub trait TaskStatusBroadcast:
HasTaskLifecycle + HasStreaming + HasPushNotifier + Send + Sync
{
async fn update_and_broadcast(
&self,
id: &TaskId,
state: TaskState,
message: Option<Message>,
) -> Result<Task, A2AError> {
let task = self.lifecycle().update_status(id, state, message).await?;
self.broadcast_current_status(id, &task).await?;
Ok(task)
}
async fn cancel_and_broadcast(&self, id: &TaskId) -> Result<Task, A2AError> {
let task = self.lifecycle().cancel(id).await?;
self.broadcast_current_status(id, &task).await?;
Ok(task)
}
async fn broadcast_artifact(
&self,
id: &TaskId,
event: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
self.streaming()
.broadcast_artifact_update(id.as_str(), event.clone())
.await?;
self.notify_push_artifact(id, &event).await;
Ok(())
}
#[doc(hidden)]
async fn broadcast_current_status(&self, id: &TaskId, task: &Task) -> Result<(), A2AError> {
let event = TaskStatusUpdateEvent {
task_id: task.id.clone(),
context_id: task.context_id.clone(),
kind: "status-update".to_string(),
status: task.status.clone().into_option().unwrap_or_default(),
metadata: None,
};
self.streaming()
.broadcast_status_update(id.as_str(), event.clone())
.await?;
self.notify_push_status(id, &event).await;
Ok(())
}
#[doc(hidden)]
async fn notify_push_status(&self, id: &TaskId, event: &TaskStatusUpdateEvent) {
if let Err(_e) = self.push_notifier().notify_status(id.as_str(), event).await {
#[cfg(feature = "tracing")]
tracing::warn!(task_id = %id.as_str(), error = %_e, "push status notification failed");
}
}
#[doc(hidden)]
async fn notify_push_artifact(&self, id: &TaskId, event: &TaskArtifactUpdateEvent) {
if let Err(_e) = self
.push_notifier()
.notify_artifact(id.as_str(), event)
.await
{
#[cfg(feature = "tracing")]
tracing::warn!(task_id = %id.as_str(), error = %_e, "push artifact notification failed");
}
}
}
impl<T: HasTaskLifecycle + HasStreaming + HasPushNotifier + Send + Sync + ?Sized>
TaskStatusBroadcast for T
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::adapter::storage::InMemoryTaskStorage;
use crate::adapter::streaming::InMemoryStreamingHandler;
use crate::port::NoopPushNotifier;
use crate::port::streaming_handler::Subscriber;
use std::sync::{Arc, Mutex};
struct BroadcastRig {
store: Arc<InMemoryTaskStorage>,
streaming: InMemoryStreamingHandler,
push: NoopPushNotifier,
}
impl HasTaskLifecycle for BroadcastRig {
fn lifecycle(&self) -> &dyn AsyncTaskLifecycle {
self.store.as_ref()
}
}
impl HasStreaming for BroadcastRig {
fn streaming(&self) -> &dyn AsyncStreamingHandler {
&self.streaming
}
}
impl HasPushNotifier for BroadcastRig {
fn push_notifier(&self) -> &dyn AsyncPushNotifier {
&self.push
}
}
#[derive(Clone, Default)]
struct Recorder {
states: Arc<Mutex<Vec<::buffa::EnumValue<TaskState>>>>,
}
#[async_trait]
impl Subscriber<TaskStatusUpdateEvent> for Recorder {
async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> {
self.states.lock().unwrap().push(update.status.state);
Ok(())
}
}
fn rig(store: Arc<InMemoryTaskStorage>) -> BroadcastRig {
BroadcastRig {
store,
streaming: InMemoryStreamingHandler::new(),
push: NoopPushNotifier,
}
}
#[tokio::test]
async fn update_and_broadcast_persists_then_announces() {
let store = Arc::new(InMemoryTaskStorage::new());
let id = TaskId::try_from("task-1").unwrap();
let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap();
store.create(&id, &ctx).await.unwrap();
store
.update_status(&id, TaskState::Working, None)
.await
.unwrap();
let rig = rig(store);
let task = rig
.update_and_broadcast(&id, TaskState::Completed, None)
.await
.unwrap();
assert_eq!(task.status.state, TaskState::Completed);
}
#[tokio::test]
async fn bare_update_status_does_not_broadcast() {
let store = Arc::new(InMemoryTaskStorage::new());
let id = TaskId::try_from("task-1").unwrap();
let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap();
let streaming = InMemoryStreamingHandler::new();
let recorder = Recorder::default();
streaming
.add_status_subscriber(id.as_str(), Box::new(recorder.clone()))
.await
.unwrap();
store.create(&id, &ctx).await.unwrap();
store
.update_status(&id, TaskState::Working, None)
.await
.unwrap();
store.cancel(&id).await.unwrap();
assert!(
recorder.states.lock().unwrap().is_empty(),
"storage mutators must not self-broadcast"
);
}
#[tokio::test]
async fn mixin_announces_each_mutation_once() {
let store = Arc::new(InMemoryTaskStorage::new());
let id = TaskId::try_from("task-1").unwrap();
let ctx = crate::domain::ContextId::try_from("ctx-1").unwrap();
store.create(&id, &ctx).await.unwrap();
let rig = rig(store);
let recorder = Recorder::default();
rig.streaming
.add_status_subscriber(id.as_str(), Box::new(recorder.clone()))
.await
.unwrap();
rig.update_and_broadcast(&id, TaskState::Working, None)
.await
.unwrap();
rig.cancel_and_broadcast(&id).await.unwrap();
assert_eq!(
*recorder.states.lock().unwrap(),
vec![
::buffa::EnumValue::from(TaskState::Working),
::buffa::EnumValue::from(TaskState::Canceled),
],
);
}
}