use std::collections::HashMap;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
use async_trait::async_trait;
use futures::{Stream, StreamExt};
use tokio::sync::Mutex;
use tokio::sync::broadcast;
use crate::domain::{A2AError, TaskArtifactUpdateEvent, TaskStatusUpdateEvent};
use crate::port::AsyncStreamingHandler;
use crate::port::streaming_handler::{SeqEvent, Subscriber, UpdateEvent};
type StatusSubscribers = Vec<Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>>;
type ArtifactSubscribers = Vec<Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>>;
const CHANNEL_CAPACITY: usize = 256;
const RING_CAPACITY: usize = 256;
struct TaskChannel {
sender: broadcast::Sender<SeqEvent>,
next_id: u64,
buffer: VecDeque<SeqEvent>,
status: StatusSubscribers,
artifacts: ArtifactSubscribers,
}
impl TaskChannel {
fn new() -> Self {
let (sender, _) = broadcast::channel(CHANNEL_CAPACITY);
Self {
sender,
next_id: 0,
buffer: VecDeque::with_capacity(RING_CAPACITY),
status: Vec::new(),
artifacts: Vec::new(),
}
}
fn publish(&mut self, event: UpdateEvent) -> SeqEvent {
self.next_id += 1;
let seq = SeqEvent::new(self.next_id, event);
if self.buffer.len() == RING_CAPACITY {
self.buffer.pop_front();
}
self.buffer.push_back(seq.clone());
let _ = self.sender.send(seq.clone());
seq
}
fn replay_after(&self, from: u64) -> Vec<SeqEvent> {
self.buffer
.iter()
.filter(|e| e.id > from)
.cloned()
.collect()
}
}
#[derive(Clone, Default)]
pub struct InMemoryStreamingHandler {
tasks: Arc<Mutex<HashMap<String, TaskChannel>>>,
}
impl InMemoryStreamingHandler {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl AsyncStreamingHandler for InMemoryStreamingHandler {
async fn add_status_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskStatusUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
#[cfg(feature = "tracing")]
tracing::info!(
task_id = %task_id,
"✅ Adding subscriber for status updates"
);
let mut guard = self.tasks.lock().await;
guard
.entry(task_id.to_string())
.or_insert_with(TaskChannel::new)
.status
.push(subscriber);
Ok(format!("status-{}-{}", task_id, uuid::Uuid::new_v4()))
}
async fn add_artifact_subscriber(
&self,
task_id: &str,
subscriber: Box<dyn Subscriber<TaskArtifactUpdateEvent> + Send + Sync>,
) -> Result<String, A2AError> {
let mut guard = self.tasks.lock().await;
guard
.entry(task_id.to_string())
.or_insert_with(TaskChannel::new)
.artifacts
.push(subscriber);
Ok(format!("artifact-{}-{}", task_id, uuid::Uuid::new_v4()))
}
async fn remove_subscription(&self, _subscription_id: &str) -> Result<(), A2AError> {
Err(A2AError::UnsupportedOperation(
"Subscription removal by ID is not supported by the in-memory streaming handler"
.to_string(),
))
}
async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
let mut guard = self.tasks.lock().await;
guard.remove(task_id);
Ok(())
}
async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
let guard = self.tasks.lock().await;
Ok(guard
.get(task_id)
.map(|c| c.status.len() + c.artifacts.len() + c.sender.receiver_count())
.unwrap_or(0))
}
async fn broadcast_status_update(
&self,
task_id: &str,
update: TaskStatusUpdateEvent,
) -> Result<(), A2AError> {
#[cfg(feature = "tracing")]
tracing::debug!(
task_id = %task_id,
state = ?update.status.state,
"📡 Broadcasting status update to subscribers"
);
let mut guard = self.tasks.lock().await;
let channel = guard
.entry(task_id.to_string())
.or_insert_with(TaskChannel::new);
channel.publish(UpdateEvent::StatusUpdate(update.clone()));
for subscriber in channel.status.iter() {
if let Err(e) = subscriber.on_update(update.clone()).await {
#[cfg(feature = "tracing")]
tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
#[cfg(not(feature = "tracing"))]
let _ = e;
}
}
Ok(())
}
async fn broadcast_artifact_update(
&self,
task_id: &str,
update: TaskArtifactUpdateEvent,
) -> Result<(), A2AError> {
let mut guard = self.tasks.lock().await;
let channel = guard
.entry(task_id.to_string())
.or_insert_with(TaskChannel::new);
channel.publish(UpdateEvent::ArtifactUpdate(update.clone()));
for subscriber in channel.artifacts.iter() {
if let Err(e) = subscriber.on_update(update.clone()).await {
#[cfg(feature = "tracing")]
tracing::error!(task_id = %task_id, error = %e, "❌ Failed to notify subscriber");
#[cfg(not(feature = "tracing"))]
let _ = e;
}
}
Ok(())
}
async fn status_update_stream(
&self,
_task_id: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
{
Err(A2AError::UnsupportedOperation(
"Status-only update stream is not supported; use combined_update_stream".to_string(),
))
}
async fn artifact_update_stream(
&self,
_task_id: &str,
) -> Result<
Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
A2AError,
> {
Err(A2AError::UnsupportedOperation(
"Artifact-only update stream is not supported; use combined_update_stream".to_string(),
))
}
async fn combined_update_stream(
&self,
task_id: &str,
from_event_id: Option<u64>,
) -> Result<Pin<Box<dyn Stream<Item = Result<SeqEvent, A2AError>> + Send>>, A2AError> {
let mut guard = self.tasks.lock().await;
let channel = guard
.entry(task_id.to_string())
.or_insert_with(TaskChannel::new);
let receiver = channel.sender.subscribe();
let replay = from_event_id
.map(|from| channel.replay_after(from))
.unwrap_or_default();
drop(guard);
let live = futures::stream::unfold(receiver, |mut rx| async move {
match rx.recv().await {
Ok(event) => Some((Ok(event), rx)),
Err(broadcast::error::RecvError::Lagged(n)) => Some((
Err(A2AError::Internal(format!(
"streaming reader lagged, dropped {n} events"
))),
rx,
)),
Err(broadcast::error::RecvError::Closed) => None,
}
});
let stream = futures::stream::iter(replay.into_iter().map(Ok)).chain(live);
Ok(Box::pin(stream))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::domain::{TaskState, TaskStatus, TaskStatusUpdateEvent};
fn status_event(task_id: &str, state: TaskState) -> TaskStatusUpdateEvent {
TaskStatusUpdateEvent {
task_id: task_id.to_string(),
context_id: "ctx".to_string(),
kind: "status-update".to_string(),
status: TaskStatus::new(state, None),
metadata: None,
}
}
fn seq_state(seq: &SeqEvent) -> ::buffa::EnumValue<TaskState> {
match &seq.event {
UpdateEvent::StatusUpdate(e) => e.status.state,
UpdateEvent::ArtifactUpdate(_) => panic!("expected status update"),
}
}
#[tokio::test]
async fn live_stream_delivers_in_order_with_ids() {
let handler = InMemoryStreamingHandler::new();
let mut stream = handler.combined_update_stream("t1", None).await.unwrap();
handler
.broadcast_status_update("t1", status_event("t1", TaskState::Working))
.await
.unwrap();
handler
.broadcast_status_update("t1", status_event("t1", TaskState::Completed))
.await
.unwrap();
let first = stream.next().await.unwrap().unwrap();
let second = stream.next().await.unwrap().unwrap();
assert_eq!(first.id, 1);
assert_eq!(
seq_state(&first),
::buffa::EnumValue::from(TaskState::Working)
);
assert_eq!(second.id, 2);
assert_eq!(
seq_state(&second),
::buffa::EnumValue::from(TaskState::Completed)
);
}
#[tokio::test]
async fn resume_replays_buffered_tail() {
let handler = InMemoryStreamingHandler::new();
handler
.broadcast_status_update("t1", status_event("t1", TaskState::Working))
.await
.unwrap();
handler
.broadcast_status_update("t1", status_event("t1", TaskState::Completed))
.await
.unwrap();
let mut stream = handler.combined_update_stream("t1", Some(1)).await.unwrap();
let replayed = stream.next().await.unwrap().unwrap();
assert_eq!(replayed.id, 2);
assert_eq!(
seq_state(&replayed),
::buffa::EnumValue::from(TaskState::Completed)
);
}
#[tokio::test]
async fn callback_subscriber_still_notified() {
use std::sync::Mutex as StdMutex;
#[derive(Default, Clone)]
struct Recorder {
seen: Arc<StdMutex<Vec<::buffa::EnumValue<TaskState>>>>,
}
#[async_trait]
impl Subscriber<TaskStatusUpdateEvent> for Recorder {
async fn on_update(&self, update: TaskStatusUpdateEvent) -> Result<(), A2AError> {
self.seen.lock().unwrap().push(update.status.state);
Ok(())
}
}
let handler = InMemoryStreamingHandler::new();
let recorder = Recorder::default();
handler
.add_status_subscriber("t1", Box::new(recorder.clone()))
.await
.unwrap();
handler
.broadcast_status_update("t1", status_event("t1", TaskState::Working))
.await
.unwrap();
assert_eq!(
*recorder.seen.lock().unwrap(),
vec![::buffa::EnumValue::from(TaskState::Working)]
);
}
}