reqrio 0.3.0-beta1

A lightweight, high-performance, fingerprint-based HTTP request library.
Documentation
mod connect;

use super::ext::TimeoutRW;
use crate::error::HlsResult;
use crate::stream::{ConnParam, StreamParam};
use crate::{Buffer, ClientConfig, HlsError, ProxyStream, ServerConfig};
use connect::{Connecting, Handshake};
use reqtls::{rand, Alert, Config, Connection, RecordType, StreamHandle, WriteExt, ALPN};
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use std::{io, mem};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;


pub struct TlsStream<S> {
    conn: Connection,
    stream: S,
    encrypted_channel: bool,
    handshake_finished: bool,
    hello_retrying: bool,
    read_buffer: Buffer,
    write_buffer: Buffer,
    shutdown_wrote: bool,
    wrote_len: usize,
    pending: Vec<usize>,
    client_hello: Vec<u8>,
}


impl<S: AsyncRead + AsyncWrite + Unpin> TlsStream<S> {
    fn _connect(stream: S, conn: Connection, config: Config<'_>, buffer: Buffer) -> Connecting<'_, S> {
        let stream = TlsStream {
            stream,
            conn,
            encrypted_channel: false,
            handshake_finished: false,
            hello_retrying: false,
            read_buffer: Buffer::default(),
            write_buffer: buffer,
            shutdown_wrote: false,
            wrote_len: 0,
            pending: vec![],
            client_hello: vec![],
        };
        Connecting {
            handshake: Handshake::Handshaking(Box::new(stream)),
            config,
            sent_client_hello: false,
        }
    }
    #[inline]
    pub fn connect(stream: S, mut config: ClientConfig<'_>) -> Connecting<'_, S> {
        let session = config.session.as_ref().cloned().unwrap_or_else(Default::default);
        Connecting {
            handshake: Handshake::Handshaking(Box::new(TlsStream {
                stream,
                conn: Connection::from_client(rand::random(), session, mem::take(&mut config.key_log)).with_verify(config.verify),
                handshake_finished: false,
                hello_retrying: false,
                read_buffer: Buffer::default(),
                write_buffer: Buffer::default(),
                shutdown_wrote: false,
                wrote_len: 0,
                pending: vec![],
                client_hello: vec![],
                encrypted_channel: false,
            })),
            sent_client_hello: false,
            config: Config::Client(config),
        }
    }

    #[inline]
    pub fn accept(stream: S, config: ServerConfig<'_>) -> Connecting<'_, S> {
        TlsStream::_connect(stream, Connection::default(), Config::Server(config), Buffer::default())
    }

    pub fn alpn(&self) -> Option<&ALPN> {
        self.conn.alpn()
    }

    pub fn client_hello(&self) -> &[u8] { &self.client_hello }
}

impl<S> StreamHandle for TlsStream<S> {
    #[inline]
    fn stream_param(&mut self) -> (&mut Buffer, StreamParam<'_>) {
        (&mut self.read_buffer, StreamParam {
            handshake_finish: &mut self.handshake_finished,
            encrypted_channel: &mut self.encrypted_channel,
            hello_retrying: &mut self.hello_retrying,
            write_buffer: &mut self.write_buffer,
            conn: &mut self.conn,
        })
    }
}

impl<S> TlsStream<S> {
    pub fn connection(&self) -> &Connection {
        &self.conn
    }
}

impl<S: AsyncRead + Unpin> TlsStream<S> {
    fn read_size(&mut self, max_size: usize, cx: &mut Context<'_>) -> Poll<HlsResult<()>> {
        while self.read_buffer.len() < max_size {
            self.read_buffer.check_move(max_size)?;
            let stream = Pin::new(&mut self.stream);
            let mut buf = ReadBuf::new(self.read_buffer.unfilled());

            match stream.poll_read(cx, &mut buf)? {
                Poll::Pending => return Poll::Pending,
                Poll::Ready(_) => {
                    let len = buf.filled().len();
                    if len == 0 { return Poll::Ready(Err(HlsError::PeerClosedConnection)); }
                    self.read_buffer.add_len(len);
                }
            }
        }
        Poll::Ready(Ok(()))
    }

    fn read_next_record(&mut self, cx: &mut Context<'_>) -> Poll<HlsResult<usize>> {
        if self.read_buffer.len() < 5 && let Poll::Pending = self.read_size(5, cx)? {
            return Poll::Pending;
        }
        let filled = self.read_buffer.filled();
        let record_len = u16::from_be_bytes([filled[3], filled[4]]) as usize + 5;
        if self.read_size(record_len, cx)?.is_pending() {
            return Poll::Pending;
        }
        Poll::Ready(Ok(record_len))
    }
}

impl<S: AsyncWrite + Unpin> TlsStream<S> {
    #[inline]
    fn write_buffer(&mut self, cx: &mut Context<'_>) -> Poll<HlsResult<()>> {
        loop {
            let stream = Pin::new(&mut self.stream);
            match stream.poll_write(cx, self.write_buffer.filled())? {
                Poll::Ready(wrote) => {
                    if wrote == 0 { return Poll::Ready(Err(HlsError::PeerClosedConnection)); }
                    if self.write_buffer.used_empty(wrote) { break; }
                }
                Poll::Pending => return Poll::Pending,
            }
        }
        self.write_buffer.reset();
        Poll::Ready(Ok(()))
    }
}

impl<S: AsyncRead + Unpin> AsyncRead for TlsStream<S> {
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
        if self.shutdown_wrote { return Poll::Ready(Ok(())); }
        let stream = self.get_mut();
        loop {
            let record_len = match stream.read_next_record(cx)? {
                Poll::Ready(len) => len,
                Poll::Pending => return Poll::Pending,
            };
            let len = stream.handle_record(record_len, None, buf.initialized_mut())?;
            if len == 0 { continue; }
            buf.set_filled(len);
            return Poll::Ready(Ok(()));
        }
    }
}

impl<S: AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
        let stream = self.get_mut();
        let chucks = buf.chunks(16384).collect::<Vec<_>>();
        if stream.pending.is_empty() {
            stream.wrote_len = 0;
            stream.pending = (0..chucks.len()).collect();
        }
        loop {
            if stream.pending.is_empty() { break; }
            if stream.write_buffer.is_empty() {
                let record_len = stream.conn.make_message(RecordType::ApplicationData, stream.write_buffer.unfilled(), chucks[stream.pending[0]])?;
                stream.write_buffer.add_len(record_len);
                stream.wrote_len += chucks[stream.pending[0]].len();
            }
            match stream.write_buffer(cx)? {
                Poll::Ready(_) => stream.pending.remove(0),
                Poll::Pending => return Poll::Pending,
            };
        }
        assert_eq!(stream.wrote_len, buf.len());
        Poll::Ready(Ok(stream.wrote_len))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Pin::new(&mut self.stream).poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        let stream = self.get_mut();
        if stream.write_buffer.is_empty() {
            let len = stream.conn.make_message(RecordType::Alert, &mut stream.write_buffer.unfilled(), &Alert::close_notify().to_bytes())?;
            stream.write_buffer.add_len(len);
        }
        match stream.shutdown_wrote {
            true => Pin::new(&mut stream.stream).poll_shutdown(cx),
            false => match stream.write_buffer(cx)? {
                Poll::Ready(_) => {
                    stream.shutdown_wrote = true;
                    Pin::new(&mut stream.stream).poll_shutdown(cx)
                }
                Poll::Pending => Poll::Pending,
            }
        }
    }
}

pub struct TlsStreamA {
    stream: TlsStream<ProxyStream<TcpStream>>,
    read_timeout: Option<Duration>,
    write_timeout: Option<Duration>,
}

impl TlsStreamA {
    pub async fn connect_timeout(param: ConnParam<'_>, tcp: ProxyStream<TcpStream>) -> HlsResult<TlsStreamA> {
        let connect_timeout = param.timeout.connect();
        let read_timeout = param.timeout.read();
        let write_timeout = param.timeout.write();
        let config = ClientConfig::from(param);
        Ok(TlsStreamA {
            stream: tokio::time::timeout(connect_timeout, TlsStream::connect(tcp, config)).await??,
            read_timeout: Some(read_timeout),
            write_timeout: Some(write_timeout),
        })
    }

    pub fn alpn(&self) -> Option<&ALPN> {
        self.stream.alpn()
    }

    pub fn get_ref(&self) -> &TlsStream<ProxyStream<TcpStream>> { &self.stream }
}

impl TimeoutRW<TlsStream<ProxyStream<TcpStream>>> for TlsStreamA {
    fn stream(&mut self) -> &mut TlsStream<ProxyStream<TcpStream>> {
        &mut self.stream
    }

    fn read_timeout(&self) -> Option<Duration> {
        self.read_timeout
    }

    fn write_timeout(&self) -> Option<Duration> {
        self.write_timeout
    }
}