reqrio 0.2.0

A lightweight, high concurrency HTTP request library
Documentation
use std::io;
use crate::stream::config::{ClientConfig, Config, ServerConfig};
use crate::*;
use std::io::Error;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
use crate::stream::TlsStreamHandle;

pub struct TlsStream<S> {
    conn: Connection,
    stream: S,
    handshake_finished: bool,
    read_buffer: Buffer,
    write_buffer: Buffer,
    shutdown_wrote: bool,
    wrote_len: usize,
    pending: Vec<usize>,
    client_hello: Vec<u8>,
}

impl<S: AsyncRead + AsyncWrite + Unpin> TlsStream<S> {
    async fn new(stream: S, conn: Connection, mut config: Config<'_>, buffer: Buffer) -> HlsResult<TlsStream<S>> {
        let mut stream = TlsStream {
            stream,
            conn,
            handshake_finished: false,
            read_buffer: Buffer::default(),
            write_buffer: buffer,
            shutdown_wrote: false,
            wrote_len: 0,
            pending: vec![],
            client_hello: vec![],
        };
        loop {
            let record_len = stream.read_packet().await?;
            let hello_done = stream.handle_message(Some(&mut config)).await?;
            stream.read_buffer.move_to(record_len..stream.read_buffer.len(), 0);
            if hello_done { break; }
        }
        Ok(stream)
    }
    pub async fn connect(mut stream: S, mut config: ClientConfig<'_>) -> HlsResult<TlsStream<S>> {
        let mut write_buffer = Buffer::default();
        let conn = Self::handle_client_hello(&mut config, &mut write_buffer)?;
        stream.write_all(write_buffer.filled()).await?;
        write_buffer.reset();
        TlsStream::new(stream, conn, Config::Client(config), write_buffer).await
    }

    pub async fn accept(stream: S, config: ServerConfig<'_>) -> HlsResult<TlsStream<S>> {
        TlsStream::new(stream, Connection::default(), Config::Server(config), Buffer::default()).await
    }

    pub async fn read_packet(&mut self) -> HlsResult<usize> {
        let record_len = match self.read_buffer.len() < 5 {
            true => {
                self.read_buffer.async_read(&mut self.stream).await?;
                u16::from_be_bytes([self.read_buffer[3], self.read_buffer[4]]) as usize
            }
            false => u16::from_be_bytes([self.read_buffer[3], self.read_buffer[4]]) as usize,
        } + 5;
        while self.read_buffer.len() < record_len {
            self.read_buffer.async_read(&mut self.stream).await?;
        }
        if !self.handshake_finished && self.read_buffer[0] == 22 { self.conn.update_session(&self.read_buffer[5..record_len])?; }
        Ok(record_len)
    }

    async fn handle_message(&mut self, mut config: Option<&mut Config<'_>>) -> HlsResult<bool> {
        let record = RecordLayer::from_bytes(self.read_buffer.filled_mut(), self.handshake_finished, Some(self.conn.cipher_suite()))?;
        match record.context_type {
            RecordType::CipherSpec => self.handshake_finished = true,
            RecordType::Alert => {
                let record_len = record.len as usize + 5;
                return Err(self.handle_by_alert(self.handshake_finished, record_len)?.into());
            }
            RecordType::HandShake => {
                for message in record.messages {
                    match message {
                        Message::ServerHello(v) => self.conn.set_by_server_hello(&v)?,
                        Message::Certificate(v) => {
                            let config = config.as_mut().ok_or("config can't be null")?;
                            let config = config.client_mut().ok_or("missing config")?;
                            self.conn.set_by_certificate(v, config.ca_certs, config.sni)?;
                        }
                        Message::ServerKeyExchange(v) => self.conn.set_by_server_exchange_key(v)?,
                        Message::ServerHelloDone(_) => {
                            self.handle_by_server_hello_done(config)?;
                            self.stream.write_all(self.write_buffer.filled()).await?;
                            self.write_buffer.reset();
                            return Ok(true);
                        }
                        Message::ClientHello(v) => {
                            let len = record.len as usize + 5;
                            let config = config.as_mut().ok_or("config can't be null")?;
                            let random = rand::random::<[u8; 32]>();
                            let server = config.server_mut().ok_or("missing config")?;
                            let mut record = self.conn.gen_server_hello(v, server.server_cert, server.cert_key, &random)?;
                            let session_id = rand::random::<[u8; 32]>();
                            record.messages[0].server_mut().ok_or(HlsError::NullPointer)?.set_session_id(&session_id);

                            record.write_to(&mut self.write_buffer, 1)?;
                            self.conn.update_session(&self.write_buffer.filled()[5..])?;
                            self.stream.write_all(self.write_buffer.filled()).await?;
                            self.client_hello.extend_from_slice(self.read_buffer[..len].as_ref());
                            self.write_buffer.reset();
                            break;
                        }
                        Message::ClientKeyExchange(v) => {
                            self.conn.set_by_client_exchange_key(v);
                            self.conn.make_cipher(true)?;
                        }
                        Message::Payload(_) => {
                            let record_len = record.len as usize + 5;
                            let mut out = vec![0; record_len];
                            let len = self.conn.read_message(&self.read_buffer[..record_len], &mut out)?;
                            self.conn.verify_finish(&out[..len], false)?;

                            let mut ticket = SessionTicket::default();
                            let tbs = rand::random::<[u8; 276]>();
                            ticket.tls_ticket_mut().set_value(&tbs);
                            self.write_buffer.write_slice(&[22, 3, 3]);
                            self.write_buffer.write_u16(ticket.len() as u16);
                            ticket.write_to(&mut self.write_buffer);
                            self.conn.update_session(&self.write_buffer.filled()[5..])?;
                            self.write_buffer.write_slice(&[20, 3, 3, 0, 1, 1]);
                            let record_len = self.conn.make_finish_message(self.write_buffer.unfilled_mut(), true)?;
                            self.write_buffer.add_len(record_len);
                            self.stream.write_all(self.write_buffer.filled()).await?;
                            self.write_buffer.reset();
                            return Ok(true);
                        }
                        Message::CertificateRequest(v) => {
                            let config = config.as_mut().ok_or("config can't be null")?;
                            let config = config.client_mut().ok_or("missing config")?;
                            self.conn.set_by_cert_req(v, config.client_cert.first_mut())?;
                        }
                        _ => {}
                    }
                }
            }
            RecordType::ApplicationData => {}
        }
        Ok(false)
    }

    pub fn alpn(&self) -> Option<&str> {
        Some(self.conn.alpn()?.value())
    }

    pub fn client_hello(&self) -> &[u8] { &self.client_hello }
}

impl<S> TlsStreamHandle for TlsStream<S> {
    fn conn_wbuf(&mut self) -> (&mut Connection, &mut Buffer) {
        (&mut self.conn, &mut self.write_buffer)
    }

    fn conn_rbuf(&mut self) -> (&mut Connection, &mut Buffer) {
        (&mut self.conn, &mut self.read_buffer)
    }
}

impl<S> TlsStream<S> {
    fn read_message(&mut self, buf: &mut ReadBuf<'_>) -> io::Result<usize> {
        let record = RecordLayer::from_bytes(self.read_buffer.filled_mut(), self.handshake_finished, None)?;
        let record_len = record.len as usize + 5;
        match record.context_type {
            RecordType::CipherSpec => {
                self.handshake_finished = true;
                self.read_buffer.move_to(record_len..self.read_buffer.len(), 0);
            }
            RecordType::Alert => return Err(self.handle_by_alert(self.handshake_finished, record_len)?.into()),
            RecordType::HandShake => {
                if self.handshake_finished {
                    let len = self.conn.read_message(&self.read_buffer[..record_len], buf.initialized_mut())?;
                    self.conn.verify_finish(&buf.initialized()[..len], true)?;
                } else {
                    self.conn.update_session(&self.read_buffer[5..record_len])?;
                }
                self.read_buffer.move_to(record_len..self.read_buffer.len(), 0);
            }
            RecordType::ApplicationData => {
                let len = self.conn.read_message(&self.read_buffer[..record_len], buf.initialized_mut())?;
                buf.set_filled(len);
                self.read_buffer.move_to(record_len..self.read_buffer.len(), 0);
                return Ok(len);
            }
        }
        Ok(0)
    }
}

impl<S: AsyncRead + Unpin> AsyncRead for TlsStream<S> {
    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
        if self.shutdown_wrote { return Poll::Ready(Ok(())); }
        let stream = self.get_mut();
        loop {
            let record_len = if stream.read_buffer.is_empty() { 0 } else { u16::from_be_bytes([stream.read_buffer[3], stream.read_buffer[4]]) as usize + 5 };
            if record_len != 0 && stream.read_buffer.len() >= record_len {
                match stream.read_message(buf) {
                    Ok(len) => if len > 0 { return Poll::Ready(Ok(())); } else { continue; }
                    Err(e) => return Poll::Ready(Err(e)),
                }
            }
            if stream.read_buffer.unfilled_mut().is_empty() { return Poll::Ready(Err(Error::other("buffer size  too small"))); }
            let mut rd = ReadBuf::new(stream.read_buffer.unfilled_mut());
            match Pin::new(&mut stream.stream).poll_read(cx, &mut rd) {
                Poll::Ready(Ok(_)) => {
                    let fl = rd.filled().len();
                    if fl == 0 { return Poll::Ready(Ok(())); }
                    let nl = stream.read_buffer.len() + fl;
                    stream.read_buffer.set_len(nl);
                }
                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                Poll::Pending => {
                    return Poll::Pending;
                }
            }
        }
    }
}

impl<S: AsyncWrite + Unpin> AsyncWrite for TlsStream<S> {
    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, Error>> {
        let stream = self.get_mut();
        let chucks = buf.chunks(16384).collect::<Vec<_>>();
        if stream.pending.is_empty() {
            stream.wrote_len = 0;
            stream.pending = (0..chucks.len()).collect();
        }
        loop {
            if stream.pending.is_empty() { break; }
            if stream.write_buffer.is_empty() {
                let record_len = stream.conn.make_message(RecordType::ApplicationData, &mut stream.write_buffer[..], chucks[stream.pending[0]])?;
                stream.write_buffer.set_len(record_len);
                stream.wrote_len += chucks[stream.pending[0]].len();
            }
            match Pin::new(&mut stream.stream).poll_write(cx, stream.write_buffer.filled()) {
                Poll::Ready(Ok(len)) => {
                    if stream.write_buffer.used_empty(len) {
                        stream.pending.remove(0);
                        stream.write_buffer.reset();
                    }
                }
                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
                Poll::Pending => return Poll::Pending,
            }
        }
        assert_eq!(stream.wrote_len, buf.len());
        Poll::Ready(Ok(stream.wrote_len))
    }

    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        Pin::new(&mut self.stream).poll_flush(cx)
    }

    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
        let stream = self.get_mut();
        if stream.write_buffer.is_empty() {
            let len = stream.conn.make_message(RecordType::Alert, &mut stream.write_buffer[..], &Alert::close_notify().to_bytes())?;
            stream.write_buffer.set_len(len);
        }
        match stream.shutdown_wrote {
            true => Pin::new(&mut stream.stream).poll_shutdown(cx),
            false => match Pin::new(&mut stream.stream).poll_write(cx, stream.write_buffer.filled()) {
                Poll::Ready(Ok(_)) => {
                    stream.shutdown_wrote = true;
                    Pin::new(&mut stream.stream).poll_shutdown(cx)
                }
                Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
                Poll::Pending => Poll::Pending,
            }
        }
    }
}