1use std::any::TypeId;
23use std::collections::HashMap;
24use std::marker::PhantomData;
25use std::sync::Arc;
26use std::sync::atomic::{AtomicBool, Ordering};
27
28use bevy::ecs::event::GlobalTrigger;
29use bevy::prelude::*;
30use serde::Serialize;
31use serde::de::DeserializeOwned;
32use ts_rs::TS;
33
34use crate::bridge::{OutboundResource, OutboundSender};
35use crate::protocol::{Outbound, ResponseResult};
36use crate::registry::{NamedEntry, register_entry};
37use crate::ts_codegen::TsCollector;
38
39pub struct RawRequest {
42 pub id: u64,
43 pub name: String,
44 pub value: serde_json::Value,
45}
46
47pub trait ReactRequest: DeserializeOwned + TS + Send + Sync + 'static {
54 const NAME: &'static str;
58 type Response: Serialize + TS + Send + Sync + 'static;
60}
61
62pub struct Responder<R> {
67 id: u64,
68 tx: OutboundSender,
69 done: Arc<AtomicBool>,
70 _marker: PhantomData<fn() -> R>,
72}
73
74impl<R> Clone for Responder<R> {
75 fn clone(&self) -> Self {
76 Self {
77 id: self.id,
78 tx: self.tx.clone(),
79 done: self.done.clone(),
80 _marker: PhantomData,
81 }
82 }
83}
84
85impl<R: Serialize> Responder<R> {
86 fn new(id: u64, tx: OutboundSender) -> Self {
87 Self {
88 id,
89 tx,
90 done: Arc::new(AtomicBool::new(false)),
91 _marker: PhantomData,
92 }
93 }
94
95 pub fn respond(&self, value: R) {
97 if !self.claim() {
98 return;
99 }
100 let result = match serde_json::to_value(&value) {
101 Ok(value) => ResponseResult::Ok { value },
102 Err(e) => ResponseResult::Err {
103 message: format!("serialize response: {e}"),
104 },
105 };
106 let _ = self.tx.send(Outbound::Response {
107 id: self.id,
108 result,
109 });
110 }
111
112 pub fn respond_err(&self, message: impl Into<String>) {
114 if !self.claim() {
115 return;
116 }
117 let _ = self.tx.send(Outbound::Response {
118 id: self.id,
119 result: ResponseResult::Err {
120 message: message.into(),
121 },
122 });
123 }
124
125 fn claim(&self) -> bool {
128 if self.done.swap(true, Ordering::SeqCst) {
129 warn!(
130 "react request {} responded to more than once; ignoring",
131 self.id
132 );
133 false
134 } else {
135 true
136 }
137 }
138}
139
140pub struct Request<T: ReactRequest> {
143 payload: T,
144 responder: Responder<T::Response>,
145}
146
147impl<T: ReactRequest> Request<T> {
148 pub fn payload(&self) -> &T {
150 &self.payload
151 }
152
153 pub fn into_payload(self) -> T {
155 self.payload
156 }
157
158 pub fn responder(&self) -> Responder<T::Response> {
160 self.responder.clone()
161 }
162
163 pub fn respond(&self, value: T::Response) {
165 self.responder.respond(value);
166 }
167
168 pub fn respond_err(&self, message: impl Into<String>) {
170 self.responder.respond_err(message);
171 }
172}
173
174impl<T: ReactRequest> Event for Request<T> {
175 type Trigger<'a> = GlobalTrigger;
176}
177
178pub trait RequestEvent {
182 type Req: ReactRequest;
183}
184
185impl<T: ReactRequest> RequestEvent for Request<T> {
186 type Req = T;
187}
188
189type RequestHandler = Box<dyn Fn(RawRequest, &OutboundSender, &mut Commands) + Send + Sync>;
191
192pub(crate) struct RequestRegistration {
195 type_id: TypeId,
196 handler: RequestHandler,
197 pub(crate) ts_request_name: fn() -> String,
199 pub(crate) ts_response_name: fn() -> String,
201 pub(crate) request_is_void: fn() -> bool,
203 pub(crate) ts_collect: fn(&mut TsCollector),
205}
206
207#[derive(Resource, Default)]
210pub(crate) struct ReactRequestRegistry {
211 pub(crate) handlers: HashMap<&'static str, RequestRegistration>,
212}
213
214impl NamedEntry for RequestRegistration {
215 fn type_id(&self) -> TypeId {
216 self.type_id
217 }
218}
219
220impl ReactRequestRegistry {
221 pub(crate) fn register<T: ReactRequest>(&mut self) {
224 register_entry(
225 &mut self.handlers,
226 T::NAME,
227 "request",
228 RequestRegistration {
229 type_id: TypeId::of::<T>(),
230 handler: Box::new(|raw, tx, commands| {
231 let responder = Responder::<T::Response>::new(raw.id, tx.clone());
233 match serde_json::from_value::<T>(raw.value) {
234 Ok(payload) => commands.trigger(Request { payload, responder }),
235 Err(e) => {
236 responder.respond_err(format!("malformed request {:?}: {e}", T::NAME))
237 }
238 }
239 }),
240 ts_request_name: <T as TS>::name,
241 ts_response_name: <T::Response as TS>::name,
242 request_is_void: || <T as TS>::inline() == "null",
244 ts_collect: |c| {
245 if <T as TS>::inline() != "null" {
248 c.add::<T>();
249 }
250 c.add::<T::Response>();
251 },
252 },
253 );
254 }
255
256 pub(crate) fn dispatch(&self, raw: RawRequest, tx: &OutboundSender, commands: &mut Commands) {
259 match self.handlers.get(raw.name.as_str()) {
260 Some(reg) => (reg.handler)(raw, tx, commands),
261 None => {
262 let _ = tx.send(Outbound::Response {
263 id: raw.id,
264 result: ResponseResult::Err {
265 message: format!("no handler registered for request {:?}", raw.name),
266 },
267 });
268 }
269 }
270 }
271}
272
273#[derive(Resource)]
275pub(crate) struct RequestReceiver(pub(crate) crossbeam_channel::Receiver<RawRequest>);
276
277pub(crate) fn dispatch_react_requests(
281 rx: Res<RequestReceiver>,
282 registry: Res<ReactRequestRegistry>,
283 out: Res<OutboundResource>,
284 mut commands: Commands,
285) {
286 while let Ok(raw) = rx.0.try_recv() {
287 registry.dispatch(raw, &out.0, &mut commands);
288 }
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::ReactAppExt;
295 use bevy::ecs::world::CommandQueue;
296 use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
297
298 #[crate::react_request(name = "ping", response = Pong)]
299 struct Ping {
300 n: u32,
301 }
302
303 #[derive(serde::Serialize, ts_rs::TS)]
304 struct Pong {
305 n: u32,
306 }
307
308 fn dispatch(app: &mut App, tx: &OutboundSender, raw: RawRequest) {
311 app.world_mut()
312 .resource_scope(|world, registry: Mut<ReactRequestRegistry>| {
313 let mut queue = CommandQueue::default();
314 let mut commands = Commands::new(&mut queue, world);
315 registry.dispatch(raw, tx, &mut commands);
316 queue.apply(world);
317 });
318 }
319
320 fn raw(id: u64, name: &str, value: serde_json::Value) -> RawRequest {
321 RawRequest {
322 id,
323 name: name.into(),
324 value,
325 }
326 }
327
328 #[test]
331 fn dispatches_and_responds() {
332 let mut app = App::new();
333 app.add_react_request_handler(|req: On<Request<Ping>>| {
334 let n = req.payload().n;
335 req.respond(Pong { n: n + 1 });
336 });
337 let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
338
339 dispatch(
340 &mut app,
341 &tx,
342 raw(7, "ping", serde_json::json!({ "n": 41 })),
343 );
344
345 match rx.try_recv() {
346 Ok(Outbound::Response {
347 id,
348 result: ResponseResult::Ok { value },
349 }) => {
350 assert_eq!(id, 7);
351 assert_eq!(value, serde_json::json!({ "n": 42 }));
352 }
353 other => panic!("expected Ok response, got {other:?}"),
354 }
355 }
356
357 #[test]
360 fn unknown_name_replies_err() {
361 let mut app = App::new();
362 app.init_resource::<ReactRequestRegistry>();
363 let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
364
365 dispatch(&mut app, &tx, raw(1, "nope", serde_json::json!(null)));
366
367 assert!(matches!(
368 rx.try_recv(),
369 Ok(Outbound::Response {
370 id: 1,
371 result: ResponseResult::Err { .. },
372 })
373 ));
374 }
375
376 #[test]
379 fn malformed_payload_replies_err() {
380 let mut app = App::new();
381 app.add_react_request_handler(|req: On<Request<Ping>>| req.respond(Pong { n: 0 }));
382 let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
383
384 dispatch(
385 &mut app,
386 &tx,
387 raw(2, "ping", serde_json::json!({ "n": "nope" })),
388 );
389
390 assert!(matches!(
391 rx.try_recv(),
392 Ok(Outbound::Response {
393 id: 2,
394 result: ResponseResult::Err { .. },
395 })
396 ));
397 }
398
399 #[test]
401 fn respond_twice_sends_once() {
402 let mut app = App::new();
403 app.add_react_request_handler(|req: On<Request<Ping>>| {
404 req.respond(Pong { n: 1 });
405 req.respond(Pong { n: 2 }); });
407 let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
408
409 dispatch(&mut app, &tx, raw(3, "ping", serde_json::json!({ "n": 0 })));
410
411 assert!(matches!(
412 rx.try_recv(),
413 Ok(Outbound::Response {
414 result: ResponseResult::Ok { .. },
415 ..
416 })
417 ));
418 assert!(rx.try_recv().is_err(), "second respond must not send");
419 }
420}