1use async_graphql::{Context, ID, Subscription};
2use futures_util::Stream;
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tokio::sync::broadcast;
6use tracing::{info, warn};
7
8pub const DEFAULT_CHANNEL_CAPACITY: usize = 256;
10
11#[derive(Debug, Clone)]
13pub enum UserEvent {
14 Created(crate::schema::User),
15 Updated(crate::schema::User),
16 Deleted(ID),
17}
18
19#[derive(Clone)]
25pub struct EventBroadcaster {
26 tx: Arc<RwLock<broadcast::Sender<UserEvent>>>,
27 capacity: usize,
28}
29
30impl EventBroadcaster {
31 pub fn new() -> Self {
41 Self::with_capacity(DEFAULT_CHANNEL_CAPACITY)
42 }
43
44 pub fn with_capacity(capacity: usize) -> Self {
61 assert!(capacity > 0, "Channel capacity must be greater than 0");
62 let (tx, _) = broadcast::channel(capacity);
63 Self {
64 tx: Arc::new(RwLock::new(tx)),
65 capacity,
66 }
67 }
68
69 pub fn capacity(&self) -> usize {
71 self.capacity
72 }
73
74 pub async fn broadcast(&self, event: UserEvent) -> usize {
80 let tx = self.tx.read().await;
81 match tx.send(event) {
82 Ok(receiver_count) => {
83 info!(receiver_count, "broadcast event sent to subscribers");
84 receiver_count
85 }
86 Err(_) => {
87 info!("broadcast event dropped: no active subscribers");
89 0
90 }
91 }
92 }
93
94 pub async fn subscribe(&self) -> broadcast::Receiver<UserEvent> {
98 self.tx.read().await.subscribe()
99 }
100}
101
102impl Default for EventBroadcaster {
103 fn default() -> Self {
104 Self::new()
105 }
106}
107
108fn receiver_to_stream(mut rx: broadcast::Receiver<UserEvent>) -> impl Stream<Item = UserEvent> {
114 async_stream::stream! {
115 loop {
116 match rx.recv().await {
117 Ok(event) => yield event,
118 Err(broadcast::error::RecvError::Lagged(skipped)) => {
119 warn!(
120 skipped,
121 "subscription receiver lagged, messages were dropped"
122 );
123 continue;
125 }
126 Err(broadcast::error::RecvError::Closed) => {
127 break;
129 }
130 }
131 }
132 }
133}
134
135pub struct SubscriptionRoot;
137
138#[Subscription]
139impl SubscriptionRoot {
140 async fn user_created<'ctx>(
141 &self,
142 ctx: &Context<'ctx>,
143 ) -> impl Stream<Item = crate::schema::User> + 'ctx {
144 let receiver = match ctx.data::<EventBroadcaster>() {
147 Ok(broadcaster) => Some(broadcaster.subscribe().await),
148 Err(_) => None,
149 };
150
151 let stream = receiver.map(receiver_to_stream);
152 async_stream::stream! {
153 if let Some(stream) = stream {
154 use futures_util::StreamExt;
155 let mut stream = std::pin::pin!(stream);
156 while let Some(event) = stream.next().await {
157 if let UserEvent::Created(user) = event {
158 yield user;
159 }
160 }
161 }
162 }
163 }
164
165 async fn user_updated<'ctx>(
166 &self,
167 ctx: &Context<'ctx>,
168 ) -> impl Stream<Item = crate::schema::User> + 'ctx {
169 let receiver = match ctx.data::<EventBroadcaster>() {
170 Ok(broadcaster) => Some(broadcaster.subscribe().await),
171 Err(_) => None,
172 };
173
174 let stream = receiver.map(receiver_to_stream);
175 async_stream::stream! {
176 if let Some(stream) = stream {
177 use futures_util::StreamExt;
178 let mut stream = std::pin::pin!(stream);
179 while let Some(event) = stream.next().await {
180 if let UserEvent::Updated(user) = event {
181 yield user;
182 }
183 }
184 }
185 }
186 }
187
188 async fn user_deleted<'ctx>(&self, ctx: &Context<'ctx>) -> impl Stream<Item = ID> + 'ctx {
189 let receiver = match ctx.data::<EventBroadcaster>() {
190 Ok(broadcaster) => Some(broadcaster.subscribe().await),
191 Err(_) => None,
192 };
193
194 let stream = receiver.map(receiver_to_stream);
195 async_stream::stream! {
196 if let Some(stream) = stream {
197 use futures_util::StreamExt;
198 let mut stream = std::pin::pin!(stream);
199 while let Some(event) = stream.next().await {
200 if let UserEvent::Deleted(id) = event {
201 yield id;
202 }
203 }
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use rstest::rstest;
213
214 fn make_test_user(id: &str, name: &str) -> crate::schema::User {
215 crate::schema::User {
216 id: ID::from(id),
217 name: name.to_string(),
218 email: format!("{}@example.com", name.to_lowercase()),
219 active: true,
220 }
221 }
222
223 #[rstest]
224 #[tokio::test]
225 async fn test_broadcaster_default_capacity() {
226 let broadcaster = EventBroadcaster::new();
228
229 assert_eq!(broadcaster.capacity(), DEFAULT_CHANNEL_CAPACITY);
231 }
232
233 #[rstest]
234 #[tokio::test]
235 async fn test_broadcaster_custom_capacity() {
236 let broadcaster = EventBroadcaster::with_capacity(512);
238
239 assert_eq!(broadcaster.capacity(), 512);
241 }
242
243 #[rstest]
244 #[tokio::test]
245 #[should_panic(expected = "Channel capacity must be greater than 0")]
246 async fn test_broadcaster_zero_capacity_panics() {
247 EventBroadcaster::with_capacity(0);
249 }
250
251 #[rstest]
252 #[tokio::test]
253 async fn test_broadcaster_send_receive() {
254 let broadcaster = EventBroadcaster::new();
256 let mut rx = broadcaster.subscribe().await;
257 let user = make_test_user("1", "Test");
258
259 let receiver_count = broadcaster
261 .broadcast(UserEvent::Created(user.clone()))
262 .await;
263
264 assert_eq!(receiver_count, 1);
266 let event = rx.recv().await.unwrap();
267 match event {
268 UserEvent::Created(u) => assert_eq!(u.name, "Test"),
269 _ => panic!("Expected Created event"),
270 }
271 }
272
273 #[rstest]
274 #[tokio::test]
275 async fn test_broadcaster_no_subscribers_returns_zero() {
276 let broadcaster = EventBroadcaster::new();
278 let user = make_test_user("no-sub", "NoSub");
279
280 let receiver_count = broadcaster.broadcast(UserEvent::Created(user)).await;
282
283 assert_eq!(receiver_count, 0);
285 }
286
287 #[rstest]
288 #[tokio::test]
289 async fn test_broadcaster_multiple_subscribers() {
290 let broadcaster = EventBroadcaster::new();
292 let mut rx1 = broadcaster.subscribe().await;
293 let mut rx2 = broadcaster.subscribe().await;
294 let mut rx3 = broadcaster.subscribe().await;
295 let user = make_test_user("multi-sub-1", "MultiSub");
296
297 let receiver_count = broadcaster
299 .broadcast(UserEvent::Created(user.clone()))
300 .await;
301
302 assert_eq!(receiver_count, 3);
304
305 let event1 = rx1.recv().await.unwrap();
306 let event2 = rx2.recv().await.unwrap();
307 let event3 = rx3.recv().await.unwrap();
308
309 match event1 {
310 UserEvent::Created(u) => assert_eq!(u.name, "MultiSub"),
311 _ => panic!("Expected Created event in rx1"),
312 }
313 match event2 {
314 UserEvent::Created(u) => assert_eq!(u.name, "MultiSub"),
315 _ => panic!("Expected Created event in rx2"),
316 }
317 match event3 {
318 UserEvent::Created(u) => assert_eq!(u.name, "MultiSub"),
319 _ => panic!("Expected Created event in rx3"),
320 }
321 }
322
323 #[rstest]
324 #[tokio::test]
325 async fn test_event_created() {
326 let broadcaster = EventBroadcaster::new();
328 let mut rx = broadcaster.subscribe().await;
329 let user = make_test_user("created-test", "CreatedUser");
330
331 broadcaster
333 .broadcast(UserEvent::Created(user.clone()))
334 .await;
335 let event = rx.recv().await.unwrap();
336
337 match event {
339 UserEvent::Created(u) => {
340 assert_eq!(u.id.to_string(), "created-test");
341 assert_eq!(u.name, "CreatedUser");
342 assert_eq!(u.email, "createduser@example.com");
343 assert!(u.active);
344 }
345 _ => panic!("Expected Created event"),
346 }
347 }
348
349 #[rstest]
350 #[tokio::test]
351 async fn test_event_updated() {
352 let broadcaster = EventBroadcaster::new();
354 let mut rx = broadcaster.subscribe().await;
355 let mut user = make_test_user("updated-test", "UpdatedUser");
356 user.active = false;
357
358 broadcaster
360 .broadcast(UserEvent::Updated(user.clone()))
361 .await;
362 let event = rx.recv().await.unwrap();
363
364 match event {
366 UserEvent::Updated(u) => {
367 assert_eq!(u.id.to_string(), "updated-test");
368 assert_eq!(u.name, "UpdatedUser");
369 assert!(!u.active);
370 }
371 _ => panic!("Expected Updated event"),
372 }
373 }
374
375 #[rstest]
376 #[tokio::test]
377 async fn test_event_deleted() {
378 let broadcaster = EventBroadcaster::new();
380 let mut rx = broadcaster.subscribe().await;
381 let deleted_id = ID::from("deleted-test");
382
383 broadcaster
385 .broadcast(UserEvent::Deleted(deleted_id.clone()))
386 .await;
387 let event = rx.recv().await.unwrap();
388
389 match event {
391 UserEvent::Deleted(id) => {
392 assert_eq!(id.to_string(), "deleted-test");
393 }
394 _ => panic!("Expected Deleted event"),
395 }
396 }
397
398 #[rstest]
399 #[tokio::test]
400 async fn test_subscription_filtering() {
401 let broadcaster = EventBroadcaster::new();
403 let mut rx = broadcaster.subscribe().await;
404 let user1 = make_test_user("filter-1", "Filter1");
405 let mut user2 = make_test_user("filter-2", "Filter2");
406 user2.active = false;
407
408 broadcaster
410 .broadcast(UserEvent::Created(user1.clone()))
411 .await;
412 broadcaster
413 .broadcast(UserEvent::Updated(user2.clone()))
414 .await;
415 broadcaster
416 .broadcast(UserEvent::Deleted(ID::from("filter-3")))
417 .await;
418
419 let event1 = rx.recv().await.unwrap();
421 let event2 = rx.recv().await.unwrap();
422 let event3 = rx.recv().await.unwrap();
423
424 assert!(matches!(event1, UserEvent::Created(_)));
425 assert!(matches!(event2, UserEvent::Updated(_)));
426 assert!(matches!(event3, UserEvent::Deleted(_)));
427 }
428
429 #[rstest]
430 #[tokio::test]
431 async fn test_backpressure_lagged_consumer() {
432 let broadcaster = EventBroadcaster::with_capacity(2);
435 let mut rx = broadcaster.subscribe().await;
436
437 for i in 0..5 {
440 let user = make_test_user(&format!("bp-{}", i), &format!("User{}", i));
441 broadcaster.broadcast(UserEvent::Created(user)).await;
442 }
443
444 match rx.recv().await {
448 Err(broadcast::error::RecvError::Lagged(skipped)) => {
449 assert!(skipped > 0);
451 }
452 Ok(_) => {
453 }
455 Err(broadcast::error::RecvError::Closed) => {
456 panic!("Channel should not be closed");
457 }
458 }
459 }
460
461 #[rstest]
462 #[tokio::test]
463 async fn test_receiver_to_stream_handles_lagged() {
464 use futures_util::StreamExt;
465
466 let broadcaster = EventBroadcaster::with_capacity(2);
469 let rx = broadcaster.subscribe().await;
470 let mut stream = std::pin::pin!(receiver_to_stream(rx));
471
472 for i in 0..5 {
475 let user = make_test_user(&format!("stream-{}", i), &format!("StreamUser{}", i));
476 broadcaster.broadcast(UserEvent::Created(user)).await;
477 }
478
479 let event = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await;
483 assert!(
484 event.is_ok(),
485 "Stream should produce an event after lagging"
486 );
487 assert!(event.unwrap().is_some(), "Stream should not be terminated");
488 }
489
490 #[rstest]
491 #[tokio::test]
492 async fn test_receiver_to_stream_closed() {
493 use futures_util::StreamExt;
494
495 let broadcaster = EventBroadcaster::with_capacity(4);
497 let rx = broadcaster.subscribe().await;
498 let mut stream = std::pin::pin!(receiver_to_stream(rx));
499
500 drop(broadcaster);
503
504 let event = tokio::time::timeout(std::time::Duration::from_secs(1), stream.next()).await;
506 assert!(
507 event.is_ok(),
508 "Stream should resolve when channel is closed"
509 );
510 assert!(
511 event.unwrap().is_none(),
512 "Stream should yield None when closed"
513 );
514 }
515
516 #[rstest]
517 #[tokio::test]
518 async fn test_bounded_channel_respects_capacity() {
519 let capacity = 4;
521 let broadcaster = EventBroadcaster::with_capacity(capacity);
522 let _rx = broadcaster.subscribe().await;
523
524 for i in 0..capacity {
527 let user = make_test_user(&format!("cap-{}", i), &format!("CapUser{}", i));
528 broadcaster.broadcast(UserEvent::Created(user)).await;
529 }
530
531 let user = make_test_user("cap-overflow", "CapOverflow");
534 let count = broadcaster.broadcast(UserEvent::Created(user)).await;
535 assert_eq!(count, 1);
536 }
537
538 #[tokio::test]
539 async fn test_subscription_missing_broadcaster_does_not_panic() {
540 use async_graphql::{EmptyMutation, Schema};
542 use tokio_stream::StreamExt;
543
544 let schema = Schema::build(crate::schema::Query, EmptyMutation, SubscriptionRoot)
545 .data(crate::schema::UserStorage::new())
546 .finish();
547
548 let query = r#"subscription { userCreated { id name } }"#;
550 let mut stream = schema.execute_stream(query);
551
552 let result =
555 tokio::time::timeout(std::time::Duration::from_millis(100), stream.next()).await;
556
557 if let Ok(Some(resp)) = result {
560 assert!(
562 resp.errors.is_empty() || !resp.errors.is_empty(),
563 "reached without panic"
564 );
565 }
566 }
568}