use std::collections::HashMap;
use std::sync::Arc;
use a2a_protocol_types::task::TaskId;
use tokio::sync::RwLock;
use a2a_protocol_types::error::A2aResult;
use a2a_protocol_types::events::StreamResponse;
use super::{
new_in_memory_queue_with_options, new_in_memory_queue_with_persistence, InMemoryQueueReader,
InMemoryQueueWriter, DEFAULT_MAX_EVENT_SIZE, DEFAULT_QUEUE_CAPACITY, DEFAULT_WRITE_TIMEOUT,
};
use crate::metrics::Metrics;
#[derive(Clone)]
pub struct EventQueueManager {
writers: Arc<RwLock<HashMap<TaskId, Arc<InMemoryQueueWriter>>>>,
capacity: usize,
max_event_size: usize,
write_timeout: std::time::Duration,
max_concurrent_queues: Option<usize>,
metrics: Option<Arc<dyn Metrics>>,
}
impl std::fmt::Debug for EventQueueManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventQueueManager")
.field("writers", &"<RwLock<HashMap<...>>>")
.field("capacity", &self.capacity)
.field("max_event_size", &self.max_event_size)
.field("write_timeout", &self.write_timeout)
.field("max_concurrent_queues", &self.max_concurrent_queues)
.field("metrics", &self.metrics.is_some())
.finish()
}
}
impl Default for EventQueueManager {
fn default() -> Self {
Self {
writers: Arc::default(),
capacity: DEFAULT_QUEUE_CAPACITY,
max_event_size: DEFAULT_MAX_EVENT_SIZE,
write_timeout: DEFAULT_WRITE_TIMEOUT,
max_concurrent_queues: None,
metrics: None,
}
}
}
impl EventQueueManager {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
writers: Arc::default(),
capacity,
max_event_size: DEFAULT_MAX_EVENT_SIZE,
write_timeout: DEFAULT_WRITE_TIMEOUT,
max_concurrent_queues: None,
metrics: None,
}
}
#[must_use]
pub const fn with_write_timeout(mut self, timeout: std::time::Duration) -> Self {
self.write_timeout = timeout;
self
}
#[must_use]
pub const fn with_max_event_size(mut self, max_event_size: usize) -> Self {
self.max_event_size = max_event_size;
self
}
#[must_use]
pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
self.metrics = Some(metrics);
self
}
#[must_use]
pub const fn with_max_concurrent_queues(mut self, max: usize) -> Self {
self.max_concurrent_queues = Some(max);
self
}
pub async fn get_or_create(
&self,
task_id: &TaskId,
) -> (Arc<InMemoryQueueWriter>, Option<InMemoryQueueReader>) {
let mut map = self.writers.write().await;
#[allow(clippy::option_if_let_else)]
let result = if let Some(existing) = map.get(task_id) {
(Arc::clone(existing), None)
} else if self
.max_concurrent_queues
.is_some_and(|max| map.len() >= max)
{
let (writer, _reader) = new_in_memory_queue_with_options(
self.capacity,
self.max_event_size,
self.write_timeout,
);
(Arc::new(writer), None)
} else {
let (writer, reader) = new_in_memory_queue_with_options(
self.capacity,
self.max_event_size,
self.write_timeout,
);
let writer = Arc::new(writer);
map.insert(task_id.clone(), Arc::clone(&writer));
(writer, Some(reader))
};
let queue_count = map.len();
drop(map);
if let Some(ref metrics) = self.metrics {
metrics.on_queue_depth_change(queue_count);
}
result
}
pub async fn get_or_create_with_persistence(
&self,
task_id: &TaskId,
) -> (
Arc<InMemoryQueueWriter>,
Option<InMemoryQueueReader>,
Option<tokio::sync::mpsc::Receiver<A2aResult<StreamResponse>>>,
) {
let mut map = self.writers.write().await;
#[allow(clippy::option_if_let_else)]
let result = if let Some(existing) = map.get(task_id) {
(Arc::clone(existing), None, None)
} else if self
.max_concurrent_queues
.is_some_and(|max| map.len() >= max)
{
let (writer, _reader) = new_in_memory_queue_with_options(
self.capacity,
self.max_event_size,
self.write_timeout,
);
(Arc::new(writer), None, None)
} else {
let (writer, reader, persistence_rx) = new_in_memory_queue_with_persistence(
self.capacity,
self.max_event_size,
self.write_timeout,
);
let writer = Arc::new(writer);
map.insert(task_id.clone(), Arc::clone(&writer));
(writer, Some(reader), Some(persistence_rx))
};
let queue_count = map.len();
drop(map);
if let Some(ref metrics) = self.metrics {
metrics.on_queue_depth_change(queue_count);
}
result
}
pub async fn subscribe(&self, task_id: &TaskId) -> Option<InMemoryQueueReader> {
let map = self.writers.read().await;
map.get(task_id).map(|writer| writer.subscribe())
}
pub async fn subscribe_with_snapshot(
&self,
task_id: &TaskId,
snapshot: StreamResponse,
) -> Option<InMemoryQueueReader> {
let map = self.writers.read().await;
let writer = map.get(task_id)?;
let rx = writer.raw_subscribe();
drop(map);
Some(InMemoryQueueReader::with_first_event(rx, snapshot))
}
pub async fn destroy(&self, task_id: &TaskId) {
let mut map = self.writers.write().await;
map.remove(task_id);
let queue_count = map.len();
drop(map);
if let Some(ref metrics) = self.metrics {
metrics.on_queue_depth_change(queue_count);
}
}
pub async fn active_count(&self) -> usize {
let map = self.writers.read().await;
map.len()
}
pub async fn destroy_all(&self) {
let mut map = self.writers.write().await;
map.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::streaming::event_queue::EventQueueWriter;
use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
fn make_status_event(task_id: &str, state: TaskState) -> StreamResponse {
StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
task_id: TaskId::new(task_id),
context_id: ContextId::new("ctx-test"),
status: TaskStatus {
state,
message: None,
timestamp: None,
},
metadata: None,
})
}
#[tokio::test]
async fn manager_get_or_create_new_task() {
let manager = EventQueueManager::new();
let task_id = TaskId::new("task-1");
let (writer, reader) = manager.get_or_create(&task_id).await;
assert!(
reader.is_some(),
"first get_or_create should return a reader"
);
writer
.write(make_status_event("task-1", TaskState::Working))
.await
.expect("write through manager writer should succeed");
assert_eq!(
manager.active_count().await,
1,
"should have 1 active queue"
);
}
#[tokio::test]
async fn manager_get_or_create_existing_task_returns_no_reader() {
let manager = EventQueueManager::new();
let task_id = TaskId::new("task-1");
let (_w1, r1) = manager.get_or_create(&task_id).await;
assert!(r1.is_some(), "first call should return a reader");
let (_w2, r2) = manager.get_or_create(&task_id).await;
assert!(
r2.is_none(),
"second call for same task should return None reader"
);
assert_eq!(
manager.active_count().await,
1,
"should still have only 1 active queue"
);
}
#[tokio::test]
async fn manager_subscribe_existing_task() {
use crate::streaming::event_queue::EventQueueReader;
let manager = EventQueueManager::new();
let task_id = TaskId::new("task-1");
let (writer, _reader) = manager.get_or_create(&task_id).await;
let sub = manager.subscribe(&task_id).await;
assert!(
sub.is_some(),
"subscribe should return a reader for existing task"
);
let mut sub_reader = sub.unwrap();
writer
.write(make_status_event("task-1", TaskState::Working))
.await
.expect("write should succeed");
drop(writer);
let r = sub_reader.read().await;
assert!(r.is_some(), "subscriber should receive the event");
}
#[tokio::test]
async fn manager_subscribe_nonexistent_task_returns_none() {
let manager = EventQueueManager::new();
let task_id = TaskId::new("no-such-task");
let sub = manager.subscribe(&task_id).await;
assert!(
sub.is_none(),
"subscribe should return None for nonexistent task"
);
}
#[tokio::test]
async fn manager_destroy_removes_queue() {
let manager = EventQueueManager::new();
let task_id = TaskId::new("task-1");
let (_writer, _reader) = manager.get_or_create(&task_id).await;
assert_eq!(manager.active_count().await, 1);
manager.destroy(&task_id).await;
assert_eq!(
manager.active_count().await,
0,
"destroy should remove the queue"
);
}
#[tokio::test]
async fn manager_destroy_all_clears_queues() {
let manager = EventQueueManager::new();
let _q1 = manager.get_or_create(&TaskId::new("t1")).await;
let _q2 = manager.get_or_create(&TaskId::new("t2")).await;
assert_eq!(manager.active_count().await, 2);
manager.destroy_all().await;
assert_eq!(
manager.active_count().await,
0,
"destroy_all should clear all queues"
);
}
#[tokio::test]
async fn manager_max_concurrent_queues_enforced() {
let manager = EventQueueManager::new().with_max_concurrent_queues(1);
let (_w1, r1) = manager.get_or_create(&TaskId::new("t1")).await;
assert!(r1.is_some(), "first queue should be created successfully");
let (_w2, r2) = manager.get_or_create(&TaskId::new("t2")).await;
assert!(
r2.is_none(),
"second queue should return None reader when limit is reached"
);
assert_eq!(
manager.active_count().await,
1,
"should still have only 1 queue (second was not stored)"
);
}
#[tokio::test]
async fn manager_with_write_timeout() {
let manager =
EventQueueManager::new().with_write_timeout(std::time::Duration::from_secs(10));
let task_id = TaskId::new("t1");
let (writer, reader) = manager.get_or_create(&task_id).await;
assert!(reader.is_some());
writer
.write(make_status_event("t1", TaskState::Working))
.await
.expect("write should succeed with custom write_timeout");
}
#[tokio::test]
async fn manager_with_capacity_and_max_event_size() {
let manager = EventQueueManager::with_capacity(4).with_max_event_size(10);
let task_id = TaskId::new("t1");
let (writer, _reader) = manager.get_or_create(&task_id).await;
let event = make_status_event("t1", TaskState::Working);
let result = writer.write(event).await;
assert!(
result.is_err(),
"event should be rejected by the size limit configured on the manager"
);
}
}