use std::fmt;
use std::future::IntoFuture;
use std::io::{Read, Write};
use std::pin::Pin;
use std::time::Duration;
use async_channel::{Receiver, Sender};
use bytes::{Bytes, BytesMut};
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use futures_lite::Stream;
use futures_lite::StreamExt;
use futures_lite::stream;
use serde::Serialize;
use serde::de::DeserializeOwned;
use crate::BodyStream;
use crate::error::{Error, ErrorKind, Result};
use crate::header::HeaderMap;
use crate::request::{ProtocolPolicy, RequestBuilder};
use crate::response::{Response, TrailerState, Version};
use crate::tls::{TlsBackend, TlsConfig};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GrpcCodec {
Protobuf,
Json,
}
type GrpcMessageStream = Pin<Box<dyn Stream<Item = Result<Bytes>> + Send + 'static>>;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum GrpcCompression {
Gzip,
}
impl GrpcCompression {
fn as_str(self) -> &'static str {
match self {
Self::Gzip => "gzip",
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
struct GrpcRequestConfig {
content_type: &'static str,
compression: Option<GrpcCompression>,
}
enum GrpcRequestPayload {
Empty,
JsonMessage(Bytes),
JsonStream(GrpcMessageStream),
RawMessage(Bytes),
RawStream(GrpcMessageStream),
}
pub struct GrpcRequestBuilder {
request: RequestBuilder,
codec: GrpcCodec,
compression: Option<String>,
timeout: Option<Duration>,
payload: GrpcRequestPayload,
}
impl GrpcRequestBuilder {
pub(crate) fn from_request_builder(request: RequestBuilder) -> Self {
Self {
request: request.http2_only(),
codec: GrpcCodec::Json,
compression: None,
timeout: None,
payload: GrpcRequestPayload::Empty,
}
}
pub fn metadata(mut self, name: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
self.request = self.request.headers([(name.as_ref(), value.as_ref())])?;
Ok(self)
}
pub fn metadata_map<I, K, V>(mut self, values: I) -> Result<Self>
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
for (name, value) in values {
self = self.metadata(name, value)?;
}
Ok(self)
}
pub fn metadata_bin(mut self, name: impl AsRef<str>, value: impl AsRef<[u8]>) -> Result<Self> {
let name = normalize_grpc_binary_metadata_name(name.as_ref())?;
let value = encode_grpc_binary_header(value.as_ref());
self.request = self.request.header(name, value)?;
Ok(self)
}
pub fn message<T: Serialize>(mut self, value: &T) -> Result<Self> {
self.payload = GrpcRequestPayload::JsonMessage(encode_json_message(value)?);
Ok(self)
}
pub fn messages<S, T>(mut self, stream: S) -> Result<Self>
where
S: Stream<Item = Result<T>> + Send + 'static,
T: Serialize + Send,
{
let stream = async_stream::stream! {
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
yield Ok(encode_json_message(&item)?);
}
};
self.payload = GrpcRequestPayload::JsonStream(Box::pin(stream));
Ok(self)
}
pub fn message_bytes(mut self, value: impl Into<Bytes>) -> Result<Self> {
self.payload = GrpcRequestPayload::RawMessage(value.into());
Ok(self)
}
pub fn messages_bytes<S, B>(mut self, stream: S) -> Result<Self>
where
S: Stream<Item = Result<B>> + Send + 'static,
B: Into<Bytes> + Send + 'static,
{
let stream = async_stream::stream! {
let mut stream = Box::pin(stream);
while let Some(item) = stream.next().await {
let item = item?;
yield Ok(item.into());
}
};
self.payload = GrpcRequestPayload::RawStream(Box::pin(stream));
Ok(self)
}
pub fn codec(mut self, codec: GrpcCodec) -> Self {
self.codec = codec;
self
}
pub fn compression(mut self, algo: impl AsRef<str>) -> Self {
let algo = algo.as_ref().trim();
self.compression = if algo.is_empty() {
None
} else {
Some(algo.to_owned())
};
self
}
pub fn protocol_policy(mut self, policy: ProtocolPolicy) -> Self {
self.request = self.request.protocol_policy(policy);
self
}
pub fn prefer_http3(mut self) -> Self {
self.request = self.request.prefer_http3();
self
}
pub fn prefer_http2(mut self) -> Self {
self.request = self.request.prefer_http2();
self
}
pub fn http2_only(mut self) -> Self {
self.request = self.request.http2_only();
self
}
pub fn http3_only(mut self) -> Self {
self.request = self.request.http3_only();
self
}
pub fn prior_knowledge_h2c(mut self, enabled: bool) -> Self {
self.request = self.request.prior_knowledge_h2c(enabled);
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = Some(duration);
self.request = self.request.timeout(duration);
self
}
pub fn connect_timeout(mut self, duration: Duration) -> Self {
self.request = self.request.connect_timeout(duration);
self
}
pub fn read_timeout(mut self, duration: Duration) -> Self {
self.request = self.request.read_timeout(duration);
self
}
pub fn write_timeout(mut self, duration: Duration) -> Self {
self.request = self.request.write_timeout(duration);
self
}
pub fn tls_config(mut self, tls_config: TlsConfig) -> Self {
self.request = self.request.tls_config(tls_config);
self
}
pub fn danger_accept_invalid_certs(mut self, enabled: bool) -> Self {
self.request = self.request.danger_accept_invalid_certs(enabled);
self
}
pub fn tls_backend(mut self, backend: TlsBackend) -> Self {
self.request = self.request.tls_backend(backend);
self
}
pub async fn send_streaming(self) -> Result<GrpcStreamingResponse> {
let codec = self.codec;
let config = self.validate_configuration()?;
let request = self.build_request(config)?;
let response = request.await?;
GrpcStreamingResponse::from_http_response(response, codec)
}
pub async fn open_duplex(self) -> Result<GrpcDuplexCall> {
if !matches!(self.payload, GrpcRequestPayload::Empty) {
return Err(duplex_payload_conflict_error());
}
let codec = self.codec;
let config = self.validate_configuration()?;
let mut request = self
.request
.header("te", "trailers")?
.header("content-type", config.content_type)?
.header("grpc-accept-encoding", GrpcCompression::Gzip.as_str())?;
if let Some(timeout) = self.timeout {
request = request.header("grpc-timeout", &encode_grpc_timeout(timeout))?;
}
if let Some(compression) = config.compression {
request = request.header("grpc-encoding", compression.as_str())?;
}
let (request_tx, request_rx) = async_channel::unbounded::<Bytes>();
let (response_tx, response_rx) = async_channel::bounded(1);
std::thread::Builder::new()
.name("request-grpc-duplex".to_owned())
.spawn(move || {
async_io::block_on(async move {
let body_stream: BodyStream = Box::pin(async_stream::stream! {
while let Ok(item) = request_rx.recv().await {
yield Ok(item);
}
});
let response = request.body_stream(body_stream).await.and_then(|response| {
GrpcStreamingResponse::from_http_response(response, codec)
});
let _ = response_tx.send(response).await;
});
})
.map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to spawn grpc duplex task",
err,
)
})?;
Ok(GrpcDuplexCall {
codec,
compression: config.compression,
request_tx,
response_rx,
response: None,
})
}
fn validate_configuration(&self) -> Result<GrpcRequestConfig> {
let compression = parse_grpc_compression(self.compression.as_deref())?;
let content_type = match self.codec {
GrpcCodec::Json => "application/grpc+json",
GrpcCodec::Protobuf => "application/grpc+proto",
};
Ok(GrpcRequestConfig {
content_type,
compression,
})
}
fn build_request(self, config: GrpcRequestConfig) -> Result<RequestBuilder> {
let mut request = self
.request
.header("te", "trailers")?
.header("content-type", config.content_type)?
.header("grpc-accept-encoding", GrpcCompression::Gzip.as_str())?;
if let Some(timeout) = self.timeout {
request = request.header("grpc-timeout", &encode_grpc_timeout(timeout))?;
}
if let Some(compression) = config.compression {
request = request.header("grpc-encoding", compression.as_str())?;
}
match self.payload {
GrpcRequestPayload::Empty => Err(Error::new(
ErrorKind::Transport,
"grpc request requires at least one message",
)),
GrpcRequestPayload::JsonMessage(message) => {
if self.codec != GrpcCodec::Json {
return Err(typed_request_codec_error());
}
Ok(request.body(encode_grpc_frame(&message, config.compression)?))
}
GrpcRequestPayload::JsonStream(stream) => {
if self.codec != GrpcCodec::Json {
return Err(typed_request_codec_error());
}
Ok(request.body_stream(frame_message_stream(stream, config.compression)))
}
GrpcRequestPayload::RawMessage(message) => {
Ok(request.body(encode_grpc_frame(&message, config.compression)?))
}
GrpcRequestPayload::RawStream(stream) => {
Ok(request.body_stream(frame_message_stream(stream, config.compression)))
}
}
}
}
impl IntoFuture for GrpcRequestBuilder {
type Output = Result<GrpcResponse>;
type IntoFuture = Pin<Box<dyn std::future::Future<Output = Self::Output> + Send + 'static>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.send_streaming().await?.into_buffered_response().await })
}
}
#[derive(Debug)]
pub struct GrpcResponse {
codec: GrpcCodec,
headers: HeaderMap,
messages: Vec<Bytes>,
next_message: usize,
trailers: Option<HeaderMap>,
status: GrpcStatus,
}
pub struct GrpcDuplexCall {
codec: GrpcCodec,
compression: Option<GrpcCompression>,
request_tx: Sender<Bytes>,
response_rx: Receiver<Result<GrpcStreamingResponse>>,
response: Option<GrpcStreamingResponse>,
}
impl fmt::Debug for GrpcDuplexCall {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GrpcDuplexCall")
.field("codec", &self.codec)
.field("compression", &self.compression)
.field("response_ready", &self.response.is_some())
.finish()
}
}
impl GrpcDuplexCall {
pub async fn send_message<T: Serialize>(&self, value: &T) -> Result<()> {
if self.codec != GrpcCodec::Json {
return Err(typed_request_codec_error());
}
let message = encode_json_message(value)?;
self.send_message_bytes(message).await
}
pub async fn send_message_bytes(&self, value: impl Into<Bytes>) -> Result<()> {
let frame = encode_grpc_frame(&value.into(), self.compression)?;
self.request_tx.send(frame).await.map_err(|_| {
Error::new(
ErrorKind::Transport,
"grpc duplex request stream is already closed",
)
})
}
pub fn finish_sending(&self) {
self.request_tx.close();
}
pub async fn next_message<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
self.ensure_response_ready().await?;
self.response
.as_mut()
.expect("response ready")
.next_message()
.await
}
pub async fn next_message_bytes(&mut self) -> Result<Option<Bytes>> {
self.ensure_response_ready().await?;
self.response
.as_mut()
.expect("response ready")
.next_message_bytes()
.await
}
pub fn messages<'a, T: DeserializeOwned + 'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
if self.codec != GrpcCodec::Json {
return Err(typed_response_codec_error());
}
Ok(Box::pin(async_stream::stream! {
while let Some(item) = self.next_message_bytes().await? {
let value = serde_json::from_slice(&item).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
})?;
yield Ok(value);
}
}))
}
pub fn messages_bytes<'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
Ok(Box::pin(async_stream::stream! {
while let Some(item) = self.next_message_bytes().await? {
yield Ok(item);
}
}))
}
pub async fn finish(&mut self) -> Result<GrpcStatus> {
self.finish_sending();
self.ensure_response_ready().await?;
self.response
.as_mut()
.expect("response ready")
.finish()
.await
}
pub fn trailers(&self) -> Result<Option<HeaderMap>> {
match self.response.as_ref() {
Some(response) => response.trailers(),
None => Ok(None),
}
}
pub fn metadata(&self, name: &str) -> Vec<String> {
self.response
.as_ref()
.map(|response| response.metadata(name))
.unwrap_or_default()
}
pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
match self.response.as_ref() {
Some(response) => response.metadata_bin(name),
None => Ok(Vec::new()),
}
}
pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
self.response
.as_ref()
.map(|response| response.trailer_metadata(name))
.unwrap_or_default()
}
pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
match self.response.as_ref() {
Some(response) => response.trailer_metadata_bin(name),
None => Ok(Vec::new()),
}
}
pub fn status(&self) -> Result<Option<GrpcStatus>> {
match self.response.as_ref() {
Some(response) => response.status(),
None => Ok(None),
}
}
pub fn is_complete(&self) -> bool {
self.response
.as_ref()
.map(GrpcStreamingResponse::is_complete)
.unwrap_or(false)
}
async fn ensure_response_ready(&mut self) -> Result<()> {
if self.response.is_none() {
let response = self.response_rx.recv().await.map_err(|_| {
Error::new(
ErrorKind::Transport,
"grpc duplex task stopped before response headers arrived",
)
})??;
self.response = Some(response);
}
Ok(())
}
}
impl GrpcResponse {
pub fn metadata(&self, name: &str) -> Vec<String> {
grpc_metadata_values(&self.headers, name)
}
pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
grpc_binary_metadata_values(&self.headers, name)
}
pub async fn message<T: DeserializeOwned>(&mut self) -> Result<T> {
if self.codec != GrpcCodec::Json {
return Err(typed_response_codec_error());
}
let message = self.message_bytes().await?;
serde_json::from_slice(&message).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
})
}
pub fn messages<'a, T: DeserializeOwned + 'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
if self.codec != GrpcCodec::Json {
return Err(typed_response_codec_error());
}
let stream = self.messages_bytes()?;
Ok(Box::pin(async_stream::stream! {
let mut stream = stream;
while let Some(item) = stream.next().await {
let item = item?;
let value = serde_json::from_slice(&item).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
})?;
yield Ok(value);
}
}))
}
pub async fn message_bytes(&mut self) -> Result<Bytes> {
self.ensure_ok_status()?;
let remaining = self.messages.len().saturating_sub(self.next_message);
if remaining != 1 {
return Err(Error::new(
ErrorKind::Transport,
format!("expected exactly one grpc response message, found {remaining}"),
));
}
let message = self.messages[self.next_message].clone();
self.next_message += 1;
Ok(message)
}
pub fn messages_bytes<'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
self.ensure_ok_status()?;
let messages = self.messages[self.next_message..].to_vec();
self.next_message = self.messages.len();
Ok(Box::pin(stream::iter(messages.into_iter().map(Ok))))
}
pub fn trailers(&self) -> Result<Option<HeaderMap>> {
Ok(self.trailers.clone())
}
pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
self.trailers
.as_ref()
.map(|trailers| grpc_metadata_values(trailers, name))
.unwrap_or_default()
}
pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
match self.trailers.as_ref() {
Some(trailers) => grpc_binary_metadata_values(trailers, name),
None => Ok(Vec::new()),
}
}
pub fn status(&self) -> Result<GrpcStatus> {
Ok(self.status.clone())
}
fn ensure_ok_status(&self) -> Result<()> {
if self.status.code != 0 {
return Err(grpc_status_error(&self.status));
}
Ok(())
}
}
pub struct GrpcStreamingResponse {
codec: GrpcCodec,
decoder: GrpcResponseDecoder,
}
impl fmt::Debug for GrpcStreamingResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GrpcStreamingResponse")
.field("codec", &self.codec)
.field("complete", &self.decoder.complete)
.field("status", &self.decoder.final_status)
.finish()
}
}
impl GrpcStreamingResponse {
fn from_http_response(response: Response, codec: GrpcCodec) -> Result<Self> {
validate_grpc_http_response(&response)?;
let headers = response.headers().clone();
let compression = parse_grpc_compression(headers.get("grpc-encoding"))?;
let (body, trailers) = response.into_body_stream_and_trailer_state();
Ok(Self {
codec,
decoder: GrpcResponseDecoder {
body,
headers,
trailers,
compression,
frame_decoder: GrpcFrameDecoder::default(),
final_trailers: None,
final_status: None,
complete: false,
},
})
}
pub async fn next_message<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
if self.codec != GrpcCodec::Json {
return Err(typed_response_codec_error());
}
let Some(message) = self.next_message_bytes().await? else {
return Ok(None);
};
let value = serde_json::from_slice(&message).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
})?;
Ok(Some(value))
}
pub async fn next_message_bytes(&mut self) -> Result<Option<Bytes>> {
let next = self.decoder.next_message().await?;
if next.is_none() {
self.ensure_ok_status()?;
}
Ok(next)
}
pub fn messages<'a, T: DeserializeOwned + 'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<T>> + 'a>>> {
if self.codec != GrpcCodec::Json {
return Err(typed_response_codec_error());
}
Ok(Box::pin(async_stream::stream! {
while let Some(item) = self.next_message_bytes().await? {
let value = serde_json::from_slice(&item).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to decode grpc json message", err)
})?;
yield Ok(value);
}
}))
}
pub fn messages_bytes<'a>(
&'a mut self,
) -> Result<Pin<Box<dyn Stream<Item = Result<Bytes>> + 'a>>> {
Ok(Box::pin(async_stream::stream! {
while let Some(item) = self.next_message_bytes().await? {
yield Ok(item);
}
}))
}
pub async fn finish(&mut self) -> Result<GrpcStatus> {
while self.decoder.next_message().await?.is_some() {}
self.decoder
.final_status
.clone()
.ok_or_else(|| Error::new(ErrorKind::Transport, "grpc response did not complete"))
}
pub fn metadata(&self, name: &str) -> Vec<String> {
grpc_metadata_values(&self.decoder.headers, name)
}
pub fn metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
grpc_binary_metadata_values(&self.decoder.headers, name)
}
pub fn trailer_metadata(&self, name: &str) -> Vec<String> {
self.decoder
.final_trailers
.as_ref()
.map(|trailers| grpc_metadata_values(trailers, name))
.unwrap_or_default()
}
pub fn trailer_metadata_bin(&self, name: &str) -> Result<Vec<Bytes>> {
match self.decoder.final_trailers.as_ref() {
Some(trailers) => grpc_binary_metadata_values(trailers, name),
None => Ok(Vec::new()),
}
}
pub fn trailers(&self) -> Result<Option<HeaderMap>> {
Ok(self.decoder.final_trailers.clone())
}
pub fn status(&self) -> Result<Option<GrpcStatus>> {
Ok(self.decoder.final_status.clone())
}
pub fn is_complete(&self) -> bool {
self.decoder.complete
}
async fn into_buffered_response(mut self) -> Result<GrpcResponse> {
let mut messages = Vec::new();
while let Some(message) = self.decoder.next_message().await? {
messages.push(message);
}
let status =
self.decoder.final_status.clone().ok_or_else(|| {
Error::new(ErrorKind::Transport, "grpc response did not complete")
})?;
Ok(GrpcResponse {
codec: self.codec,
headers: self.decoder.headers.clone(),
messages,
next_message: 0,
trailers: self.decoder.final_trailers.clone(),
status,
})
}
fn ensure_ok_status(&self) -> Result<()> {
if let Some(status) = self.decoder.final_status.as_ref() {
if status.code != 0 {
return Err(grpc_status_error(status));
}
}
Ok(())
}
}
struct GrpcResponseDecoder {
body: BodyStream,
headers: HeaderMap,
trailers: TrailerState,
compression: Option<GrpcCompression>,
frame_decoder: GrpcFrameDecoder,
final_trailers: Option<HeaderMap>,
final_status: Option<GrpcStatus>,
complete: bool,
}
impl GrpcResponseDecoder {
async fn next_message(&mut self) -> Result<Option<Bytes>> {
loop {
if let Some(message) = self.frame_decoder.next_message(self.compression)? {
return Ok(Some(message));
}
if self.complete {
if self.frame_decoder.has_buffered_data() {
return Err(self.frame_decoder.incomplete_frame_error());
}
return Ok(None);
}
match self.body.next().await {
Some(chunk) => self.frame_decoder.push(chunk?),
None => self.finish_stream()?,
}
}
}
fn finish_stream(&mut self) -> Result<()> {
if self.complete {
return Ok(());
}
self.complete = true;
let trailer_state = std::mem::replace(&mut self.trailers, TrailerState::Ready(None));
let trailers = match trailer_state.take() {
Some(trailers) => Some(trailers),
None if self.headers.get("grpc-status").is_some()
|| self.headers.get("grpc-message").is_some() =>
{
Some(self.headers.clone())
}
None => None,
};
self.final_status = Some(parse_grpc_status(trailers.as_ref(), &self.headers)?);
self.final_trailers = trailers;
Ok(())
}
}
#[derive(Default)]
struct GrpcFrameDecoder {
buffer: BytesMut,
}
impl GrpcFrameDecoder {
fn push(&mut self, chunk: Bytes) {
self.buffer.extend_from_slice(&chunk);
}
fn next_message(&mut self, compression: Option<GrpcCompression>) -> Result<Option<Bytes>> {
if self.buffer.len() < 5 {
return Ok(None);
}
let compressed_flag = self.buffer[0];
if compressed_flag > 1 {
return Err(Error::new(
ErrorKind::Transport,
format!("invalid grpc compression flag: {compressed_flag}"),
));
}
let message_len = u32::from_be_bytes([
self.buffer[1],
self.buffer[2],
self.buffer[3],
self.buffer[4],
]) as usize;
if self.buffer.len() < 5 + message_len {
return Ok(None);
}
let frame = self.buffer.split_to(5 + message_len);
let payload = &frame[5..];
let message = if compressed_flag == 1 {
let compression = compression.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"grpc frame is compressed but grpc-encoding is missing",
)
})?;
decompress_grpc_message(payload, compression)?
} else {
Bytes::copy_from_slice(payload)
};
Ok(Some(message))
}
fn has_buffered_data(&self) -> bool {
!self.buffer.is_empty()
}
fn incomplete_frame_error(&self) -> Error {
if self.buffer.len() < 5 {
Error::new(ErrorKind::Transport, "incomplete grpc frame header")
} else {
Error::new(
ErrorKind::Transport,
"grpc frame length exceeds remaining response body",
)
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GrpcStatus {
code: i32,
message: String,
details_bin: Option<Bytes>,
}
impl GrpcStatus {
pub fn new(code: i32, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
details_bin: None,
}
}
fn with_details(code: i32, message: impl Into<String>, details_bin: Option<Bytes>) -> Self {
Self {
code,
message: message.into(),
details_bin,
}
}
pub fn code(&self) -> i32 {
self.code
}
pub fn message(&self) -> &str {
&self.message
}
pub fn details_bin(&self) -> Option<&[u8]> {
self.details_bin.as_deref()
}
}
fn validate_grpc_http_response(response: &Response) -> Result<()> {
if !matches!(response.version(), Version::Http2 | Version::Http3) {
return Err(Error::new(
ErrorKind::Transport,
"grpc requires an HTTP/2 or HTTP/3 response",
));
}
if response.status().as_u16() != 200 {
return Err(Error::new(
ErrorKind::Transport,
format!(
"unexpected grpc http status: {}",
response.status().as_u16()
),
));
}
let Some(content_type) = response.headers().get("content-type") else {
return Err(Error::new(
ErrorKind::Transport,
"grpc response did not include content-type",
));
};
if !content_type
.to_ascii_lowercase()
.starts_with("application/grpc")
{
return Err(Error::new(
ErrorKind::Transport,
format!("unexpected grpc content-type: {content_type}"),
));
}
Ok(())
}
fn encode_grpc_timeout(duration: Duration) -> String {
if duration.as_nanos() == 0 {
return "1n".to_owned();
}
const UNITS: &[(u128, char)] = &[
(3_600_000_000_000u128, 'H'),
(60_000_000_000u128, 'M'),
(1_000_000_000u128, 'S'),
(1_000_000u128, 'm'),
(1_000u128, 'u'),
(1u128, 'n'),
];
let nanos = duration.as_nanos();
for (unit_nanos, suffix) in UNITS {
let value = nanos / unit_nanos;
if value > 0 && value <= 99_999_999 && nanos % unit_nanos == 0 {
return format!("{value}{suffix}");
}
}
for (unit_nanos, suffix) in UNITS.iter().rev() {
let value = nanos.div_ceil(*unit_nanos);
if value <= 99_999_999 {
return format!("{value}{suffix}");
}
}
"99999999H".to_owned()
}
fn encode_json_message<T: Serialize>(value: &T) -> Result<Bytes> {
let payload = serde_json::to_vec(value).map_err(|err| {
Error::with_source(ErrorKind::Decode, "failed to encode grpc json message", err)
})?;
Ok(Bytes::from(payload))
}
fn frame_message_stream(
stream: GrpcMessageStream,
compression: Option<GrpcCompression>,
) -> BodyStream {
Box::pin(async_stream::stream! {
let mut stream = stream;
while let Some(item) = stream.next().await {
let item = item?;
yield Ok(encode_grpc_frame(&item, compression)?);
}
})
}
fn encode_grpc_frame(payload: &[u8], compression: Option<GrpcCompression>) -> Result<Bytes> {
let (compressed_flag, payload) = match compression {
Some(GrpcCompression::Gzip) => (1_u8, gzip_compress(payload)?),
None => (0_u8, payload.to_vec()),
};
let mut framed = Vec::with_capacity(5 + payload.len());
framed.push(compressed_flag);
framed.extend_from_slice(&(payload.len() as u32).to_be_bytes());
framed.extend_from_slice(&payload);
Ok(Bytes::from(framed))
}
#[cfg(test)]
fn decode_grpc_frames(body: Bytes, compression: Option<GrpcCompression>) -> Result<Vec<Bytes>> {
let mut decoder = GrpcFrameDecoder::default();
decoder.push(body);
let mut messages = Vec::new();
while let Some(message) = decoder.next_message(compression)? {
messages.push(message);
}
if decoder.has_buffered_data() {
return Err(decoder.incomplete_frame_error());
}
Ok(messages)
}
fn parse_grpc_status(trailers: Option<&HeaderMap>, headers: &HeaderMap) -> Result<GrpcStatus> {
let code = trailers
.and_then(|map| map.get("grpc-status"))
.or_else(|| headers.get("grpc-status"))
.ok_or_else(|| {
Error::new(
ErrorKind::Transport,
"grpc response did not include grpc-status",
)
})?;
let code = code.parse::<i32>().map_err(|err| {
Error::with_source(ErrorKind::Transport, "invalid grpc-status value", err)
})?;
let message = trailers
.and_then(|map| map.get("grpc-message"))
.or_else(|| headers.get("grpc-message"))
.map(decode_grpc_message_header)
.transpose()?
.unwrap_or_default();
let details_bin = trailers
.and_then(|map| map.get("grpc-status-details-bin"))
.or_else(|| headers.get("grpc-status-details-bin"))
.map(decode_grpc_binary_header)
.transpose()?;
Ok(GrpcStatus::with_details(code, message, details_bin))
}
fn grpc_metadata_values(headers: &HeaderMap, name: &str) -> Vec<String> {
headers
.get_all(name)
.into_iter()
.map(str::to_owned)
.collect()
}
fn grpc_binary_metadata_values(headers: &HeaderMap, name: &str) -> Result<Vec<Bytes>> {
let name = normalize_grpc_binary_metadata_name(name)?;
headers
.get_all(&name)
.into_iter()
.map(decode_grpc_binary_header)
.collect()
}
fn normalize_grpc_binary_metadata_name(name: &str) -> Result<String> {
let name = name.trim();
if name.is_empty() {
return Err(Error::new(
ErrorKind::InvalidHeaderName,
"grpc binary metadata name cannot be empty",
));
}
if name.ends_with("-bin") {
return Ok(name.to_ascii_lowercase());
}
Ok(format!("{}-bin", name.to_ascii_lowercase()))
}
fn encode_grpc_binary_header(bytes: &[u8]) -> String {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut output = String::with_capacity(bytes.len().div_ceil(3) * 4);
for chunk in bytes.chunks(3) {
let b0 = chunk[0];
let b1 = *chunk.get(1).unwrap_or(&0);
let b2 = *chunk.get(2).unwrap_or(&0);
let n = ((b0 as u32) << 16) | ((b1 as u32) << 8) | b2 as u32;
output.push(TABLE[((n >> 18) & 0x3F) as usize] as char);
output.push(TABLE[((n >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
output.push(TABLE[((n >> 6) & 0x3F) as usize] as char);
} else {
output.push('=');
}
if chunk.len() > 2 {
output.push(TABLE[(n & 0x3F) as usize] as char);
} else {
output.push('=');
}
}
output
}
fn decode_grpc_binary_header(value: &str) -> Result<Bytes> {
let bytes = value.trim().as_bytes();
if bytes.is_empty() {
return Ok(Bytes::new());
}
let mut output = Vec::with_capacity((bytes.len() / 4) * 3);
let mut chunk = [0_u8; 4];
let mut chunk_len = 0usize;
let mut padding = 0usize;
for &byte in bytes {
if byte == b'=' {
chunk[chunk_len] = 0;
chunk_len += 1;
padding += 1;
} else {
chunk[chunk_len] = decode_base64_value(byte)?;
chunk_len += 1;
}
if chunk_len == 4 {
output.push((chunk[0] << 2) | (chunk[1] >> 4));
if padding < 2 {
output.push((chunk[1] << 4) | (chunk[2] >> 2));
}
if padding == 0 {
output.push((chunk[2] << 6) | chunk[3]);
}
chunk_len = 0;
padding = 0;
}
}
match chunk_len {
0 => {}
2 => {
output.push((chunk[0] << 2) | (chunk[1] >> 4));
}
3 => {
output.push((chunk[0] << 2) | (chunk[1] >> 4));
output.push((chunk[1] << 4) | (chunk[2] >> 2));
}
_ => {
return Err(Error::new(
ErrorKind::Transport,
"invalid grpc binary metadata encoding",
));
}
}
Ok(Bytes::from(output))
}
fn decode_base64_value(byte: u8) -> Result<u8> {
match byte {
b'A'..=b'Z' => Ok(byte - b'A'),
b'a'..=b'z' => Ok(byte - b'a' + 26),
b'0'..=b'9' => Ok(byte - b'0' + 52),
b'+' => Ok(62),
b'/' => Ok(63),
_ => Err(Error::new(
ErrorKind::Transport,
"invalid grpc binary metadata encoding",
)),
}
}
fn parse_grpc_compression(value: Option<&str>) -> Result<Option<GrpcCompression>> {
match value.map(str::trim) {
None | Some("") => Ok(None),
Some(value) if value.eq_ignore_ascii_case("identity") => Ok(None),
Some(value) if value.eq_ignore_ascii_case("gzip") => Ok(Some(GrpcCompression::Gzip)),
Some(value) => Err(Error::new(
ErrorKind::Transport,
format!("unsupported grpc compression algorithm: {value}"),
)),
}
}
fn gzip_compress(payload: &[u8]) -> Result<Vec<u8>> {
let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
encoder.write_all(payload).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to gzip grpc message", err)
})?;
encoder.finish().map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"failed to finish grpc gzip encoding",
err,
)
})
}
fn decompress_grpc_message(payload: &[u8], compression: GrpcCompression) -> Result<Bytes> {
match compression {
GrpcCompression::Gzip => {
let mut decoder = GzDecoder::new(payload);
let mut decoded = Vec::new();
decoder.read_to_end(&mut decoded).map_err(|err| {
Error::with_source(ErrorKind::Transport, "failed to gunzip grpc message", err)
})?;
Ok(Bytes::from(decoded))
}
}
}
fn decode_grpc_message_header(value: &str) -> Result<String> {
let bytes = value.as_bytes();
let mut decoded = Vec::with_capacity(bytes.len());
let mut index = 0usize;
while index < bytes.len() {
if bytes[index] == b'%' {
if index + 2 >= bytes.len() {
return Err(Error::new(
ErrorKind::Transport,
"invalid grpc-message percent encoding",
));
}
let hex = std::str::from_utf8(&bytes[index + 1..index + 3]).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"invalid grpc-message percent encoding",
err,
)
})?;
let byte = u8::from_str_radix(hex, 16).map_err(|err| {
Error::with_source(
ErrorKind::Transport,
"invalid grpc-message percent encoding",
err,
)
})?;
decoded.push(byte);
index += 3;
continue;
}
decoded.push(bytes[index]);
index += 1;
}
String::from_utf8(decoded).map_err(|err| {
Error::with_source(ErrorKind::Transport, "grpc-message is not valid utf-8", err)
})
}
fn grpc_status_error(status: &GrpcStatus) -> Error {
let message = if status.message.is_empty() {
format!("grpc request failed with status {}", status.code)
} else {
format!(
"grpc request failed with status {}: {}",
status.code, status.message
)
};
Error::new(ErrorKind::Transport, message)
}
fn duplex_payload_conflict_error() -> Error {
Error::new(
ErrorKind::Transport,
"grpc duplex call manages request messages itself; do not combine open_duplex() with message/messages/message_bytes/messages_bytes",
)
}
fn typed_request_codec_error() -> Error {
Error::new(
ErrorKind::Transport,
"typed grpc request APIs only support GrpcCodec::Json; use message_bytes/messages_bytes for protobuf payloads",
)
}
fn typed_response_codec_error() -> Error {
Error::new(
ErrorKind::Transport,
"typed grpc response APIs only support GrpcCodec::Json; use message_bytes/messages_bytes for protobuf payloads",
)
}
#[cfg(test)]
mod tests {
use super::{
GrpcCodec, GrpcCompression, GrpcFrameDecoder, decode_grpc_binary_header,
decode_grpc_frames, decode_grpc_message_header, encode_grpc_binary_header,
encode_grpc_frame, encode_grpc_timeout, normalize_grpc_binary_metadata_name,
parse_grpc_status, validate_grpc_http_response,
};
use bytes::Bytes;
use std::time::Duration;
use crate::{Body, HeaderMap, StatusCode, Url, Version};
#[test]
fn grpc_frame_round_trip_keeps_message_boundaries() {
let first = encode_grpc_frame(br#"{"name":"one"}"#, None).unwrap();
let second = encode_grpc_frame(br#"{"name":"two"}"#, None).unwrap();
let combined = Bytes::from([first.as_ref(), second.as_ref()].concat());
let frames = decode_grpc_frames(combined, None).unwrap();
assert_eq!(frames.len(), 2);
assert_eq!(frames[0], Bytes::from_static(br#"{"name":"one"}"#));
assert_eq!(frames[1], Bytes::from_static(br#"{"name":"two"}"#));
}
#[test]
fn grpc_gzip_frame_round_trip_restores_original_payload() {
let frame = encode_grpc_frame(br#"{"name":"gzip"}"#, Some(GrpcCompression::Gzip)).unwrap();
let frames = decode_grpc_frames(frame, Some(GrpcCompression::Gzip)).unwrap();
assert_eq!(frames, vec![Bytes::from_static(br#"{"name":"gzip"}"#)]);
}
#[test]
fn grpc_frame_decoder_handles_split_frame_boundaries() {
let frame = encode_grpc_frame(br#"{"name":"split"}"#, None).unwrap();
let mut decoder = GrpcFrameDecoder::default();
decoder.push(frame.slice(..2));
assert!(decoder.next_message(None).unwrap().is_none());
decoder.push(frame.slice(2..7));
assert!(decoder.next_message(None).unwrap().is_none());
decoder.push(frame.slice(7..));
assert_eq!(
decoder.next_message(None).unwrap(),
Some(Bytes::from_static(br#"{"name":"split"}"#))
);
}
#[test]
fn grpc_message_header_percent_decodes() {
assert_eq!(
decode_grpc_message_header("user%20not%20found").unwrap(),
"user not found"
);
}
#[test]
fn open_duplex_rejects_preconfigured_payload_before_network() {
let err = futures_lite::future::block_on(async {
crate::grpc("https://example.com/chat.Service/Talk")
.message_bytes(Bytes::from_static(b"hello"))?
.open_duplex()
.await
})
.unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
assert!(
err.to_string()
.contains("grpc duplex call manages request messages itself")
);
}
#[test]
fn protobuf_typed_request_api_returns_explicit_error_before_network() {
let err = futures_lite::future::block_on(async {
crate::grpc("https://example.com/greeter.SayHello/Call")
.codec(GrpcCodec::Protobuf)
.message(&serde_json::json!({ "name": "Ada" }))?
.await
})
.unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
assert!(
err.to_string()
.contains("typed grpc request APIs only support GrpcCodec::Json")
);
}
#[test]
fn grpc_timeout_header_uses_smallest_fitting_unit() {
assert_eq!(encode_grpc_timeout(Duration::from_millis(1500)), "1500m");
assert_eq!(encode_grpc_timeout(Duration::from_secs(12)), "12S");
assert_eq!(encode_grpc_timeout(Duration::from_nanos(1)), "1n");
}
#[test]
fn grpc_response_requires_content_type() {
let response = crate::Response::new(
StatusCode::OK,
Version::Http2,
Url::parse("https://example.com/grpc").unwrap(),
HeaderMap::new(),
Some({
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "0").unwrap();
trailers
}),
Body::default(),
);
let err = validate_grpc_http_response(&response).unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
assert!(err.to_string().contains("did not include content-type"));
}
#[test]
fn grpc_response_rejects_non_grpc_content_type() {
let mut headers = HeaderMap::new();
headers.insert("content-type", "application/json").unwrap();
let response = crate::Response::new(
StatusCode::OK,
Version::Http2,
Url::parse("https://example.com/grpc").unwrap(),
headers,
Some({
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "0").unwrap();
trailers
}),
Body::default(),
);
let err = validate_grpc_http_response(&response).unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
assert!(err.to_string().contains("unexpected grpc content-type"));
}
#[test]
fn grpc_binary_metadata_name_normalizes_suffix() {
assert_eq!(
normalize_grpc_binary_metadata_name("trace").unwrap(),
"trace-bin"
);
assert_eq!(
normalize_grpc_binary_metadata_name("trace-bin").unwrap(),
"trace-bin"
);
}
#[test]
fn grpc_binary_metadata_round_trips() {
let encoded = encode_grpc_binary_header(b"\x00\x01grpc");
let decoded = decode_grpc_binary_header(&encoded).unwrap();
assert_eq!(decoded, Bytes::from_static(b"\x00\x01grpc"));
}
#[test]
fn grpc_binary_header_decodes_unpadded_base64() {
let decoded = decode_grpc_binary_header("AA").unwrap();
assert_eq!(&decoded[..], b"\x00");
let decoded = decode_grpc_binary_header("AAA").unwrap();
assert_eq!(&decoded[..], b"\x00\x00");
let decoded = decode_grpc_binary_header("AAAA").unwrap();
assert_eq!(&decoded[..], b"\x00\x00\x00");
let raw = b"\xde\xad\xbe\xef";
let padded = encode_grpc_binary_header(raw);
let unpadded = padded.trim_end_matches('=');
let decoded = decode_grpc_binary_header(unpadded).unwrap();
assert_eq!(&decoded[..], raw);
}
#[test]
fn grpc_binary_header_rejects_invalid_chunk_length_of_one() {
let err = decode_grpc_binary_header("A").unwrap_err();
assert_eq!(err.kind(), &crate::ErrorKind::Transport);
}
#[test]
fn grpc_status_exposes_status_details_bin() {
let mut trailers = HeaderMap::new();
trailers.insert("grpc-status", "7").unwrap();
trailers.insert("grpc-message", "denied").unwrap();
trailers
.insert(
"grpc-status-details-bin",
encode_grpc_binary_header(b"details").as_str(),
)
.unwrap();
let status = parse_grpc_status(Some(&trailers), &HeaderMap::new()).unwrap();
assert_eq!(status.code(), 7);
assert_eq!(status.message(), "denied");
assert_eq!(status.details_bin(), Some(&b"details"[..]));
}
}