Skip to main content

a2a_rs/adapter/streaming/
in_memory.rs

1//! In-memory streaming fan-out adapter.
2//!
3//! `InMemoryStreamingHandler` is the [`AsyncStreamingHandler`] adapter. It owns
4//! **only** the per-task fan-out state — a broadcast channel plus a bounded
5//! replay buffer, and an optional set of synchronous callback subscribers — and
6//! fans broadcast events out to live `combined_update_stream` readers and to
7//! those subscribers. It deliberately does *not*:
8//!
9//! - touch the task store (so it cannot replay current task state on subscribe —
10//!   the initial `Task` snapshot is delivered by the application service before
11//!   stream items, which is spec-compliant), nor
12//! - fire push-webhook notifications (that is the [`AsyncPushNotifier`] port's
13//!   job, orchestrated by the
14//!   [`TaskStatusBroadcast`](crate::application::TaskStatusBroadcast) mixin).
15//!
16//! Each broadcast event is assigned a per-task monotonic id and retained in a
17//! bounded ring buffer, so a reconnecting client can resume after a disconnect
18//! by passing the last id it observed (`from_event_id`); the handler replays the
19//! buffered tail with a greater id before switching to live updates.
20//!
21//! [`AsyncPushNotifier`]: crate::port::AsyncPushNotifier
22
23use std::collections::HashMap;
24use std::collections::VecDeque;
25use std::pin::Pin;
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use futures::{Stream, StreamExt};
30use tokio::sync::Mutex;
31use tokio::sync::broadcast;
32
33use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
34use crate::port::AsyncStreamingHandler;
35use crate::port::streaming_handler::{SeqEvent, Subscriber, UpdateEvent};
36
37type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
38type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
39
40/// Capacity of the per-task broadcast channel and replay ring buffer.
41const CHANNEL_CAPACITY: usize = 256;
42const RING_CAPACITY: usize = 256;
43
44/// Per-task fan-out state: a broadcast channel for live readers, a bounded
45/// replay buffer keyed by monotonic id, and any synchronous callback
46/// subscribers.
47struct TaskChannel {
48    sender: broadcast::Sender<SeqEvent>,
49    next_id: u64,
50    buffer: VecDeque<SeqEvent>,
51    status: StatusSubscribers,
52    artifacts: ArtifactSubscribers,
53}
54
55impl TaskChannel {
56    fn new() -> Self {
57        let (sender, _) = broadcast::channel(CHANNEL_CAPACITY);
58        Self {
59            sender,
60            next_id: 0,
61            buffer: VecDeque::with_capacity(RING_CAPACITY),
62            status: Vec::new(),
63            artifacts: Vec::new(),
64        }
65    }
66
67    /// Assign the next id, retain the event for replay, and publish it to live
68    /// readers. Returns the sequenced event for any further fan-out.
69    fn publish(&mut self, event: UpdateEvent) -> SeqEvent {
70        self.next_id += 1;
71        let seq = SeqEvent::new(self.next_id, event);
72        if self.buffer.len() == RING_CAPACITY {
73            self.buffer.pop_front();
74        }
75        self.buffer.push_back(seq.clone());
76        // A send error just means there are no live receivers; the buffer still
77        // retains the event for a later resume, so the error is ignored.
78        let _ = self.sender.send(seq.clone());
79        seq
80    }
81
82    /// Buffered events with an id strictly greater than `from`, in order.
83    fn replay_after(&self, from: u64) -> Vec<SeqEvent> {
84        self.buffer
85            .iter()
86            .filter(|e| e.id > from)
87            .cloned()
88            .collect()
89    }
90}
91
92/// In-memory [`AsyncStreamingHandler`]: per-task broadcast fan-out with a
93/// bounded replay buffer for Last-Event-ID resumption.
94///
95/// Cloning shares the underlying per-task state (an `Arc<Mutex<…>>`), so a clone
96/// observes the same channels and subscribers.
97#[derive(Clone, Default)]
98pub struct InMemoryStreamingHandler {
99    tasks: Arc<Mutex<HashMap<String, TaskChannel>>>,
100}
101
102impl InMemoryStreamingHandler {
103    /// Create an empty streaming handler.
104    pub fn new() -> Self {
105        Self::default()
106    }
107}
108
109#[async_trait]
110impl AsyncStreamingHandler for InMemoryStreamingHandler {
111    async fn add_status_subscriber(
112        &self,
113        task_id: &str,
114        subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
115    ) -> Result<String, A2AError> {
116        #[cfg(feature = "tracing")]
117        tracing::info!(
118            task_id = %task_id,
119            "✅ Adding subscriber for status updates"
120        );
121
122        let mut guard = self.tasks.lock().await;
123        guard
124            .entry(task_id.to_string())
125            .or_insert_with(TaskChannel::new)
126            .status
127            .push(subscriber);
128
129        Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
130    }
131
132    async fn add_artifact_subscriber(
133        &self,
134        task_id: &str,
135        subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
136    ) -> Result<String, A2AError> {
137        let mut guard = self.tasks.lock().await;
138        guard
139            .entry(task_id.to_string())
140            .or_insert_with(TaskChannel::new)
141            .artifacts
142            .push(subscriber);
143
144        Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
145    }
146
147    async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
148        Err(A2AError::UnsupportedOperation(
149            "Subscription removal by ID is not supported by the in-memory streaming handler"
150                .to_string(),
151        ))
152    }
153
154    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
155        let mut guard = self.tasks.lock().await;
156        guard.remove(task_id);
157        Ok(())
158    }
159
160    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
161        let guard = self.tasks.lock().await;
162        Ok(guard
163            .get(task_id)
164            .map(|c| c.status.len() + c.artifacts.len() + c.sender.receiver_count())
165            .unwrap_or(0))
166    }
167
168    async fn broadcast_status_update(
169        &self,
170        task_id: &str,
171        update: TaskStatusUpdateEvent,
172    ) -> Result<(), A2AError> {
173        #[cfg(feature = "tracing")]
174        tracing::debug!(
175            task_id = %task_id,
176            state = ?update.status.state,
177            "📡 Broadcasting status update to subscribers"
178        );
179
180        let mut guard = self.tasks.lock().await;
181        let channel = guard
182            .entry(task_id.to_string())
183            .or_insert_with(TaskChannel::new);
184        channel.publish(UpdateEvent::StatusUpdate(update.clone()));
185        for subscriber in channel.status.iter() {
186            if let Err(e) = subscriber.on_update(update.clone()).await {
187                #[cfg(feature = "tracing")]
188                tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
189                #[cfg(not(feature = "tracing"))]
190                let _ = e;
191            }
192        }
193        Ok(())
194    }
195
196    async fn broadcast_artifact_update(
197        &self,
198        task_id: &str,
199        update: TaskArtifactUpdateEvent,
200    ) -> Result<(), A2AError> {
201        let mut guard = self.tasks.lock().await;
202        let channel = guard
203            .entry(task_id.to_string())
204            .or_insert_with(TaskChannel::new);
205        channel.publish(UpdateEvent::ArtifactUpdate(update.clone()));
206        for subscriber in channel.artifacts.iter() {
207            if let Err(e) = subscriber.on_update(update.clone()).await {
208                #[cfg(feature = "tracing")]
209                tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
210                #[cfg(not(feature = "tracing"))]
211                let _ = e;
212            }
213        }
214        Ok(())
215    }
216
217    async fn status_update_stream(
218        &self,
219        _task_id: &str,
220    ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
221    {
222        Err(A2AError::UnsupportedOperation(
223            "Status-only update stream is not supported; use combined_update_stream".to_string(),
224        ))
225    }
226
227    async fn artifact_update_stream(
228        &self,
229        _task_id: &str,
230    ) -> Result<
231        Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
232        A2AError,
233    > {
234        Err(A2AError::UnsupportedOperation(
235            "Artifact-only update stream is not supported; use combined_update_stream".to_string(),
236        ))
237    }
238
239    async fn combined_update_stream(
240        &self,
241        task_id: &str,
242        from_event_id: Option<u64>,
243    ) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
244        let mut guard = self.tasks.lock().await;
245        let channel = guard
246            .entry(task_id.to_string())
247            .or_insert_with(TaskChannel::new);
248        let receiver = channel.sender.subscribe();
249        let replay = from_event_id
250            .map(|from| channel.replay_after(from))
251            .unwrap_or_default();
252        drop(guard);
253
254        let live = futures::stream::unfold(receiver, |mut rx| async move {
255            match rx.recv().await {
256                Ok(event) => Some((Ok(event), rx)),
257                // Reader fell behind the ring buffer: surface an error so a
258                // resilient client reconnects and resumes from its last id.
259                Err(broadcast::error::RecvError::Lagged(n)) => Some((
260                    Err(A2AError::Internal(format!(
261                        "streaming reader lagged, dropped {n} events"
262                    ))),
263                    rx,
264                )),
265                Err(broadcast::error::RecvError::Closed) => None,
266            }
267        });
268
269        let stream = futures::stream::iter(replay.into_iter().map(Ok)).chain(live);
270        Ok(Box::pin(stream))
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::domain::{TaskState, TaskStatus, TaskStatusUpdateEvent};
278
279    fn status_event(task_id: &str, state: TaskState) -> TaskStatusUpdateEvent {
280        TaskStatusUpdateEvent {
281            task_id: task_id.to_string(),
282            context_id: "ctx".to_string(),
283            kind: "status-update".to_string(),
284            status: TaskStatus::new(state, None),
285            metadata: None,
286        }
287    }
288
289    fn seq_state(seq: &SeqEvent) -> ::buffa::EnumValue<TaskState> {
290        match &seq.event {
291            UpdateEvent::StatusUpdate(e) => e.status.state,
292            UpdateEvent::ArtifactUpdate(_) => panic!("expected status update"),
293        }
294    }
295
296    /// A live `combined_update_stream` reader receives broadcasts in order, each
297    /// tagged with a monotonic id starting at 1.
298    #[tokio::test]
299    async fn live_stream_delivers_in_order_with_ids() {
300        let handler = InMemoryStreamingHandler::new();
301        let mut stream = handler.combined_update_stream("t1", None).await.unwrap();
302
303        handler
304            .broadcast_status_update("t1", status_event("t1", TaskState::Working))
305            .await
306            .unwrap();
307        handler
308            .broadcast_status_update("t1", status_event("t1", TaskState::Completed))
309            .await
310            .unwrap();
311
312        let first = stream.next().await.unwrap().unwrap();
313        let second = stream.next().await.unwrap().unwrap();
314        assert_eq!(first.id, 1);
315        assert_eq!(
316            seq_state(&first),
317            ::buffa::EnumValue::from(TaskState::Working)
318        );
319        assert_eq!(second.id, 2);
320        assert_eq!(
321            seq_state(&second),
322            ::buffa::EnumValue::from(TaskState::Completed)
323        );
324    }
325
326    /// Subscribing with `from_event_id` replays the buffered tail with a greater
327    /// id before any live updates.
328    #[tokio::test]
329    async fn resume_replays_buffered_tail() {
330        let handler = InMemoryStreamingHandler::new();
331        // Emit two events with no live reader; they are retained in the buffer.
332        handler
333            .broadcast_status_update("t1", status_event("t1", TaskState::Working))
334            .await
335            .unwrap();
336        handler
337            .broadcast_status_update("t1", status_event("t1", TaskState::Completed))
338            .await
339            .unwrap();
340
341        // Resume from id 1: only event 2 should replay.
342        let mut stream = handler.combined_update_stream("t1", Some(1)).await.unwrap();
343        let replayed = stream.next().await.unwrap().unwrap();
344        assert_eq!(replayed.id, 2);
345        assert_eq!(
346            seq_state(&replayed),
347            ::buffa::EnumValue::from(TaskState::Completed)
348        );
349    }
350
351    /// A synchronous callback subscriber still receives broadcasts (the push API
352    /// rides alongside the broadcast channel).
353    #[tokio::test]
354    async fn callback_subscriber_still_notified() {
355        use std::sync::Mutex as StdMutex;
356
357        #[derive(Default, Clone)]
358        struct Recorder {
359            seen: Arc<StdMutex<Vec<::buffa::EnumValue<TaskState>>>>,
360        }
361        #[async_trait]
362        impl Subscriber<TaskStatusUpdateEvent> for Recorder {
363            async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> {
364                self.seen.lock().unwrap().push(update.status.state);
365                Ok(())
366            }
367        }
368
369        let handler = InMemoryStreamingHandler::new();
370        let recorder = Recorder::default();
371        handler
372            .add_status_subscriber("t1", Box::new(recorder.clone()))
373            .await
374            .unwrap();
375        handler
376            .broadcast_status_update("t1", status_event("t1", TaskState::Working))
377            .await
378            .unwrap();
379
380        assert_eq!(
381            *recorder.seen.lock().unwrap(),
382            vec![::buffa::EnumValue::from(TaskState::Working)]
383        );
384    }
385}