use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
use crate::codec::EncodeBody;
use crate::metadata::GRPC_CONTENT_TYPE;
use crate::{
body::Body,
client::GrpcService,
codec::{Codec, Decoder, Streaming},
request::SanitizeHeaders,
Code, Request, Response, Status,
};
use http::{
header::{HeaderValue, CONTENT_TYPE, TE},
uri::{PathAndQuery, Uri},
};
use http_body::Body as HttpBody;
use std::{fmt, future, pin::pin};
use tokio_stream::{Stream, StreamExt};
pub struct Grpc<T> {
inner: T,
config: GrpcConfig,
}
struct GrpcConfig {
origin: Uri,
accept_compression_encodings: EnabledCompressionEncodings,
send_compression_encodings: Option<CompressionEncoding>,
max_decoding_message_size: Option<usize>,
max_encoding_message_size: Option<usize>,
}
impl<T> Grpc<T> {
pub fn new(inner: T) -> Self {
Self::with_origin(inner, Uri::default())
}
pub fn with_origin(inner: T, origin: Uri) -> Self {
Self {
inner,
config: GrpcConfig {
origin,
send_compression_encodings: None,
accept_compression_encodings: EnabledCompressionEncodings::default(),
max_decoding_message_size: None,
max_encoding_message_size: None,
},
}
}
pub fn send_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.config.send_compression_encodings = Some(encoding);
self
}
pub fn accept_compressed(mut self, encoding: CompressionEncoding) -> Self {
self.config.accept_compression_encodings.enable(encoding);
self
}
pub fn max_decoding_message_size(mut self, limit: usize) -> Self {
self.config.max_decoding_message_size = Some(limit);
self
}
pub fn max_encoding_message_size(mut self, limit: usize) -> Self {
self.config.max_encoding_message_size = Some(limit);
self
}
pub async fn ready(&mut self) -> Result<(), T::Error>
where
T: GrpcService<Body>,
{
future::poll_fn(|cx| self.inner.poll_ready(cx)).await
}
pub async fn unary<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<Body>,
T::ResponseBody: HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| tokio_stream::once(m));
self.client_streaming(request, path, codec).await
}
pub async fn client_streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<Body>,
T::ResponseBody: HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
S: Stream<Item = M1> + Send + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let (mut parts, body, extensions) =
self.streaming(request, path, codec).await?.into_parts();
let mut body = pin!(body);
let message = body
.try_next()
.await
.map_err(|mut status| {
status.metadata_mut().merge(parts.clone());
status
})?
.ok_or_else(|| Status::internal("Missing response message."))?;
if let Some(trailers) = body.trailers().await? {
parts.merge(trailers);
}
Ok(Response::from_parts(parts, message, extensions))
}
pub async fn server_streaming<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<Body>,
T::ResponseBody: HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| tokio_stream::once(m));
self.streaming(request, path, codec).await
}
pub async fn streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
mut codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<Body>,
T::ResponseBody: HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
S: Stream<Item = M1> + Send + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request
.map(|s| {
EncodeBody::new_client(
codec.encoder(),
s.map(Ok),
self.config.send_compression_encodings,
self.config.max_encoding_message_size,
)
})
.map(Body::new);
let request = self.config.prepare_request(request, path);
let response = self
.inner
.call(request)
.await
.map_err(Status::from_error_generic)?;
let decoder = codec.decoder();
self.create_response(decoder, response)
}
fn create_response<M2>(
&self,
decoder: impl Decoder<Item = M2, Error = Status> + Send + 'static,
response: http::Response<T::ResponseBody>,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<Body>,
T::ResponseBody: HttpBody + Send + 'static,
<T::ResponseBody as HttpBody>::Error: Into<crate::BoxError>,
{
let encoding = CompressionEncoding::from_encoding_header(
response.headers(),
self.config.accept_compression_encodings,
)?;
let status_code = response.status();
let trailers_only_status = Status::from_header_map(response.headers());
let expect_additional_trailers = if let Some(status) = trailers_only_status {
if status.code() != Code::Ok {
return Err(status);
}
false
} else {
true
};
let response = response.map(|body| {
if expect_additional_trailers {
Streaming::new_response(
decoder,
body,
status_code,
encoding,
self.config.max_decoding_message_size,
)
} else {
Streaming::new_empty(decoder, body)
}
});
Ok(Response::from_http(response))
}
}
impl GrpcConfig {
fn prepare_request(&self, request: Request<Body>, path: PathAndQuery) -> http::Request<Body> {
let mut parts = self.origin.clone().into_parts();
match &parts.path_and_query {
Some(pnq) if pnq != "/" => {
parts.path_and_query = Some(
format!("{}{}", pnq.path(), path)
.parse()
.expect("must form valid path_and_query"),
)
}
_ => {
parts.path_and_query = Some(path);
}
}
let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
let mut request = request.into_http(
uri,
http::Method::POST,
http::Version::HTTP_2,
SanitizeHeaders::Yes,
);
request
.headers_mut()
.insert(TE, HeaderValue::from_static("trailers"));
request
.headers_mut()
.insert(CONTENT_TYPE, GRPC_CONTENT_TYPE);
#[cfg(any(feature = "gzip", feature = "deflate", feature = "zstd"))]
if let Some(encoding) = self.send_compression_encodings {
request.headers_mut().insert(
crate::codec::compression::ENCODING_HEADER,
encoding.into_header_value(),
);
}
if let Some(header_value) = self
.accept_compression_encodings
.into_accept_encoding_header_value()
{
request.headers_mut().insert(
crate::codec::compression::ACCEPT_ENCODING_HEADER,
header_value,
);
}
request
}
}
impl<T: Clone> Clone for Grpc<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
config: GrpcConfig {
origin: self.config.origin.clone(),
send_compression_encodings: self.config.send_compression_encodings,
accept_compression_encodings: self.config.accept_compression_encodings,
max_encoding_message_size: self.config.max_encoding_message_size,
max_decoding_message_size: self.config.max_decoding_message_size,
},
}
}
}
impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Grpc")
.field("inner", &self.inner)
.field("origin", &self.config.origin)
.field(
"compression_encoding",
&self.config.send_compression_encodings,
)
.field(
"accept_compression_encodings",
&self.config.accept_compression_encodings,
)
.field(
"max_decoding_message_size",
&self.config.max_decoding_message_size,
)
.field(
"max_encoding_message_size",
&self.config.max_encoding_message_size,
)
.finish()
}
}