sim-lib-server 0.1.0-rc.1

SIM workspace package for sim lib server.
Documentation
use std::{
    io::ErrorKind,
    net::{Shutdown, TcpListener, TcpStream},
    sync::Arc,
    time::Duration,
};

use sim_kernel::{Cx, Error, Result, Symbol};

use crate::{
    EvalSite, ServerAddress, ServerFrame, ServerRuntime,
    http::{
        HttpRequest, HttpResponse, ParsedUrl, WsMessage, base64_encode, format_url, header_value,
        parse_url, read_request, read_response, read_ws_message, websocket_accept_value,
        write_request, write_response, write_ws_binary, write_ws_close,
    },
};

use super::{
    ConnectionTransport, SERVER_CONNECTION_IO_TIMEOUT_MS, ServerTransport, WS_TRANSPORT_PATH,
    answer_or_negotiate, decode_transport_frame, encode_transport_frame, error_frame_from_error,
    io_to_host, is_timeout, update_negotiated_codec_from_reply,
};

pub struct WsServerTransport {
    address: ServerAddress,
    listener: TcpListener,
    path: String,
}

impl WsServerTransport {
    pub fn bind(address: ServerAddress) -> Result<Self> {
        let ServerAddress::Ws { url } = &address else {
            return Err(Error::Eval("ws transport requires a ws address".to_owned()));
        };
        let parsed = parse_url(url, "ws", WS_TRANSPORT_PATH)?;
        let listener =
            TcpListener::bind((parsed.host.as_str(), parsed.port)).map_err(io_to_host)?;
        listener.set_nonblocking(true).map_err(io_to_host)?;
        let local_addr = listener.local_addr().map_err(io_to_host)?;
        let address = ServerAddress::Ws {
            url: format_url(&ParsedUrl {
                port: local_addr.port(),
                ..parsed.clone()
            }),
        };
        Ok(Self {
            address,
            listener,
            path: parsed.path,
        })
    }
}

impl ServerTransport for WsServerTransport {
    fn address(&self) -> &ServerAddress {
        &self.address
    }

    fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
        loop {
            if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
                return Ok(connection);
            }
        }
    }

    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
        Ok(())
    }

    fn accept_timeout(
        &self,
        _cx: &mut Cx,
        _timeout: Duration,
    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
        match self.listener.accept() {
            Ok((stream, _peer)) => {
                stream.set_nodelay(true).map_err(io_to_host)?;
                Ok(Some(Box::new(WsServerConnectionTransport::new(
                    stream,
                    self.path.clone(),
                ))))
            }
            Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
            Err(error) => Err(io_to_host(error)),
        }
    }
}

pub struct WsConnectionTransport {
    stream: TcpStream,
}

impl WsConnectionTransport {
    pub fn connect(address: &ServerAddress) -> Result<Self> {
        let ServerAddress::Ws { url } = address else {
            return Err(Error::Eval("ws connect requires a ws address".to_owned()));
        };
        let parsed = parse_url(url, "ws", WS_TRANSPORT_PATH)?;
        let mut stream =
            TcpStream::connect((parsed.host.as_str(), parsed.port)).map_err(io_to_host)?;
        stream.set_nodelay(true).map_err(io_to_host)?;
        let client_key = base64_encode(b"sim-say-websocket");
        write_request(
            &mut stream,
            &HttpRequest {
                method: "GET".to_owned(),
                path: parsed.path,
                headers: vec![
                    ("Host".to_owned(), "sim-server".to_owned()),
                    ("Upgrade".to_owned(), "websocket".to_owned()),
                    ("Connection".to_owned(), "Upgrade".to_owned()),
                    ("Sec-WebSocket-Version".to_owned(), "13".to_owned()),
                    ("Sec-WebSocket-Key".to_owned(), client_key.clone()),
                ],
                body: Vec::new(),
            },
        )?;
        let response = read_response(&mut stream)?;
        if response.status != 101 {
            return Err(Error::Eval(format!("websocket status {}", response.status)));
        }
        let Some(accept) = header_value(&response.headers, "Sec-WebSocket-Accept") else {
            return Err(Error::HostError(
                "websocket handshake missing sec-websocket-accept".to_owned(),
            ));
        };
        if accept != websocket_accept_value(&client_key) {
            return Err(Error::HostError(
                "websocket handshake returned the wrong accept value".to_owned(),
            ));
        }
        Ok(Self { stream })
    }
}

impl ConnectionTransport for WsConnectionTransport {
    fn send_frame(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
        write_ws_binary(&mut self.stream, &encode_transport_frame(&frame)?, true)
    }

    fn recv_frame(
        &mut self,
        _cx: &mut Cx,
        timeout: Option<Duration>,
    ) -> Result<Option<ServerFrame>> {
        self.stream.set_read_timeout(timeout).map_err(io_to_host)?;
        match read_ws_message(&mut self.stream)? {
            Some(WsMessage::Binary(payload)) => decode_transport_frame(&payload).map(Some),
            Some(WsMessage::Close) => Ok(None),
            None => Ok(None),
        }
    }

    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
        let _ = write_ws_close(&mut self.stream, true);
        let _ = self.stream.shutdown(Shutdown::Both);
        Ok(())
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

struct WsServerConnectionTransport {
    stream: TcpStream,
    path: String,
}

impl WsServerConnectionTransport {
    fn new(stream: TcpStream, path: String) -> Self {
        Self { stream, path }
    }

    fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
        let handshake = match read_request(&mut self.stream)? {
            Some(request) => request,
            None => return Ok(()),
        };
        if handshake.method != "GET" {
            super::http_transport::write_http_error(&mut self.stream, 405, "method not allowed")?;
            return Ok(());
        }
        if handshake.path != self.path {
            super::http_transport::write_http_error(&mut self.stream, 404, "not found")?;
            return Ok(());
        }
        let Some(key) = header_value(&handshake.headers, "Sec-WebSocket-Key") else {
            super::http_transport::write_http_error(
                &mut self.stream,
                400,
                "missing sec-websocket-key",
            )?;
            return Ok(());
        };
        write_response(
            &mut self.stream,
            &HttpResponse {
                status: 101,
                headers: vec![
                    ("Upgrade".to_owned(), "websocket".to_owned()),
                    ("Connection".to_owned(), "Upgrade".to_owned()),
                    (
                        "Sec-WebSocket-Accept".to_owned(),
                        websocket_accept_value(key),
                    ),
                ],
                body: Vec::new(),
            },
        )?;

        let session_id = runtime.open_session(
            Symbol::qualified("codec", "binary"),
            runtime.session_isolation().clone(),
        )?;
        loop {
            if runtime.is_stopping() {
                let _ = runtime.close_session(session_id);
                return Ok(());
            }
            self.stream
                .set_read_timeout(Some(Duration::from_millis(SERVER_CONNECTION_IO_TIMEOUT_MS)))
                .map_err(io_to_host)?;
            let message = match read_ws_message(&mut self.stream) {
                Ok(message) => message,
                Err(error) if is_timeout(&error) => continue,
                Err(error) => {
                    let _ = runtime.close_session(session_id);
                    return Err(error);
                }
            };
            let Some(message) = message else {
                let _ = runtime.close_session(session_id);
                return Ok(());
            };
            let WsMessage::Binary(payload) = message else {
                let _ = write_ws_close(&mut self.stream, false);
                let _ = runtime.close_session(session_id);
                return Ok(());
            };
            let frame = decode_transport_frame(&payload)?;
            runtime.note_message_received();
            let reply = match runtime.with_cx(|cx| answer_or_negotiate(cx, site, frame.clone())) {
                Ok(reply) => {
                    update_negotiated_codec_from_reply(runtime, session_id, &frame, &reply)?;
                    reply
                }
                Err(error) => runtime.with_cx(|cx| error_frame_from_error(cx, &frame, &error))?,
            };
            write_ws_binary(&mut self.stream, &encode_transport_frame(&reply)?, false)?;
            runtime.note_message_sent();
        }
    }
}

impl ConnectionTransport for WsServerConnectionTransport {
    fn send_frame(&mut self, _cx: &mut Cx, _frame: ServerFrame) -> Result<()> {
        Err(Error::Eval(
            "ws server connection transport is receive-only".to_owned(),
        ))
    }

    fn recv_frame(
        &mut self,
        _cx: &mut Cx,
        _timeout: Option<Duration>,
    ) -> Result<Option<ServerFrame>> {
        Err(Error::Eval(
            "ws server connection transport does not expose raw frames".to_owned(),
        ))
    }

    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
        let _ = write_ws_close(&mut self.stream, false);
        let _ = self.stream.shutdown(Shutdown::Both);
        Ok(())
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn serve_connection(
        &mut self,
        runtime: &Arc<ServerRuntime>,
        site: &Arc<dyn EvalSite>,
    ) -> Result<()> {
        self.serve(runtime, site)
    }
}