use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{RwLock, broadcast};
use crate::error::{A2AError, Result};
pub(super) use crate::types::StreamResponse;
pub type Event = StreamResponse;
#[derive(Debug)]
pub struct EventQueue {
sender: broadcast::Sender<StreamResponse>,
}
impl EventQueue {
#[must_use]
pub fn new(capacity: usize) -> Self {
let (sender, _) = broadcast::channel(capacity);
Self { sender }
}
pub fn send(&self, event: StreamResponse) -> Result<()> {
self.sender
.send(event)
.map_err(|e| A2AError::Other(format!("Failed to send event: {e}")))?;
Ok(())
}
#[must_use]
pub fn subscribe(&self) -> broadcast::Receiver<StreamResponse> {
self.sender.subscribe()
}
#[must_use]
pub fn subscriber_count(&self) -> usize {
self.sender.receiver_count()
}
}
impl Default for EventQueue {
fn default() -> Self {
Self::new(100)
}
}
#[derive(Debug, Default)]
pub struct QueueManager {
queues: Arc<RwLock<HashMap<String, Arc<EventQueue>>>>,
capacity: usize,
}
impl QueueManager {
#[must_use]
pub fn new() -> Self {
Self::with_capacity(100)
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
queues: Arc::new(RwLock::new(HashMap::new())),
capacity,
}
}
pub async fn create_queue(&self, task_id: &str) -> Option<Arc<EventQueue>> {
let mut queues = self.queues.write().await;
if queues.contains_key(task_id) {
return None;
}
let queue = Arc::new(EventQueue::new(self.capacity));
queues.insert(task_id.to_owned(), Arc::clone(&queue));
drop(queues);
Some(queue)
}
pub async fn get_queue(&self, task_id: &str) -> Option<Arc<EventQueue>> {
let queues = self.queues.read().await;
queues.get(task_id).cloned()
}
pub async fn get_or_create_queue(&self, task_id: &str) -> Arc<EventQueue> {
{
let queues = self.queues.read().await;
if let Some(queue) = queues.get(task_id) {
return Arc::clone(queue);
}
}
let mut queues = self.queues.write().await;
if let Some(queue) = queues.get(task_id) {
return Arc::clone(queue);
}
let queue = Arc::new(EventQueue::new(self.capacity));
queues.insert(task_id.to_owned(), Arc::clone(&queue));
queue
}
pub async fn remove_queue(&self, task_id: &str) -> Option<Arc<EventQueue>> {
let mut queues = self.queues.write().await;
queues.remove(task_id)
}
pub async fn queue_count(&self) -> usize {
let queues = self.queues.read().await;
queues.len()
}
pub async fn send_event(&self, task_id: &str, event: StreamResponse) -> Result<()> {
let queue = self
.get_queue(task_id)
.await
.ok_or_else(|| A2AError::Other(format!("No queue for task {task_id}")))?;
queue.send(event)
}
}