use std::{
cell::RefCell,
collections::{HashMap, VecDeque},
fmt::Debug,
io::{self, Error, ErrorKind, Read},
mem::swap,
rc::Rc,
str::FromStr,
sync::{Arc, Mutex},
time::Duration,
};
use chrono::{DateTime, Utc};
use hyperium_http::{
Response,
header::{CONNECTION, CONTENT_LENGTH, HOST, TRANSFER_ENCODING},
};
use tcp_stream::OwnedTLSConfig;
use crate::{
DriveOutcome, Flush, Publish, PublishOutcome, Receive, ReceiveOutcome, Session, SessionStatus,
buffer::GrowableCircleBuf,
dns::AddrResolver,
frame::{DeserializeFrame, FrameDuplex, SerializeFrame, SizedFrame},
tcp::TcpSession,
tls::{NativeTlsConnector, TlsConnector},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Scheme {
Http,
Https,
}
impl Scheme {
pub fn default_port(&self) -> u16 {
match self {
Self::Http => 80,
Self::Https => 443,
}
}
}
pub enum HttpRequest {
Request(hyperium_http::Request<Vec<u8>>),
Serialized(Vec<u8>),
}
impl<I: IntoBody> From<hyperium_http::Request<I>> for HttpRequest {
fn from(value: hyperium_http::Request<I>) -> Self {
let (parts, body) = value.into_parts();
HttpRequest::Request(hyperium_http::Request::from_parts(parts, body.into_body()))
}
}
pub struct HttpClient {
tls_connector: Option<Arc<TlsConnector>>,
addr_resolver: Option<Arc<AddrResolver>>,
pool: Option<Arc<HttpClientSessionPool>>,
}
impl HttpClient {
pub fn new() -> Self {
Self {
tls_connector: None,
addr_resolver: None,
pool: None,
}
}
pub fn with_connection_pool(
mut self,
max_connections_per_domain: usize,
max_connections_total: usize,
default_keep_alive_timeout: Duration,
) -> Self {
self.pool = Some(Arc::new(HttpClientSessionPool::new(
max_connections_per_domain,
max_connections_total,
default_keep_alive_timeout,
)));
self
}
pub fn with_addr_resolver(mut self, addr_resolver: Arc<AddrResolver>) -> Self {
self.addr_resolver = Some(addr_resolver);
self
}
pub fn with_tls_connector(mut self, tls_connector: Arc<TlsConnector>) -> Self {
self.tls_connector = Some(tls_connector);
self
}
pub fn with_tls_config(mut self, tls_config: OwnedTLSConfig) -> Result<Self, Error> {
self.tls_connector = Some(Arc::new(TlsConnector::Native(NativeTlsConnector::new(
tls_config.as_ref(),
false,
)?)));
Ok(self)
}
pub fn connect(
&self,
host: &str,
port: u16,
scheme: Scheme,
) -> Result<HttpClientSession, io::Error> {
if let Some(pool) = self.pool.as_ref() {
if let Some(conn) = pool.try_check_out(scheme, host.to_owned(), port) {
return Ok(HttpClientSession::new_with_pool(
conn,
Some((Arc::clone(&pool), (scheme, host.to_owned(), port))),
true,
));
}
}
let mut conn = TcpSession::connect(
format!("{host}:{port}"),
self.addr_resolver.as_ref().map(|x| Arc::clone(x)),
self.tls_connector.as_ref().map(|x| Arc::clone(&x)),
)?;
if scheme == Scheme::Https {
conn = conn.into_tls(&host)?;
}
Ok(HttpClientSession::new_with_pool(
FrameDuplex::new(
conn,
Http1ResponseDeserializer::new(),
Http1RequestSerializer::new(),
0,
),
self.pool
.as_ref()
.map(|pool| (Arc::clone(&pool), (scheme, host.to_owned(), port))),
false,
))
}
pub fn request<I: IntoBody>(
&self,
request: hyperium_http::Request<I>,
) -> Result<HttpClientSession, io::Error> {
let (parts, body) = request.into_parts();
let request = hyperium_http::Request::from_parts(parts, body.into_body());
let scheme = match request.uri().scheme_str() {
None => Scheme::Http,
Some("http") => Scheme::Http,
Some("https") => Scheme::Https,
_ => {
return Err(io::Error::new(
ErrorKind::InvalidData,
"bad http uri scheme",
));
}
};
let mut conn = self.connect(
request.uri().host().unwrap_or_default(),
request
.uri()
.port()
.map(|x| x.as_u16())
.unwrap_or_else(|| match scheme {
Scheme::Http => 80,
Scheme::Https => 443,
}),
scheme,
)?;
conn.pending_initial_request = Some(request.into());
Ok(conn)
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new()
}
}
struct HttpClientSessionPool {
context: Mutex<HttpClientSessionPoolContext>,
}
impl HttpClientSessionPool {
pub fn new(
max_connections_per_domain: usize,
max_connections_total: usize,
default_keep_alive_timeout: Duration,
) -> Self {
Self {
context: Mutex::new(HttpClientSessionPoolContext::new(
max_connections_per_domain,
max_connections_total,
default_keep_alive_timeout,
)),
}
}
pub fn try_check_in(
&self,
scheme: Scheme,
host: String,
port: u16,
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
response_connection_header_value: String,
response_keep_alive_header_value: Option<String>,
) {
if let Ok(mut context) = self.context.lock() {
context.try_check_in(
scheme,
host,
port,
session,
response_connection_header_value,
response_keep_alive_header_value,
)
}
}
pub fn try_check_out(
&self,
scheme: Scheme,
host: String,
port: u16,
) -> Option<FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>> {
let mut context = self.context.lock().ok()?;
context.try_check_out(scheme, host, port)
}
}
struct PoolEntry {
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
expires: DateTime<Utc>,
}
struct HttpClientSessionPoolContext {
max_connections_per_domain: usize,
max_connections_total: usize,
default_keep_alive_timeout: Duration,
domain_sessions: HashMap<(Scheme, String, u16), VecDeque<PoolEntry>>,
cur_connections_total: usize,
}
impl HttpClientSessionPoolContext {
pub fn new(
max_connections_per_domain: usize,
max_connections_total: usize,
default_keep_alive_timeout: Duration,
) -> Self {
Self {
max_connections_per_domain,
max_connections_total,
default_keep_alive_timeout,
domain_sessions: HashMap::new(),
cur_connections_total: 0,
}
}
pub fn try_check_in(
&mut self,
scheme: Scheme,
host: String,
port: u16,
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
response_connection_header_value: String,
response_keep_alive_header_value: Option<String>,
) {
self.cleanup();
if !response_connection_header_value.eq_ignore_ascii_case("Keep-Alive") {
return;
}
let mut keep_alive_timeout = self.default_keep_alive_timeout;
let mut keep_alive_max = None;
if let Some(response_keep_alive_header_value) = response_keep_alive_header_value {
for part in response_keep_alive_header_value
.split(",")
.map(|x| x.trim())
{
if let [key, value] = part
.split("=")
.map(|x| x.trim())
.collect::<Vec<_>>()
.as_slice()
{
if key.eq_ignore_ascii_case("timeout") {
if let Ok(value) = value.parse::<u64>() {
keep_alive_timeout = Duration::from_secs(value)
}
} else if key.eq_ignore_ascii_case("max") {
if let Ok(value) = value.parse::<u64>() {
keep_alive_max = Some(value)
}
}
}
}
}
if keep_alive_max == Some(0) {
return;
}
let expires = Utc::now() + keep_alive_timeout - Duration::from_millis(250);
if self.cur_connections_total >= self.max_connections_total {
return;
}
let domain_sessions = self
.domain_sessions
.entry((scheme, host, port))
.or_default();
if domain_sessions.len() >= self.max_connections_per_domain {
return;
}
self.cur_connections_total += 1;
domain_sessions.push_back(PoolEntry { session, expires });
}
pub fn try_check_out(
&mut self,
scheme: Scheme,
host: String,
port: u16,
) -> Option<FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>> {
self.cleanup();
let session = self
.domain_sessions
.get_mut(&(scheme, host, port))?
.pop_front()
.map(|x| x.session);
if session.is_some() {
self.cur_connections_total -= 1;
}
session
}
pub fn cleanup(&mut self) {
let now = Utc::now();
for (_, domain_sessions) in self.domain_sessions.iter_mut() {
domain_sessions.retain(|x| now < x.expires);
}
self.domain_sessions.retain(|_, x| !x.is_empty());
}
}
pub(crate) fn connect_stream(
scheme: Scheme,
host: Option<&str>,
port: Option<u16>,
addr_resolver: Option<Arc<AddrResolver>>,
tls_connector: Option<Arc<TlsConnector>>,
) -> Result<TcpSession, Error> {
let host = match host {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let port = match port {
Some(x) => x,
None => scheme.default_port(),
};
let mut conn = TcpSession::connect(format!("{host}:{port}"), addr_resolver, tls_connector)?;
if scheme == Scheme::Https {
conn = conn
.into_tls(&host)
.map_err(|err| Error::new(ErrorKind::ConnectionRefused, err))?;
}
Ok(conn)
}
pub struct HttpClientSession {
session: Option<FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>>,
pending_initial_request: Option<HttpRequest>,
session_pool: Option<(Arc<HttpClientSessionPool>, (Scheme, String, u16))>,
response_connection_header_value: Option<String>,
response_keep_alive_header_value: Option<String>,
pooled: bool,
}
impl HttpClientSession {
pub fn new(
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
) -> Self {
Self {
session: Some(session),
pending_initial_request: None,
session_pool: None,
response_connection_header_value: None,
response_keep_alive_header_value: None,
pooled: false,
}
}
fn new_with_pool(
session: FrameDuplex<TcpSession, Http1ResponseDeserializer, Http1RequestSerializer>,
session_pool: Option<(Arc<HttpClientSessionPool>, (Scheme, String, u16))>,
pooled: bool,
) -> Self {
Self {
session: Some(session),
pending_initial_request: None,
session_pool,
response_connection_header_value: None,
response_keep_alive_header_value: None,
pooled,
}
}
pub fn is_pooled(&self) -> bool {
self.pooled
}
}
impl Drop for HttpClientSession {
fn drop(&mut self) {
if let (
Some(session),
Some((pool, (scheme, host, port))),
Some(response_connection_header_value),
) = (
self.session.take(),
self.session_pool.take(),
self.response_connection_header_value.take(),
) {
pool.try_check_in(
scheme,
host,
port,
session,
response_connection_header_value,
self.response_keep_alive_header_value.take(),
);
}
}
}
impl Session for HttpClientSession {
fn status(&self) -> SessionStatus {
match self.session.as_ref() {
Some(x) => x.status(),
None => SessionStatus::Terminated,
}
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
let session = match self.session.as_mut() {
Some(x) => x,
None => return Err(Error::new(ErrorKind::NotConnected, "FrameDuplex closed")),
};
let mut result: crate::DriveOutcome = session.drive()?;
if session.status() == SessionStatus::Established && self.pending_initial_request.is_some()
{
let wrote = match session.publish(
self.pending_initial_request
.take()
.expect("checked pending_request"),
)? {
PublishOutcome::Published => true,
PublishOutcome::Incomplete(x) => {
self.pending_initial_request = Some(x);
false
}
};
if wrote {
self.pending_initial_request = None;
result = DriveOutcome::Active;
}
}
Ok(result)
}
}
impl Receive for HttpClientSession {
type ReceivePayload<'a> = hyperium_http::Response<Vec<u8>>;
fn receive<'a>(&'a mut self) -> Result<crate::ReceiveOutcome<Self::ReceivePayload<'a>>, Error> {
self.drive()?;
if self.pending_initial_request.is_none() && self.status() == SessionStatus::Established {
let outcome = match self.session.as_mut() {
Some(x) => x.receive()?,
None => return Err(Error::new(ErrorKind::NotConnected, "FrameDuplex closed")),
};
if self.session_pool.is_some() {
if let ReceiveOutcome::Payload(response) = &outcome {
self.response_connection_header_value = response
.headers()
.get(CONNECTION)
.and_then(|x| Some(x.to_str().ok()?.to_owned()));
self.response_keep_alive_header_value = response
.headers()
.get("Keep-Alive")
.and_then(|x| Some(x.to_str().ok()?.to_owned()));
}
}
Ok(outcome)
} else {
Ok(crate::ReceiveOutcome::Idle)
}
}
}
impl Publish for HttpClientSession {
type PublishPayload<'a> = HttpRequest;
fn publish<'a>(
&mut self,
data: Self::PublishPayload<'a>,
) -> Result<PublishOutcome<Self::PublishPayload<'a>>, Error> {
let mut data = data;
if self.session_pool.is_some() {
if let HttpRequest::Request(req) = &mut data {
if let Ok(value) = "Keep-Alive".parse() {
req.headers_mut().insert(CONNECTION, value);
}
}
}
match self.session.as_mut() {
Some(x) => x.publish(data),
None => Err(Error::new(ErrorKind::NotConnected, "FrameDuplex closed")),
}
}
}
impl Flush for HttpClientSession {
fn flush(&mut self) -> Result<(), Error> {
match self.session.as_mut() {
Some(x) => x.flush(),
None => Err(Error::new(ErrorKind::NotConnected, "FrameDuplex closed")),
}
}
}
impl Debug for HttpClientSession {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpClientSession")
.field("session", &self.session)
.finish()
}
}
struct PersistentHttpSessionContext {
session: HttpClientSession,
active_request_id: u64,
next_request_id: u64,
closed: bool,
}
pub struct PersistentHttpConnection {
context: Rc<RefCell<PersistentHttpSessionContext>>,
}
impl From<HttpClientSession> for PersistentHttpConnection {
fn from(session: HttpClientSession) -> Self {
Self {
context: Rc::new(RefCell::new(PersistentHttpSessionContext {
session,
active_request_id: 0,
next_request_id: 0,
closed: false,
})),
}
}
}
impl Debug for PersistentHttpConnection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PersistentHttpConnection").finish()
}
}
impl Session for PersistentHttpConnection {
fn status(&self) -> crate::SessionStatus {
self.context.borrow().session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
self.context.borrow_mut().session.drive()
}
}
impl PersistentHttpConnection {
pub fn request<I: IntoBody>(
&mut self,
mut request: hyperium_http::Request<I>,
) -> Result<PendingHttpResponse, Error> {
request
.headers_mut()
.insert("Connection", "keep-alive".parse().unwrap());
let mut context = self.context.borrow_mut();
if context.closed {
return Err(Error::new(
ErrorKind::ConnectionRefused,
"connection closed by server",
));
}
let request_id = context.next_request_id;
context.next_request_id += 1;
drop(context);
Ok(PendingHttpResponse {
context: Rc::clone(&self.context),
pending_request: Some(request.into()),
request_id,
})
}
}
pub struct PendingHttpResponse {
context: Rc<RefCell<PersistentHttpSessionContext>>,
pending_request: Option<HttpRequest>,
request_id: u64,
}
impl Debug for PendingHttpResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PendingHttpResponse")
.field("request_id", &self.request_id)
.finish()
}
}
impl Session for PendingHttpResponse {
fn status(&self) -> crate::SessionStatus {
self.context.borrow().session.status()
}
fn drive(&mut self) -> Result<DriveOutcome, Error> {
let mut context = self.context.borrow_mut();
if context.active_request_id == self.request_id {
context.session.drive()
} else {
Ok(DriveOutcome::Idle)
}
}
}
impl Receive for PendingHttpResponse {
type ReceivePayload<'a> = hyperium_http::Response<Vec<u8>>;
fn receive<'a>(&'a mut self) -> Result<ReceiveOutcome<Self::ReceivePayload<'a>>, Error> {
let mut context = self.context.borrow_mut();
if context.active_request_id != self.request_id {
return Ok(ReceiveOutcome::Idle);
}
let drive_outcome = context.session.drive()?;
if context.session.status() == SessionStatus::Establishing {
match drive_outcome {
DriveOutcome::Active => return Ok(ReceiveOutcome::Active),
DriveOutcome::Idle => return Ok(ReceiveOutcome::Idle),
}
}
if let Some(request) = self.pending_request.take() {
match context.session.publish(request)? {
PublishOutcome::Incomplete(request) => {
self.pending_request = Some(request);
return Ok(ReceiveOutcome::Idle);
}
PublishOutcome::Published => {
return Ok(ReceiveOutcome::Active);
}
}
}
match context.session.receive()? {
ReceiveOutcome::Payload(response) => {
context.active_request_id += 1;
if response
.headers()
.get("Connection")
.map(|x| x.to_str().unwrap())
== Some("close")
{
context.closed = true;
}
Ok(ReceiveOutcome::Payload(response))
}
ReceiveOutcome::Active => Ok(ReceiveOutcome::Active),
ReceiveOutcome::Idle => Ok(ReceiveOutcome::Idle),
}
}
}
pub trait IntoBody {
fn into_body(self) -> Vec<u8>;
}
impl IntoBody for String {
fn into_body(self) -> Vec<u8> {
self.into_bytes()
}
}
impl IntoBody for &str {
fn into_body(self) -> Vec<u8> {
self.as_bytes().to_vec()
}
}
impl IntoBody for Vec<u8> {
fn into_body(self) -> Vec<u8> {
self
}
}
impl IntoBody for &[u8] {
fn into_body(self) -> Vec<u8> {
self.to_vec()
}
}
impl IntoBody for () {
fn into_body(self) -> Vec<u8> {
Vec::new()
}
}
enum BodyType {
ContentLength(usize),
ChunkedTransfer,
OnClose,
None,
}
struct BodyInfo {
offset: usize,
ty: BodyType,
}
impl BodyInfo {
pub fn new(offset: usize, ty: BodyType) -> Self {
Self { offset, ty }
}
}
pub struct Http1ResponseDeserializer {
deserialized_response: Option<hyperium_http::Response<Vec<u8>>>,
deserialized_size: usize,
body_info: Option<BodyInfo>,
}
impl Http1ResponseDeserializer {
pub fn new() -> Self {
Self {
deserialized_response: None,
deserialized_size: 0,
body_info: None,
}
}
}
impl DeserializeFrame for Http1ResponseDeserializer {
type DeserializedFrame<'a> = hyperium_http::Response<Vec<u8>>;
fn check_deserialize_frame(&mut self, data: &[u8], eof: bool) -> Result<bool, Error> {
if self.deserialized_response.is_none() {
self.deserialized_response = Some(Response::new(Vec::new()));
}
let deserialized_response = self
.deserialized_response
.as_mut()
.expect("checked deserialized_response value");
let header_count: usize = count_max_headers(data);
let mut headers = Vec::new();
headers.resize(header_count, httparse::EMPTY_HEADER);
if self.body_info.is_none() {
let mut parsed = httparse::Response::new(&mut headers);
match parsed.parse(data).map_err(|err| {
Error::new(
ErrorKind::InvalidData,
format!("http response parse failed: {err:?}").as_str(),
)
})? {
httparse::Status::Complete(size) => {
if parse_is_chunked(&parsed.headers) {
self.body_info = Some(BodyInfo::new(size, BodyType::ChunkedTransfer));
} else if let Some(content_length) = parse_content_length(&parsed.headers)? {
self.body_info =
Some(BodyInfo::new(size, BodyType::ContentLength(content_length)));
} else if parsed.version.is_none() || parsed.version == Some(1) {
self.body_info = Some(BodyInfo::new(size, BodyType::OnClose));
} else {
self.body_info = Some(BodyInfo::new(size, BodyType::None));
}
parsed_into_response(parsed, deserialized_response)?;
}
httparse::Status::Partial => return Ok(false),
}
}
let (parsed_body, total_size) = match &self.body_info {
None => (None, 0),
Some(body_info) => {
match body_info.ty {
BodyType::ChunkedTransfer => {
if body_info.offset < data.len() && ends_with_ascii(data, "\r\n\r\n") {
let mut body = Vec::new();
let mut decoder =
chunked_transfer::Decoder::new(&data[body_info.offset..]);
decoder.read_to_end(&mut body)?;
let body_len = body.len();
match decoder.remaining_chunks_size() {
None => (Some(body), body_info.offset + body_len),
Some(_) => (None, 0),
}
} else {
(None, 0)
}
}
BodyType::ContentLength(content_length) => {
let total_length = body_info.offset + content_length;
if data.len() >= total_length {
(
Some(data[body_info.offset..total_length].to_vec()),
total_length,
)
} else {
(None, 0)
}
}
BodyType::OnClose => {
if eof {
(Some(data[body_info.offset..].to_vec()), data.len())
} else {
(None, 0)
}
}
BodyType::None => (Some(Vec::new()), body_info.offset),
}
}
};
match parsed_body {
None => {
if eof {
Err(Error::new(
ErrorKind::UnexpectedEof,
"http connection terminated before receiving full response",
))
} else {
Ok(false)
}
}
Some(mut body) => {
swap(deserialized_response.body_mut(), &mut body);
self.body_info = None;
self.deserialized_size = total_size;
Ok(true)
}
}
}
fn deserialize_frame<'a>(
&'a mut self,
_data: &'a [u8],
) -> Result<crate::frame::SizedFrame<Self::DeserializedFrame<'a>>, Error> {
Ok(SizedFrame::new(
self.deserialized_response
.take()
.ok_or_else(|| Error::new(ErrorKind::Other, "no deserialized frame"))?,
self.deserialized_size,
))
}
}
pub struct Http1RequestSerializer {}
impl Http1RequestSerializer {
pub fn new() -> Self {
Self {}
}
}
impl SerializeFrame for Http1RequestSerializer {
type SerializedFrame<'a> = HttpRequest;
fn serialize_frame<'a>(
&mut self,
request: Self::SerializedFrame<'a>,
buffer: &mut GrowableCircleBuf,
) -> Result<PublishOutcome<Self::SerializedFrame<'a>>, Error> {
let serialized_request = match request {
HttpRequest::Request(request) => {
match request.version() {
hyperium_http::Version::HTTP_10 | hyperium_http::Version::HTTP_11 => {}
version => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("unsupported http request version {version:?}").as_str(),
));
}
}
let host = match request.uri().host() {
Some(x) => x.to_owned(),
None => return Err(io::Error::new(ErrorKind::InvalidData, "missing host")),
};
let body = request.body();
let content_length = body.len().to_string();
let mut serialized_request = Vec::new();
serialized_request.extend_from_slice(request.method().as_str().as_bytes());
serialized_request.extend_from_slice(" ".as_bytes());
serialized_request.extend_from_slice(request.uri().path().as_bytes());
if let Some(query) = request.uri().query() {
serialized_request.extend_from_slice("?".as_bytes());
serialized_request.extend_from_slice(query.as_bytes());
}
serialized_request
.extend_from_slice(format!(" {:?}", request.version()).as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
{
serialized_request.extend_from_slice(HOST.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(host.as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
for (n, v) in request.headers().iter() {
serialized_request.extend_from_slice(n.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(
v.to_str()
.map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("could not convert header '{}' to string", n.as_str())
.as_str(),
)
})?
.as_bytes(),
);
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
if body.len() > 0 {
serialized_request.extend_from_slice(CONTENT_LENGTH.as_str().as_bytes());
serialized_request.extend_from_slice(": ".as_bytes());
serialized_request.extend_from_slice(content_length.as_bytes());
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
}
serialized_request.extend_from_slice(LINE_BREAK.as_bytes());
serialized_request.extend_from_slice(body);
serialized_request
}
HttpRequest::Serialized(serialized) => serialized,
};
if buffer.try_write(&vec![&serialized_request])? {
Ok(PublishOutcome::Published)
} else {
Ok(PublishOutcome::Incomplete(HttpRequest::Serialized(
serialized_request,
)))
}
}
}
fn parsed_into_response(
parsed: httparse::Response,
resp: &mut http::Response<Vec<u8>>,
) -> Result<(), Error> {
if let Some(code) = parsed.code {
let status_code = hyperium_http::StatusCode::try_from(code).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid status code '{code}'").as_str(),
)
})?;
*resp.status_mut() = status_code;
}
if let Some(version) = parsed.version {
*resp.version_mut() = match version {
0 => hyperium_http::Version::HTTP_10,
1 => hyperium_http::Version::HTTP_11,
_ => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("response invalid version '{version}'").as_str(),
));
}
};
}
for h in parsed.headers.iter() {
let name = http::HeaderName::from_str(h.name).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header name '{}'", h.name).as_str(),
)
})?;
let value = http::HeaderValue::from_bytes(h.value).map_err(|_| {
Error::new(
ErrorKind::InvalidData,
format!("response invalid header value '{:?}'", h.value).as_str(),
)
})?;
resp.headers_mut().insert(name, value);
}
Ok(())
}
const LINE_BREAK: &str = "\r\n";
fn count_max_headers(payload: &[u8]) -> usize {
if payload.is_empty() {
return 0;
}
let mut count = 0;
for i in 0..payload.len() - 1 {
if payload[i] == b'\r' && payload[i + 1] == b'\n' {
count += 1;
}
}
count
}
fn ends_with_ascii(buf: &[u8], ends_with: &str) -> bool {
if buf.len() < ends_with.len() {
return false;
}
let ends_with = ends_with.as_bytes();
for i in 0..ends_with.len() {
if buf[buf.len() - i - 1] != ends_with[ends_with.len() - i - 1] {
return false;
}
}
true
}
fn parse_is_chunked(headers: &[httparse::Header]) -> bool {
match find_header(&headers, TRANSFER_ENCODING.as_str()) {
Some(v) => String::from_utf8_lossy(v).eq_ignore_ascii_case("chunked"),
None => false,
}
}
fn parse_content_length(headers: &[httparse::Header]) -> Result<Option<usize>, Error> {
if let Some(v) = find_header(headers, CONTENT_LENGTH.as_str()) {
let v = String::from_utf8_lossy(v);
return Ok(Some(v.parse().map_err(|_| {
Error::new(ErrorKind::InvalidData, "content-length not a number")
})?));
}
Ok(None)
}
fn find_header<'a>(headers: &'a [httparse::Header], name: &str) -> Option<&'a [u8]> {
for h in headers.iter() {
if h.name.eq_ignore_ascii_case(name) {
return Some(h.value);
}
}
None
}