1use std::net::{IpAddr, Ipv4Addr, TcpListener, TcpStream};
2
3use bevy::{
4 prelude::*,
5 remote::{error_codes, BrpError, BrpRequest, BrpResponse},
6 tasks::IoTaskPool,
7};
8use futures_util::{
9 stream::{SplitSink, SplitStream},
10 SinkExt, StreamExt,
11};
12use http_body_util::Full;
13use hyper::{
14 body::{Bytes, Incoming},
15 header::{
16 HeaderValue, ACCESS_CONTROL_ALLOW_METHODS, ACCESS_CONTROL_ALLOW_ORIGIN,
17 ACCESS_CONTROL_MAX_AGE, ORIGIN,
18 },
19 server::conn::http1,
20 service, Method, Request, Response,
21};
22use hyper_tungstenite::{HyperWebsocket, HyperWebsocketStream};
23use serde_json::Value;
24use smol::{
25 channel::{self, Receiver, Sender},
26 Async,
27};
28use smol_hyper::rt::{FuturesIo, SmolTimer};
29use tungstenite::Message;
30
31use crate::{BrpStreamMessage, StreamClientId, StreamMessage, StreamMessageKind, StreamSender};
32
33pub const DEFAULT_PORT: u16 = 3000;
35
36pub const DEFAULT_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
38pub struct RemoteStreamWebSocketPlugin {
45 address: IpAddr,
47
48 port: u16,
50}
51
52impl RemoteStreamWebSocketPlugin {
53 #[must_use]
55 pub fn with_address(mut self, address: impl Into<IpAddr>) -> Self {
56 self.address = address.into();
57 self
58 }
59
60 #[must_use]
62 pub fn with_port(mut self, port: u16) -> Self {
63 self.port = port;
64 self
65 }
66}
67
68impl Default for RemoteStreamWebSocketPlugin {
69 fn default() -> Self {
70 Self {
71 address: DEFAULT_ADDR,
72 port: DEFAULT_PORT,
73 }
74 }
75}
76
77impl Plugin for RemoteStreamWebSocketPlugin {
78 fn build(&self, app: &mut App) {
79 app.insert_resource(HostAddress(self.address))
80 .insert_resource(HostPort(self.port))
81 .add_systems(Startup, start_server);
82 }
83}
84
85#[derive(Debug, Resource)]
86pub struct HostAddress(pub IpAddr);
87
88#[derive(Debug, Resource, Reflect)]
89pub struct HostPort(pub u16);
90
91fn start_server(sender: Res<StreamSender>, address: Res<HostAddress>, remote_port: Res<HostPort>) {
92 IoTaskPool::get()
93 .spawn(server_main(address.0, remote_port.0, sender.clone()))
94 .detach();
95}
96
97struct TcpClient {
98 id: StreamClientId,
99 stream: Async<TcpStream>,
100}
101
102async fn server_main(
103 address: IpAddr,
104 port: u16,
105 request_sender: Sender<StreamMessage>,
106) -> anyhow::Result<()> {
107 let listener = Async::<TcpListener>::bind((address, port))?;
108 let mut client_id: usize = 0;
109 loop {
110 let (stream, _) = listener.accept().await?;
111 client_id = client_id.wrapping_add(1);
112 let client = TcpClient {
113 id: StreamClientId(client_id),
114 stream,
115 };
116 let request_sender = request_sender.clone();
117 IoTaskPool::get()
118 .spawn(async move {
119 let _ = handle_client(client, request_sender).await;
120 })
121 .detach();
122 }
123}
124
125async fn handle_client(
126 client: TcpClient,
127 request_sender: Sender<StreamMessage>,
128) -> anyhow::Result<()> {
129 http1::Builder::new()
130 .keep_alive(true)
131 .timer(SmolTimer::new())
132 .serve_connection(
133 FuturesIo::new(client.stream),
134 service::service_fn(|request| process_request(request, &request_sender, client.id)),
135 )
136 .with_upgrades()
137 .await?;
138
139 Ok(())
140}
141
142async fn process_request(
143 mut request: Request<Incoming>,
144 request_sender: &Sender<StreamMessage>,
145 client_id: StreamClientId,
146) -> anyhow::Result<Response<Full<Bytes>>> {
147 let default_origin = HeaderValue::from_static("");
148 let origin = request.headers().get(ORIGIN).unwrap_or(&default_origin);
149
150 if request.method() == Method::OPTIONS {
151 let response = Response::builder()
152 .status(200)
153 .header(ACCESS_CONTROL_ALLOW_METHODS, "*")
154 .header(ACCESS_CONTROL_ALLOW_ORIGIN, origin)
155 .header(ACCESS_CONTROL_MAX_AGE, "86400")
156 .body(Full::new(Bytes::new()))?;
157
158 return Ok(response);
159 }
160
161 if hyper_tungstenite::is_upgrade_request(&request) {
162 let (response, websocket) = hyper_tungstenite::upgrade(&mut request, None)?;
163
164 let body = match validate_websocket_request(&request) {
165 Ok(body) => body,
166 Err(err) => {
167 let response = serde_json::to_string(&BrpError {
168 code: error_codes::INVALID_REQUEST,
169 message: format!("{err}"),
170 data: None,
171 })?;
172
173 return Ok(Response::new(Full::new(response.into_bytes().into())));
174 }
175 };
176
177 IoTaskPool::get()
178 .spawn(process_websocket_stream(
179 websocket,
180 request_sender.clone(),
181 body,
182 client_id,
183 ))
184 .detach();
185
186 return Ok(response);
187 }
188
189 let response_body = serde_json::to_string(&BrpError {
190 code: error_codes::INVALID_REQUEST,
191 message: "Invalid request".into(),
192 data: None,
193 })?;
194
195 let response = Response::builder()
196 .status(400)
197 .header(ACCESS_CONTROL_ALLOW_ORIGIN, origin)
198 .body(Full::new(response_body.into_bytes().into()))
199 .unwrap();
200
201 return Ok(response);
202}
203
204async fn process_websocket_stream(
205 ws: HyperWebsocket,
206 request_sender: Sender<StreamMessage>,
207 request: BrpRequest,
208 client_id: StreamClientId,
209) -> anyhow::Result<()> {
210 let ws = ws.await?;
211
212 let (write_stream, read_stream) = ws.split();
213
214 let (result_sender, result_receiver) = channel::bounded(32);
215
216 IoTaskPool::get()
217 .spawn(send_stream_response(write_stream, result_receiver))
218 .detach();
219
220 send_stream_message(
221 read_stream,
222 request_sender.clone(),
223 request,
224 result_sender,
225 client_id,
226 )
227 .await?;
228
229 Ok(())
230}
231
232const QUERY_KEY: &str = "body";
233
234fn validate_websocket_request(request: &Request<Incoming>) -> anyhow::Result<BrpRequest> {
235 let body = request
236 .uri()
237 .query()
238 .and_then(|query| {
239 for pair in query.split('&') {
241 let mut it = pair.split('=').take(2);
242 match (it.next(), it.next()) {
243 (Some(k), Some(v)) if k == QUERY_KEY => return Some(v),
244 _ => {}
245 };
246 }
247 None
248 })
249 .ok_or_else(|| anyhow::anyhow!("Missing body"))?;
250
251 let body = urlencoding::decode(body)?.into_owned();
252
253 match serde_json::from_str::<BrpRequest>(&body) {
254 Ok(req) => {
255 if req.jsonrpc != "2.0" {
256 anyhow::bail!("JSON-RPC request requires `\"jsonrpc\": \"2.0\"`")
257 }
258
259 Ok(req)
260 }
261 Err(err) => anyhow::bail!(err),
262 }
263}
264
265async fn send_stream_message(
266 mut stream: SplitStream<HyperWebsocketStream>,
267 sender: Sender<StreamMessage>,
268 request: BrpRequest,
269 result_sender: Sender<BrpResponse>,
270 client_id: StreamClientId,
271) -> anyhow::Result<()> {
272 let _ = sender
273 .send(StreamMessage {
274 client_id,
275 kind: StreamMessageKind::Connect(
276 request.id,
277 BrpStreamMessage {
278 method: request.method,
279 params: request.params,
280 sender: result_sender,
281 },
282 ),
283 })
284 .await?;
285 while let Some(message) = stream.next().await {
286 match message {
287 Ok(Message::Text(text)) => {
288 let msg = serde_json::from_str::<Value>(&text)?;
289 let _ = sender
290 .send(StreamMessage {
291 client_id,
292 kind: StreamMessageKind::Data(msg),
293 })
294 .await?;
295 }
296 Ok(Message::Close(_)) | Err(_) => return Ok(()),
297 _ => {}
298 }
299 }
300 let _ = sender
301 .send(StreamMessage {
302 client_id,
303 kind: StreamMessageKind::Disconnect,
304 })
305 .await?;
306
307 Ok(())
308}
309
310async fn send_stream_response(
311 mut stream: SplitSink<HyperWebsocketStream, Message>,
312 result_receiver: Receiver<BrpResponse>,
313) -> anyhow::Result<()> {
314 while let Ok(response) = result_receiver.recv().await {
315 let response = serde_json::to_string(&response)?;
316 stream.send(Message::text(response)).await?;
317 }
318
319 Ok(())
320}