pipa-js 0.1.3

A fast, minimal ES2023 JavaScript runtime built in Rust.
Documentation
use std::io::{self, Read, Write};
use std::net::TcpStream;
use std::os::unix::io::{AsRawFd, RawFd};
use std::sync::mpsc;
use std::thread;
use std::time::Duration;

pub enum Connection {
    Plain(TcpStream),
    Tls {
        tls: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
    },
}

impl Connection {
    pub fn connect_async(
        host: String,
        port: u16,
        use_tls: bool,
    ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
        Self::connect_async_with_roots(host, port, use_tls, Vec::new())
    }

    pub fn connect_async_with_roots(
        host: String,
        port: u16,
        use_tls: bool,
        extra_roots: Vec<Vec<u8>>,
    ) -> Result<mpsc::Receiver<Result<Self, String>>, String> {
        let (tx, rx) = mpsc::channel();
        thread::spawn(move || {
            let result = Self::connect_blocking(&host, port, use_tls, &extra_roots);
            let _ = tx.send(result);
        });
        Ok(rx)
    }

    fn connect_blocking(
        host: &str,
        port: u16,
        use_tls: bool,
        extra_roots: &[Vec<u8>],
    ) -> Result<Self, String> {
        let addr = format!("{host}:{port}");
        let stream =
            TcpStream::connect(&addr).map_err(|e| format!("connect to {addr} failed: {e}"))?;
        stream
            .set_read_timeout(None)
            .map_err(|e| format!("set_read_timeout failed: {e}"))?;
        stream
            .set_write_timeout(None)
            .map_err(|e| format!("set_write_timeout failed: {e}"))?;
        stream
            .set_nonblocking(true)
            .map_err(|e| format!("set_nonblocking failed: {e}"))?;

        if use_tls {
            Self::wrap_tls(host, stream, extra_roots)
        } else {
            Ok(Connection::Plain(stream))
        }
    }

    fn wrap_tls(host: &str, stream: TcpStream, extra_roots: &[Vec<u8>]) -> Result<Self, String> {
        let mut root_certs =
            rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
        for cert_der in extra_roots {
            root_certs
                .add(cert_der.clone().into())
                .map_err(|e| format!("add root cert failed: {e}"))?;
        }
        let config = rustls::ClientConfig::builder()
            .with_root_certificates(root_certs)
            .with_no_client_auth();
        let server_name = rustls::pki_types::ServerName::try_from(host)
            .map_err(|e| format!("invalid server name: {e}"))?
            .to_owned();
        let tls_conn = rustls::ClientConnection::new(std::sync::Arc::new(config), server_name)
            .map_err(|e| format!("tls handshake failed: {e}"))?;
        let tls = rustls::StreamOwned::new(tls_conn, stream);
        Ok(Connection::Tls { tls })
    }

    pub fn set_nonblocking(&self, nonblocking: bool) -> Result<(), String> {
        match self {
            Connection::Plain(stream) => stream
                .set_nonblocking(nonblocking)
                .map_err(|e| format!("set_nonblocking failed: {e}")),
            Connection::Tls { tls } => tls
                .sock
                .set_nonblocking(nonblocking)
                .map_err(|e| format!("set_nonblocking failed: {e}")),
        }
    }

    pub fn raw_fd(&self) -> RawFd {
        match self {
            Connection::Plain(stream) => stream.as_raw_fd(),
            Connection::Tls { tls } => tls.sock.as_raw_fd(),
        }
    }

    pub fn is_tls(&self) -> bool {
        matches!(self, Connection::Tls { .. })
    }

    pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<(), String> {
        match self {
            Connection::Plain(stream) => stream
                .set_read_timeout(dur)
                .map_err(|e| format!("set_read_timeout failed: {e}")),
            Connection::Tls { tls } => tls
                .sock
                .set_read_timeout(dur)
                .map_err(|e| format!("set_read_timeout failed: {e}")),
        }
    }

    pub fn try_clone(&self) -> Result<Self, String> {
        match self {
            Connection::Plain(stream) => stream
                .try_clone()
                .map(Connection::Plain)
                .map_err(|e| format!("try_clone failed: {e}")),
            Connection::Tls { .. } => Err("cannot clone TLS connection".into()),
        }
    }
}

impl Read for Connection {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        match self {
            Connection::Plain(stream) => stream.read(buf),
            Connection::Tls { tls } => tls.read(buf),
        }
    }
}

impl Write for Connection {
    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
        match self {
            Connection::Plain(stream) => stream.write(buf),
            Connection::Tls { tls } => tls.write(buf),
        }
    }

    fn flush(&mut self) -> io::Result<()> {
        match self {
            Connection::Plain(stream) => stream.flush(),
            Connection::Tls { tls } => tls.sock.flush(),
        }
    }
}

impl AsRawFd for Connection {
    fn as_raw_fd(&self) -> RawFd {
        self.raw_fd()
    }
}