use crate::{
bindings::http::types::{self, ErrorCode, Method, Scheme},
body::{HostIncomingBody, HyperIncomingBody, HyperOutgoingBody},
};
use bytes::Bytes;
use http::header::{HeaderMap, HeaderName, HeaderValue};
use http_body_util::BodyExt;
use hyper::body::Body;
use std::any::Any;
use std::fmt;
use std::time::Duration;
use wasmtime::component::{Resource, ResourceTable};
use wasmtime::{Result, bail};
use wasmtime_wasi::p2::Pollable;
use wasmtime_wasi::runtime::AbortOnDropJoinHandle;
#[cfg(feature = "default-send-request")]
use {
crate::io::TokioIo,
crate::{error::dns_error, hyper_request_error},
tokio::net::TcpStream,
tokio::time::timeout,
};
const DEFAULT_FIELD_SIZE_LIMIT: usize = 128 * 1024;
#[derive(Debug)]
pub struct WasiHttpCtx {
pub(crate) field_size_limit: usize,
}
impl WasiHttpCtx {
pub fn new() -> Self {
Self {
field_size_limit: DEFAULT_FIELD_SIZE_LIMIT,
}
}
pub fn set_field_size_limit(&mut self, limit: usize) {
self.field_size_limit = limit;
}
}
pub trait WasiHttpView {
fn ctx(&mut self) -> &mut WasiHttpCtx;
fn table(&mut self) -> &mut ResourceTable;
fn new_incoming_request<B>(
&mut self,
scheme: Scheme,
req: hyper::Request<B>,
) -> wasmtime::Result<Resource<HostIncomingRequest>>
where
B: Body<Data = Bytes> + Send + 'static,
B::Error: Into<ErrorCode>,
Self: Sized,
{
let field_size_limit = self.ctx().field_size_limit;
let (parts, body) = req.into_parts();
let body = body.map_err(Into::into).boxed_unsync();
let body = HostIncomingBody::new(
body,
std::time::Duration::from_millis(600 * 1000),
field_size_limit,
);
let incoming_req =
HostIncomingRequest::new(self, parts, scheme, Some(body), field_size_limit)?;
Ok(self.table().push(incoming_req)?)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
let id = self.table().push(HostResponseOutparam { result })?;
Ok(id)
}
#[cfg(feature = "default-send-request")]
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
Ok(default_send_request(request, config))
}
#[cfg(not(feature = "default-send-request"))]
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse>;
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
DEFAULT_FORBIDDEN_HEADERS.contains(name)
}
fn outgoing_body_buffer_chunks(&mut self) -> usize {
DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS
}
fn outgoing_body_chunk_size(&mut self) -> usize {
DEFAULT_OUTGOING_BODY_CHUNK_SIZE
}
}
pub const DEFAULT_OUTGOING_BODY_BUFFER_CHUNKS: usize = 1;
pub const DEFAULT_OUTGOING_BODY_CHUNK_SIZE: usize = 1024 * 1024;
impl<T: ?Sized + WasiHttpView> WasiHttpView for &mut T {
fn ctx(&mut self) -> &mut WasiHttpCtx {
T::ctx(self)
}
fn table(&mut self) -> &mut ResourceTable {
T::table(self)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
T::new_response_outparam(self, result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
T::send_request(self, request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
T::is_forbidden_header(self, name)
}
fn outgoing_body_buffer_chunks(&mut self) -> usize {
T::outgoing_body_buffer_chunks(self)
}
fn outgoing_body_chunk_size(&mut self) -> usize {
T::outgoing_body_chunk_size(self)
}
}
impl<T: ?Sized + WasiHttpView> WasiHttpView for Box<T> {
fn ctx(&mut self) -> &mut WasiHttpCtx {
T::ctx(self)
}
fn table(&mut self) -> &mut ResourceTable {
T::table(self)
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
T::new_response_outparam(self, result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
T::send_request(self, request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
T::is_forbidden_header(self, name)
}
fn outgoing_body_buffer_chunks(&mut self) -> usize {
T::outgoing_body_buffer_chunks(self)
}
fn outgoing_body_chunk_size(&mut self) -> usize {
T::outgoing_body_chunk_size(self)
}
}
#[repr(transparent)]
pub struct WasiHttpImpl<T>(pub T);
impl<T: WasiHttpView> WasiHttpView for WasiHttpImpl<T> {
fn ctx(&mut self) -> &mut WasiHttpCtx {
self.0.ctx()
}
fn table(&mut self) -> &mut ResourceTable {
self.0.table()
}
fn new_response_outparam(
&mut self,
result: tokio::sync::oneshot::Sender<
Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>,
>,
) -> wasmtime::Result<Resource<HostResponseOutparam>> {
self.0.new_response_outparam(result)
}
fn send_request(
&mut self,
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> crate::HttpResult<HostFutureIncomingResponse> {
self.0.send_request(request, config)
}
fn is_forbidden_header(&mut self, name: &HeaderName) -> bool {
self.0.is_forbidden_header(name)
}
fn outgoing_body_buffer_chunks(&mut self) -> usize {
self.0.outgoing_body_buffer_chunks()
}
fn outgoing_body_chunk_size(&mut self) -> usize {
self.0.outgoing_body_chunk_size()
}
}
pub const DEFAULT_FORBIDDEN_HEADERS: [http::header::HeaderName; 9] = [
hyper::header::CONNECTION,
HeaderName::from_static("keep-alive"),
hyper::header::PROXY_AUTHENTICATE,
hyper::header::PROXY_AUTHORIZATION,
HeaderName::from_static("proxy-connection"),
hyper::header::TRANSFER_ENCODING,
hyper::header::UPGRADE,
hyper::header::HOST,
HeaderName::from_static("http2-settings"),
];
pub(crate) fn remove_forbidden_headers(view: &mut dyn WasiHttpView, headers: &mut FieldMap) {
let forbidden_keys = Vec::from_iter(headers.as_ref().keys().filter_map(|name| {
if view.is_forbidden_header(name) {
Some(name.clone())
} else {
None
}
}));
for name in forbidden_keys {
headers.remove_all(&name);
}
}
pub struct OutgoingRequestConfig {
pub use_tls: bool,
pub connect_timeout: Duration,
pub first_byte_timeout: Duration,
pub between_bytes_timeout: Duration,
}
#[cfg(feature = "default-send-request")]
pub fn default_send_request(
request: hyper::Request<HyperOutgoingBody>,
config: OutgoingRequestConfig,
) -> HostFutureIncomingResponse {
let handle = wasmtime_wasi::runtime::spawn(async move {
Ok(default_send_request_handler(request, config).await)
});
HostFutureIncomingResponse::pending(handle)
}
#[cfg(feature = "default-send-request")]
pub async fn default_send_request_handler(
mut request: hyper::Request<HyperOutgoingBody>,
OutgoingRequestConfig {
use_tls,
connect_timeout,
first_byte_timeout,
between_bytes_timeout,
}: OutgoingRequestConfig,
) -> Result<IncomingResponse, types::ErrorCode> {
let authority = if let Some(authority) = request.uri().authority() {
if authority.port().is_some() {
authority.to_string()
} else {
let port = if use_tls { 443 } else { 80 };
format!("{}:{port}", authority.to_string())
}
} else {
return Err(types::ErrorCode::HttpRequestUriInvalid);
};
let tcp_stream = timeout(connect_timeout, TcpStream::connect(&authority))
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(|e| match e.kind() {
std::io::ErrorKind::AddrNotAvailable => {
dns_error("address not available".to_string(), 0)
}
_ => {
if e.to_string()
.starts_with("failed to lookup address information")
{
dns_error("address not available".to_string(), 0)
} else {
types::ErrorCode::ConnectionRefused
}
}
})?;
let (mut sender, worker) = if use_tls {
use rustls::pki_types::ServerName;
let root_cert_store = rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.into(),
};
let config = rustls::ClientConfig::builder()
.with_root_certificates(root_cert_store)
.with_no_client_auth();
let connector = tokio_rustls::TlsConnector::from(std::sync::Arc::new(config));
let mut parts = authority.split(":");
let host = parts.next().unwrap_or(&authority);
let domain = ServerName::try_from(host)
.map_err(|e| {
tracing::warn!("dns lookup error: {e:?}");
dns_error("invalid dns name".to_string(), 0)
})?
.to_owned();
let stream = connector.connect(domain, tcp_stream).await.map_err(|e| {
tracing::warn!("tls protocol error: {e:?}");
types::ErrorCode::TlsProtocolError
})?;
let stream = TokioIo::new(stream);
let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Ok(()) => {}
Err(e) => tracing::warn!("dropping error {e}"),
}
});
(sender, worker)
} else {
let tcp_stream = TokioIo::new(tcp_stream);
let (sender, conn) = timeout(
connect_timeout,
hyper::client::conn::http1::handshake(tcp_stream),
)
.await
.map_err(|_| types::ErrorCode::ConnectionTimeout)?
.map_err(hyper_request_error)?;
let worker = wasmtime_wasi::runtime::spawn(async move {
match conn.await {
Ok(()) => {}
Err(e) => tracing::warn!("dropping error {e}"),
}
});
(sender, worker)
};
*request.uri_mut() = http::Uri::builder()
.path_and_query(
request
.uri()
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/"),
)
.build()
.expect("comes from valid request");
let resp = timeout(first_byte_timeout, sender.send_request(request))
.await
.map_err(|_| types::ErrorCode::ConnectionReadTimeout)?
.map_err(hyper_request_error)?
.map(|body| body.map_err(hyper_request_error).boxed_unsync());
Ok(IncomingResponse {
resp,
worker: Some(worker),
between_bytes_timeout,
})
}
impl From<http::Method> for types::Method {
fn from(method: http::Method) -> Self {
if method == http::Method::GET {
types::Method::Get
} else if method == hyper::Method::HEAD {
types::Method::Head
} else if method == hyper::Method::POST {
types::Method::Post
} else if method == hyper::Method::PUT {
types::Method::Put
} else if method == hyper::Method::DELETE {
types::Method::Delete
} else if method == hyper::Method::CONNECT {
types::Method::Connect
} else if method == hyper::Method::OPTIONS {
types::Method::Options
} else if method == hyper::Method::TRACE {
types::Method::Trace
} else if method == hyper::Method::PATCH {
types::Method::Patch
} else {
types::Method::Other(method.to_string())
}
}
}
impl TryInto<http::Method> for types::Method {
type Error = http::method::InvalidMethod;
fn try_into(self) -> Result<http::Method, Self::Error> {
match self {
Method::Get => Ok(http::Method::GET),
Method::Head => Ok(http::Method::HEAD),
Method::Post => Ok(http::Method::POST),
Method::Put => Ok(http::Method::PUT),
Method::Delete => Ok(http::Method::DELETE),
Method::Connect => Ok(http::Method::CONNECT),
Method::Options => Ok(http::Method::OPTIONS),
Method::Trace => Ok(http::Method::TRACE),
Method::Patch => Ok(http::Method::PATCH),
Method::Other(s) => http::Method::from_bytes(s.as_bytes()),
}
}
}
#[derive(Debug)]
pub struct HostIncomingRequest {
pub(crate) method: http::method::Method,
pub(crate) uri: http::uri::Uri,
pub(crate) headers: FieldMap,
pub(crate) scheme: Scheme,
pub(crate) authority: String,
pub body: Option<HostIncomingBody>,
}
impl HostIncomingRequest {
pub fn new(
view: &mut dyn WasiHttpView,
parts: http::request::Parts,
scheme: Scheme,
body: Option<HostIncomingBody>,
field_size_limit: usize,
) -> wasmtime::Result<Self> {
let authority = match parts.uri.authority() {
Some(authority) => authority.to_string(),
None => match parts.headers.get(http::header::HOST) {
Some(host) => host.to_str()?.to_string(),
None => bail!("invalid HTTP request missing authority in URI and host header"),
},
};
let mut headers = FieldMap::new(parts.headers, field_size_limit);
remove_forbidden_headers(view, &mut headers);
Ok(Self {
method: parts.method,
uri: parts.uri,
headers,
authority,
scheme,
body,
})
}
}
pub struct HostResponseOutparam {
pub result:
tokio::sync::oneshot::Sender<Result<hyper::Response<HyperOutgoingBody>, types::ErrorCode>>,
}
pub struct HostOutgoingResponse {
pub status: http::StatusCode,
pub headers: FieldMap,
pub body: Option<HyperOutgoingBody>,
}
impl TryFrom<HostOutgoingResponse> for hyper::Response<HyperOutgoingBody> {
type Error = http::Error;
fn try_from(
resp: HostOutgoingResponse,
) -> Result<hyper::Response<HyperOutgoingBody>, Self::Error> {
use http_body_util::Empty;
let mut builder = hyper::Response::builder().status(resp.status);
*builder.headers_mut().unwrap() = resp.headers.map;
match resp.body {
Some(body) => builder.body(body),
None => builder.body(
Empty::<bytes::Bytes>::new()
.map_err(|_| unreachable!("Infallible error"))
.boxed_unsync(),
),
}
}
}
#[derive(Debug)]
pub struct HostOutgoingRequest {
pub method: Method,
pub scheme: Option<Scheme>,
pub authority: Option<String>,
pub path_with_query: Option<String>,
pub headers: FieldMap,
pub body: Option<HyperOutgoingBody>,
}
#[derive(Debug, Default)]
pub struct HostRequestOptions {
pub connect_timeout: Option<std::time::Duration>,
pub first_byte_timeout: Option<std::time::Duration>,
pub between_bytes_timeout: Option<std::time::Duration>,
}
#[derive(Debug)]
pub struct HostIncomingResponse {
pub status: u16,
pub headers: FieldMap,
pub body: Option<HostIncomingBody>,
}
#[derive(Debug)]
pub enum HostFields {
Ref {
parent: u32,
get_fields: for<'a> fn(elem: &'a mut (dyn Any + 'static)) -> &'a mut FieldMap,
},
Owned {
fields: FieldMap,
},
}
#[derive(Debug, Clone)]
pub struct FieldMap {
map: HeaderMap,
limit: usize,
size: usize,
}
#[derive(Debug)]
pub struct FieldSizeLimitError {
pub(crate) size: usize,
pub(crate) limit: usize,
}
impl fmt::Display for FieldSizeLimitError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Field size limit {} exceeded: {}", self.limit, self.size)
}
}
impl std::error::Error for FieldSizeLimitError {}
impl FieldMap {
pub fn new(map: HeaderMap, limit: usize) -> Self {
let size = Self::content_size(&map);
Self { map, size, limit }
}
pub fn empty(limit: usize) -> Self {
Self {
map: HeaderMap::new(),
size: 0,
limit,
}
}
pub fn into_inner(self) -> HeaderMap {
self.map
}
pub(crate) fn content_size(map: &HeaderMap) -> usize {
let mut sum = 0;
for key in map.keys() {
sum += header_name_size(key);
}
for value in map.values() {
sum += header_value_size(value);
}
sum
}
pub fn remove_all(&mut self, key: &HeaderName) -> Vec<HeaderValue> {
use http::header::Entry;
match self.map.try_entry(key) {
Ok(Entry::Vacant { .. }) | Err(_) => Vec::new(),
Ok(Entry::Occupied(e)) => {
let (name, value_drain) = e.remove_entry_mult();
let mut removed = header_name_size(&name);
let values = value_drain.collect::<Vec<_>>();
for v in values.iter() {
removed += header_value_size(v);
}
self.size -= removed;
values
}
}
}
pub fn append(&mut self, key: &HeaderName, value: HeaderValue) -> Result<bool> {
let key_size = header_name_size(key);
let val_size = header_value_size(&value);
let new_size = if !self.map.contains_key(key) {
self.size + key_size + val_size
} else {
self.size + val_size
};
if new_size > self.limit {
bail!(FieldSizeLimitError {
limit: self.limit,
size: new_size
})
}
self.size = new_size;
Ok(self.map.try_append(key, value)?)
}
}
fn header_name_size(name: &HeaderName) -> usize {
name.as_str().len() + size_of::<HeaderName>()
}
fn header_value_size(value: &HeaderValue) -> usize {
value.len() + size_of::<HeaderValue>()
}
impl AsRef<HeaderMap> for FieldMap {
fn as_ref(&self) -> &HeaderMap {
&self.map
}
}
pub type FutureIncomingResponseHandle =
AbortOnDropJoinHandle<wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>>;
#[derive(Debug)]
pub struct IncomingResponse {
pub resp: hyper::Response<HyperIncomingBody>,
pub worker: Option<AbortOnDropJoinHandle<()>>,
pub between_bytes_timeout: std::time::Duration,
}
#[derive(Debug)]
pub enum HostFutureIncomingResponse {
Pending(FutureIncomingResponseHandle),
Ready(wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>),
Consumed,
}
impl HostFutureIncomingResponse {
pub fn pending(handle: FutureIncomingResponseHandle) -> Self {
Self::Pending(handle)
}
pub fn ready(result: wasmtime::Result<Result<IncomingResponse, types::ErrorCode>>) -> Self {
Self::Ready(result)
}
pub fn is_ready(&self) -> bool {
matches!(self, Self::Ready(_))
}
pub fn unwrap_ready(self) -> wasmtime::Result<Result<IncomingResponse, types::ErrorCode>> {
match self {
Self::Ready(res) => res,
Self::Pending(_) | Self::Consumed => {
panic!("unwrap_ready called on a pending HostFutureIncomingResponse")
}
}
}
}
#[async_trait::async_trait]
impl Pollable for HostFutureIncomingResponse {
async fn ready(&mut self) {
if let Self::Pending(handle) = self {
*self = Self::Ready(handle.await);
}
}
}