use lex_ast::canonicalize_program;
use lex_bytecode::{compile_program, vm::Vm, Value};
use lex_runtime::{DefaultHandler, Policy};
use lex_syntax::parse_source;
use std::collections::BTreeSet;
use std::net::TcpListener;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use tungstenite::Message;
fn spawn_test_server(
setup: impl FnOnce(&mut tungstenite::WebSocket<std::net::TcpStream>) + Send + 'static,
on_each: impl Fn(&str, &mut tungstenite::WebSocket<std::net::TcpStream>) + Send + 'static,
max_frames: usize,
) -> (u16, Arc<Mutex<Vec<String>>>) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let port = listener.local_addr().unwrap().port();
let recv_log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
let log_handle = Arc::clone(&recv_log);
thread::spawn(move || {
let (stream, _) = listener.accept().expect("server accept");
let _ = stream.set_read_timeout(Some(Duration::from_secs(5)));
let mut ws = tungstenite::accept(stream).expect("ws handshake");
setup(&mut ws);
while log_handle.lock().unwrap().len() < max_frames {
match ws.read() {
Ok(Message::Text(body)) => {
let s = body.to_string();
log_handle.lock().unwrap().push(s.clone());
on_each(&s, &mut ws);
}
Ok(Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => break,
Ok(_) => {}
Err(_) => break,
}
}
let _ = ws.close(None);
let _ = ws.read();
});
thread::sleep(Duration::from_millis(50));
(port, recv_log)
}
fn run_lex(src: &str, fn_name: &str) -> Value {
let prog = parse_source(src).expect("parse");
let stages = canonicalize_program(&prog);
if let Err(errs) = lex_types::check_program(&stages) {
panic!("type errors:\n{errs:#?}");
}
let bc = Arc::new(compile_program(&stages));
let mut policy = Policy::pure();
policy.allow_effects = ["net".to_string()]
.into_iter()
.collect::<BTreeSet<_>>();
let handler = DefaultHandler::new(policy).with_program(Arc::clone(&bc));
let mut vm = Vm::with_handler(&bc, Box::new(handler));
vm.call(fn_name, vec![])
.unwrap_or_else(|e| panic!("call {fn_name}: {e}"))
}
fn assert_ok_unit(v: &Value) {
match v {
Value::Variant { name, args } if name == "Ok" && args.len() == 1 => {
assert!(matches!(args[0], Value::Unit), "Ok payload not Unit: {:?}", args[0]);
}
other => panic!("expected Ok(Unit), got {other:?}"),
}
}
fn assert_err_contains(v: &Value, needle: &str) {
match v {
Value::Variant { name, args } if name == "Err" && args.len() == 1 => match &args[0] {
Value::Str(s) => assert!(
s.contains(needle),
"expected Err containing `{needle}`, got `{s}`"
),
other => panic!("Err payload not Str: {other:?}"),
},
other => panic!("expected Err(_), got {other:?}"),
}
}
#[test]
fn dial_ws_runs_on_open_then_replies_to_inbound_text() {
let (port, log) = spawn_test_server(
|ws| {
ws.send(Message::Text("ping".into())).expect("server ping");
},
|_inbound, _ws| {},
2, );
let src = format!(
r#"
import "std.net" as net
fn main() -> [net] Result[Unit, Str] {{
net.dial_ws(
"ws://127.0.0.1:{port}",
"",
fn () -> WsAction {{ WsSend("boot") }},
fn (msg :: WsMessage) -> WsAction {{
match msg {{
WsText(_) => WsSend("pong"),
WsBinary(_) => WsNoOp,
WsPing => WsNoOp,
WsClose => WsNoOp,
}}
}},
)
}}
"#
);
let result = run_lex(&src, "main");
assert_ok_unit(&result);
let frames = log.lock().unwrap().clone();
assert_eq!(
frames,
vec!["boot".to_string(), "pong".to_string()],
"server should have seen boot frame from on_open and pong reply from on_message",
);
}
#[test]
fn dial_ws_returns_err_on_connect_failure() {
let src = r#"
import "std.net" as net
fn main() -> [net] Result[Unit, Str] {
net.dial_ws(
"ws://127.0.0.1:1",
"",
fn () -> WsAction { WsNoOp },
fn (_msg :: WsMessage) -> WsAction { WsNoOp },
)
}
"#;
let result = run_lex(src, "main");
assert_err_contains(&result, "connect");
}
#[test]
fn dial_ws_returns_err_on_bad_url() {
let src = r#"
import "std.net" as net
fn main() -> [net] Result[Unit, Str] {
net.dial_ws(
"not a url",
"",
fn () -> WsAction { WsNoOp },
fn (_msg :: WsMessage) -> WsAction { WsNoOp },
)
}
"#;
let result = run_lex(src, "main");
match &result {
Value::Variant { name, .. } if name == "Err" => {}
other => panic!("expected Err(_), got {other:?}"),
}
}
#[test]
#[allow(clippy::result_large_err)] fn dial_ws_subprotocol_header_is_sent_when_non_empty() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let port = listener.local_addr().unwrap().port();
let seen_subproto: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
let seen_handle = Arc::clone(&seen_subproto);
thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept");
let _ = stream.set_read_timeout(Some(Duration::from_secs(5)));
let mut ws = tungstenite::accept_hdr(stream, |req: &tungstenite::handshake::server::Request, mut resp: tungstenite::handshake::server::Response| {
if let Some(v) = req.headers().get("Sec-WebSocket-Protocol") {
if let Ok(s) = v.to_str() {
*seen_handle.lock().unwrap() = Some(s.to_string());
resp.headers_mut().insert(
"Sec-WebSocket-Protocol",
tungstenite::http::HeaderValue::from_str(s).unwrap(),
);
}
}
Ok(resp)
}).expect("handshake");
let _ = ws.close(None);
let _ = ws.read();
});
thread::sleep(Duration::from_millis(50));
let src = format!(
r#"
import "std.net" as net
fn main() -> [net] Result[Unit, Str] {{
net.dial_ws(
"ws://127.0.0.1:{port}",
"ocpp1.6",
fn () -> WsAction {{ WsNoOp }},
fn (_msg :: WsMessage) -> WsAction {{ WsNoOp }},
)
}}
"#
);
let result = run_lex(&src, "main");
assert_ok_unit(&result);
let seen = seen_subproto.lock().unwrap().clone();
assert_eq!(
seen.as_deref(),
Some("ocpp1.6"),
"server should have received Sec-WebSocket-Protocol: ocpp1.6"
);
}