specters 3.0.0

HTTP client with full TLS, HTTP/2, and HTTP/3 fingerprint control
Documentation
use std::future::Future;
use std::time::Duration;

use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::timeout as tokio_timeout;
use url::Url;

use crate::transport::connector::MaybeHttpsStream;
use crate::websocket::error::{WebSocketError, WebSocketResult};
use crate::websocket::frame::{decode_frame, encode_frame, FrameConfig, FrameDecoder, OpCode};
use crate::websocket::message::{CloseFrame, Message};
use crate::websocket::WebSocketConfig;

#[derive(Debug)]
pub struct WebSocket {
    stream: MaybeHttpsStream,
    url: Url,
    protocol: Option<String>,
    read_buffer: BytesMut,
    frame_config: FrameConfig,
    read_timeout: Option<Duration>,
    write_timeout: Option<Duration>,
    decoder: FrameDecoder,
    close_sent: bool,
    close_received: bool,
}

impl WebSocket {
    pub(crate) fn new(
        stream: MaybeHttpsStream,
        url: Url,
        protocol: Option<String>,
        config: WebSocketConfig,
        initial_read_buffer: Bytes,
    ) -> Self {
        Self {
            stream,
            url,
            protocol,
            read_buffer: BytesMut::from(&initial_read_buffer[..]),
            frame_config: FrameConfig::new(config.max_frame_size, config.max_message_size),
            read_timeout: config.read_timeout,
            write_timeout: config.write_timeout,
            decoder: FrameDecoder::new(),
            close_sent: false,
            close_received: false,
        }
    }

    pub fn url(&self) -> &Url {
        &self.url
    }

    pub fn protocol(&self) -> Option<&str> {
        self.protocol.as_deref()
    }

    pub async fn send(&mut self, msg: Message) -> WebSocketResult<()> {
        if self.close_sent && !matches!(msg, Message::Close(_)) {
            return Err(WebSocketError::protocol(
                &self.url,
                "cannot send data after close frame",
            ));
        }

        match msg {
            Message::Text(text) => self.write_frame(OpCode::Text, text.as_bytes()).await,
            Message::Binary(bytes) => self.write_frame(OpCode::Binary, &bytes).await,
            Message::Ping(bytes) => self.write_control(OpCode::Ping, &bytes).await,
            Message::Pong(bytes) => self.write_control(OpCode::Pong, &bytes).await,
            Message::Close(frame) => self.close(frame).await,
        }
    }

    pub async fn send_text(&mut self, text: impl Into<String>) -> WebSocketResult<()> {
        self.send(Message::Text(text.into())).await
    }

    pub async fn send_binary(&mut self, bytes: impl Into<Bytes>) -> WebSocketResult<()> {
        self.send(Message::Binary(bytes.into())).await
    }

    pub async fn next(&mut self) -> WebSocketResult<Option<Message>> {
        loop {
            let frame = match decode_frame(&self.url, &mut self.read_buffer, self.frame_config) {
                Ok(frame) => frame,
                Err(error) => return Err(self.best_effort_close_for_error(error).await),
            };

            if let Some(frame) = frame {
                let message = match self
                    .decoder
                    .decode_message(&self.url, frame, self.frame_config)
                {
                    Ok(message) => message,
                    Err(error) => return Err(self.best_effort_close_for_error(error).await),
                };

                match message {
                    Some(Message::Ping(payload)) => {
                        if !self.close_received {
                            self.write_control(OpCode::Pong, &payload).await?;
                        }
                        return Ok(Some(Message::Ping(payload)));
                    }
                    Some(Message::Close(frame)) => {
                        self.close_received = true;
                        if !self.close_sent {
                            self.send_close_raw(frame.clone()).await?;
                        }
                        return Ok(None);
                    }
                    Some(other) => return Ok(Some(other)),
                    None => {}
                }
            } else {
                let mut scratch = [0_u8; 8192];
                let n = Self::io_with_timeout(
                    &self.url,
                    self.read_timeout,
                    "read",
                    self.stream.read(&mut scratch),
                )
                .await?;
                if n == 0 {
                    return if self.close_sent || self.close_received {
                        Ok(None)
                    } else {
                        Err(WebSocketError::connection_closed(&self.url))
                    };
                }
                self.read_buffer.extend_from_slice(&scratch[..n]);
            }
        }
    }

    pub async fn close(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
        if !self.close_sent {
            self.send_close_raw(frame).await?;
        }
        Ok(())
    }

    async fn write_frame(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
        if payload.len() > self.frame_config.max_frame_size {
            return Err(WebSocketError::limit_exceeded(
                &self.url,
                format!("frame exceeds {} bytes", self.frame_config.max_frame_size),
            ));
        }
        if matches!(opcode, OpCode::Text | OpCode::Binary)
            && payload.len() > self.frame_config.max_message_size
        {
            return Err(WebSocketError::limit_exceeded(
                &self.url,
                format!(
                    "message exceeds {} bytes",
                    self.frame_config.max_message_size
                ),
            ));
        }
        let bytes = encode_frame(opcode, payload, true)?;
        Self::io_with_timeout(
            &self.url,
            self.write_timeout,
            "write",
            self.stream.write_all(&bytes),
        )
        .await?;
        Self::io_with_timeout(&self.url, self.write_timeout, "flush", self.stream.flush()).await
    }

    async fn write_control(&mut self, opcode: OpCode, payload: &[u8]) -> WebSocketResult<()> {
        if payload.len() > 125 {
            return Err(WebSocketError::protocol(
                &self.url,
                "control frame payload exceeds 125 bytes",
            ));
        }
        self.write_frame(opcode, payload).await
    }

    async fn send_close_raw(&mut self, frame: Option<CloseFrame>) -> WebSocketResult<()> {
        let payload = match frame {
            Some(frame) => frame.encode(&self.url)?,
            None => Vec::new(),
        };
        self.write_control(OpCode::Close, &payload).await?;
        self.close_sent = true;
        Ok(())
    }

    async fn best_effort_close_for_error(&mut self, error: WebSocketError) -> WebSocketError {
        if let Some(code) = error.close_code() {
            if !self.close_sent {
                let frame = CloseFrame {
                    code,
                    reason: String::new(),
                };
                let _ = self.send_close_raw(Some(frame)).await;
            }
        }
        error
    }

    async fn io_with_timeout<T, F>(
        url: &Url,
        timeout: Option<Duration>,
        operation: &'static str,
        future: F,
    ) -> WebSocketResult<T>
    where
        F: Future<Output = std::io::Result<T>>,
    {
        let result = match timeout {
            Some(duration) => {
                tokio_timeout(duration, future)
                    .await
                    .map_err(|_| WebSocketError::Timeout {
                        url: url.to_string(),
                        operation: format!("{operation} after {:?}", duration),
                    })?
            }
            None => future.await,
        };

        result.map_err(|error| WebSocketError::io(url, error))
    }
}