Skip to main content

lex_runtime/
ws.rs

1//! WebSocket server + chat-broadcast registry.
2//!
3//! `net.serve_ws(port, on_message)` blocks on a TCP listener, upgrades
4//! each incoming connection to WebSocket, and runs a per-connection
5//! worker thread that polls both inbound (calls Lex's `on_message`)
6//! and outbound (drains broadcasts from a channel into the socket).
7//!
8//! `chat.broadcast(room, body)` looks up every connection in `room`
9//! and pushes `body` onto its outbound channel. `chat.send(conn_id,
10//! body)` is the same but to a single connection.
11//!
12//! The registry is an `Arc<Mutex<…>>` because Lex's immutability means
13//! shared mutable state has to live in the host runtime. Lex code
14//! stays pure: it receives an event, returns Nil, and any side
15//! effects go through `chat.*` which is gated by the policy.
16
17// tungstenite's `accept_hdr` callback takes/returns a tungstenite
18// `ErrorResponse` which is large; we only ever return Ok so the
19// large-Err warning is noise.
20#![allow(clippy::result_large_err)]
21
22use crate::policy::Policy;
23use indexmap::IndexMap;
24use lex_bytecode::vm::Vm;
25use lex_bytecode::{Program, Value};
26use std::net::TcpListener;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::mpsc;
29use std::sync::{Arc, Mutex};
30use std::thread;
31use std::time::Duration;
32
33/// Per-connection state held in the global registry.
34struct Conn {
35    room: String,
36    /// Channel writer end. The connection's worker thread reads from
37    /// the corresponding Receiver and writes each message to the
38    /// WebSocket. Broadcasts push here.
39    outbound: mpsc::Sender<String>,
40}
41
42/// Global chat registry. One per `net.serve_ws` invocation.
43#[derive(Default)]
44pub struct ChatRegistry {
45    conns: Mutex<IndexMap<u64, Conn>>,
46}
47
48impl ChatRegistry {
49    fn register(&self, room: String, outbound: mpsc::Sender<String>) -> u64 {
50        static NEXT_ID: AtomicU64 = AtomicU64::new(1);
51        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
52        self.conns.lock().unwrap().insert(id, Conn { room, outbound });
53        id
54    }
55    fn unregister(&self, id: u64) {
56        self.conns.lock().unwrap().shift_remove(&id);
57    }
58    fn broadcast(&self, room: &str, body: &str) {
59        let conns = self.conns.lock().unwrap();
60        for c in conns.values() {
61            if c.room == room {
62                let _ = c.outbound.send(body.to_string());
63            }
64        }
65    }
66    fn send_to(&self, id: u64, body: &str) -> bool {
67        if let Some(c) = self.conns.lock().unwrap().get(&id) {
68            let _ = c.outbound.send(body.to_string());
69            true
70        } else {
71            false
72        }
73    }
74}
75
76/// `chat.broadcast(room, body)` — looked up at runtime by the
77/// effect handler; called from inside the Lex VM.
78pub fn chat_broadcast(reg: &Arc<ChatRegistry>, room: &str, body: &str) {
79    reg.broadcast(room, body);
80}
81
82pub fn chat_send(reg: &Arc<ChatRegistry>, conn_id: u64, body: &str) -> bool {
83    reg.send_to(conn_id, body)
84}
85
86/// Bind a WebSocket server. Blocks; returns Unit on shutdown (the
87/// process is normally killed before that).
88pub fn serve_ws(
89    port: u16,
90    handler_name: String,
91    program: Arc<Program>,
92    policy: Policy,
93    registry: Arc<ChatRegistry>,
94) -> Result<Value, String> {
95    let listener = TcpListener::bind(("127.0.0.1", port))
96        .map_err(|e| format!("net.serve_ws bind {port}: {e}"))?;
97    eprintln!("net.serve_ws: listening on ws://127.0.0.1:{port}");
98    for stream in listener.incoming() {
99        let stream = match stream {
100            Ok(s) => s,
101            Err(e) => { eprintln!("net.serve_ws accept: {e}"); continue; }
102        };
103        let program = Arc::clone(&program);
104        let policy = policy.clone();
105        let handler_name = handler_name.clone();
106        let registry = Arc::clone(&registry);
107        thread::spawn(move || {
108            if let Err(e) = handle_connection(stream, program, policy, handler_name, registry) {
109                eprintln!("net.serve_ws connection error: {e}");
110            }
111        });
112    }
113    Ok(Value::Unit)
114}
115
116fn handle_connection(
117    stream: std::net::TcpStream,
118    program: Arc<Program>,
119    policy: Policy,
120    handler_name: String,
121    registry: Arc<ChatRegistry>,
122) -> Result<(), String> {
123    use tungstenite::{accept_hdr, handshake::server::{Request, Response}};
124
125    // Capture the request path during the handshake — used as the room name.
126    let mut path = String::new();
127    let path_ref = &mut path;
128    let mut ws = accept_hdr(stream, |req: &Request, resp: Response| {
129        *path_ref = req.uri().path().to_string();
130        Ok(resp)
131    }).map_err(|e| format!("ws handshake: {e}"))?;
132
133    let room = path.trim_start_matches('/').to_string();
134
135    // Outbound channel: broadcast/send pushes here, this thread writes
136    // each message into the WebSocket.
137    let (tx, rx) = mpsc::channel::<String>();
138    let conn_id = registry.register(room.clone(), tx);
139
140    // Make WS reads non-blocking-ish so the same thread can also drain
141    // the outbound channel. tungstenite reads through the underlying
142    // TcpStream; setting a short read timeout lets us multiplex.
143    let _ = ws.get_mut().set_read_timeout(Some(Duration::from_millis(50)));
144
145    let result = run_loop(&mut ws, &rx, conn_id, &room, &program, &policy, &handler_name, &registry);
146    registry.unregister(conn_id);
147    let _ = ws.close(None);
148    result
149}
150
151#[allow(clippy::too_many_arguments)]
152fn run_loop(
153    ws: &mut tungstenite::WebSocket<std::net::TcpStream>,
154    rx: &mpsc::Receiver<String>,
155    conn_id: u64,
156    room: &str,
157    program: &Arc<Program>,
158    policy: &Policy,
159    handler_name: &str,
160    registry: &Arc<ChatRegistry>,
161) -> Result<(), String> {
162    use tungstenite::Message;
163    use std::io::ErrorKind;
164    loop {
165        // 1) Try to read one inbound message. WouldBlock = no data yet.
166        match ws.read() {
167            Ok(Message::Text(body)) => {
168                let ev = build_ws_event(conn_id, room, &body);
169                let handler = crate::handler::DefaultHandler::new(policy.clone())
170                    .with_program(Arc::clone(program))
171                    .with_chat_registry(Arc::clone(registry));
172                let mut vm = Vm::with_handler(program, Box::new(handler));
173                if let Err(e) = vm.call(handler_name, vec![ev]) {
174                    eprintln!("on_message {conn_id}: {e}");
175                }
176            }
177            Ok(Message::Binary(_)) => { /* binary frames ignored in v1 */ }
178            Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => break,
179            Ok(_) => {} // ping/pong/frame
180            Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock
181                || e.kind() == ErrorKind::TimedOut => {}
182            Err(e) => return Err(format!("ws read: {e}")),
183        }
184        // 2) Drain outbound channel. Doesn't block.
185        loop {
186            match rx.try_recv() {
187                Ok(msg) => {
188                    if let Err(e) = ws.send(Message::Text(msg.into())) {
189                        return Err(format!("ws send: {e}"));
190                    }
191                }
192                Err(mpsc::TryRecvError::Empty) => break,
193                Err(mpsc::TryRecvError::Disconnected) => return Ok(()),
194            }
195        }
196    }
197    Ok(())
198}
199
200fn build_ws_event(conn_id: u64, room: &str, body: &str) -> Value {
201    let mut rec = IndexMap::new();
202    rec.insert("body".into(), Value::Str(body.to_string()));
203    rec.insert("conn_id".into(), Value::Int(conn_id as i64));
204    rec.insert("room".into(), Value::Str(room.to_string()));
205    Value::Record(rec)
206}
207
208// ── Closure-based WebSocket server (#359) ────────────────────────────────────
209
210/// Build a `WsConn` record value for the typed closure-based handler.
211fn build_ws_conn(conn_id: u64, path: &str, subprotocol: &str) -> Value {
212    let mut rec = IndexMap::new();
213    rec.insert("id".into(), Value::Str(conn_id.to_string()));
214    rec.insert("path".into(), Value::Str(path.to_string()));
215    rec.insert("subprotocol".into(), Value::Str(subprotocol.to_string()));
216    Value::Record(rec)
217}
218
219/// Build a `WsMessage` variant value.
220fn build_ws_message_text(body: &str) -> Value {
221    Value::Variant { name: "WsText".into(), args: vec![Value::Str(body.to_string())] }
222}
223
224fn build_ws_message_close() -> Value {
225    Value::Variant { name: "WsClose".into(), args: vec![] }
226}
227
228fn build_ws_message_ping() -> Value {
229    Value::Variant { name: "WsPing".into(), args: vec![] }
230}
231
232fn build_ws_message_binary(payload: &[u8]) -> Value {
233    let bytes = payload.iter().map(|b| Value::Int(*b as i64)).collect();
234    Value::Variant { name: "WsBinary".into(), args: vec![Value::List(bytes)] }
235}
236
237/// Interpret a `WsAction` variant and send the appropriate frame.
238/// Generic over the stream so this serves both the plaintext-only
239/// server path (`TcpStream`) and the dial path that may sit on top
240/// of a TLS-wrapped stream (`MaybeTlsStream<TcpStream>`).
241fn apply_ws_action<S: std::io::Read + std::io::Write>(
242    action: &Value,
243    ws: &mut tungstenite::WebSocket<S>,
244) -> Result<(), String> {
245    use tungstenite::Message;
246    match action {
247        Value::Variant { name, args } if name == "WsSend" => {
248            let text = match args.first() {
249                Some(Value::Str(s)) => s.clone(),
250                _ => return Err("WsSend payload must be Str".into()),
251            };
252            ws.send(Message::Text(text.into()))
253                .map_err(|e| format!("ws send: {e}"))
254        }
255        Value::Variant { name, args } if name == "WsSendBinary" => {
256            let bytes: Vec<u8> = match args.first() {
257                Some(Value::List(elems)) => elems
258                    .iter()
259                    .map(|v| match v {
260                        Value::Int(n) => Ok(*n as u8),
261                        _ => Err("WsSendBinary payload must be List[Int]".into()),
262                    })
263                    .collect::<Result<Vec<_>, String>>()?,
264                _ => return Err("WsSendBinary payload must be List[Int]".into()),
265            };
266            ws.send(Message::Binary(bytes.into()))
267                .map_err(|e| format!("ws send binary: {e}"))
268        }
269        Value::Variant { name, .. } if name == "WsNoOp" => Ok(()),
270        other => Err(format!("unexpected WsAction: {other:?}")),
271    }
272}
273
274/// Closure-based WebSocket server. Accepts a `Value::Closure` as the handler.
275pub fn serve_ws_fn(
276    port: u16,
277    subprotocol: String,
278    closure: Value,
279    program: Arc<Program>,
280    policy: Policy,
281    registry: Arc<ChatRegistry>,
282) -> Result<Value, String> {
283    let listener = TcpListener::bind(("127.0.0.1", port))
284        .map_err(|e| format!("net.serve_ws_fn bind {port}: {e}"))?;
285    eprintln!("net.serve_ws_fn: listening on ws://127.0.0.1:{port}");
286    for stream in listener.incoming() {
287        let stream = match stream {
288            Ok(s) => s,
289            Err(e) => { eprintln!("net.serve_ws_fn accept: {e}"); continue; }
290        };
291        let program = Arc::clone(&program);
292        let policy = policy.clone();
293        let closure = closure.clone();
294        let subprotocol = subprotocol.clone();
295        let registry = Arc::clone(&registry);
296        thread::spawn(move || {
297            if let Err(e) = handle_connection_fn(
298                stream, program, policy, closure, subprotocol, registry,
299            ) {
300                eprintln!("net.serve_ws_fn connection error: {e}");
301            }
302        });
303    }
304    Ok(Value::Unit)
305}
306
307fn handle_connection_fn(
308    stream: std::net::TcpStream,
309    program: Arc<Program>,
310    policy: Policy,
311    closure: Value,
312    subprotocol: String,
313    registry: Arc<ChatRegistry>,
314) -> Result<(), String> {
315    use tungstenite::{accept_hdr, handshake::server::{Request, Response}};
316
317    let mut path = String::new();
318    let path_ref = &mut path;
319    let mut ws = accept_hdr(stream, |req: &Request, resp: Response| {
320        *path_ref = req.uri().path().to_string();
321        Ok(resp)
322    }).map_err(|e| format!("ws handshake: {e}"))?;
323
324    let (tx, rx) = mpsc::channel::<String>();
325    let conn_id = registry.register(path.trim_start_matches('/').to_string(), tx);
326    let _ = ws.get_mut().set_read_timeout(Some(Duration::from_millis(50)));
327
328    let result = run_loop_fn(
329        &mut ws, &rx, conn_id, &path, &subprotocol,
330        &program, &policy, &closure, &registry,
331    );
332    registry.unregister(conn_id);
333    let _ = ws.close(None);
334    result
335}
336
337#[allow(clippy::too_many_arguments)]
338fn run_loop_fn(
339    ws: &mut tungstenite::WebSocket<std::net::TcpStream>,
340    rx: &mpsc::Receiver<String>,
341    conn_id: u64,
342    path: &str,
343    subprotocol: &str,
344    program: &Arc<Program>,
345    policy: &Policy,
346    closure: &Value,
347    registry: &Arc<ChatRegistry>,
348) -> Result<(), String> {
349    use tungstenite::Message;
350    use std::io::ErrorKind;
351
352    let ws_conn = build_ws_conn(conn_id, path, subprotocol);
353
354    loop {
355        let ws_msg = match ws.read() {
356            Ok(Message::Text(body)) => Some(build_ws_message_text(&body)),
357            Ok(Message::Binary(_)) => None,
358            Ok(Message::Ping(_)) => Some(build_ws_message_ping()),
359            Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
360                // Notify handler then exit.
361                let handler = crate::handler::DefaultHandler::new(policy.clone())
362                    .with_program(Arc::clone(program))
363                    .with_chat_registry(Arc::clone(registry));
364                let mut vm = Vm::with_handler(program, Box::new(handler));
365                let _ = vm.invoke_closure_value(
366                    closure.clone(),
367                    vec![ws_conn.clone(), build_ws_message_close()],
368                );
369                break;
370            }
371            Ok(_) => None, // pong / frame
372            Err(tungstenite::Error::Io(ref e))
373                if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => None,
374            Err(e) => return Err(format!("ws read: {e}")),
375        };
376
377        if let Some(msg) = ws_msg {
378            let handler = crate::handler::DefaultHandler::new(policy.clone())
379                .with_program(Arc::clone(program))
380                .with_chat_registry(Arc::clone(registry));
381            let mut vm = Vm::with_handler(program, Box::new(handler));
382            match vm.invoke_closure_value(closure.clone(), vec![ws_conn.clone(), msg]) {
383                Ok(action) => {
384                    if let Err(e) = apply_ws_action(&action, ws) {
385                        eprintln!("ws action {conn_id}: {e}");
386                    }
387                }
388                Err(e) => eprintln!("ws handler {conn_id}: {e}"),
389            }
390        }
391
392        // Drain broadcast/send outbound channel.
393        loop {
394            match rx.try_recv() {
395                Ok(msg) => {
396                    if let Err(e) = ws.send(Message::Text(msg.into())) {
397                        return Err(format!("ws send: {e}"));
398                    }
399                }
400                Err(mpsc::TryRecvError::Empty) => break,
401                Err(mpsc::TryRecvError::Disconnected) => return Ok(()),
402            }
403        }
404    }
405    Ok(())
406}
407
408// ── Closure-based WebSocket client (#390) ────────────────────────────────────
409//
410// Inverse of `serve_ws_fn`: open a connection to a remote WS server and
411// run two Lex callbacks against it.
412//
413// - `on_open : () -> [E] WsAction` is invoked once after the handshake
414//   completes. The returned `WsAction` (typically `WsSend(boot_frame)`)
415//   is applied to the socket immediately. This is the hook for
416//   protocols like OCPP where the client sends a `BootNotification`
417//   the moment it connects.
418// - `on_message : (WsMessage) -> [E] WsAction` is invoked for every
419//   inbound frame. Same `WsAction` semantics as the server-side
420//   handler. A `WsClose` message is delivered once before the loop
421//   exits so handlers can run shutdown logic.
422//
423// Multi-frame sends from `on_open` (e.g. a charger that wants to
424// also kick off a heartbeat scheduler at connect-time) aren't
425// expressible in v1 — the issue's `send :: (Str) -> [net]
426// Result[Unit, Str]` closure would let users push outbound frames
427// from arbitrary `[net]` code, but that requires representing
428// Rust-native closures as Lex `Value`s, which is a separate
429// runtime change. v1 covers the BootNotification + reactive reply
430// pattern that motivates the issue.
431
432fn build_dial_result(ok: Result<(), String>) -> Value {
433    match ok {
434        Ok(()) => Value::Variant {
435            name: "Ok".into(),
436            args: vec![Value::Unit],
437        },
438        Err(msg) => Value::Variant {
439            name: "Err".into(),
440            args: vec![Value::Str(msg)],
441        },
442    }
443}
444
445/// `net.dial_ws(url, subprotocol, on_open, on_message) -> [net, E]
446/// Result[Unit, Str]`. Blocks for the lifetime of the connection;
447/// returns `Ok(())` on a clean close from the server, `Err(reason)`
448/// on dial failure, handshake failure, read error, or write error.
449pub fn dial_ws(
450    url: String,
451    subprotocol: String,
452    on_open: Value,
453    on_message: Value,
454    program: Arc<Program>,
455    policy: Policy,
456) -> Result<Value, String> {
457    use tungstenite::client::IntoClientRequest;
458    use tungstenite::http::HeaderValue;
459
460    // Build the request — when `subprotocol` is non-empty, attach the
461    // Sec-WebSocket-Protocol header so the server's accept-handler
462    // can match on it. Empty subprotocol → header omitted (the same
463    // contract as `serve_ws_fn`'s subprotocol arg).
464    //
465    // Caller-controlled inputs (URL syntax, subprotocol header value)
466    // surface as a Lex `Err(reason)`, not a Rust panic / handler
467    // error, so `match net.dial_ws(...) { Err(_) => ..., Ok(_) => ... }`
468    // works at the Lex level.
469    let mut req = match url.as_str().into_client_request() {
470        Ok(r) => r,
471        Err(e) => {
472            return Ok(build_dial_result(Err(format!(
473                "net.dial_ws: bad URL `{url}`: {e}"
474            ))));
475        }
476    };
477    if !subprotocol.is_empty() {
478        let header = match HeaderValue::from_str(&subprotocol) {
479            Ok(h) => h,
480            Err(e) => {
481                return Ok(build_dial_result(Err(format!(
482                    "net.dial_ws: invalid subprotocol `{subprotocol}`: {e}"
483                ))));
484            }
485        };
486        req.headers_mut().insert("Sec-WebSocket-Protocol", header);
487    }
488
489    let (mut ws, _resp) = match tungstenite::connect(req) {
490        Ok(pair) => pair,
491        Err(e) => {
492            return Ok(build_dial_result(Err(format!(
493                "net.dial_ws: connect to `{url}`: {e}"
494            ))));
495        }
496    };
497
498    // Non-blocking-ish reads so we don't tie up the thread on an idle
499    // socket, mirroring the server's read-timeout multiplexing.
500    if let Some(stream) = stream_for(&mut ws) {
501        let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
502    }
503
504    // 1. Fire on_open once and apply its action.
505    {
506        let handler = crate::handler::DefaultHandler::new(policy.clone())
507            .with_program(Arc::clone(&program));
508        let mut vm = Vm::with_handler(&program, Box::new(handler));
509        match vm.invoke_closure_value(on_open.clone(), vec![]) {
510            Ok(action) => {
511                if let Err(e) = apply_ws_action(&action, &mut ws) {
512                    return Ok(build_dial_result(Err(format!(
513                        "net.dial_ws: on_open action: {e}"
514                    ))));
515                }
516            }
517            Err(e) => {
518                return Ok(build_dial_result(Err(format!(
519                    "net.dial_ws: on_open: {e}"
520                ))));
521            }
522        }
523    }
524
525    // 2. Run the read loop, dispatching each inbound frame to on_message.
526    let loop_result = dial_run_loop(&mut ws, &on_message, &program, &policy);
527    let _ = ws.close(None);
528    Ok(build_dial_result(loop_result))
529}
530
531/// Pull the underlying TCP stream out of a `MaybeTlsStream` so we can
532/// set a read timeout. For plaintext connections this is the
533/// `TcpStream` directly; for `rustls`-wrapped streams it's the inner
534/// socket. Returns `None` if the wrapping is some other variant —
535/// in that case we just skip the timeout and rely on blocking reads.
536fn stream_for(
537    ws: &mut tungstenite::WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>,
538) -> Option<&mut std::net::TcpStream> {
539    use tungstenite::stream::MaybeTlsStream;
540    match ws.get_mut() {
541        MaybeTlsStream::Plain(s) => Some(s),
542        MaybeTlsStream::Rustls(s) => Some(s.get_mut()),
543        _ => None,
544    }
545}
546
547fn dial_run_loop(
548    ws: &mut tungstenite::WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>,
549    on_message: &Value,
550    program: &Arc<Program>,
551    policy: &Policy,
552) -> Result<(), String> {
553    use std::io::ErrorKind;
554    use tungstenite::Message;
555
556    loop {
557        let ws_msg = match ws.read() {
558            Ok(Message::Text(body)) => Some(build_ws_message_text(&body)),
559            Ok(Message::Binary(payload)) => Some(build_ws_message_binary(&payload)),
560            Ok(Message::Ping(_)) => Some(build_ws_message_ping()),
561            Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
562                // Deliver WsClose so the handler can do shutdown work.
563                let handler = crate::handler::DefaultHandler::new(policy.clone())
564                    .with_program(Arc::clone(program));
565                let mut vm = Vm::with_handler(program, Box::new(handler));
566                let _ = vm.invoke_closure_value(
567                    on_message.clone(),
568                    vec![build_ws_message_close()],
569                );
570                return Ok(());
571            }
572            Ok(_) => None, // pong / raw frame
573            Err(tungstenite::Error::Io(ref e))
574                if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut =>
575            {
576                None
577            }
578            Err(e) => return Err(format!("net.dial_ws: read: {e}")),
579        };
580
581        if let Some(msg) = ws_msg {
582            let handler = crate::handler::DefaultHandler::new(policy.clone())
583                .with_program(Arc::clone(program));
584            let mut vm = Vm::with_handler(program, Box::new(handler));
585            match vm.invoke_closure_value(on_message.clone(), vec![msg]) {
586                Ok(action) => {
587                    if let Err(e) = apply_ws_action(&action, ws) {
588                        return Err(format!("net.dial_ws: action: {e}"));
589                    }
590                }
591                Err(e) => return Err(format!("net.dial_ws: on_message: {e}")),
592            }
593        }
594    }
595}