sphereql_graphql/
subscription.rs1use async_graphql::futures_util::Stream;
2use async_graphql::{Context, Result, Subscription};
3use tokio::sync::broadcast;
4
5use sphereql_core::{Contains, SphericalPoint};
6
7use crate::types::{RegionInput, SphericalPointOutput};
8
9#[derive(async_graphql::Enum, Copy, Clone, Eq, PartialEq, Debug)]
10pub enum SpatialEventType {
11 Entered,
12 Left,
13 Moved,
14}
15
16#[derive(async_graphql::SimpleObject, Debug, Clone)]
17pub struct SpatialEvent {
18 pub event_type: SpatialEventType,
19 pub point: SphericalPointOutput,
20 pub item_id: String,
21 #[graphql(skip)]
26 pub(crate) core_point: SphericalPoint,
27}
28
29pub struct SpatialEventBus {
30 sender: broadcast::Sender<SpatialEvent>,
31}
32
33impl SpatialEvent {
34 pub fn new(event_type: SpatialEventType, point: SphericalPointOutput, item_id: String) -> Self {
39 let core_point = SphericalPoint::new_unchecked(point.r, point.theta, point.phi);
40 Self {
41 event_type,
42 point,
43 item_id,
44 core_point,
45 }
46 }
47}
48
49impl SpatialEventBus {
50 pub fn new(capacity: usize) -> Self {
51 let (sender, _) = broadcast::channel(capacity);
52 Self { sender }
53 }
54
55 pub fn publish(&self, event: SpatialEvent) {
56 if let Err(e) = self.sender.send(event) {
57 tracing::trace!(error = %e, "SpatialEventBus::publish: no subscribers");
61 }
62 }
63
64 pub fn subscribe(&self) -> broadcast::Receiver<SpatialEvent> {
65 self.sender.subscribe()
66 }
67}
68
69pub struct SphericalSubscriptionRoot;
70
71#[Subscription]
72impl SphericalSubscriptionRoot {
73 async fn item_entered_region(
74 &self,
75 ctx: &Context<'_>,
76 region: RegionInput,
77 ) -> Result<impl Stream<Item = SpatialEvent>> {
78 let bus = ctx.data::<SpatialEventBus>()?;
79 let mut rx = bus.subscribe();
80 let region = region.to_core()?;
81
82 let stream = async_graphql::async_stream::stream! {
83 loop {
84 match rx.recv().await {
85 Ok(event) => {
86 if event.event_type == SpatialEventType::Entered
87 && region.contains(&event.core_point)
88 {
89 yield event;
90 }
91 }
92 Err(broadcast::error::RecvError::Lagged(_)) => continue,
93 Err(broadcast::error::RecvError::Closed) => break,
94 }
95 }
96 };
97
98 Ok(stream)
99 }
100
101 async fn item_left_region(
102 &self,
103 ctx: &Context<'_>,
104 region: RegionInput,
105 ) -> Result<impl Stream<Item = SpatialEvent>> {
106 let bus = ctx.data::<SpatialEventBus>()?;
107 let mut rx = bus.subscribe();
108 let region = region.to_core()?;
109
110 let stream = async_graphql::async_stream::stream! {
111 loop {
112 match rx.recv().await {
113 Ok(event) => {
114 if event.event_type == SpatialEventType::Left
115 && region.contains(&event.core_point)
116 {
117 yield event;
118 }
119 }
120 Err(broadcast::error::RecvError::Lagged(_)) => continue,
121 Err(broadcast::error::RecvError::Closed) => break,
122 }
123 }
124 };
125
126 Ok(stream)
127 }
128
129 async fn spatial_events(&self, ctx: &Context<'_>) -> Result<impl Stream<Item = SpatialEvent>> {
130 let bus = ctx.data::<SpatialEventBus>()?;
131 let mut rx = bus.subscribe();
132
133 let stream = async_graphql::async_stream::stream! {
134 loop {
135 match rx.recv().await {
136 Ok(event) => { yield event; }
137 Err(broadcast::error::RecvError::Lagged(_)) => continue,
138 Err(broadcast::error::RecvError::Closed) => break,
139 }
140 }
141 };
142
143 Ok(stream)
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150 use std::f64::consts::FRAC_PI_4;
151
152 fn make_event(event_type: SpatialEventType, r: f64, theta: f64, phi: f64) -> SpatialEvent {
153 SpatialEvent::new(
154 event_type,
155 SphericalPointOutput {
156 r,
157 theta,
158 phi,
159 theta_degrees: theta.to_degrees(),
160 phi_degrees: phi.to_degrees(),
161 },
162 format!("item-{r}-{theta}-{phi}"),
163 )
164 }
165
166 #[tokio::test]
167 async fn event_bus_publish_subscribe() {
168 let bus = SpatialEventBus::new(16);
169 let mut rx = bus.subscribe();
170
171 let event = make_event(SpatialEventType::Entered, 1.0, 0.5, FRAC_PI_4);
172 bus.publish(event.clone());
173
174 let received = rx.recv().await.unwrap();
175 assert_eq!(received.item_id, "item-1-0.5-0.7853981633974483");
176 assert_eq!(received.event_type, SpatialEventType::Entered);
177 assert!((received.point.r - 1.0).abs() < 1e-12);
178 }
179
180 #[tokio::test]
181 async fn multiple_subscribers_receive_events() {
182 let bus = SpatialEventBus::new(16);
183 let mut rx1 = bus.subscribe();
184 let mut rx2 = bus.subscribe();
185
186 let event = make_event(SpatialEventType::Moved, 2.0, 1.0, 0.5);
187 bus.publish(event.clone());
188
189 let r1 = rx1.recv().await.unwrap();
190 let r2 = rx2.recv().await.unwrap();
191
192 assert_eq!(r1.item_id, r2.item_id);
193 assert_eq!(r1.event_type, SpatialEventType::Moved);
194 assert_eq!(r2.event_type, SpatialEventType::Moved);
195 }
196
197 #[tokio::test]
198 async fn event_type_filtering() {
199 let bus = SpatialEventBus::new(16);
200 let mut rx = bus.subscribe();
201
202 bus.publish(make_event(SpatialEventType::Entered, 1.0, 0.5, 0.5));
203 bus.publish(make_event(SpatialEventType::Left, 1.0, 0.6, 0.6));
204 bus.publish(make_event(SpatialEventType::Moved, 1.0, 0.7, 0.7));
205 bus.publish(make_event(SpatialEventType::Entered, 2.0, 0.8, 0.8));
206
207 let mut entered = Vec::new();
208 for _ in 0..4 {
209 let event = rx.recv().await.unwrap();
210 if event.event_type == SpatialEventType::Entered {
211 entered.push(event);
212 }
213 }
214
215 assert_eq!(entered.len(), 2);
216 assert!((entered[0].point.r - 1.0).abs() < 1e-12);
217 assert!((entered[1].point.r - 2.0).abs() < 1e-12);
218 }
219}