a2a_rs/adapter/streaming/
in_memory.rs1use std::collections::HashMap;
24use std::collections::VecDeque;
25use std::pin::Pin;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use futures::{Stream, StreamExt};
30use tokio::sync::Mutex;
31use tokio::sync::broadcast;
32
33use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
34use crate::port::AsyncStreamingHandler;
35use crate::port::streaming_handler::{SeqEvent, Subscriber, UpdateEvent};
36
37type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
38type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
39
40const CHANNEL_CAPACITY: usize = 256;
42const RING_CAPACITY: usize = 256;
43
44struct TaskChannel {
48 sender: broadcast::Sender<SeqEvent>,
49 next_id: u64,
50 buffer: VecDeque<SeqEvent>,
51 status: StatusSubscribers,
52 artifacts: ArtifactSubscribers,
53}
54
55impl TaskChannel {
56 fn new() -> Self {
57 let (sender, _) = broadcast::channel(CHANNEL_CAPACITY);
58 Self {
59 sender,
60 next_id: 0,
61 buffer: VecDeque::with_capacity(RING_CAPACITY),
62 status: Vec::new(),
63 artifacts: Vec::new(),
64 }
65 }
66
67 fn publish(&mut self, event: UpdateEvent) -> SeqEvent {
70 self.next_id += 1;
71 let seq = SeqEvent::new(self.next_id, event);
72 if self.buffer.len() == RING_CAPACITY {
73 self.buffer.pop_front();
74 }
75 self.buffer.push_back(seq.clone());
76 let _ = self.sender.send(seq.clone());
79 seq
80 }
81
82 fn replay_after(&self, from: u64) -> Vec<SeqEvent> {
84 self.buffer
85 .iter()
86 .filter(|e| e.id > from)
87 .cloned()
88 .collect()
89 }
90}
91
92#[derive(Clone, Default)]
98pub struct InMemoryStreamingHandler {
99 tasks: Arc<Mutex<HashMap<String, TaskChannel>>>,
100}
101
102impl InMemoryStreamingHandler {
103 pub fn new() -> Self {
105 Self::default()
106 }
107}
108
109#[async_trait]
110impl AsyncStreamingHandler for InMemoryStreamingHandler {
111 async fn add_status_subscriber(
112 &self,
113 task_id: &str,
114 subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
115 ) -> Result<String, A2AError> {
116 #[cfg(feature = "tracing")]
117 tracing::info!(
118 task_id = %task_id,
119 "✅ Adding subscriber for status updates"
120 );
121
122 let mut guard = self.tasks.lock().await;
123 guard
124 .entry(task_id.to_string())
125 .or_insert_with(TaskChannel::new)
126 .status
127 .push(subscriber);
128
129 Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
130 }
131
132 async fn add_artifact_subscriber(
133 &self,
134 task_id: &str,
135 subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
136 ) -> Result<String, A2AError> {
137 let mut guard = self.tasks.lock().await;
138 guard
139 .entry(task_id.to_string())
140 .or_insert_with(TaskChannel::new)
141 .artifacts
142 .push(subscriber);
143
144 Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
145 }
146
147 async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
148 Err(A2AError::UnsupportedOperation(
149 "Subscription removal by ID is not supported by the in-memory streaming handler"
150 .to_string(),
151 ))
152 }
153
154 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
155 let mut guard = self.tasks.lock().await;
156 guard.remove(task_id);
157 Ok(())
158 }
159
160 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
161 let guard = self.tasks.lock().await;
162 Ok(guard
163 .get(task_id)
164 .map(|c| c.status.len() + c.artifacts.len() + c.sender.receiver_count())
165 .unwrap_or(0))
166 }
167
168 async fn broadcast_status_update(
169 &self,
170 task_id: &str,
171 update: TaskStatusUpdateEvent,
172 ) -> Result<(), A2AError> {
173 #[cfg(feature = "tracing")]
174 tracing::debug!(
175 task_id = %task_id,
176 state = ?update.status.state,
177 "📡 Broadcasting status update to subscribers"
178 );
179
180 let mut guard = self.tasks.lock().await;
181 let channel = guard
182 .entry(task_id.to_string())
183 .or_insert_with(TaskChannel::new);
184 channel.publish(UpdateEvent::StatusUpdate(update.clone()));
185 for subscriber in channel.status.iter() {
186 if let Err(e) = subscriber.on_update(update.clone()).await {
187 #[cfg(feature = "tracing")]
188 tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
189 #[cfg(not(feature = "tracing"))]
190 let _ = e;
191 }
192 }
193 Ok(())
194 }
195
196 async fn broadcast_artifact_update(
197 &self,
198 task_id: &str,
199 update: TaskArtifactUpdateEvent,
200 ) -> Result<(), A2AError> {
201 let mut guard = self.tasks.lock().await;
202 let channel = guard
203 .entry(task_id.to_string())
204 .or_insert_with(TaskChannel::new);
205 channel.publish(UpdateEvent::ArtifactUpdate(update.clone()));
206 for subscriber in channel.artifacts.iter() {
207 if let Err(e) = subscriber.on_update(update.clone()).await {
208 #[cfg(feature = "tracing")]
209 tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
210 #[cfg(not(feature = "tracing"))]
211 let _ = e;
212 }
213 }
214 Ok(())
215 }
216
217 async fn status_update_stream(
218 &self,
219 _task_id: &str,
220 ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
221 {
222 Err(A2AError::UnsupportedOperation(
223 "Status-only update stream is not supported; use combined_update_stream".to_string(),
224 ))
225 }
226
227 async fn artifact_update_stream(
228 &self,
229 _task_id: &str,
230 ) -> Result<
231 Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
232 A2AError,
233 > {
234 Err(A2AError::UnsupportedOperation(
235 "Artifact-only update stream is not supported; use combined_update_stream".to_string(),
236 ))
237 }
238
239 async fn combined_update_stream(
240 &self,
241 task_id: &str,
242 from_event_id: Option<u64>,
243 ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
244 let mut guard = self.tasks.lock().await;
245 let channel = guard
246 .entry(task_id.to_string())
247 .or_insert_with(TaskChannel::new);
248 let receiver = channel.sender.subscribe();
249 let replay = from_event_id
250 .map(|from| channel.replay_after(from))
251 .unwrap_or_default();
252 drop(guard);
253
254 let live = futures::stream::unfold(receiver, |mut rx| async move {
255 match rx.recv().await {
256 Ok(event) => Some((Ok(event), rx)),
257 Err(broadcast::error::RecvError::Lagged(n)) => Some((
260 Err(A2AError::Internal(format!(
261 "streaming reader lagged, dropped {n} events"
262 ))),
263 rx,
264 )),
265 Err(broadcast::error::RecvError::Closed) => None,
266 }
267 });
268
269 let stream = futures::stream::iter(replay.into_iter().map(Ok)).chain(live);
270 Ok(Box::pin(stream))
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::domain::{TaskState, TaskStatus, TaskStatusUpdateEvent};
278
279 fn status_event(task_id: &str, state: TaskState) -> TaskStatusUpdateEvent {
280 TaskStatusUpdateEvent {
281 task_id: task_id.to_string(),
282 context_id: "ctx".to_string(),
283 kind: "status-update".to_string(),
284 status: TaskStatus::new(state, None),
285 metadata: None,
286 }
287 }
288
289 fn seq_state(seq: &SeqEvent) -> ::buffa::EnumValue<TaskState> {
290 match &seq.event {
291 UpdateEvent::StatusUpdate(e) => e.status.state,
292 UpdateEvent::ArtifactUpdate(_) => panic!("expected status update"),
293 }
294 }
295
296 #[tokio::test]
299 async fn live_stream_delivers_in_order_with_ids() {
300 let handler = InMemoryStreamingHandler::new();
301 let mut stream = handler.combined_update_stream("t1", None).await.unwrap();
302
303 handler
304 .broadcast_status_update("t1", status_event("t1", TaskState::Working))
305 .await
306 .unwrap();
307 handler
308 .broadcast_status_update("t1", status_event("t1", TaskState::Completed))
309 .await
310 .unwrap();
311
312 let first = stream.next().await.unwrap().unwrap();
313 let second = stream.next().await.unwrap().unwrap();
314 assert_eq!(first.id, 1);
315 assert_eq!(
316 seq_state(&first),
317 ::buffa::EnumValue::from(TaskState::Working)
318 );
319 assert_eq!(second.id, 2);
320 assert_eq!(
321 seq_state(&second),
322 ::buffa::EnumValue::from(TaskState::Completed)
323 );
324 }
325
326 #[tokio::test]
329 async fn resume_replays_buffered_tail() {
330 let handler = InMemoryStreamingHandler::new();
331 handler
333 .broadcast_status_update("t1", status_event("t1", TaskState::Working))
334 .await
335 .unwrap();
336 handler
337 .broadcast_status_update("t1", status_event("t1", TaskState::Completed))
338 .await
339 .unwrap();
340
341 let mut stream = handler.combined_update_stream("t1", Some(1)).await.unwrap();
343 let replayed = stream.next().await.unwrap().unwrap();
344 assert_eq!(replayed.id, 2);
345 assert_eq!(
346 seq_state(&replayed),
347 ::buffa::EnumValue::from(TaskState::Completed)
348 );
349 }
350
351 #[tokio::test]
354 async fn callback_subscriber_still_notified() {
355 use std::sync::Mutex as StdMutex;
356
357 #[derive(Default, Clone)]
358 struct Recorder {
359 seen: Arc<StdMutex<Vec<::buffa::EnumValue<TaskState>>>>,
360 }
361 #[async_trait]
362 impl Subscriber<TaskStatusUpdateEvent> for Recorder {
363 async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> {
364 self.seen.lock().unwrap().push(update.status.state);
365 Ok(())
366 }
367 }
368
369 let handler = InMemoryStreamingHandler::new();
370 let recorder = Recorder::default();
371 handler
372 .add_status_subscriber("t1", Box::new(recorder.clone()))
373 .await
374 .unwrap();
375 handler
376 .broadcast_status_update("t1", status_event("t1", TaskState::Working))
377 .await
378 .unwrap();
379
380 assert_eq!(
381 *recorder.seen.lock().unwrap(),
382 vec![::buffa::EnumValue::from(TaskState::Working)]
383 );
384 }
385}