use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context as TaskContext;
use std::task::Poll;
use std::time::Duration;
use bytes::Bytes;
use futures::{Stream, StreamExt};
use http::Method;
use http::Request;
use http::Response;
use http::StatusCode;
use http::header;
use http_body::Body;
use http_body::Frame;
use http_body_util::BodyExt;
use http_body_util::Full;
use serde::Serialize;
use tracing::Instrument;
use crate::codec::CodecFormat;
use crate::codec::content_type;
use crate::codec::header as connect_header;
use crate::compression::CompressionPolicy;
use crate::compression::CompressionRegistry;
use crate::dispatcher::Dispatcher;
use crate::envelope::Envelope;
use crate::error::ConnectError;
use crate::handler::BoxStream;
use crate::handler::Context;
use crate::protocol::Protocol;
use crate::router::MethodKind;
use crate::router::Router;
#[derive(Debug, Default)]
struct GetQueryParams {
message: Option<String>,
encoding: Option<String>,
base64: bool,
compression: Option<String>,
connect_version: Option<String>,
}
fn parse_get_query_params(query: Option<&str>) -> Result<GetQueryParams, ConnectError> {
let Some(query) = query else {
return Err(ConnectError::invalid_argument(
"GET request requires query parameters",
));
};
let mut params = GetQueryParams::default();
for pair in query.split('&') {
let mut parts = pair.splitn(2, '=');
let key = parts.next().unwrap_or("");
let value = parts.next().unwrap_or("");
match key {
"message" => params.message = Some(value.to_owned()),
"encoding" => params.encoding = Some(value.to_owned()),
"base64" => params.base64 = value == "1",
"compression" => params.compression = Some(value.to_owned()),
"connect" => params.connect_version = Some(value.to_owned()),
_ => {} }
}
if params.encoding.is_none() {
return Err(ConnectError::invalid_argument(
"GET request requires 'encoding' query parameter",
));
}
Ok(params)
}
#[derive(Debug)]
struct RequestMetadata {
content_type: Option<String>,
#[allow(dead_code)]
protocol: Protocol,
timeout: Option<Duration>,
unary_encoding: Option<String>,
streaming_encoding: Option<String>,
unary_accept_encoding: Option<String>,
streaming_accept_encoding: Option<String>,
protocol_version: Option<String>,
headers: http::HeaderMap,
}
impl RequestMetadata {
fn from_headers(headers: &http::HeaderMap, protocol: Protocol) -> Self {
let content_type = headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let timeout = headers
.get(protocol.timeout_header())
.and_then(|v| v.to_str().ok())
.and_then(|s| parse_timeout(s, protocol));
let unary_encoding = headers
.get(header::CONTENT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let streaming_encoding = headers
.get(protocol.content_encoding_header())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let unary_accept_encoding = headers
.get(header::ACCEPT_ENCODING)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let streaming_accept_encoding = headers
.get(protocol.accept_encoding_header())
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
let protocol_version = headers
.get(connect_header::PROTOCOL_VERSION)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_owned());
Self {
content_type,
protocol,
timeout,
unary_encoding,
streaming_encoding,
unary_accept_encoding,
streaming_accept_encoding,
protocol_version,
headers: headers.clone(),
}
}
}
fn parse_timeout(s: &str, protocol: Protocol) -> Option<Duration> {
match protocol {
Protocol::Connect => {
if s.is_empty() || s.len() > 10 {
return None;
}
let ms = s.parse::<u64>().ok()?;
Some(Duration::from_millis(ms))
}
Protocol::Grpc | Protocol::GrpcWeb => {
if s.is_empty() || !s.is_ascii() {
return None;
}
let (digits, unit) = s.split_at(s.len() - 1);
if digits.is_empty() || digits.len() > 8 {
return None;
}
let value = digits.parse::<u64>().ok()?;
match unit {
"H" => value.checked_mul(3600).map(Duration::from_secs),
"M" => value.checked_mul(60).map(Duration::from_secs),
"S" => Some(Duration::from_secs(value)),
"m" => Some(Duration::from_millis(value)),
"u" => Some(Duration::from_micros(value)),
"n" => Some(Duration::from_nanos(value)),
_ => None,
}
}
}
}
async fn collect_body_limited<B>(body: B, limit: usize) -> Result<Bytes, ConnectError>
where
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
match http_body_util::Limited::new(body, limit).collect().await {
Ok(collected) => Ok(collected.to_bytes()),
Err(err) => {
if err
.downcast_ref::<http_body_util::LengthLimitError>()
.is_some()
{
Err(ConnectError::resource_exhausted(format!(
"request body size exceeds limit {limit}"
)))
} else {
Err(ConnectError::internal(format!(
"failed to read request body: {err}"
)))
}
}
}
}
fn decode_get_message(
params: &GetQueryParams,
compression: &CompressionRegistry,
max_message_size: usize,
) -> Result<Bytes, ConnectError> {
let Some(ref encoded_message) = params.message else {
return Ok(Bytes::new());
};
let decoded = if params.base64 {
use base64::Engine;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
let message = if encoded_message.contains('%') {
percent_decode(encoded_message)?
} else {
encoded_message.as_bytes().to_vec()
};
URL_SAFE_NO_PAD
.decode(&message)
.or_else(|_| {
use base64::engine::general_purpose::URL_SAFE;
URL_SAFE.decode(&message)
})
.map_err(|e| ConnectError::invalid_argument(format!("invalid base64 encoding: {e}")))?
} else {
percent_decode(encoded_message)?
};
let body = if let Some(ref encoding) = params.compression {
if encoding != "identity" {
compression.decompress_with_limit(encoding, Bytes::from(decoded), max_message_size)?
} else {
Bytes::from(decoded)
}
} else {
Bytes::from(decoded)
};
if body.len() > max_message_size {
return Err(ConnectError::resource_exhausted(format!(
"message size {} exceeds limit {}",
body.len(),
max_message_size
)));
}
Ok(body)
}
fn percent_decode(input: &str) -> Result<Vec<u8>, ConnectError> {
let with_spaces = input.replace('+', " ");
Ok(percent_encoding::percent_decode_str(&with_spaces).collect())
}
pub const DEFAULT_MAX_REQUEST_BODY_SIZE: usize = 4 * 1024 * 1024;
pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 4 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct Limits {
pub max_request_body_size: usize,
pub max_message_size: usize,
}
impl Default for Limits {
fn default() -> Self {
Self {
max_request_body_size: DEFAULT_MAX_REQUEST_BODY_SIZE,
max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
}
}
}
impl Limits {
pub fn unlimited() -> Self {
Self {
max_request_body_size: usize::MAX,
max_message_size: usize::MAX,
}
}
#[must_use]
pub fn max_request_body_size(mut self, size: usize) -> Self {
self.max_request_body_size = size;
self
}
#[must_use]
pub fn max_message_size(mut self, size: usize) -> Self {
self.max_message_size = size;
self
}
}
#[derive(Debug, Clone, Serialize)]
struct EndStreamResponse {
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<EndStreamError>,
#[serde(skip_serializing_if = "Option::is_none")]
metadata: Option<std::collections::HashMap<String, Vec<String>>>,
}
#[derive(Debug, Clone, Serialize)]
struct EndStreamError {
code: String,
#[serde(skip_serializing_if = "Option::is_none")]
message: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
details: Option<Vec<serde_json::Value>>,
}
impl EndStreamResponse {
fn metadata_from_trailers(
trailers: &http::HeaderMap,
) -> Option<std::collections::HashMap<String, Vec<String>>> {
if trailers.is_empty() {
None
} else {
Some(headers_to_metadata(trailers))
}
}
fn success(trailers: &http::HeaderMap) -> Self {
Self {
error: None,
metadata: Self::metadata_from_trailers(trailers),
}
}
fn error(err: &ConnectError, context_trailers: &http::HeaderMap) -> Self {
let trailers_source = if err.trailers.is_empty() {
context_trailers
} else {
&err.trailers
};
let metadata = Self::metadata_from_trailers(trailers_source);
Self {
error: Some(EndStreamError {
code: err.code.as_str().to_owned(),
message: err.message.clone(),
details: if err.details.is_empty() {
None
} else {
Some(
err.details
.iter()
.filter_map(|d| serde_json::to_value(d).ok())
.collect(),
)
},
}),
metadata,
}
}
fn to_json(&self) -> Bytes {
serde_json::to_vec(self)
.map(Bytes::from)
.unwrap_or_else(|_| Bytes::from_static(b"{}"))
}
}
fn streaming_error_response(
err: &ConnectError,
protocol: Protocol,
codec_format: CodecFormat,
) -> Response<StreamingResponseBody> {
match protocol {
Protocol::Connect => connect_streaming_error_response(err, codec_format),
Protocol::Grpc | Protocol::GrpcWeb => grpc_error_response(err, protocol, codec_format),
}
}
fn connect_streaming_error_response(
err: &ConnectError,
codec_format: CodecFormat,
) -> Response<StreamingResponseBody> {
use futures::stream::StreamExt as _;
let end_stream = EndStreamResponse::error(err, &err.trailers);
let mut encoder = crate::envelope::EnvelopeEncoder::uncompressed();
let mut buf = bytes::BytesMut::new();
let _ = encoder.encode_end_stream(end_stream.to_json(), &mut buf);
let encoded = buf.freeze();
let body_stream = futures::stream::unfold(Some(encoded), async |data| {
data.map(|bytes| (Ok(Frame::data(bytes)), None))
})
.fuse();
let body = StreamingResponseBody {
inner: Box::pin(body_stream),
_reader_task: None,
};
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
Protocol::Connect.response_content_type(codec_format, true),
);
for (key, value) in err.response_headers.iter() {
response = response.header(key, value);
}
response.body(body).unwrap_or_else(|_| {
Response::new(StreamingResponseBody {
inner: Box::pin(futures::stream::empty()),
_reader_task: None,
})
})
}
fn grpc_error_response(
err: &ConnectError,
protocol: Protocol,
codec_format: CodecFormat,
) -> Response<StreamingResponseBody> {
let grpc_trailers = build_grpc_trailers(Some(err), &err.trailers);
let body_stream: Pin<Box<dyn Stream<Item = Result<Frame<Bytes>, Infallible>> + Send>> =
match protocol {
Protocol::Grpc => {
Box::pin(
futures::stream::once(async move { Ok(Frame::trailers(grpc_trailers)) }).fuse(),
)
}
Protocol::GrpcWeb => {
let trailer_bytes = encode_grpc_web_trailers(&grpc_trailers);
Box::pin(
futures::stream::once(async move { Ok(Frame::data(trailer_bytes)) }).fuse(),
)
}
Protocol::Connect => unreachable!("Connect handled separately"),
};
let body = StreamingResponseBody {
inner: body_stream,
_reader_task: None,
};
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if protocol == Protocol::Grpc {
response = response.header(&GRPC_STATUS, err.code.grpc_code());
if let Some(val) = err
.message
.as_deref()
.and_then(|m| http::HeaderValue::from_str(&grpc_percent_encode(m)).ok())
{
response = response.header(&GRPC_MESSAGE, val);
}
}
for (key, value) in err.response_headers.iter() {
response = response.header(key, value);
}
response.body(body).unwrap_or_else(|_| {
Response::new(StreamingResponseBody {
inner: Box::pin(futures::stream::empty()),
_reader_task: None,
})
})
}
fn encode_grpc_web_trailers(trailers: &http::HeaderMap) -> Bytes {
let mut trailer_payload = Vec::new();
for (key, value) in trailers.iter() {
trailer_payload.extend_from_slice(key.as_str().as_bytes());
trailer_payload.extend_from_slice(b": ");
trailer_payload.extend_from_slice(value.as_bytes());
trailer_payload.extend_from_slice(b"\r\n");
}
let len = trailer_payload.len() as u32;
let mut frame = Vec::with_capacity(5 + trailer_payload.len());
frame.push(0x80); frame.extend_from_slice(&len.to_be_bytes());
frame.extend_from_slice(&trailer_payload);
Bytes::from(frame)
}
fn headers_to_metadata(
headers: &http::HeaderMap,
) -> std::collections::HashMap<String, Vec<String>> {
let mut metadata = std::collections::HashMap::new();
for (key, value) in headers.iter() {
let key_str = key.as_str().to_owned();
let value_str = value.to_str().unwrap_or("").to_owned();
metadata
.entry(key_str)
.or_insert_with(Vec::new)
.push(value_str);
}
metadata
}
pub struct StreamingResponseBody {
inner: Pin<Box<dyn Stream<Item = Result<Frame<Bytes>, Infallible>> + Send>>,
_reader_task: Option<tokio::task::JoinHandle<()>>,
}
impl StreamingResponseBody {
fn new(
response_stream: BoxStream<Result<Bytes, ConnectError>>,
ctx: Context,
protocol: Protocol,
compression: Option<(Arc<CompressionRegistry>, &'static str)>,
compression_policy: CompressionPolicy,
) -> Self {
let trailers = ctx.trailers.clone();
let inner: Pin<Box<dyn Stream<Item = Result<Frame<Bytes>, Infallible>> + Send>> =
match protocol {
Protocol::Grpc => Box::pin(create_grpc_envelope_stream(
response_stream,
trailers,
compression,
compression_policy,
)),
Protocol::GrpcWeb => Box::pin(create_grpc_web_envelope_stream(
response_stream,
trailers,
compression,
compression_policy,
)),
Protocol::Connect => Box::pin(create_envelope_stream(
response_stream,
trailers,
compression,
compression_policy,
)),
};
Self {
inner,
_reader_task: None,
}
}
fn with_reader_task(mut self, task: tokio::task::JoinHandle<()>) -> Self {
self._reader_task = Some(task);
self
}
}
impl Body for StreamingResponseBody {
type Data = Bytes;
type Error = Infallible;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
self.inner.as_mut().poll_next(cx)
}
}
const STREAM_BATCH_THRESHOLD: usize = 16 * 1024;
enum StreamFinalizer {
ConnectEndStream,
GrpcTrailers,
GrpcWebTrailers,
}
impl StreamFinalizer {
fn success(&self, trailers: &http::HeaderMap) -> Frame<Bytes> {
match self {
StreamFinalizer::ConnectEndStream => {
let end_stream = EndStreamResponse::success(trailers);
let mut buf = bytes::BytesMut::new();
let mut enc = crate::envelope::EnvelopeEncoder::uncompressed();
let _ = enc.encode_end_stream(end_stream.to_json(), &mut buf);
Frame::data(buf.freeze())
}
StreamFinalizer::GrpcTrailers => Frame::trailers(build_grpc_trailers(None, trailers)),
StreamFinalizer::GrpcWebTrailers => {
let t = build_grpc_trailers(None, trailers);
Frame::data(encode_grpc_web_trailers(&t))
}
}
}
fn error(&self, err: &ConnectError, trailers: &http::HeaderMap) -> Frame<Bytes> {
match self {
StreamFinalizer::ConnectEndStream => {
let end_stream = EndStreamResponse::error(err, trailers);
let mut buf = bytes::BytesMut::new();
let mut enc = crate::envelope::EnvelopeEncoder::uncompressed();
let _ = enc.encode_end_stream(end_stream.to_json(), &mut buf);
Frame::data(buf.freeze())
}
StreamFinalizer::GrpcTrailers => {
Frame::trailers(build_grpc_trailers(Some(err), trailers))
}
StreamFinalizer::GrpcWebTrailers => {
let t = build_grpc_trailers(Some(err), trailers);
Frame::data(encode_grpc_web_trailers(&t))
}
}
}
}
struct BatchingEnvelopeStream {
source: futures::stream::Fuse<BoxStream<Result<Bytes, ConnectError>>>,
buf: bytes::BytesMut,
encoder: crate::envelope::EnvelopeEncoder,
trailers: http::HeaderMap,
finalizer: StreamFinalizer,
pending_final: Option<Frame<Bytes>>,
done: bool,
}
impl BatchingEnvelopeStream {
fn new(
source: BoxStream<Result<Bytes, ConnectError>>,
trailers: http::HeaderMap,
compression: Option<(Arc<CompressionRegistry>, &'static str)>,
compression_policy: CompressionPolicy,
finalizer: StreamFinalizer,
) -> Self {
Self {
source: source.fuse(),
buf: bytes::BytesMut::new(),
encoder: crate::envelope::EnvelopeEncoder::new(compression, compression_policy),
trailers,
finalizer,
pending_final: None,
done: false,
}
}
#[inline]
fn flush_buf(&mut self) -> Frame<Bytes> {
Frame::data(self.buf.split().freeze())
}
}
impl Stream for BatchingEnvelopeStream {
type Item = Result<Frame<Bytes>, Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Option<Self::Item>> {
use tokio_util::codec::Encoder;
if self.done {
return Poll::Ready(None);
}
if let Some(frame) = self.pending_final.take() {
self.done = true;
return Poll::Ready(Some(Ok(frame)));
}
loop {
match self.source.poll_next_unpin(cx) {
Poll::Pending if self.buf.is_empty() => {
return Poll::Pending;
}
Poll::Pending => {
return Poll::Ready(Some(Ok(self.flush_buf())));
}
Poll::Ready(None) => {
let final_frame = self.finalizer.success(&self.trailers);
if self.buf.is_empty() {
self.done = true;
return Poll::Ready(Some(Ok(final_frame)));
} else {
self.pending_final = Some(final_frame);
return Poll::Ready(Some(Ok(self.flush_buf())));
}
}
Poll::Ready(Some(Err(err))) => {
tracing::debug!(
error = %err,
"streaming response: source error, emitting error trailers"
);
let final_frame = self.finalizer.error(&err, &self.trailers);
if self.buf.is_empty() {
self.done = true;
return Poll::Ready(Some(Ok(final_frame)));
} else {
self.pending_final = Some(final_frame);
return Poll::Ready(Some(Ok(self.flush_buf())));
}
}
Poll::Ready(Some(Ok(data))) => {
let me = &mut *self;
if let Err(err) = me.encoder.encode(data, &mut me.buf) {
tracing::debug!(
error = %err,
"streaming response: envelope encoding failed"
);
let final_frame = self.finalizer.error(&err, &self.trailers);
if self.buf.is_empty() {
self.done = true;
return Poll::Ready(Some(Ok(final_frame)));
} else {
self.pending_final = Some(final_frame);
return Poll::Ready(Some(Ok(self.flush_buf())));
}
}
if self.buf.len() >= STREAM_BATCH_THRESHOLD {
return Poll::Ready(Some(Ok(self.flush_buf())));
}
}
}
}
}
}
fn create_envelope_stream(
response_stream: BoxStream<Result<Bytes, ConnectError>>,
trailers: http::HeaderMap,
compression: Option<(Arc<CompressionRegistry>, &'static str)>,
compression_policy: CompressionPolicy,
) -> impl Stream<Item = Result<Frame<Bytes>, Infallible>> + Send {
BatchingEnvelopeStream::new(
response_stream,
trailers,
compression,
compression_policy,
StreamFinalizer::ConnectEndStream,
)
}
fn create_grpc_envelope_stream(
response_stream: BoxStream<Result<Bytes, ConnectError>>,
trailers: http::HeaderMap,
compression: Option<(Arc<CompressionRegistry>, &'static str)>,
compression_policy: CompressionPolicy,
) -> impl Stream<Item = Result<Frame<Bytes>, Infallible>> + Send {
BatchingEnvelopeStream::new(
response_stream,
trailers,
compression,
compression_policy,
StreamFinalizer::GrpcTrailers,
)
}
fn create_grpc_web_envelope_stream(
response_stream: BoxStream<Result<Bytes, ConnectError>>,
trailers: http::HeaderMap,
compression: Option<(Arc<CompressionRegistry>, &'static str)>,
compression_policy: CompressionPolicy,
) -> impl Stream<Item = Result<Frame<Bytes>, Infallible>> + Send {
BatchingEnvelopeStream::new(
response_stream,
trailers,
compression,
compression_policy,
StreamFinalizer::GrpcWebTrailers,
)
}
pub struct ConnectRpcService<D = Router> {
dispatcher: Arc<D>,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: CompressionPolicy,
}
impl<D> Clone for ConnectRpcService<D> {
fn clone(&self) -> Self {
Self {
dispatcher: Arc::clone(&self.dispatcher),
limits: self.limits.clone(),
compression: Arc::clone(&self.compression),
compression_policy: self.compression_policy,
}
}
}
impl<D: Dispatcher> ConnectRpcService<D> {
pub fn new(dispatcher: D) -> Self {
Self {
dispatcher: Arc::new(dispatcher),
limits: Limits::default(),
compression: Arc::new(CompressionRegistry::default()),
compression_policy: CompressionPolicy::default(),
}
}
pub fn from_arc(dispatcher: Arc<D>) -> Self {
Self {
dispatcher,
limits: Limits::default(),
compression: Arc::new(CompressionRegistry::default()),
compression_policy: CompressionPolicy::default(),
}
}
#[must_use]
pub fn with_limits(mut self, limits: Limits) -> Self {
self.limits = limits;
self
}
#[must_use]
pub fn with_compression(mut self, compression: CompressionRegistry) -> Self {
self.compression = Arc::new(compression);
self
}
#[must_use]
pub fn with_compression_policy(mut self, policy: CompressionPolicy) -> Self {
self.compression_policy = policy;
self
}
pub fn limits(&self) -> &Limits {
&self.limits
}
pub fn dispatcher(&self) -> &D {
&self.dispatcher
}
}
pub struct GrpcUnaryBody {
data: Option<Bytes>,
trailers: Option<GrpcUnaryTrailers>,
}
enum GrpcUnaryTrailers {
Http2(http::HeaderMap),
WebBody(Bytes),
}
impl Body for GrpcUnaryBody {
type Data = Bytes;
type Error = Infallible;
fn poll_frame(
self: Pin<&mut Self>,
_cx: &mut TaskContext<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let me = self.get_mut();
if let Some(data) = me.data.take() {
Poll::Ready(Some(Ok(Frame::data(data))))
} else if let Some(trailers) = me.trailers.take() {
match trailers {
GrpcUnaryTrailers::Http2(map) => Poll::Ready(Some(Ok(Frame::trailers(map)))),
GrpcUnaryTrailers::WebBody(bytes) => Poll::Ready(Some(Ok(Frame::data(bytes)))),
}
} else {
Poll::Ready(None)
}
}
}
#[non_exhaustive]
pub enum ConnectRpcBody {
Full(Full<Bytes>),
GrpcUnary(GrpcUnaryBody),
Streaming(StreamingResponseBody),
}
impl Body for ConnectRpcBody {
type Data = Bytes;
type Error = Infallible;
fn poll_frame(
self: Pin<&mut Self>,
cx: &mut TaskContext<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
match self.get_mut() {
ConnectRpcBody::Full(inner) => Pin::new(inner).poll_frame(cx),
ConnectRpcBody::GrpcUnary(inner) => Pin::new(inner).poll_frame(cx),
ConnectRpcBody::Streaming(inner) => Pin::new(inner).poll_frame(cx),
}
}
}
impl<D, B> tower::Service<Request<B>> for ConnectRpcService<D>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
type Response = Response<ConnectRpcBody>;
type Error = Infallible;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<B>) -> Self::Future {
let dispatcher = Arc::clone(&self.dispatcher);
let limits = self.limits.clone();
let compression = Arc::clone(&self.compression);
let compression_policy = self.compression_policy;
let span = if tracing::enabled!(tracing::Level::DEBUG) {
Some(tracing::debug_span!(
"connectrpc_request",
path = %req.uri().path(),
method = %req.method(),
protocol = tracing::field::Empty,
codec = tracing::field::Empty,
))
} else {
None
};
let fut = async move {
let response =
match handle_request(dispatcher, req, limits, compression, &compression_policy)
.await
{
Ok(response) => response,
Err(err) => error_response_either(err),
};
Ok(response)
};
match span {
Some(span) => Box::pin(fut.instrument(span)),
None => Box::pin(fut),
}
}
}
async fn handle_request<D, B>(
dispatcher: Arc<D>,
req: Request<B>,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Result<Response<ConnectRpcBody>, ConnectError>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let request_protocol = Protocol::detect(req.headers());
if let Some(ref rp) = request_protocol {
let span = tracing::Span::current();
span.record("protocol", tracing::field::display(rp.protocol));
span.record("codec", tracing::field::display(rp.codec_format));
}
if req.method() == Method::GET {
return handle_unary_request(&*dispatcher, req, limits, compression, compression_policy)
.await
.map(|r| r.map(ConnectRpcBody::Full));
}
if request_protocol.is_none() {
let ct = req
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("")
.to_owned();
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
if ct.starts_with("application/grpc-web") {
let err = ConnectError::internal("unsupported content type");
let response = grpc_error_response(&err, Protocol::GrpcWeb, CodecFormat::Proto);
return Ok(response.map(ConnectRpcBody::Streaming));
} else if ct.starts_with("application/grpc") {
let err = ConnectError::internal("unsupported content type");
let response = grpc_error_response(&err, Protocol::Grpc, CodecFormat::Proto);
return Ok(response.map(ConnectRpcBody::Streaming));
} else {
let response = Response::builder()
.status(StatusCode::UNSUPPORTED_MEDIA_TYPE)
.body(Full::new(Bytes::new()))
.unwrap();
return Ok(response.map(ConnectRpcBody::Full));
}
}
match request_protocol {
Some(rp) if rp.is_streaming => {
if rp.is_text_mode {
let err = ConnectError::unimplemented(
"gRPC-Web text mode (application/grpc-web-text) is not supported",
);
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
let response = grpc_error_response(&err, Protocol::GrpcWeb, rp.codec_format);
return Ok(response.map(ConnectRpcBody::Streaming));
}
if matches!(rp.protocol, Protocol::Grpc | Protocol::GrpcWeb) {
let path = req.uri().path();
let path = path.strip_prefix('/').unwrap_or(path);
if let Some(desc) = dispatcher.lookup(path)
&& desc.kind == MethodKind::Unary
{
let path = path.to_owned();
let response = handle_grpc_unary_request(
&*dispatcher,
&path,
req,
rp.protocol,
rp.codec_format,
limits,
compression,
compression_policy,
)
.await;
return Ok(response.map(ConnectRpcBody::GrpcUnary));
}
}
let response = handle_streaming_request(
&*dispatcher,
req,
rp.protocol,
rp.codec_format,
limits,
compression,
compression_policy,
)
.await;
Ok(response.map(ConnectRpcBody::Streaming))
}
Some(_) | None => {
handle_unary_request(&*dispatcher, req, limits, compression, compression_policy)
.await
.map(|r| r.map(ConnectRpcBody::Full))
}
}
}
async fn handle_unary_request<D, B>(
dispatcher: &D,
req: Request<B>,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Result<Response<Full<Bytes>>, ConnectError>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let path = req.uri().path();
let path = path.strip_prefix('/').unwrap_or(path).to_owned();
let query_string = req.uri().query().map(|s| s.to_owned());
let method = req.method().clone();
let metadata = RequestMetadata::from_headers(req.headers(), Protocol::Connect);
let (parts, body) = req.into_parts();
let extensions = parts.extensions;
let post_body = collect_body_limited(body, limits.max_request_body_size).await?;
let desc = dispatcher.lookup(&path).ok_or_else(|| {
ConnectError::unimplemented(format!("method not found: {path}"))
.with_http_status(StatusCode::NOT_FOUND)
})?;
let is_idempotent = desc.idempotent;
let (body, codec_format) = if method == Method::GET {
if !is_idempotent {
return Err(ConnectError::method_not_allowed(
"GET requests are only supported for idempotent methods",
));
}
let params = parse_get_query_params(query_string.as_deref())?;
if let Some(ref version) = params.connect_version
&& version != "v1"
{
return Err(ConnectError::invalid_argument(
"unsupported protocol version",
));
}
let message = decode_get_message(¶ms, &compression, limits.max_message_size)?;
let encoding = params.encoding.as_deref().unwrap_or("proto");
let codec_format = CodecFormat::from_codec(encoding).ok_or_else(|| {
ConnectError::unsupported_media_type(format!("unsupported encoding: {encoding}"))
})?;
(message, codec_format)
} else if method == Method::POST {
if let Some(ref version) = metadata.protocol_version
&& version != "1"
{
return Err(ConnectError::invalid_argument(
"unsupported protocol version",
));
}
let content_type_str = metadata
.content_type
.as_deref()
.unwrap_or(content_type::PROTO);
let codec_format = CodecFormat::from_content_type(content_type_str).ok_or_else(|| {
ConnectError::unsupported_media_type(format!(
"unsupported content type: {content_type_str}"
))
})?;
let body = if let Some(ref encoding) = metadata.unary_encoding {
compression.decompress_with_limit(encoding, post_body, limits.max_message_size)?
} else {
post_body
};
if body.len() > limits.max_message_size {
return Err(ConnectError::resource_exhausted(format!(
"message size {} exceeds limit {}",
body.len(),
limits.max_message_size
)));
}
(body, codec_format)
} else {
return Err(ConnectError::method_not_allowed(
"only GET and POST methods are supported",
));
};
let deadline = metadata
.timeout
.and_then(|t| std::time::Instant::now().checked_add(t));
let ctx = Context::new(metadata.headers)
.with_deadline(deadline)
.with_extensions(extensions);
let (response_body, ctx) = if let Some(timeout) = metadata.timeout {
tokio::time::timeout(
timeout,
dispatcher.call_unary(&path, ctx, body, codec_format),
)
.await
.map_err(|_| ConnectError::deadline_exceeded("request timeout"))?
} else {
dispatcher.call_unary(&path, ctx, body, codec_format).await
}?;
let response_encoding = compression.negotiate_encoding(
metadata.unary_accept_encoding.as_deref(),
metadata.unary_encoding.as_deref(),
);
let effective_policy = compression_policy.with_override(ctx.compress_response);
let (final_body, content_encoding) = if let Some(encoding) = response_encoding {
if !effective_policy.should_compress(response_body.len()) {
(response_body, None)
} else {
match compression.compress(encoding, &response_body) {
Ok(compressed) => (compressed, Some(encoding)),
Err(_) => (response_body, None), }
}
} else {
(response_body, None)
};
let mut response = Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, codec_format.content_type());
if let Some(encoding) = content_encoding {
response = response.header(header::CONTENT_ENCODING, encoding);
}
let accept = compression.accept_encoding_header();
if !accept.is_empty() {
response = response.header(header::ACCEPT_ENCODING, accept);
}
for (key, value) in ctx.response_headers.iter() {
response = response.header(key, value);
}
let response = add_trailers(response, &ctx.trailers);
response
.body(Full::new(final_body))
.map_err(|e| ConnectError::internal(format!("failed to build response: {e}")))
}
#[allow(clippy::too_many_arguments)]
async fn handle_grpc_unary_request<D, B>(
dispatcher: &D,
path: &str,
req: Request<B>,
protocol: Protocol,
codec_format: CodecFormat,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Response<GrpcUnaryBody>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let grpc_unary_error = |err: &ConnectError| -> Response<GrpcUnaryBody> {
let grpc_trailers = build_grpc_trailers(Some(err), &err.trailers);
let trailers = match protocol {
Protocol::Grpc => GrpcUnaryTrailers::Http2(grpc_trailers),
Protocol::GrpcWeb => {
GrpcUnaryTrailers::WebBody(encode_grpc_web_trailers(&grpc_trailers))
}
Protocol::Connect => unreachable!("gRPC unary fast path is gRPC/gRPC-Web only"),
};
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if protocol == Protocol::Grpc {
response = response.header(&GRPC_STATUS, err.code.grpc_code());
if let Some(val) = err
.message
.as_deref()
.and_then(|m| http::HeaderValue::from_str(&grpc_percent_encode(m)).ok())
{
response = response.header(&GRPC_MESSAGE, val);
}
}
for (key, value) in err.response_headers.iter() {
response = response.header(key, value);
}
let body = GrpcUnaryBody {
data: None,
trailers: Some(trailers),
};
response.body(body).unwrap_or_else(|_| {
Response::new(GrpcUnaryBody {
data: None,
trailers: None,
})
})
};
if req.method() != Method::POST {
let err = ConnectError::internal(format!("invalid method for gRPC: {}", req.method()));
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
return grpc_unary_error(&err);
}
let metadata = RequestMetadata::from_headers(req.headers(), protocol);
if let Some(ref encoding) = metadata.streaming_encoding
&& encoding != "identity"
&& !compression.supports(encoding)
{
let err = ConnectError::unimplemented(format!("unsupported compression: {encoding}"));
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
return grpc_unary_error(&err);
}
let (parts, body) = req.into_parts();
let extensions = parts.extensions;
let post_body = match collect_body_limited(body, limits.max_request_body_size).await {
Ok(bytes) => bytes,
Err(err) => return grpc_unary_error(&err),
};
let request_body = if post_body.is_empty() {
let err = ConnectError::unimplemented("request body is empty: expected a message");
return grpc_unary_error(&err);
} else {
let mut buf = bytes::BytesMut::from(&post_body[..]);
let envelope = match Envelope::decode_with_limit(&mut buf, limits.max_message_size) {
Ok(Some(env)) => env,
Ok(None) => {
let err = ConnectError::invalid_argument("incomplete request envelope");
return grpc_unary_error(&err);
}
Err(e) => return grpc_unary_error(&e),
};
if !buf.is_empty() {
let err = ConnectError::unimplemented("unary request must have exactly one message");
return grpc_unary_error(&err);
}
if envelope.is_compressed() {
let encoding = match metadata.streaming_encoding.as_deref() {
Some(enc) if enc != "identity" => enc,
_ => {
let err = ConnectError::internal(format!(
"received compressed message without {} header",
protocol.content_encoding_header()
));
return grpc_unary_error(&err);
}
};
match compression.decompress_with_limit(
encoding,
envelope.data,
limits.max_message_size,
) {
Ok(data) => data,
Err(e) => return grpc_unary_error(&e),
}
} else {
envelope.data
}
};
let deadline = metadata
.timeout
.and_then(|t| std::time::Instant::now().checked_add(t));
let ctx = Context::new(metadata.headers)
.with_deadline(deadline)
.with_extensions(extensions);
let handler_result = if let Some(timeout) = metadata.timeout {
match tokio::time::timeout(
timeout,
dispatcher.call_unary(path, ctx, request_body, codec_format),
)
.await
{
Ok(result) => result,
Err(_) => {
let err = ConnectError::deadline_exceeded("request timeout");
return grpc_unary_error(&err);
}
}
} else {
dispatcher
.call_unary(path, ctx, request_body, codec_format)
.await
};
let (response_bytes, ctx) = match handler_result {
Ok(result) => result,
Err(e) => return grpc_unary_error(&e),
};
let response_encoding = compression.negotiate_encoding(
metadata.streaming_accept_encoding.as_deref(),
metadata.streaming_encoding.as_deref(),
);
let effective_policy = compression_policy.with_override(ctx.compress_response);
let encoded_data = if let Some(encoding) = response_encoding
&& effective_policy.should_compress(response_bytes.len())
{
match compression.compress(encoding, &response_bytes) {
Ok(compressed) => Envelope::compressed(compressed).encode(),
Err(_) => Envelope::data(response_bytes).encode(),
}
} else {
Envelope::data(response_bytes).encode()
};
let grpc_trailers = build_grpc_trailers(None, &ctx.trailers);
let trailers = match protocol {
Protocol::Grpc => GrpcUnaryTrailers::Http2(grpc_trailers),
Protocol::GrpcWeb => GrpcUnaryTrailers::WebBody(encode_grpc_web_trailers(&grpc_trailers)),
Protocol::Connect => unreachable!("Connect unary uses handle_unary_request"),
};
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if let Some(encoding) = response_encoding {
response = response.header(protocol.content_encoding_header(), encoding);
}
let accept = compression.accept_encoding_header();
if !accept.is_empty() {
response = response.header(protocol.accept_encoding_header(), accept);
}
for (key, value) in ctx.response_headers.iter() {
response = response.header(key, value);
}
let body = GrpcUnaryBody {
data: Some(encoded_data),
trailers: Some(trailers),
};
response.body(body).unwrap_or_else(|_| {
let err = ConnectError::internal("failed to build response");
grpc_unary_error(&err)
})
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum StreamingDispatchKind {
ServerStreaming,
Unary,
}
#[allow(clippy::too_many_arguments)]
async fn handle_streaming_request<D, B>(
dispatcher: &D,
req: Request<B>,
protocol: Protocol,
codec_format: CodecFormat,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Response<StreamingResponseBody>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::error::Error + Send + Sync + 'static,
{
let path = req.uri().path();
let path = path.strip_prefix('/').unwrap_or(path).to_owned();
if matches!(protocol, Protocol::Grpc | Protocol::GrpcWeb) && req.method() != Method::POST {
let err = ConnectError::internal(format!("invalid method for gRPC: {}", req.method()));
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
return streaming_error_response(&err, protocol, codec_format);
}
let metadata = RequestMetadata::from_headers(req.headers(), protocol);
if let Some(ref encoding) = metadata.streaming_encoding
&& encoding != "identity"
&& !compression.supports(encoding)
{
let err = ConnectError::unimplemented(format!("unsupported compression: {encoding}"));
let (_parts, body) = req.into_parts();
let limited = http_body_util::Limited::new(body, limits.max_request_body_size);
let _ = limited.collect().await;
return streaming_error_response(&err, protocol, codec_format);
}
let method_desc = dispatcher.lookup(&path);
let (parts, body) = req.into_parts();
let extensions = parts.extensions;
if matches!(method_desc, Some(d) if d.kind == MethodKind::BidiStreaming) {
return handle_bidi_streaming_request(
dispatcher,
&path,
metadata,
body,
extensions,
protocol,
codec_format,
limits,
compression,
compression_policy,
)
.await;
}
if matches!(method_desc, Some(d) if d.kind == MethodKind::ClientStreaming) {
return handle_client_streaming_request(
dispatcher,
&path,
metadata,
body,
extensions,
protocol,
codec_format,
limits,
compression,
compression_policy,
)
.await;
}
let post_body = match collect_body_limited(body, limits.max_request_body_size).await {
Ok(bytes) => bytes,
Err(err) => return streaming_error_response(&err, protocol, codec_format),
};
let dispatch_kind = match method_desc.map(|d| d.kind) {
Some(MethodKind::ServerStreaming) => StreamingDispatchKind::ServerStreaming,
Some(MethodKind::Unary) => match protocol {
Protocol::Grpc | Protocol::GrpcWeb => StreamingDispatchKind::Unary,
Protocol::Connect => {
let err =
ConnectError::invalid_argument("streaming content type used for unary method");
return streaming_error_response(&err, protocol, codec_format);
}
},
None => {
let err = ConnectError::unimplemented(format!("method not found: {path}"));
return streaming_error_response(&err, protocol, codec_format);
}
Some(MethodKind::BidiStreaming | MethodKind::ClientStreaming) => {
unreachable!("bidi and client streaming handled before body buffering")
}
};
let request_body = if post_body.is_empty() {
let err = ConnectError::unimplemented("server streaming request requires a message");
return streaming_error_response(&err, protocol, codec_format);
} else {
let mut buf = bytes::BytesMut::from(&post_body[..]);
let envelope = match Envelope::decode_with_limit(&mut buf, limits.max_message_size) {
Ok(Some(env)) => env,
Ok(None) => {
let err = ConnectError::invalid_argument("incomplete request envelope");
return streaming_error_response(&err, protocol, codec_format);
}
Err(e) => {
return streaming_error_response(&e, protocol, codec_format);
}
};
if !buf.is_empty() {
let err = ConnectError::unimplemented(
"server streaming request must have exactly one message",
);
return streaming_error_response(&err, protocol, codec_format);
}
if envelope.is_compressed() {
let encoding = match metadata.streaming_encoding.as_deref() {
Some(enc) if enc != "identity" => enc,
_ => {
let err = ConnectError::internal(format!(
"received compressed message without {} header",
protocol.content_encoding_header()
));
return streaming_error_response(&err, protocol, codec_format);
}
};
match compression.decompress_with_limit(
encoding,
envelope.data,
limits.max_message_size,
) {
Ok(data) => data,
Err(e) => {
return streaming_error_response(&e, protocol, codec_format);
}
}
} else {
envelope.data
}
};
let deadline = metadata
.timeout
.and_then(|t| std::time::Instant::now().checked_add(t));
let ctx = Context::new(metadata.headers)
.with_deadline(deadline)
.with_extensions(extensions);
let (response_stream, ctx): (BoxStream<Result<Bytes, ConnectError>>, Context) =
match dispatch_kind {
StreamingDispatchKind::ServerStreaming => {
let fut = dispatcher.call_server_streaming(&path, ctx, request_body, codec_format);
let handler_result = if let Some(timeout) = metadata.timeout {
match tokio::time::timeout(timeout, fut).await {
Ok(result) => result,
Err(_) => {
let err = ConnectError::deadline_exceeded("request timeout");
return streaming_error_response(&err, protocol, codec_format);
}
}
} else {
fut.await
};
match handler_result {
Ok(result) => result,
Err(e) => return streaming_error_response(&e, protocol, codec_format),
}
}
StreamingDispatchKind::Unary => {
let fut = dispatcher.call_unary(&path, ctx, request_body, codec_format);
let handler_result = if let Some(timeout) = metadata.timeout {
match tokio::time::timeout(timeout, fut).await {
Ok(result) => result,
Err(_) => {
let err = ConnectError::deadline_exceeded("request timeout");
return streaming_error_response(&err, protocol, codec_format);
}
}
} else {
fut.await
};
match handler_result {
Ok((response_bytes, ctx)) => {
let stream: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::once(async move { Ok(response_bytes) }));
(stream, ctx)
}
Err(e) => return streaming_error_response(&e, protocol, codec_format),
}
}
};
let response_encoding = compression.negotiate_encoding(
metadata.streaming_accept_encoding.as_deref(),
metadata.streaming_encoding.as_deref(),
);
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if let Some(encoding) = response_encoding {
response = response.header(protocol.content_encoding_header(), encoding);
}
let accept = compression.accept_encoding_header();
if !accept.is_empty() {
response = response.header(protocol.accept_encoding_header(), accept);
}
for (key, value) in ctx.response_headers.iter() {
response = response.header(key, value);
}
let stream_compression = response_encoding.map(|encoding| (compression, encoding));
let effective_policy = compression_policy.with_override(ctx.compress_response);
let body = StreamingResponseBody::new(
response_stream,
ctx,
protocol,
stream_compression,
effective_policy,
);
response.body(body).unwrap_or_else(|_| {
let err = ConnectError::internal("failed to build streaming response");
streaming_error_response(&err, protocol, codec_format)
})
}
#[allow(clippy::too_many_arguments)]
async fn handle_client_streaming_request<D, B>(
dispatcher: &D,
path: &str,
metadata: RequestMetadata,
body: B,
extensions: http::Extensions,
protocol: Protocol,
codec_format: CodecFormat,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Response<StreamingResponseBody>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::fmt::Display + Send,
{
let (request_stream, reader_task) = spawn_body_reader(
body,
limits.max_message_size,
metadata.streaming_encoding.clone(),
Arc::clone(&compression),
);
let deadline = metadata
.timeout
.and_then(|t| std::time::Instant::now().checked_add(t));
let ctx = Context::new(metadata.headers)
.with_deadline(deadline)
.with_extensions(extensions);
let handler_result = if let Some(timeout) = metadata.timeout {
match tokio::time::timeout(
timeout,
dispatcher.call_client_streaming(path, ctx, request_stream, codec_format),
)
.await
{
Ok(result) => result,
Err(_) => {
drop(reader_task);
let err = ConnectError::deadline_exceeded("request timeout");
return streaming_error_response(&err, protocol, codec_format);
}
}
} else {
dispatcher
.call_client_streaming(path, ctx, request_stream, codec_format)
.await
};
let (response_bytes, ctx) = match handler_result {
Ok(result) => result,
Err(e) => {
drop(reader_task);
return streaming_error_response(&e, protocol, codec_format);
}
};
let response_encoding = compression.negotiate_encoding(
metadata.streaming_accept_encoding.as_deref(),
metadata.streaming_encoding.as_deref(),
);
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if let Some(encoding) = response_encoding {
response = response.header(protocol.content_encoding_header(), encoding);
}
let accept = compression.accept_encoding_header();
if !accept.is_empty() {
response = response.header(protocol.accept_encoding_header(), accept);
}
for (key, value) in ctx.response_headers.iter() {
response = response.header(key, value);
}
let stream_compression = response_encoding.map(|encoding| (compression, encoding));
let response_stream: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::once(async { Ok(response_bytes) }));
let effective_policy = compression_policy.with_override(ctx.compress_response);
let body = StreamingResponseBody::new(
response_stream,
ctx,
protocol,
stream_compression,
effective_policy,
)
.with_reader_task(reader_task);
response.body(body).unwrap_or_else(|_| {
let err = ConnectError::internal("failed to build client streaming response");
streaming_error_response(&err, protocol, codec_format)
})
}
const MAX_DRAIN_BYTES: usize = 1024 * 1024; fn spawn_body_reader<B>(
body: B,
max_message_size: usize,
streaming_encoding: Option<String>,
compression: Arc<CompressionRegistry>,
) -> (
BoxStream<Result<Bytes, ConnectError>>,
tokio::task::JoinHandle<()>,
)
where
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::fmt::Display + Send,
{
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Bytes, ConnectError>>(1);
let reader_task = tokio::spawn(async move {
use tokio_util::codec::Decoder as _;
let mut body = std::pin::pin!(body);
let mut decoder = crate::envelope::EnvelopeDecoder::new(
max_message_size,
streaming_encoding,
compression,
);
let mut buf = bytes::BytesMut::new();
let mut decoder_done = false;
let mut drained_bytes: usize = 0;
loop {
if !decoder_done {
match decoder.decode(&mut buf) {
Ok(Some(data)) => {
if tx.send(Ok(data)).await.is_err() {
decoder_done = true;
}
continue;
}
Ok(None) => {} Err(e) => {
let _ = tx.send(Err(e)).await;
decoder_done = true;
continue;
}
}
}
match std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
Some(Ok(frame)) => {
if let Ok(data) = frame.into_data() {
if decoder_done {
drained_bytes = drained_bytes.saturating_add(data.len());
if drained_bytes > MAX_DRAIN_BYTES {
tracing::debug!(
drained_bytes,
"body drain limit reached, stopping"
);
break;
}
} else {
buf.extend_from_slice(&data);
}
}
}
Some(Err(_)) => break, None => {
if !decoder_done {
loop {
match decoder.decode_eof(&mut buf) {
Ok(Some(data)) => {
if tx.send(Ok(data)).await.is_err() {
break;
}
}
Ok(None) => break,
Err(e) => {
let _ = tx.send(Err(e)).await;
break;
}
}
}
}
break;
}
}
}
});
let request_stream: BoxStream<Result<Bytes, ConnectError>> =
Box::pin(futures::stream::unfold(rx, |mut rx| async {
rx.recv().await.map(|item| (item, rx))
}));
(request_stream, reader_task)
}
#[allow(clippy::too_many_arguments)]
async fn handle_bidi_streaming_request<D, B>(
dispatcher: &D,
path: &str,
metadata: RequestMetadata,
body: B,
extensions: http::Extensions,
protocol: Protocol,
codec_format: CodecFormat,
limits: Limits,
compression: Arc<CompressionRegistry>,
compression_policy: &CompressionPolicy,
) -> Response<StreamingResponseBody>
where
D: Dispatcher,
B: Body<Data = Bytes> + Send + 'static,
B::Error: std::fmt::Display + Send,
{
let (request_stream, reader_task) = spawn_body_reader(
body,
limits.max_message_size,
metadata.streaming_encoding.clone(),
Arc::clone(&compression),
);
let deadline = metadata
.timeout
.and_then(|t| std::time::Instant::now().checked_add(t));
let ctx = Context::new(metadata.headers)
.with_deadline(deadline)
.with_extensions(extensions);
let handler_result = if let Some(timeout) = metadata.timeout {
match tokio::time::timeout(
timeout,
dispatcher.call_bidi_streaming(path, ctx, request_stream, codec_format),
)
.await
{
Ok(result) => result,
Err(_) => {
let err = ConnectError::deadline_exceeded("request timeout");
drop(reader_task);
return streaming_error_response(&err, protocol, codec_format);
}
}
} else {
dispatcher
.call_bidi_streaming(path, ctx, request_stream, codec_format)
.await
};
let (response_stream, ctx) = match handler_result {
Ok(result) => result,
Err(e) => {
drop(reader_task);
return streaming_error_response(&e, protocol, codec_format);
}
};
let response_encoding = compression.negotiate_encoding(
metadata.streaming_accept_encoding.as_deref(),
metadata.streaming_encoding.as_deref(),
);
let mut response = Response::builder().status(StatusCode::OK).header(
header::CONTENT_TYPE,
protocol.response_content_type(codec_format, true),
);
if let Some(encoding) = response_encoding {
response = response.header(protocol.content_encoding_header(), encoding);
}
let accept = compression.accept_encoding_header();
if !accept.is_empty() {
response = response.header(protocol.accept_encoding_header(), accept);
}
for (key, value) in ctx.response_headers.iter() {
response = response.header(key, value);
}
let stream_compression = response_encoding.map(|encoding| (compression, encoding));
let effective_policy = compression_policy.with_override(ctx.compress_response);
let body = StreamingResponseBody::new(
response_stream,
ctx,
protocol,
stream_compression,
effective_policy,
)
.with_reader_task(reader_task);
response.body(body).unwrap_or_else(|_| {
let err = ConnectError::internal("failed to build bidi streaming response");
streaming_error_response(&err, protocol, codec_format)
})
}
fn add_trailers(
mut response: http::response::Builder,
trailers: &http::HeaderMap,
) -> http::response::Builder {
for (key, value) in trailers.iter() {
let trailer_key = format!("trailer-{}", key.as_str());
response = response.header(trailer_key, value);
}
response
}
use crate::protocol::hdr::GRPC_MESSAGE;
use crate::protocol::hdr::GRPC_STATUS;
use crate::protocol::hdr::GRPC_STATUS_DETAILS_BIN;
fn build_grpc_trailers(
error: Option<&ConnectError>,
custom_trailers: &http::HeaderMap,
) -> http::HeaderMap {
let mut trailers = http::HeaderMap::new();
match error {
Some(err) => {
trailers.insert(&GRPC_STATUS, http::HeaderValue::from(err.code.grpc_code()));
if let Some(val) = err
.message
.as_deref()
.and_then(|m| http::HeaderValue::from_str(&grpc_percent_encode(m)).ok())
{
trailers.insert(&GRPC_MESSAGE, val);
}
{
use base64::Engine;
let status_bytes = crate::grpc_status::encode(err);
let b64 = base64::engine::general_purpose::STANDARD_NO_PAD.encode(&status_bytes);
if let Ok(val) = http::HeaderValue::from_str(&b64) {
trailers.insert(&GRPC_STATUS_DETAILS_BIN, val);
}
}
for (key, value) in err.trailers.iter() {
trailers.append(key, value.clone());
}
}
None => {
trailers.insert(&GRPC_STATUS, http::HeaderValue::from_static("0"));
}
}
let error_has_own_trailers = error.is_some_and(|e| !e.trailers.is_empty());
if !error_has_own_trailers {
for (key, value) in custom_trailers.iter() {
trailers.append(key, value.clone());
}
}
trailers
}
const GRPC_MESSAGE_ENCODE_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS.add(b'%');
fn grpc_percent_encode(message: &str) -> String {
percent_encoding::utf8_percent_encode(message, GRPC_MESSAGE_ENCODE_SET).to_string()
}
fn error_response_either(err: ConnectError) -> Response<ConnectRpcBody> {
error_response(err).map(ConnectRpcBody::Full)
}
fn error_response(err: ConnectError) -> Response<Full<Bytes>> {
let status = err.http_status();
let body = err.to_json();
let mut response = Response::builder()
.status(status)
.header(header::CONTENT_TYPE, content_type::JSON);
for (key, value) in err.response_headers.iter() {
response = response.header(key, value);
}
let response = add_trailers(response, &err.trailers);
response.body(Full::new(body)).unwrap_or_else(|_| {
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::new()))
.unwrap()
})
}
#[cfg(feature = "axum")]
pub mod axum_integration {
use super::*;
use axum::body::Body;
use axum::response::IntoResponse;
impl Router {
pub fn into_axum_service(self) -> ConnectRpcService {
ConnectRpcService::new(self)
}
pub fn into_axum_router(self) -> axum::Router {
let service = ConnectRpcService::new(self);
axum::Router::new().fallback_service(service)
}
}
impl IntoResponse for ConnectError {
fn into_response(self) -> axum::response::Response {
let status = self.http_status();
let body = self.to_json();
let mut response = axum::response::Response::new(Body::from(body));
*response.status_mut() = status;
response.headers_mut().insert(
header::CONTENT_TYPE,
http::HeaderValue::from_static(content_type::JSON),
);
response
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_service_creation() {
let router = Router::new();
let _service = ConnectRpcService::new(router);
}
#[test]
fn test_service_clone() {
let router = Router::new();
let service = ConnectRpcService::new(router);
let _cloned = service.clone();
}
#[tokio::test]
async fn test_collect_body_limited_under_limit() {
let body = Full::new(Bytes::from_static(b"hello"));
let result = collect_body_limited(body, 1024).await.unwrap();
assert_eq!(&result[..], b"hello");
}
#[tokio::test]
async fn test_collect_body_limited_exact_limit() {
let body = Full::new(Bytes::from_static(b"hello"));
let result = collect_body_limited(body, 5).await.unwrap();
assert_eq!(&result[..], b"hello");
}
#[tokio::test]
async fn test_collect_body_limited_over_limit() {
let body = Full::new(Bytes::from_static(b"hello world"));
let err = collect_body_limited(body, 5).await.unwrap_err();
assert_eq!(err.code, crate::error::ErrorCode::ResourceExhausted);
assert!(err.message.as_deref().unwrap().contains("limit 5"));
}
#[test]
fn test_parse_get_query_params_basic() {
let params = parse_get_query_params(Some("message=%7B%7D&encoding=json&connect=v1"))
.expect("should parse");
assert_eq!(params.message, Some("%7B%7D".to_string()));
assert_eq!(params.encoding, Some("json".to_string()));
assert_eq!(params.connect_version, Some("v1".to_string()));
assert!(!params.base64);
assert!(params.compression.is_none());
}
#[test]
fn test_parse_get_query_params_with_base64() {
let params = parse_get_query_params(Some("message=e30&encoding=proto&base64=1&connect=v1"))
.expect("should parse");
assert_eq!(params.message, Some("e30".to_string()));
assert_eq!(params.encoding, Some("proto".to_string()));
assert!(params.base64);
}
#[test]
fn test_parse_get_query_params_with_compression() {
let params =
parse_get_query_params(Some("message=abc&encoding=json&compression=gzip&base64=1"))
.expect("should parse");
assert_eq!(params.compression, Some("gzip".to_string()));
}
#[test]
fn test_parse_get_query_params_missing_encoding() {
let result = parse_get_query_params(Some("message=test&connect=v1"));
assert!(result.is_err());
}
#[test]
fn test_parse_get_query_params_no_query() {
let result = parse_get_query_params(None);
assert!(result.is_err());
}
#[test]
fn test_percent_decode_basic() {
let decoded = percent_decode("%7B%22name%22%3A%22test%22%7D").expect("should decode");
assert_eq!(decoded, b"{\"name\":\"test\"}");
}
#[test]
fn test_percent_decode_plus_as_space() {
let decoded = percent_decode("hello+world").expect("should decode");
assert_eq!(decoded, b"hello world");
}
#[test]
fn test_percent_decode_passthrough() {
let decoded = percent_decode("hello").expect("should decode");
assert_eq!(decoded, b"hello");
}
#[test]
fn test_decode_get_message_json() {
let params = GetQueryParams {
message: Some("%7B%7D".to_string()),
encoding: Some("json".to_string()),
base64: false,
compression: None,
connect_version: Some("v1".to_string()),
};
let compression = CompressionRegistry::default();
let result = decode_get_message(¶ms, &compression, 1024 * 1024).expect("should decode");
assert_eq!(result.as_ref(), b"{}");
}
#[test]
fn test_decode_get_message_base64() {
let params = GetQueryParams {
message: Some("e30".to_string()), encoding: Some("json".to_string()),
base64: true,
compression: None,
connect_version: Some("v1".to_string()),
};
let compression = CompressionRegistry::default();
let result = decode_get_message(¶ms, &compression, 1024 * 1024).expect("should decode");
assert_eq!(result.as_ref(), b"{}");
}
#[test]
fn test_decode_get_message_empty() {
let params = GetQueryParams {
message: None,
encoding: Some("json".to_string()),
base64: false,
compression: None,
connect_version: Some("v1".to_string()),
};
let compression = CompressionRegistry::default();
let result = decode_get_message(¶ms, &compression, 1024 * 1024).expect("should decode");
assert!(result.is_empty());
}
#[test]
fn test_parse_timeout_connect_milliseconds() {
assert_eq!(
parse_timeout("5000", Protocol::Connect),
Some(Duration::from_millis(5000))
);
}
#[test]
fn test_parse_timeout_connect_zero() {
assert_eq!(
parse_timeout("0", Protocol::Connect),
Some(Duration::from_millis(0))
);
}
#[test]
fn test_parse_timeout_connect_invalid() {
assert_eq!(parse_timeout("abc", Protocol::Connect), None);
assert_eq!(parse_timeout("", Protocol::Connect), None);
}
#[test]
fn test_parse_timeout_grpc_hours() {
assert_eq!(
parse_timeout("1H", Protocol::Grpc),
Some(Duration::from_secs(3600))
);
}
#[test]
fn test_parse_timeout_grpc_minutes() {
assert_eq!(
parse_timeout("5M", Protocol::Grpc),
Some(Duration::from_secs(300))
);
}
#[test]
fn test_parse_timeout_grpc_seconds() {
assert_eq!(
parse_timeout("30S", Protocol::Grpc),
Some(Duration::from_secs(30))
);
}
#[test]
fn test_parse_timeout_grpc_milliseconds() {
assert_eq!(
parse_timeout("500m", Protocol::Grpc),
Some(Duration::from_millis(500))
);
}
#[test]
fn test_parse_timeout_grpc_microseconds() {
assert_eq!(
parse_timeout("100u", Protocol::Grpc),
Some(Duration::from_micros(100))
);
}
#[test]
fn test_parse_timeout_grpc_nanoseconds() {
assert_eq!(
parse_timeout("999n", Protocol::Grpc),
Some(Duration::from_nanos(999))
);
}
#[test]
fn test_parse_timeout_grpc_zero() {
assert_eq!(
parse_timeout("0S", Protocol::Grpc),
Some(Duration::from_secs(0))
);
}
#[test]
fn test_parse_timeout_grpc_invalid_unit() {
assert_eq!(parse_timeout("5X", Protocol::Grpc), None);
}
#[test]
fn test_parse_timeout_grpc_no_digits() {
assert_eq!(parse_timeout("H", Protocol::Grpc), None);
}
#[test]
fn test_parse_timeout_grpc_empty() {
assert_eq!(parse_timeout("", Protocol::Grpc), None);
}
#[test]
fn test_parse_timeout_grpc_over_8_digits_rejected() {
assert_eq!(parse_timeout("123456789S", Protocol::Grpc), None);
let huge = format!("{}S", u64::MAX);
assert_eq!(parse_timeout(&huge, Protocol::Grpc), None);
assert_eq!(
parse_timeout("99999999S", Protocol::Grpc),
Some(Duration::from_secs(99_999_999))
);
let d = parse_timeout("99999999H", Protocol::Grpc).unwrap();
assert!(std::time::Instant::now().checked_add(d).is_some());
}
#[test]
fn test_parse_timeout_connect_over_10_digits_rejected() {
assert_eq!(parse_timeout("12345678901", Protocol::Connect), None);
let huge = format!("{}", u64::MAX);
assert_eq!(parse_timeout(&huge, Protocol::Connect), None);
assert_eq!(
parse_timeout("9999999999", Protocol::Connect),
Some(Duration::from_millis(9_999_999_999))
);
let d = parse_timeout("9999999999", Protocol::Connect).unwrap();
assert!(std::time::Instant::now().checked_add(d).is_some());
}
#[test]
fn test_parse_timeout_grpc_non_ascii_rejected() {
assert_eq!(parse_timeout("5é", Protocol::Grpc), None);
assert_eq!(parse_timeout("é5m", Protocol::Grpc), None);
assert_eq!(parse_timeout("5☺", Protocol::Grpc), None);
assert_eq!(parse_timeout("5é", Protocol::Connect), None);
}
#[test]
fn test_parse_timeout_grpc_web_same_as_grpc() {
assert_eq!(
parse_timeout("500m", Protocol::GrpcWeb),
Some(Duration::from_millis(500))
);
}
#[test]
fn test_grpc_percent_encode_passthrough() {
assert_eq!(grpc_percent_encode("hello world"), "hello world");
assert_eq!(grpc_percent_encode("a-b_c.d"), "a-b_c.d");
}
#[test]
fn test_grpc_percent_encode_percent() {
assert_eq!(grpc_percent_encode("100%"), "100%25");
}
#[test]
fn test_grpc_percent_encode_non_ascii() {
assert_eq!(grpc_percent_encode("café"), "caf%C3%A9");
}
#[test]
fn test_grpc_percent_encode_control_chars() {
assert_eq!(grpc_percent_encode("a\nb"), "a%0Ab");
assert_eq!(grpc_percent_encode("a\tb"), "a%09b");
}
#[test]
fn test_build_grpc_trailers_success() {
let custom = http::HeaderMap::new();
let trailers = build_grpc_trailers(None, &custom);
assert_eq!(trailers.get(&GRPC_STATUS).unwrap().to_str().unwrap(), "0");
assert!(!trailers.contains_key(&GRPC_MESSAGE));
}
#[test]
fn test_build_grpc_trailers_error() {
let err = ConnectError::not_found("thing not found");
let custom = http::HeaderMap::new();
let trailers = build_grpc_trailers(Some(&err), &custom);
assert_eq!(
trailers.get(&GRPC_STATUS).unwrap().to_str().unwrap(),
"5" );
assert_eq!(
trailers.get(&GRPC_MESSAGE).unwrap().to_str().unwrap(),
"thing not found"
);
assert!(trailers.contains_key("grpc-status-details-bin"));
}
#[test]
fn test_build_grpc_trailers_custom_metadata() {
let mut custom = http::HeaderMap::new();
custom.insert("x-custom", http::HeaderValue::from_static("value1"));
custom.append("x-custom", http::HeaderValue::from_static("value2"));
let trailers = build_grpc_trailers(None, &custom);
assert_eq!(trailers.get(&GRPC_STATUS).unwrap().to_str().unwrap(), "0");
let values: Vec<_> = trailers
.get_all("x-custom")
.iter()
.map(|v| v.to_str().unwrap())
.collect();
assert_eq!(values, vec!["value1", "value2"]);
}
#[test]
fn test_build_grpc_trailers_dedup_error_trailers() {
let mut err = ConnectError::internal("error");
err.trailers
.insert("x-trailer", http::HeaderValue::from_static("from-error"));
let mut custom = http::HeaderMap::new();
custom.insert("x-trailer", http::HeaderValue::from_static("from-context"));
let trailers = build_grpc_trailers(Some(&err), &custom);
let values: Vec<_> = trailers
.get_all("x-trailer")
.iter()
.map(|v| v.to_str().unwrap())
.collect();
assert_eq!(values, vec!["from-error"]);
}
#[test]
fn test_build_grpc_trailers_no_dedup_when_error_has_no_trailers() {
let err = ConnectError::internal("error");
let mut custom = http::HeaderMap::new();
custom.insert("x-trailer", http::HeaderValue::from_static("from-context"));
let trailers = build_grpc_trailers(Some(&err), &custom);
assert_eq!(
trailers.get("x-trailer").unwrap().to_str().unwrap(),
"from-context"
);
}
#[test]
fn test_encode_grpc_web_trailers() {
let mut headers = http::HeaderMap::new();
headers.insert(&GRPC_STATUS, http::HeaderValue::from_static("0"));
let frame = encode_grpc_web_trailers(&headers);
assert_eq!(frame[0], 0x80);
let len = u32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]) as usize;
assert_eq!(frame.len(), 5 + len);
let payload = std::str::from_utf8(&frame[5..]).unwrap();
assert!(payload.contains("grpc-status: 0\r\n"));
}
#[test]
fn test_encode_grpc_web_trailers_multi_header() {
let mut headers = http::HeaderMap::new();
headers.insert(&GRPC_STATUS, http::HeaderValue::from_static("13"));
headers.insert(&GRPC_MESSAGE, http::HeaderValue::from_static("internal"));
headers.insert(
"grpc-status-details-bin",
http::HeaderValue::from_static("abc123"),
);
let frame = encode_grpc_web_trailers(&headers);
assert_eq!(frame[0], 0x80);
let len = u32::from_be_bytes([frame[1], frame[2], frame[3], frame[4]]) as usize;
let payload = std::str::from_utf8(&frame[5..5 + len]).unwrap();
assert!(payload.contains("grpc-status: 13\r\n"));
assert!(payload.contains("grpc-message: internal\r\n"));
assert!(payload.contains("grpc-status-details-bin: abc123\r\n"));
}
#[test]
fn test_encode_grpc_status_details_basic() {
let err = ConnectError::internal("test error");
let bytes = crate::grpc_status::encode(&err);
assert!(bytes.len() > 2);
assert_eq!(bytes[0], 8); assert_eq!(bytes[1], 13); }
fn collect_frames(mut stream: BatchingEnvelopeStream) -> (Vec<Bytes>, Option<Frame<Bytes>>) {
use futures::task::noop_waker_ref;
let mut cx = std::task::Context::from_waker(noop_waker_ref());
let mut data_frames = Vec::new();
let mut terminal = None;
loop {
match Pin::new(&mut stream).poll_next(&mut cx) {
Poll::Ready(Some(Ok(f))) if f.is_data() => {
data_frames.push(f.into_data().unwrap());
}
Poll::Ready(Some(Ok(f))) => {
terminal = Some(f);
}
Poll::Ready(None) => break,
Poll::Pending => panic!("synchronous source should never be Pending"),
}
}
(data_frames, terminal)
}
#[test]
fn batching_sync_source_one_data_frame() {
let items: Vec<Result<Bytes, ConnectError>> =
(0..10).map(|_| Ok(Bytes::from_static(b"msg"))).collect();
let source: BoxStream<_> = Box::pin(futures::stream::iter(items));
let stream = BatchingEnvelopeStream::new(
source,
http::HeaderMap::new(),
None,
CompressionPolicy::default(),
StreamFinalizer::GrpcTrailers,
);
let (data_frames, terminal) = collect_frames(stream);
assert_eq!(
data_frames.len(),
1,
"10 synchronous items should produce 1 batched data frame, got {}",
data_frames.len()
);
assert_eq!(data_frames[0].len(), 80);
assert!(terminal.is_some());
assert!(terminal.unwrap().is_trailers());
}
#[test]
fn batching_threshold_splits_frames() {
let big = Bytes::from(vec![b'x'; 9 * 1024]);
let items: Vec<Result<Bytes, ConnectError>> = (0..4).map(|_| Ok(big.clone())).collect();
let source: BoxStream<_> = Box::pin(futures::stream::iter(items));
let stream = BatchingEnvelopeStream::new(
source,
http::HeaderMap::new(),
None,
CompressionPolicy::default(),
StreamFinalizer::GrpcTrailers,
);
let (data_frames, terminal) = collect_frames(stream);
assert_eq!(data_frames.len(), 2);
assert!(terminal.unwrap().is_trailers());
}
#[test]
fn batching_empty_source_just_finalizer() {
let source: BoxStream<_> = Box::pin(futures::stream::empty());
let stream = BatchingEnvelopeStream::new(
source,
http::HeaderMap::new(),
None,
CompressionPolicy::default(),
StreamFinalizer::GrpcTrailers,
);
let (data_frames, terminal) = collect_frames(stream);
assert!(data_frames.is_empty());
assert!(terminal.unwrap().is_trailers());
}
#[test]
fn batching_connect_finalizer_is_data_frame() {
let source: BoxStream<_> = Box::pin(futures::stream::once(async {
Ok(Bytes::from_static(b"x"))
}));
let stream = BatchingEnvelopeStream::new(
source,
http::HeaderMap::new(),
None,
CompressionPolicy::default(),
StreamFinalizer::ConnectEndStream,
);
let (data_frames, terminal) = collect_frames(stream);
assert_eq!(data_frames.len(), 2);
assert!(terminal.is_none());
let end_frame = &data_frames[1];
assert_eq!(end_frame[0], crate::envelope::flags::END_STREAM);
}
#[test]
fn batching_error_after_items_stages_final() {
let items: Vec<Result<Bytes, ConnectError>> = vec![
Ok(Bytes::from_static(b"a")),
Ok(Bytes::from_static(b"b")),
Ok(Bytes::from_static(b"c")),
Err(ConnectError::internal("boom")),
];
let source: BoxStream<_> = Box::pin(futures::stream::iter(items));
let stream = BatchingEnvelopeStream::new(
source,
http::HeaderMap::new(),
None,
CompressionPolicy::default(),
StreamFinalizer::GrpcTrailers,
);
let (data_frames, terminal) = collect_frames(stream);
assert_eq!(data_frames.len(), 1);
assert_eq!(data_frames[0].len(), 3 * (5 + 1)); let trailers = terminal.unwrap().into_trailers().unwrap();
let status = trailers.get("grpc-status").unwrap().to_str().unwrap();
assert_ne!(status, "0");
}
#[test]
fn end_stream_error_includes_error_trailers() {
let mut err_trailers = http::HeaderMap::new();
err_trailers.insert("x-error-info", "from-err".parse().unwrap());
let err = ConnectError::internal("boom").with_trailers(err_trailers);
let mut context_trailers = http::HeaderMap::new();
context_trailers.insert("x-ctx", "from-context".parse().unwrap());
let end = EndStreamResponse::error(&err, &context_trailers);
let metadata = end.metadata.expect("metadata should be Some");
assert!(
metadata.contains_key("x-error-info"),
"error-level trailer should be in metadata: {metadata:?}"
);
assert!(
!metadata.contains_key("x-ctx"),
"context trailer should NOT be in metadata when err has own trailers: {metadata:?}"
);
}
#[test]
fn end_stream_error_falls_back_to_context_trailers() {
let err = ConnectError::internal("boom");
let mut context_trailers = http::HeaderMap::new();
context_trailers.insert("x-ctx", "from-context".parse().unwrap());
let end = EndStreamResponse::error(&err, &context_trailers);
let metadata = end.metadata.expect("metadata should be Some");
assert!(metadata.contains_key("x-ctx"));
}
#[test]
fn end_stream_error_details_include_debug_field() {
let detail = crate::error::ErrorDetail {
type_url: "test.Detail".into(),
value: Some("YmFzZTY0".into()),
debug: Some(serde_json::json!({"hint": "turn it off and on again"})),
};
let err = ConnectError::internal("boom").with_detail(detail);
let end = EndStreamResponse::error(&err, &http::HeaderMap::new());
let json = serde_json::to_string(&end).unwrap();
assert!(
json.contains("\"type\":\"test.Detail\""),
"type missing: {json}"
);
assert!(
json.contains("\"value\":\"YmFzZTY0\""),
"value missing: {json}"
);
assert!(json.contains("\"debug\":"), "debug field missing: {json}");
assert!(
json.contains("turn it off and on again"),
"debug content missing: {json}"
);
}
#[tokio::test]
async fn extensions_flow_to_handler_context() {
use std::sync::Mutex;
#[derive(Clone, Debug, PartialEq)]
struct PeerTag(&'static str);
let captured = Arc::new(Mutex::new(None::<PeerTag>));
let handler_captured = Arc::clone(&captured);
let router = Router::new().route(
"svc",
"Method",
crate::handler_fn(move |ctx: Context, _req: buffa_types::Empty| {
let cap = Arc::clone(&handler_captured);
async move {
*cap.lock().unwrap() = ctx.extensions.get::<PeerTag>().cloned();
Ok((buffa_types::Empty::default(), ctx))
}
}),
);
let mut req = Request::builder()
.method(Method::POST)
.uri("/svc/Method")
.header(header::CONTENT_TYPE, "application/proto")
.body(Full::new(Bytes::new()))
.unwrap();
req.extensions_mut().insert(PeerTag("10.0.0.1:54321"));
handle_unary_request(
&router,
req,
Limits::default(),
Arc::new(CompressionRegistry::new()),
&CompressionPolicy::default(),
)
.await
.expect("dispatch should succeed");
assert_eq!(
captured.lock().unwrap().take(),
Some(PeerTag("10.0.0.1:54321")),
"extension inserted on the http::Request must reach Context.extensions"
);
}
}