use crate::grpc::framing::encode_grpc_message;
use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData};
use crate::grpc::streaming::MessageStream;
use axum::http::{HeaderMap, HeaderValue};
use bytes::Bytes;
use futures_util::StreamExt;
use http_body::Frame;
use http_body_util::StreamBody;
use std::convert::Infallible;
use std::sync::Arc;
use tonic::{Request, Response, Status};
pub struct GenericGrpcService {
handler: Arc<dyn GrpcHandler>,
}
impl GenericGrpcService {
pub fn new(handler: Arc<dyn GrpcHandler>) -> Self {
Self { handler }
}
pub async fn handle_unary(
&self,
service_name: String,
method_name: String,
request: Request<Bytes>,
) -> Result<Response<Bytes>, Status> {
let (metadata, _extensions, payload) = request.into_parts();
let grpc_request = GrpcRequestData {
service_name,
method_name,
payload,
metadata,
};
let result: GrpcHandlerResult = self.handler.call(grpc_request).await;
match result {
Ok(grpc_response) => {
let mut response = Response::new(grpc_response.payload);
copy_metadata(&grpc_response.metadata, response.metadata_mut());
Ok(response)
}
Err(status) => Err(status),
}
}
pub async fn handle_server_stream(
&self,
service_name: String,
method_name: String,
request: Request<Bytes>,
max_stream_response_bytes: Option<usize>,
) -> Result<Response<axum::body::Body>, Status> {
let (metadata, _extensions, payload) = request.into_parts();
let grpc_request = GrpcRequestData {
service_name,
method_name,
payload,
metadata,
};
let message_stream: MessageStream = self.handler.call_server_stream(grpc_request).await?;
let message_stream = crate::grpc::streaming::limit_message_stream(message_stream, max_stream_response_bytes);
let body = grpc_stream_body(message_stream);
let response = Response::new(body);
Ok(response)
}
pub async fn handle_client_stream(
&self,
service_name: String,
method_name: String,
request: Request<axum::body::Body>,
max_message_size: usize,
compression_enabled: bool,
) -> Result<Response<Bytes>, Status> {
let (metadata, _extensions, body) = request.into_parts();
let request_encoding = metadata
.get("grpc-encoding")
.and_then(|value| value.to_str().ok())
.map(str::to_owned);
let message_stream = crate::grpc::framing::parse_grpc_client_stream(
body,
max_message_size,
request_encoding.as_deref(),
compression_enabled,
)
.await?;
let streaming_request = crate::grpc::streaming::StreamingRequest {
service_name,
method_name,
message_stream,
metadata,
};
let response: crate::grpc::handler::GrpcHandlerResult =
self.handler.call_client_stream(streaming_request).await;
match response {
Ok(grpc_response) => {
let mut tonic_response = Response::new(grpc_response.payload);
copy_metadata(&grpc_response.metadata, tonic_response.metadata_mut());
Ok(tonic_response)
}
Err(status) => Err(status),
}
}
pub async fn handle_bidi_stream(
&self,
service_name: String,
method_name: String,
request: Request<axum::body::Body>,
max_message_size: usize,
compression_enabled: bool,
max_stream_response_bytes: Option<usize>,
) -> Result<Response<axum::body::Body>, Status> {
let (metadata, _extensions, body) = request.into_parts();
let request_encoding = metadata
.get("grpc-encoding")
.and_then(|value| value.to_str().ok())
.map(str::to_owned);
let message_stream = crate::grpc::framing::parse_grpc_client_stream(
body,
max_message_size,
request_encoding.as_deref(),
compression_enabled,
)
.await?;
let streaming_request = crate::grpc::streaming::StreamingRequest {
service_name,
method_name,
message_stream,
metadata,
};
let response_stream: MessageStream = self.handler.call_bidi_stream(streaming_request).await?;
let response_stream = crate::grpc::streaming::limit_message_stream(response_stream, max_stream_response_bytes);
let body = grpc_stream_body(response_stream);
let response = Response::new(body);
Ok(response)
}
}
fn grpc_stream_body(message_stream: MessageStream) -> axum::body::Body {
let frame_stream = futures_util::stream::unfold(
GrpcFrameStreamState {
stream: message_stream,
finished: false,
},
|mut state| async move {
if state.finished {
return None;
}
match state.stream.next().await {
Some(Ok(bytes)) => match encode_grpc_message(bytes) {
Ok(framed) => Some((Ok::<Frame<Bytes>, Infallible>(Frame::data(framed)), state)),
Err(status) => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_status_trailers(&status))), state))
}
},
Some(Err(status)) => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_status_trailers(&status))), state))
}
None => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_success_trailers())), state))
}
}
},
);
axum::body::Body::new(StreamBody::new(frame_stream))
}
struct GrpcFrameStreamState {
stream: MessageStream,
finished: bool,
}
fn grpc_success_trailers() -> HeaderMap {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", HeaderValue::from_static("0"));
trailers.insert("grpc-message", HeaderValue::from_static("OK"));
trailers
}
fn grpc_status_trailers(status: &Status) -> HeaderMap {
let mut trailers = HeaderMap::new();
let code = grpc_code_number(status.code());
trailers.insert(
"grpc-status",
HeaderValue::from_str(code).unwrap_or_else(|_| HeaderValue::from_static("2")),
);
let encoded_message = if status.message().is_empty() {
"unknown".to_string()
} else {
urlencoding::encode(status.message()).into_owned()
};
trailers.insert(
"grpc-message",
HeaderValue::from_str(&encoded_message).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
);
trailers
}
fn grpc_code_number(code: tonic::Code) -> &'static str {
match code {
tonic::Code::Ok => "0",
tonic::Code::Cancelled => "1",
tonic::Code::Unknown => "2",
tonic::Code::InvalidArgument => "3",
tonic::Code::DeadlineExceeded => "4",
tonic::Code::NotFound => "5",
tonic::Code::AlreadyExists => "6",
tonic::Code::PermissionDenied => "7",
tonic::Code::ResourceExhausted => "8",
tonic::Code::FailedPrecondition => "9",
tonic::Code::Aborted => "10",
tonic::Code::OutOfRange => "11",
tonic::Code::Unimplemented => "12",
tonic::Code::Internal => "13",
tonic::Code::Unavailable => "14",
tonic::Code::DataLoss => "15",
tonic::Code::Unauthenticated => "16",
}
}
pub fn parse_grpc_path(path: &str) -> Result<(String, String), Status> {
let path = path.trim_start_matches('/');
let parts: Vec<&str> = path.split('/').collect();
if parts.len() != 2 {
return Err(Status::invalid_argument(format!("Invalid gRPC path: {}", path)));
}
let service_name = parts[0].to_string();
let method_name = parts[1].to_string();
if service_name.is_empty() || method_name.is_empty() {
return Err(Status::invalid_argument("Service or method name is empty"));
}
Ok((service_name, method_name))
}
pub fn copy_metadata(source: &tonic::metadata::MetadataMap, dest: &mut tonic::metadata::MetadataMap) {
for key_value in source.iter() {
match key_value {
tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
dest.insert(key, value.clone());
}
tonic::metadata::KeyAndValueRef::Binary(key, value) => {
dest.insert_bin(key, value.clone());
}
}
}
}