use crate::grpc::framing::encode_grpc_message;
use crate::grpc::handler::{GrpcHandler, GrpcHandlerResult, GrpcRequestData, GrpcResponseData};
use crate::grpc::streaming::MessageStream;
use axum::http::{HeaderMap, HeaderValue};
use bytes::Bytes;
use futures_util::StreamExt;
use http_body::Frame;
use http_body_util::StreamBody;
use std::convert::Infallible;
use std::sync::Arc;
use tonic::{Request, Response, Status};
pub struct GenericGrpcService {
handler: Arc<dyn GrpcHandler>,
}
impl GenericGrpcService {
pub fn new(handler: Arc<dyn GrpcHandler>) -> Self {
Self { handler }
}
pub async fn handle_unary(
&self,
service_name: String,
method_name: String,
request: Request<Bytes>,
) -> Result<Response<Bytes>, Status> {
let (metadata, _extensions, payload) = request.into_parts();
let grpc_request = GrpcRequestData {
service_name,
method_name,
payload,
metadata,
};
let result: GrpcHandlerResult = self.handler.call(grpc_request).await;
match result {
Ok(grpc_response) => {
let mut response = Response::new(grpc_response.payload);
copy_metadata(&grpc_response.metadata, response.metadata_mut());
Ok(response)
}
Err(status) => Err(status),
}
}
pub async fn handle_server_stream(
&self,
service_name: String,
method_name: String,
request: Request<Bytes>,
) -> Result<Response<axum::body::Body>, Status> {
let (metadata, _extensions, payload) = request.into_parts();
let grpc_request = GrpcRequestData {
service_name,
method_name,
payload,
metadata,
};
let message_stream: MessageStream = self.handler.call_server_stream(grpc_request).await?;
let body = grpc_stream_body(message_stream);
let response = Response::new(body);
Ok(response)
}
pub async fn handle_client_stream(
&self,
service_name: String,
method_name: String,
request: Request<axum::body::Body>,
max_message_size: usize,
compression_enabled: bool,
) -> Result<Response<Bytes>, Status> {
let (metadata, _extensions, body) = request.into_parts();
let request_encoding = metadata
.get("grpc-encoding")
.and_then(|value| value.to_str().ok())
.map(str::to_owned);
let message_stream = crate::grpc::framing::parse_grpc_client_stream(
body,
max_message_size,
request_encoding.as_deref(),
compression_enabled,
)
.await?;
let streaming_request = crate::grpc::streaming::StreamingRequest {
service_name,
method_name,
message_stream,
metadata,
};
let response: crate::grpc::handler::GrpcHandlerResult =
self.handler.call_client_stream(streaming_request).await;
match response {
Ok(grpc_response) => {
let mut tonic_response = Response::new(grpc_response.payload);
copy_metadata(&grpc_response.metadata, tonic_response.metadata_mut());
Ok(tonic_response)
}
Err(status) => Err(status),
}
}
pub async fn handle_bidi_stream(
&self,
service_name: String,
method_name: String,
request: Request<axum::body::Body>,
max_message_size: usize,
compression_enabled: bool,
) -> Result<Response<axum::body::Body>, Status> {
let (metadata, _extensions, body) = request.into_parts();
let request_encoding = metadata
.get("grpc-encoding")
.and_then(|value| value.to_str().ok())
.map(str::to_owned);
let message_stream = crate::grpc::framing::parse_grpc_client_stream(
body,
max_message_size,
request_encoding.as_deref(),
compression_enabled,
)
.await?;
let streaming_request = crate::grpc::streaming::StreamingRequest {
service_name,
method_name,
message_stream,
metadata,
};
let response_stream: MessageStream = self.handler.call_bidi_stream(streaming_request).await?;
let body = grpc_stream_body(response_stream);
let response = Response::new(body);
Ok(response)
}
pub fn service_name(&self) -> &str {
self.handler.service_name()
}
}
fn grpc_stream_body(message_stream: MessageStream) -> axum::body::Body {
let frame_stream = futures_util::stream::unfold(
GrpcFrameStreamState {
stream: message_stream,
finished: false,
},
|mut state| async move {
if state.finished {
return None;
}
match state.stream.next().await {
Some(Ok(bytes)) => match encode_grpc_message(bytes) {
Ok(framed) => Some((Ok::<Frame<Bytes>, Infallible>(Frame::data(framed)), state)),
Err(status) => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_status_trailers(&status))), state))
}
},
Some(Err(status)) => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_status_trailers(&status))), state))
}
None => {
state.finished = true;
Some((Ok(Frame::trailers(grpc_success_trailers())), state))
}
}
},
);
axum::body::Body::new(StreamBody::new(frame_stream))
}
struct GrpcFrameStreamState {
stream: MessageStream,
finished: bool,
}
fn grpc_success_trailers() -> HeaderMap {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", HeaderValue::from_static("0"));
trailers.insert("grpc-message", HeaderValue::from_static("OK"));
trailers
}
fn grpc_status_trailers(status: &Status) -> HeaderMap {
let mut trailers = HeaderMap::new();
let code = grpc_code_number(status.code());
trailers.insert(
"grpc-status",
HeaderValue::from_str(code).unwrap_or_else(|_| HeaderValue::from_static("2")),
);
let encoded_message = if status.message().is_empty() {
"unknown".to_string()
} else {
urlencoding::encode(status.message()).into_owned()
};
trailers.insert(
"grpc-message",
HeaderValue::from_str(&encoded_message).unwrap_or_else(|_| HeaderValue::from_static("unknown")),
);
trailers
}
fn grpc_code_number(code: tonic::Code) -> &'static str {
match code {
tonic::Code::Ok => "0",
tonic::Code::Cancelled => "1",
tonic::Code::Unknown => "2",
tonic::Code::InvalidArgument => "3",
tonic::Code::DeadlineExceeded => "4",
tonic::Code::NotFound => "5",
tonic::Code::AlreadyExists => "6",
tonic::Code::PermissionDenied => "7",
tonic::Code::ResourceExhausted => "8",
tonic::Code::FailedPrecondition => "9",
tonic::Code::Aborted => "10",
tonic::Code::OutOfRange => "11",
tonic::Code::Unimplemented => "12",
tonic::Code::Internal => "13",
tonic::Code::Unavailable => "14",
tonic::Code::DataLoss => "15",
tonic::Code::Unauthenticated => "16",
}
}
pub fn parse_grpc_path(path: &str) -> Result<(String, String), Status> {
let path = path.trim_start_matches('/');
let parts: Vec<&str> = path.split('/').collect();
if parts.len() != 2 {
return Err(Status::invalid_argument(format!("Invalid gRPC path: {}", path)));
}
let service_name = parts[0].to_string();
let method_name = parts[1].to_string();
if service_name.is_empty() || method_name.is_empty() {
return Err(Status::invalid_argument("Service or method name is empty"));
}
Ok((service_name, method_name))
}
pub fn is_grpc_request(headers: &axum::http::HeaderMap) -> bool {
headers
.get(axum::http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|v| v.starts_with("application/grpc"))
.unwrap_or(false)
}
pub fn copy_metadata(source: &tonic::metadata::MetadataMap, dest: &mut tonic::metadata::MetadataMap) {
for key_value in source.iter() {
match key_value {
tonic::metadata::KeyAndValueRef::Ascii(key, value) => {
dest.insert(key, value.clone());
}
tonic::metadata::KeyAndValueRef::Binary(key, value) => {
dest.insert_bin(key, value.clone());
}
}
}
}
pub fn grpc_response_to_tonic(response: GrpcResponseData) -> Response<Bytes> {
let mut tonic_response = Response::new(response.payload);
copy_metadata(&response.metadata, tonic_response.metadata_mut());
tonic_response
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grpc::handler::GrpcHandler;
use std::future::Future;
use std::pin::Pin;
use tonic::metadata::MetadataMap;
struct TestHandler;
impl GrpcHandler for TestHandler {
fn call(&self, request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
Box::pin(async move {
Ok(GrpcResponseData {
payload: request.payload,
metadata: MetadataMap::new(),
})
})
}
fn service_name(&self) -> &str {
"test.TestService"
}
}
#[tokio::test]
async fn test_generic_grpc_service_handle_unary() {
let handler = Arc::new(TestHandler);
let service = GenericGrpcService::new(handler);
let request = Request::new(Bytes::from("test payload"));
let result = service
.handle_unary("test.TestService".to_string(), "TestMethod".to_string(), request)
.await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.into_inner(), Bytes::from("test payload"));
}
#[tokio::test]
async fn test_generic_grpc_service_with_metadata() {
let handler = Arc::new(TestHandler);
let service = GenericGrpcService::new(handler);
let mut request = Request::new(Bytes::from("payload"));
request
.metadata_mut()
.insert("custom-header", "custom-value".parse().unwrap());
let result = service
.handle_unary("test.TestService".to_string(), "TestMethod".to_string(), request)
.await;
assert!(result.is_ok());
}
#[test]
fn test_parse_grpc_path_valid() {
let (service, method) = parse_grpc_path("/mypackage.UserService/GetUser").unwrap();
assert_eq!(service, "mypackage.UserService");
assert_eq!(method, "GetUser");
}
#[test]
fn test_parse_grpc_path_with_nested_package() {
let (service, method) = parse_grpc_path("/com.example.api.v1.UserService/GetUser").unwrap();
assert_eq!(service, "com.example.api.v1.UserService");
assert_eq!(method, "GetUser");
}
#[test]
fn test_parse_grpc_path_invalid_format() {
let result = parse_grpc_path("/invalid");
assert!(result.is_err());
let status = result.unwrap_err();
assert_eq!(status.code(), tonic::Code::InvalidArgument);
}
#[test]
fn test_parse_grpc_path_empty_service() {
let result = parse_grpc_path("//Method");
assert!(result.is_err());
}
#[test]
fn test_parse_grpc_path_empty_method() {
let result = parse_grpc_path("/Service/");
assert!(result.is_err());
}
#[test]
fn test_parse_grpc_path_no_leading_slash() {
let (service, method) = parse_grpc_path("package.Service/Method").unwrap();
assert_eq!(service, "package.Service");
assert_eq!(method, "Method");
}
#[test]
fn test_is_grpc_request_valid() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(axum::http::header::CONTENT_TYPE, "application/grpc".parse().unwrap());
assert!(is_grpc_request(&headers));
}
#[test]
fn test_is_grpc_request_with_subtype() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(
axum::http::header::CONTENT_TYPE,
"application/grpc+proto".parse().unwrap(),
);
assert!(is_grpc_request(&headers));
}
#[test]
fn test_is_grpc_request_not_grpc() {
let mut headers = axum::http::HeaderMap::new();
headers.insert(axum::http::header::CONTENT_TYPE, "application/json".parse().unwrap());
assert!(!is_grpc_request(&headers));
}
#[test]
fn test_is_grpc_request_no_content_type() {
let headers = axum::http::HeaderMap::new();
assert!(!is_grpc_request(&headers));
}
#[test]
fn test_grpc_response_to_tonic_basic() {
let response = GrpcResponseData {
payload: Bytes::from("response"),
metadata: MetadataMap::new(),
};
let tonic_response = grpc_response_to_tonic(response);
assert_eq!(tonic_response.into_inner(), Bytes::from("response"));
}
#[test]
fn test_grpc_response_to_tonic_with_metadata() {
let mut metadata = MetadataMap::new();
metadata.insert("custom-header", "value".parse().unwrap());
let response = GrpcResponseData {
payload: Bytes::from("data"),
metadata,
};
let tonic_response = grpc_response_to_tonic(response);
assert_eq!(tonic_response.get_ref(), &Bytes::from("data"));
assert!(tonic_response.metadata().get("custom-header").is_some());
}
#[test]
fn test_generic_grpc_service_service_name() {
let handler = Arc::new(TestHandler);
let service = GenericGrpcService::new(handler);
assert_eq!(service.service_name(), "test.TestService");
}
#[test]
fn test_copy_metadata() {
let mut source = MetadataMap::new();
source.insert("key1", "value1".parse().unwrap());
source.insert("key2", "value2".parse().unwrap());
let mut dest = MetadataMap::new();
copy_metadata(&source, &mut dest);
assert_eq!(dest.get("key1").unwrap(), "value1");
assert_eq!(dest.get("key2").unwrap(), "value2");
}
#[test]
fn test_copy_metadata_empty() {
let source = MetadataMap::new();
let mut dest = MetadataMap::new();
copy_metadata(&source, &mut dest);
assert!(dest.is_empty());
}
#[test]
fn test_copy_metadata_binary() {
let mut source = MetadataMap::new();
source.insert_bin("binary-key-bin", tonic::metadata::MetadataValue::from_bytes(b"binary"));
let mut dest = MetadataMap::new();
copy_metadata(&source, &mut dest);
assert!(dest.get_bin("binary-key-bin").is_some());
}
#[tokio::test]
async fn test_generic_grpc_service_error_handling() {
struct ErrorHandler;
impl GrpcHandler for ErrorHandler {
fn call(&self, _request: GrpcRequestData) -> Pin<Box<dyn Future<Output = GrpcHandlerResult> + Send>> {
Box::pin(async { Err(Status::not_found("Resource not found")) })
}
fn service_name(&self) -> &str {
"test.ErrorService"
}
}
let handler = Arc::new(ErrorHandler);
let service = GenericGrpcService::new(handler);
let request = Request::new(Bytes::new());
let result = service
.handle_unary("test.ErrorService".to_string(), "ErrorMethod".to_string(), request)
.await;
assert!(result.is_err());
let status = result.unwrap_err();
assert_eq!(status.code(), tonic::Code::NotFound);
assert_eq!(status.message(), "Resource not found");
}
}