sim-lib-server 0.1.0-rc.1

SIM workspace package for sim lib server.
Documentation
use std::{
    io::{BufReader, ErrorKind, Write},
    net::{Shutdown, TcpListener, TcpStream},
    sync::Arc,
    time::Duration,
};

use sim_kernel::{Cx, Error, Result, Symbol};

use crate::{
    EvalSite, FrameKind, ServerAddress, ServerFrame, ServerRuntime, StreamSink,
    http::{
        HttpRequest, HttpResponse, ParsedUrl, base64_decode, base64_encode, format_url,
        header_value, parse_url, read_request, read_response, read_sse_event, write_request,
        write_response,
    },
};

use super::{
    ConnectionTransport, HTTP_TRANSPORT_PATH, SSE_TRANSPORT_PATH, ServerTransport,
    decode_transport_frame, encode_transport_frame, io_to_host, update_negotiated_codec_from_reply,
};

pub struct SseServerTransport {
    address: ServerAddress,
    listener: TcpListener,
    path: String,
}

impl SseServerTransport {
    pub fn bind(address: ServerAddress) -> Result<Self> {
        let ServerAddress::Sse { url } = &address else {
            return Err(Error::Eval(
                "sse transport requires an sse address".to_owned(),
            ));
        };
        let parsed = parse_url(url, "http", SSE_TRANSPORT_PATH)?;
        let listener =
            TcpListener::bind((parsed.host.as_str(), parsed.port)).map_err(io_to_host)?;
        listener.set_nonblocking(true).map_err(io_to_host)?;
        let local_addr = listener.local_addr().map_err(io_to_host)?;
        let address = ServerAddress::Sse {
            url: format_url(&ParsedUrl {
                port: local_addr.port(),
                ..parsed.clone()
            }),
        };
        Ok(Self {
            address,
            listener,
            path: parsed.path,
        })
    }
}

impl ServerTransport for SseServerTransport {
    fn address(&self) -> &ServerAddress {
        &self.address
    }

    fn accept(&self, cx: &mut Cx) -> Result<Box<dyn ConnectionTransport>> {
        loop {
            if let Some(connection) = self.accept_timeout(cx, Duration::from_millis(25))? {
                return Ok(connection);
            }
        }
    }

    fn shutdown(&self, _cx: &mut Cx) -> Result<()> {
        Ok(())
    }

    fn accept_timeout(
        &self,
        _cx: &mut Cx,
        _timeout: Duration,
    ) -> Result<Option<Box<dyn ConnectionTransport>>> {
        match self.listener.accept() {
            Ok((stream, _peer)) => {
                stream.set_nodelay(true).map_err(io_to_host)?;
                Ok(Some(Box::new(SseServerConnectionTransport::new(
                    stream,
                    self.path.clone(),
                ))))
            }
            Err(error) if error.kind() == ErrorKind::WouldBlock => Ok(None),
            Err(error) => Err(io_to_host(error)),
        }
    }
}

pub struct SseConnectionTransport {
    address: ServerAddress,
    inner: Option<super::HttpConnectionTransport>,
}

impl SseConnectionTransport {
    pub fn connect(address: &ServerAddress) -> Result<Self> {
        let ServerAddress::Sse { url } = address else {
            return Err(Error::Eval(
                "sse connect requires an sse address".to_owned(),
            ));
        };
        let parsed = parse_url(url, "http", SSE_TRANSPORT_PATH)?;
        Ok(Self {
            address: ServerAddress::Http {
                url: format_url(&ParsedUrl {
                    path: HTTP_TRANSPORT_PATH.to_owned(),
                    ..parsed
                }),
            },
            inner: None,
        })
    }

    fn inner_mut(&mut self) -> Result<&mut super::HttpConnectionTransport> {
        if self.inner.is_none() {
            self.inner = Some(super::HttpConnectionTransport::connect(&self.address)?);
        }
        self.inner.as_mut().ok_or_else(|| {
            Error::HostError("sse http fallback transport was not initialized".to_owned())
        })
    }
}

impl ConnectionTransport for SseConnectionTransport {
    fn send_frame(&mut self, cx: &mut Cx, frame: ServerFrame) -> Result<()> {
        self.inner_mut()?.send_frame(cx, frame)
    }

    fn recv_frame(
        &mut self,
        cx: &mut Cx,
        timeout: Option<Duration>,
    ) -> Result<Option<ServerFrame>> {
        self.inner_mut()?.recv_frame(cx, timeout)
    }

    fn close(&mut self, cx: &mut Cx) -> Result<()> {
        if let Some(inner) = &mut self.inner {
            inner.close(cx)?;
        }
        Ok(())
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }
}

struct SseServerConnectionTransport {
    stream: TcpStream,
    path: String,
}

impl SseServerConnectionTransport {
    fn new(stream: TcpStream, path: String) -> Self {
        Self { stream, path }
    }

    fn serve(&mut self, runtime: &Arc<ServerRuntime>, site: &Arc<dyn EvalSite>) -> Result<()> {
        let session_id = runtime.open_session(
            Symbol::qualified("codec", "binary"),
            runtime.session_isolation().clone(),
        )?;
        let request = match read_request(&mut self.stream)? {
            Some(request) => request,
            None => {
                let _ = runtime.close_session(session_id);
                return Ok(());
            }
        };
        if request.method != "GET" {
            super::http_transport::write_http_error(&mut self.stream, 405, "method not allowed")?;
            let _ = runtime.close_session(session_id);
            return Ok(());
        }
        if request.path != self.path {
            super::http_transport::write_http_error(&mut self.stream, 404, "not found")?;
            let _ = runtime.close_session(session_id);
            return Ok(());
        }
        let Some(frame_header) = header_value(&request.headers, "X-Sim-Frame") else {
            super::http_transport::write_http_error(
                &mut self.stream,
                400,
                "missing x-sim-frame header",
            )?;
            let _ = runtime.close_session(session_id);
            return Ok(());
        };
        let frame = decode_transport_frame(&base64_decode(frame_header)?)?;
        runtime.note_message_received();
        write_response(
            &mut self.stream,
            &HttpResponse {
                status: 200,
                headers: vec![
                    ("Content-Type".to_owned(), "text/event-stream".to_owned()),
                    ("Cache-Control".to_owned(), "no-cache".to_owned()),
                ],
                body: Vec::new(),
            },
        )?;
        let mut sink = SseStreamSink {
            stream: &mut self.stream,
            sent_end: false,
            runtime,
        };
        update_negotiated_codec_from_reply(runtime, session_id, &frame, &frame)?;
        let outcome = runtime.with_cx(|cx| site.stream(cx, frame, &mut sink));
        let _ = sink.end_without_cx();
        let _ = runtime.close_session(session_id);
        outcome
    }
}

impl ConnectionTransport for SseServerConnectionTransport {
    fn send_frame(&mut self, _cx: &mut Cx, _frame: ServerFrame) -> Result<()> {
        Err(Error::Eval(
            "sse server connection transport is receive-only".to_owned(),
        ))
    }

    fn recv_frame(
        &mut self,
        _cx: &mut Cx,
        _timeout: Option<Duration>,
    ) -> Result<Option<ServerFrame>> {
        Err(Error::Eval(
            "sse server connection transport does not expose raw frames".to_owned(),
        ))
    }

    fn close(&mut self, _cx: &mut Cx) -> Result<()> {
        let _ = self.stream.shutdown(Shutdown::Both);
        Ok(())
    }

    fn as_any(&self) -> &dyn std::any::Any {
        self
    }

    fn serve_connection(
        &mut self,
        runtime: &Arc<ServerRuntime>,
        site: &Arc<dyn EvalSite>,
    ) -> Result<()> {
        self.serve(runtime, site)
    }
}

struct SseStreamSink<'a> {
    stream: &'a mut TcpStream,
    sent_end: bool,
    runtime: &'a Arc<ServerRuntime>,
}

impl SseStreamSink<'_> {
    fn write_event(&mut self, event: &str, data: &str) -> Result<()> {
        write!(self.stream, "event: {event}\r\ndata: {data}\r\n\r\n").map_err(io_to_host)?;
        self.stream.flush().map_err(io_to_host)
    }

    fn end_without_cx(&mut self) -> Result<()> {
        if self.sent_end {
            return Ok(());
        }
        self.write_event("end", "")?;
        self.sent_end = true;
        Ok(())
    }
}

fn sse_event_name_for_frame(kind: &FrameKind) -> &'static str {
    match kind {
        FrameKind::StreamStart => "stream-start",
        FrameKind::StreamChunk => "stream-chunk",
        FrameKind::StreamEnd => "stream-end",
        _ => "chunk",
    }
}

impl StreamSink for SseStreamSink<'_> {
    fn chunk(&mut self, _cx: &mut Cx, frame: ServerFrame) -> Result<()> {
        let event = sse_event_name_for_frame(&frame.kind);
        let payload = encode_transport_frame(&frame)?;
        self.write_event(event, &base64_encode(&payload))?;
        self.runtime.note_message_sent();
        Ok(())
    }

    fn end(&mut self, _cx: &mut Cx) -> Result<()> {
        self.end_without_cx()
    }
}

pub(super) fn sse_stream_request(
    cx: &mut Cx,
    address: &ServerAddress,
    frame: ServerFrame,
    sink: &mut dyn StreamSink,
) -> Result<()> {
    let ServerAddress::Sse { url } = address else {
        return Err(Error::Eval("sse stream requires an sse address".to_owned()));
    };
    let parsed = parse_url(url, "http", SSE_TRANSPORT_PATH)?;
    let mut stream = TcpStream::connect((parsed.host.as_str(), parsed.port)).map_err(io_to_host)?;
    stream.set_nodelay(true).map_err(io_to_host)?;
    let request_frame = base64_encode(&encode_transport_frame(&frame)?);
    write_request(
        &mut stream,
        &HttpRequest {
            method: "GET".to_owned(),
            path: parsed.path,
            headers: vec![
                ("Host".to_owned(), "sim-server".to_owned()),
                ("Accept".to_owned(), "text/event-stream".to_owned()),
                ("X-Sim-Frame".to_owned(), request_frame),
            ],
            body: Vec::new(),
        },
    )?;
    let response = read_response(&mut stream)?;
    if response.status != 200 {
        return Err(Error::Eval(format!("sse status {}", response.status)));
    }
    let mut reader = BufReader::new(stream);
    while let Some((event, data)) = read_sse_event(&mut reader)? {
        match event.as_str() {
            "chunk" | "stream-start" | "stream-chunk" | "stream-end" => {
                let payload = base64_decode(&data)?;
                let frame = decode_transport_frame(&payload)?;
                sink.chunk(cx, frame)?;
            }
            "end" => {
                sink.end(cx)?;
                return Ok(());
            }
            _ => {}
        }
    }
    sink.end(cx)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn sse_event_names_preserve_stream_frame_kinds() {
        assert_eq!(
            sse_event_name_for_frame(&FrameKind::StreamStart),
            "stream-start"
        );
        assert_eq!(
            sse_event_name_for_frame(&FrameKind::StreamChunk),
            "stream-chunk"
        );
        assert_eq!(
            sse_event_name_for_frame(&FrameKind::StreamEnd),
            "stream-end"
        );
        assert_eq!(sse_event_name_for_frame(&FrameKind::Response), "chunk");
    }
}