bevy_remote_stream/
websocket.rs

1use std::net::{IpAddr, Ipv4Addr, TcpListener, TcpStream};
2
3use bevy::{
4    prelude::*,
5    remote::{error_codes, BrpError, BrpRequest, BrpResponse},
6    tasks::IoTaskPool,
7};
8use futures_util::{
9    stream::{SplitSink, SplitStream},
10    SinkExt, StreamExt,
11};
12use http_body_util::Full;
13use hyper::{
14    body::{Bytes, Incoming},
15    header::{
16        HeaderValue, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
17        ACCESS_CONTROL_MAX_AGE, ORIGIN,
18    },
19    server::conn::http1,
20    service, Method, Request, Response,
21};
22use hyper_tungstenite::{HyperWebsocket, HyperWebsocketStream};
23use serde_json::Value;
24use smol::{
25    channel::{self, Receiver, Sender},
26    Async,
27};
28use smol_hyper::rt::{FuturesIo, SmolTimer};
29use tungstenite::Message;
30
31use crate::{BrpStreamMessage, StreamClientId, StreamMessage, StreamMessageKind, StreamSender};
32
33/// The default port that the WebSocket server will listen on.
34pub const DEFAULT_PORT: u16 = 3000;
35
36/// The default host address that WebSocket server will use.
37pub const DEFAULT_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
38/// Add this plugin to your [`App`] to allow remote connections to inspect and modify entities.
39///
40/// The defaults are:
41/// - [`DEFAULT_ADDR`] : 127.0.0.1.
42/// - [`DEFAULT_PORT`] : 3000.
43///
44pub struct RemoteStreamWebSocketPlugin {
45    /// The address that the WebSocket server will use.
46    address: IpAddr,
47
48    /// The port that the WebSocket server will listen on.
49    port: u16,
50}
51
52impl RemoteStreamWebSocketPlugin {
53    /// Set the IP address that the server will use.
54    #[must_use]
55    pub fn with_address(mut self, address: impl Into<IpAddr>) -> Self {
56        self.address = address.into();
57        self
58    }
59
60    /// Set the remote port that the server will listen on.
61    #[must_use]
62    pub fn with_port(mut self, port: u16) -> Self {
63        self.port = port;
64        self
65    }
66}
67
68impl Default for RemoteStreamWebSocketPlugin {
69    fn default() -> Self {
70        Self {
71            address: DEFAULT_ADDR,
72            port: DEFAULT_PORT,
73        }
74    }
75}
76
77impl Plugin for RemoteStreamWebSocketPlugin {
78    fn build(&self, app: &mut App) {
79        app.insert_resource(HostAddress(self.address))
80            .insert_resource(HostPort(self.port))
81            .add_systems(Startup, start_server);
82    }
83}
84
85#[derive(Debug, Resource)]
86pub struct HostAddress(pub IpAddr);
87
88#[derive(Debug, Resource, Reflect)]
89pub struct HostPort(pub u16);
90
91fn start_server(sender: Res<StreamSender>, address: Res<HostAddress>, remote_port: Res<HostPort>) {
92    IoTaskPool::get()
93        .spawn(server_main(address.0, remote_port.0, sender.clone()))
94        .detach();
95}
96
97struct TcpClient {
98    id: StreamClientId,
99    stream: Async<TcpStream>,
100}
101
102async fn server_main(
103    address: IpAddr,
104    port: u16,
105    request_sender: Sender<StreamMessage>,
106) -> anyhow::Result<()> {
107    let listener = Async::<TcpListener>::bind((address, port))?;
108    let mut client_id: usize = 0;
109    loop {
110        let (stream, _) = listener.accept().await?;
111        client_id = client_id.wrapping_add(1);
112        let client = TcpClient {
113            id: StreamClientId(client_id),
114            stream,
115        };
116        let request_sender = request_sender.clone();
117        IoTaskPool::get()
118            .spawn(async move {
119                let _ = handle_client(client, request_sender).await;
120            })
121            .detach();
122    }
123}
124
125async fn handle_client(
126    client: TcpClient,
127    request_sender: Sender<StreamMessage>,
128) -> anyhow::Result<()> {
129    http1::Builder::new()
130        .keep_alive(true)
131        .timer(SmolTimer::new())
132        .serve_connection(
133            FuturesIo::new(client.stream),
134            service::service_fn(|request| process_request(request, &request_sender, client.id)),
135        )
136        .with_upgrades()
137        .await?;
138
139    Ok(())
140}
141
142async fn process_request(
143    mut request: Request<Incoming>,
144    request_sender: &Sender<StreamMessage>,
145    client_id: StreamClientId,
146) -> anyhow::Result<Response<Full<Bytes>>> {
147    let default_origin = HeaderValue::from_static("");
148    let origin = request.headers().get(ORIGIN).unwrap_or(&default_origin);
149
150    if request.method() == Method::OPTIONS {
151        let response = Response::builder()
152            .status(200)
153            .header(ACCESS_CONTROL_ALLOW_METHODS, "*")
154            .header(ACCESS_CONTROL_ALLOW_ORIGIN, origin)
155            .header(ACCESS_CONTROL_MAX_AGE, "86400")
156            .body(Full::new(Bytes::new()))?;
157
158        return Ok(response);
159    }
160
161    if hyper_tungstenite::is_upgrade_request(&request) {
162        let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)?;
163
164        let body = match validate_websocket_request(&request) {
165            Ok(body) => body,
166            Err(err) => {
167                let response = serde_json::to_string(&BrpError {
168                    code: error_codes::INVALID_REQUEST,
169                    message: format!("{err}"),
170                    data: None,
171                })?;
172
173                return Ok(Response::new(Full::new(response.into_bytes().into())));
174            }
175        };
176
177        IoTaskPool::get()
178            .spawn(process_websocket_stream(
179                websocket,
180                request_sender.clone(),
181                body,
182                client_id,
183            ))
184            .detach();
185
186        return Ok(response);
187    }
188
189    let response_body = serde_json::to_string(&BrpError {
190        code: error_codes::INVALID_REQUEST,
191        message: "Invalid request".into(),
192        data: None,
193    })?;
194
195    let response = Response::builder()
196        .status(400)
197        .header(ACCESS_CONTROL_ALLOW_ORIGIN, origin)
198        .body(Full::new(response_body.into_bytes().into()))
199        .unwrap();
200
201    return Ok(response);
202}
203
204async fn process_websocket_stream(
205    ws: HyperWebsocket,
206    request_sender: Sender<StreamMessage>,
207    request: BrpRequest,
208    client_id: StreamClientId,
209) -> anyhow::Result<()> {
210    let ws = ws.await?;
211
212    let (write_stream, read_stream) = ws.split();
213
214    let (result_sender, result_receiver) = channel::bounded(32);
215
216    IoTaskPool::get()
217        .spawn(send_stream_response(write_stream, result_receiver))
218        .detach();
219
220    send_stream_message(
221        read_stream,
222        request_sender.clone(),
223        request,
224        result_sender,
225        client_id,
226    )
227    .await?;
228
229    Ok(())
230}
231
232const QUERY_KEY: &str = "body";
233
234fn validate_websocket_request(request: &Request<Incoming>) -> anyhow::Result<BrpRequest> {
235    let body = request
236        .uri()
237        .query()
238        .and_then(|query| {
239            // Simple query string parsing
240            for pair in query.split('&') {
241                let mut it = pair.split('=').take(2);
242                match (it.next(), it.next()) {
243                    (Some(k), Some(v)) if k == QUERY_KEY => return Some(v),
244                    _ => {}
245                };
246            }
247            None
248        })
249        .ok_or_else(|| anyhow::anyhow!("Missing body"))?;
250
251    let body = urlencoding::decode(body)?.into_owned();
252
253    match serde_json::from_str::<BrpRequest>(&body) {
254        Ok(req) => {
255            if req.jsonrpc != "2.0" {
256                anyhow::bail!("JSON-RPC request requires `\"jsonrpc\": \"2.0\"`")
257            }
258
259            Ok(req)
260        }
261        Err(err) => anyhow::bail!(err),
262    }
263}
264
265async fn send_stream_message(
266    mut stream: SplitStream<HyperWebsocketStream>,
267    sender: Sender<StreamMessage>,
268    request: BrpRequest,
269    result_sender: Sender<BrpResponse>,
270    client_id: StreamClientId,
271) -> anyhow::Result<()> {
272    let _ = sender
273        .send(StreamMessage {
274            client_id,
275            kind: StreamMessageKind::Connect(
276                request.id,
277                BrpStreamMessage {
278                    method: request.method,
279                    params: request.params,
280                    sender: result_sender,
281                },
282            ),
283        })
284        .await?;
285    while let Some(message) = stream.next().await {
286        match message {
287            Ok(Message::Text(text)) => {
288                let msg = serde_json::from_str::<Value>(&text)?;
289                let _ = sender
290                    .send(StreamMessage {
291                        client_id,
292                        kind: StreamMessageKind::Data(msg),
293                    })
294                    .await?;
295            }
296            Ok(Message::Close(_)) | Err(_) => return Ok(()),
297            _ => {}
298        }
299    }
300    let _ = sender
301        .send(StreamMessage {
302            client_id,
303            kind: StreamMessageKind::Disconnect,
304        })
305        .await?;
306
307    Ok(())
308}
309
310async fn send_stream_response(
311    mut stream: SplitSink<HyperWebsocketStream, Message>,
312    result_receiver: Receiver<BrpResponse>,
313) -> anyhow::Result<()> {
314    while let Ok(response) = result_receiver.recv().await {
315        let response = serde_json::to_string(&response)?;
316        stream.send(Message::text(response)).await?;
317    }
318
319    Ok(())
320}