use std::collections::HashMap;
use std::marker::PhantomData;
use std::pin::Pin;
use std::time::Duration;
use bytes::Bytes;
use bytes::BytesMut;
use http::Request;
use http::Response;
use http::Uri;
use http_body::Body;
use http_body_util::BodyExt;
use http_body_util::Full;
use http_body_util::combinators::BoxBody;
use buffa::view::MessageView;
use buffa::view::OwnedView;
use crate::codec::CodecFormat;
use crate::codec::content_type;
use crate::codec::header as connect_header;
use crate::compression::CompressionPolicy;
use crate::compression::CompressionRegistry;
use crate::envelope::Envelope;
use crate::error::ConnectError;
use crate::error::ErrorCode;
use crate::error::ErrorDetail;
use crate::protocol::Protocol;
pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
pub type ClientBody = BoxBody<Bytes, ConnectError>;
#[inline]
pub fn full_body(b: Bytes) -> ClientBody {
Full::new(b).map_err(|never| match never {}).boxed()
}
const RESPONSE_BUFFER_TRAILER_SLACK: usize = 64 * 1024;
pub trait ClientTransport: Clone + Send + Sync + 'static {
type ResponseBody: Body<Data = Bytes> + Send + 'static;
type Error: std::error::Error + Send + Sync + 'static;
fn send(
&self,
request: Request<ClientBody>,
) -> BoxFuture<'static, Result<Response<Self::ResponseBody>, Self::Error>>;
}
#[derive(Clone)]
pub struct ServiceTransport<S> {
service: S,
}
impl<S> ServiceTransport<S> {
pub fn new(service: S) -> Self {
Self { service }
}
pub fn inner(&self) -> &S {
&self.service
}
pub fn inner_mut(&mut self) -> &mut S {
&mut self.service
}
pub fn into_inner(self) -> S {
self.service
}
}
impl<S, ResBody> ClientTransport for ServiceTransport<S>
where
S: tower::Service<Request<ClientBody>, Response = Response<ResBody>>
+ Clone
+ Send
+ Sync
+ 'static,
S::Error: std::error::Error + Send + Sync + 'static,
S::Future: Send + 'static,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: std::error::Error + Send + Sync + 'static,
{
type ResponseBody = ResBody;
type Error = S::Error;
fn send(
&self,
request: Request<ClientBody>,
) -> BoxFuture<'static, Result<Response<Self::ResponseBody>, Self::Error>> {
use tower::ServiceExt;
let service = self.service.clone();
Box::pin(service.oneshot(request))
}
}
#[cfg(feature = "client")]
mod http2;
#[cfg(feature = "client")]
pub use http2::Http2Connection;
#[cfg(feature = "client")]
pub use http2::SharedHttp2Connection;
#[cfg(feature = "client")]
#[derive(Clone)]
pub struct HttpClient {
inner: HttpClientInner,
}
#[cfg(feature = "client")]
impl std::fmt::Debug for HttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mode = match self.inner {
HttpClientInner::Plain(_) => "plaintext",
#[cfg(feature = "client-tls")]
HttpClientInner::Tls(_) => "tls",
};
f.debug_struct("HttpClient").field("mode", &mode).finish()
}
}
#[cfg(feature = "client")]
#[derive(Clone)]
enum HttpClientInner {
Plain(
hyper_util::client::legacy::Client<
hyper_util::client::legacy::connect::HttpConnector,
ClientBody,
>,
),
#[cfg(feature = "client-tls")]
Tls(
hyper_util::client::legacy::Client<
hyper_rustls::HttpsConnector<hyper_util::client::legacy::connect::HttpConnector>,
ClientBody,
>,
),
}
#[cfg(feature = "client")]
impl HttpClient {
pub fn plaintext() -> Self {
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
let mut connector = HttpConnector::new();
connector.set_nodelay(true);
let client = Client::builder(TokioExecutor::new()).build(connector);
Self {
inner: HttpClientInner::Plain(client),
}
}
pub fn plaintext_http2_only() -> Self {
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
let mut connector = HttpConnector::new();
connector.set_nodelay(true);
let client = Client::builder(TokioExecutor::new())
.http2_only(true)
.build(connector);
Self {
inner: HttpClientInner::Plain(client),
}
}
#[cfg(feature = "client-tls")]
pub fn with_tls(tls_config: std::sync::Arc<rustls::ClientConfig>) -> Self {
use hyper_util::client::legacy::Client;
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::rt::TokioExecutor;
let mut http = HttpConnector::new();
http.set_nodelay(true);
http.enforce_http(false);
let mut cfg = (*tls_config).clone();
cfg.alpn_protocols.clear();
let https = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(cfg)
.https_only()
.enable_all_versions()
.wrap_connector(http);
let client = Client::builder(TokioExecutor::new()).build(https);
Self {
inner: HttpClientInner::Tls(client),
}
}
}
#[cfg(feature = "client")]
impl ClientTransport for HttpClient {
type ResponseBody = hyper::body::Incoming;
type Error = ConnectError;
fn send(
&self,
request: Request<ClientBody>,
) -> BoxFuture<'static, Result<Response<Self::ResponseBody>, Self::Error>> {
let scheme = request.uri().scheme_str();
match &self.inner {
HttpClientInner::Plain(client) => {
if scheme == Some("https") {
return Box::pin(async {
Err(ConnectError::invalid_argument(
"HttpClient::plaintext() received https:// URI; \
use HttpClient::with_tls for TLS",
))
});
}
let client = client.clone();
Box::pin(async move {
client
.request(request)
.await
.map_err(|e| ConnectError::unavailable(format!("HTTP request failed: {e}")))
})
}
#[cfg(feature = "client-tls")]
HttpClientInner::Tls(client) => {
if scheme == Some("http") {
return Box::pin(async {
Err(ConnectError::invalid_argument(
"HttpClient::with_tls() received http:// URI; \
use HttpClient::plaintext for cleartext",
))
});
}
let client = client.clone();
Box::pin(async move {
client.request(request).await.map_err(|e| {
ConnectError::unavailable(format!("HTTPS request failed: {e}"))
})
})
}
}
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct ClientConfig {
pub base_uri: Uri,
pub protocol: Protocol,
pub codec_format: CodecFormat,
pub compression: CompressionRegistry,
pub request_compression: Option<String>,
pub compression_policy: CompressionPolicy,
pub default_timeout: Option<Duration>,
pub default_max_message_size: Option<usize>,
pub default_headers: http::HeaderMap,
}
impl ClientConfig {
pub fn new(base_uri: Uri) -> Self {
Self {
base_uri,
protocol: Protocol::Connect,
codec_format: CodecFormat::Proto,
compression: CompressionRegistry::default(),
request_compression: None,
compression_policy: CompressionPolicy::default(),
default_timeout: None,
default_max_message_size: None,
default_headers: http::HeaderMap::new(),
}
}
#[must_use]
pub fn protocol(mut self, protocol: Protocol) -> Self {
self.protocol = protocol;
self
}
#[must_use]
pub fn codec_format(mut self, format: CodecFormat) -> Self {
self.codec_format = format;
self
}
#[must_use]
pub fn json(mut self) -> Self {
self.codec_format = CodecFormat::Json;
self
}
#[must_use]
pub fn proto(mut self) -> Self {
self.codec_format = CodecFormat::Proto;
self
}
#[must_use]
pub fn compression(mut self, registry: CompressionRegistry) -> Self {
self.compression = registry;
self
}
#[must_use]
pub fn compress_requests(mut self, encoding: impl Into<String>) -> Self {
self.request_compression = Some(encoding.into());
self
}
#[must_use]
pub fn compression_policy(mut self, policy: CompressionPolicy) -> Self {
self.compression_policy = policy;
self
}
#[must_use]
pub fn default_timeout(mut self, timeout: Duration) -> Self {
self.default_timeout = Some(timeout);
self
}
#[must_use]
pub fn default_max_message_size(mut self, size: usize) -> Self {
self.default_max_message_size = Some(size);
self
}
#[must_use]
pub fn default_header(
mut self,
name: impl TryInto<http::header::HeaderName>,
value: impl TryInto<http::header::HeaderValue>,
) -> Self {
if let (Ok(name), Ok(value)) = (name.try_into(), value.try_into()) {
self.default_headers.append(name, value);
}
self
}
#[must_use]
pub fn default_headers(mut self, headers: http::HeaderMap) -> Self {
self.default_headers = headers;
self
}
}
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct CallOptions {
pub headers: http::HeaderMap,
pub timeout: Option<Duration>,
pub max_message_size: Option<usize>,
pub compress: Option<bool>,
}
impl CallOptions {
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn with_header(
mut self,
name: impl TryInto<http::header::HeaderName>,
value: impl TryInto<http::header::HeaderValue>,
) -> Self {
if let (Ok(name), Ok(value)) = (name.try_into(), value.try_into()) {
self.headers.append(name, value);
}
self
}
pub fn try_with_header(
mut self,
name: impl TryInto<http::header::HeaderName>,
value: impl TryInto<http::header::HeaderValue>,
) -> Result<Self, ConnectError> {
let name = name
.try_into()
.map_err(|_| ConnectError::internal("invalid header name"))?;
let value = value
.try_into()
.map_err(|_| ConnectError::internal("invalid header value"))?;
self.headers.append(name, value);
Ok(self)
}
#[must_use]
pub fn with_headers(
mut self,
headers: impl IntoIterator<Item = (http::header::HeaderName, http::header::HeaderValue)>,
) -> Self {
for (name, value) in headers {
self.headers.append(name, value);
}
self
}
#[must_use]
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.max_message_size = Some(size);
self
}
#[must_use]
pub fn with_compression(mut self, enabled: bool) -> Self {
self.compress = Some(enabled);
self
}
}
fn effective_options(config: &ClientConfig, options: CallOptions) -> CallOptions {
CallOptions {
timeout: options.timeout.or(config.default_timeout),
max_message_size: options.max_message_size.or(config.default_max_message_size),
compress: options.compress,
headers: merge_headers(&config.default_headers, options.headers),
}
}
fn merge_headers(config_defaults: &http::HeaderMap, options: http::HeaderMap) -> http::HeaderMap {
if config_defaults.is_empty() {
return options;
}
if options.is_empty() {
return config_defaults.clone();
}
let mut merged = config_defaults.clone();
for name in options.keys() {
merged.remove(name);
}
for (name, value) in options.iter() {
merged.append(name.clone(), value.clone());
}
merged
}
async fn with_deadline<F, T>(
deadline: Option<std::time::Instant>,
fut: F,
) -> Result<T, ConnectError>
where
F: Future<Output = Result<T, ConnectError>>,
{
match deadline {
None => fut.await,
Some(d) => {
let tokio_deadline = tokio::time::Instant::from_std(d);
tokio::time::timeout_at(tokio_deadline, fut)
.await
.map_err(|_| ConnectError::deadline_exceeded("client-side deadline exceeded"))?
}
}
}
#[derive(Debug)]
pub struct UnaryResponse<Resp> {
headers: http::HeaderMap,
body: Resp,
trailers: http::HeaderMap,
}
impl<Resp> UnaryResponse<Resp> {
#[must_use]
pub fn headers(&self) -> &http::HeaderMap {
&self.headers
}
#[must_use]
pub fn view(&self) -> &Resp {
&self.body
}
#[must_use]
pub fn into_view(self) -> Resp {
self.body
}
#[must_use]
pub fn trailers(&self) -> &http::HeaderMap {
&self.trailers
}
#[must_use]
pub fn into_parts(self) -> (http::HeaderMap, Resp, http::HeaderMap) {
(self.headers, self.body, self.trailers)
}
}
impl<V> UnaryResponse<OwnedView<V>>
where
V: MessageView<'static>,
{
#[must_use]
pub fn into_owned(self) -> V::Owned {
self.body.to_owned_message()
}
}
fn decode_response_view<RespView>(
data: Bytes,
format: CodecFormat,
) -> Result<OwnedView<RespView>, ConnectError>
where
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
match format {
CodecFormat::Proto => OwnedView::<RespView>::decode(data)
.map_err(|e| ConnectError::internal(format!("failed to decode response: {e}"))),
CodecFormat::Json => {
let owned: RespView::Owned = serde_json::from_slice(&data).map_err(|e| {
ConnectError::internal(format!("failed to decode JSON response: {e}"))
})?;
OwnedView::<RespView>::from_owned(&owned)
.map_err(|e| ConnectError::internal(format!("failed to re-encode for view: {e}")))
}
}
}
pub async fn call_unary<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
service: &str,
method: &str,
request: Req,
options: CallOptions,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
T: ClientTransport,
<T::ResponseBody as Body>::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let options = effective_options(config, options);
let base_str = config.base_uri.to_string();
let base_str = base_str.trim_end_matches('/');
let full_uri = format!("{base_str}/{service}/{method}");
let uri: Uri = full_uri
.parse()
.map_err(|e| ConnectError::internal(format!("invalid URI: {e}")))?;
let body = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};
let (body, applied_content_encoding) = match config.protocol {
Protocol::Grpc | Protocol::GrpcWeb => {
let compression_for_encoder = config.request_compression.as_ref().map(|enc| {
(
std::sync::Arc::new(config.compression.clone()),
enc.as_str(),
)
});
let mut encoder = crate::envelope::EnvelopeEncoder::new(
compression_for_encoder,
config.compression_policy.with_override(options.compress),
);
let mut buf = bytes::BytesMut::new();
tokio_util::codec::Encoder::encode(&mut encoder, body, &mut buf)
.map_err(|e| ConnectError::internal(format!("envelope encode failed: {e}")))?;
(buf.freeze(), None)
}
Protocol::Connect => {
if let Some(ref encoding) = config.request_compression {
let effective_policy = config.compression_policy.with_override(options.compress);
if effective_policy.should_compress(body.len()) {
let compressed = config.compression.compress(encoding, &body)?;
(compressed, Some(encoding.as_str()))
} else {
(body, None)
}
} else {
(body, None)
}
}
};
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
let mut builder = Request::builder().method(http::Method::POST).uri(uri);
builder = add_unary_request_headers(builder, config, options.timeout, applied_content_encoding);
let headers = builder.headers_mut().unwrap();
for (name, value) in &options.headers {
headers.append(name.clone(), value.clone());
}
let http_request = builder
.body(full_body(body))
.map_err(|e| ConnectError::internal(format!("failed to build request: {e}")))?;
with_deadline(deadline, async {
let response = transport
.send(http_request)
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")))?;
match config.protocol {
Protocol::Connect => parse_connect_unary_response(response, config, &options).await,
Protocol::Grpc | Protocol::GrpcWeb => {
parse_grpc_unary_response(response, config, &options, deadline).await
}
}
})
.await
}
pub async fn call_unary_get<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
service: &str,
method: &str,
request: Req,
options: CallOptions,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
T: ClientTransport,
<T::ResponseBody as Body>::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
if !matches!(config.protocol, Protocol::Connect) {
return Err(ConnectError::invalid_argument(
"call_unary_get requires Protocol::Connect (gRPC/gRPC-Web are POST-only)",
));
}
let options = effective_options(config, options);
let base_str = config.base_uri.to_string();
let base_str = base_str.trim_end_matches('/');
let body = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};
let (payload, compressed_with) = if let Some(ref encoding) = config.request_compression {
let effective_policy = config.compression_policy.with_override(options.compress);
if effective_policy.should_compress(body.len()) {
let compressed = config.compression.compress(encoding, &body)?;
(compressed, Some(encoding.as_str()))
} else {
(body, None)
}
} else {
(body, None)
};
let is_binary_codec = matches!(config.codec_format, CodecFormat::Proto);
let use_base64 = is_binary_codec || compressed_with.is_some();
let encoded_message = if use_base64 {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(&payload)
} else {
percent_encoding::percent_encode(&payload, percent_encoding::NON_ALPHANUMERIC).to_string()
};
let encoding_name = match config.codec_format {
CodecFormat::Proto => "proto",
CodecFormat::Json => "json",
};
let mut query = format!("connect=v1&encoding={encoding_name}&message={encoded_message}");
if use_base64 {
query.push_str("&base64=1");
}
if let Some(enc) = compressed_with {
query.push_str("&compression=");
query.push_str(enc);
}
let full_uri = format!("{base_str}/{service}/{method}?{query}");
let uri: Uri = full_uri
.parse()
.map_err(|e| ConnectError::internal(format!("invalid GET URI: {e}")))?;
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
let mut builder = Request::builder().method(http::Method::GET).uri(uri);
if let Some(timeout) = options.timeout {
builder = builder.header(
crate::codec::header::TIMEOUT_MS,
format_timeout(timeout, Protocol::Connect),
);
}
let accept = config.compression.accept_encoding_header();
if !accept.is_empty() {
builder = builder.header(http::header::ACCEPT_ENCODING, accept);
}
let headers = builder.headers_mut().unwrap();
for (name, value) in &options.headers {
headers.append(name.clone(), value.clone());
}
let http_request = builder
.body(full_body(Bytes::new()))
.map_err(|e| ConnectError::internal(format!("failed to build GET request: {e}")))?;
with_deadline(deadline, async {
let response = transport
.send(http_request)
.await
.map_err(|e| ConnectError::unavailable(format!("GET request failed: {e}")))?;
parse_connect_unary_response(response, config, &options).await
})
.await
}
async fn parse_connect_unary_response<B, RespView>(
response: Response<B>,
config: &ClientConfig,
options: &CallOptions,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
B: Body<Data = Bytes> + Send,
B::Error: std::fmt::Display,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let status = response.status();
if !status.is_success() {
let response_headers = response.headers().clone();
let mut trailers = http::HeaderMap::new();
let mut headers = http::HeaderMap::new();
for (name, value) in &response_headers {
if let Some(trailer_name) = name.as_str().strip_prefix("trailer-") {
if let Ok(name) = http::header::HeaderName::from_bytes(trailer_name.as_bytes()) {
trailers.append(name, value.clone());
}
} else {
headers.append(name.clone(), value.clone());
}
}
let error_encoding = response_headers
.get(http::header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let max_err_body_size = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let body = collect_body_bounded(response.into_body(), max_err_body_size)
.await
.map_err(|mut e| {
e.response_headers = headers.clone();
e.trailers = trailers.clone();
e
})?;
let body = match error_encoding {
Some(encoding) => {
match config
.compression
.decompress_with_limit(&encoding, body, max_err_body_size)
{
Ok(decompressed) => decompressed,
Err(e) => {
tracing::debug!(
"failed to decompress Connect error response ({encoding}): {e}"
);
let mut err = ConnectError::new(
http_status_to_error_code(status),
format!("HTTP error {}", status.as_u16()),
);
err.response_headers = headers;
err.trailers = trailers;
return Err(err);
}
}
}
None => body,
};
if let Ok(error) = serde_json::from_slice::<ConnectErrorResponse>(&body) {
let code = error
.code
.as_deref()
.and_then(|s| s.parse::<ErrorCode>().ok())
.unwrap_or_else(|| http_status_to_error_code(status));
let mut err = ConnectError::new(code, error.message.unwrap_or_default());
err.details = error.details;
err.response_headers = headers;
err.trailers = trailers;
return Err(err);
}
let code = http_status_to_error_code(status);
let mut err = ConnectError::new(
code,
format!(
"HTTP error {}: {}",
status.as_u16(),
String::from_utf8_lossy(&body)
),
);
err.response_headers = headers;
err.trailers = trailers;
return Err(err);
}
let mut resp_headers = http::HeaderMap::new();
let mut resp_trailers = http::HeaderMap::new();
for (name, value) in response.headers() {
if let Some(trailer_name) = name.as_str().strip_prefix("trailer-") {
if let Ok(name) = http::header::HeaderName::from_bytes(trailer_name.as_bytes()) {
resp_trailers.append(name, value.clone());
}
} else {
resp_headers.append(name.clone(), value.clone());
}
}
let expected_content_type = config.codec_format.content_type();
if let Some(resp_content_type) = response.headers().get(http::header::CONTENT_TYPE) {
let ct = resp_content_type.to_str().unwrap_or("");
if !ct.starts_with(expected_content_type) {
let code = if ct.starts_with(content_type::PROTO) || ct.starts_with(content_type::JSON)
{
ErrorCode::Internal
} else {
ErrorCode::Unknown
};
return Err(ConnectError::new(
code,
format!("unexpected content-type: {ct}"),
));
}
}
let response_encoding = response
.headers()
.get(http::header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let max_message_size = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let body = collect_body_bounded(response.into_body(), max_message_size).await?;
let body = if let Some(encoding) = response_encoding {
config
.compression
.decompress_with_limit(&encoding, body, max_message_size)
.map_err(|mut e| {
if e.code == ErrorCode::Unimplemented {
e.code = ErrorCode::Internal;
}
e
})?
} else {
body
};
if body.len() > max_message_size {
return Err(ConnectError::new(
ErrorCode::ResourceExhausted,
format!(
"message size {} exceeds limit {}",
body.len(),
max_message_size
),
));
}
let message = decode_response_view::<RespView>(body, config.codec_format)?;
Ok(UnaryResponse {
headers: resp_headers,
body: message,
trailers: resp_trailers,
})
}
async fn parse_grpc_unary_response<B, RespView>(
response: Response<B>,
config: &ClientConfig,
options: &CallOptions,
deadline: Option<std::time::Instant>,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
B: Body<Data = Bytes> + Send,
B::Error: std::fmt::Display,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let status = response.status();
let resp_headers = response.headers().clone();
if !status.is_success() {
let code = http_status_to_error_code(status);
let mut err = ConnectError::new(code, format!("HTTP error {}", status.as_u16()));
err.response_headers = resp_headers;
return Err(err);
}
if let Some(ct) = resp_headers.get(http::header::CONTENT_TYPE) {
let ct_str = ct.to_str().unwrap_or("");
if !ct_str.starts_with("application/grpc") {
let mut err = ConnectError::new(
ErrorCode::Unknown,
format!("unexpected content-type: {ct_str}"),
);
err.response_headers = resp_headers;
return Err(err);
}
}
let response_encoding = resp_headers
.get("grpc-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
if let Some(ref enc) = response_encoding
&& enc != "identity"
&& !config.compression.supports(enc)
{
let mut err = ConnectError::internal(format!("unsupported response compression: {enc}"));
err.response_headers = resp_headers;
return Err(err);
}
let mut body = std::pin::pin!(response.into_body());
let mut buf = BytesMut::new();
let mut grpc_trailers = http::HeaderMap::new();
let mut has_body_data = false;
let max_buf_size = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE)
.saturating_add(crate::envelope::HEADER_SIZE)
.saturating_add(RESPONSE_BUFFER_TRAILER_SLACK);
loop {
match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
Some(Ok(frame)) => {
if frame.is_data() {
if let Ok(data) = frame.into_data() {
if !data.is_empty() {
has_body_data = true;
}
if buf.len().saturating_add(data.len()) > max_buf_size {
return Err(ConnectError::resource_exhausted(format!(
"response body size exceeds limit {max_buf_size}"
)));
}
buf.extend_from_slice(&data);
}
} else if frame.is_trailers()
&& let Ok(trailers) = frame.into_trailers()
{
grpc_trailers = trailers;
}
}
Some(Err(e)) => {
return Err(ConnectError::internal(format!(
"failed to read response body: {e}"
)));
}
None => break,
}
}
let mut message_data: Option<Bytes> = None;
let mut message_count = 0u32;
while !buf.is_empty() {
if buf[0] & 0x80 != 0 {
let decompression = response_encoding
.as_deref()
.map(|enc| (&config.compression, enc));
if let Some(trailers) =
parse_grpc_web_trailer_frame_with_compression(&buf, decompression)
{
grpc_trailers = trailers;
}
break;
}
let grpc_max_msg = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let envelope = match Envelope::decode_with_limit(&mut buf, grpc_max_msg) {
Ok(Some(env)) => env,
Ok(None) => break,
Err(e) => {
return Err(ConnectError::internal(format!(
"envelope decode failed: {e}"
)));
}
};
let data = if envelope.is_compressed() {
let enc = response_encoding.as_deref().ok_or_else(|| {
ConnectError::internal("received compressed message without grpc-encoding header")
})?;
if enc == "identity" {
return Err(ConnectError::internal(
"received compressed message with identity encoding",
));
}
config
.compression
.decompress_with_limit(enc, envelope.data, grpc_max_msg)?
} else {
envelope.data
};
message_count += 1;
message_data = Some(data);
}
let effective_trailers = if !grpc_trailers.is_empty() {
&grpc_trailers
} else if !has_body_data {
&resp_headers
} else {
&grpc_trailers };
if let Some(mut err) = parse_grpc_error_from_trailers(effective_trailers) {
err.response_headers = resp_headers;
return Err(err);
}
if message_count > 1 {
let mut err = ConnectError::unimplemented("received multiple messages for unary response");
err.response_headers = resp_headers;
return Err(err);
}
if effective_trailers.get("grpc-status").is_none() {
let is_deadline_exceeded = deadline.is_some_and(|d| std::time::Instant::now() >= d);
let mut err = if is_deadline_exceeded {
ConnectError::deadline_exceeded("request timeout")
} else {
ConnectError::internal("gRPC response missing grpc-status trailer")
};
err.response_headers = resp_headers;
return Err(err);
}
let data = match message_data {
Some(data) => data,
None => {
let mut err = ConnectError::unimplemented("gRPC response contained no message data");
err.response_headers = resp_headers;
return Err(err);
}
};
if let Some(max_size) = options.max_message_size
&& data.len() > max_size
{
return Err(ConnectError::new(
ErrorCode::ResourceExhausted,
format!("message size {} exceeds limit {}", data.len(), max_size),
));
}
let message = decode_response_view::<RespView>(data, config.codec_format)?;
Ok(UnaryResponse {
headers: resp_headers,
body: message,
trailers: grpc_trailers,
})
}
pub struct ServerStream<B, RespView> {
headers: http::HeaderMap,
body: B,
buf: BytesMut,
encoding: Option<String>,
compression: CompressionRegistry,
codec_format: CodecFormat,
protocol: Protocol,
max_message_size: Option<usize>,
deadline: Option<std::time::Instant>,
trailers: Option<http::HeaderMap>,
error: Option<ConnectError>,
done: bool,
_phantom: PhantomData<RespView>,
}
impl<B, RespView> std::fmt::Debug for ServerStream<B, RespView> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerStream")
.field("protocol", &self.protocol)
.field("codec_format", &self.codec_format)
.field("encoding", &self.encoding)
.field("done", &self.done)
.field("error", &self.error)
.field("has_trailers", &self.trailers.is_some())
.field("buffered_bytes", &self.buf.len())
.finish_non_exhaustive()
}
}
impl<B, RespView> ServerStream<B, RespView>
where
B: Body<Data = Bytes> + Unpin,
B::Error: std::fmt::Display,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
#[must_use]
pub fn headers(&self) -> &http::HeaderMap {
&self.headers
}
pub async fn message(&mut self) -> Result<Option<OwnedView<RespView>>, ConnectError> {
let deadline = self.deadline;
with_deadline(deadline, self.message_inner()).await
}
async fn message_inner(&mut self) -> Result<Option<OwnedView<RespView>>, ConnectError> {
if self.done {
return Ok(None);
}
loop {
if matches!(self.protocol, Protocol::GrpcWeb)
&& self.buf.len() >= 5
&& self.buf[0] & 0x80 != 0
{
let trailer_len =
u32::from_be_bytes([self.buf[1], self.buf[2], self.buf[3], self.buf[4]])
as usize;
if self.buf.len() >= 5 + trailer_len {
self.done = true;
let decompression =
self.encoding.as_deref().map(|enc| (&self.compression, enc));
if let Some(trailers) =
parse_grpc_web_trailer_frame_with_compression(&self.buf, decompression)
{
if let Some(err) = parse_grpc_error_from_trailers(&trailers) {
self.error = Some(err);
}
self.trailers = Some(trailers);
}
return Ok(None);
}
}
let envelope_result = if matches!(self.protocol, Protocol::GrpcWeb)
&& !self.buf.is_empty()
&& self.buf[0] & 0x80 != 0
{
None
} else {
Envelope::decode_with_limit(
&mut self.buf,
self.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE),
)?
};
match envelope_result {
Some(envelope) => {
if envelope.is_end_stream() {
self.done = true;
self.process_end_stream(envelope)?;
return Ok(None);
}
let data = self.decompress_envelope(envelope)?;
if let Some(max_size) = self.max_message_size
&& data.len() > max_size
{
return Err(ConnectError::new(
ErrorCode::ResourceExhausted,
format!("message size {} exceeds limit {}", data.len(), max_size),
));
}
let msg = decode_response_view::<RespView>(data, self.codec_format)?;
return Ok(Some(msg));
}
None => {
if !self.poll_body().await? {
self.done = true;
if matches!(self.protocol, Protocol::GrpcWeb)
&& !self.buf.is_empty()
&& self.buf[0] & 0x80 != 0
{
let decompression =
self.encoding.as_deref().map(|enc| (&self.compression, enc));
if let Some(trailers) = parse_grpc_web_trailer_frame_with_compression(
&self.buf,
decompression,
) {
if let Some(err) = parse_grpc_error_from_trailers(&trailers) {
self.error = Some(err);
}
self.trailers = Some(trailers);
}
}
if self.error.is_none()
&& self.trailers.is_none()
&& matches!(self.protocol, Protocol::Grpc | Protocol::GrpcWeb)
&& self
.deadline
.is_some_and(|d| std::time::Instant::now() >= d)
{
self.error = Some(ConnectError::deadline_exceeded("request timeout"));
}
return Ok(None);
}
}
}
}
}
#[must_use]
pub fn trailers(&self) -> Option<&http::HeaderMap> {
self.trailers.as_ref()
}
#[must_use]
pub fn error(&self) -> Option<&ConnectError> {
self.error.as_ref()
}
async fn poll_body(&mut self) -> Result<bool, ConnectError> {
let max_buf_size = self
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE)
.saturating_add(2 * crate::envelope::HEADER_SIZE)
.saturating_add(RESPONSE_BUFFER_TRAILER_SLACK);
loop {
let frame = Pin::new(&mut self.body).frame().await;
match frame {
None => return Ok(false), Some(Ok(frame)) => {
if frame.is_data() {
if let Ok(data) = frame.into_data() {
if self.buf.len().saturating_add(data.len()) > max_buf_size {
return Err(ConnectError::resource_exhausted(format!(
"response buffer exceeds limit {max_buf_size}"
)));
}
self.buf.extend_from_slice(&data);
return Ok(true);
}
} else if frame.is_trailers()
&& let Ok(trailers) = frame.into_trailers()
&& matches!(self.protocol, Protocol::Grpc | Protocol::GrpcWeb)
{
if let Some(err) = parse_grpc_error_from_trailers(&trailers) {
self.error = Some(err);
}
self.trailers = Some(trailers);
self.done = true;
return Ok(false);
}
}
Some(Err(e)) => {
return Err(ConnectError::internal(format!(
"error reading response body: {e}"
)));
}
}
}
}
fn decompress_envelope(&self, envelope: Envelope) -> Result<Bytes, ConnectError> {
if envelope.is_compressed() {
let encoding = self.encoding.as_deref().ok_or_else(|| {
ConnectError::internal(
"received compressed message without content-encoding header",
)
})?;
let max_size = self
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
self.compression
.decompress_with_limit(encoding, envelope.data, max_size)
.map_err(|mut e| {
if e.code == ErrorCode::Unimplemented {
e.code = ErrorCode::Internal;
}
e
})
} else {
Ok(envelope.data)
}
}
fn process_end_stream(&mut self, envelope: Envelope) -> Result<(), ConnectError> {
let end_stream_data = self.decompress_envelope(envelope)?;
let end_stream: ClientEndStreamResponse =
serde_json::from_slice(&end_stream_data).unwrap_or_default();
if let Some(metadata) = end_stream.metadata {
let mut trailers = http::HeaderMap::new();
for (name, values) in metadata {
for value in values {
if let (Ok(name), Ok(value)) = (
http::header::HeaderName::from_bytes(name.as_bytes()),
http::header::HeaderValue::from_str(&value),
) {
trailers.append(name, value);
}
}
}
self.trailers = Some(trailers);
}
if let Some(err) = end_stream.error {
let mut connect_error = ConnectError::new(
err.code
.as_deref()
.and_then(|c| c.parse().ok())
.unwrap_or(ErrorCode::Unknown),
err.message.unwrap_or_default(),
);
connect_error.details = err.details;
self.error = Some(connect_error);
}
Ok(())
}
}
pub async fn call_server_stream<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
service: &str,
method: &str,
request: Req,
options: CallOptions,
) -> Result<ServerStream<T::ResponseBody, RespView>, ConnectError>
where
T: ClientTransport,
<T::ResponseBody as Body>::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let options = effective_options(config, options);
let base_str = config.base_uri.to_string();
let base_str = base_str.trim_end_matches('/');
let full_uri = format!("{base_str}/{service}/{method}");
let uri: Uri = full_uri
.parse()
.map_err(|e| ConnectError::internal(format!("invalid URI: {e}")))?;
let body = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};
let compression_for_encoder = config.request_compression.as_ref().map(|enc| {
(
std::sync::Arc::new(config.compression.clone()),
enc.as_str(),
)
});
let mut encoder = crate::envelope::EnvelopeEncoder::new(
compression_for_encoder,
config.compression_policy.with_override(options.compress),
);
let mut request_buf = bytes::BytesMut::new();
tokio_util::codec::Encoder::encode(&mut encoder, body, &mut request_buf)?;
let request_body = request_buf.freeze();
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
let mut builder = Request::builder().method(http::Method::POST).uri(uri);
builder = add_streaming_request_headers(builder, config, options.timeout);
let headers = builder.headers_mut().unwrap();
for (name, value) in &options.headers {
headers.append(name.clone(), value.clone());
}
let http_request = builder
.body(full_body(request_body))
.map_err(|e| ConnectError::internal(format!("failed to build request: {e}")))?;
with_deadline(deadline, async {
let response = transport
.send(http_request)
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")))?;
make_server_stream(
response,
config.protocol,
&config.compression,
config.codec_format,
options.max_message_size,
deadline,
)
.await
})
.await
}
async fn make_server_stream<B, RespView>(
response: Response<B>,
protocol: Protocol,
compression: &CompressionRegistry,
codec_format: CodecFormat,
max_message_size: Option<usize>,
deadline: Option<std::time::Instant>,
) -> Result<ServerStream<B, RespView>, ConnectError>
where
B: Body<Data = Bytes> + Send,
B::Error: std::fmt::Display,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let response_headers = response.headers().clone();
let status = response.status();
if matches!(protocol, Protocol::Grpc | Protocol::GrpcWeb)
&& let Some(mut err) = parse_grpc_error_from_trailers(&response_headers)
{
err.response_headers = response_headers;
return Err(err);
}
if !status.is_success() {
if matches!(protocol, Protocol::Connect) {
let error_encoding = response_headers
.get(http::header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let stream_max_err_size =
max_message_size.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let body = collect_body_bounded(response.into_body(), stream_max_err_size).await?;
let body = match error_encoding {
Some(encoding) => {
match compression.decompress_with_limit(&encoding, body, stream_max_err_size) {
Ok(decompressed) => Some(decompressed),
Err(e) => {
tracing::debug!(
"failed to decompress Connect error response ({encoding}): {e}"
);
None
}
}
}
None => Some(body),
};
if let Some(body) = body
&& let Ok(error) = serde_json::from_slice::<ConnectErrorResponse>(&body)
{
let code = error
.code
.as_deref()
.and_then(|s| s.parse::<ErrorCode>().ok())
.unwrap_or_else(|| http_status_to_error_code(status));
let mut err = ConnectError::new(code, error.message.unwrap_or_default());
err.details = error.details;
err.response_headers = response_headers;
return Err(err);
}
}
let code = http_status_to_error_code(status);
let mut err = ConnectError::new(code, format!("HTTP error {}", status.as_u16()));
err.response_headers = response_headers;
return Err(err);
}
let encoding = response_headers
.get(protocol.content_encoding_header())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
Ok(ServerStream {
headers: response_headers,
body: response.into_body(),
buf: BytesMut::new(),
encoding,
compression: compression.clone(),
codec_format,
protocol,
max_message_size,
deadline,
trailers: None,
error: None,
done: false,
_phantom: PhantomData,
})
}
struct ChannelBody {
rx: tokio::sync::mpsc::Receiver<Result<Bytes, ConnectError>>,
}
impl Body for ChannelBody {
type Data = Bytes;
type Error = ConnectError;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, ConnectError>>> {
self.rx
.poll_recv(cx)
.map(|opt| opt.map(|r| r.map(http_body::Frame::data)))
}
}
enum RecvState<B, RespView> {
Pending(tokio::task::JoinHandle<Result<Response<B>, ConnectError>>),
Ready(Box<ServerStream<B, RespView>>),
Failed,
}
pub struct BidiStream<B, Req, RespView> {
tx: Option<tokio::sync::mpsc::Sender<Result<Bytes, ConnectError>>>,
encoder: crate::envelope::EnvelopeEncoder,
codec_format: CodecFormat,
recv: RecvState<B, RespView>,
stream_config: StreamConfig,
construct_err: Option<ConnectError>,
_req: PhantomData<Req>,
}
impl<B, Req, RespView> std::fmt::Debug for BidiStream<B, Req, RespView> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let recv_state = match &self.recv {
RecvState::Pending(_) => "Pending",
RecvState::Ready(_) => "Ready",
RecvState::Failed => "Failed",
};
f.debug_struct("BidiStream")
.field("send_closed", &self.tx.is_none())
.field("recv_state", &recv_state)
.field("protocol", &self.stream_config.protocol)
.field("codec_format", &self.stream_config.codec_format)
.field("construct_err", &self.construct_err)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
struct StreamConfig {
protocol: Protocol,
codec_format: CodecFormat,
compression: CompressionRegistry,
max_message_size: Option<usize>,
deadline: Option<std::time::Instant>,
}
impl<B, Req, RespView> BidiStream<B, Req, RespView>
where
B: Body<Data = Bytes> + Send + Unpin,
B::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
pub async fn send(&mut self, msg: Req) -> Result<(), ConnectError> {
if let Some(d) = self.stream_config.deadline
&& std::time::Instant::now() >= d
{
return Err(ConnectError::deadline_exceeded(
"client-side deadline exceeded",
));
}
let Some(tx) = &self.tx else {
return Err(ConnectError::internal("send after close_send"));
};
let msg_bytes = match self.codec_format {
CodecFormat::Proto => msg.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&msg).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};
let mut envelope_buf = BytesMut::new();
tokio_util::codec::Encoder::encode(&mut self.encoder, msg_bytes, &mut envelope_buf)?;
tx.send(Ok(envelope_buf.freeze())).await.map_err(|_| {
ConnectError::unavailable("stream closed by server (call message() for error)")
})
}
pub fn close_send(&mut self) {
self.tx = None; }
pub async fn message(&mut self) -> Result<Option<OwnedView<RespView>>, ConnectError> {
if let Some(ref err) = self.construct_err {
return Err(err.clone());
}
if matches!(self.recv, RecvState::Pending(_)) {
let RecvState::Pending(handle) = std::mem::replace(&mut self.recv, RecvState::Failed)
else {
unreachable!()
};
let awaited = async move {
handle.await.map_err(|e| {
ConnectError::internal(format!("transport send task panicked: {e}"))
})?
};
match with_deadline(self.stream_config.deadline, awaited).await {
Ok(response) => {
let cfg = &self.stream_config;
match make_server_stream(
response,
cfg.protocol,
&cfg.compression,
cfg.codec_format,
cfg.max_message_size,
cfg.deadline,
)
.await
{
Ok(stream) => self.recv = RecvState::Ready(Box::new(stream)),
Err(e) => {
self.construct_err = Some(e.clone());
return Err(e);
}
}
}
Err(e) => {
self.construct_err = Some(e.clone());
return Err(e);
}
}
}
match &mut self.recv {
RecvState::Ready(stream) => stream.message().await,
RecvState::Failed => {
Err(ConnectError::internal("stream in failed state"))
}
RecvState::Pending(_) => unreachable!("transitioned above"),
}
}
#[must_use]
pub fn headers(&self) -> Option<&http::HeaderMap> {
match &self.recv {
RecvState::Ready(s) => Some(s.headers()),
_ => None,
}
}
#[must_use]
pub fn trailers(&self) -> Option<&http::HeaderMap> {
match &self.recv {
RecvState::Ready(s) => s.trailers(),
_ => None,
}
}
#[must_use]
pub fn error(&self) -> Option<&ConnectError> {
match &self.recv {
RecvState::Ready(s) => s.error(),
_ => None,
}
}
}
pub async fn call_bidi_stream<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
service: &str,
method: &str,
options: CallOptions,
) -> Result<BidiStream<T::ResponseBody, Req, RespView>, ConnectError>
where
T: ClientTransport,
<T::ResponseBody as Body>::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let options = effective_options(config, options);
let base_str = config.base_uri.to_string();
let base_str = base_str.trim_end_matches('/');
let full_uri = format!("{base_str}/{service}/{method}");
let uri: Uri = full_uri
.parse()
.map_err(|e| ConnectError::internal(format!("invalid URI: {e}")))?;
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, ConnectError>>(32);
let body: ClientBody = ChannelBody { rx }.boxed();
let compression_for_encoder = config.request_compression.as_ref().map(|enc| {
(
std::sync::Arc::new(config.compression.clone()),
enc.as_str(),
)
});
let encoder = crate::envelope::EnvelopeEncoder::new(
compression_for_encoder,
config.compression_policy.with_override(options.compress),
);
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
let mut builder = Request::builder().method(http::Method::POST).uri(uri);
builder = add_streaming_request_headers(builder, config, options.timeout);
let headers = builder.headers_mut().unwrap();
for (name, value) in &options.headers {
headers.append(name.clone(), value.clone());
}
let http_request = builder
.body(body)
.map_err(|e| ConnectError::internal(format!("failed to build request: {e}")))?;
let response_fut = transport.send(http_request);
let response_task = tokio::spawn(async move {
response_fut
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")))
});
Ok(BidiStream {
tx: Some(tx),
encoder,
codec_format: config.codec_format,
recv: RecvState::Pending(response_task),
stream_config: StreamConfig {
protocol: config.protocol,
codec_format: config.codec_format,
compression: config.compression.clone(),
max_message_size: options.max_message_size,
deadline,
},
construct_err: None,
_req: PhantomData,
})
}
pub async fn call_client_stream<T, Req, RespView>(
transport: &T,
config: &ClientConfig,
service: &str,
method: &str,
requests: impl IntoIterator<Item = Req>,
options: CallOptions,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
T: ClientTransport,
<T::ResponseBody as Body>::Error: std::fmt::Display,
Req: buffa::Message + serde::Serialize,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let options = effective_options(config, options);
let base_str = config.base_uri.to_string();
let base_str = base_str.trim_end_matches('/');
let full_uri = format!("{base_str}/{service}/{method}");
let uri: Uri = full_uri
.parse()
.map_err(|e| ConnectError::internal(format!("invalid URI: {e}")))?;
let compression_for_encoder = config.request_compression.as_ref().map(|enc| {
(
std::sync::Arc::new(config.compression.clone()),
enc.as_str(),
)
});
let mut encoder = crate::envelope::EnvelopeEncoder::new(
compression_for_encoder,
config.compression_policy.with_override(options.compress),
);
let mut body_buf = BytesMut::new();
for request in requests {
let msg_bytes = match config.codec_format {
CodecFormat::Proto => request.encode_to_bytes(),
CodecFormat::Json => {
let buf = serde_json::to_vec(&request).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON request: {e}"))
})?;
Bytes::from(buf)
}
};
tokio_util::codec::Encoder::encode(&mut encoder, msg_bytes, &mut body_buf)?;
}
let request_body = body_buf.freeze();
let deadline = options.timeout.map(|t| std::time::Instant::now() + t);
let mut builder = Request::builder().method(http::Method::POST).uri(uri);
builder = add_streaming_request_headers(builder, config, options.timeout);
let headers = builder.headers_mut().unwrap();
for (name, value) in &options.headers {
headers.append(name.clone(), value.clone());
}
let http_request = builder
.body(full_body(request_body))
.map_err(|e| ConnectError::internal(format!("failed to build request: {e}")))?;
with_deadline(deadline, async {
let response = transport
.send(http_request)
.await
.map_err(|e| ConnectError::unavailable(format!("request failed: {e}")))?;
match config.protocol {
Protocol::Grpc | Protocol::GrpcWeb => {
parse_grpc_unary_response(response, config, &options, deadline).await
}
Protocol::Connect => {
parse_connect_client_stream_response(response, config, &options).await
}
}
})
.await
}
async fn parse_connect_client_stream_response<B, RespView>(
response: Response<B>,
config: &ClientConfig,
options: &CallOptions,
) -> Result<UnaryResponse<OwnedView<RespView>>, ConnectError>
where
B: Body<Data = Bytes> + Send,
B::Error: std::fmt::Display,
RespView: MessageView<'static> + Send,
RespView::Owned: buffa::Message + serde::de::DeserializeOwned,
{
let status = response.status();
if !status.is_success() {
let response_headers = response.headers().clone();
let error_encoding = response_headers
.get(http::header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let max_err_size = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let body = collect_body_bounded(response.into_body(), max_err_size).await?;
let body = match error_encoding {
Some(encoding) => {
match config
.compression
.decompress_with_limit(&encoding, body, max_err_size)
{
Ok(decompressed) => Some(decompressed),
Err(e) => {
tracing::debug!(
"failed to decompress Connect error response ({encoding}): {e}"
);
None
}
}
}
None => Some(body),
};
if let Some(body) = body
&& let Ok(error) = serde_json::from_slice::<ConnectErrorResponse>(&body)
{
let code = error
.code
.as_deref()
.and_then(|s| s.parse::<ErrorCode>().ok())
.unwrap_or_else(|| http_status_to_error_code(status));
let mut err = ConnectError::new(code, error.message.unwrap_or_default());
err.details = error.details;
err.response_headers = response_headers;
return Err(err);
}
let code = http_status_to_error_code(status);
let mut err = ConnectError::new(code, format!("HTTP error {}", status.as_u16()));
err.response_headers = response_headers;
return Err(err);
}
let encoding = response
.headers()
.get(config.protocol.content_encoding_header())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let resp_headers = response.headers().clone();
let max_msg_size = options
.max_message_size
.unwrap_or(crate::service::DEFAULT_MAX_MESSAGE_SIZE);
let body_limit = max_msg_size
.saturating_add(2 * crate::envelope::HEADER_SIZE)
.saturating_add(RESPONSE_BUFFER_TRAILER_SLACK);
let body = collect_body_bounded(response.into_body(), body_limit).await?;
let mut buf = BytesMut::from(body.as_ref());
let mut data_envelopes: Vec<Bytes> = Vec::new();
let mut trailers = http::HeaderMap::new();
while !buf.is_empty() {
let envelope = match Envelope::decode_with_limit(&mut buf, max_msg_size)? {
Some(env) => env,
None => break,
};
if envelope.is_end_stream() {
let end_stream_data = if envelope.is_compressed() {
let enc = encoding.as_deref().ok_or_else(|| {
ConnectError::internal("received compressed END_STREAM without encoding header")
})?;
config
.compression
.decompress_with_limit(enc, envelope.data, max_msg_size)?
} else {
envelope.data
};
let end_stream: ClientEndStreamResponse =
serde_json::from_slice(&end_stream_data).unwrap_or_default();
if let Some(metadata) = end_stream.metadata {
for (name, values) in metadata {
for value in values {
if let (Ok(name), Ok(value)) = (
http::header::HeaderName::from_bytes(name.as_bytes()),
http::header::HeaderValue::from_str(&value),
) {
trailers.append(name, value);
}
}
}
}
if let Some(err) = end_stream.error {
let mut connect_error = ConnectError::new(
err.code
.as_deref()
.and_then(|c| c.parse().ok())
.unwrap_or(ErrorCode::Unknown),
err.message.unwrap_or_default(),
);
connect_error.details = err.details;
connect_error.response_headers = resp_headers;
connect_error.trailers = trailers;
return Err(connect_error);
}
} else {
let data = if envelope.is_compressed() {
let enc = encoding.as_deref().ok_or_else(|| {
ConnectError::internal("received compressed message without encoding header")
})?;
config
.compression
.decompress_with_limit(enc, envelope.data, max_msg_size)
.map_err(|mut e| {
if e.code == ErrorCode::Unimplemented {
e.code = ErrorCode::Internal;
}
e
})?
} else {
envelope.data
};
if data.len() > max_msg_size {
return Err(ConnectError::new(
ErrorCode::ResourceExhausted,
format!("message size {} exceeds limit {}", data.len(), max_msg_size),
));
}
data_envelopes.push(data);
}
}
if data_envelopes.is_empty() {
return Err(ConnectError::unimplemented(
"client streaming response contains no data messages",
));
}
if data_envelopes.len() > 1 {
return Err(ConnectError::unimplemented(
"client streaming response contains multiple data messages",
));
}
let data = data_envelopes.into_iter().next().unwrap();
let message = decode_response_view::<RespView>(data, config.codec_format)?;
Ok(UnaryResponse {
headers: resp_headers,
body: message,
trailers,
})
}
#[derive(serde::Deserialize, Default)]
struct ClientEndStreamResponse {
error: Option<ClientEndStreamError>,
metadata: Option<HashMap<String, Vec<String>>>,
}
#[derive(serde::Deserialize)]
struct ClientEndStreamError {
code: Option<String>,
message: Option<String>,
#[serde(default)]
details: Vec<ErrorDetail>,
}
#[derive(serde::Deserialize)]
struct ConnectErrorResponse {
#[serde(default)]
code: Option<String>,
#[serde(default)]
message: Option<String>,
#[serde(default)]
details: Vec<ErrorDetail>,
}
fn http_status_to_error_code(status: http::StatusCode) -> ErrorCode {
match status.as_u16() {
400 => ErrorCode::Internal,
401 => ErrorCode::Unauthenticated,
403 => ErrorCode::PermissionDenied,
404 => ErrorCode::Unimplemented,
408 => ErrorCode::DeadlineExceeded,
429 => ErrorCode::Unavailable,
502 => ErrorCode::Unavailable,
503 => ErrorCode::Unavailable,
504 => ErrorCode::Unavailable,
_ => ErrorCode::Unknown,
}
}
fn unary_request_content_type(config: &ClientConfig) -> &'static str {
match config.protocol {
Protocol::Connect => config.codec_format.content_type(),
Protocol::Grpc | Protocol::GrpcWeb => config
.protocol
.response_content_type(config.codec_format, false),
}
}
fn streaming_request_content_type(config: &ClientConfig) -> &'static str {
config
.protocol
.response_content_type(config.codec_format, true)
}
#[allow(clippy::manual_is_multiple_of)]
fn format_timeout(timeout: Duration, protocol: Protocol) -> String {
match protocol {
Protocol::Connect => {
const MAX_MILLIS: u128 = 9_999_999_999;
timeout.as_millis().min(MAX_MILLIS).to_string()
}
Protocol::Grpc | Protocol::GrpcWeb => {
const MAX_DIGITS: u128 = 99_999_999;
let nanos = timeout.as_nanos();
let secs = timeout.as_secs() as u128;
let millis = timeout.as_millis();
let micros = timeout.as_micros();
if nanos == 0 {
"0n".to_owned()
} else if nanos % 1_000_000_000 == 0 && secs <= MAX_DIGITS {
format!("{secs}S")
} else if nanos % 1_000_000 == 0 && millis <= MAX_DIGITS {
format!("{millis}m")
} else if nanos % 1_000 == 0 && micros <= MAX_DIGITS {
format!("{micros}u")
} else if nanos <= MAX_DIGITS {
format!("{nanos}n")
} else if micros <= MAX_DIGITS {
format!("{micros}u")
} else if millis <= MAX_DIGITS {
format!("{millis}m")
} else if secs <= MAX_DIGITS {
format!("{secs}S")
} else {
format!("{MAX_DIGITS}S")
}
}
}
}
fn add_unary_request_headers(
mut builder: http::request::Builder,
config: &ClientConfig,
timeout: Option<Duration>,
applied_content_encoding: Option<&str>,
) -> http::request::Builder {
builder = builder.header(
http::header::CONTENT_TYPE,
unary_request_content_type(config),
);
match config.protocol {
Protocol::Connect => {
builder = builder.header(connect_header::PROTOCOL_VERSION, "1");
if let Some(encoding) = applied_content_encoding {
builder = builder.header(http::header::CONTENT_ENCODING, encoding);
}
let accept = config.compression.accept_encoding_header();
if !accept.is_empty() {
builder = builder.header(http::header::ACCEPT_ENCODING, accept);
}
}
Protocol::Grpc => {
builder = builder.header("te", "trailers");
if let Some(ref encoding) = config.request_compression {
builder = builder.header("grpc-encoding", encoding.as_str());
}
let accept = config.compression.accept_encoding_header();
if !accept.is_empty() {
builder = builder.header("grpc-accept-encoding", accept);
}
}
Protocol::GrpcWeb => {
if let Some(ref encoding) = config.request_compression {
builder = builder.header("grpc-encoding", encoding.as_str());
}
let accept = config.compression.accept_encoding_header();
if !accept.is_empty() {
builder = builder.header("grpc-accept-encoding", accept);
}
}
}
if let Some(timeout) = timeout {
builder = builder.header(
config.protocol.timeout_header(),
format_timeout(timeout, config.protocol),
);
}
builder
}
fn add_streaming_request_headers(
mut builder: http::request::Builder,
config: &ClientConfig,
timeout: Option<Duration>,
) -> http::request::Builder {
builder = builder.header(
http::header::CONTENT_TYPE,
streaming_request_content_type(config),
);
match config.protocol {
Protocol::Connect => {
builder = builder.header(connect_header::PROTOCOL_VERSION, "1");
}
Protocol::Grpc => {
builder = builder.header("te", "trailers");
}
Protocol::GrpcWeb => {}
}
let encoding_header = config.protocol.content_encoding_header();
let accept_header = config.protocol.accept_encoding_header();
if let Some(ref encoding) = config.request_compression {
builder = builder.header(encoding_header, encoding.as_str());
}
let accept = config.compression.accept_encoding_header();
if !accept.is_empty() {
builder = builder.header(accept_header, accept);
}
if let Some(timeout) = timeout {
builder = builder.header(
config.protocol.timeout_header(),
format_timeout(timeout, config.protocol),
);
}
builder
}
fn parse_grpc_error_from_trailers(trailers: &http::HeaderMap) -> Option<ConnectError> {
let status = trailers
.get("grpc-status")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())?;
if status == 0 {
return None; }
let code = ErrorCode::from_grpc_code(status).unwrap_or(ErrorCode::Unknown);
let message = trailers
.get("grpc-message")
.and_then(|v| v.to_str().ok())
.map(grpc_percent_decode);
let mut err = ConnectError::new(code, message.unwrap_or_default());
if let Some(details_b64) = trailers
.get("grpc-status-details-bin")
.and_then(|v| v.to_str().ok())
{
use base64::Engine;
if let Ok(details_bytes) = base64::engine::general_purpose::STANDARD
.decode(details_b64)
.or_else(|_| base64::engine::general_purpose::STANDARD_NO_PAD.decode(details_b64))
{
err.details = crate::grpc_status::decode_details(&details_bytes);
}
}
for (key, value) in trailers.iter() {
let name = key.as_str();
if name != "grpc-status" && name != "grpc-message" && name != "grpc-status-details-bin" {
err.trailers.append(key, value.clone());
}
}
Some(err)
}
async fn collect_body_bounded<B>(body: B, max_size: usize) -> Result<Bytes, ConnectError>
where
B: Body<Data = Bytes>,
B::Error: std::fmt::Display,
{
let mut buf = BytesMut::new();
let mut stream = std::pin::pin!(body);
loop {
match std::future::poll_fn(|cx| stream.as_mut().poll_frame(cx)).await {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if buf.len().saturating_add(data.len()) > max_size {
return Err(ConnectError::new(
ErrorCode::ResourceExhausted,
format!("response body size exceeds limit {max_size}"),
));
}
buf.extend_from_slice(&data);
}
}
Some(Err(e)) => {
return Err(ConnectError::internal(format!(
"failed to read response body: {e}",
)));
}
None => break,
}
}
Ok(buf.freeze())
}
fn grpc_percent_decode(s: &str) -> String {
percent_encoding::percent_decode_str(s)
.decode_utf8_lossy()
.into_owned()
}
fn parse_grpc_web_trailer_frame_with_compression(
data: &[u8],
decompression: Option<(&CompressionRegistry, &str)>,
) -> Option<http::HeaderMap> {
if data.len() < 5 || data[0] & 0x80 == 0 {
return None;
}
let is_compressed = data[0] & 0x01 != 0;
let len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
const MAX_TRAILER_SIZE: usize = 1024 * 1024;
if len > MAX_TRAILER_SIZE || data.len() < 5 + len {
return None;
}
let raw_payload = &data[5..5 + len];
let payload_bytes;
let payload = if is_compressed {
if let Some((registry, encoding)) = decompression {
payload_bytes = registry
.decompress_with_limit(
encoding,
Bytes::copy_from_slice(raw_payload),
MAX_TRAILER_SIZE,
)
.ok()?;
std::str::from_utf8(&payload_bytes).ok()?
} else {
return None;
}
} else {
std::str::from_utf8(raw_payload).ok()?
};
let mut headers = http::HeaderMap::new();
for line in payload.split('\n') {
let line = line.trim_end_matches('\r');
if line.is_empty() {
continue;
}
if let Some((key, value)) = line.split_once(':')
&& let (Ok(name), Ok(val)) = (
http::header::HeaderName::from_bytes(key.trim().as_bytes()),
http::HeaderValue::from_str(value.trim()),
)
{
headers.append(name, val);
}
}
Some(headers)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_config() {
let config = ClientConfig::new("http://localhost:8080".parse().unwrap())
.json()
.compress_requests("gzip");
assert_eq!(config.codec_format, CodecFormat::Json);
assert_eq!(config.request_compression, Some("gzip".to_string()));
}
#[cfg(feature = "client")]
#[test]
fn test_http_client_plaintext_creation() {
let _client = HttpClient::plaintext();
let _client = HttpClient::plaintext_http2_only();
}
#[test]
fn client_types_are_debug() {
fn assert_debug<T: std::fmt::Debug>() {}
assert_debug::<UnaryResponse<()>>();
assert_debug::<ServerStream<http_body_util::Empty<Bytes>, ()>>();
assert_debug::<BidiStream<http_body_util::Empty<Bytes>, (), ()>>();
#[cfg(feature = "client")]
assert_debug::<HttpClient>();
}
#[cfg(feature = "client")]
#[tokio::test]
async fn http_client_plaintext_rejects_https() {
let client = HttpClient::plaintext();
let req = Request::builder()
.uri("https://localhost:8080/foo")
.body(full_body(Bytes::new()))
.unwrap();
let err = client.send(req).await.unwrap_err();
assert_eq!(err.code, ErrorCode::InvalidArgument);
assert!(err.message.as_deref().unwrap().contains("with_tls"));
}
#[cfg(all(feature = "client", feature = "client-tls"))]
#[tokio::test]
async fn http_client_with_tls_rejects_http() {
let tls_config = std::sync::Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let client = HttpClient::with_tls(tls_config);
let req = Request::builder()
.uri("http://localhost:8080/foo")
.body(full_body(Bytes::new()))
.unwrap();
let err = client.send(req).await.unwrap_err();
assert_eq!(err.code, ErrorCode::InvalidArgument);
assert!(err.message.as_deref().unwrap().contains("plaintext"));
}
#[cfg(all(feature = "client", feature = "client-tls"))]
#[test]
fn http_client_with_tls_construction() {
let tls_config = std::sync::Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore::empty())
.with_no_client_auth(),
);
let _client = HttpClient::with_tls(tls_config);
}
#[test]
fn test_format_timeout_connect() {
assert_eq!(
format_timeout(Duration::from_millis(5000), Protocol::Connect),
"5000"
);
assert_eq!(
format_timeout(Duration::from_secs(0), Protocol::Connect),
"0"
);
}
#[test]
fn test_format_timeout_connect_caps_at_10_digits() {
assert_eq!(
format_timeout(Duration::from_millis(9_999_999_999), Protocol::Connect),
"9999999999"
);
assert_eq!(
format_timeout(Duration::from_secs(365 * 86400), Protocol::Connect),
"9999999999"
);
assert_eq!(
format_timeout(Duration::MAX, Protocol::Connect),
"9999999999"
);
}
#[test]
fn test_format_timeout_grpc_seconds() {
assert_eq!(
format_timeout(Duration::from_secs(30), Protocol::Grpc),
"30S"
);
}
#[test]
fn test_format_timeout_grpc_milliseconds() {
assert_eq!(
format_timeout(Duration::from_millis(500), Protocol::Grpc),
"500m"
);
}
#[test]
fn test_format_timeout_grpc_microseconds() {
assert_eq!(
format_timeout(Duration::from_micros(100), Protocol::Grpc),
"100u"
);
}
#[test]
fn test_format_timeout_grpc_nanoseconds() {
assert_eq!(
format_timeout(Duration::from_nanos(999), Protocol::Grpc),
"999n"
);
}
#[test]
fn test_format_timeout_grpc_zero() {
assert_eq!(format_timeout(Duration::from_secs(0), Protocol::Grpc), "0n");
}
#[test]
fn test_format_timeout_grpc_8_digit_limit() {
assert_eq!(
format_timeout(Duration::from_secs(99_999_999), Protocol::Grpc),
"99999999S"
);
assert_eq!(
format_timeout(Duration::from_secs(100_000_000), Protocol::Grpc),
"99999999S"
);
}
#[test]
fn test_format_timeout_grpc_web_same_as_grpc() {
assert_eq!(
format_timeout(Duration::from_millis(500), Protocol::GrpcWeb),
"500m"
);
}
#[test]
fn test_format_timeout_grpc_subsecond_nanosecond_residue() {
assert_eq!(
format_timeout(Duration::from_nanos(100_000_001), Protocol::Grpc),
"100000u" );
assert_eq!(
format_timeout(Duration::from_nanos(100_000_000), Protocol::Grpc),
"100m" );
assert_eq!(
format_timeout(Duration::from_nanos(200_000_000_001), Protocol::Grpc),
"200000m" );
}
#[test]
fn test_grpc_percent_decode_passthrough() {
assert_eq!(grpc_percent_decode("hello world"), "hello world");
}
#[test]
fn test_grpc_percent_decode_percent() {
assert_eq!(grpc_percent_decode("100%25"), "100%");
}
#[test]
fn test_grpc_percent_decode_newlines() {
assert_eq!(grpc_percent_decode("a%0Ab"), "a\nb");
assert_eq!(grpc_percent_decode("a%0D%0Ab"), "a\r\nb");
}
#[test]
fn test_grpc_percent_decode_utf8_multibyte() {
assert_eq!(grpc_percent_decode("caf%C3%A9"), "café");
assert_eq!(grpc_percent_decode("%E2%98%BA"), "☺");
assert_eq!(grpc_percent_decode("%F0%9F%98%88"), "😈");
}
#[test]
fn test_grpc_percent_decode_partial_percent() {
assert_eq!(grpc_percent_decode("100%"), "100%");
assert_eq!(grpc_percent_decode("a%2"), "a%2");
}
#[test]
fn test_grpc_percent_decode_invalid_hex() {
assert_eq!(grpc_percent_decode("a%ZZb"), "a%ZZb");
}
#[test]
fn test_parse_grpc_error_ok_returns_none() {
let mut trailers = http::HeaderMap::new();
trailers.insert("grpc-status", http::HeaderValue::from_static("0"));
assert!(parse_grpc_error_from_trailers(&trailers).is_none());
}
#[test]
fn test_parse_grpc_error_missing_status_returns_none() {
let trailers = http::HeaderMap::new();
assert!(parse_grpc_error_from_trailers(&trailers).is_none());
}
#[test]
fn test_parse_grpc_error_with_code_and_message() {
let mut trailers = http::HeaderMap::new();
trailers.insert("grpc-status", http::HeaderValue::from_static("5"));
trailers.insert(
"grpc-message",
http::HeaderValue::from_static("not%20found"),
);
let err = parse_grpc_error_from_trailers(&trailers).unwrap();
assert_eq!(err.code, ErrorCode::NotFound);
assert_eq!(err.message.as_deref(), Some("not found"));
}
#[test]
fn test_parse_grpc_error_unknown_code() {
let mut trailers = http::HeaderMap::new();
trailers.insert("grpc-status", http::HeaderValue::from_static("99"));
let err = parse_grpc_error_from_trailers(&trailers).unwrap();
assert_eq!(err.code, ErrorCode::Unknown);
}
#[test]
fn test_parse_grpc_error_custom_trailers() {
let mut trailers = http::HeaderMap::new();
trailers.insert("grpc-status", http::HeaderValue::from_static("13"));
trailers.insert("x-custom", http::HeaderValue::from_static("value"));
let err = parse_grpc_error_from_trailers(&trailers).unwrap();
assert_eq!(err.code, ErrorCode::Internal);
assert_eq!(
err.trailers.get("x-custom").unwrap().to_str().unwrap(),
"value"
);
}
#[test]
fn test_parse_grpc_web_trailer_uncompressed() {
let payload = b"grpc-status: 0\r\n";
let mut frame = Vec::with_capacity(5 + payload.len());
frame.push(0x80);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
let headers = parse_grpc_web_trailer_frame_with_compression(&frame, None).unwrap();
assert_eq!(headers.get("grpc-status").unwrap().to_str().unwrap(), "0");
}
#[test]
fn test_parse_grpc_web_trailer_with_error() {
let payload = b"grpc-status: 13\r\ngrpc-message: internal error\r\n";
let mut frame = Vec::with_capacity(5 + payload.len());
frame.push(0x80);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
let headers = parse_grpc_web_trailer_frame_with_compression(&frame, None).unwrap();
assert_eq!(headers.get("grpc-status").unwrap().to_str().unwrap(), "13");
assert_eq!(
headers.get("grpc-message").unwrap().to_str().unwrap(),
"internal error"
);
}
#[test]
fn test_parse_grpc_web_trailer_truncated() {
assert!(parse_grpc_web_trailer_frame_with_compression(&[0x80, 0, 0], None).is_none());
}
#[test]
fn test_parse_grpc_web_trailer_not_trailer() {
let frame = [0x00, 0, 0, 0, 5, b'h', b'e', b'l', b'l', b'o'];
assert!(parse_grpc_web_trailer_frame_with_compression(&frame, None).is_none());
}
#[test]
fn test_parse_grpc_web_trailer_compressed_no_registry() {
let payload = b"grpc-status: 0\r\n";
let mut frame = Vec::with_capacity(5 + payload.len());
frame.push(0x81);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
assert!(parse_grpc_web_trailer_frame_with_compression(&frame, None).is_none());
}
#[test]
fn test_parse_grpc_web_trailer_newline_only() {
let payload = b"grpc-status: 0\n";
let mut frame = Vec::with_capacity(5 + payload.len());
frame.push(0x80);
frame.extend_from_slice(&(payload.len() as u32).to_be_bytes());
frame.extend_from_slice(payload);
let headers = parse_grpc_web_trailer_frame_with_compression(&frame, None).unwrap();
assert_eq!(headers.get("grpc-status").unwrap().to_str().unwrap(), "0");
}
#[test]
fn test_unary_request_content_type_connect() {
let config = ClientConfig::new("http://localhost".parse().unwrap());
assert_eq!(unary_request_content_type(&config), "application/proto");
let config = config.codec_format(CodecFormat::Json);
assert_eq!(unary_request_content_type(&config), "application/json");
}
#[test]
fn test_unary_request_content_type_grpc() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::Grpc);
assert_eq!(
unary_request_content_type(&config),
"application/grpc+proto"
);
let config = config.codec_format(CodecFormat::Json);
assert_eq!(unary_request_content_type(&config), "application/grpc+json");
}
#[test]
fn test_streaming_request_content_type() {
let config = ClientConfig::new("http://localhost".parse().unwrap());
assert_eq!(
streaming_request_content_type(&config),
"application/connect+proto"
);
let config = config.protocol(Protocol::Grpc);
assert_eq!(
streaming_request_content_type(&config),
"application/grpc+proto"
);
let config = config.protocol(Protocol::GrpcWeb);
assert_eq!(
streaming_request_content_type(&config),
"application/grpc-web+proto"
);
}
#[test]
fn test_http_status_to_error_code() {
assert_eq!(
http_status_to_error_code(http::StatusCode::BAD_REQUEST),
ErrorCode::Internal
);
assert_eq!(
http_status_to_error_code(http::StatusCode::UNAUTHORIZED),
ErrorCode::Unauthenticated
);
assert_eq!(
http_status_to_error_code(http::StatusCode::FORBIDDEN),
ErrorCode::PermissionDenied
);
assert_eq!(
http_status_to_error_code(http::StatusCode::NOT_FOUND),
ErrorCode::Unimplemented
);
assert_eq!(
http_status_to_error_code(http::StatusCode::SERVICE_UNAVAILABLE),
ErrorCode::Unavailable
);
assert_eq!(
http_status_to_error_code(http::StatusCode::INTERNAL_SERVER_ERROR),
ErrorCode::Unknown
);
}
#[test]
fn test_add_unary_request_headers_connect() {
let config = ClientConfig::new("http://localhost".parse().unwrap());
let builder = http::Request::builder();
let builder = add_unary_request_headers(builder, &config, None, None);
let req = builder.body(()).unwrap();
assert_eq!(
req.headers().get("content-type").unwrap(),
"application/proto"
);
assert_eq!(req.headers().get("connect-protocol-version").unwrap(), "1");
assert!(req.headers().get("te").is_none());
}
#[test]
fn test_add_unary_request_headers_grpc() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::Grpc);
let builder = http::Request::builder();
let builder = add_unary_request_headers(builder, &config, None, None);
let req = builder.body(()).unwrap();
assert_eq!(
req.headers().get("content-type").unwrap(),
"application/grpc+proto"
);
assert_eq!(req.headers().get("te").unwrap(), "trailers");
assert!(req.headers().get("connect-protocol-version").is_none());
}
#[test]
fn test_add_unary_request_headers_grpc_web() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::GrpcWeb);
let builder = http::Request::builder();
let builder = add_unary_request_headers(builder, &config, None, None);
let req = builder.body(()).unwrap();
assert_eq!(
req.headers().get("content-type").unwrap(),
"application/grpc-web+proto"
);
assert!(req.headers().get("te").is_none());
assert!(req.headers().get("connect-protocol-version").is_none());
}
#[test]
fn test_add_unary_request_headers_with_timeout() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::Grpc);
let builder = http::Request::builder();
let builder =
add_unary_request_headers(builder, &config, Some(Duration::from_millis(500)), None);
let req = builder.body(()).unwrap();
assert_eq!(req.headers().get("grpc-timeout").unwrap(), "500m");
}
#[tokio::test]
async fn with_deadline_none_passes_through() {
let result: Result<i32, ConnectError> = with_deadline(None, async { Ok(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn with_deadline_completes_before_deadline() {
let deadline = std::time::Instant::now() + Duration::from_secs(10);
let result: Result<i32, ConnectError> =
with_deadline(Some(deadline), async { Ok(42) }).await;
assert_eq!(result.unwrap(), 42);
}
#[tokio::test(start_paused = true)]
async fn with_deadline_fires_on_slow_future() {
let deadline = std::time::Instant::now() + Duration::from_millis(100);
let slow = async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok::<i32, ConnectError>(42)
};
let result = with_deadline(Some(deadline), slow).await;
let err = result.unwrap_err();
assert_eq!(err.code, ErrorCode::DeadlineExceeded);
}
#[tokio::test(start_paused = true)]
async fn with_deadline_already_passed_returns_immediately() {
let deadline = std::time::Instant::now() - Duration::from_secs(1);
let result: Result<i32, ConnectError> =
with_deadline(Some(deadline), std::future::pending()).await;
let err = result.unwrap_err();
assert_eq!(err.code, ErrorCode::DeadlineExceeded);
}
#[tokio::test]
async fn with_deadline_propagates_inner_error() {
let deadline = std::time::Instant::now() + Duration::from_secs(10);
let failing = async { Err::<i32, _>(ConnectError::internal("inner")) };
let result = with_deadline(Some(deadline), failing).await;
let err = result.unwrap_err();
assert_eq!(err.code, ErrorCode::Internal);
}
#[tokio::test]
async fn channel_body_delivers_frames_then_eof() {
let (tx, rx) = tokio::sync::mpsc::channel(4);
let body = ChannelBody { rx };
tx.send(Ok(Bytes::from_static(b"hello"))).await.unwrap();
tx.send(Ok(Bytes::from_static(b"world"))).await.unwrap();
drop(tx);
let collected = body.collect().await.unwrap().to_bytes();
assert_eq!(&collected[..], b"helloworld");
}
#[tokio::test]
async fn channel_body_surfaces_error() {
let (tx, rx) = tokio::sync::mpsc::channel(4);
let mut body = ChannelBody { rx };
tx.send(Err(ConnectError::internal("boom"))).await.unwrap();
drop(tx);
let frame = std::future::poll_fn(|cx| Pin::new(&mut body).poll_frame(cx)).await;
assert!(matches!(frame, Some(Err(_))));
}
#[tokio::test]
async fn collect_body_bounded_within_limit() {
let body = Full::new(Bytes::from_static(b"hello"));
let got = collect_body_bounded(body, 10).await.unwrap();
assert_eq!(&got[..], b"hello");
}
#[tokio::test]
async fn collect_body_bounded_at_exact_limit() {
let body = Full::new(Bytes::from_static(b"hello"));
let got = collect_body_bounded(body, 5).await.unwrap();
assert_eq!(&got[..], b"hello");
}
#[tokio::test]
async fn collect_body_bounded_exceeds_limit() {
let body = Full::new(Bytes::from_static(b"hello world"));
let err = collect_body_bounded(body, 5).await.unwrap_err();
assert_eq!(err.code, ErrorCode::ResourceExhausted);
}
#[tokio::test]
async fn collect_body_bounded_empty() {
let body = Full::new(Bytes::new());
let got = collect_body_bounded(body, 0).await.unwrap();
assert!(got.is_empty());
}
#[tokio::test]
async fn collect_body_bounded_multi_frame_exceeds_mid_stream() {
let (tx, rx) = tokio::sync::mpsc::channel(4);
let body = ChannelBody { rx };
tx.send(Ok(Bytes::from_static(b"aaa"))).await.unwrap();
tx.send(Ok(Bytes::from_static(b"bbb"))).await.unwrap();
tx.send(Ok(Bytes::from_static(b"ccc"))).await.unwrap();
drop(tx);
let err = collect_body_bounded(body, 7).await.unwrap_err();
assert_eq!(err.code, ErrorCode::ResourceExhausted);
}
#[tokio::test]
async fn collect_body_bounded_multi_frame_within_limit() {
let (tx, rx) = tokio::sync::mpsc::channel(4);
let body = ChannelBody { rx };
tx.send(Ok(Bytes::from_static(b"foo"))).await.unwrap();
tx.send(Ok(Bytes::from_static(b"bar"))).await.unwrap();
drop(tx);
let got = collect_body_bounded(body, 10).await.unwrap();
assert_eq!(&got[..], b"foobar");
}
#[tokio::test]
async fn collect_body_bounded_propagates_body_error() {
let (tx, rx) = tokio::sync::mpsc::channel(4);
let body = ChannelBody { rx };
tx.send(Err(ConnectError::internal("io"))).await.unwrap();
drop(tx);
let err = collect_body_bounded(body, 1024).await.unwrap_err();
assert_eq!(err.code, ErrorCode::Internal);
}
#[test]
fn test_add_streaming_request_headers_grpc() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::Grpc);
let builder = http::Request::builder();
let builder = add_streaming_request_headers(builder, &config, None);
let req = builder.body(()).unwrap();
assert_eq!(
req.headers().get("content-type").unwrap(),
"application/grpc+proto"
);
assert_eq!(req.headers().get("te").unwrap(), "trailers");
}
#[test]
fn test_client_config_protocol() {
let config =
ClientConfig::new("http://localhost".parse().unwrap()).protocol(Protocol::Grpc);
assert_eq!(config.protocol, Protocol::Grpc);
}
#[test]
fn test_client_config_default_protocol() {
let config = ClientConfig::new("http://localhost".parse().unwrap());
assert_eq!(config.protocol, Protocol::Connect);
}
fn headers_for(protocol: Protocol, applied_encoding: Option<&str>) -> http::HeaderMap {
let config = ClientConfig::new("http://localhost".parse().unwrap())
.protocol(protocol)
.compress_requests("gzip");
let builder = http::Request::builder();
add_unary_request_headers(builder, &config, None, applied_encoding)
.body(())
.unwrap()
.headers()
.clone()
}
#[test]
fn connect_unary_no_content_encoding_when_compression_skipped() {
let headers = headers_for(Protocol::Connect, None);
assert!(
!headers.contains_key(http::header::CONTENT_ENCODING),
"Content-Encoding must not be set when compression policy skipped the body"
);
}
#[test]
fn connect_unary_content_encoding_when_compressed() {
let headers = headers_for(Protocol::Connect, Some("gzip"));
assert_eq!(headers.get(http::header::CONTENT_ENCODING).unwrap(), "gzip");
}
#[test]
fn grpc_unary_encoding_header_independent_of_applied() {
let headers = headers_for(Protocol::Grpc, None);
assert_eq!(headers.get("grpc-encoding").unwrap(), "gzip");
}
fn test_config() -> ClientConfig {
ClientConfig::new("http://localhost:8080".parse().unwrap())
}
#[test]
fn effective_options_uses_config_defaults_when_options_unset() {
let config = test_config()
.default_timeout(Duration::from_secs(30))
.default_max_message_size(1024)
.default_header("x-trace-id", "cfg-trace");
let eff = effective_options(&config, CallOptions::default());
assert_eq!(eff.timeout, Some(Duration::from_secs(30)));
assert_eq!(eff.max_message_size, Some(1024));
assert_eq!(eff.headers.get("x-trace-id").unwrap(), "cfg-trace");
}
#[test]
fn effective_options_options_override_config_defaults() {
let config = test_config()
.default_timeout(Duration::from_secs(30))
.default_max_message_size(1024);
let options = CallOptions::default()
.with_timeout(Duration::from_secs(5))
.with_max_message_size(512);
let eff = effective_options(&config, options);
assert_eq!(eff.timeout, Some(Duration::from_secs(5)));
assert_eq!(eff.max_message_size, Some(512));
}
#[test]
fn effective_options_compress_has_no_config_default() {
let config = test_config();
let options = CallOptions::default().with_compression(true);
let eff = effective_options(&config, options);
assert_eq!(eff.compress, Some(true));
}
#[test]
fn merge_headers_options_override_config_same_name() {
let mut cfg = http::HeaderMap::new();
cfg.insert("x-token", "cfg-token".parse().unwrap());
let mut opts = http::HeaderMap::new();
opts.insert("x-token", "opt-token".parse().unwrap());
let merged = merge_headers(&cfg, opts);
let vals: Vec<_> = merged.get_all("x-token").iter().collect();
assert_eq!(vals.len(), 1);
assert_eq!(vals[0], "opt-token");
}
#[test]
fn merge_headers_config_only_names_preserved() {
let mut cfg = http::HeaderMap::new();
cfg.insert("x-cfg-only", "kept".parse().unwrap());
let mut opts = http::HeaderMap::new();
opts.insert("x-opt-only", "also-kept".parse().unwrap());
let merged = merge_headers(&cfg, opts);
assert_eq!(merged.get("x-cfg-only").unwrap(), "kept");
assert_eq!(merged.get("x-opt-only").unwrap(), "also-kept");
}
#[test]
fn merge_headers_options_multivalue_replaces_config() {
let mut cfg = http::HeaderMap::new();
cfg.append("x-thing", "cfg-a".parse().unwrap());
cfg.append("x-thing", "cfg-b".parse().unwrap());
let mut opts = http::HeaderMap::new();
opts.append("x-thing", "opt-1".parse().unwrap());
opts.append("x-thing", "opt-2".parse().unwrap());
let merged = merge_headers(&cfg, opts);
let vals: Vec<_> = merged
.get_all("x-thing")
.iter()
.map(|v| v.to_str().unwrap())
.collect();
assert_eq!(vals, vec!["opt-1", "opt-2"]);
}
#[test]
fn merge_headers_empty_config_fast_path() {
let cfg = http::HeaderMap::new();
let mut opts = http::HeaderMap::new();
opts.insert("x", "y".parse().unwrap());
let merged = merge_headers(&cfg, opts);
assert_eq!(merged.get("x").unwrap(), "y");
}
#[test]
fn merge_headers_empty_options_fast_path() {
let mut cfg = http::HeaderMap::new();
cfg.insert("x", "y".parse().unwrap());
let opts = http::HeaderMap::new();
let merged = merge_headers(&cfg, opts);
assert_eq!(merged.get("x").unwrap(), "y");
}
#[test]
fn get_base64_encoding_matches_rfc4648_urlsafe_no_pad() {
use base64::Engine;
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(b"\xfa\xfb\xfc");
assert!(!encoded.contains('+'), "URL-safe must not contain +");
assert!(!encoded.contains('/'), "URL-safe must not contain /");
assert!(!encoded.contains('='), "no-pad must not contain =");
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(&encoded)
.unwrap();
assert_eq!(decoded, b"\xfa\xfb\xfc");
}
}