use std::{
cell::RefCell,
fmt::Debug,
io::{self, Error, ErrorKind, Read},
mem::swap,
rc::Rc,
str::FromStr,
sync::Arc,
};
use hyperium_http::{
header::{CONTENT_LENGTH, HOST, TRANSFER_ENCODING},
Response,
};
use tcp_stream::OwnedTLSConfig;
use crate::{
buffer::GrowableCircleBuf,
dns::AddrResolver,
frame::{DeserializeFrame, FrameDuplex, SerializeFrame, SizedFrame},
tcp::TcpSession,
tls::{NativeTlsConnector, TlsConnector},
DriveOutcome, Flush, Publish, PublishOutcome, Receive, ReceiveOutcome, Session, SessionStatus,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Scheme {
Http,
Https,
}
impl Scheme {
pub fn default_port(&self) -> u16 {
match self {
Self::Http => 80,
Self::Https => 443,
}
}
}
pub enum HttpRequest {
Request(hyperium_http::Request<Vec<u8>>),
Serialized(Vec<u8>),
}
impl<I: IntoBody> From<hyperium_http::Request<I>> for HttpRequest {
fn from(value: hyperium_http::Request<I>) -> Self {
let (parts, body) = value.into_parts();
HttpRequest::Request(hyperium_http::Request::from_parts(parts, body.into_body()))
}
}
pub struct HttpClient {
tls_connector: Option<Arc<TlsConnector>>,
addr_resolver: Option<Arc<AddrResolver>>,
}
impl HttpClient {
pub fn new() -> Self {
Self {
tls_connector: None,
addr_resolver: None,
}
}
pub fn with_addr_resolver(mut self, addr_resolver: Arc<AddrResolver>) -> Self {
self.addr_resolver = Some(addr_resolver);
self
}
pub fn with_tls_connector(mut self, tls_connector: Arc<TlsConnector>) -> Self {
self.tls_connector = Some(tls_connector);
self
}
pub fn with_tls_config(mut self, tls_config: OwnedTLSConfig) -> Result<Self, Error> {
self.tls_connector = Some(Arc::new(TlsConnector::Native(NativeTlsConnector::new(
tls_config.as_ref(),
false,
)?)));
Ok(self)
}
pub fn connect(
&mut self,
host: &str,
port: u16,
scheme: Scheme,
) -> Result<HttpClientSession, io::Error> {
let mut conn = TcpSession::connect(
format!("{host}:{port}"),
self.addr_resolver.as_ref().map(|x| Arc::clone(x)),
self.tls_connector.as_ref().map(|x| Arc::clone(&x)),
)?;
if scheme == Scheme::Https {
conn = conn.into_tls(&host)?;
}
Ok(HttpClientSession::new(FrameDuplex::new(
conn,
Http1ResponseDeserializer::new(),
Http1RequestSerializer::new(),
0,
)))
}
pub fn request<I: IntoBody>(
&mut self,
request: hyperium_http::Request<I>,
) -> Result<HttpClientSession, io::Error> {
let (parts, body) = request.into_parts();
let request = hyperium_http::Request::from_parts(parts, body.into_body());
let scheme = match request.uri().scheme_str() {
None => Scheme::Http,
Some("http") => Scheme::Http,
Some("https") => Scheme::Https,
_ => {
return Err(io::Error::new(
ErrorKind::InvalidData,
"bad http uri scheme",
))
}
};
let session = connect_stream(
scheme,
request.uri().host(),
request.uri().port().map(|x| x.as_u16()),
self.addr_resolver.as_ref().map(|x| Arc::clone(x)),
self.tls_connector.as_ref().map(|x| Arc::clone(&x)),
)?;
let mut conn = HttpClientSession::new(FrameDuplex::new(
session,
Http1ResponseDeserializer::new(),
Http1RequestSerializer::new(),
0,
));
conn.pending_initial_request = Some(request.into());
Ok(conn)
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new()
}
}
pub(crate) fn connect_stream(
scheme: Scheme,
host: Option<&str>,
port: Option<u16>,
addr_resolver: Option<Arc<AddrResolver>>,
tls_connector: Option<Arc<TlsConnector>>,
) -> Result<TcpSession, Error> {
let host = match host {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let port = match port {
Some(x) => x,
None => scheme.default_port(),
};
let mut conn = TcpSession::connect(format!("{host}:{port}"), addr_resolver, tls_connector)?;
if scheme == Scheme::Https {
conn = conn
.into_tls(&host)
.map_err(|err| Error::new(ErrorKind::ConnectionRefused, err))?;
}
Ok(conn)
}
pub struct HttpClientSession {
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
pending_initial_request: Option<HttpRequest>,
}
impl HttpClientSession {
pub fn new(
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
) -> Self {
Self {
session,
pending_initial_request: None,
}
}
}
impl Session for HttpClientSession {
fn status(&self) -> crate::SessionStatus {
self.session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
let mut result: crate::DriveOutcome = self.session.drive()?;
if self.session.status() == SessionStatus::Established
&& self.pending_initial_request.is_some()
{
let wrote = match self.session.publish(
self.pending_initial_request
.take()
.expect("checked pending_request"),
)? {
PublishOutcome::Published => true,
PublishOutcome::Incomplete(x) => {
self.pending_initial_request = Some(x);
false
}
};
if wrote {
self.pending_initial_request = None;
result = DriveOutcome::Active;
}
}
Ok(result)
}
}
impl Receive for HttpClientSession {
type ReceivePayload<'a> = hyperium_http::Response<Vec<u8>>;
fn receive<'a>(&'a mut self) -> Result<crate::ReceiveOutcome<Self::ReceivePayload<'a>>, Error> {
self.drive()?;
if self.pending_initial_request.is_none() && self.status() == SessionStatus::Established {
self.session.receive()
} else {
Ok(crate::ReceiveOutcome::Idle)
}
}
}
impl Publish for HttpClientSession {
type PublishPayload<'a> = HttpRequest;
fn publish<'a>(
&mut self,
data: Self::PublishPayload<'a>,
) -> Result<PublishOutcome<Self::PublishPayload<'a>>, Error> {
self.session.publish(data)
}
}
impl Flush for HttpClientSession {
fn flush(&mut self) -> Result<(), Error> {
self.session.flush()
}
}
impl Debug for HttpClientSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpClientSession")
.field("session", &self.session)
.finish()
}
}
struct PersistentHttpSessionContext {
session: HttpClientSession,
active_request_id: u64,
next_request_id: u64,
closed: bool,
}
pub struct PersistentHttpConnection {
context: Rc<RefCell<PersistentHttpSessionContext>>,
}
impl From<HttpClientSession> for PersistentHttpConnection {
fn from(session: HttpClientSession) -> Self {
Self {
context: Rc::new(RefCell::new(PersistentHttpSessionContext {
session,
active_request_id: 0,
next_request_id: 0,
closed: false,
})),
}
}
}
impl Debug for PersistentHttpConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PersistentHttpConnection").finish()
}
}
impl Session for PersistentHttpConnection {
fn status(&self) -> crate::SessionStatus {
self.context.borrow().session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
self.context.borrow_mut().session.drive()
}
}
impl PersistentHttpConnection {
pub fn request<I: IntoBody>(
&mut self,
mut request: hyperium_http::Request<I>,
) -> Result<PendingHttpResponse, Error> {
request
.headers_mut()
.insert("Connection", "keep-alive".parse().unwrap());
let mut context = self.context.borrow_mut();
if context.closed {
return Err(Error::new(
ErrorKind::ConnectionRefused,
"connection closed by server",
));
}
let request_id = context.next_request_id;
context.next_request_id += 1;
drop(context);
Ok(PendingHttpResponse {
context: Rc::clone(&self.context),
pending_request: Some(request.into()),
request_id,
})
}
}
pub struct PendingHttpResponse {
context: Rc<RefCell<PersistentHttpSessionContext>>,
pending_request: Option<HttpRequest>,
request_id: u64,
}
impl Debug for PendingHttpResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingHttpResponse")
.field("request_id", &self.request_id)
.finish()
}
}
impl Session for PendingHttpResponse {
fn status(&self) -> crate::SessionStatus {
self.context.borrow().session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
let mut context = self.context.borrow_mut();
if context.active_request_id == self.request_id {
context.session.drive()
} else {
Ok(DriveOutcome::Idle)
}
}
}
impl Receive for PendingHttpResponse {
type ReceivePayload<'a> = hyperium_http::Response<Vec<u8>>;
fn receive<'a>(&'a mut self) -> Result<ReceiveOutcome<Self::ReceivePayload<'a>>, Error> {
let mut context = self.context.borrow_mut();
if context.active_request_id != self.request_id {
return Ok(ReceiveOutcome::Idle);
}
if let Some(request) = self.pending_request.take() {
match context.session.publish(request)? {
PublishOutcome::Incomplete(request) => {
self.pending_request = Some(request);
return Ok(ReceiveOutcome::Idle);
}
PublishOutcome::Published => {
return Ok(ReceiveOutcome::Active);
}
}
}
match context.session.receive()? {
ReceiveOutcome::Payload(response) => {
context.active_request_id += 1;
if response
.headers()
.get("Connection")
.map(|x| x.to_str().unwrap())
== Some("close")
{
context.closed = true;
}
Ok(ReceiveOutcome::Payload(response))
}
ReceiveOutcome::Active => Ok(ReceiveOutcome::Active),
ReceiveOutcome::Idle => Ok(ReceiveOutcome::Idle),
}
}
}
pub trait IntoBody {
fn into_body(self) -> Vec<u8>;
}
impl IntoBody for String {
fn into_body(self) -> Vec<u8> {
self.into_bytes()
}
}
impl IntoBody for &str {
fn into_body(self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}
impl IntoBody for Vec<u8> {
fn into_body(self) -> Vec<u8> {
self
}
}
impl IntoBody for &[u8] {
fn into_body(self) -> Vec<u8> {
self.to_vec()
}
}
impl IntoBody for () {
fn into_body(self) -> Vec<u8> {
Vec::new()
}
}
enum BodyType {
ContentLength(usize),
ChunkedTransfer,
OnClose,
None,
}
struct BodyInfo {
offset: usize,
ty: BodyType,
}
impl BodyInfo {
pub fn new(offset: usize, ty: BodyType) -> Self {
Self { offset, ty }
}
}
pub struct Http1ResponseDeserializer {
deserialized_response: Option<hyperium_http::Response<Vec<u8>>>,
deserialized_size: usize,
body_info: Option<BodyInfo>,
}
impl Http1ResponseDeserializer {
pub fn new() -> Self {
Self {
deserialized_response: None,
deserialized_size: 0,
body_info: None,
}
}
}
impl DeserializeFrame for Http1ResponseDeserializer {
type DeserializedFrame<'a> = hyperium_http::Response<Vec<u8>>;
fn check_deserialize_frame(&mut self, data: &[u8], eof: bool) -> Result<bool, Error> {
if self.deserialized_response.is_none() {
self.deserialized_response = Some(Response::new(Vec::new()));
}
let deserialized_response = self
.deserialized_response
.as_mut()
.expect("checked deserialized_response value");
let header_count: usize = count_max_headers(data);
let mut headers = Vec::new();
headers.resize(header_count, httparse::EMPTY_HEADER);
if self.body_info.is_none() {
let mut parsed = httparse::Response::new(&mut headers);
match parsed.parse(data).map_err(|err| {
Error::new(
ErrorKind::InvalidData,
format!("http response parse failed: {err:?}").as_str(),
)
})? {
httparse::Status::Complete(size) => {
if parse_is_chunked(&parsed.headers) {
self.body_info = Some(BodyInfo::new(size, BodyType::ChunkedTransfer));
} else if let Some(content_length) = parse_content_length(&parsed.headers)? {
self.body_info =
Some(BodyInfo::new(size, BodyType::ContentLength(content_length)));
} else if parsed.version.is_none() || parsed.version == Some(1) {
self.body_info = Some(BodyInfo::new(size, BodyType::OnClose));
} else {
self.body_info = Some(BodyInfo::new(size, BodyType::None));
}
parsed_into_response(parsed, deserialized_response)?;
}
httparse::Status::Partial => return Ok(false),
}
}
let (parsed_body, total_size) = match &self.body_info {
None => (None, 0),
Some(body_info) => {
match body_info.ty {
BodyType::ChunkedTransfer => {
if body_info.offset < data.len() && ends_with_ascii(data, "\r\n\r\n") {
let mut body = Vec::new();
let mut decoder =
chunked_transfer::Decoder::new(&data[body_info.offset..]);
decoder.read_to_end(&mut body)?;
let body_len = body.len();
match decoder.remaining_chunks_size() {
None => (Some(body), body_info.offset + body_len),
Some(_) => (None, 0),
}
} else {
(None, 0)
}
}
BodyType::ContentLength(content_length) => {
let total_length = body_info.offset + content_length;
if data.len() >= total_length {
(
Some(data[body_info.offset..total_length].to_vec()),
total_length,
)
} else {
(None, 0)
}
}
BodyType::OnClose => {
if eof {
(Some(data[body_info.offset..].to_vec()), data.len())
} else {
(None, 0)
}
}
BodyType::None => (Some(Vec::new()), body_info.offset),
}
}
};
match parsed_body {
None => {
if eof {
Err(Error::new(
ErrorKind::UnexpectedEof,
"http connection terminated before receiving full response",
))
} else {
Ok(false)
}
}
Some(mut body) => {
swap(deserialized_response.body_mut(), &mut body);
self.body_info = None;
self.deserialized_size = total_size;
Ok(true)
}
}
}
fn deserialize_frame<'a>(
&'a mut self,
_data: &'a [u8],
) -> Result<crate::frame::SizedFrame<Self::DeserializedFrame<'a>>, Error> {
Ok(SizedFrame::new(
self.deserialized_response
.take()
.ok_or_else(|| Error::new(ErrorKind::Other, "no deserialized frame"))?,
self.deserialized_size,
))
}
}
pub struct Http1RequestSerializer {}
impl Http1RequestSerializer {
pub fn new() -> Self {
Self {}
}
}
impl SerializeFrame for Http1RequestSerializer {
type SerializedFrame<'a> = HttpRequest;
fn serialize_frame<'a>(
&mut self,
request: Self::SerializedFrame<'a>,
buffer: &mut GrowableCircleBuf,
) -> Result<PublishOutcome<Self::SerializedFrame<'a>>, Error> {
let serialized_request = match request {
HttpRequest::Request(request) => {
match request.version() {
hyperium_http::Version::HTTP_10 | hyperium_http::Version::HTTP_11 => {}
version => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("unsupported http request version {version:?}").as_str(),
))
}
}
let host = match request.uri().host() {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let body = request.body();
let content_length = body.len().to_string();
let mut serialized_request = Vec::new();
serialized_request.extend_from_slice(request.method().as_str().as_bytes());
serialized_request.extend_from_slice(" ".as_bytes());
serialized_request.extend_from_slice(request.uri().path().as_bytes());
if let Some(query) = request.uri().query() {
serialized_request.extend_from_slice("?".as_bytes());
serialized_request.extend_from_slice(query.as_bytes());
}
serialized_request
.extend_from_slice(format!(" {:?}", request.version()).as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
{
serialized_request.extend_from_slice(HOST.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(host.as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
for (n, v) in request.headers().iter() {
serialized_request.extend_from_slice(n.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(
v.to_str()
.map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("could not convert header '{}' to string", n.as_str())
.as_str(),
)
})?
.as_bytes(),
);
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
if body.len() > 0 {
serialized_request.extend_from_slice(CONTENT_LENGTH.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(content_length.as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
serialized_request.extend_from_slice(body);
serialized_request
}
HttpRequest::Serialized(serialized) => serialized,
};
if buffer.try_write(&vec![&serialized_request])? {
Ok(PublishOutcome::Published)
} else {
Ok(PublishOutcome::Incomplete(HttpRequest::Serialized(
serialized_request,
)))
}
}
}
fn parsed_into_response(
parsed: httparse::Response,
resp: &mut http::Response<Vec<u8>>,
) -> Result<(), Error> {
if let Some(code) = parsed.code {
let status_code = hyperium_http::StatusCode::try_from(code).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid status code '{code}'").as_str(),
)
})?;
*resp.status_mut() = status_code;
}
if let Some(version) = parsed.version {
*resp.version_mut() = match version {
0 => hyperium_http::Version::HTTP_10,
1 => hyperium_http::Version::HTTP_11,
_ => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("response invalid version '{version}'").as_str(),
))
}
};
}
for h in parsed.headers.iter() {
let name = http::HeaderName::from_str(h.name).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header name '{}'", h.name).as_str(),
)
})?;
let value = http::HeaderValue::from_bytes(h.value).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header value '{:?}'", h.value).as_str(),
)
})?;
resp.headers_mut().insert(name, value);
}
Ok(())
}
const LINE_BREAK: &str = "\r\n";
fn count_max_headers(payload: &[u8]) -> usize {
if payload.is_empty() {
return 0;
}
let mut count = 0;
for i in 0..payload.len() - 1 {
if payload[i] == b'\r' && payload[i + 1] == b'\n' {
count += 1;
}
}
count
}
fn ends_with_ascii(buf: &[u8], ends_with: &str) -> bool {
if buf.len() < ends_with.len() {
return false;
}
let ends_with = ends_with.as_bytes();
for i in 0..ends_with.len() {
if buf[buf.len() - i - 1] != ends_with[ends_with.len() - i - 1] {
return false;
}
}
true
}
fn parse_is_chunked(headers: &[httparse::Header]) -> bool {
match find_header(&headers, TRANSFER_ENCODING.as_str()) {
Some(v) => String::from_utf8_lossy(v).eq_ignore_ascii_case("chunked"),
None => false,
}
}
fn parse_content_length(headers: &[httparse::Header]) -> Result<Option<usize>, Error> {
if let Some(v) = find_header(headers, CONTENT_LENGTH.as_str()) {
let v = String::from_utf8_lossy(v);
return Ok(Some(v.parse().map_err(|_| {
Error::new(ErrorKind::InvalidData, "content-length not a number")
})?));
}
Ok(None)
}
fn find_header<'a>(headers: &'a [httparse::Header], name: &str) -> Option<&'a [u8]> {
for h in headers.iter() {
if h.name.eq_ignore_ascii_case(name) {
return Some(h.value);
}
}
None
}