use async_trait::async_trait;
use futures::Stream;
use std::pin::Pin;
use crate::domain::core::task::TaskStateExt;
use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
#[async_trait]
pub trait Subscriber<T>: Send + Sync {
async fn on_update(&self, update: T) -> Result<(), A2AError>;
async fn on_error(&self, error: A2AError) -> Result<(), A2AError> {
eprintln!("Subscription error: {}", error);
Ok(())
}
async fn on_complete(&self) -> Result<(), A2AError> {
Ok(())
}
}
#[async_trait]
pub trait AsyncStreamingHandler: Send + Sync {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError>;
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError>;
async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
async fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
let count = self.get_subscriber_count(task_id).await?;
Ok(count > 0)
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError>;
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError>;
async fn status_update_stream(
&self,
task_id: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>;
async fn artifact_update_stream(
&self,
task_id: &str,
) -> Result<
Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
A2AError,
>;
async fn combined_update_stream(
&self,
task_id: &str,
from_event_id: Option<u64>,
) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError>;
async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> {
if task_id.trim().is_empty() {
return Err(A2AError::ValidationError {
field: "task_id".to_string(),
message: "Task ID cannot be empty for streaming".to_string(),
});
}
Ok(())
}
async fn start_task_streaming(
&self,
task_id: &str,
from_event_id: Option<u64>,
) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
self.validate_streaming_params(task_id).await?;
self.combined_update_stream(task_id, from_event_id).await
}
async fn stop_task_streaming(&self, task_id: &str) -> Result<(), A2AError> {
self.remove_task_subscribers(task_id).await
}
}
#[async_trait]
impl AsyncStreamingHandler for std::sync::Arc<dyn AsyncStreamingHandler> {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
(**self).add_status_subscriber(task_id, subscriber).await
}
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
(**self).add_artifact_subscriber(task_id, subscriber).await
}
async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
(**self).remove_subscription(subscription_id).await
}
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
(**self).remove_task_subscribers(task_id).await
}
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
(**self).get_subscriber_count(task_id).await
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError> {
(**self).broadcast_status_update(task_id, update).await
}
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
(**self).broadcast_artifact_update(task_id, update).await
}
async fn status_update_stream(
&self,
task_id: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
{
(**self).status_update_stream(task_id).await
}
async fn artifact_update_stream(
&self,
task_id: &str,
) -> Result<
Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
A2AError,
> {
(**self).artifact_update_stream(task_id).await
}
async fn combined_update_stream(
&self,
task_id: &str,
from_event_id: Option<u64>,
) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
(**self)
.combined_update_stream(task_id, from_event_id)
.await
}
}
#[derive(Debug, Clone)]
pub struct SeqEvent {
pub id: u64,
pub event: UpdateEvent,
}
impl SeqEvent {
#[inline]
pub fn new(id: u64, event: UpdateEvent) -> Self {
Self { id, event }
}
}
#[derive(Debug, Clone)]
pub enum UpdateEvent {
StatusUpdate(TaskStatusUpdateEvent),
ArtifactUpdate(TaskArtifactUpdateEvent),
}
impl UpdateEvent {
#[inline]
pub fn task_id(&self) -> &str {
match self {
UpdateEvent::StatusUpdate(event) => &event.task_id,
UpdateEvent::ArtifactUpdate(event) => &event.task_id,
}
}
#[inline]
pub fn context_id(&self) -> &str {
match self {
UpdateEvent::StatusUpdate(event) => &event.context_id,
UpdateEvent::ArtifactUpdate(event) => &event.context_id,
}
}
#[inline]
pub fn is_final(&self) -> bool {
match self {
UpdateEvent::StatusUpdate(event) => event.status.state.is_terminal(),
UpdateEvent::ArtifactUpdate(event) => event.last_chunk.unwrap_or(false),
}
}
}