use std::{
io,
net::{SocketAddr, ToSocketAddrs},
path::{Path, PathBuf},
pin::Pin,
time::Duration,
};
use boring::ssl::{SslConnector, SslMethod};
use futures::{AsyncRead, AsyncWrite};
use http::{uri::Scheme, Request, Response};
use rasi::{net::TcpStream, time::TimeoutExt};
use crate::net::{
http::{parse::Responser, writer::RequestWriter},
tls::{connect, SslStream},
};
use super::parse::BodyReader;
pub enum HttpClientWrite {
TcpStream(TcpStream),
TlsStream(SslStream<TcpStream>),
}
impl Into<HttpClientRead> for HttpClientWrite {
fn into(self) -> HttpClientRead {
match self {
HttpClientWrite::TcpStream(stream) => HttpClientRead::TcpStream(stream),
HttpClientWrite::TlsStream(stream) => HttpClientRead::TlsStream(stream),
}
}
}
impl AsyncWrite for HttpClientWrite {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<io::Result<usize>> {
match &mut *self {
Self::TcpStream(stream) => Pin::new(stream).poll_write(cx, buf),
Self::TlsStream(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match &mut *self {
Self::TcpStream(stream) => Pin::new(stream).poll_flush(cx),
Self::TlsStream(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<io::Result<()>> {
match &mut *self {
Self::TcpStream(stream) => Pin::new(stream).poll_close(cx),
Self::TlsStream(stream) => Pin::new(stream).poll_close(cx),
}
}
}
pub enum HttpClientRead {
TcpStream(TcpStream),
TlsStream(SslStream<TcpStream>),
}
impl AsyncRead for HttpClientRead {
fn poll_read(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut [u8],
) -> std::task::Poll<io::Result<usize>> {
match &mut *self {
HttpClientRead::TcpStream(stream) => Pin::new(stream).poll_read(cx, buf),
HttpClientRead::TlsStream(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
pub trait HttpRequestSend {
type Body;
fn send(self) -> HttpRequestSender<Self::Body>;
}
#[must_use = "Must call response function to invoke real sending action."]
pub struct HttpRequestSender<T> {
request: http::Result<Request<T>>,
timeout: Duration,
raddrs: Option<io::Result<Vec<SocketAddr>>>,
server_name: Option<String>,
ca_file: Option<PathBuf>,
}
impl<T> HttpRequestSender<T> {
pub fn new(request: http::Result<Request<T>>) -> Self {
Self {
request,
timeout: Duration::from_secs(30),
raddrs: None,
server_name: None,
ca_file: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn redirect<R: ToSocketAddrs>(mut self, raddrs: R) -> Self {
self.raddrs = Some(
raddrs
.to_socket_addrs()
.map(|iter| iter.collect::<Vec<_>>()),
);
self
}
pub fn with_server_name(mut self, server_name: &str) -> Self {
self.server_name = Some(server_name.to_string());
self
}
pub fn with_ca_file<P: AsRef<Path>>(mut self, ca_file: P) -> Self {
self.ca_file = Some(ca_file.as_ref().to_path_buf());
self
}
pub async fn create(self) -> io::Result<(Request<T>, HttpClientWrite, Duration)> {
let request = self
.request
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let scheme = request.uri().scheme().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Unspecified request scheme",
))?;
let host = request.uri().host().ok_or(io::Error::new(
io::ErrorKind::InvalidInput,
"Unspecified request uri",
))?;
let port =
request.uri().port_u16().unwrap_or_else(
|| {
if scheme == &Scheme::HTTP {
80
} else {
440
}
},
);
let raddr = if let Some(raddr) = self.raddrs {
raddr?
} else {
vec![format!("{}:{}", host, port,)
.parse()
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?]
};
let stream = if scheme == &Scheme::HTTP {
let stream = TcpStream::connect(raddr.as_slice()).await?;
HttpClientWrite::TcpStream(stream)
} else {
let stream = TcpStream::connect(raddr.as_slice()).await?;
let mut config = SslConnector::builder(SslMethod::tls())
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
if let Some(ca_file) = self.ca_file {
config
.set_ca_file(ca_file)
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
}
let config = config.build().configure().unwrap();
let stream = connect(config, host, stream)
.await
.map_err(|err| io::Error::new(io::ErrorKind::ConnectionRefused, err))?;
HttpClientWrite::TlsStream(stream)
};
Ok((request, stream, self.timeout))
}
pub async fn response(self) -> io::Result<Response<BodyReader<HttpClientRead>>>
where
T: AsRef<[u8]>,
{
let (request, mut stream, timeout) = self.create().await?;
let writer = RequestWriter::new(&mut stream);
match writer.write(request).timeout(timeout).await {
Some(Ok(_)) => {}
Some(Err(err)) => return Err(err),
None => {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"send http request timeout",
));
}
}
let stream: HttpClientRead = stream.into();
match Responser::new(stream).parse().timeout(timeout).await {
Some(Ok(response)) => Ok(response),
Some(Err(err)) => return Err(err.into()),
None => {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"recv http response header timeout",
));
}
}
}
pub async fn stream_response(self) -> io::Result<Response<BodyReader<HttpClientRead>>>
where
T: AsyncRead + Unpin,
{
let (request, mut stream, timeout) = self.create().await?;
let writer = RequestWriter::new(&mut stream);
match writer
.write_with_stream_body(request)
.timeout(timeout)
.await
{
Some(Ok(_)) => {}
Some(Err(err)) => return Err(err),
None => {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"send http request timeout",
));
}
}
let stream: HttpClientRead = stream.into();
match Responser::new(stream).parse().timeout(timeout).await {
Some(Ok(response)) => Ok(response),
Some(Err(err)) => return Err(err.into()),
None => {
return Err(io::Error::new(
io::ErrorKind::TimedOut,
"recv http response header timeout",
));
}
}
}
}
impl<T> HttpRequestSend for http::Result<Request<T>>
where
T: AsRef<[u8]>,
{
type Body = T;
fn send(self) -> HttpRequestSender<Self::Body> {
HttpRequestSender::new(self)
}
}