Skip to main content

ply_engine/net/
mod.rs

1mod http;
2mod websocket;
3
4pub use http::{HttpConfig, Request, Response};
5pub use websocket::{WebSocket, WsConfig, WsMessage};
6
7use http::HttpRequestState;
8use websocket::WebSocketState;
9
10use rustc_hash::FxHashMap;
11use std::sync::{LazyLock, Mutex};
12
13/// Global manager for all network operations.
14#[allow(private_interfaces)]
15pub static NET_MANAGER: LazyLock<Mutex<NetManager>> =
16    LazyLock::new(|| Mutex::new(NetManager::new()));
17
18/// Generic wrapper that tracks how many frames since last access.
19pub(crate) struct Tracked<T> {
20    pub(crate) frames_not_accessed: usize,
21    pub(crate) state: T,
22}
23
24pub(crate) struct NetManager {
25    pub(crate) http_requests: FxHashMap<u64, Tracked<HttpRequestState>>,
26    pub(crate) websockets: FxHashMap<u64, Tracked<WebSocketState>>,
27    /// Number of frames after which an unused closed response is evicted.
28    pub max_frames_not_used: usize,
29}
30
31impl NetManager {
32    fn new() -> Self {
33        Self {
34            http_requests: FxHashMap::default(),
35            websockets: FxHashMap::default(),
36            max_frames_not_used: 60,
37        }
38    }
39
40    /// Frame-based eviction. Called once per frame from `eval()`.
41    pub fn clean(&mut self) {
42        self.http_requests.retain(|_, entry| {
43            // Try to receive if still pending
44            if let HttpRequestState::Pending(pending) = &mut entry.state {
45                if let Some(result) = pending.try_recv() {
46                    match result {
47                        Ok(resp) => {
48                            entry.state =
49                                HttpRequestState::Done(std::sync::Arc::new(resp));
50                        }
51                        Err(e) => {
52                            entry.state = HttpRequestState::Error(e);
53                        }
54                    }
55                }
56            }
57
58            match &entry.state {
59                HttpRequestState::Pending(_) => true, // never evict
60                _ => {
61                    entry.frames_not_accessed += 1;
62                    entry.frames_not_accessed <= self.max_frames_not_used
63                }
64            }
65        });
66
67        self.websockets.retain(|_, entry| {
68            entry.frames_not_accessed += 1;
69            let disconnected = entry.state.is_disconnected();
70            !(disconnected && entry.frames_not_accessed > self.max_frames_not_used)
71        });
72    }
73}
74
75fn hash_id(id: &str) -> u64 {
76    use std::hash::{Hash, Hasher};
77    let mut hasher = rustc_hash::FxHasher::default();
78    id.hash(&mut hasher);
79    hasher.finish()
80}
81
82fn fire_http(
83    method: &str,
84    id: &str,
85    url: &str,
86    f: impl FnOnce(&mut HttpConfig) -> &mut HttpConfig,
87) {
88    let key = hash_id(id);
89    let mut mgr = NET_MANAGER.lock().unwrap();
90
91    // Idempotent: don't re-fire if a request with this ID already exists
92    if mgr.http_requests.contains_key(&key) {
93        return;
94    }
95
96    let mut config = HttpConfig::new();
97    f(&mut config);
98
99    #[cfg(not(target_arch = "wasm32"))]
100    {
101        use http::PendingHttp;
102
103        let method = method.to_owned();
104        let url = url.to_owned();
105        let (tx, rx) = std::sync::mpsc::channel();
106
107        std::thread::spawn(move || {
108            let result: Result<Response, String> = (|| {
109                let agent = ureq::Agent::new_with_defaults();
110
111                macro_rules! apply_headers {
112                    ($req:expr, $headers:expr) => {{
113                        let mut r = $req;
114                        for (key, value) in $headers {
115                            r = r.header(key.as_str(), value.as_str());
116                        }
117                        r
118                    }};
119                }
120
121                let send_result = match method.as_str() {
122                    "GET" => {
123                        let req = apply_headers!(agent.get(&url), &config.headers);
124                        req.call()
125                    }
126                    "DELETE" => {
127                        let req = apply_headers!(agent.delete(&url), &config.headers);
128                        req.call()
129                    }
130                    "POST" => {
131                        let req = apply_headers!(agent.post(&url), &config.headers);
132                        if let Some(body) = &config.body {
133                            req.content_type("application/octet-stream").send(body)
134                        } else {
135                            req.send_empty()
136                        }
137                    }
138                    "PUT" => {
139                        let req = apply_headers!(agent.put(&url), &config.headers);
140                        if let Some(body) = &config.body {
141                            req.content_type("application/octet-stream").send(body)
142                        } else {
143                            req.send_empty()
144                        }
145                    }
146                    _ => return Err(format!("Unsupported HTTP method: {method}")),
147                };
148
149                match send_result {
150                    Ok(resp) => {
151                        let status: u16 = resp.status().into();
152                        let body = resp
153                            .into_body()
154                            .read_to_vec()
155                            .map_err(|e| e.to_string())?;
156                        Ok(Response::new(status, body))
157                    }
158                    Err(e) => Err(e.to_string()),
159                }
160            })();
161
162            let _ = tx.send(result);
163        });
164
165        mgr.http_requests.insert(
166            key,
167            Tracked {
168                frames_not_accessed: 0,
169                state: HttpRequestState::Pending(PendingHttp::new(rx)),
170            },
171        );
172    }
173
174    #[cfg(target_arch = "wasm32")]
175    {
176        use http::PendingHttp;
177        use sapp_jsutils::JsObject;
178
179        let scheme: i32 = match method {
180            "GET" => 0,
181            "POST" => 1,
182            "PUT" => 2,
183            "DELETE" => 3,
184            _ => return,
185        };
186
187        let headers_obj = JsObject::object();
188        for (key, value) in &config.headers {
189            headers_obj.set_field_string(key, value);
190        }
191
192        let body_str = config
193            .body
194            .as_ref()
195            .map(|b| String::from_utf8_lossy(b).to_string())
196            .unwrap_or_default();
197
198        let cid = unsafe {
199            http::ply_net_http_make_request(
200                scheme,
201                JsObject::string(url),
202                JsObject::string(&body_str),
203                headers_obj,
204            )
205        };
206
207        mgr.http_requests.insert(
208            key,
209            Tracked {
210                frames_not_accessed: 0,
211                state: HttpRequestState::Pending(PendingHttp::new(cid)),
212            },
213        );
214    }
215}
216
217/// Fire a GET request. Idempotent: won't re-fire if a request with this ID exists.
218pub fn get(id: &str, url: &str, f: impl FnOnce(&mut HttpConfig) -> &mut HttpConfig) {
219    fire_http("GET", id, url, f);
220}
221
222/// Fire a POST request. Idempotent: won't re-fire if a request with this ID exists.
223pub fn post(id: &str, url: &str, f: impl FnOnce(&mut HttpConfig) -> &mut HttpConfig) {
224    fire_http("POST", id, url, f);
225}
226
227/// Fire a PUT request. Idempotent: won't re-fire if a request with this ID exists.
228pub fn put(id: &str, url: &str, f: impl FnOnce(&mut HttpConfig) -> &mut HttpConfig) {
229    fire_http("PUT", id, url, f);
230}
231
232/// Fire a DELETE request. Idempotent: won't re-fire if a request with this ID exists.
233pub fn delete(id: &str, url: &str, f: impl FnOnce(&mut HttpConfig) -> &mut HttpConfig) {
234    fire_http("DELETE", id, url, f);
235}
236
237/// Get a handle to an existing HTTP request. Returns `None` if no such ID.
238pub fn request(id: &str) -> Option<Request> {
239    let key = hash_id(id);
240    let mut mgr = NET_MANAGER.lock().unwrap();
241    let entry = mgr.http_requests.get_mut(&key)?;
242    entry.frames_not_accessed = 0;
243    Some(Request { id: key })
244}
245
246/// Connect a WebSocket. Idempotent: won't reconnect if already open.
247pub fn ws_connect(
248    id: &str,
249    url: &str,
250    f: impl FnOnce(&mut WsConfig) -> &mut WsConfig,
251) {
252    let key = hash_id(id);
253    let mut mgr = NET_MANAGER.lock().unwrap();
254
255    if mgr.websockets.contains_key(&key) {
256        return;
257    }
258
259    let mut config = WsConfig::new();
260    f(&mut config);
261
262    #[cfg(not(target_arch = "wasm32"))]
263    {
264        let url = url.to_owned();
265        let (incoming_tx, incoming_rx) = std::sync::mpsc::channel();
266        let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::unbounded_channel();
267
268        let runtime = tokio::runtime::Builder::new_multi_thread()
269            .enable_all()
270            .build()
271            .expect("Failed to create tokio runtime for WebSocket");
272
273        let insecure = config.insecure;
274        let headers = config.headers;
275
276        runtime.spawn(async move {
277            use futures::{SinkExt, StreamExt};
278            use tokio_tungstenite::tungstenite;
279            use tungstenite::client::IntoClientRequest;
280
281            // Build handshake request: start from URL to get proper WS headers,
282            // then add custom headers on top.
283            let mut ws_request = match url.into_client_request() {
284                Ok(r) => r,
285                Err(e) => {
286                    let _ = incoming_tx.send(WsMessage::Error(e.to_string()));
287                    return;
288                }
289            };
290            for (key, value) in &headers {
291                if let (Ok(name), Ok(val)) = (
292                    tungstenite::http::header::HeaderName::from_bytes(key.as_bytes()),
293                    tungstenite::http::header::HeaderValue::from_str(value),
294                ) {
295                    ws_request.headers_mut().insert(name, val);
296                }
297            }
298
299            let socket = if insecure {
300                let tls_config = {
301                    let config = rustls::ClientConfig::builder()
302                        .dangerous()
303                        .with_custom_certificate_verifier(std::sync::Arc::new(
304                            NoCertificateVerification {},
305                        ))
306                        .with_no_client_auth();
307                    std::sync::Arc::new(config)
308                };
309                let connector = tokio_tungstenite::Connector::Rustls(tls_config);
310                tokio_tungstenite::connect_async_tls_with_config(
311                    ws_request,
312                    None,
313                    true,
314                    Some(connector),
315                )
316                .await
317            } else {
318                tokio_tungstenite::connect_async(ws_request).await
319            };
320
321            let (ws_stream, _response) = match socket {
322                Ok(s) => s,
323                Err(e) => {
324                    let _ = incoming_tx.send(WsMessage::Error(e.to_string()));
325                    return;
326                }
327            };
328
329            let _ = incoming_tx.send(WsMessage::Connected);
330
331            let (mut write_half, mut read_half) = ws_stream.split();
332
333            // Read task
334            let incoming_tx_read = incoming_tx.clone();
335            tokio::spawn(async move {
336                while let Some(msg) = read_half.next().await {
337                    match msg {
338                        Ok(tungstenite::Message::Binary(data)) => {
339                            if incoming_tx_read
340                                .send(WsMessage::Binary(data.into()))
341                                .is_err()
342                            {
343                                break;
344                            }
345                        }
346                        Ok(tungstenite::Message::Text(text)) => {
347                            if incoming_tx_read
348                                .send(WsMessage::Text(text.to_string()))
349                                .is_err()
350                            {
351                                break;
352                            }
353                        }
354                        Ok(tungstenite::Message::Close(_)) => {
355                            let _ = incoming_tx_read.send(WsMessage::Closed);
356                            break;
357                        }
358                        Err(e) => {
359                            let _ =
360                                incoming_tx_read.send(WsMessage::Error(e.to_string()));
361                            break;
362                        }
363                        _ => {}
364                    }
365                }
366            });
367
368            // Write task
369            let incoming_tx_write = incoming_tx.clone();
370            tokio::spawn(async move {
371                use websocket::OutgoingWsMessage;
372                while let Some(msg) = outgoing_rx.recv().await {
373                    match msg {
374                        OutgoingWsMessage::Text(text) => {
375                            if let Err(e) = write_half
376                                .send(tungstenite::Message::Text(text.into()))
377                                .await
378                            {
379                                let _ = incoming_tx_write
380                                    .send(WsMessage::Error(e.to_string()));
381                                break;
382                            }
383                        }
384                        OutgoingWsMessage::Binary(data) => {
385                            if let Err(e) = write_half
386                                .send(tungstenite::Message::Binary(data.into()))
387                                .await
388                            {
389                                let _ = incoming_tx_write
390                                    .send(WsMessage::Error(e.to_string()));
391                                break;
392                            }
393                        }
394                        OutgoingWsMessage::Close => {
395                            let _ = incoming_tx_write.send(WsMessage::Closed);
396                            let _ = write_half
397                                .send(tungstenite::Message::Close(None))
398                                .await;
399                            break;
400                        }
401                    }
402                }
403            });
404        });
405
406        mgr.websockets.insert(
407            key,
408            Tracked {
409                frames_not_accessed: 0,
410                state: WebSocketState {
411                    tx: outgoing_tx,
412                    rx: incoming_rx,
413                    _runtime: runtime,
414                },
415            },
416        );
417    }
418
419    #[cfg(target_arch = "wasm32")]
420    {
421        use sapp_jsutils::JsObject;
422
423        // JS bridge uses an integer socket ID; we use the lower 32 bits of the hash
424        let socket_id = key as i32;
425
426        unsafe {
427            websocket::ply_net_ws_connect(socket_id, JsObject::string(url));
428        }
429
430        mgr.websockets.insert(
431            key,
432            Tracked {
433                frames_not_accessed: 0,
434                state: WebSocketState { socket_id },
435            },
436        );
437    }
438}
439
440/// Get a handle to an existing WebSocket. Returns `None` if no such ID.
441pub fn ws(id: &str) -> Option<WebSocket> {
442    let key = hash_id(id);
443    let mut mgr = NET_MANAGER.lock().unwrap();
444    let entry = mgr.websockets.get_mut(&key)?;
445    entry.frames_not_accessed = 0;
446    Some(WebSocket { id: key })
447}
448
449#[cfg(not(target_arch = "wasm32"))]
450#[derive(Debug)]
451struct NoCertificateVerification;
452
453#[cfg(not(target_arch = "wasm32"))]
454impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
455    fn verify_server_cert(
456        &self,
457        _end_entity: &rustls::pki_types::CertificateDer<'_>,
458        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
459        _server_name: &rustls::pki_types::ServerName<'_>,
460        _ocsp_response: &[u8],
461        _now: rustls::pki_types::UnixTime,
462    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
463        Ok(rustls::client::danger::ServerCertVerified::assertion())
464    }
465
466    fn verify_tls12_signature(
467        &self,
468        _message: &[u8],
469        _cert: &rustls::pki_types::CertificateDer<'_>,
470        _dss: &rustls::DigitallySignedStruct,
471    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
472        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
473    }
474
475    fn verify_tls13_signature(
476        &self,
477        _message: &[u8],
478        _cert: &rustls::pki_types::CertificateDer<'_>,
479        _dss: &rustls::DigitallySignedStruct,
480    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
481        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
482    }
483
484    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
485        vec![
486            rustls::SignatureScheme::RSA_PKCS1_SHA256,
487            rustls::SignatureScheme::RSA_PKCS1_SHA384,
488            rustls::SignatureScheme::RSA_PKCS1_SHA512,
489            rustls::SignatureScheme::RSA_PSS_SHA256,
490            rustls::SignatureScheme::RSA_PSS_SHA384,
491            rustls::SignatureScheme::RSA_PSS_SHA512,
492            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
493            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
494            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
495            rustls::SignatureScheme::ED25519,
496            rustls::SignatureScheme::ED448,
497        ]
498    }
499}