use std::future::Future;
use std::pin::Pin;
use a2a_protocol_types::error::{A2aError, A2aResult};
use a2a_protocol_types::events::StreamResponse;
use tokio::sync::{broadcast, mpsc};
use super::{EventQueueReader, EventQueueWriter};
struct CountingWriter(usize);
impl std::io::Write for CountingWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
self.0 += buf.len();
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct InMemoryQueueWriter {
tx: broadcast::Sender<A2aResult<StreamResponse>>,
persistence_tx: Option<mpsc::Sender<A2aResult<StreamResponse>>>,
max_event_size: usize,
#[allow(dead_code)]
write_timeout: std::time::Duration,
}
impl InMemoryQueueWriter {
pub(super) const fn new(
tx: broadcast::Sender<A2aResult<StreamResponse>>,
max_event_size: usize,
write_timeout: std::time::Duration,
) -> Self {
Self {
tx,
persistence_tx: None,
max_event_size,
write_timeout,
}
}
pub(super) const fn new_with_persistence(
tx: broadcast::Sender<A2aResult<StreamResponse>>,
persistence_tx: mpsc::Sender<A2aResult<StreamResponse>>,
max_event_size: usize,
write_timeout: std::time::Duration,
) -> Self {
Self {
tx,
persistence_tx: Some(persistence_tx),
max_event_size,
write_timeout,
}
}
#[must_use]
pub fn subscribe(&self) -> InMemoryQueueReader {
InMemoryQueueReader::new(self.tx.subscribe())
}
pub(crate) fn raw_subscribe(&self) -> broadcast::Receiver<A2aResult<StreamResponse>> {
self.tx.subscribe()
}
}
#[allow(clippy::manual_async_fn)]
impl EventQueueWriter for InMemoryQueueWriter {
fn write<'a>(
&'a self,
event: StreamResponse,
) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
let serialized_size = {
let mut counter = CountingWriter(0);
serde_json::to_writer(&mut counter, &event)
.map_err(|e| A2aError::internal(format!("event serialization failed: {e}")))?;
counter.0
};
if serialized_size > self.max_event_size {
return Err(A2aError::internal(format!(
"event size {serialized_size} bytes exceeds maximum {} bytes",
self.max_event_size
)));
}
if let Some(ref persistence_tx) = self.persistence_tx {
if let Err(_e) = persistence_tx.send(Ok(event.clone())).await {
trace_warn!("persistence channel closed, event not persisted");
}
}
self.tx
.send(Ok(event))
.map(|_| ())
.map_err(|_| A2aError::internal("event queue: no active receivers"))
})
}
fn close<'a>(&'a self) -> Pin<Box<dyn Future<Output = A2aResult<()>> + Send + 'a>> {
Box::pin(async move {
Ok(())
})
}
}
#[derive(Debug)]
pub struct InMemoryQueueReader {
rx: broadcast::Receiver<A2aResult<StreamResponse>>,
pending_first: Option<A2aResult<StreamResponse>>,
}
impl InMemoryQueueReader {
pub(crate) const fn new(rx: broadcast::Receiver<A2aResult<StreamResponse>>) -> Self {
Self {
rx,
pending_first: None,
}
}
pub fn set_first_event(&mut self, event: StreamResponse) {
self.pending_first = Some(Ok(event));
}
pub(crate) const fn with_first_event(
rx: broadcast::Receiver<A2aResult<StreamResponse>>,
first: StreamResponse,
) -> Self {
Self {
rx,
pending_first: Some(Ok(first)),
}
}
}
impl EventQueueReader for InMemoryQueueReader {
fn read(
&mut self,
) -> Pin<Box<dyn Future<Output = Option<A2aResult<StreamResponse>>> + Send + '_>> {
Box::pin(async move {
if let Some(first) = self.pending_first.take() {
return Some(first);
}
loop {
match self.rx.recv().await {
Ok(event) => return Some(event),
Err(broadcast::error::RecvError::Lagged(_n)) => {
trace_warn!(
dropped_events = _n,
"event queue reader lagged, {_n} events skipped"
);
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::streaming::event_queue::{
new_in_memory_queue, new_in_memory_queue_with_options, DEFAULT_MAX_EVENT_SIZE,
DEFAULT_WRITE_TIMEOUT,
};
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::task::{ContextId, TaskId, TaskState, TaskStatus};
fn make_status_event(task_id: &str, state: TaskState) -> StreamResponse {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: TaskId::new(task_id),
context_id: ContextId::new("ctx-test"),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
}
#[tokio::test]
async fn write_then_read_single_event() {
let (writer, mut reader) = new_in_memory_queue();
let event = make_status_event("t1", TaskState::Working);
writer.write(event).await.expect("write should succeed");
drop(writer);
let received = reader.read().await;
assert!(received.is_some(), "reader should return the written event");
let result = received.unwrap();
let event = result.expect("event should be Ok");
match &event {
StreamResponse::StatusUpdate(evt) => {
assert_eq!(
evt.status.state,
TaskState::Working,
"should be Working event"
);
}
other => panic!("expected StatusUpdate, got: {other:?}"),
}
let eof = reader.read().await;
assert!(
eof.is_none(),
"reader should return None after writer is dropped"
);
}
#[tokio::test]
async fn write_multiple_events_read_in_order() {
let (writer, mut reader) = new_in_memory_queue();
let e1 = make_status_event("t1", TaskState::Working);
let e2 = make_status_event("t1", TaskState::Completed);
writer.write(e1).await.expect("first write should succeed");
writer.write(e2).await.expect("second write should succeed");
drop(writer);
let r1 = reader.read().await.expect("should read first event");
let sr1 = r1.expect("first event should be Ok");
match &sr1 {
StreamResponse::StatusUpdate(evt) => {
assert_eq!(
evt.status.state,
TaskState::Working,
"first event should be Working"
);
}
other => panic!("expected StatusUpdate, got: {other:?}"),
}
let r2 = reader.read().await.expect("should read second event");
let sr2 = r2.expect("second event should be Ok");
match &sr2 {
StreamResponse::StatusUpdate(evt) => {
assert_eq!(
evt.status.state,
TaskState::Completed,
"second event should be Completed"
);
}
other => panic!("expected StatusUpdate, got: {other:?}"),
}
assert!(
reader.read().await.is_none(),
"should be EOF after all events"
);
}
#[tokio::test]
async fn read_returns_none_on_empty_closed_queue() {
let (writer, mut reader) = new_in_memory_queue();
drop(writer);
let result = reader.read().await;
assert!(
result.is_none(),
"reading from an empty closed queue should return None"
);
}
#[tokio::test]
async fn write_after_all_readers_dropped_returns_error() {
let (writer, reader) = new_in_memory_queue();
drop(reader);
let result = writer
.write(make_status_event("t1", TaskState::Working))
.await;
assert!(
result.is_err(),
"writing with no active receivers should return an error"
);
}
#[tokio::test]
async fn close_is_no_op_and_succeeds() {
let (writer, _reader) = new_in_memory_queue();
let result = writer.close().await;
assert!(result.is_ok(), "close() should succeed");
}
#[tokio::test]
async fn subscribe_creates_independent_reader() {
let (writer, mut reader1) = new_in_memory_queue();
let mut reader2 = writer.subscribe();
let event = make_status_event("t1", TaskState::Working);
writer.write(event).await.expect("write should succeed");
drop(writer);
let r1 = reader1.read().await;
assert!(r1.is_some(), "reader1 should receive the event");
let r2 = reader2.read().await;
assert!(r2.is_some(), "reader2 should receive the event");
assert!(reader1.read().await.is_none(), "reader1 should see EOF");
assert!(reader2.read().await.is_none(), "reader2 should see EOF");
}
#[tokio::test]
async fn subscriber_only_sees_events_after_subscribe() {
let (writer, mut reader1) = new_in_memory_queue();
writer
.write(make_status_event("t1", TaskState::Submitted))
.await
.expect("write should succeed");
let mut reader2 = writer.subscribe();
writer
.write(make_status_event("t1", TaskState::Working))
.await
.expect("write should succeed");
drop(writer);
let r1a = reader1
.read()
.await
.expect("reader1 should see first event");
let evt1a = r1a.expect("first event should be Ok");
assert!(
matches!(&evt1a, StreamResponse::StatusUpdate(e) if e.status.state == TaskState::Submitted),
"reader1 first event should be Submitted"
);
let r1b = reader1
.read()
.await
.expect("reader1 should see second event");
let evt_1b = r1b.expect("second event should be Ok");
assert!(
matches!(&evt_1b, StreamResponse::StatusUpdate(e) if e.status.state == TaskState::Working),
"reader1 second event should be Working"
);
assert!(reader1.read().await.is_none());
let r2a = reader2
.read()
.await
.expect("reader2 should see second event");
let evt2a = r2a.expect("event should be Ok");
assert!(
matches!(&evt2a, StreamResponse::StatusUpdate(e) if e.status.state == TaskState::Working),
"reader2 should see Working event"
);
assert!(
reader2.read().await.is_none(),
"reader2 should see EOF after the one event it received"
);
}
#[tokio::test]
async fn oversized_event_is_rejected() {
let (writer, _reader) = new_in_memory_queue_with_options(
16,
10, DEFAULT_WRITE_TIMEOUT,
);
let event = make_status_event("t1", TaskState::Working);
let result = writer.write(event).await;
assert!(
result.is_err(),
"event exceeding max_event_size should be rejected"
);
let err = result.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("exceeds maximum"),
"error message should mention size limit, got: {msg}"
);
}
#[test]
fn counting_writer_flush_is_noop() {
use std::io::Write;
let mut cw = super::CountingWriter(0);
cw.write_all(b"hello").unwrap();
assert_eq!(cw.0, 5);
cw.flush().unwrap();
assert_eq!(cw.0, 5, "flush should not change the count");
}
#[tokio::test]
async fn event_within_size_limit_is_accepted() {
let (writer, mut reader) =
new_in_memory_queue_with_options(16, DEFAULT_MAX_EVENT_SIZE, DEFAULT_WRITE_TIMEOUT);
let event = make_status_event("t1", TaskState::Working);
writer
.write(event)
.await
.expect("event within size limit should be accepted");
drop(writer);
let r = reader.read().await;
assert!(r.is_some(), "reader should receive the event");
}
}