Skip to main content

lex_runtime/
ws.rs

1//! WebSocket server + chat-broadcast registry.
2//!
3//! `net.serve_ws(port, on_message)` blocks on a TCP listener, upgrades
4//! each incoming connection to WebSocket, and runs a per-connection
5//! worker thread that polls both inbound (calls Lex's `on_message`)
6//! and outbound (drains broadcasts from a channel into the socket).
7//!
8//! `chat.broadcast(room, body)` looks up every connection in `room`
9//! and pushes `body` onto its outbound channel. `chat.send(conn_id,
10//! body)` is the same but to a single connection.
11//!
12//! The registry is an `Arc<Mutex<…>>` because Lex's immutability means
13//! shared mutable state has to live in the host runtime. Lex code
14//! stays pure: it receives an event, returns Nil, and any side
15//! effects go through `chat.*` which is gated by the policy.
16
17// tungstenite's `accept_hdr` callback takes/returns a tungstenite
18// `ErrorResponse` which is large; we only ever return Ok so the
19// large-Err warning is noise.
20#![allow(clippy::result_large_err)]
21
22use crate::policy::Policy;
23use indexmap::IndexMap;
24use lex_bytecode::vm::Vm;
25use lex_bytecode::{Program, Value};
26use std::net::TcpListener;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::mpsc;
29use std::sync::{Arc, Mutex};
30use std::thread;
31use std::time::Duration;
32
33/// Per-connection state held in the global registry.
34struct Conn {
35    room: String,
36    /// Channel writer end. The connection's worker thread reads from
37    /// the corresponding Receiver and writes each message to the
38    /// WebSocket. Broadcasts push here.
39    outbound: mpsc::Sender<String>,
40}
41
42/// Global chat registry. One per `net.serve_ws` invocation.
43#[derive(Default)]
44pub struct ChatRegistry {
45    conns: Mutex<IndexMap<u64, Conn>>,
46}
47
48impl ChatRegistry {
49    fn register(&self, room: String, outbound: mpsc::Sender<String>) -> u64 {
50        static NEXT_ID: AtomicU64 = AtomicU64::new(1);
51        let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
52        self.conns.lock().unwrap().insert(id, Conn { room, outbound });
53        id
54    }
55    fn unregister(&self, id: u64) {
56        self.conns.lock().unwrap().shift_remove(&id);
57    }
58    fn broadcast(&self, room: &str, body: &str) {
59        let conns = self.conns.lock().unwrap();
60        for c in conns.values() {
61            if c.room == room {
62                let _ = c.outbound.send(body.to_string());
63            }
64        }
65    }
66    fn send_to(&self, id: u64, body: &str) -> bool {
67        if let Some(c) = self.conns.lock().unwrap().get(&id) {
68            let _ = c.outbound.send(body.to_string());
69            true
70        } else {
71            false
72        }
73    }
74}
75
76/// `chat.broadcast(room, body)` — looked up at runtime by the
77/// effect handler; called from inside the Lex VM.
78pub fn chat_broadcast(reg: &Arc<ChatRegistry>, room: &str, body: &str) {
79    reg.broadcast(room, body);
80}
81
82pub fn chat_send(reg: &Arc<ChatRegistry>, conn_id: u64, body: &str) -> bool {
83    reg.send_to(conn_id, body)
84}
85
86/// Bind a WebSocket server. Blocks; returns Unit on shutdown (the
87/// process is normally killed before that).
88pub fn serve_ws(
89    port: u16,
90    handler_name: String,
91    program: Arc<Program>,
92    policy: Policy,
93    registry: Arc<ChatRegistry>,
94) -> Result<Value, String> {
95    let listener = TcpListener::bind(("127.0.0.1", port))
96        .map_err(|e| format!("net.serve_ws bind {port}: {e}"))?;
97    eprintln!("net.serve_ws: listening on ws://127.0.0.1:{port}");
98    for stream in listener.incoming() {
99        let stream = match stream {
100            Ok(s) => s,
101            Err(e) => { eprintln!("net.serve_ws accept: {e}"); continue; }
102        };
103        let program = Arc::clone(&program);
104        let policy = policy.clone();
105        let handler_name = handler_name.clone();
106        let registry = Arc::clone(&registry);
107        thread::spawn(move || {
108            if let Err(e) = handle_connection(stream, program, policy, handler_name, registry) {
109                eprintln!("net.serve_ws connection error: {e}");
110            }
111        });
112    }
113    Ok(Value::Unit)
114}
115
116fn handle_connection(
117    stream: std::net::TcpStream,
118    program: Arc<Program>,
119    policy: Policy,
120    handler_name: String,
121    registry: Arc<ChatRegistry>,
122) -> Result<(), String> {
123    use tungstenite::{accept_hdr, handshake::server::{Request, Response}};
124
125    // Capture the request path during the handshake — used as the room name.
126    let mut path = String::new();
127    let path_ref = &mut path;
128    let mut ws = accept_hdr(stream, |req: &Request, resp: Response| {
129        *path_ref = req.uri().path().to_string();
130        Ok(resp)
131    }).map_err(|e| format!("ws handshake: {e}"))?;
132
133    let room = path.trim_start_matches('/').to_string();
134
135    // Outbound channel: broadcast/send pushes here, this thread writes
136    // each message into the WebSocket.
137    let (tx, rx) = mpsc::channel::<String>();
138    let conn_id = registry.register(room.clone(), tx);
139
140    // Make WS reads non-blocking-ish so the same thread can also drain
141    // the outbound channel. tungstenite reads through the underlying
142    // TcpStream; setting a short read timeout lets us multiplex.
143    let _ = ws.get_mut().set_read_timeout(Some(Duration::from_millis(50)));
144
145    let result = run_loop(&mut ws, &rx, conn_id, &room, &program, &policy, &handler_name, &registry);
146    registry.unregister(conn_id);
147    let _ = ws.close(None);
148    result
149}
150
151#[allow(clippy::too_many_arguments)]
152fn run_loop(
153    ws: &mut tungstenite::WebSocket<std::net::TcpStream>,
154    rx: &mpsc::Receiver<String>,
155    conn_id: u64,
156    room: &str,
157    program: &Arc<Program>,
158    policy: &Policy,
159    handler_name: &str,
160    registry: &Arc<ChatRegistry>,
161) -> Result<(), String> {
162    use tungstenite::Message;
163    use std::io::ErrorKind;
164    loop {
165        // 1) Try to read one inbound message. WouldBlock = no data yet.
166        match ws.read() {
167            Ok(Message::Text(body)) => {
168                let ev = build_ws_event(conn_id, room, &body);
169                let handler = crate::handler::DefaultHandler::new(policy.clone())
170                    .with_program(Arc::clone(program))
171                    .with_chat_registry(Arc::clone(registry));
172                let mut vm = Vm::with_handler(program, Box::new(handler));
173                if let Err(e) = vm.call(handler_name, vec![ev]) {
174                    eprintln!("on_message {conn_id}: {e}");
175                }
176            }
177            Ok(Message::Binary(_)) => { /* binary frames ignored in v1 */ }
178            Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => break,
179            Ok(_) => {} // ping/pong/frame
180            Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock
181                || e.kind() == ErrorKind::TimedOut => {}
182            Err(e) => return Err(format!("ws read: {e}")),
183        }
184        // 2) Drain outbound channel. Doesn't block.
185        loop {
186            match rx.try_recv() {
187                Ok(msg) => {
188                    if let Err(e) = ws.send(Message::Text(msg.into())) {
189                        return Err(format!("ws send: {e}"));
190                    }
191                }
192                Err(mpsc::TryRecvError::Empty) => break,
193                Err(mpsc::TryRecvError::Disconnected) => return Ok(()),
194            }
195        }
196    }
197    Ok(())
198}
199
200fn build_ws_event(conn_id: u64, room: &str, body: &str) -> Value {
201    let mut rec = IndexMap::new();
202    rec.insert("body".into(), Value::Str(body.to_string()));
203    rec.insert("conn_id".into(), Value::Int(conn_id as i64));
204    rec.insert("room".into(), Value::Str(room.to_string()));
205    Value::Record(rec)
206}