a2a_protocol_server/streaming/event_queue/
manager.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use a2a_protocol_types::task::TaskId;
12use tokio::sync::RwLock;
13
14use a2a_protocol_types::error::A2aResult;
15use a2a_protocol_types::events::StreamResponse;
16
17use super::{
18 new_in_memory_queue_with_options, new_in_memory_queue_with_persistence, InMemoryQueueReader,
19 InMemoryQueueWriter, DEFAULT_MAX_EVENT_SIZE, DEFAULT_QUEUE_CAPACITY, DEFAULT_WRITE_TIMEOUT,
20};
21use crate::metrics::Metrics;
22
23#[derive(Clone)]
31pub struct EventQueueManager {
32 writers: Arc<RwLock<HashMap<TaskId, Arc<InMemoryQueueWriter>>>>,
33 capacity: usize,
35 max_event_size: usize,
37 write_timeout: std::time::Duration,
39 max_concurrent_queues: Option<usize>,
41 metrics: Option<Arc<dyn Metrics>>,
43}
44
45impl std::fmt::Debug for EventQueueManager {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("EventQueueManager")
48 .field("writers", &"<RwLock<HashMap<...>>>")
49 .field("capacity", &self.capacity)
50 .field("max_event_size", &self.max_event_size)
51 .field("write_timeout", &self.write_timeout)
52 .field("max_concurrent_queues", &self.max_concurrent_queues)
53 .field("metrics", &self.metrics.is_some())
54 .finish()
55 }
56}
57
58impl Default for EventQueueManager {
59 fn default() -> Self {
60 Self {
61 writers: Arc::default(),
62 capacity: DEFAULT_QUEUE_CAPACITY,
63 max_event_size: DEFAULT_MAX_EVENT_SIZE,
64 write_timeout: DEFAULT_WRITE_TIMEOUT,
65 max_concurrent_queues: None,
66 metrics: None,
67 }
68 }
69}
70
71impl EventQueueManager {
72 #[must_use]
82 pub fn new() -> Self {
83 Self::default()
84 }
85
86 #[must_use]
88 pub fn with_capacity(capacity: usize) -> Self {
89 Self {
90 writers: Arc::default(),
91 capacity,
92 max_event_size: DEFAULT_MAX_EVENT_SIZE,
93 write_timeout: DEFAULT_WRITE_TIMEOUT,
94 max_concurrent_queues: None,
95 metrics: None,
96 }
97 }
98
99 #[must_use]
104 pub const fn with_write_timeout(mut self, timeout: std::time::Duration) -> Self {
105 self.write_timeout = timeout;
106 self
107 }
108
109 #[must_use]
114 pub const fn with_max_event_size(mut self, max_event_size: usize) -> Self {
115 self.max_event_size = max_event_size;
116 self
117 }
118
119 #[must_use]
121 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
122 self.metrics = Some(metrics);
123 self
124 }
125
126 #[must_use]
131 pub const fn with_max_concurrent_queues(mut self, max: usize) -> Self {
132 self.max_concurrent_queues = Some(max);
133 self
134 }
135
136 pub async fn get_or_create(
147 &self,
148 task_id: &TaskId,
149 ) -> (Arc<InMemoryQueueWriter>, Option<InMemoryQueueReader>) {
150 let mut map = self.writers.write().await;
151 #[allow(clippy::option_if_let_else)]
152 let result = if let Some(existing) = map.get(task_id) {
153 (Arc::clone(existing), None)
154 } else if self
155 .max_concurrent_queues
156 .is_some_and(|max| map.len() >= max)
157 {
158 let (writer, _reader) = new_in_memory_queue_with_options(
161 self.capacity,
162 self.max_event_size,
163 self.write_timeout,
164 );
165 (Arc::new(writer), None)
166 } else {
167 let (writer, reader) = new_in_memory_queue_with_options(
168 self.capacity,
169 self.max_event_size,
170 self.write_timeout,
171 );
172 let writer = Arc::new(writer);
173 map.insert(task_id.clone(), Arc::clone(&writer));
174 (writer, Some(reader))
175 };
176 let queue_count = map.len();
177 drop(map);
178 if let Some(ref metrics) = self.metrics {
179 metrics.on_queue_depth_change(queue_count);
180 }
181 result
182 }
183
184 pub async fn get_or_create_with_persistence(
192 &self,
193 task_id: &TaskId,
194 ) -> (
195 Arc<InMemoryQueueWriter>,
196 Option<InMemoryQueueReader>,
197 Option<tokio::sync::mpsc::Receiver<A2aResult<StreamResponse>>>,
198 ) {
199 let mut map = self.writers.write().await;
200 #[allow(clippy::option_if_let_else)]
201 let result = if let Some(existing) = map.get(task_id) {
202 (Arc::clone(existing), None, None)
203 } else if self
204 .max_concurrent_queues
205 .is_some_and(|max| map.len() >= max)
206 {
207 let (writer, _reader) = new_in_memory_queue_with_options(
208 self.capacity,
209 self.max_event_size,
210 self.write_timeout,
211 );
212 (Arc::new(writer), None, None)
213 } else {
214 let (writer, reader, persistence_rx) = new_in_memory_queue_with_persistence(
215 self.capacity,
216 self.max_event_size,
217 self.write_timeout,
218 );
219 let writer = Arc::new(writer);
220 map.insert(task_id.clone(), Arc::clone(&writer));
221 (writer, Some(reader), Some(persistence_rx))
222 };
223 let queue_count = map.len();
224 drop(map);
225 if let Some(ref metrics) = self.metrics {
226 metrics.on_queue_depth_change(queue_count);
227 }
228 result
229 }
230
231 pub async fn subscribe(&self, task_id: &TaskId) -> Option<InMemoryQueueReader> {
239 let map = self.writers.read().await;
240 map.get(task_id).map(|writer| writer.subscribe())
241 }
242
243 pub async fn destroy(&self, task_id: &TaskId) {
245 let mut map = self.writers.write().await;
246 map.remove(task_id);
247 let queue_count = map.len();
248 drop(map);
249 if let Some(ref metrics) = self.metrics {
250 metrics.on_queue_depth_change(queue_count);
251 }
252 }
253
254 pub async fn active_count(&self) -> usize {
256 let map = self.writers.read().await;
257 map.len()
258 }
259
260 pub async fn destroy_all(&self) {
262 let mut map = self.writers.write().await;
263 map.clear();
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use crate::streaming::event_queue::EventQueueWriter;
271 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
272 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
273
274 fn make_status_event(task_id: &str, state: TaskState) -> StreamResponse {
276 StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
277 task_id: TaskId::new(task_id),
278 context_id: ContextId::new("ctx-test"),
279 status: TaskStatus {
280 state,
281 message: None,
282 timestamp: None,
283 },
284 metadata: None,
285 })
286 }
287
288 #[tokio::test]
291 async fn manager_get_or_create_new_task() {
292 let manager = EventQueueManager::new();
293 let task_id = TaskId::new("task-1");
294
295 let (writer, reader) = manager.get_or_create(&task_id).await;
296 assert!(
297 reader.is_some(),
298 "first get_or_create should return a reader"
299 );
300
301 writer
303 .write(make_status_event("task-1", TaskState::Working))
304 .await
305 .expect("write through manager writer should succeed");
306
307 assert_eq!(
308 manager.active_count().await,
309 1,
310 "should have 1 active queue"
311 );
312 }
313
314 #[tokio::test]
315 async fn manager_get_or_create_existing_task_returns_no_reader() {
316 let manager = EventQueueManager::new();
317 let task_id = TaskId::new("task-1");
318
319 let (_w1, r1) = manager.get_or_create(&task_id).await;
320 assert!(r1.is_some(), "first call should return a reader");
321
322 let (_w2, r2) = manager.get_or_create(&task_id).await;
323 assert!(
324 r2.is_none(),
325 "second call for same task should return None reader"
326 );
327
328 assert_eq!(
329 manager.active_count().await,
330 1,
331 "should still have only 1 active queue"
332 );
333 }
334
335 #[tokio::test]
336 async fn manager_subscribe_existing_task() {
337 use crate::streaming::event_queue::EventQueueReader;
338
339 let manager = EventQueueManager::new();
340 let task_id = TaskId::new("task-1");
341
342 let (writer, _reader) = manager.get_or_create(&task_id).await;
343
344 let sub = manager.subscribe(&task_id).await;
345 assert!(
346 sub.is_some(),
347 "subscribe should return a reader for existing task"
348 );
349
350 let mut sub_reader = sub.unwrap();
351 writer
352 .write(make_status_event("task-1", TaskState::Working))
353 .await
354 .expect("write should succeed");
355 drop(writer);
356
357 let r = sub_reader.read().await;
358 assert!(r.is_some(), "subscriber should receive the event");
359 }
360
361 #[tokio::test]
362 async fn manager_subscribe_nonexistent_task_returns_none() {
363 let manager = EventQueueManager::new();
364 let task_id = TaskId::new("no-such-task");
365
366 let sub = manager.subscribe(&task_id).await;
367 assert!(
368 sub.is_none(),
369 "subscribe should return None for nonexistent task"
370 );
371 }
372
373 #[tokio::test]
374 async fn manager_destroy_removes_queue() {
375 let manager = EventQueueManager::new();
376 let task_id = TaskId::new("task-1");
377
378 let (_writer, _reader) = manager.get_or_create(&task_id).await;
379 assert_eq!(manager.active_count().await, 1);
380
381 manager.destroy(&task_id).await;
382 assert_eq!(
383 manager.active_count().await,
384 0,
385 "destroy should remove the queue"
386 );
387 }
388
389 #[tokio::test]
390 async fn manager_destroy_all_clears_queues() {
391 let manager = EventQueueManager::new();
392
393 let _q1 = manager.get_or_create(&TaskId::new("t1")).await;
394 let _q2 = manager.get_or_create(&TaskId::new("t2")).await;
395 assert_eq!(manager.active_count().await, 2);
396
397 manager.destroy_all().await;
398 assert_eq!(
399 manager.active_count().await,
400 0,
401 "destroy_all should clear all queues"
402 );
403 }
404
405 #[tokio::test]
406 async fn manager_max_concurrent_queues_enforced() {
407 let manager = EventQueueManager::new().with_max_concurrent_queues(1);
408
409 let (_w1, r1) = manager.get_or_create(&TaskId::new("t1")).await;
410 assert!(r1.is_some(), "first queue should be created successfully");
411
412 let (_w2, r2) = manager.get_or_create(&TaskId::new("t2")).await;
414 assert!(
415 r2.is_none(),
416 "second queue should return None reader when limit is reached"
417 );
418 assert_eq!(
419 manager.active_count().await,
420 1,
421 "should still have only 1 queue (second was not stored)"
422 );
423 }
424
425 #[tokio::test]
427 async fn manager_with_write_timeout() {
428 let manager =
429 EventQueueManager::new().with_write_timeout(std::time::Duration::from_secs(10));
430 let task_id = TaskId::new("t1");
432 let (writer, reader) = manager.get_or_create(&task_id).await;
433 assert!(reader.is_some());
434 writer
435 .write(make_status_event("t1", TaskState::Working))
436 .await
437 .expect("write should succeed with custom write_timeout");
438 }
439
440 #[tokio::test]
441 async fn manager_with_capacity_and_max_event_size() {
442 let manager = EventQueueManager::with_capacity(4).with_max_event_size(10); let task_id = TaskId::new("t1");
445 let (writer, _reader) = manager.get_or_create(&task_id).await;
446
447 let event = make_status_event("t1", TaskState::Working);
448 let result = writer.write(event).await;
449 assert!(
450 result.is_err(),
451 "event should be rejected by the size limit configured on the manager"
452 );
453 }
454}