use std::any::TypeId;
use std::collections::HashMap;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use bevy::ecs::event::GlobalTrigger;
use bevy::prelude::*;
use serde::Serialize;
use serde::de::DeserializeOwned;
use ts_rs::TS;
use crate::bridge::{OutboundResource, OutboundSender};
use crate::protocol::{Outbound, ResponseResult};
use crate::registry::{NamedEntry, register_entry};
use crate::ts_codegen::TsCollector;
pub struct RawRequest {
pub id: u64,
pub name: String,
pub value: serde_json::Value,
}
pub trait ReactRequest: DeserializeOwned + TS + Send + Sync + 'static {
const NAME: &'static str;
type Response: Serialize + TS + Send + Sync + 'static;
}
pub struct Responder<R> {
id: u64,
tx: OutboundSender,
done: Arc<AtomicBool>,
_marker: PhantomData<fn() -> R>,
}
impl<R> Clone for Responder<R> {
fn clone(&self) -> Self {
Self {
id: self.id,
tx: self.tx.clone(),
done: self.done.clone(),
_marker: PhantomData,
}
}
}
impl<R: Serialize> Responder<R> {
fn new(id: u64, tx: OutboundSender) -> Self {
Self {
id,
tx,
done: Arc::new(AtomicBool::new(false)),
_marker: PhantomData,
}
}
pub fn respond(&self, value: R) {
if !self.claim() {
return;
}
let result = match serde_json::to_value(&value) {
Ok(value) => ResponseResult::Ok { value },
Err(e) => ResponseResult::Err {
message: format!("serialize response: {e}"),
},
};
let _ = self.tx.send(Outbound::Response {
id: self.id,
result,
});
}
pub fn respond_err(&self, message: impl Into<String>) {
if !self.claim() {
return;
}
let _ = self.tx.send(Outbound::Response {
id: self.id,
result: ResponseResult::Err {
message: message.into(),
},
});
}
fn claim(&self) -> bool {
if self.done.swap(true, Ordering::SeqCst) {
warn!(
"react request {} responded to more than once; ignoring",
self.id
);
false
} else {
true
}
}
}
pub struct Request<T: ReactRequest> {
payload: T,
responder: Responder<T::Response>,
}
impl<T: ReactRequest> Request<T> {
pub fn payload(&self) -> &T {
&self.payload
}
pub fn into_payload(self) -> T {
self.payload
}
pub fn responder(&self) -> Responder<T::Response> {
self.responder.clone()
}
pub fn respond(&self, value: T::Response) {
self.responder.respond(value);
}
pub fn respond_err(&self, message: impl Into<String>) {
self.responder.respond_err(message);
}
}
impl<T: ReactRequest> Event for Request<T> {
type Trigger<'a> = GlobalTrigger;
}
pub trait RequestEvent {
type Req: ReactRequest;
}
impl<T: ReactRequest> RequestEvent for Request<T> {
type Req = T;
}
type RequestHandler = Box<dyn Fn(RawRequest, &OutboundSender, &mut Commands) + Send + Sync>;
pub(crate) struct RequestRegistration {
type_id: TypeId,
handler: RequestHandler,
pub(crate) ts_request_name: fn() -> String,
pub(crate) ts_response_name: fn() -> String,
pub(crate) request_is_void: fn() -> bool,
pub(crate) ts_collect: fn(&mut TsCollector),
}
#[derive(Resource, Default)]
pub(crate) struct ReactRequestRegistry {
pub(crate) handlers: HashMap<&'static str, RequestRegistration>,
}
impl NamedEntry for RequestRegistration {
fn type_id(&self) -> TypeId {
self.type_id
}
}
impl ReactRequestRegistry {
pub(crate) fn register<T: ReactRequest>(&mut self) {
register_entry(
&mut self.handlers,
T::NAME,
"request",
RequestRegistration {
type_id: TypeId::of::<T>(),
handler: Box::new(|raw, tx, commands| {
let responder = Responder::<T::Response>::new(raw.id, tx.clone());
match serde_json::from_value::<T>(raw.value) {
Ok(payload) => commands.trigger(Request { payload, responder }),
Err(e) => {
responder.respond_err(format!("malformed request {:?}: {e}", T::NAME))
}
}
}),
ts_request_name: <T as TS>::name,
ts_response_name: <T::Response as TS>::name,
request_is_void: || <T as TS>::inline() == "null",
ts_collect: |c| {
if <T as TS>::inline() != "null" {
c.add::<T>();
}
c.add::<T::Response>();
},
},
);
}
pub(crate) fn dispatch(&self, raw: RawRequest, tx: &OutboundSender, commands: &mut Commands) {
match self.handlers.get(raw.name.as_str()) {
Some(reg) => (reg.handler)(raw, tx, commands),
None => {
let _ = tx.send(Outbound::Response {
id: raw.id,
result: ResponseResult::Err {
message: format!("no handler registered for request {:?}", raw.name),
},
});
}
}
}
}
#[derive(Resource)]
pub(crate) struct RequestReceiver(pub(crate) crossbeam_channel::Receiver<RawRequest>);
pub(crate) fn dispatch_react_requests(
rx: Res<RequestReceiver>,
registry: Res<ReactRequestRegistry>,
out: Res<OutboundResource>,
mut commands: Commands,
) {
while let Ok(raw) = rx.0.try_recv() {
registry.dispatch(raw, &out.0, &mut commands);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ReactAppExt;
use bevy::ecs::world::CommandQueue;
use tokio::sync::mpsc::{UnboundedReceiver, unbounded_channel};
#[crate::react_request(name = "ping", response = Pong)]
struct Ping {
n: u32,
}
#[derive(serde::Serialize, ts_rs::TS)]
struct Pong {
n: u32,
}
fn dispatch(app: &mut App, tx: &OutboundSender, raw: RawRequest) {
app.world_mut()
.resource_scope(|world, registry: Mut<ReactRequestRegistry>| {
let mut queue = CommandQueue::default();
let mut commands = Commands::new(&mut queue, world);
registry.dispatch(raw, tx, &mut commands);
queue.apply(world);
});
}
fn raw(id: u64, name: &str, value: serde_json::Value) -> RawRequest {
RawRequest {
id,
name: name.into(),
value,
}
}
#[test]
fn dispatches_and_responds() {
let mut app = App::new();
app.add_react_request_handler(|req: On<Request<Ping>>| {
let n = req.payload().n;
req.respond(Pong { n: n + 1 });
});
let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
dispatch(
&mut app,
&tx,
raw(7, "ping", serde_json::json!({ "n": 41 })),
);
match rx.try_recv() {
Ok(Outbound::Response {
id,
result: ResponseResult::Ok { value },
}) => {
assert_eq!(id, 7);
assert_eq!(value, serde_json::json!({ "n": 42 }));
}
other => panic!("expected Ok response, got {other:?}"),
}
}
#[test]
fn unknown_name_replies_err() {
let mut app = App::new();
app.init_resource::<ReactRequestRegistry>();
let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
dispatch(&mut app, &tx, raw(1, "nope", serde_json::json!(null)));
assert!(matches!(
rx.try_recv(),
Ok(Outbound::Response {
id: 1,
result: ResponseResult::Err { .. },
})
));
}
#[test]
fn malformed_payload_replies_err() {
let mut app = App::new();
app.add_react_request_handler(|req: On<Request<Ping>>| req.respond(Pong { n: 0 }));
let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
dispatch(
&mut app,
&tx,
raw(2, "ping", serde_json::json!({ "n": "nope" })),
);
assert!(matches!(
rx.try_recv(),
Ok(Outbound::Response {
id: 2,
result: ResponseResult::Err { .. },
})
));
}
#[test]
fn respond_twice_sends_once() {
let mut app = App::new();
app.add_react_request_handler(|req: On<Request<Ping>>| {
req.respond(Pong { n: 1 });
req.respond(Pong { n: 2 }); });
let (tx, mut rx): (OutboundSender, UnboundedReceiver<Outbound>) = unbounded_channel();
dispatch(&mut app, &tx, raw(3, "ping", serde_json::json!({ "n": 0 })));
assert!(matches!(
rx.try_recv(),
Ok(Outbound::Response {
result: ResponseResult::Ok { .. },
..
})
));
assert!(rx.try_recv().is_err(), "second respond must not send");
}
}