use std::error::Error;
use std::fmt::{self, Display, Formatter};
use std::future::ready;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::{future, Future, TryFutureExt};
use hyper::client::HttpConnector;
use hyper::header::{HeaderMap, ALLOW, CONTENT_LENGTH, CONTENT_TYPE};
use hyper::http::{self, HeaderValue};
use hyper::service::Service;
use hyper::{Body, Client, Method, Request, Response, StatusCode, Uri, Version};
use prost::{DecodeError, EncodeError, Message};
pub type PTRes<O> =
Pin<Box<dyn Future<Output = Result<ServiceResponse<O>, ProstTwirpError>> + Send + 'static>>;
static JSON_CONTENT_TYPE: &str = "application/json";
static PROTOBUF_CONTENT_TYPE: &str = "application/protobuf";
#[derive(Debug)]
pub struct ServiceRequest<T: Message> {
pub uri: Uri,
pub method: Method,
pub version: Version,
pub headers: HeaderMap,
pub input: T,
}
impl<T: Message> ServiceRequest<T> {
pub fn new(input: T) -> ServiceRequest<T> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
);
ServiceRequest {
uri: Default::default(),
method: Method::POST,
version: Version::default(),
headers,
input,
}
}
pub fn clone_with_input(&self, input: T) -> ServiceRequest<T> {
ServiceRequest {
uri: self.uri.clone(),
method: self.method.clone(),
version: self.version,
headers: self.headers.clone(),
input,
}
}
}
impl<T: Message + Default + 'static> From<T> for ServiceRequest<T> {
fn from(v: T) -> ServiceRequest<T> {
ServiceRequest::new(v)
}
}
impl<T: Message + Default + 'static> ServiceRequest<T> {
pub fn to_hyper_request(&self) -> Result<Request<Body>, ProstTwirpError> {
let mut body = Vec::new();
self.input
.encode(&mut body)
.map_err(ProstTwirpError::ProstEncodeError)?;
let mut builder = Request::post(self.uri.clone());
builder.headers_mut().unwrap().clone_from(&self.headers);
builder
.header(CONTENT_LENGTH, body.len() as u64)
.body(Body::from(body))
.map_err(ProstTwirpError::from)
}
pub async fn from_hyper_request(
req: Request<Body>,
) -> Result<ServiceRequest<T>, ProstTwirpError> {
if req.method() != Method::POST {
return Err(ProstTwirpError::InvalidMethod);
} else if req
.headers()
.get(CONTENT_TYPE)
.map_or(true, |v| v != PROTOBUF_CONTENT_TYPE)
{
return Err(ProstTwirpError::InvalidContentType);
}
let uri = req.uri().clone();
let method = req.method().clone();
let version = req.version();
let headers = req.headers().clone();
let body_bytes = hyper::body::to_bytes(req.into_body()).await?;
match T::decode(body_bytes.clone()) {
Ok(input) => Ok(ServiceRequest {
uri,
method,
version,
headers,
input,
}),
Err(err) => Err(ProstTwirpError::AfterBodyError {
status: None,
method: Some(method),
version,
headers,
err: Box::new(ProstTwirpError::ProstDecodeError(err)),
body: body_bytes.to_vec(),
}),
}
}
}
#[derive(Debug)]
pub struct ServiceResponse<M: Message> {
pub version: Version,
pub headers: HeaderMap,
pub status: StatusCode,
pub output: M,
}
impl<M: Message> ServiceResponse<M> {
pub fn new(output: M) -> ServiceResponse<M> {
let mut headers = HeaderMap::new();
headers.insert(
CONTENT_TYPE,
HeaderValue::from_static(PROTOBUF_CONTENT_TYPE),
);
ServiceResponse {
version: Version::default(),
headers,
status: StatusCode::OK,
output,
}
}
pub fn clone_with_output(&self, output: M) -> ServiceResponse<M> {
ServiceResponse {
version: self.version,
headers: self.headers.clone(),
status: self.status,
output,
}
}
}
impl<M: Message + Default + 'static> From<M> for ServiceResponse<M> {
fn from(v: M) -> ServiceResponse<M> {
ServiceResponse::new(v)
}
}
impl<M: Message + Default> ServiceResponse<M> {
pub async fn from_hyper_response(resp: Response<Body>) -> Result<Self, ProstTwirpError> {
let version = resp.version();
let headers = resp.headers().clone();
let status = resp.status();
let body_bytes = hyper::body::to_bytes(resp.into_body()).await?;
let err = if status.is_success() {
match M::decode(&*body_bytes) {
Ok(output) => {
return Ok(ServiceResponse {
version,
headers,
status,
output,
})
}
Err(err) => ProstTwirpError::ProstDecodeError(err),
}
} else {
match TwirpError::from_json_bytes(status, &body_bytes) {
Ok(err) => ProstTwirpError::TwirpError(err),
Err(err) => ProstTwirpError::JsonDecodeError(err),
}
};
Err(ProstTwirpError::AfterBodyError {
body: body_bytes.to_vec(),
method: None,
version,
headers,
status: Some(status),
err: Box::new(err),
})
}
pub fn to_hyper_response(&self) -> Result<Response<Body>, ProstTwirpError> {
let body_bytes = self.output.encode_to_vec();
let mut builder = Response::builder().status(self.status);
builder.headers_mut().unwrap().clone_from(&self.headers);
builder
.header(CONTENT_LENGTH, body_bytes.len() as u64)
.body(body_bytes.into())
.map_err(ProstTwirpError::from)
}
}
#[derive(Debug, Clone)]
pub struct TwirpError {
pub status: StatusCode,
pub error_type: String,
pub msg: String,
pub meta: Option<serde_json::Value>,
}
impl TwirpError {
pub fn new(status: StatusCode, error_type: &str, msg: &str) -> TwirpError {
TwirpError::new_meta(status, error_type, msg, None)
}
pub fn new_meta(
status: StatusCode,
error_type: &str,
msg: &str,
meta: Option<serde_json::Value>,
) -> TwirpError {
TwirpError {
status,
error_type: error_type.to_string(),
msg: msg.to_string(),
meta,
}
}
pub fn to_hyper_response(&self) -> Response<Body> {
let body_bytes = self
.to_json_bytes()
.unwrap_or_else(|_| "{}".as_bytes().to_vec());
let body_len = body_bytes.len() as u64;
Response::builder()
.status(self.status)
.header(CONTENT_TYPE, JSON_CONTENT_TYPE)
.header(CONTENT_LENGTH, HeaderValue::from(body_len))
.header(ALLOW, HeaderValue::from_static("POST"))
.body(Body::from(body_bytes))
.expect("failed to serialize twirp error")
}
pub fn from_json(status: StatusCode, json: serde_json::Value) -> TwirpError {
let error_type = json["error_type"].as_str();
TwirpError {
status,
error_type: error_type.unwrap_or("<no code>").to_string(),
msg: json["msg"].as_str().unwrap_or("<no message>").to_string(),
meta: if error_type.is_some() {
json.get("meta").cloned()
} else {
Some(json.clone())
},
}
}
pub fn from_json_bytes(status: StatusCode, json: &[u8]) -> serde_json::Result<TwirpError> {
serde_json::from_slice(json).map(|v| TwirpError::from_json(status, v))
}
pub fn to_json(&self) -> serde_json::Value {
let mut props = serde_json::map::Map::new();
props.insert(
"error_type".to_string(),
serde_json::Value::String(self.error_type.clone()),
);
props.insert(
"msg".to_string(),
serde_json::Value::String(self.msg.clone()),
);
if let Some(ref meta) = self.meta {
props.insert("meta".to_string(), meta.clone());
}
serde_json::Value::Object(props)
}
pub fn to_json_bytes(&self) -> serde_json::Result<Vec<u8>> {
serde_json::to_vec(&self.to_json())
}
}
impl Error for TwirpError {}
impl Display for TwirpError {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{:?} {}: {}", self.status, self.error_type, self.msg)
}
}
impl From<TwirpError> for ProstTwirpError {
fn from(v: TwirpError) -> ProstTwirpError {
ProstTwirpError::TwirpError(v)
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ProstTwirpError {
TwirpError(TwirpError),
JsonDecodeError(serde_json::Error),
ProstEncodeError(EncodeError),
ProstDecodeError(DecodeError),
HyperError(hyper::Error),
HttpError(http::Error),
InvalidUri(http::uri::InvalidUri),
InvalidMethod,
InvalidContentType,
NotFound,
AfterBodyError {
body: Vec<u8>,
method: Option<Method>,
version: Version,
headers: HeaderMap,
status: Option<StatusCode>,
err: Box<ProstTwirpError>,
},
}
impl ProstTwirpError {
pub fn root_err(self) -> ProstTwirpError {
match self {
ProstTwirpError::AfterBodyError { err, .. } => err.root_err(),
_ => self,
}
}
pub fn into_hyper_response(self) -> Result<Response<Body>, hyper::Error> {
let external_err = match self {
ProstTwirpError::TwirpError(err) => err,
ProstTwirpError::HyperError(err) => return Err(err),
ProstTwirpError::InvalidMethod => TwirpError::new(
StatusCode::METHOD_NOT_ALLOWED,
"bad_method",
"Method must be POST",
),
ProstTwirpError::ProstDecodeError(_) => TwirpError::new(
StatusCode::BAD_REQUEST,
"protobuf_decode_err",
"Invalid protobuf body",
),
ProstTwirpError::InvalidContentType => TwirpError::new(
StatusCode::UNSUPPORTED_MEDIA_TYPE,
"bad_content_type",
"Content type must be application/protobuf",
),
ProstTwirpError::NotFound => TwirpError::new(
StatusCode::NOT_FOUND,
"not_found",
"The requested method was not found",
),
_ => TwirpError::new(
StatusCode::INTERNAL_SERVER_ERROR,
"internal_err",
"Internal error",
),
};
Ok(external_err.to_hyper_response())
}
}
impl From<hyper::Error> for ProstTwirpError {
fn from(v: hyper::Error) -> ProstTwirpError {
ProstTwirpError::HyperError(v)
}
}
impl From<http::Error> for ProstTwirpError {
fn from(v: http::Error) -> ProstTwirpError {
ProstTwirpError::HttpError(v)
}
}
impl Display for ProstTwirpError {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl Error for ProstTwirpError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ProstTwirpError::TwirpError(err) => Some(err),
ProstTwirpError::JsonDecodeError(err) => Some(err),
ProstTwirpError::ProstEncodeError(err) => Some(err),
ProstTwirpError::ProstDecodeError(err) => Some(err),
ProstTwirpError::HyperError(err) => Some(err),
ProstTwirpError::HttpError(err) => Some(err),
ProstTwirpError::InvalidUri(err) => Some(err),
ProstTwirpError::InvalidMethod => None,
ProstTwirpError::InvalidContentType => None,
ProstTwirpError::NotFound => None,
ProstTwirpError::AfterBodyError { err, .. } => Some(err),
}
}
}
#[derive(Debug)]
pub struct HyperClient {
pub client: Client<HttpConnector>,
pub root_url: String,
}
impl HyperClient {
pub fn new(client: Client<HttpConnector>, root_url: &str) -> HyperClient {
HyperClient {
client,
root_url: root_url.trim_end_matches('/').to_string(),
}
}
pub fn go<I, O>(&self, path: &str, req: ServiceRequest<I>) -> PTRes<O>
where
I: Message + Default + 'static,
O: Message + Default + 'static,
{
let uri = match format!("{}/{}", self.root_url, path.trim_start_matches('/')).parse() {
Err(err) => return Box::pin(ready(Err(ProstTwirpError::InvalidUri(err)))),
Ok(v) => v,
};
let mut hyper_req = match req.to_hyper_request() {
Err(err) => return Box::pin(ready(Err(err))),
Ok(v) => v,
};
*hyper_req.uri_mut() = uri;
Box::pin(
self.client
.request(hyper_req)
.map_err(ProstTwirpError::HyperError)
.and_then(ServiceResponse::from_hyper_response),
)
}
}
pub trait HyperService {
fn handle(
&self,
req: Request<Body>,
) -> Pin<Box<dyn Future<Output = Result<Response<Body>, ProstTwirpError>> + Send>>;
}
pub struct HyperServer<T: HyperService + Send + Sync + 'static> {
pub service: Arc<T>,
}
impl<T: HyperService + Send + Sync + 'static> HyperServer<T> {
pub fn new(service: T) -> HyperServer<T> {
HyperServer {
service: Arc::new(service),
}
}
}
impl<T: 'static + HyperService + Send + Sync> Service<Request<Body>> for HyperServer<T> {
type Response = Response<Body>;
type Error = hyper::Error;
type Future = Pin<Box<dyn (Future<Output = Result<Self::Response, Self::Error>>) + Send>>;
fn poll_ready(&mut self, _context: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let service = self.service.clone();
Box::pin(
service
.handle(req)
.or_else(|err| future::ready(err.into_hyper_response())),
)
}
}