1#![allow(clippy::result_large_err)]
21
22use crate::policy::Policy;
23use indexmap::IndexMap;
24use lex_bytecode::vm::Vm;
25use lex_bytecode::{Program, Value};
26use std::net::TcpListener;
27use std::sync::atomic::{AtomicU64, Ordering};
28use std::sync::mpsc;
29use std::sync::{Arc, Mutex};
30use std::thread;
31use std::time::Duration;
32
33struct Conn {
35 room: String,
36 outbound: mpsc::Sender<String>,
40}
41
42#[derive(Default)]
44pub struct ChatRegistry {
45 conns: Mutex<IndexMap<u64, Conn>>,
46}
47
48impl ChatRegistry {
49 fn register(&self, room: String, outbound: mpsc::Sender<String>) -> u64 {
50 static NEXT_ID: AtomicU64 = AtomicU64::new(1);
51 let id = NEXT_ID.fetch_add(1, Ordering::SeqCst);
52 self.conns.lock().unwrap().insert(id, Conn { room, outbound });
53 id
54 }
55 fn unregister(&self, id: u64) {
56 self.conns.lock().unwrap().shift_remove(&id);
57 }
58 fn broadcast(&self, room: &str, body: &str) {
59 let conns = self.conns.lock().unwrap();
60 for c in conns.values() {
61 if c.room == room {
62 let _ = c.outbound.send(body.to_string());
63 }
64 }
65 }
66 fn send_to(&self, id: u64, body: &str) -> bool {
67 if let Some(c) = self.conns.lock().unwrap().get(&id) {
68 let _ = c.outbound.send(body.to_string());
69 true
70 } else {
71 false
72 }
73 }
74}
75
76pub fn chat_broadcast(reg: &Arc<ChatRegistry>, room: &str, body: &str) {
79 reg.broadcast(room, body);
80}
81
82pub fn chat_send(reg: &Arc<ChatRegistry>, conn_id: u64, body: &str) -> bool {
83 reg.send_to(conn_id, body)
84}
85
86pub fn serve_ws(
89 port: u16,
90 handler_name: String,
91 program: Arc<Program>,
92 policy: Policy,
93 registry: Arc<ChatRegistry>,
94) -> Result<Value, String> {
95 let listener = TcpListener::bind(("127.0.0.1", port))
96 .map_err(|e| format!("net.serve_ws bind {port}: {e}"))?;
97 eprintln!("net.serve_ws: listening on ws://127.0.0.1:{port}");
98 for stream in listener.incoming() {
99 let stream = match stream {
100 Ok(s) => s,
101 Err(e) => { eprintln!("net.serve_ws accept: {e}"); continue; }
102 };
103 let program = Arc::clone(&program);
104 let policy = policy.clone();
105 let handler_name = handler_name.clone();
106 let registry = Arc::clone(®istry);
107 thread::spawn(move || {
108 if let Err(e) = handle_connection(stream, program, policy, handler_name, registry) {
109 eprintln!("net.serve_ws connection error: {e}");
110 }
111 });
112 }
113 Ok(Value::Unit)
114}
115
116fn handle_connection(
117 stream: std::net::TcpStream,
118 program: Arc<Program>,
119 policy: Policy,
120 handler_name: String,
121 registry: Arc<ChatRegistry>,
122) -> Result<(), String> {
123 use tungstenite::{accept_hdr, handshake::server::{Request, Response}};
124
125 let mut path = String::new();
127 let path_ref = &mut path;
128 let mut ws = accept_hdr(stream, |req: &Request, resp: Response| {
129 *path_ref = req.uri().path().to_string();
130 Ok(resp)
131 }).map_err(|e| format!("ws handshake: {e}"))?;
132
133 let room = path.trim_start_matches('/').to_string();
134
135 let (tx, rx) = mpsc::channel::<String>();
138 let conn_id = registry.register(room.clone(), tx);
139
140 let _ = ws.get_mut().set_read_timeout(Some(Duration::from_millis(50)));
144
145 let result = run_loop(&mut ws, &rx, conn_id, &room, &program, &policy, &handler_name, ®istry);
146 registry.unregister(conn_id);
147 let _ = ws.close(None);
148 result
149}
150
151#[allow(clippy::too_many_arguments)]
152fn run_loop(
153 ws: &mut tungstenite::WebSocket<std::net::TcpStream>,
154 rx: &mpsc::Receiver<String>,
155 conn_id: u64,
156 room: &str,
157 program: &Arc<Program>,
158 policy: &Policy,
159 handler_name: &str,
160 registry: &Arc<ChatRegistry>,
161) -> Result<(), String> {
162 use tungstenite::Message;
163 use std::io::ErrorKind;
164 loop {
165 match ws.read() {
167 Ok(Message::Text(body)) => {
168 let ev = build_ws_event(conn_id, room, &body);
169 let handler = crate::handler::DefaultHandler::new(policy.clone())
170 .with_program(Arc::clone(program))
171 .with_chat_registry(Arc::clone(registry));
172 let mut vm = Vm::with_handler(program, Box::new(handler));
173 if let Err(e) = vm.call(handler_name, vec![ev]) {
174 eprintln!("on_message {conn_id}: {e}");
175 }
176 }
177 Ok(Message::Binary(_)) => { }
178 Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => break,
179 Ok(_) => {} Err(tungstenite::Error::Io(ref e)) if e.kind() == ErrorKind::WouldBlock
181 || e.kind() == ErrorKind::TimedOut => {}
182 Err(e) => return Err(format!("ws read: {e}")),
183 }
184 loop {
186 match rx.try_recv() {
187 Ok(msg) => {
188 if let Err(e) = ws.send(Message::Text(msg.into())) {
189 return Err(format!("ws send: {e}"));
190 }
191 }
192 Err(mpsc::TryRecvError::Empty) => break,
193 Err(mpsc::TryRecvError::Disconnected) => return Ok(()),
194 }
195 }
196 }
197 Ok(())
198}
199
200fn build_ws_event(conn_id: u64, room: &str, body: &str) -> Value {
201 let mut rec = IndexMap::new();
202 rec.insert("body".into(), Value::Str(body.to_string()));
203 rec.insert("conn_id".into(), Value::Int(conn_id as i64));
204 rec.insert("room".into(), Value::Str(room.to_string()));
205 Value::Record(rec)
206}
207
208fn build_ws_conn(conn_id: u64, path: &str, subprotocol: &str) -> Value {
212 let mut rec = IndexMap::new();
213 rec.insert("id".into(), Value::Str(conn_id.to_string()));
214 rec.insert("path".into(), Value::Str(path.to_string()));
215 rec.insert("subprotocol".into(), Value::Str(subprotocol.to_string()));
216 Value::Record(rec)
217}
218
219fn build_ws_message_text(body: &str) -> Value {
221 Value::Variant { name: "WsText".into(), args: vec![Value::Str(body.to_string())] }
222}
223
224fn build_ws_message_close() -> Value {
225 Value::Variant { name: "WsClose".into(), args: vec![] }
226}
227
228fn build_ws_message_ping() -> Value {
229 Value::Variant { name: "WsPing".into(), args: vec![] }
230}
231
232fn build_ws_message_binary(payload: &[u8]) -> Value {
233 let bytes = payload.iter().map(|b| Value::Int(*b as i64)).collect();
234 Value::Variant { name: "WsBinary".into(), args: vec![Value::List(bytes)] }
235}
236
237fn apply_ws_action<S: std::io::Read + std::io::Write>(
242 action: &Value,
243 ws: &mut tungstenite::WebSocket<S>,
244) -> Result<(), String> {
245 use tungstenite::Message;
246 match action {
247 Value::Variant { name, args } if name == "WsSend" => {
248 let text = match args.first() {
249 Some(Value::Str(s)) => s.clone(),
250 _ => return Err("WsSend payload must be Str".into()),
251 };
252 ws.send(Message::Text(text.into()))
253 .map_err(|e| format!("ws send: {e}"))
254 }
255 Value::Variant { name, args } if name == "WsSendBinary" => {
256 let bytes: Vec<u8> = match args.first() {
257 Some(Value::List(elems)) => elems
258 .iter()
259 .map(|v| match v {
260 Value::Int(n) => Ok(*n as u8),
261 _ => Err("WsSendBinary payload must be List[Int]".into()),
262 })
263 .collect::<Result<Vec<_>, String>>()?,
264 _ => return Err("WsSendBinary payload must be List[Int]".into()),
265 };
266 ws.send(Message::Binary(bytes.into()))
267 .map_err(|e| format!("ws send binary: {e}"))
268 }
269 Value::Variant { name, .. } if name == "WsNoOp" => Ok(()),
270 other => Err(format!("unexpected WsAction: {other:?}")),
271 }
272}
273
274pub fn serve_ws_fn(
276 port: u16,
277 subprotocol: String,
278 closure: Value,
279 program: Arc<Program>,
280 policy: Policy,
281 registry: Arc<ChatRegistry>,
282) -> Result<Value, String> {
283 let listener = TcpListener::bind(("127.0.0.1", port))
284 .map_err(|e| format!("net.serve_ws_fn bind {port}: {e}"))?;
285 eprintln!("net.serve_ws_fn: listening on ws://127.0.0.1:{port}");
286 for stream in listener.incoming() {
287 let stream = match stream {
288 Ok(s) => s,
289 Err(e) => { eprintln!("net.serve_ws_fn accept: {e}"); continue; }
290 };
291 let program = Arc::clone(&program);
292 let policy = policy.clone();
293 let closure = closure.clone();
294 let subprotocol = subprotocol.clone();
295 let registry = Arc::clone(®istry);
296 thread::spawn(move || {
297 if let Err(e) = handle_connection_fn(
298 stream, program, policy, closure, subprotocol, registry,
299 ) {
300 eprintln!("net.serve_ws_fn connection error: {e}");
301 }
302 });
303 }
304 Ok(Value::Unit)
305}
306
307fn handle_connection_fn(
308 stream: std::net::TcpStream,
309 program: Arc<Program>,
310 policy: Policy,
311 closure: Value,
312 subprotocol: String,
313 registry: Arc<ChatRegistry>,
314) -> Result<(), String> {
315 use tungstenite::{accept_hdr, handshake::server::{Request, Response}};
316
317 let mut path = String::new();
318 let path_ref = &mut path;
319 let mut ws = accept_hdr(stream, |req: &Request, resp: Response| {
320 *path_ref = req.uri().path().to_string();
321 Ok(resp)
322 }).map_err(|e| format!("ws handshake: {e}"))?;
323
324 let (tx, rx) = mpsc::channel::<String>();
325 let conn_id = registry.register(path.trim_start_matches('/').to_string(), tx);
326 let _ = ws.get_mut().set_read_timeout(Some(Duration::from_millis(50)));
327
328 let result = run_loop_fn(
329 &mut ws, &rx, conn_id, &path, &subprotocol,
330 &program, &policy, &closure, ®istry,
331 );
332 registry.unregister(conn_id);
333 let _ = ws.close(None);
334 result
335}
336
337#[allow(clippy::too_many_arguments)]
338fn run_loop_fn(
339 ws: &mut tungstenite::WebSocket<std::net::TcpStream>,
340 rx: &mpsc::Receiver<String>,
341 conn_id: u64,
342 path: &str,
343 subprotocol: &str,
344 program: &Arc<Program>,
345 policy: &Policy,
346 closure: &Value,
347 registry: &Arc<ChatRegistry>,
348) -> Result<(), String> {
349 use tungstenite::Message;
350 use std::io::ErrorKind;
351
352 let ws_conn = build_ws_conn(conn_id, path, subprotocol);
353
354 loop {
355 let ws_msg = match ws.read() {
356 Ok(Message::Text(body)) => Some(build_ws_message_text(&body)),
357 Ok(Message::Binary(_)) => None,
358 Ok(Message::Ping(_)) => Some(build_ws_message_ping()),
359 Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
360 let handler = crate::handler::DefaultHandler::new(policy.clone())
362 .with_program(Arc::clone(program))
363 .with_chat_registry(Arc::clone(registry));
364 let mut vm = Vm::with_handler(program, Box::new(handler));
365 let _ = vm.invoke_closure_value(
366 closure.clone(),
367 vec![ws_conn.clone(), build_ws_message_close()],
368 );
369 break;
370 }
371 Ok(_) => None, Err(tungstenite::Error::Io(ref e))
373 if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut => None,
374 Err(e) => return Err(format!("ws read: {e}")),
375 };
376
377 if let Some(msg) = ws_msg {
378 let handler = crate::handler::DefaultHandler::new(policy.clone())
379 .with_program(Arc::clone(program))
380 .with_chat_registry(Arc::clone(registry));
381 let mut vm = Vm::with_handler(program, Box::new(handler));
382 match vm.invoke_closure_value(closure.clone(), vec![ws_conn.clone(), msg]) {
383 Ok(action) => {
384 if let Err(e) = apply_ws_action(&action, ws) {
385 eprintln!("ws action {conn_id}: {e}");
386 }
387 }
388 Err(e) => eprintln!("ws handler {conn_id}: {e}"),
389 }
390 }
391
392 loop {
394 match rx.try_recv() {
395 Ok(msg) => {
396 if let Err(e) = ws.send(Message::Text(msg.into())) {
397 return Err(format!("ws send: {e}"));
398 }
399 }
400 Err(mpsc::TryRecvError::Empty) => break,
401 Err(mpsc::TryRecvError::Disconnected) => return Ok(()),
402 }
403 }
404 }
405 Ok(())
406}
407
408fn build_dial_result(ok: Result<(), String>) -> Value {
433 match ok {
434 Ok(()) => Value::Variant {
435 name: "Ok".into(),
436 args: vec![Value::Unit],
437 },
438 Err(msg) => Value::Variant {
439 name: "Err".into(),
440 args: vec![Value::Str(msg)],
441 },
442 }
443}
444
445pub fn dial_ws(
450 url: String,
451 subprotocol: String,
452 on_open: Value,
453 on_message: Value,
454 program: Arc<Program>,
455 policy: Policy,
456) -> Result<Value, String> {
457 use tungstenite::client::IntoClientRequest;
458 use tungstenite::http::HeaderValue;
459
460 let mut req = match url.as_str().into_client_request() {
470 Ok(r) => r,
471 Err(e) => {
472 return Ok(build_dial_result(Err(format!(
473 "net.dial_ws: bad URL `{url}`: {e}"
474 ))));
475 }
476 };
477 if !subprotocol.is_empty() {
478 let header = match HeaderValue::from_str(&subprotocol) {
479 Ok(h) => h,
480 Err(e) => {
481 return Ok(build_dial_result(Err(format!(
482 "net.dial_ws: invalid subprotocol `{subprotocol}`: {e}"
483 ))));
484 }
485 };
486 req.headers_mut().insert("Sec-WebSocket-Protocol", header);
487 }
488
489 let (mut ws, _resp) = match tungstenite::connect(req) {
490 Ok(pair) => pair,
491 Err(e) => {
492 return Ok(build_dial_result(Err(format!(
493 "net.dial_ws: connect to `{url}`: {e}"
494 ))));
495 }
496 };
497
498 if let Some(stream) = stream_for(&mut ws) {
501 let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
502 }
503
504 {
506 let handler = crate::handler::DefaultHandler::new(policy.clone())
507 .with_program(Arc::clone(&program));
508 let mut vm = Vm::with_handler(&program, Box::new(handler));
509 match vm.invoke_closure_value(on_open.clone(), vec![]) {
510 Ok(action) => {
511 if let Err(e) = apply_ws_action(&action, &mut ws) {
512 return Ok(build_dial_result(Err(format!(
513 "net.dial_ws: on_open action: {e}"
514 ))));
515 }
516 }
517 Err(e) => {
518 return Ok(build_dial_result(Err(format!(
519 "net.dial_ws: on_open: {e}"
520 ))));
521 }
522 }
523 }
524
525 let loop_result = dial_run_loop(&mut ws, &on_message, &program, &policy);
527 let _ = ws.close(None);
528 Ok(build_dial_result(loop_result))
529}
530
531fn stream_for(
537 ws: &mut tungstenite::WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>,
538) -> Option<&mut std::net::TcpStream> {
539 use tungstenite::stream::MaybeTlsStream;
540 match ws.get_mut() {
541 MaybeTlsStream::Plain(s) => Some(s),
542 MaybeTlsStream::Rustls(s) => Some(s.get_mut()),
543 _ => None,
544 }
545}
546
547fn dial_run_loop(
548 ws: &mut tungstenite::WebSocket<tungstenite::stream::MaybeTlsStream<std::net::TcpStream>>,
549 on_message: &Value,
550 program: &Arc<Program>,
551 policy: &Policy,
552) -> Result<(), String> {
553 use std::io::ErrorKind;
554 use tungstenite::Message;
555
556 loop {
557 let ws_msg = match ws.read() {
558 Ok(Message::Text(body)) => Some(build_ws_message_text(&body)),
559 Ok(Message::Binary(payload)) => Some(build_ws_message_binary(&payload)),
560 Ok(Message::Ping(_)) => Some(build_ws_message_ping()),
561 Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
562 let handler = crate::handler::DefaultHandler::new(policy.clone())
564 .with_program(Arc::clone(program));
565 let mut vm = Vm::with_handler(program, Box::new(handler));
566 let _ = vm.invoke_closure_value(
567 on_message.clone(),
568 vec![build_ws_message_close()],
569 );
570 return Ok(());
571 }
572 Ok(_) => None, Err(tungstenite::Error::Io(ref e))
574 if e.kind() == ErrorKind::WouldBlock || e.kind() == ErrorKind::TimedOut =>
575 {
576 None
577 }
578 Err(e) => return Err(format!("net.dial_ws: read: {e}")),
579 };
580
581 if let Some(msg) = ws_msg {
582 let handler = crate::handler::DefaultHandler::new(policy.clone())
583 .with_program(Arc::clone(program));
584 let mut vm = Vm::with_handler(program, Box::new(handler));
585 match vm.invoke_closure_value(on_message.clone(), vec![msg]) {
586 Ok(action) => {
587 if let Err(e) = apply_ws_action(&action, ws) {
588 return Err(format!("net.dial_ws: action: {e}"));
589 }
590 }
591 Err(e) => return Err(format!("net.dial_ws: on_message: {e}")),
592 }
593 }
594 }
595}