#[cfg(feature = "compression")]
use crate::codec::compression::{CompressionEncoding, EnabledCompressionEncodings};
use crate::{
body::BoxBody,
client::GrpcService,
codec::{encode_client, Codec, Streaming},
request::SanitizeHeaders,
Code, Request, Response, Status,
};
use futures_core::Stream;
use futures_util::{future, stream, TryStreamExt};
use http::{
header::{HeaderValue, CONTENT_TYPE, TE},
uri::{Parts, PathAndQuery, Uri},
};
use http_body::Body;
use std::fmt;
pub struct Grpc<T> {
inner: T,
#[cfg(feature = "compression")]
accept_compression_encodings: EnabledCompressionEncodings,
#[cfg(feature = "compression")]
send_compression_encodings: Option<CompressionEncoding>,
}
impl<T> Grpc<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
#[cfg(feature = "compression")]
send_compression_encodings: None,
#[cfg(feature = "compression")]
accept_compression_encodings: EnabledCompressionEncodings::default(),
}
}
#[cfg(feature = "compression")]
#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
pub fn send_gzip(mut self) -> Self {
self.send_compression_encodings = Some(CompressionEncoding::Gzip);
self
}
#[doc(hidden)]
#[cfg(not(feature = "compression"))]
pub fn send_gzip(self) -> Self {
panic!(
"`send_gzip` called on a client but the `compression` feature is not enabled on tonic"
);
}
#[cfg(feature = "compression")]
#[cfg_attr(docsrs, doc(cfg(feature = "compression")))]
pub fn accept_gzip(mut self) -> Self {
self.accept_compression_encodings.enable_gzip();
self
}
#[doc(hidden)]
#[cfg(not(feature = "compression"))]
pub fn accept_gzip(self) -> Self {
panic!("`accept_gzip` called on a client but the `compression` feature is not enabled on tonic");
}
pub async fn ready(&mut self) -> Result<(), T::Error>
where
T: GrpcService<BoxBody>,
{
future::poll_fn(|cx| self.inner.poll_ready(cx)).await
}
pub async fn unary<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + Send + 'static,
<T::ResponseBody as Body>::Error: Into<crate::Error>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| stream::once(future::ready(m)));
self.client_streaming(request, path, codec).await
}
pub async fn client_streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
codec: C,
) -> Result<Response<M2>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + Send + 'static,
<T::ResponseBody as Body>::Error: Into<crate::Error>,
S: Stream<Item = M1> + Send + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let (mut parts, body, extensions) =
self.streaming(request, path, codec).await?.into_parts();
futures_util::pin_mut!(body);
let message = body
.try_next()
.await
.map_err(|mut status| {
status.metadata_mut().merge(parts.clone());
status
})?
.ok_or_else(|| Status::new(Code::Internal, "Missing response message."))?;
if let Some(trailers) = body.trailers().await? {
parts.merge(trailers);
}
Ok(Response::from_parts(parts, message, extensions))
}
pub async fn server_streaming<M1, M2, C>(
&mut self,
request: Request<M1>,
path: PathAndQuery,
codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + Send + 'static,
<T::ResponseBody as Body>::Error: Into<crate::Error>,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let request = request.map(|m| stream::once(future::ready(m)));
self.streaming(request, path, codec).await
}
pub async fn streaming<S, M1, M2, C>(
&mut self,
request: Request<S>,
path: PathAndQuery,
mut codec: C,
) -> Result<Response<Streaming<M2>>, Status>
where
T: GrpcService<BoxBody>,
T::ResponseBody: Body + Send + 'static,
<T::ResponseBody as Body>::Error: Into<crate::Error>,
S: Stream<Item = M1> + Send + 'static,
C: Codec<Encode = M1, Decode = M2>,
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let mut parts = Parts::default();
parts.path_and_query = Some(path);
let uri = Uri::from_parts(parts).expect("path_and_query only is valid Uri");
let request = request
.map(|s| {
encode_client(
codec.encoder(),
s,
#[cfg(feature = "compression")]
self.send_compression_encodings,
)
})
.map(BoxBody::new);
let mut request = request.into_http(uri, SanitizeHeaders::Yes);
request
.headers_mut()
.insert(TE, HeaderValue::from_static("trailers"));
request
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("application/grpc"));
#[cfg(feature = "compression")]
{
if let Some(encoding) = self.send_compression_encodings {
request.headers_mut().insert(
crate::codec::compression::ENCODING_HEADER,
encoding.into_header_value(),
);
}
if let Some(header_value) = self
.accept_compression_encodings
.into_accept_encoding_header_value()
{
request.headers_mut().insert(
crate::codec::compression::ACCEPT_ENCODING_HEADER,
header_value,
);
}
}
let response = self
.inner
.call(request)
.await
.map_err(|err| Status::from_error(err.into()))?;
#[cfg(feature = "compression")]
let encoding = CompressionEncoding::from_encoding_header(
response.headers(),
self.accept_compression_encodings,
)?;
let status_code = response.status();
let trailers_only_status = Status::from_header_map(response.headers());
let expect_additional_trailers = if let Some(status) = trailers_only_status {
if status.code() != Code::Ok {
return Err(status);
}
false
} else {
true
};
let response = response.map(|body| {
if expect_additional_trailers {
Streaming::new_response(
codec.decoder(),
body,
status_code,
#[cfg(feature = "compression")]
encoding,
)
} else {
Streaming::new_empty(codec.decoder(), body)
}
});
Ok(Response::from_http(response))
}
}
impl<T: Clone> Clone for Grpc<T> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
#[cfg(feature = "compression")]
send_compression_encodings: self.send_compression_encodings,
#[cfg(feature = "compression")]
accept_compression_encodings: self.accept_compression_encodings,
}
}
}
impl<T: fmt::Debug> fmt::Debug for Grpc<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut f = f.debug_struct("Grpc");
f.field("inner", &self.inner);
#[cfg(feature = "compression")]
f.field("compression_encoding", &self.send_compression_encodings);
#[cfg(feature = "compression")]
f.field(
"accept_compression_encodings",
&self.accept_compression_encodings,
);
f.finish()
}
}