#![doc = include_str!("../README.md")]
#![deny(unsafe_code)]
use bytes::Bytes;
use http::header::{CONNECTION, UPGRADE};
pub use http::{
HeaderMap, HeaderName, HeaderValue, Method, Request, Response, StatusCode, header::HOST,
};
use http_body_util::{Empty, Full};
use hyper::body::Incoming;
use tokio::net::TcpStream;
mod client;
mod error;
pub mod http1;
pub mod http2;
pub use client::{Client, ClientExt};
pub use error::Error;
pub use http1::Http1;
pub use http2::Http2;
pub use hyper::upgrade::on as upgrade;
pub use sealed::ResponseExt;
pub type EmptyBody = Empty<Bytes>;
pub type BytesBody = Full<Bytes>;
pub type Upgraded = hyper_util::rt::TokioIo<hyper::upgrade::Upgraded>;
pub async fn do_upgrade(resp: Response<Incoming>) -> hyper::Result<Upgraded> {
let upgraded = hyper::upgrade::on(resp).await?;
Ok(hyper_util::rt::TokioIo::new(upgraded))
}
mod sealed {
use futures::TryStreamExt;
use http_body_util::BodyExt;
use tokio::io::AsyncRead;
use tokio_util::io::StreamReader;
use crate::Error;
pub trait ResponseExt {
fn collect_bytes(self) -> impl Future<Output = Result<bytes::Bytes, Error>> + Send;
fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static;
}
impl<B> ResponseExt for http::Response<B>
where
B: hyper::body::Body + Send + Unpin + 'static,
B::Data: Send + 'static,
B::Error: core::error::Error + Send + Sync + 'static,
{
async fn collect_bytes(self) -> Result<bytes::Bytes, Error> {
let buf = self
.into_body()
.collect()
.await
.map_err(|e| {
tracing::error!(error = %e, "collecting response body");
Error::Io
})?
.to_bytes();
Ok(buf)
}
fn into_read(self) -> impl AsyncRead + Send + Unpin + 'static {
StreamReader::new(
self.into_body()
.into_data_stream()
.map_err(tokio::io::Error::other),
)
}
}
}
pub fn make_upgrade_req(
u: &url::Url,
protocol: &str,
extra_headers: impl IntoIterator<Item = (HeaderName, HeaderValue)>,
) -> Result<Request<EmptyBody>, Error> {
let mut req = Request::post(u.as_str())
.header(HOST, u.host_str().ok_or(Error::InvalidInput)?)
.header(UPGRADE, protocol)
.header(CONNECTION, "Upgrade")
.body(EmptyBody::new())
.map_err(|e| {
tracing::error!(error = %e, "creating upgrade request");
Error::InvalidInput
})?;
req.headers_mut().extend(extra_headers);
Ok(req)
}
pub fn host_header(u: &url::Url) -> Option<(HeaderName, HeaderValue)> {
let host = match u.port() {
Some(port) => format!("{}:{port}", u.host_str()?),
None => u.host_str()?.to_owned(),
};
Some((HOST, HeaderValue::from_str(&host).ok()?))
}
async fn dial_tcp(url: &url::Url) -> Result<TcpStream, Error> {
let conn = TcpStream::connect((
url.host_str().ok_or(Error::InvalidInput)?,
url.port_or_known_default()
.ok_or(Error::InvalidInput)
.inspect_err(|_err| tracing::error!("unknown url port"))?,
))
.await
.map_err(|e| {
tracing::error!(error = %e, %url, "dialing tcp");
Error::Io
})?;
Ok(conn)
}
async fn dial_tls(
url: &url::Url,
alpn: impl IntoIterator<Item = Vec<u8>>,
) -> Result<ts_tls_util::TlsStream<TcpStream>, Error> {
let server_name = ts_tls_util::server_name(url)
.ok_or_else(|| {
tracing::error!(%url, "parsing server name");
Error::InvalidInput
})?
.to_owned();
let conn = dial_tcp(url).await?;
ts_tls_util::connect_alpn(server_name, conn, alpn)
.await
.map_err(|e| {
tracing::error!(error = %e, "dialing tls connection");
Error::Io
})
}
#[cfg(test)]
mod tests {
use super::*;
fn url(s: &str) -> url::Url {
url::Url::parse(s).unwrap()
}
#[test]
fn host_header_omits_default_https_port() {
let (name, value) = host_header(&url("https://h/")).unwrap();
assert_eq!(name, HOST);
assert_eq!(value, "h");
assert!(!value.to_str().unwrap().contains(":443"));
}
#[test]
fn host_header_omits_default_http_port() {
let (name, value) = host_header(&url("http://h/")).unwrap();
assert_eq!(name, HOST);
assert_eq!(value, "h");
assert!(!value.to_str().unwrap().contains(":80"));
}
#[test]
fn host_header_includes_non_default_port() {
let (name, value) = host_header(&url("https://localhost:14000/")).unwrap();
assert_eq!(name, HOST);
assert_eq!(value, "localhost:14000");
}
}