use super::error::HttpError;
use super::error_status_code::ErrorStatusCode;
use super::extractor::RequestExtractor;
use super::http_util::CONTENT_TYPE_JSON;
use super::http_util::CONTENT_TYPE_OCTET_STREAM;
use super::server::DropshotState;
use super::server::ServerContext;
use crate::api_description::ApiEndpointBodyContentType;
use crate::api_description::ApiEndpointHeader;
use crate::api_description::ApiEndpointResponse;
use crate::api_description::ApiSchemaGenerator;
use crate::api_description::StubContext;
use crate::body::Body;
use crate::pagination::PaginationParams;
use crate::router::VariableSet;
use crate::schema_util::make_subschema_for;
use crate::schema_util::schema2struct;
use crate::schema_util::ReferenceVisitor;
use crate::to_map::to_map;
use async_trait::async_trait;
use http::HeaderMap;
use http::StatusCode;
use hyper::Response;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde::Serialize;
use slog::Logger;
use std::cmp::min;
use std::convert::TryFrom;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::fmt::Result as FmtResult;
use std::future::Future;
use std::marker::PhantomData;
use std::num::NonZeroU32;
use std::sync::Arc;
pub type HttpHandlerResult = Result<Response<Body>, HttpError>;
#[derive(Debug)]
#[non_exhaustive]
pub struct RequestContext<Context: ServerContext> {
pub server: Arc<DropshotState<Context>>,
pub endpoint: RequestEndpointMetadata,
pub request_id: String,
pub log: Logger,
pub request: RequestInfo,
}
#[derive(Debug)]
pub struct RequestInfo {
method: http::Method,
uri: http::Uri,
version: http::Version,
headers: http::HeaderMap<http::HeaderValue>,
remote_addr: std::net::SocketAddr,
}
impl RequestInfo {
pub fn new<B>(
request: &hyper::Request<B>,
remote_addr: std::net::SocketAddr,
) -> Self {
RequestInfo {
method: request.method().clone(),
uri: request.uri().clone(),
version: request.version(),
headers: request.headers().clone(),
remote_addr,
}
}
}
impl RequestInfo {
pub fn method(&self) -> &http::Method {
&self.method
}
pub fn uri(&self) -> &http::Uri {
&self.uri
}
pub fn version(&self) -> http::Version {
self.version
}
pub fn headers(&self) -> &http::HeaderMap<http::HeaderValue> {
&self.headers
}
pub fn remote_addr(&self) -> std::net::SocketAddr {
self.remote_addr
}
#[deprecated(
since = "0.9.0",
note = "use `rqctx.request` directly instead of \
`rqctx.request.lock().await`"
)]
pub async fn lock(&self) -> &Self {
self
}
}
impl<Context: ServerContext> RequestContext<Context> {
pub fn context(&self) -> &Context {
&self.server.private
}
pub fn request_body_max_bytes(&self) -> usize {
self.endpoint
.request_body_max_bytes
.unwrap_or(self.server.config.default_request_body_max_bytes)
}
pub fn page_limit<ScanParams, PageSelector>(
&self,
pag_params: &PaginationParams<ScanParams, PageSelector>,
) -> Result<NonZeroU32, HttpError>
where
ScanParams: DeserializeOwned,
PageSelector: DeserializeOwned + Serialize,
{
let server_config = &self.server.config;
Ok(pag_params
.limit
.map(|limit| min(limit, server_config.page_max_nitems))
.unwrap_or(server_config.page_default_nitems))
}
}
#[derive(Debug)]
pub struct RequestEndpointMetadata {
pub operation_id: String,
pub variables: VariableSet,
pub body_content_type: ApiEndpointBodyContentType,
pub request_body_max_bytes: Option<usize>,
}
pub trait RequestContextArgument {
type Context;
}
impl<T: 'static + ServerContext> RequestContextArgument for RequestContext<T> {
type Context = T;
}
#[async_trait]
pub trait HttpHandlerFunc<Context, FuncParams, ResponseType>:
Send + Sync + 'static
where
Context: ServerContext,
FuncParams: RequestExtractor,
ResponseType: HttpResponse + Send + Sync + 'static,
{
type Error: HttpResponseError;
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
params: FuncParams,
) -> Result<Response<Body>, HandlerError>;
}
pub enum HandlerError {
Handler { message: String, rsp: Response<Body> },
Dropshot(HttpError),
}
impl HandlerError {
pub(crate) fn status_code(&self) -> StatusCode {
match self {
Self::Handler { ref rsp, .. } => rsp.status(),
Self::Dropshot(ref e) => e.status_code.as_status(),
}
}
pub(crate) fn internal_message(&self) -> &String {
match self {
Self::Handler { ref message, .. } => message,
Self::Dropshot(ref e) => &e.internal_message,
}
}
pub(crate) fn external_message(&self) -> Option<&String> {
match self {
Self::Handler { .. } => None,
Self::Dropshot(ref e) => Some(&e.external_message),
}
}
pub(crate) fn into_response(self, request_id: &str) -> Response<Body> {
match self {
Self::Handler { mut rsp, .. } => {
match http::HeaderValue::from_str(request_id) {
Ok(header) => {
rsp.headers_mut()
.insert(crate::HEADER_REQUEST_ID, header);
}
Err(e) => {
unreachable!(
"request ID {request_id:?} is not a valid \
HeaderValue: {e}",
);
}
}
rsp
}
Self::Dropshot(e) => e.into_response(request_id),
}
}
}
impl<E> From<E> for HandlerError
where
E: HttpResponseError,
{
fn from(e: E) -> Self {
let message = e.to_string();
let status = e.status_code();
match e.to_response(Response::builder().status(status.as_status())) {
Ok(rsp) => Self::Handler { message, rsp },
Err(e) => Self::Dropshot(e),
}
}
}
#[diagnostic::on_unimplemented(
note = "consider using `dropshot::HttpError`, unless custom error \
presentation is needed"
)]
pub trait HttpResponseError:
HttpResponseContent + From<HttpError> + std::fmt::Display
{
fn status_code(&self) -> ErrorStatusCode;
}
macro_rules! impl_HttpHandlerFunc_for_func_with_params {
($(($i:tt, $T:tt)),*) => {
#[async_trait]
impl<Context, FuncType, FutureType, ResponseType, ErrorType, $($T,)*>
HttpHandlerFunc<Context, ($($T,)*), ResponseType> for FuncType
where
Context: ServerContext,
FuncType: Fn(RequestContext<Context>, $($T,)*)
-> FutureType + Send + Sync + 'static,
FutureType: Future<Output = Result<ResponseType, ErrorType>>
+ Send + 'static,
ResponseType: HttpResponse + Send + Sync + 'static,
ErrorType: HttpResponseError + Send + Sync + 'static,
($($T,)*): RequestExtractor,
$($T: Send + Sync + 'static,)*
{
type Error = ErrorType;
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
_param_tuple: ($($T,)*),
) -> Result<Response<Body>, HandlerError>
{
let response: ResponseType = (self)(rqctx, $(_param_tuple.$i,)*).await?;
response.to_result().map_err(|error| {
// If turning the endpoint's response into a
// `http::Response<Body>` failed, try to convert the `HttpError` into
// the endpoint's error type.
let error = ErrorType::from(error);
HandlerError::from(error)
})
}
}
}}
impl_HttpHandlerFunc_for_func_with_params!();
impl_HttpHandlerFunc_for_func_with_params!((0, T0));
impl_HttpHandlerFunc_for_func_with_params!((0, T1), (1, T2));
impl_HttpHandlerFunc_for_func_with_params!((0, T1), (1, T2), (2, T3));
#[async_trait]
pub trait RouteHandler<Context: ServerContext>: Debug + Send + Sync {
fn label(&self) -> &str;
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<Response<Body>, HandlerError>;
}
pub struct HttpRouteHandler<Context, HandlerType, FuncParams, ResponseType>
where
Context: ServerContext,
HandlerType: HttpHandlerFunc<Context, FuncParams, ResponseType>,
FuncParams: RequestExtractor,
ResponseType: HttpResponse + Send + Sync + 'static,
{
handler: HandlerType,
label: String,
phantom: PhantomData<(FuncParams, ResponseType, Context)>,
}
impl<Context, HandlerType, FuncParams, ResponseType> Debug
for HttpRouteHandler<Context, HandlerType, FuncParams, ResponseType>
where
Context: ServerContext,
HandlerType: HttpHandlerFunc<Context, FuncParams, ResponseType>,
FuncParams: RequestExtractor,
ResponseType: HttpResponse + Send + Sync + 'static,
{
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(f, "handler: {}", self.label)
}
}
#[async_trait]
impl<Context, HandlerType, FuncParams, ResponseType> RouteHandler<Context>
for HttpRouteHandler<Context, HandlerType, FuncParams, ResponseType>
where
Context: ServerContext,
HandlerType: HttpHandlerFunc<Context, FuncParams, ResponseType>,
FuncParams: RequestExtractor + 'static,
ResponseType: HttpResponse + Send + Sync + 'static,
{
fn label(&self) -> &str {
&self.label
}
async fn handle_request(
&self,
rqctx: RequestContext<Context>,
request: hyper::Request<crate::Body>,
) -> Result<Response<Body>, HandlerError> {
let funcparams = RequestExtractor::from_request(&rqctx, request)
.await
.map_err(<HandlerType::Error>::from)?;
let future = self.handler.handle_request(rqctx, funcparams);
future.await
}
}
impl<Context, HandlerType, FuncParams, ResponseType>
HttpRouteHandler<Context, HandlerType, FuncParams, ResponseType>
where
Context: ServerContext,
HandlerType: HttpHandlerFunc<Context, FuncParams, ResponseType>,
FuncParams: RequestExtractor + 'static,
ResponseType: HttpResponse + Send + Sync + 'static,
{
pub fn new(handler: HandlerType) -> Arc<dyn RouteHandler<Context>> {
HttpRouteHandler::new_with_name(handler, "<unlabeled handler>")
}
pub fn new_with_name(
handler: HandlerType,
label: &str,
) -> Arc<dyn RouteHandler<Context>> {
Arc::new(HttpRouteHandler {
label: label.to_string(),
handler,
phantom: PhantomData,
})
}
}
#[derive(Debug)]
pub(crate) struct StubRouteHandler {
label: String,
}
#[async_trait]
impl RouteHandler<StubContext> for StubRouteHandler {
fn label(&self) -> &str {
&self.label
}
async fn handle_request(
&self,
_: RequestContext<StubContext>,
_: hyper::Request<crate::Body>,
) -> Result<Response<Body>, HandlerError> {
unimplemented!("stub handler called, not implemented: {}", self.label)
}
}
impl StubRouteHandler {
pub(crate) fn new_with_name(
label: &str,
) -> Arc<dyn RouteHandler<StubContext>> {
Arc::new(StubRouteHandler { label: label.to_string() })
}
}
pub trait HttpResponse {
fn to_result(self) -> HttpHandlerResult;
fn response_metadata() -> ApiEndpointResponse;
fn status_code(&self) -> StatusCode;
}
impl HttpResponse for Response<Body> {
fn to_result(self) -> HttpHandlerResult {
Ok(self)
}
fn response_metadata() -> ApiEndpointResponse {
ApiEndpointResponse::default()
}
fn status_code(&self) -> StatusCode {
self.status()
}
}
pub struct FreeformBody(pub Body);
impl From<Body> for FreeformBody {
fn from(body: Body) -> Self {
Self(body)
}
}
#[doc(hidden)]
pub struct Empty;
pub trait HttpResponseContent {
fn to_response(self, builder: http::response::Builder)
-> HttpHandlerResult;
fn content_metadata() -> Option<ApiSchemaGenerator>;
}
impl HttpResponseContent for FreeformBody {
fn to_response(
self,
builder: http::response::Builder,
) -> HttpHandlerResult {
Ok(builder
.header(http::header::CONTENT_TYPE, CONTENT_TYPE_OCTET_STREAM)
.body(self.0)?)
}
fn content_metadata() -> Option<ApiSchemaGenerator> {
None
}
}
impl HttpResponseContent for Empty {
fn to_response(
self,
builder: http::response::Builder,
) -> HttpHandlerResult {
Ok(builder.body(Body::empty())?)
}
fn content_metadata() -> Option<ApiSchemaGenerator> {
Some(ApiSchemaGenerator::Static {
schema: Box::new(schemars::schema::Schema::Bool(false)),
dependencies: indexmap::IndexMap::default(),
})
}
}
impl<T> HttpResponseContent for T
where
T: JsonSchema + Serialize + Send + Sync + 'static,
{
fn to_response(
self,
builder: http::response::Builder,
) -> HttpHandlerResult {
let serialized = serde_json::to_string(&self)
.map_err(|e| HttpError::for_internal_error(e.to_string()))?;
Ok(builder
.header(http::header::CONTENT_TYPE, CONTENT_TYPE_JSON)
.body(serialized.into())?)
}
fn content_metadata() -> Option<ApiSchemaGenerator> {
Some(ApiSchemaGenerator::Gen {
name: Self::schema_name,
schema: make_subschema_for::<Self>,
})
}
}
impl HttpResponseContent for HttpError {
fn to_response(self, _: http::response::Builder) -> HttpHandlerResult {
Err(self)
}
fn content_metadata() -> Option<ApiSchemaGenerator> {
use crate::error::HttpErrorResponseBody;
Some(ApiSchemaGenerator::Gen {
name: HttpErrorResponseBody::schema_name,
schema: make_subschema_for::<HttpErrorResponseBody>,
})
}
}
impl HttpResponseError for HttpError {
fn status_code(&self) -> ErrorStatusCode {
self.status_code
}
}
pub trait HttpCodedResponse:
Into<HttpHandlerResult> + Send + Sync + 'static
{
type Body: HttpResponseContent;
const STATUS_CODE: StatusCode;
const DESCRIPTION: &'static str;
fn for_object(body: Self::Body) -> HttpHandlerResult {
body.to_response(Response::builder().status(Self::STATUS_CODE))
}
}
impl<T> HttpResponse for T
where
T: HttpCodedResponse,
{
fn to_result(self) -> HttpHandlerResult {
self.into()
}
fn response_metadata() -> ApiEndpointResponse {
ApiEndpointResponse {
schema: T::Body::content_metadata(),
success: Some(T::STATUS_CODE),
description: Some(T::DESCRIPTION.to_string()),
..Default::default()
}
}
fn status_code(&self) -> StatusCode {
T::STATUS_CODE
}
}
pub struct HttpResponseCreated<T: HttpResponseContent + Send + Sync + 'static>(
pub T,
);
impl<T: HttpResponseContent + Send + Sync + 'static> HttpCodedResponse
for HttpResponseCreated<T>
{
type Body = T;
const STATUS_CODE: StatusCode = StatusCode::CREATED;
const DESCRIPTION: &'static str = "successful creation";
}
impl<T: HttpResponseContent + Send + Sync + 'static>
From<HttpResponseCreated<T>> for HttpHandlerResult
{
fn from(response: HttpResponseCreated<T>) -> HttpHandlerResult {
HttpResponseCreated::for_object(response.0)
}
}
pub struct HttpResponseAccepted<T: HttpResponseContent + Send + Sync + 'static>(
pub T,
);
impl<T: HttpResponseContent + Send + Sync + 'static> HttpCodedResponse
for HttpResponseAccepted<T>
{
type Body = T;
const STATUS_CODE: StatusCode = StatusCode::ACCEPTED;
const DESCRIPTION: &'static str = "successfully enqueued operation";
}
impl<T: HttpResponseContent + Send + Sync + 'static>
From<HttpResponseAccepted<T>> for HttpHandlerResult
{
fn from(response: HttpResponseAccepted<T>) -> HttpHandlerResult {
HttpResponseAccepted::for_object(response.0)
}
}
pub struct HttpResponseOk<T: HttpResponseContent + Send + Sync + 'static>(
pub T,
);
impl<T: HttpResponseContent + Send + Sync + 'static> HttpCodedResponse
for HttpResponseOk<T>
{
type Body = T;
const STATUS_CODE: StatusCode = StatusCode::OK;
const DESCRIPTION: &'static str = "successful operation";
}
impl<T: HttpResponseContent + Send + Sync + 'static> From<HttpResponseOk<T>>
for HttpHandlerResult
{
fn from(response: HttpResponseOk<T>) -> HttpHandlerResult {
HttpResponseOk::for_object(response.0)
}
}
pub struct HttpResponseDeleted();
impl HttpCodedResponse for HttpResponseDeleted {
type Body = Empty;
const STATUS_CODE: StatusCode = StatusCode::NO_CONTENT;
const DESCRIPTION: &'static str = "successful deletion";
}
impl From<HttpResponseDeleted> for HttpHandlerResult {
fn from(_: HttpResponseDeleted) -> HttpHandlerResult {
HttpResponseDeleted::for_object(Empty)
}
}
pub struct HttpResponseUpdatedNoContent();
impl HttpCodedResponse for HttpResponseUpdatedNoContent {
type Body = Empty;
const STATUS_CODE: StatusCode = StatusCode::NO_CONTENT;
const DESCRIPTION: &'static str = "resource updated";
}
impl From<HttpResponseUpdatedNoContent> for HttpHandlerResult {
fn from(_: HttpResponseUpdatedNoContent) -> HttpHandlerResult {
HttpResponseUpdatedNoContent::for_object(Empty)
}
}
#[derive(JsonSchema, Serialize)]
#[doc(hidden)]
pub struct RedirectHeaders {
location: String,
}
pub type HttpResponseFound =
HttpResponseHeaders<HttpResponseFoundStatus, RedirectHeaders>;
pub fn http_response_found(
location: String,
) -> Result<HttpResponseFound, HttpError> {
let _ = http::HeaderValue::from_str(&location)
.map_err(|e| http_redirect_error(e, &location))?;
Ok(HttpResponseHeaders::new(
HttpResponseFoundStatus,
RedirectHeaders { location },
))
}
fn http_redirect_error(
error: http::header::InvalidHeaderValue,
location: &str,
) -> HttpError {
HttpError::for_internal_error(format!(
"error encoding redirect URL {:?}: {:#}",
location, error
))
}
#[doc(hidden)]
pub struct HttpResponseFoundStatus;
impl HttpCodedResponse for HttpResponseFoundStatus {
type Body = Empty;
const STATUS_CODE: StatusCode = StatusCode::FOUND;
const DESCRIPTION: &'static str = "redirect (found)";
}
impl From<HttpResponseFoundStatus> for HttpHandlerResult {
fn from(_: HttpResponseFoundStatus) -> HttpHandlerResult {
HttpResponseFoundStatus::for_object(Empty)
}
}
pub type HttpResponseSeeOther =
HttpResponseHeaders<HttpResponseSeeOtherStatus, RedirectHeaders>;
pub fn http_response_see_other(
location: String,
) -> Result<HttpResponseSeeOther, HttpError> {
let _ = http::HeaderValue::from_str(&location)
.map_err(|e| http_redirect_error(e, &location))?;
Ok(HttpResponseHeaders::new(
HttpResponseSeeOtherStatus,
RedirectHeaders { location },
))
}
#[doc(hidden)]
pub struct HttpResponseSeeOtherStatus;
impl HttpCodedResponse for HttpResponseSeeOtherStatus {
type Body = Empty;
const STATUS_CODE: StatusCode = StatusCode::SEE_OTHER;
const DESCRIPTION: &'static str = "redirect (see other)";
}
impl From<HttpResponseSeeOtherStatus> for HttpHandlerResult {
fn from(_: HttpResponseSeeOtherStatus) -> HttpHandlerResult {
HttpResponseSeeOtherStatus::for_object(Empty)
}
}
pub type HttpResponseTemporaryRedirect =
HttpResponseHeaders<HttpResponseTemporaryRedirectStatus, RedirectHeaders>;
pub fn http_response_temporary_redirect(
location: String,
) -> Result<HttpResponseTemporaryRedirect, HttpError> {
let _ = http::HeaderValue::from_str(&location)
.map_err(|e| http_redirect_error(e, &location))?;
Ok(HttpResponseHeaders::new(
HttpResponseTemporaryRedirectStatus,
RedirectHeaders { location },
))
}
#[doc(hidden)]
pub struct HttpResponseTemporaryRedirectStatus;
impl HttpCodedResponse for HttpResponseTemporaryRedirectStatus {
type Body = Empty;
const STATUS_CODE: StatusCode = StatusCode::TEMPORARY_REDIRECT;
const DESCRIPTION: &'static str = "redirect (temporary redirect)";
}
impl From<HttpResponseTemporaryRedirectStatus> for HttpHandlerResult {
fn from(_: HttpResponseTemporaryRedirectStatus) -> HttpHandlerResult {
HttpResponseTemporaryRedirectStatus::for_object(Empty)
}
}
#[derive(Serialize, JsonSchema)]
pub struct NoHeaders {}
pub struct HttpResponseHeaders<
T: HttpCodedResponse,
H: JsonSchema + Serialize + Send + Sync + 'static = NoHeaders,
> {
body: T,
structured_headers: H,
other_headers: HeaderMap,
}
impl<T: HttpCodedResponse> HttpResponseHeaders<T, NoHeaders> {
pub fn new_unnamed(body: T) -> Self {
Self {
body,
structured_headers: NoHeaders {},
other_headers: HeaderMap::default(),
}
}
}
impl<
T: HttpCodedResponse,
H: JsonSchema + Serialize + Send + Sync + 'static,
> HttpResponseHeaders<T, H>
{
pub fn new(body: T, headers: H) -> Self {
Self {
body,
structured_headers: headers,
other_headers: HeaderMap::default(),
}
}
pub fn headers_mut(&mut self) -> &mut HeaderMap {
&mut self.other_headers
}
}
impl<
T: HttpCodedResponse,
H: JsonSchema + Serialize + Send + Sync + 'static,
> HttpResponse for HttpResponseHeaders<T, H>
{
fn to_result(self) -> HttpHandlerResult {
let HttpResponseHeaders { body, structured_headers, other_headers } =
self;
let mut result = body.into()?;
let headers = result.headers_mut();
let header_map = to_map(&structured_headers).map_err(|e| {
HttpError::for_internal_error(format!(
"error processing headers: {}",
e.0
))
})?;
for (key, value) in header_map {
let key = http::header::HeaderName::try_from(key)
.map_err(|e| HttpError::for_internal_error(e.to_string()))?;
let value = http::header::HeaderValue::try_from(value)
.map_err(|e| HttpError::for_internal_error(e.to_string()))?;
headers.insert(key, value);
}
headers.extend(other_headers);
Ok(result)
}
fn response_metadata() -> ApiEndpointResponse {
let mut metadata = T::response_metadata();
let mut generator = schemars::gen::SchemaGenerator::new(
schemars::gen::SchemaSettings::openapi3(),
);
let schema = generator.root_schema_for::<H>().schema.into();
let headers = schema2struct(
&H::schema_name(),
"headers",
&schema,
&generator,
true,
)
.into_iter()
.map(|struct_member| {
let mut s = struct_member.schema;
let mut visitor = ReferenceVisitor::new(&generator);
schemars::visit::visit_schema(&mut visitor, &mut s);
ApiEndpointHeader {
name: struct_member.name,
description: struct_member.description,
schema: ApiSchemaGenerator::Static {
schema: Box::new(s),
dependencies: visitor.dependencies(),
},
required: struct_member.required,
}
})
.collect::<Vec<_>>();
metadata.headers = headers;
metadata
}
fn status_code(&self) -> StatusCode {
T::STATUS_CODE
}
}