Skip to main content

a2a_rs/port/
streaming_handler.rs

1//! Streaming and real-time update handling port definitions
2
3use async_trait::async_trait;
4use futures::Stream;
5use std::pin::Pin;
6
7use crate::domain::core::task::TaskStateExt;
8use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
9
10/// A trait for subscribing to real-time updates
11#[async_trait]
12pub trait Subscriber<T>: Send + Sync {
13    /// Handle an update
14    async fn on_update(&self, update: T) -> Result<(), A2AError>;
15
16    /// Handle subscription errors
17    async fn on_error(&self, error: A2AError) -> Result<(), A2AError> {
18        // Default implementation - log error but don't propagate
19        eprintln!("Subscription error: {}", error);
20        Ok(())
21    }
22
23    /// Handle subscription completion
24    async fn on_complete(&self) -> Result<(), A2AError> {
25        // Default implementation - no-op
26        Ok(())
27    }
28}
29
30#[async_trait]
31/// An async trait for managing streaming connections and real-time updates
32pub trait AsyncStreamingHandler: Send + Sync {
33    /// Add a status update subscriber for a task
34    async fn add_status_subscriber(
35        &self,
36        task_id: &str,
37        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
38    ) -> Result<String, A2AError>; // Returns subscription ID
39
40    /// Add an artifact update subscriber for a task
41    async fn add_artifact_subscriber(
42        &self,
43        task_id: &str,
44        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
45    ) -> Result<String, A2AError>; // Returns subscription ID
46
47    /// Remove a specific subscription
48    async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
49
50    /// Remove all subscribers for a task
51    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
52
53    /// Get the number of active subscribers for a task
54    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
55
56    /// Check if a task has any active subscribers
57    async fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
58        let count = self.get_subscriber_count(task_id).await?;
59        Ok(count > 0)
60    }
61
62    /// Broadcast a status update to all subscribers of a task
63    async fn broadcast_status_update(
64        &self,
65        task_id: &str,
66        update: TaskStatusUpdateEvent,
67    ) -> Result<(), A2AError>;
68
69    /// Broadcast an artifact update to all subscribers of a task
70    async fn broadcast_artifact_update(
71        &self,
72        task_id: &str,
73        update: TaskArtifactUpdateEvent,
74    ) -> Result<(), A2AError>;
75
76    /// Create a stream of status updates for a task
77    async fn status_update_stream(
78        &self,
79        task_id: &str,
80    ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>;
81
82    /// Create a stream of artifact updates for a task
83    async fn artifact_update_stream(
84        &self,
85        task_id: &str,
86    ) -> Result<
87        Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
88        A2AError,
89    >;
90
91    /// Create a combined stream of all updates for a task.
92    ///
93    /// Each yielded [`SeqEvent`] carries a per-task monotonic id so a client can
94    /// resume after a disconnect. When `from_event_id` is `Some(n)`, the handler
95    /// first replays any buffered events with id `> n` (best-effort, bounded by
96    /// the handler's replay buffer) before streaming live updates.
97    async fn combined_update_stream(
98        &self,
99        task_id: &str,
100        from_event_id: Option<u64>,
101    ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError>;
102
103    /// Validate streaming parameters
104    async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> {
105        if task_id.trim().is_empty() {
106            return Err(A2AError::ValidationError {
107                field: "task_id".to_string(),
108                message: "Task ID cannot be empty for streaming".to_string(),
109            });
110        }
111        Ok(())
112    }
113
114    /// Start streaming for a task with automatic cleanup.
115    ///
116    /// `from_event_id` is forwarded to [`combined_update_stream`] for
117    /// Last-Event-ID resumption.
118    ///
119    /// [`combined_update_stream`]: AsyncStreamingHandler::combined_update_stream
120    async fn start_task_streaming(
121        &self,
122        task_id: &str,
123        from_event_id: Option<u64>,
124    ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
125        self.validate_streaming_params(task_id).await?;
126        self.combined_update_stream(task_id, from_event_id).await
127    }
128
129    /// Stop all streaming for a task
130    async fn stop_task_streaming(&self, task_id: &str) -> Result<(), A2AError> {
131        self.remove_task_subscribers(task_id).await
132    }
133}
134
135/// Forwarding impl so a type-erased `Arc<dyn AsyncStreamingHandler>` can itself
136/// be passed wherever an `impl AsyncStreamingHandler` is expected (e.g.
137/// `TaskService::with_streaming_handler`). This lets a single shared streaming
138/// backend be injected into both a message handler and a transport adapter
139/// without naming its concrete type. Only the required methods are forwarded;
140/// the trait's default methods ride along on top of them.
141#[async_trait]
142impl AsyncStreamingHandler for std::sync::Arc<dyn AsyncStreamingHandler> {
143    async fn add_status_subscriber(
144        &self,
145        task_id: &str,
146        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
147    ) -> Result<String, A2AError> {
148        (**self).add_status_subscriber(task_id, subscriber).await
149    }
150
151    async fn add_artifact_subscriber(
152        &self,
153        task_id: &str,
154        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
155    ) -> Result<String, A2AError> {
156        (**self).add_artifact_subscriber(task_id, subscriber).await
157    }
158
159    async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
160        (**self).remove_subscription(subscription_id).await
161    }
162
163    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
164        (**self).remove_task_subscribers(task_id).await
165    }
166
167    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
168        (**self).get_subscriber_count(task_id).await
169    }
170
171    async fn broadcast_status_update(
172        &self,
173        task_id: &str,
174        update: TaskStatusUpdateEvent,
175    ) -> Result<(), A2AError> {
176        (**self).broadcast_status_update(task_id, update).await
177    }
178
179    async fn broadcast_artifact_update(
180        &self,
181        task_id: &str,
182        update: TaskArtifactUpdateEvent,
183    ) -> Result<(), A2AError> {
184        (**self).broadcast_artifact_update(task_id, update).await
185    }
186
187    async fn status_update_stream(
188        &self,
189        task_id: &str,
190    ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
191    {
192        (**self).status_update_stream(task_id).await
193    }
194
195    async fn artifact_update_stream(
196        &self,
197        task_id: &str,
198    ) -> Result<
199        Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
200        A2AError,
201    > {
202        (**self).artifact_update_stream(task_id).await
203    }
204
205    async fn combined_update_stream(
206        &self,
207        task_id: &str,
208        from_event_id: Option<u64>,
209    ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
210        (**self)
211            .combined_update_stream(task_id, from_event_id)
212            .await
213    }
214}
215
216/// A streamed [`UpdateEvent`] tagged with a per-task monotonic id.
217///
218/// The id is assigned by the streaming handler when the event is broadcast and
219/// is surfaced to clients as the SSE `id:` field. On reconnect a client echoes
220/// the last id it saw via `Last-Event-ID`, and the handler replays buffered
221/// events with a greater id (see
222/// [`combined_update_stream`](AsyncStreamingHandler::combined_update_stream)).
223///
224/// This id/`Last-Event-ID` resumption is an a2a-rs enhancement on top of the
225/// W3C SSE standard, **not** part of the A2A v1.0 spec. Emitting the `id:` field
226/// is inert for spec clients (they read only the event payload), so it does not
227/// affect interop.
228#[derive(Debug, Clone)]
229pub struct SeqEvent {
230    /// Per-task monotonic event id (starts at 1; `0` is reserved for the
231    /// initial task snapshot, which carries no replayable id).
232    pub id: u64,
233    /// The update payload.
234    pub event: UpdateEvent,
235}
236
237impl SeqEvent {
238    /// Construct a sequenced event.
239    #[inline]
240    pub fn new(id: u64, event: UpdateEvent) -> Self {
241        Self { id, event }
242    }
243}
244
245/// Union type for different kinds of updates that can be streamed
246#[derive(Debug, Clone)]
247pub enum UpdateEvent {
248    StatusUpdate(TaskStatusUpdateEvent),
249    ArtifactUpdate(TaskArtifactUpdateEvent),
250}
251
252impl UpdateEvent {
253    /// Get the task ID from the update event
254    #[inline]
255    pub fn task_id(&self) -> &str {
256        match self {
257            UpdateEvent::StatusUpdate(event) => &event.task_id,
258            UpdateEvent::ArtifactUpdate(event) => &event.task_id,
259        }
260    }
261
262    /// Get the context ID from the update event
263    #[inline]
264    pub fn context_id(&self) -> &str {
265        match self {
266            UpdateEvent::StatusUpdate(event) => &event.context_id,
267            UpdateEvent::ArtifactUpdate(event) => &event.context_id,
268        }
269    }
270
271    /// Check if this is a final update
272    #[inline]
273    pub fn is_final(&self) -> bool {
274        match self {
275            UpdateEvent::StatusUpdate(event) => event.status.state.is_terminal(),
276            UpdateEvent::ArtifactUpdate(event) => event.last_chunk.unwrap_or(false),
277        }
278    }
279}