mock-http-connector 0.5.0

Mock hyper HTTPConnector
Documentation
use std::{
    cmp::min,
    future::Future,
    io,
    pin::Pin,
    sync::Arc,
    task::{ready, Context, Poll, Waker},
};

use crate::hyper::{Connected, Connection, Response, Uri};
use httparse::{Request, Status};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use crate::{connector::InnerConnector, response::ResponseFuture, Error};

pub struct MockStream {
    res: ResponseState,
    req_data: Vec<u8>,
    waker: Option<Waker>,

    uri: Uri,

    connector: Arc<InnerConnector>,
}

impl MockStream {
    pub(crate) fn new(connector: Arc<InnerConnector>, uri: Uri) -> Self {
        Self {
            res: ResponseState::New,
            req_data: Vec::new(),
            waker: None,
            uri,
            connector,
        }
    }
}

impl Connection for MockStream {
    fn connected(&self) -> Connected {
        Connected::new()
    }
}

impl AsyncRead for MockStream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &mut ReadBuf<'_>,
    ) -> Poll<io::Result<()>> {
        let (data, mut pos) = match &mut self.res {
            ResponseState::New => {
                self.waker = Some(cx.waker().clone());
                return Poll::Pending;
            }
            ResponseState::Fut(fut) => {
                let res = ready!(Pin::new(fut).poll(cx))
                    .map_err(|err| into_connect_error(Error::Runtime(err)))?;
                (into_data(res)?, 0)
            }
            ResponseState::Data(data, pos) => (data.clone(), *pos),
        };

        let size = min(buf.remaining(), data.len() - pos);
        buf.put_slice(&data[pos..pos + size]);
        pos += size;

        self.res = ResponseState::Data(data, pos);

        self.waker = Some(cx.waker().clone());

        Poll::Ready(Ok(()))
    }
}

#[cfg(feature = "hyper_1")]
impl hyper_1::rt::Read for MockStream {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        mut buf: hyper_1::rt::ReadBufCursor<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        let (data, mut pos) = match &mut self.res {
            ResponseState::New => {
                self.waker = Some(cx.waker().clone());
                return Poll::Pending;
            }
            ResponseState::Fut(fut) => {
                let res = ready!(Pin::new(fut).poll(cx))
                    .map_err(|err| into_connect_error(Error::Runtime(err)))?;
                (into_data(res)?, 0)
            }
            ResponseState::Data(data, pos) => (data.clone(), *pos),
        };

        let size = min(buf.remaining(), data.len() - pos);
        buf.put_slice(&data[pos..pos + size]);
        pos += size;

        self.res = ResponseState::Data(data, pos);

        self.waker = Some(cx.waker().clone());

        Poll::Ready(Ok(()))
    }
}

impl AsyncWrite for MockStream {
    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_write(
        mut self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, io::Error>> {
        let mut headers = [httparse::EMPTY_HEADER; 64];
        let mut req = Request::new(&mut headers);
        self.req_data.extend(buf);

        let status = req
            .parse(&self.req_data)
            .map_err(|err| into_connect_error(err.into()))?;

        let body = match status {
            Status::Complete(body_pos) => &self.req_data[body_pos..],
            Status::Partial => &[],
        };

        self.res = ResponseState::Fut(
            self.connector
                .matches_raw(req, body, &self.uri)
                .map_err(into_connect_error)?,
        );

        if let Some(w) = self.waker.take() {
            w.wake()
        }

        Poll::Ready(Ok(buf.len()))
    }
}

#[cfg(feature = "hyper_1")]
impl hyper_1::rt::Write for MockStream {
    fn poll_write(
        mut self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        let mut headers = [httparse::EMPTY_HEADER; 64];
        let mut req = Request::new(&mut headers);
        self.req_data.extend(buf);

        let status = req
            .parse(&self.req_data)
            .map_err(|err| into_connect_error(err.into()))?;

        let body = match status {
            Status::Complete(body_pos) => &self.req_data[body_pos..],
            Status::Partial => &[],
        };

        self.res = ResponseState::Fut(
            self.connector
                .matches_raw(req, body, &self.uri)
                .map_err(into_connect_error)?,
        );

        if let Some(w) = self.waker.take() {
            w.wake()
        }

        Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
        Poll::Ready(Ok(()))
    }

    fn poll_shutdown(
        self: Pin<&mut Self>,
        _cx: &mut Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        Poll::Ready(Ok(()))
    }
}

#[derive(Default)]
enum ResponseState {
    #[default]
    New,
    Fut(ResponseFuture),
    Data(Vec<u8>, usize),
}

fn into_data(res: Response<String>) -> Result<Vec<u8>, io::Error> {
    let mut data = String::new();
    let status = res.status();
    data.push_str(&format!(
        "HTTP/1.1 {} {}\r\n",
        status.as_u16(),
        status.as_str()
    ));

    for (name, value) in res.headers() {
        data.push_str(&format!(
            "{name}: {}\r\n",
            value
                .to_str()
                .map_err(|err| io::Error::new(io::ErrorKind::BrokenPipe, err))?
        ));
    }

    data.push_str("\r\n");
    data.push_str(res.body());

    Ok(data.into_bytes())
}

fn into_connect_error(err: Error) -> io::Error {
    io::Error::new(io::ErrorKind::ConnectionRefused, err)
}