1use async_trait::async_trait;
4use futures::Stream;
5use std::pin::Pin;
6
7use crate::domain::core::task::TaskStateExt;
8use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
9
10#[async_trait]
12pub trait Subscriber<T>: Send + Sync {
13 async fn on_update(&self, update: T) -> Result<(), A2AError>;
15
16 async fn on_error(&self, error: A2AError) -> Result<(), A2AError> {
18 eprintln!("Subscription error: {}", error);
20 Ok(())
21 }
22
23 async fn on_complete(&self) -> Result<(), A2AError> {
25 Ok(())
27 }
28}
29
30#[async_trait]
31pub trait AsyncStreamingHandler: Send + Sync {
33 async fn add_status_subscriber(
35 &self,
36 task_id: &str,
37 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
38 ) -> Result<String, A2AError>; async fn add_artifact_subscriber(
42 &self,
43 task_id: &str,
44 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
45 ) -> Result<String, A2AError>; async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
49
50 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
52
53 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
55
56 async fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
58 let count = self.get_subscriber_count(task_id).await?;
59 Ok(count > 0)
60 }
61
62 async fn broadcast_status_update(
64 &self,
65 task_id: &str,
66 update: TaskStatusUpdateEvent,
67 ) -> Result<(), A2AError>;
68
69 async fn broadcast_artifact_update(
71 &self,
72 task_id: &str,
73 update: TaskArtifactUpdateEvent,
74 ) -> Result<(), A2AError>;
75
76 async fn status_update_stream(
78 &self,
79 task_id: &str,
80 ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>;
81
82 async fn artifact_update_stream(
84 &self,
85 task_id: &str,
86 ) -> Result<
87 Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
88 A2AError,
89 >;
90
91 async fn combined_update_stream(
98 &self,
99 task_id: &str,
100 from_event_id: Option<u64>,
101 ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError>;
102
103 async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> {
105 if task_id.trim().is_empty() {
106 return Err(A2AError::ValidationError {
107 field: "task_id".to_string(),
108 message: "Task ID cannot be empty for streaming".to_string(),
109 });
110 }
111 Ok(())
112 }
113
114 async fn start_task_streaming(
121 &self,
122 task_id: &str,
123 from_event_id: Option<u64>,
124 ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
125 self.validate_streaming_params(task_id).await?;
126 self.combined_update_stream(task_id, from_event_id).await
127 }
128
129 async fn stop_task_streaming(&self, task_id: &str) -> Result<(), A2AError> {
131 self.remove_task_subscribers(task_id).await
132 }
133}
134
135#[async_trait]
142impl AsyncStreamingHandler for std::sync::Arc<dyn AsyncStreamingHandler> {
143 async fn add_status_subscriber(
144 &self,
145 task_id: &str,
146 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
147 ) -> Result<String, A2AError> {
148 (**self).add_status_subscriber(task_id, subscriber).await
149 }
150
151 async fn add_artifact_subscriber(
152 &self,
153 task_id: &str,
154 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
155 ) -> Result<String, A2AError> {
156 (**self).add_artifact_subscriber(task_id, subscriber).await
157 }
158
159 async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
160 (**self).remove_subscription(subscription_id).await
161 }
162
163 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
164 (**self).remove_task_subscribers(task_id).await
165 }
166
167 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
168 (**self).get_subscriber_count(task_id).await
169 }
170
171 async fn broadcast_status_update(
172 &self,
173 task_id: &str,
174 update: TaskStatusUpdateEvent,
175 ) -> Result<(), A2AError> {
176 (**self).broadcast_status_update(task_id, update).await
177 }
178
179 async fn broadcast_artifact_update(
180 &self,
181 task_id: &str,
182 update: TaskArtifactUpdateEvent,
183 ) -> Result<(), A2AError> {
184 (**self).broadcast_artifact_update(task_id, update).await
185 }
186
187 async fn status_update_stream(
188 &self,
189 task_id: &str,
190 ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
191 {
192 (**self).status_update_stream(task_id).await
193 }
194
195 async fn artifact_update_stream(
196 &self,
197 task_id: &str,
198 ) -> Result<
199 Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
200 A2AError,
201 > {
202 (**self).artifact_update_stream(task_id).await
203 }
204
205 async fn combined_update_stream(
206 &self,
207 task_id: &str,
208 from_event_id: Option<u64>,
209 ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
210 (**self)
211 .combined_update_stream(task_id, from_event_id)
212 .await
213 }
214}
215
216#[derive(Debug, Clone)]
229pub struct SeqEvent {
230 pub id: u64,
233 pub event: UpdateEvent,
235}
236
237impl SeqEvent {
238 #[inline]
240 pub fn new(id: u64, event: UpdateEvent) -> Self {
241 Self { id, event }
242 }
243}
244
245#[derive(Debug, Clone)]
247pub enum UpdateEvent {
248 StatusUpdate(TaskStatusUpdateEvent),
249 ArtifactUpdate(TaskArtifactUpdateEvent),
250}
251
252impl UpdateEvent {
253 #[inline]
255 pub fn task_id(&self) -> &str {
256 match self {
257 UpdateEvent::StatusUpdate(event) => &event.task_id,
258 UpdateEvent::ArtifactUpdate(event) => &event.task_id,
259 }
260 }
261
262 #[inline]
264 pub fn context_id(&self) -> &str {
265 match self {
266 UpdateEvent::StatusUpdate(event) => &event.context_id,
267 UpdateEvent::ArtifactUpdate(event) => &event.context_id,
268 }
269 }
270
271 #[inline]
273 pub fn is_final(&self) -> bool {
274 match self {
275 UpdateEvent::StatusUpdate(event) => event.status.state.is_terminal(),
276 UpdateEvent::ArtifactUpdate(event) => event.last_chunk.unwrap_or(false),
277 }
278 }
279}