use std::{
pin::Pin,
time::Duration,
{convert::Infallible, sync::Arc},
};
use bytes::Bytes;
use flume::{bounded, Receiver, Sender};
use futures::Future;
use http::header::HeaderMap;
use thiserror::Error as ThisError;
use tokio::task::JoinHandle;
use tracing::{error, info, trace, Instrument};
use warp::{filters::cors::Builder, path::FullPath, Filter};
use wasmbus_rpc::{core::LinkDefinition, error::RpcResult};
use wasmcloud_interface_httpserver::{HttpRequest, HttpResponse};
mod settings;
pub use settings::{load_settings, ServiceSettings, CONTENT_LEN_LIMIT, DEFAULT_MAX_CONTENT_LEN};
mod hashmap_ci;
pub(crate) use hashmap_ci::make_case_insensitive;
pub mod wasmcloud_interface_httpserver {
smithy_bindgen::smithy_bindgen!(
"httpserver/httpserver.smithy",
"org.wasmcloud.interface.httpserver"
);
}
#[derive(ThisError, Debug)]
pub enum Error {
#[error("invalid parameter: {0}")]
InvalidParameter(String),
#[error("problem reading settings: {0}")]
Settings(String),
#[error("provider startup: {0}")]
Init(String),
#[error("warp error: {0}")]
Warp(warp::Error),
#[error("deserializing settings: {0}")]
SettingsToml(toml::de::Error),
}
pub type AsyncCallActorFn = Box<
dyn Fn(
String,
Arc<LinkDefinition>,
HttpRequest,
Option<Duration>,
) -> Pin<Box<dyn Future<Output = RpcResult<HttpResponse>> + Send + 'static>>
+ Send
+ Sync,
>;
struct CallActorFn(AsyncCallActorFn);
pub struct Inner {
settings: ServiceSettings,
lattice_id: String,
shutdown_tx: Sender<bool>,
shutdown_rx: Receiver<bool>,
call_actor: CallActorFn,
}
#[derive(Clone)]
pub struct HttpServerCore {
inner: Arc<Inner>,
}
impl std::ops::Deref for HttpServerCore {
type Target = Inner;
fn deref(&self) -> &Self::Target {
self.inner.as_ref()
}
}
impl HttpServerCore {
pub fn new<F, Fut>(settings: ServiceSettings, lattice_id: String, call_actor_fn: F) -> Self
where
F: Fn(String, Arc<LinkDefinition>, HttpRequest, Option<Duration>) -> Fut
+ Send
+ Sync
+ 'static,
Fut: Future<Output = RpcResult<HttpResponse>> + 'static + Send,
{
let (shutdown_tx, shutdown_rx) = bounded(1);
let call_actor_fn = Arc::new(call_actor_fn);
Self {
inner: Arc::new(Inner {
settings,
lattice_id,
shutdown_tx,
shutdown_rx,
call_actor: CallActorFn(Box::new(
move |lattice: String,
ld: Arc<LinkDefinition>,
req: HttpRequest,
timeout: Option<Duration>| {
let call_actor_fn = call_actor_fn.clone();
Box::pin(call_actor_fn(lattice, ld, req, timeout))
},
)),
}),
}
}
pub fn begin_shutdown(&self) {
let _ = self.shutdown_tx.try_send(true);
}
pub async fn start(&self, ld: &LinkDefinition) -> Result<JoinHandle<()>, Error> {
let timeout = self
.inner
.settings
.timeout_ms
.map(std::time::Duration::from_millis);
let ld = Arc::new(ld.clone());
let linkdefs = ld.clone();
let trace_ld = ld.clone();
let arc_inner = self.inner.clone();
let route = warp::any()
.and(warp::header::headers_cloned())
.and(warp::method())
.and(warp::body::bytes())
.and(warp::path::full())
.and(opt_raw_query())
.and_then(
move |
headers: HeaderMap,
method: http::method::Method,
body: Bytes,
path: FullPath,
query: String| {
let span = tracing::debug_span!("http request", %method, path = %path.as_str(), %query);
let ld = linkdefs.clone();
let arc_inner = arc_inner.clone();
async move {
let hmap = convert_request_headers(&headers);
let req = HttpRequest {
body: Vec::from(body),
header: hmap,
method: method.as_str().to_ascii_uppercase(),
path: path.as_str().to_string(),
query_string: query,
};
trace!(
?req,
"httpserver calling actor"
);
let response = match arc_inner.call_actor.call(arc_inner.lattice_id.clone(), ld, req, timeout).in_current_span().await {
Ok(resp) => resp,
Err(e) => {
error!(
error = %e,
"Error sending HttpRequest to actor"
);
HttpResponse {
status_code: http::StatusCode::INTERNAL_SERVER_ERROR.as_u16(),
body: Default::default(),
header: Default::default(),
}
}
};
let mut http_response = http::response::Response::new(response.body);
let status = match http::StatusCode::from_u16(response.status_code) {
Ok(status_code) => status_code,
Err(e) => {
error!(
status_code = %response.status_code,
error = %e,
"invalid response status code, changing to 500"
);
http::StatusCode::INTERNAL_SERVER_ERROR
}
};
*http_response.status_mut() = status;
convert_response_headers(response.header, http_response.headers_mut());
Ok::<_, warp::Rejection>(http_response)
}.instrument(span)
},
).with(warp::trace(move |req_info| {
let actor_id = &trace_ld.actor_id;
let span = tracing::debug_span!("request", method = %req_info.method(), path = %req_info.path(), query = tracing::field::Empty, %actor_id);
if let Some(remote_addr) = req_info.remote_addr() {
span.record("remote_addr", &tracing::field::display(remote_addr));
}
span
}));
let addr = self.settings.address.unwrap();
info!(
%addr,
actor_id = %ld.actor_id,
"httpserver starting listener for actor",
);
let cors = cors_filter(&self.settings)?;
let server = warp::serve(route.with(cors));
let handle = tokio::runtime::Handle::current();
let shutdown_rx = self.shutdown_rx.clone();
let join = if self.settings.tls.is_set() {
let (_, fut) = server
.tls()
.key_path(self.settings.tls.priv_key_file.as_ref().unwrap())
.cert_path(self.settings.tls.cert_file.as_ref().unwrap())
.bind_with_graceful_shutdown(addr, async move {
if let Err(e) = shutdown_rx.recv_async().await {
error!(error = %e, "shutting down httpserver listener");
}
});
handle.spawn(fut)
} else {
let (_, fut) = server
.try_bind_with_graceful_shutdown(addr, async move {
if let Err(e) = shutdown_rx.recv_async().await {
error!(error = %e, "shutting down httpserver listener");
}
})
.map_err(|e| {
Error::Settings(format!(
"failed binding to address '{}' reason: {}",
&addr.to_string(),
e
))
})?;
handle.spawn(fut)
};
Ok(join)
}
}
impl Drop for HttpServerCore {
fn drop(&mut self) {
let _ = self.shutdown_tx.try_send(true);
}
}
fn convert_request_headers(headers: &http::HeaderMap) -> wasmcloud_interface_httpserver::HeaderMap {
let mut hmap = wasmcloud_interface_httpserver::HeaderMap::new();
for k in headers.keys() {
let vals = headers
.get_all(k)
.iter()
.filter_map(|val| val.to_str().ok())
.map(|s| s.to_string())
.collect::<Vec<_>>();
if !vals.is_empty() {
hmap.insert(k.to_string(), vals);
}
}
hmap
}
fn convert_response_headers(
header: wasmcloud_interface_httpserver::HeaderMap,
headers_mut: &mut http::header::HeaderMap,
) {
let map = headers_mut;
for (k, vals) in header.into_iter() {
let name = match http::header::HeaderName::from_bytes(k.as_bytes()) {
Ok(name) => name,
Err(e) => {
error!(
header_name = %k,
error = %e,
"invalid response header name, sending without this header"
);
continue;
}
};
for val in vals.into_iter() {
let value = match http::header::HeaderValue::try_from(val) {
Ok(value) => value,
Err(e) => {
error!(
error = %e,
"Non-ascii header value, skipping this header"
);
continue;
}
};
map.append(&name, value);
}
}
}
fn opt_raw_query() -> impl Filter<Extract = (String,), Error = Infallible> + Copy {
warp::any().and(
warp::filters::query::raw()
.or(warp::any().map(String::default))
.unify(),
)
}
fn cors_filter(settings: &settings::ServiceSettings) -> Result<warp::filters::cors::Cors, Error> {
let mut cors: Builder = warp::cors();
match settings.cors.allowed_origins {
Some(ref allowed_origins) if !allowed_origins.is_empty() => {
cors = cors.allow_origins(allowed_origins.iter().map(AsRef::as_ref));
}
_ => {
cors = cors.allow_any_origin();
}
}
if let Some(ref allowed_headers) = settings.cors.allowed_headers {
cors = cors.allow_headers(allowed_headers.iter());
}
if let Some(ref allowed_methods) = settings.cors.allowed_methods {
for m in allowed_methods.iter() {
match http::method::Method::try_from(m.as_str()) {
Err(_) => return Err(Error::InvalidParameter(format!("method: '{}'", m))),
Ok(method) => {
cors = cors.allow_method(method);
}
}
}
}
if let Some(ref exposed_headers) = settings.cors.exposed_headers {
cors = cors.expose_headers(exposed_headers.iter());
}
if let Some(max_age) = settings.cors.max_age_secs {
cors = cors.max_age(std::time::Duration::from_secs(max_age));
}
Ok(cors.build())
}
pub fn convert_human_size(value: &str) -> Result<u64, Error> {
let value = value.trim();
let mut limit = None;
if value.is_empty() {
limit = Some(DEFAULT_MAX_CONTENT_LEN);
} else if let Ok(num) = value.parse::<u64>() {
limit = Some(num);
} else {
let (num, units) = value.split_at(value.len() - 1);
if let Ok(base_value) = num.trim().parse::<u64>() {
match units {
"k" | "K" => {
limit = Some(base_value * 1024);
}
"m" | "M" => {
limit = Some(base_value * 1024 * 1024);
}
"g" | "G" => {
limit = Some(base_value * 1024 * 1024 * 1024);
}
_ => {}
}
}
}
match limit {
Some(x) if x > 0 && x <= CONTENT_LEN_LIMIT => Ok(x),
Some(_) => {
Err(Error::Settings(format!("Invalid size in max_content_len '{}': value must be >0 and <= {}", value, settings::CONTENT_LEN_LIMIT)))
}
None => {
Err(Error::Settings(format!("Invalid size in max_content_len: '{}'. Should be a number, optionally followed by 'K', 'M', or 'G'. Example: '10M'. Value must be <= i32::MAX", value)))
}
}
}
impl CallActorFn {
fn call(
&self,
lattice_id: String,
ld: Arc<LinkDefinition>,
req: HttpRequest,
timeout: Option<Duration>,
) -> Pin<Box<dyn Future<Output = RpcResult<HttpResponse>> + Send + 'static>> {
Box::pin((self.0.as_ref())(lattice_id, ld, req, timeout))
}
}
#[test]
fn parse_max_content_len() {
assert_eq!(convert_human_size("").unwrap(), DEFAULT_MAX_CONTENT_LEN);
assert_eq!(convert_human_size("4").unwrap(), 4);
assert_eq!(convert_human_size("12345678").unwrap(), 12345678);
assert_eq!(convert_human_size("2k").unwrap(), 2 * 1024);
assert_eq!(convert_human_size("2K").unwrap(), 2 * 1024);
assert_eq!(convert_human_size("10m").unwrap(), 10 * 1024 * 1024);
assert_eq!(convert_human_size("10M").unwrap(), 10 * 1024 * 1024);
assert_eq!(convert_human_size("10 M").unwrap(), 10 * 1024 * 1024);
assert_eq!(convert_human_size(" 5 k ").unwrap(), 5 * 1024);
assert!(convert_human_size("k").is_err());
assert!(convert_human_size("0").is_err());
assert!(convert_human_size("1mb").is_err());
assert!(convert_human_size(&i32::MAX.to_string()).is_err());
assert!(convert_human_size(&(i32::MAX as u64 + 1).to_string()).is_err());
}