a2a_rs/port/
streaming_handler.rs1#[cfg(feature = "server")]
4use async_trait::async_trait;
5use futures::Stream;
6use std::pin::Pin;
7
8use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
9
10#[cfg(feature = "server")]
12#[async_trait]
13pub trait Subscriber<T>: Send + Sync {
14 async fn on_update(&self, update: T) -> Result<(), A2AError>;
16
17 async fn on_error(&self, error: A2AError) -> Result<(), A2AError> {
19 eprintln!("Subscription error: {}", error);
21 Ok(())
22 }
23
24 async fn on_complete(&self) -> Result<(), A2AError> {
26 Ok(())
28 }
29}
30
31pub trait StreamingHandler {
33 fn add_status_subscriber(
35 &self,
36 task_id: &str,
37 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
38 ) -> Result<String, A2AError>; fn add_artifact_subscriber(
42 &self,
43 task_id: &str,
44 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
45 ) -> Result<String, A2AError>; fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
49
50 fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
52
53 fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
55
56 fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
58 let count = self.get_subscriber_count(task_id)?;
59 Ok(count > 0)
60 }
61}
62
63#[cfg(feature = "server")]
64#[async_trait]
65pub trait AsyncStreamingHandler: Send + Sync {
67 async fn add_status_subscriber(
69 &self,
70 task_id: &str,
71 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
72 ) -> Result<String, A2AError>; async fn add_artifact_subscriber(
76 &self,
77 task_id: &str,
78 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
79 ) -> Result<String, A2AError>; async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError>;
83
84 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError>;
86
87 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError>;
89
90 async fn has_subscribers(&self, task_id: &str) -> Result<bool, A2AError> {
92 let count = self.get_subscriber_count(task_id).await?;
93 Ok(count > 0)
94 }
95
96 async fn broadcast_status_update(
98 &self,
99 task_id: &str,
100 update: TaskStatusUpdateEvent,
101 ) -> Result<(), A2AError>;
102
103 async fn broadcast_artifact_update(
105 &self,
106 task_id: &str,
107 update: TaskArtifactUpdateEvent,
108 ) -> Result<(), A2AError>;
109
110 async fn status_update_stream(
112 &self,
113 task_id: &str,
114 ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>;
115
116 async fn artifact_update_stream(
118 &self,
119 task_id: &str,
120 ) -> Result<
121 Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
122 A2AError,
123 >;
124
125 async fn combined_update_stream(
127 &self,
128 task_id: &str,
129 ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError>;
130
131 async fn validate_streaming_params(&self, task_id: &str) -> Result<(), A2AError> {
133 if task_id.trim().is_empty() {
134 return Err(A2AError::ValidationError {
135 field: "task_id".to_string(),
136 message: "Task ID cannot be empty for streaming".to_string(),
137 });
138 }
139 Ok(())
140 }
141
142 async fn start_task_streaming(
144 &self,
145 task_id: &str,
146 ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError> {
147 self.validate_streaming_params(task_id).await?;
148 self.combined_update_stream(task_id).await
149 }
150
151 async fn stop_task_streaming(&self, task_id: &str) -> Result<(), A2AError> {
153 self.remove_task_subscribers(task_id).await
154 }
155}
156
157#[derive(Debug, Clone)]
159pub enum UpdateEvent {
160 StatusUpdate(TaskStatusUpdateEvent),
161 ArtifactUpdate(TaskArtifactUpdateEvent),
162}
163
164impl UpdateEvent {
165 #[inline]
167 pub fn task_id(&self) -> &str {
168 match self {
169 UpdateEvent::StatusUpdate(event) => &event.task_id,
170 UpdateEvent::ArtifactUpdate(event) => &event.task_id,
171 }
172 }
173
174 #[inline]
176 pub fn context_id(&self) -> &str {
177 match self {
178 UpdateEvent::StatusUpdate(event) => &event.context_id,
179 UpdateEvent::ArtifactUpdate(event) => &event.context_id,
180 }
181 }
182
183 #[inline]
185 pub fn is_final(&self) -> bool {
186 match self {
187 UpdateEvent::StatusUpdate(event) => event.final_,
188 UpdateEvent::ArtifactUpdate(event) => event.last_chunk.unwrap_or(false),
189 }
190 }
191}