use std::{
io::{self, Error, ErrorKind, Read},
mem::swap,
str::FromStr,
};
use hyperium_http::header::{CONTENT_LENGTH, HOST, TRANSFER_ENCODING};
use tcp_stream::OwnedTLSConfig;
use crate::{
frame::{DeserializedFrame, FramingSession, FramingStrategy},
tcp::StreamingTcpSession,
Session, TlsSession, WriteStatus,
};
pub type HttpClientSession = FramingSession<StreamingTcpSession, Http1FramingStrategy>;
pub struct HttpClient {
tls_config: OwnedTLSConfig,
}
impl HttpClient {
pub fn new(tls_config: OwnedTLSConfig) -> Self {
Self { tls_config }
}
pub fn request<I: IntoBody>(
&mut self,
request: hyperium_http::Request<I>,
) -> Result<HttpClientSession, io::Error> {
let (parts, body) = request.into_parts();
let request = hyperium_http::Request::from_parts(parts, body.into_body());
let https = match request.uri().scheme_str() {
None => false,
Some("http") => false,
Some("https") => true,
_ => return Err(io::Error::new(ErrorKind::InvalidData, "bad uri scheme")),
};
let host = match request.uri().host() {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let port = match request.uri().port() {
Some(x) => x.as_u16(),
None => {
if https {
443
} else {
80
}
}
};
let mut conn = FramingSession::new(
StreamingTcpSession::connect(&format!("{}:{}", host, port))?.with_nonblocking(true)?,
Http1FramingStrategy::new(),
0,
);
if https {
conn.to_tls(&host, self.tls_config.as_ref())?;
}
if let WriteStatus::Pending(_) = conn.write(&request)? {
return Err(Error::new(
ErrorKind::Other,
"http payload should have buffered",
));
}
Ok(conn)
}
}
pub trait IntoBody {
fn into_body(self) -> Vec<u8>;
}
impl IntoBody for Vec<u8> {
fn into_body(self) -> Vec<u8> {
self
}
}
impl IntoBody for &[u8] {
fn into_body(self) -> Vec<u8> {
self.to_vec()
}
}
impl IntoBody for () {
fn into_body(self) -> Vec<u8> {
Vec::new()
}
}
enum BodyType {
ContentLength(usize),
ChunkedTransfer,
OnClose,
None,
}
struct BodyInfo {
offset: usize,
ty: BodyType,
}
impl BodyInfo {
pub fn new(offset: usize, ty: BodyType) -> Self {
Self { offset, ty }
}
}
pub struct Http1FramingStrategy {
serialized_request: Vec<u8>,
deserialized_response: hyperium_http::Response<Vec<u8>>,
deserialized_size: usize,
body_info: Option<BodyInfo>,
}
impl Http1FramingStrategy {
pub fn new() -> Self {
Self {
serialized_request: Vec::new(),
deserialized_response: hyperium_http::Response::new(Vec::new()),
deserialized_size: 0,
body_info: None,
}
}
}
impl FramingStrategy for Http1FramingStrategy {
type ReadFrame = hyperium_http::Response<Vec<u8>>;
type WriteFrame = hyperium_http::Request<Vec<u8>>;
fn check_deserialize_frame(&mut self, data: &[u8], eof: bool) -> Result<bool, Error> {
let header_count: usize = count_max_headers(data);
let mut headers = Vec::new();
headers.resize(header_count, httparse::EMPTY_HEADER);
if self.body_info.is_none() {
let mut parsed = httparse::Response::new(&mut headers);
match parsed.parse(data).map_err(|err| {
Error::new(
ErrorKind::InvalidData,
format!("http response parse failed: {err:?}").as_str(),
)
})? {
httparse::Status::Complete(size) => {
if parse_is_chunked(&parsed.headers) {
self.body_info = Some(BodyInfo::new(size, BodyType::ChunkedTransfer));
} else if let Some(content_length) = parse_content_length(&parsed.headers)? {
self.body_info =
Some(BodyInfo::new(size, BodyType::ContentLength(content_length)));
} else if parsed.version.is_none() || parsed.version == Some(1) {
self.body_info = Some(BodyInfo::new(size, BodyType::OnClose));
} else {
self.body_info = Some(BodyInfo::new(size, BodyType::None));
}
parsed_into_response(parsed, &mut self.deserialized_response)?;
}
httparse::Status::Partial => return Ok(false),
}
}
let parsed_body = match &self.body_info {
None => None,
Some(body_info) => {
match body_info.ty {
BodyType::ChunkedTransfer => {
if body_info.offset < data.len() && ends_with_ascii(data, "\r\n\r\n") {
let mut body = Vec::new();
let mut decoder =
chunked_transfer::Decoder::new(&data[body_info.offset..]);
decoder.read_to_end(&mut body)?;
match decoder.remaining_chunks_size() {
None => Some(body),
Some(_) => None,
}
} else {
None
}
}
BodyType::ContentLength(content_length) => {
let total_length = body_info.offset + content_length;
if data.len() >= total_length {
Some(data[body_info.offset..total_length].to_vec())
} else {
None
}
}
BodyType::OnClose => {
if eof {
Some(data[body_info.offset..].to_vec())
} else {
None
}
}
BodyType::None => Some(Vec::new()),
}
}
};
match parsed_body {
None => {
if eof {
Err(Error::new(
ErrorKind::UnexpectedEof,
"http connection terminated before receiving full response",
))
} else {
Ok(false)
}
}
Some(mut body) => {
swap(self.deserialized_response.body_mut(), &mut body);
Ok(true)
}
}
}
fn deserialize_frame<'a>(
&'a mut self,
_data: &'a [u8],
) -> Result<crate::frame::DeserializedFrame<'a, Self::ReadFrame>, Error> {
Ok(DeserializedFrame::new(
&self.deserialized_response,
self.deserialized_size,
))
}
fn serialize_frame<'a>(
&'a mut self,
request: &'a Self::WriteFrame,
) -> Result<Vec<&'a [u8]>, Error> {
match request.version() {
hyperium_http::Version::HTTP_10 | hyperium_http::Version::HTTP_11 => {}
version => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("unsupported http request version {version:?}").as_str(),
))
}
}
let host = match request.uri().host() {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let body = request.body();
let content_length = body.len().to_string();
self.serialized_request = Vec::new();
self.serialized_request
.extend_from_slice(request.method().as_str().as_bytes());
self.serialized_request.extend_from_slice(" ".as_bytes());
self.serialized_request
.extend_from_slice(request.uri().path().as_bytes());
self.serialized_request
.extend_from_slice(format!(" {:?}", request.version()).as_bytes());
self.serialized_request
.extend_from_slice(LINE_BREAK.as_bytes());
{
self.serialized_request
.extend_from_slice(HOST.as_str().as_bytes());
self.serialized_request.extend_from_slice(": ".as_bytes());
self.serialized_request.extend_from_slice(host.as_bytes());
self.serialized_request
.extend_from_slice(LINE_BREAK.as_bytes());
}
for (n, v) in request.headers().iter() {
self.serialized_request
.extend_from_slice(n.as_str().as_bytes());
self.serialized_request.extend_from_slice(": ".as_bytes());
self.serialized_request.extend_from_slice(
v.to_str()
.map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("could not convert header '{}' to string", n.as_str()).as_str(),
)
})?
.as_bytes(),
);
self.serialized_request
.extend_from_slice(LINE_BREAK.as_bytes());
}
if body.len() > 0 {
self.serialized_request
.extend_from_slice(CONTENT_LENGTH.as_str().as_bytes());
self.serialized_request.extend_from_slice(": ".as_bytes());
self.serialized_request
.extend_from_slice(content_length.as_bytes());
self.serialized_request
.extend_from_slice(LINE_BREAK.as_bytes());
}
self.serialized_request
.extend_from_slice(LINE_BREAK.as_bytes());
self.serialized_request.extend_from_slice(body);
Ok(vec![&self.serialized_request])
}
}
fn parsed_into_response(
parsed: httparse::Response,
resp: &mut http::Response<Vec<u8>>,
) -> Result<(), Error> {
if let Some(code) = parsed.code {
let status_code = hyperium_http::StatusCode::try_from(code).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid status code '{code}'").as_str(),
)
})?;
*resp.status_mut() = status_code;
}
if let Some(version) = parsed.version {
*resp.version_mut() = match version {
0 => hyperium_http::Version::HTTP_10,
1 => hyperium_http::Version::HTTP_11,
_ => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("response invalid version '{version}'").as_str(),
))
}
};
}
for h in parsed.headers.iter() {
let name = http::HeaderName::from_str(h.name).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header name '{}'", h.name).as_str(),
)
})?;
let value = http::HeaderValue::from_str(h.name).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header name '{}'", h.name).as_str(),
)
})?;
resp.headers_mut().insert(name, value);
}
Ok(())
}
const LINE_BREAK: &str = "\r\n";
fn count_max_headers(payload: &[u8]) -> usize {
if payload.is_empty() {
return 0;
}
let mut count = 0;
for i in 0..payload.len() - 1 {
if payload[i] == b'\r' && payload[i + 1] == b'\n' {
count += 1;
}
}
count
}
fn ends_with_ascii(buf: &[u8], ends_with: &str) -> bool {
if buf.len() < ends_with.len() {
return false;
}
let ends_with = ends_with.as_bytes();
for i in 0..ends_with.len() {
if buf[buf.len() - i - 1] != ends_with[ends_with.len() - i - 1] {
return false;
}
}
true
}
fn parse_is_chunked(headers: &[httparse::Header]) -> bool {
match find_header(&headers, TRANSFER_ENCODING.as_str()) {
Some(v) => String::from_utf8_lossy(v).eq_ignore_ascii_case("chunked"),
None => false,
}
}
fn parse_content_length(headers: &[httparse::Header]) -> Result<Option<usize>, Error> {
if let Some(v) = find_header(headers, CONTENT_LENGTH.as_str()) {
let v = String::from_utf8_lossy(v);
return Ok(Some(v.parse().map_err(|_| {
Error::new(ErrorKind::InvalidData, "content-length not a number")
})?));
}
Ok(None)
}
fn find_header<'a>(headers: &'a [httparse::Header], name: &str) -> Option<&'a [u8]> {
for h in headers.iter() {
if h.name.eq_ignore_ascii_case(name) {
return Some(h.value);
}
}
None
}
#[cfg(test)]
mod test {
use std::{net::Ipv4Addr, str::FromStr};
use hyperium_http::{Request, StatusCode};
use tcp_stream::OwnedTLSConfig;
use crate::{ReadStatus, Session};
use super::HttpClient;
#[test]
fn test_google_chunked_response() {
let mut client = HttpClient::new(OwnedTLSConfig::default());
let mut conn = client
.request(Request::get("https://www.google.com").body(()).unwrap())
.unwrap();
loop {
conn.drive().unwrap();
if let ReadStatus::Data(r) = conn.read().unwrap() {
assert_eq!(r.status(), StatusCode::OK);
assert!(String::from_utf8_lossy(r.body()).ends_with("</html>"));
break;
}
}
}
#[test]
fn test_simple_response() {
let mut client = HttpClient::new(OwnedTLSConfig::default());
let mut conn = client
.request(Request::get("http://icanhazip.com").body(()).unwrap())
.unwrap();
loop {
conn.drive().unwrap();
if let ReadStatus::Data(r) = conn.read().unwrap() {
assert_eq!(r.status(), StatusCode::OK);
let body = String::from_utf8_lossy(r.body());
Ipv4Addr::from_str(body.trim()).expect("IP V4 address as body");
break;
}
}
}
}