pipa-js 0.1.1

A fast, minimal ES2023 JavaScript runtime built in Rust.
Documentation
use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use std::os::unix::io::RawFd;
use std::sync::mpsc;

use crate::http::conn::Connection;
use crate::http::headers::Headers;
use crate::http::status::HttpStatus;
use crate::http::url::Url;
use crate::http::ws::frame::{OpCode, WsFrame};
use crate::http::ws::handshake::WsHandshake;

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WsState {
    Connecting,
    Handshake,
    Open,
    Closing,
    Closed,
}

#[derive(Debug, Clone)]
pub enum WsEvent {
    Open,
    Message(Vec<u8>, bool),
    Close(u16, String),
    Error(String),
}

pub struct WsConnection {
    pub url: Url,
    pub state: WsState,
    conn: Option<Connection>,
    connect_rx: Option<mpsc::Receiver<Result<Connection, String>>>,
    write_buf: Vec<u8>,
    write_pos: usize,
    read_buf: [u8; 8192],
    read_data: Vec<u8>,
    key: String,
    pending_frames: VecDeque<WsFrame>,
    close_code: u16,
    close_reason: String,
    pub ready_state: u8,
}

impl WsConnection {
    pub fn new(url: Url) -> Self {
        WsConnection {
            url,
            state: WsState::Connecting,
            conn: None,
            connect_rx: None,
            write_buf: Vec::new(),
            write_pos: 0,
            read_buf: [0u8; 8192],
            read_data: Vec::new(),
            key: String::new(),
            pending_frames: VecDeque::new(),
            close_code: 0,
            close_reason: String::new(),
            ready_state: 0,
        }
    }

    pub fn fd(&self) -> Option<RawFd> {
        self.conn.as_ref().map(|c| c.raw_fd())
    }

    pub fn set_connect_rx(&mut self, rx: mpsc::Receiver<Result<Connection, String>>) {
        self.connect_rx = Some(rx);
        self.state = WsState::Connecting;
    }

    pub fn try_advance(&mut self) -> Result<Option<WsEvent>, String> {
        loop {
            match self.state {
                WsState::Connecting => match self.connect_rx.as_ref().unwrap().try_recv() {
                    Ok(result) => {
                        let conn = result?;
                        conn.set_nonblocking(true)?;
                        self.conn = Some(conn);
                        self.key = WsHandshake::generate_key();
                        self.build_handshake_request();
                        self.state = WsState::Handshake;
                    }
                    Err(mpsc::TryRecvError::Empty) => {
                        return Ok(None);
                    }
                    Err(mpsc::TryRecvError::Disconnected) => {
                        return Err("connect thread disconnected".into());
                    }
                },

                WsState::Handshake => {
                    if !self.write_buf.is_empty() {
                        let conn = self.conn.as_mut().unwrap();
                        let remaining = &self.write_buf[self.write_pos..];
                        if !remaining.is_empty() {
                            match conn.write(remaining) {
                                Ok(n) => {
                                    self.write_pos += n;
                                    if self.write_pos >= self.write_buf.len() {
                                        self.write_buf.clear();
                                        self.write_pos = 0;
                                    } else {
                                        return Ok(None);
                                    }
                                }
                                Err(e) if e.kind() == ErrorKind::WouldBlock => {
                                    return Ok(None);
                                }
                                Err(e) => return Err(format!("ws handshake write: {e}")),
                            }
                        }
                    }

                    let conn = self.conn.as_mut().unwrap();
                    match conn.read(&mut self.read_buf) {
                        Ok(0) => return Err("connection closed during handshake".into()),
                        Ok(n) => {
                            self.read_data.extend_from_slice(&self.read_buf[..n]);
                            if let Some(pos) =
                                self.read_data.windows(4).position(|w| w == b"\r\n\r\n")
                            {
                                let header_data = &self.read_data[..pos + 4];
                                let (headers, _) = Headers::from_bytes(header_data)?;
                                let status_line_end = self
                                    .read_data
                                    .windows(2)
                                    .position(|w| w == b"\r\n")
                                    .unwrap_or(0);
                                let status_line = &self.read_data[..status_line_end];
                                let code = if status_line.len() >= 12 {
                                    let s = &status_line[9..12];
                                    String::from_utf8_lossy(s).parse::<u16>().unwrap_or(0)
                                } else {
                                    0
                                };
                                let status = HttpStatus(code);
                                let accept = WsHandshake::validate_response(status, &headers)?;
                                if !WsHandshake::verify_accept(&self.key, &accept) {
                                    return Err("WebSocket accept mismatch".into());
                                }
                                self.state = WsState::Open;
                                self.ready_state = 1;
                                self.read_data.clear();
                                return Ok(Some(WsEvent::Open));
                            }
                        }
                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
                            return Ok(None);
                        }
                        Err(e) => return Err(format!("ws handshake read: {e}")),
                    }
                }

                WsState::Open => {
                    if let Some(frame) = self.pending_frames.pop_front() {
                        self.write_buf = frame.encode();
                        self.write_pos = 0;
                    }

                    if !self.write_buf.is_empty() {
                        let conn = self.conn.as_mut().unwrap();
                        let remaining = &self.write_buf[self.write_pos..];
                        if !remaining.is_empty() {
                            match conn.write(remaining) {
                                Ok(n) => {
                                    self.write_pos += n;
                                }
                                Err(e) if e.kind() == ErrorKind::WouldBlock => {
                                    return Ok(None);
                                }
                                Err(e) => return Err(format!("ws write error: {e}")),
                            }
                        }
                        if self.write_pos >= self.write_buf.len() {
                            self.write_buf.clear();
                            self.write_pos = 0;
                        }
                        if !self.pending_frames.is_empty() {
                            return Ok(None);
                        }
                    }

                    let conn = self.conn.as_mut().unwrap();
                    match conn.read(&mut self.read_buf) {
                        Ok(0) => {
                            self.state = WsState::Closed;
                            self.ready_state = 3;
                            return Ok(Some(WsEvent::Close(1006, "connection closed".into())));
                        }
                        Ok(n) => {
                            self.read_data.extend_from_slice(&self.read_buf[..n]);
                            let frames = WsFrame::parse_all(&self.read_data)?;
                            if !frames.is_empty() {
                                let consumed = self.calculate_consumed(&frames);
                                self.read_data.drain(..consumed);
                                for frame in frames {
                                    match frame.opcode {
                                        OpCode::Text | OpCode::Binary => {
                                            let is_text = frame.opcode == OpCode::Text;
                                            return Ok(Some(WsEvent::Message(
                                                frame.payload,
                                                is_text,
                                            )));
                                        }
                                        OpCode::Ping => {
                                            let pong = WsFrame::new_pong(frame.payload);
                                            self.pending_frames.push_back(pong);
                                            return Ok(None);
                                        }
                                        OpCode::Close => {
                                            let (code, reason) =
                                                Self::parse_close_payload(&frame.payload);
                                            self.close_code = code;
                                            self.close_reason = reason.clone();
                                            let close_frame = WsFrame::new_close(code, &reason);
                                            self.pending_frames.push_back(close_frame);
                                            self.state = WsState::Closing;
                                            self.ready_state = 2;
                                            return Ok(None);
                                        }
                                        OpCode::Pong => {}
                                        OpCode::Continuation => {}
                                    }
                                }
                            }
                        }
                        Err(e) if e.kind() == ErrorKind::WouldBlock => {
                            return Ok(None);
                        }
                        Err(e) => return Err(format!("ws read error: {e}")),
                    }
                }

                WsState::Closing => {
                    if let Some(frame) = self.pending_frames.pop_front() {
                        self.write_buf = frame.encode();
                        self.write_pos = 0;
                    }
                    if !self.write_buf.is_empty() {
                        let conn = self.conn.as_mut().unwrap();
                        let remaining = &self.write_buf[self.write_pos..];
                        if !remaining.is_empty() {
                            let _ = conn.write(remaining);
                        }
                        self.write_buf.clear();
                        self.write_pos = 0;
                    }
                    let conn = self.conn.as_mut().unwrap();
                    let _ = conn.read(&mut self.read_buf);
                    self.state = WsState::Closed;
                    self.ready_state = 3;
                    return Ok(Some(WsEvent::Close(
                        self.close_code,
                        self.close_reason.clone(),
                    )));
                }

                WsState::Closed => {
                    return Ok(None);
                }
            }
        }
    }

    pub fn send_text(&mut self, data: &str) {
        let frame = WsFrame::new_text(data.as_bytes().to_vec());
        self.pending_frames.push_back(frame);
    }

    pub fn send_binary(&mut self, data: &[u8]) {
        let frame = WsFrame::new_binary(data.to_vec());
        self.pending_frames.push_back(frame);
    }

    pub fn close(&mut self, code: u16, reason: &str) {
        if self.state == WsState::Open {
            let frame = WsFrame::new_close(code, reason);
            self.pending_frames.push_back(frame);
            self.state = WsState::Closing;
            self.ready_state = 2;
            self.close_code = code;
            self.close_reason = reason.to_string();
        }
    }

    pub fn wants_read(&self) -> bool {
        matches!(
            self.state,
            WsState::Connecting | WsState::Handshake | WsState::Open
        )
    }

    pub fn wants_write(&self) -> bool {
        matches!(
            self.state,
            WsState::Handshake | WsState::Open | WsState::Closing
        ) && (!self.write_buf.is_empty() || !self.pending_frames.is_empty())
    }

    fn build_handshake_request(&mut self) {
        let host = format!(
            "{}:{}",
            self.url.host,
            if self.url.port != 80 && self.url.port != 443 {
                self.url.port
            } else {
                0
            }
        );
        let host = if host.ends_with(":0") {
            self.url.host.clone()
        } else {
            host
        };
        let path = self.url.request_target();
        let mut headers = WsHandshake::build_request(&host, &path, &self.key);
        if !headers.contains("user-agent") {
            headers.set("User-Agent", "pipa/0.1");
        }
        let mut buf = Vec::new();
        buf.extend_from_slice(b"GET ");
        buf.extend_from_slice(path.as_bytes());
        buf.extend_from_slice(b" HTTP/1.1\r\n");
        buf.extend_from_slice(headers.to_request_bytes().as_ref());
        buf.extend_from_slice(b"\r\n");
        self.write_buf = buf;
        self.write_pos = 0;
    }

    fn parse_close_payload(payload: &[u8]) -> (u16, String) {
        if payload.len() >= 2 {
            let code = u16::from_be_bytes([payload[0], payload[1]]);
            let reason = if payload.len() > 2 {
                String::from_utf8_lossy(&payload[2..]).to_string()
            } else {
                String::new()
            };
            (code, reason)
        } else {
            (1005, String::new())
        }
    }

    fn calculate_consumed(&self, frames: &[WsFrame]) -> usize {
        let mut total = 0usize;
        for frame in frames {
            let mut frame_size = 2;
            let payload_len = frame.payload.len();
            if payload_len >= 126 && payload_len <= 0xFFFF {
                frame_size += 2;
            } else if payload_len > 0xFFFF {
                frame_size += 8;
            }
            if frame.mask.is_some() {
                frame_size += 4;
            }
            frame_size += payload_len;
            total += frame_size;
        }
        total
    }
}