use core::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{ready, Context, Poll};
use http::{header, HeaderMap, HeaderValue, Method, Request, Response, StatusCode, Version};
use pin_project::pin_project;
use tonic::metadata::GRPC_CONTENT_TYPE;
use tonic::{body::Body, server::NamedService};
use tower_service::Service;
use tracing::{debug, trace};
use crate::call::content_types::is_grpc_web;
use crate::call::{Encoding, GrpcWebCall};
#[derive(Debug, Clone)]
pub struct GrpcWebService<S> {
inner: S,
}
#[derive(Debug, PartialEq)]
enum RequestKind<'a> {
GrpcWeb {
method: &'a Method,
encoding: Encoding,
accept: Encoding,
},
Other(http::Version),
}
impl<S> GrpcWebService<S> {
pub(crate) fn new(inner: S) -> Self {
GrpcWebService { inner }
}
}
impl<S, ReqBody, ResBody> Service<Request<ReqBody>> for GrpcWebService<S>
where
S: Service<Request<Body>, Response = Response<ResBody>>,
ReqBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
ReqBody::Error: Into<crate::BoxError> + fmt::Display,
ResBody: http_body::Body<Data = bytes::Bytes> + Send + 'static,
ResBody::Error: Into<crate::BoxError> + fmt::Display,
{
type Response = Response<Body>;
type Error = S::Error;
type Future = ResponseFuture<S::Future>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
match RequestKind::new(req.headers(), req.method(), req.version()) {
RequestKind::GrpcWeb {
method: &Method::POST,
encoding,
accept,
} => {
trace!(kind = "simple", path = ?req.uri().path(), ?encoding, ?accept);
ResponseFuture {
case: Case::GrpcWeb {
future: self.inner.call(coerce_request(req, encoding)),
accept,
},
}
}
RequestKind::GrpcWeb { .. } => {
debug!(kind = "simple", error="method not allowed", method = ?req.method());
ResponseFuture {
case: Case::immediate(StatusCode::METHOD_NOT_ALLOWED),
}
}
RequestKind::Other(Version::HTTP_2) => {
debug!(kind = "other h2", content_type = ?req.headers().get(header::CONTENT_TYPE));
ResponseFuture {
case: Case::Other {
future: self.inner.call(req.map(Body::new)),
},
}
}
RequestKind::Other(_) => {
debug!(kind = "other h1", content_type = ?req.headers().get(header::CONTENT_TYPE));
ResponseFuture {
case: Case::immediate(StatusCode::BAD_REQUEST),
}
}
}
}
}
#[pin_project]
#[must_use = "futures do nothing unless polled"]
pub struct ResponseFuture<F> {
#[pin]
case: Case<F>,
}
#[pin_project(project = CaseProj)]
enum Case<F> {
GrpcWeb {
#[pin]
future: F,
accept: Encoding,
},
Other {
#[pin]
future: F,
},
ImmediateResponse {
res: Option<http::response::Parts>,
},
}
impl<F> Case<F> {
fn immediate(status: StatusCode) -> Self {
let (res, ()) = Response::builder()
.status(status)
.body(())
.unwrap()
.into_parts();
Self::ImmediateResponse { res: Some(res) }
}
}
impl<F, B, E> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<B>, E>>,
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
type Output = Result<Response<Body>, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.case.project() {
CaseProj::GrpcWeb { future, accept } => {
let res = ready!(future.poll(cx))?;
Poll::Ready(Ok(coerce_response(res, *accept)))
}
CaseProj::Other { future } => future.poll(cx).map_ok(|res| res.map(Body::new)),
CaseProj::ImmediateResponse { res } => {
let res = Response::from_parts(res.take().unwrap(), Body::empty());
Poll::Ready(Ok(res))
}
}
}
}
impl<S: NamedService> NamedService for GrpcWebService<S> {
const NAME: &'static str = S::NAME;
}
impl<F> fmt::Debug for ResponseFuture<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}
impl<'a> RequestKind<'a> {
fn new(headers: &'a HeaderMap, method: &'a Method, version: Version) -> Self {
if is_grpc_web(headers) {
return RequestKind::GrpcWeb {
method,
encoding: Encoding::from_content_type(headers),
accept: Encoding::from_accept(headers),
};
}
RequestKind::Other(version)
}
}
fn coerce_request<B>(mut req: Request<B>, encoding: Encoding) -> Request<Body>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
req.headers_mut().remove(header::CONTENT_LENGTH);
req.headers_mut()
.insert(header::CONTENT_TYPE, GRPC_CONTENT_TYPE);
req.headers_mut()
.insert(header::TE, HeaderValue::from_static("trailers"));
req.headers_mut().insert(
header::ACCEPT_ENCODING,
HeaderValue::from_static("identity,deflate,gzip"),
);
req.map(|b| Body::new(GrpcWebCall::request(b, encoding)))
}
fn coerce_response<B>(res: Response<B>, encoding: Encoding) -> Response<Body>
where
B: http_body::Body<Data = bytes::Bytes> + Send + 'static,
B::Error: Into<crate::BoxError> + fmt::Display,
{
let mut res = res
.map(|b| GrpcWebCall::response(b, encoding))
.map(Body::new);
res.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static(encoding.to_content_type()),
);
res
}
#[cfg(test)]
mod tests {
use super::*;
use crate::call::content_types::*;
use http::header::{
ACCESS_CONTROL_REQUEST_HEADERS, ACCESS_CONTROL_REQUEST_METHOD, CONTENT_TYPE, ORIGIN,
};
use tower_layer::Layer as _;
type BoxFuture<T, E> = Pin<Box<dyn Future<Output = Result<T, E>> + Send>>;
#[derive(Debug, Clone)]
struct Svc;
impl<B> tower_service::Service<Request<B>> for Svc {
type Response = Response<Body>;
type Error = std::convert::Infallible;
type Future = BoxFuture<Self::Response, Self::Error>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _: Request<B>) -> Self::Future {
Box::pin(async { Ok(Response::new(Body::default())) })
}
}
impl NamedService for Svc {
const NAME: &'static str = "test";
}
fn enable<S>(service: S) -> tower_http::cors::Cors<GrpcWebService<S>>
where
S: Service<http::Request<Body>, Response = http::Response<Body>>,
{
tower_layer::Stack::new(
crate::GrpcWebLayer::new(),
tower_http::cors::CorsLayer::new(),
)
.layer(service)
}
mod grpc_web {
use super::*;
use tower_layer::Layer;
fn request() -> Request<Body> {
Request::builder()
.method(Method::POST)
.header(CONTENT_TYPE, GRPC_WEB)
.header(ORIGIN, "http://example.com")
.body(Body::default())
.unwrap()
}
#[tokio::test]
async fn default_cors_config() {
let mut svc = enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn web_layer() {
let mut svc = crate::GrpcWebLayer::new().layer(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn web_layer_with_axum() {
let mut svc = axum::routing::Router::new()
.route("/", axum::routing::post_service(Svc))
.layer(crate::GrpcWebLayer::new());
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn without_origin() {
let mut svc = enable(Svc);
let mut req = request();
req.headers_mut().remove(ORIGIN);
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
#[tokio::test]
async fn only_post_and_options_allowed() {
let mut svc = enable(Svc);
for method in &[
Method::GET,
Method::PUT,
Method::DELETE,
Method::HEAD,
Method::PATCH,
] {
let mut req = request();
*req.method_mut() = method.clone();
let res = svc.call(req).await.unwrap();
assert_eq!(
res.status(),
StatusCode::METHOD_NOT_ALLOWED,
"{method} should not be allowed"
);
}
}
#[tokio::test]
async fn grpc_web_content_types() {
let mut svc = enable(Svc);
for ct in &[GRPC_WEB_TEXT, GRPC_WEB_PROTO, GRPC_WEB_TEXT_PROTO, GRPC_WEB] {
let mut req = request();
req.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static(ct));
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
}
}
mod options {
use super::*;
fn request() -> Request<Body> {
Request::builder()
.method(Method::OPTIONS)
.header(ORIGIN, "http://example.com")
.header(ACCESS_CONTROL_REQUEST_HEADERS, "x-grpc-web")
.header(ACCESS_CONTROL_REQUEST_METHOD, "POST")
.body(Body::default())
.unwrap()
}
#[tokio::test]
async fn valid_grpc_web_preflight() {
let mut svc = enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
}
}
mod grpc {
use super::*;
fn request() -> Request<Body> {
Request::builder()
.version(Version::HTTP_2)
.header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
.body(Body::default())
.unwrap()
}
#[tokio::test]
async fn h2_is_ok() {
let mut svc = enable(Svc);
let req = request();
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
#[tokio::test]
async fn h1_is_err() {
let mut svc = enable(Svc);
let req = Request::builder()
.header(CONTENT_TYPE, GRPC_CONTENT_TYPE)
.body(Body::default())
.unwrap();
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST)
}
#[tokio::test]
async fn content_type_variants() {
let mut svc = enable(Svc);
for variant in &["grpc", "grpc+proto", "grpc+thrift", "grpc+foo"] {
let mut req = request();
req.headers_mut().insert(
CONTENT_TYPE,
HeaderValue::from_maybe_shared(format!("application/{variant}")).unwrap(),
);
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
}
}
mod other {
use super::*;
fn request() -> Request<Body> {
Request::builder()
.header(CONTENT_TYPE, "application/text")
.body(Body::default())
.unwrap()
}
#[tokio::test]
async fn h1_is_err() {
let mut svc = enable(Svc);
let res = svc.call(request()).await.unwrap();
assert_eq!(res.status(), StatusCode::BAD_REQUEST)
}
#[tokio::test]
async fn h2_is_ok() {
let mut svc = enable(Svc);
let mut req = request();
*req.version_mut() = Version::HTTP_2;
let res = svc.call(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK)
}
}
}