a2a_protocol_server/streaming/event_queue/
manager.rs1use std::collections::HashMap;
9use std::sync::Arc;
10
11use a2a_protocol_types::task::TaskId;
12use tokio::sync::RwLock;
13
14use a2a_protocol_types::error::A2aResult;
15use a2a_protocol_types::events::StreamResponse;
16
17use super::{
18 new_in_memory_queue_with_options, new_in_memory_queue_with_persistence, InMemoryQueueReader,
19 InMemoryQueueWriter, DEFAULT_MAX_EVENT_SIZE, DEFAULT_QUEUE_CAPACITY, DEFAULT_WRITE_TIMEOUT,
20};
21use crate::metrics::Metrics;
22
23#[derive(Clone)]
31pub struct EventQueueManager {
32 writers: Arc<RwLock<HashMap<TaskId, Arc<InMemoryQueueWriter>>>>,
33 capacity: usize,
35 max_event_size: usize,
37 write_timeout: std::time::Duration,
39 max_concurrent_queues: Option<usize>,
41 metrics: Option<Arc<dyn Metrics>>,
43}
44
45impl std::fmt::Debug for EventQueueManager {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("EventQueueManager")
48 .field("writers", &"<RwLock<HashMap<...>>>")
49 .field("capacity", &self.capacity)
50 .field("max_event_size", &self.max_event_size)
51 .field("write_timeout", &self.write_timeout)
52 .field("max_concurrent_queues", &self.max_concurrent_queues)
53 .field("metrics", &self.metrics.is_some())
54 .finish()
55 }
56}
57
58impl Default for EventQueueManager {
59 fn default() -> Self {
60 Self {
61 writers: Arc::default(),
62 capacity: DEFAULT_QUEUE_CAPACITY,
63 max_event_size: DEFAULT_MAX_EVENT_SIZE,
64 write_timeout: DEFAULT_WRITE_TIMEOUT,
65 max_concurrent_queues: None,
66 metrics: None,
67 }
68 }
69}
70
71impl EventQueueManager {
72 #[must_use]
82 pub fn new() -> Self {
83 Self::default()
84 }
85
86 #[must_use]
88 pub fn with_capacity(capacity: usize) -> Self {
89 Self {
90 writers: Arc::default(),
91 capacity,
92 max_event_size: DEFAULT_MAX_EVENT_SIZE,
93 write_timeout: DEFAULT_WRITE_TIMEOUT,
94 max_concurrent_queues: None,
95 metrics: None,
96 }
97 }
98
99 #[must_use]
104 pub const fn with_write_timeout(mut self, timeout: std::time::Duration) -> Self {
105 self.write_timeout = timeout;
106 self
107 }
108
109 #[must_use]
114 pub const fn with_max_event_size(mut self, max_event_size: usize) -> Self {
115 self.max_event_size = max_event_size;
116 self
117 }
118
119 #[must_use]
121 pub fn with_metrics(mut self, metrics: Arc<dyn Metrics>) -> Self {
122 self.metrics = Some(metrics);
123 self
124 }
125
126 #[must_use]
131 pub const fn with_max_concurrent_queues(mut self, max: usize) -> Self {
132 self.max_concurrent_queues = Some(max);
133 self
134 }
135
136 pub async fn get_or_create(
147 &self,
148 task_id: &TaskId,
149 ) -> (Arc<InMemoryQueueWriter>, Option<InMemoryQueueReader>) {
150 let mut map = self.writers.write().await;
151 #[allow(clippy::option_if_let_else)]
152 let result = if let Some(existing) = map.get(task_id) {
153 (Arc::clone(existing), None)
154 } else if self
155 .max_concurrent_queues
156 .is_some_and(|max| map.len() >= max)
157 {
158 let (writer, _reader) = new_in_memory_queue_with_options(
161 self.capacity,
162 self.max_event_size,
163 self.write_timeout,
164 );
165 (Arc::new(writer), None)
166 } else {
167 let (writer, reader) = new_in_memory_queue_with_options(
168 self.capacity,
169 self.max_event_size,
170 self.write_timeout,
171 );
172 let writer = Arc::new(writer);
173 map.insert(task_id.clone(), Arc::clone(&writer));
174 (writer, Some(reader))
175 };
176 let queue_count = map.len();
177 drop(map);
178 if let Some(ref metrics) = self.metrics {
179 metrics.on_queue_depth_change(queue_count);
180 }
181 result
182 }
183
184 pub async fn get_or_create_with_persistence(
192 &self,
193 task_id: &TaskId,
194 ) -> (
195 Arc<InMemoryQueueWriter>,
196 Option<InMemoryQueueReader>,
197 Option<tokio::sync::mpsc::Receiver<A2aResult<StreamResponse>>>,
198 ) {
199 let mut map = self.writers.write().await;
200 #[allow(clippy::option_if_let_else)]
201 let result = if let Some(existing) = map.get(task_id) {
202 (Arc::clone(existing), None, None)
203 } else if self
204 .max_concurrent_queues
205 .is_some_and(|max| map.len() >= max)
206 {
207 let (writer, _reader) = new_in_memory_queue_with_options(
208 self.capacity,
209 self.max_event_size,
210 self.write_timeout,
211 );
212 (Arc::new(writer), None, None)
213 } else {
214 let (writer, reader, persistence_rx) = new_in_memory_queue_with_persistence(
215 self.capacity,
216 self.max_event_size,
217 self.write_timeout,
218 );
219 let writer = Arc::new(writer);
220 map.insert(task_id.clone(), Arc::clone(&writer));
221 (writer, Some(reader), Some(persistence_rx))
222 };
223 let queue_count = map.len();
224 drop(map);
225 if let Some(ref metrics) = self.metrics {
226 metrics.on_queue_depth_change(queue_count);
227 }
228 result
229 }
230
231 pub async fn subscribe(&self, task_id: &TaskId) -> Option<InMemoryQueueReader> {
239 let map = self.writers.read().await;
240 map.get(task_id).map(|writer| writer.subscribe())
241 }
242
243 pub async fn subscribe_with_snapshot(
252 &self,
253 task_id: &TaskId,
254 snapshot: StreamResponse,
255 ) -> Option<InMemoryQueueReader> {
256 let map = self.writers.read().await;
257 let writer = map.get(task_id)?;
258 let rx = writer.raw_subscribe();
262 drop(map);
263 Some(InMemoryQueueReader::with_first_event(rx, snapshot))
264 }
265
266 pub async fn destroy(&self, task_id: &TaskId) {
268 let mut map = self.writers.write().await;
269 map.remove(task_id);
270 let queue_count = map.len();
271 drop(map);
272 if let Some(ref metrics) = self.metrics {
273 metrics.on_queue_depth_change(queue_count);
274 }
275 }
276
277 pub async fn active_count(&self) -> usize {
279 let map = self.writers.read().await;
280 map.len()
281 }
282
283 pub async fn destroy_all(&self) {
285 let mut map = self.writers.write().await;
286 map.clear();
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use crate::streaming::event_queue::EventQueueWriter;
294 use a2a_protocol_types::events::{StreamResponse, TaskStatusUpdateEvent};
295 use a2a_protocol_types::task::{ContextId, TaskState, TaskStatus};
296
297 fn make_status_event(task_id: &str, state: TaskState) -> StreamResponse {
299 StreamResponse::StatusUpdate(TaskStatusUpdateEvent {
300 task_id: TaskId::new(task_id),
301 context_id: ContextId::new("ctx-test"),
302 status: TaskStatus {
303 state,
304 message: None,
305 timestamp: None,
306 },
307 metadata: None,
308 })
309 }
310
311 #[tokio::test]
314 async fn manager_get_or_create_new_task() {
315 let manager = EventQueueManager::new();
316 let task_id = TaskId::new("task-1");
317
318 let (writer, reader) = manager.get_or_create(&task_id).await;
319 assert!(
320 reader.is_some(),
321 "first get_or_create should return a reader"
322 );
323
324 writer
326 .write(make_status_event("task-1", TaskState::Working))
327 .await
328 .expect("write through manager writer should succeed");
329
330 assert_eq!(
331 manager.active_count().await,
332 1,
333 "should have 1 active queue"
334 );
335 }
336
337 #[tokio::test]
338 async fn manager_get_or_create_existing_task_returns_no_reader() {
339 let manager = EventQueueManager::new();
340 let task_id = TaskId::new("task-1");
341
342 let (_w1, r1) = manager.get_or_create(&task_id).await;
343 assert!(r1.is_some(), "first call should return a reader");
344
345 let (_w2, r2) = manager.get_or_create(&task_id).await;
346 assert!(
347 r2.is_none(),
348 "second call for same task should return None reader"
349 );
350
351 assert_eq!(
352 manager.active_count().await,
353 1,
354 "should still have only 1 active queue"
355 );
356 }
357
358 #[tokio::test]
359 async fn manager_subscribe_existing_task() {
360 use crate::streaming::event_queue::EventQueueReader;
361
362 let manager = EventQueueManager::new();
363 let task_id = TaskId::new("task-1");
364
365 let (writer, _reader) = manager.get_or_create(&task_id).await;
366
367 let sub = manager.subscribe(&task_id).await;
368 assert!(
369 sub.is_some(),
370 "subscribe should return a reader for existing task"
371 );
372
373 let mut sub_reader = sub.unwrap();
374 writer
375 .write(make_status_event("task-1", TaskState::Working))
376 .await
377 .expect("write should succeed");
378 drop(writer);
379
380 let r = sub_reader.read().await;
381 assert!(r.is_some(), "subscriber should receive the event");
382 }
383
384 #[tokio::test]
385 async fn manager_subscribe_nonexistent_task_returns_none() {
386 let manager = EventQueueManager::new();
387 let task_id = TaskId::new("no-such-task");
388
389 let sub = manager.subscribe(&task_id).await;
390 assert!(
391 sub.is_none(),
392 "subscribe should return None for nonexistent task"
393 );
394 }
395
396 #[tokio::test]
397 async fn manager_destroy_removes_queue() {
398 let manager = EventQueueManager::new();
399 let task_id = TaskId::new("task-1");
400
401 let (_writer, _reader) = manager.get_or_create(&task_id).await;
402 assert_eq!(manager.active_count().await, 1);
403
404 manager.destroy(&task_id).await;
405 assert_eq!(
406 manager.active_count().await,
407 0,
408 "destroy should remove the queue"
409 );
410 }
411
412 #[tokio::test]
413 async fn manager_destroy_all_clears_queues() {
414 let manager = EventQueueManager::new();
415
416 let _q1 = manager.get_or_create(&TaskId::new("t1")).await;
417 let _q2 = manager.get_or_create(&TaskId::new("t2")).await;
418 assert_eq!(manager.active_count().await, 2);
419
420 manager.destroy_all().await;
421 assert_eq!(
422 manager.active_count().await,
423 0,
424 "destroy_all should clear all queues"
425 );
426 }
427
428 #[tokio::test]
429 async fn manager_max_concurrent_queues_enforced() {
430 let manager = EventQueueManager::new().with_max_concurrent_queues(1);
431
432 let (_w1, r1) = manager.get_or_create(&TaskId::new("t1")).await;
433 assert!(r1.is_some(), "first queue should be created successfully");
434
435 let (_w2, r2) = manager.get_or_create(&TaskId::new("t2")).await;
437 assert!(
438 r2.is_none(),
439 "second queue should return None reader when limit is reached"
440 );
441 assert_eq!(
442 manager.active_count().await,
443 1,
444 "should still have only 1 queue (second was not stored)"
445 );
446 }
447
448 #[tokio::test]
450 async fn manager_with_write_timeout() {
451 let manager =
452 EventQueueManager::new().with_write_timeout(std::time::Duration::from_secs(10));
453 let task_id = TaskId::new("t1");
455 let (writer, reader) = manager.get_or_create(&task_id).await;
456 assert!(reader.is_some());
457 writer
458 .write(make_status_event("t1", TaskState::Working))
459 .await
460 .expect("write should succeed with custom write_timeout");
461 }
462
463 #[tokio::test]
464 async fn manager_with_capacity_and_max_event_size() {
465 let manager = EventQueueManager::with_capacity(4).with_max_event_size(10); let task_id = TaskId::new("t1");
468 let (writer, _reader) = manager.get_or_create(&task_id).await;
469
470 let event = make_status_event("t1", TaskState::Working);
471 let result = writer.write(event).await;
472 assert!(
473 result.is_err(),
474 "event should be rejected by the size limit configured on the manager"
475 );
476 }
477}