use std::collections::HashMap;
use std::convert::Infallible;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use bytes::Bytes;
use http_body_util::BodyExt;
use typeway_grpc::codec::{GrpcCodec, JsonCodec};
use typeway_grpc::framing;
use typeway_grpc::health::HealthService;
use typeway_grpc::reflection::ReflectionService;
use typeway_grpc::service::{GrpcMethodDescriptor, GrpcServiceDescriptor};
use typeway_grpc::status::{http_to_grpc_code, GrpcCode, GrpcStatus};
use typeway_grpc::trailer_body::GrpcBody;
use crate::body::{body_from_bytes, BoxBody};
use crate::handler::BoxedHandler;
use crate::router::{Router, RouterService};
type StateInjector = Arc<dyn Fn(&mut http::Extensions) + Send + Sync>;
pub struct GrpcRouter {
handlers: HashMap<String, GrpcRouteEntry>,
state_injector: Option<StateInjector>,
}
enum GrpcRouteEntry {
Standard {
handler: BoxedHandler,
method_descriptor: GrpcMethodDescriptor,
middleware: Vec<GrpcMiddleware>,
},
#[cfg(feature = "protobuf")]
Direct(crate::grpc_direct::DirectHandler),
}
pub type GrpcMiddleware = Arc<
dyn Fn(
http::request::Parts,
Bytes,
) -> std::pin::Pin<
Box<
dyn std::future::Future<
Output = Result<(http::request::Parts, Bytes), http::Response<BoxBody>>,
> + Send,
>,
> + Send
+ Sync,
>;
impl GrpcRouter {
pub fn from_router(router: &Router, descriptor: &GrpcServiceDescriptor) -> Self {
let mut handlers = HashMap::new();
for method in &descriptor.methods {
if let Some(handler) =
router.find_handler_by_pattern(&method.http_method, &method.rest_path)
{
handlers.insert(
method.full_path.clone(),
GrpcRouteEntry::Standard {
handler,
method_descriptor: method.clone(),
middleware: Vec::new(),
},
);
} else {
tracing::warn!(
"gRPC method {} has no matching REST handler for {} {}",
method.full_path,
method.http_method,
method.rest_path,
);
}
}
let state_injector = router.state_injector();
GrpcRouter {
handlers,
state_injector,
}
}
#[cfg(feature = "protobuf")]
pub fn add_direct_handler(
&mut self,
grpc_path: String,
handler: crate::grpc_direct::DirectHandler,
) {
self.handlers
.insert(grpc_path, GrpcRouteEntry::Direct(handler));
}
pub fn add_middleware(&mut self, grpc_path: &str, middleware: GrpcMiddleware) {
if let Some(GrpcRouteEntry::Standard { middleware: mw, .. }) =
self.handlers.get_mut(grpc_path)
{
mw.push(middleware);
}
}
fn lookup(&self, grpc_path: &str) -> Option<&GrpcRouteEntry> {
self.handlers.get(grpc_path)
}
}
fn build_synthetic_request_raw(
original_parts: http::request::Parts,
state_injector: Option<&StateInjector>,
) -> http::request::Parts {
let mut parts = original_parts;
if let Some(injector) = state_injector {
injector(&mut parts.extensions);
}
parts
}
fn build_synthetic_request(
original_parts: &http::request::Parts,
method_desc: &GrpcMethodDescriptor,
message_json: &serde_json::Value,
state_injector: Option<&StateInjector>,
) -> (http::request::Parts, Bytes) {
let mut builder = http::Request::builder().method(method_desc.http_method.clone());
let rest_path = &method_desc.rest_path;
let uri = if rest_path.contains("{}") {
let capture_values = extract_capture_values(message_json, rest_path);
let mut path = rest_path.clone();
for val in &capture_values {
path = path.replacen("{}", val, 1);
}
path
} else {
rest_path.clone()
};
builder = builder.uri(
uri.parse::<http::Uri>()
.unwrap_or_else(|_| method_desc.rest_path.parse().unwrap_or_default()),
);
let (synthetic_parts, _) = builder.body(()).unwrap().into_parts();
let mut parts = synthetic_parts;
parts.headers = original_parts.headers.clone();
parts.headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
parts.extensions = original_parts.extensions.clone();
if let Some(injector) = state_injector {
injector(&mut parts.extensions);
}
let body_bytes = serde_json::to_vec(message_json).unwrap_or_default();
(parts, Bytes::from(body_bytes))
}
fn extract_capture_values(message: &serde_json::Value, rest_path: &str) -> Vec<String> {
let placeholder_count = rest_path.matches("{}").count();
if placeholder_count == 0 {
return Vec::new();
}
let obj = match message.as_object() {
Some(o) => o,
None => return vec!["".to_string(); placeholder_count],
};
let mut values = Vec::with_capacity(placeholder_count);
for i in 1..=placeholder_count {
let key = format!("param{i}");
if let Some(val) = obj.get(&key) {
values.push(json_value_to_string(val));
}
}
if values.len() == placeholder_count {
return values;
}
values.clear();
let id_fields = ["id", "user_id", "item_id", "name", "slug"];
for field in &id_fields {
if values.len() >= placeholder_count {
break;
}
if let Some(val) = obj.get(*field) {
values.push(json_value_to_string(val));
}
}
if values.len() < placeholder_count {
values.clear();
for (_, val) in obj.iter().take(placeholder_count) {
values.push(json_value_to_string(val));
}
}
while values.len() < placeholder_count {
values.push(String::new());
}
values
}
fn json_value_to_string(val: &serde_json::Value) -> String {
match val {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Number(n) => n.to_string(),
serde_json::Value::Bool(b) => b.to_string(),
_ => val.to_string(),
}
}
#[derive(Clone)]
pub struct GrpcMultiplexer {
pub(crate) rest: RouterService,
pub(crate) grpc_router: Arc<GrpcRouter>,
pub(crate) reflection: Arc<ReflectionService>,
pub(crate) health: HealthService,
pub(crate) reflection_enabled: bool,
pub(crate) grpc_spec_json: Option<Arc<String>>,
pub(crate) grpc_docs_html: Option<Arc<String>>,
#[cfg(feature = "grpc-proto-binary")]
pub(crate) transcoder: Option<Arc<typeway_grpc::transcode::ProtoTranscoder>>,
}
impl GrpcMultiplexer {
pub fn new(
rest: RouterService,
grpc_router: Arc<GrpcRouter>,
reflection: Arc<ReflectionService>,
health: HealthService,
reflection_enabled: bool,
grpc_spec_json: Option<Arc<String>>,
grpc_docs_html: Option<Arc<String>>,
) -> Self {
Self {
rest,
grpc_router,
reflection,
health,
reflection_enabled,
grpc_spec_json,
grpc_docs_html,
#[cfg(feature = "grpc-proto-binary")]
transcoder: None,
}
}
}
fn grpc_json_response(json_body: &str) -> http::Response<BoxBody> {
let framed = framing::encode_grpc_frame(json_body.as_bytes());
let mut res = http::Response::new(body_from_bytes(Bytes::from(framed)));
*res.status_mut() = http::StatusCode::OK;
res.headers_mut()
.insert("grpc-status", http::HeaderValue::from_static("0"));
res.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/grpc+json"),
);
res
}
fn encode_response_bytes(
json_bytes: &[u8],
_grpc_path: &str,
_use_proto_binary: bool,
#[cfg(feature = "grpc-proto-binary")] transcoder: &Option<
Arc<typeway_grpc::transcode::ProtoTranscoder>,
>,
) -> (Vec<u8>, &'static str) {
#[cfg(feature = "grpc-proto-binary")]
if _use_proto_binary {
if let Some(tc) = transcoder.as_ref() {
let json_val: serde_json::Value =
serde_json::from_slice(json_bytes).unwrap_or_default();
match tc.encode_response(_grpc_path, &json_val) {
Ok(proto_bytes) => return (proto_bytes, "application/grpc+proto"),
Err(e) => {
tracing::warn!(
"proto-binary response encode failed for {}: {}",
_grpc_path,
e
);
}
}
}
}
(json_bytes.to_vec(), "application/grpc+json")
}
async fn wrap_response_as_grpc(
rest_response: http::Response<BoxBody>,
method_desc: &GrpcMethodDescriptor,
grpc_path: &str,
use_proto_binary: bool,
#[cfg(feature = "grpc-proto-binary")] transcoder: &Option<
Arc<typeway_grpc::transcode::ProtoTranscoder>,
>,
) -> http::Response<BoxBody> {
let (res_parts, res_body) = rest_response.into_parts();
if res_parts
.extensions
.get::<crate::grpc_stream::GrpcStreamMarker>()
.is_some()
{
let mut response = http::Response::from_parts(res_parts, res_body);
*response.status_mut() = http::StatusCode::OK;
response.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/grpc+json"),
);
return response;
}
let grpc_code = http_to_grpc_code(res_parts.status);
let res_bytes = match res_body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
let grpc_status = GrpcStatus {
code: grpc_code,
message: String::new(),
};
let (framed, response_content_type) =
if method_desc.server_streaming && grpc_code == GrpcCode::Ok {
match serde_json::from_slice::<serde_json::Value>(&res_bytes) {
Ok(serde_json::Value::Array(items)) => {
let mut buf = Vec::new();
let mut ct = "application/grpc+json";
for item in &items {
let item_bytes = serde_json::to_vec(item).unwrap_or_default();
let (encoded, content_type) = encode_response_bytes(
&item_bytes,
grpc_path,
use_proto_binary,
#[cfg(feature = "grpc-proto-binary")]
transcoder,
);
ct = content_type;
buf.extend_from_slice(&framing::encode_grpc_frame(&encoded));
}
(buf, ct)
}
_ => {
let (encoded, ct) = encode_response_bytes(
&res_bytes,
grpc_path,
use_proto_binary,
#[cfg(feature = "grpc-proto-binary")]
transcoder,
);
(framing::encode_grpc_frame(&encoded), ct)
}
}
} else {
let (encoded, ct) = encode_response_bytes(
&res_bytes,
grpc_path,
use_proto_binary,
#[cfg(feature = "grpc-proto-binary")]
transcoder,
);
(framing::encode_grpc_frame(&encoded), ct)
};
let grpc_body = GrpcBody::with_status(Bytes::from(framed), grpc_status);
let boxed_body: BoxBody =
http_body_util::BodyExt::boxed_unsync(http_body_util::BodyExt::map_err(grpc_body, |e| {
match e {}
}));
let mut response = http::Response::new(boxed_body);
*response.status_mut() = http::StatusCode::OK;
response.headers_mut().insert(
"content-type",
response_content_type.parse().expect("valid content-type"),
);
response.headers_mut().insert(
"grpc-status",
grpc_code
.as_i32()
.to_string()
.parse()
.expect("valid grpc-status"),
);
response
}
fn grpc_error_response(status: GrpcStatus) -> http::Response<BoxBody> {
let code = status.code;
let message = status.message.clone();
let grpc_body = GrpcBody::error(status);
let boxed_body: BoxBody =
http_body_util::BodyExt::boxed_unsync(http_body_util::BodyExt::map_err(grpc_body, |e| {
match e {}
}));
let mut res = http::Response::new(boxed_body);
*res.status_mut() = http::StatusCode::OK;
res.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/grpc"),
);
res.headers_mut().insert(
"grpc-status",
code.as_i32()
.to_string()
.parse()
.expect("valid grpc-status"),
);
if !message.is_empty() {
if let Ok(val) = message.parse() {
res.headers_mut().insert("grpc-message", val);
}
}
res
}
impl tower_service::Service<http::Request<hyper::body::Incoming>> for GrpcMultiplexer {
type Response = http::Response<BoxBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: http::Request<hyper::body::Incoming>) -> Self::Future {
let path = req.uri().path();
if req.method() == http::Method::GET && path == "/grpc-spec" {
if let Some(spec_json) = self.grpc_spec_json.clone() {
return Box::pin(async move {
let mut res = http::Response::new(body_from_bytes(Bytes::from(
spec_json.as_bytes().to_vec(),
)));
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json; charset=utf-8"),
);
Ok(res)
});
}
}
if req.method() == http::Method::GET && path == "/grpc-docs" {
if let Some(docs_html) = self.grpc_docs_html.clone() {
return Box::pin(async move {
let mut res = http::Response::new(body_from_bytes(Bytes::from(
docs_html.as_bytes().to_vec(),
)));
res.headers_mut().insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("text/html; charset=utf-8"),
);
Ok(res)
});
}
}
if typeway_grpc::is_grpc_request(&req) {
let grpc_router = self.grpc_router.clone();
let reflection = self.reflection.clone();
let health = self.health.clone();
let reflection_enabled = self.reflection_enabled;
#[cfg(feature = "grpc-proto-binary")]
let transcoder = self.transcoder.clone();
Box::pin(async move {
let grpc_path = req.uri().path().to_string();
if HealthService::is_health_path(&grpc_path) {
let response_json = health.handle_request();
return Ok(grpc_json_response(&response_json));
}
if reflection_enabled && ReflectionService::is_reflection_path(&grpc_path) {
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
let unframed = framing::decode_grpc_frame(&body_bytes).unwrap_or(&body_bytes);
let body_str = String::from_utf8_lossy(unframed);
let _ = parts;
let response_json = reflection.handle_request(&body_str);
return Ok(grpc_json_response(&response_json));
}
let entry = grpc_router.lookup(&grpc_path);
let entry = match entry {
Some(e) => e,
None => {
let status = GrpcStatus::unimplemented(&format!(
"method '{}' not found in service",
grpc_path
));
return Ok(grpc_error_response(status));
}
};
#[cfg(feature = "protobuf")]
if let GrpcRouteEntry::Direct(direct_handler) = entry {
let direct_handler = direct_handler.clone();
let body_bytes = match http_body_util::BodyExt::collect(req.into_body()).await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
return Ok(
crate::grpc_direct::dispatch_direct(&direct_handler, body_bytes).await,
);
}
let (method_desc, handler, rpc_middleware) = match entry {
GrpcRouteEntry::Standard {
handler,
method_descriptor,
middleware,
} => (
method_descriptor.clone(),
handler.clone(),
middleware.clone(),
),
#[cfg(feature = "protobuf")]
GrpcRouteEntry::Direct(_) => unreachable!(),
};
let grpc_timeout = req
.headers()
.get("grpc-timeout")
.and_then(|v| v.to_str().ok())
.and_then(typeway_grpc::parse_grpc_timeout);
#[cfg(feature = "grpc-proto-binary")]
let incoming_content_type =
typeway_grpc::transcode::grpc_content_type(req.headers()).to_string();
#[cfg(feature = "grpc-proto-binary")]
let use_proto_binary = transcoder.is_some()
&& typeway_grpc::transcode::is_proto_binary_content_type(
&incoming_content_type,
);
let (parts, body) = req.into_parts();
let body_bytes = match body.collect().await {
Ok(collected) => collected.to_bytes(),
Err(_) => Bytes::new(),
};
let unframed = framing::decode_grpc_frame(&body_bytes)
.map(|b| b.to_vec())
.unwrap_or_else(|_| body_bytes.to_vec());
#[cfg(feature = "grpc-proto-binary")]
let binary_fast_path = use_proto_binary && !method_desc.rest_path.contains("{}");
#[cfg(not(feature = "grpc-proto-binary"))]
let binary_fast_path = false;
let (synthetic_parts, body_bytes) = if binary_fast_path {
let mut synthetic =
build_synthetic_request_raw(parts, grpc_router.state_injector.as_ref());
synthetic.headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/grpc+proto"),
);
(synthetic, Bytes::from(unframed))
} else {
#[cfg(feature = "grpc-proto-binary")]
let message = if use_proto_binary {
let tc = transcoder.as_ref().unwrap();
match tc.decode_request(&grpc_path, &unframed) {
Ok(json) => json,
Err(e) => {
let status = GrpcStatus {
code: GrpcCode::InvalidArgument,
message: format!("failed to decode binary protobuf: {e}"),
};
return Ok(grpc_error_response(status));
}
}
} else {
match JsonCodec.decode(&unframed) {
Ok(msg) => msg,
Err(e) => {
let status = GrpcStatus {
code: GrpcCode::InvalidArgument,
message: format!("failed to decode request: {e}"),
};
return Ok(grpc_error_response(status));
}
}
};
#[cfg(not(feature = "grpc-proto-binary"))]
let message = match JsonCodec.decode(&unframed) {
Ok(msg) => msg,
Err(e) => {
let status = GrpcStatus {
code: GrpcCode::InvalidArgument,
message: format!("failed to decode request: {e}"),
};
return Ok(grpc_error_response(status));
}
};
build_synthetic_request(
&parts,
&method_desc,
&message,
grpc_router.state_injector.as_ref(),
)
};
let (synthetic_parts, body_bytes) = {
let mut parts = synthetic_parts;
let mut body = body_bytes;
for mw in &rpc_middleware {
match mw(parts, body).await {
Ok((p, b)) => {
parts = p;
body = b;
}
Err(resp) => return Ok(resp),
}
}
(parts, body)
};
let rest_response = if let Some(timeout_duration) = grpc_timeout {
match tokio::time::timeout(
timeout_duration,
handler(synthetic_parts, body_bytes),
)
.await
{
Ok(res) => res,
Err(_) => {
let status = GrpcStatus {
code: GrpcCode::DeadlineExceeded,
message: "deadline exceeded".to_string(),
};
return Ok(grpc_error_response(status));
}
}
} else {
handler(synthetic_parts, body_bytes).await
};
#[cfg(feature = "grpc-proto-binary")]
let use_binary = use_proto_binary;
#[cfg(not(feature = "grpc-proto-binary"))]
let use_binary = false;
Ok(wrap_response_as_grpc(
rest_response,
&method_desc,
&grpc_path,
use_binary,
#[cfg(feature = "grpc-proto-binary")]
&transcoder,
)
.await)
})
} else {
let mut rest = self.rest.clone();
Box::pin(async move { tower_service::Service::call(&mut rest, req).await })
}
}
}