use std::marker::PhantomData;
use std::pin::Pin;
use std::time::{Duration, Instant};
use buffa::Message;
use buffa::view::{MessageView, ViewEncode};
use bytes::Bytes;
use futures::Stream;
use http::HeaderMap;
use http::header::{HeaderName, HeaderValue};
use serde::Serialize;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
#[derive(Debug, Clone, Default)]
#[non_exhaustive]
pub struct RequestContext {
pub(crate) headers: HeaderMap,
pub(crate) deadline: Option<Instant>,
pub(crate) extensions: http::Extensions,
pub(crate) spec: Option<crate::spec::Spec>,
pub(crate) protocol: Option<crate::Protocol>,
pub(crate) path: Option<String>,
}
impl RequestContext {
pub fn new(headers: HeaderMap) -> Self {
Self {
headers,
deadline: None,
extensions: http::Extensions::new(),
spec: None,
protocol: None,
path: None,
}
}
#[must_use]
pub fn with_deadline(mut self, deadline: Option<Instant>) -> Self {
self.deadline = deadline;
self
}
#[must_use]
pub fn with_extensions(mut self, extensions: http::Extensions) -> Self {
self.extensions = extensions;
self
}
#[must_use]
pub fn with_spec(mut self, spec: Option<crate::spec::Spec>) -> Self {
self.spec = spec;
self
}
#[must_use]
pub fn with_protocol(mut self, protocol: Option<crate::Protocol>) -> Self {
self.protocol = protocol;
self
}
#[must_use]
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.path = Some(path.into());
self
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn header(&self, key: impl http::header::AsHeaderName) -> Option<&HeaderValue> {
self.headers.get(key)
}
pub fn deadline(&self) -> Option<Instant> {
self.deadline
}
pub fn time_remaining(&self) -> Option<Duration> {
self.deadline
.map(|d| d.saturating_duration_since(Instant::now()))
}
pub fn extensions(&self) -> &http::Extensions {
&self.extensions
}
pub fn extensions_mut(&mut self) -> &mut http::Extensions {
&mut self.extensions
}
pub fn spec(&self) -> Option<crate::spec::Spec> {
self.spec
}
pub fn protocol(&self) -> Option<crate::Protocol> {
self.protocol
}
pub fn path(&self) -> Option<&str> {
self.path.as_deref()
}
#[cfg(feature = "server")]
#[cfg_attr(docsrs, doc(cfg(feature = "server")))]
pub fn peer_addr(&self) -> Option<std::net::SocketAddr> {
self.extensions
.get::<crate::server::PeerAddr>()
.map(|p| p.0)
}
#[cfg(feature = "server-tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "server-tls")))]
pub fn peer_certs(&self) -> Option<&[rustls::pki_types::CertificateDer<'static>]> {
self.extensions
.get::<crate::server::PeerCerts>()
.map(|p| &p.0[..])
}
}
#[derive(Debug, Clone)]
pub struct Response<B> {
pub body: B,
pub headers: HeaderMap,
pub trailers: HeaderMap,
pub compress: Option<bool>,
}
impl<B> Response<B> {
pub fn ok(body: B) -> ServiceResult<B> {
Ok(Self::from(body))
}
pub fn new(body: B) -> Self {
Self {
body,
headers: HeaderMap::new(),
trailers: HeaderMap::new(),
compress: None,
}
}
#[must_use]
pub fn with_header<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,
K::Error: std::fmt::Debug,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Debug,
{
self.headers
.append(name.try_into().unwrap(), value.try_into().unwrap());
self
}
pub fn try_with_header<K, V>(mut self, name: K, value: V) -> Result<Self, http::Error>
where
K: TryInto<HeaderName>,
K::Error: Into<http::Error>,
V: TryInto<HeaderValue>,
V::Error: Into<http::Error>,
{
self.headers.append(
name.try_into().map_err(Into::into)?,
value.try_into().map_err(Into::into)?,
);
Ok(self)
}
#[must_use]
pub fn with_trailer<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,
K::Error: std::fmt::Debug,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Debug,
{
self.trailers
.append(name.try_into().unwrap(), value.try_into().unwrap());
self
}
pub fn try_with_trailer<K, V>(mut self, name: K, value: V) -> Result<Self, http::Error>
where
K: TryInto<HeaderName>,
K::Error: Into<http::Error>,
V: TryInto<HeaderValue>,
V::Error: Into<http::Error>,
{
self.trailers.append(
name.try_into().map_err(Into::into)?,
value.try_into().map_err(Into::into)?,
);
Ok(self)
}
#[must_use]
pub fn compress(mut self, enabled: impl Into<Option<bool>>) -> Self {
self.compress = enabled.into();
self
}
pub fn map_body<C>(self, f: impl FnOnce(B) -> C) -> Response<C> {
Response {
body: f(self.body),
headers: self.headers,
trailers: self.trailers,
compress: self.compress,
}
}
}
impl<B> From<B> for Response<B> {
fn from(body: B) -> Self {
Self::new(body)
}
}
impl<T> Response<ServiceStream<T>> {
pub fn stream(s: impl Stream<Item = Result<T, ConnectError>> + Send + 'static) -> Self {
Self::new(Box::pin(s))
}
pub fn stream_ok(
s: impl Stream<Item = Result<T, ConnectError>> + Send + 'static,
) -> ServiceResult<ServiceStream<T>> {
Ok(Self::stream(s))
}
}
pub type ServiceResult<B> = Result<Response<B>, ConnectError>;
pub type ServiceStream<T> = Pin<Box<dyn Stream<Item = Result<T, ConnectError>> + Send>>;
pub trait Encodable<M> {
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError>;
}
impl<M: Message + Serialize> Encodable<M> for M {
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError> {
match codec {
CodecFormat::Proto => Ok(self.encode_to_bytes()),
CodecFormat::Json => serde_json::to_vec(self).map(Bytes::from).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON response: {e}"))
}),
}
}
}
#[doc(hidden)]
pub fn encode_view_body<'a, V: ViewEncode<'a>>(
view: &V,
codec: CodecFormat,
) -> Result<Bytes, ConnectError> {
match codec {
CodecFormat::Proto => Ok(view.encode_to_bytes()),
CodecFormat::Json => Err(ConnectError::unimplemented(
"view-body responses do not support the JSON codec; return the owned message type for JSON-serving handlers",
)),
}
}
#[derive(Debug, Clone)]
pub enum MaybeBorrowed<M, V> {
Owned(M),
Borrowed(V),
}
impl<M, V> Encodable<M> for MaybeBorrowed<M, V>
where
M: Encodable<M>,
V: Encodable<M>,
{
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError> {
match self {
Self::Owned(m) => m.encode(codec),
Self::Borrowed(v) => v.encode(codec),
}
}
}
#[must_use = "PreEncoded must be returned from a handler to take effect"]
pub struct PreEncoded<M> {
bytes: Bytes,
_marker: PhantomData<fn() -> M>,
}
impl<M> std::fmt::Debug for PreEncoded<M> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PreEncoded").field(&self.bytes).finish()
}
}
impl<M> Clone for PreEncoded<M> {
fn clone(&self) -> Self {
Self {
bytes: self.bytes.clone(),
_marker: PhantomData,
}
}
}
impl<M: Message> PreEncoded<M> {
pub fn from_message(msg: &M) -> Self {
Self {
bytes: msg.encode_to_bytes(),
_marker: PhantomData,
}
}
pub fn from_view<'a, V>(view: &V) -> Self
where
V: ViewEncode<'a> + MessageView<'a, Owned = M>,
{
Self {
bytes: view.encode_to_bytes(),
_marker: PhantomData,
}
}
pub fn from_bytes_unchecked(bytes: impl Into<Bytes>) -> Self {
let bytes = bytes.into();
debug_assert!(
M::decode_from_slice(&bytes).is_ok(),
"PreEncoded::from_bytes_unchecked: bytes do not decode as {}",
std::any::type_name::<M>(),
);
Self {
bytes,
_marker: PhantomData,
}
}
}
impl<M: Message> From<&M> for PreEncoded<M> {
fn from(msg: &M) -> Self {
Self::from_message(msg)
}
}
impl<M: Message + Serialize> Encodable<M> for PreEncoded<M> {
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError> {
match codec {
CodecFormat::Proto => Ok(self.bytes.clone()),
CodecFormat::Json => {
let msg = M::decode_from_slice(&self.bytes).map_err(|e| {
ConnectError::internal(format!(
"pre-encoded bytes did not decode as {}: {e}",
std::any::type_name::<M>(),
))
})?;
serde_json::to_vec(&msg).map(Bytes::from).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON response: {e}"))
})
}
}
}
}
pub type EncodedResponse = Response<Bytes>;
impl<B> Response<B> {
#[doc(hidden)] pub fn encode<M>(self, codec: CodecFormat) -> Result<EncodedResponse, ConnectError>
where
B: Encodable<M>,
{
let bytes = self.body.encode(codec)?;
Ok(Response {
body: bytes,
headers: self.headers,
trailers: self.trailers,
compress: self.compress,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_types::google::protobuf::StringValue;
#[tokio::test]
async fn response_stream_ok_shorthand() {
use futures::StreamExt;
let r: ServiceResult<ServiceStream<i32>> =
Response::stream_ok(futures::stream::iter([Ok(7)]));
let collected: Vec<_> = r.unwrap().body.map(|x| x.unwrap()).collect().await;
assert_eq!(collected, vec![7]);
}
#[test]
fn compress_tristate() {
assert_eq!(Response::new(()).compress(true).compress, Some(true));
assert_eq!(Response::new(()).compress(false).compress, Some(false));
assert_eq!(Response::new(()).compress(None).compress, None);
}
#[test]
fn header_accepts_str() {
let mut h = HeaderMap::new();
h.insert("x-custom", HeaderValue::from_static("v"));
let ctx = RequestContext::new(h);
assert_eq!(ctx.header("x-custom").unwrap(), "v");
}
#[test]
fn response_ok_shorthand() {
let r: ServiceResult<u32> = Response::ok(42);
let r = r.unwrap();
assert_eq!(r.body, 42);
assert!(r.headers.is_empty());
}
#[test]
fn response_from_body() {
let r: Response<StringValue> = StringValue::from("hi").into();
assert_eq!(r.body.value, "hi");
assert!(r.headers.is_empty());
assert!(r.trailers.is_empty());
assert_eq!(r.compress, None);
}
#[test]
fn response_builder() {
let r = Response::new(StringValue::from("hi"))
.with_header("x-a", "1")
.with_trailer("x-b", "2")
.compress(true);
assert_eq!(r.headers.get("x-a").unwrap(), "1");
assert_eq!(r.trailers.get("x-b").unwrap(), "2");
assert_eq!(r.compress, Some(true));
}
#[test]
fn encodable_owned_proto() {
let m = StringValue::from("hello");
let bytes = Encodable::<StringValue>::encode(&m, CodecFormat::Proto).unwrap();
assert_eq!(
StringValue::decode_from_slice(&bytes).unwrap().value,
"hello"
);
}
#[test]
fn encodable_owned_json() {
let m = StringValue::from("hello");
let bytes = Encodable::<StringValue>::encode(&m, CodecFormat::Json).unwrap();
assert_eq!(&bytes[..], b"\"hello\"");
}
#[test]
fn response_encode() {
let r = Response::new(StringValue::from("hi")).with_header("x-a", "1");
let enc = r.encode::<StringValue>(CodecFormat::Proto).unwrap();
assert_eq!(enc.headers.get("x-a").unwrap(), "1");
assert_eq!(
StringValue::decode_from_slice(&enc.body).unwrap().value,
"hi"
);
}
#[test]
fn request_context_new() {
let mut h = HeaderMap::new();
h.insert("x-custom", HeaderValue::from_static("v"));
let ctx = RequestContext::new(h);
assert_eq!(
ctx.header(HeaderName::from_static("x-custom")).unwrap(),
"v"
);
assert_eq!(ctx.headers().get("x-custom").unwrap(), "v");
assert!(ctx.deadline().is_none());
assert!(ctx.time_remaining().is_none());
assert!(ctx.extensions().is_empty());
}
#[test]
fn request_context_with_deadline() {
let d = Instant::now();
let ctx = RequestContext::new(HeaderMap::new()).with_deadline(Some(d));
assert_eq!(ctx.deadline(), Some(d));
}
#[test]
fn request_context_time_remaining_saturates_at_zero() {
let past = Instant::now() - Duration::from_secs(60);
let ctx = RequestContext::new(HeaderMap::new()).with_deadline(Some(past));
assert_eq!(ctx.time_remaining(), Some(Duration::ZERO));
}
#[test]
fn request_context_time_remaining_future() {
let future = Instant::now() + Duration::from_secs(60);
let ctx = RequestContext::new(HeaderMap::new()).with_deadline(Some(future));
let remaining = ctx.time_remaining().unwrap();
assert!(remaining > Duration::from_secs(55));
assert!(remaining <= Duration::from_secs(60));
}
#[test]
fn request_context_extensions_mut() {
#[derive(Clone, Debug, PartialEq)]
struct Tag(u8);
let mut ctx = RequestContext::new(HeaderMap::new());
ctx.extensions_mut().insert(Tag(1));
assert_eq!(ctx.extensions().get::<Tag>(), Some(&Tag(1)));
}
#[cfg(feature = "server")]
#[test]
fn request_context_peer_addr_absent() {
let ctx = RequestContext::new(HeaderMap::new());
assert_eq!(ctx.peer_addr(), None);
}
#[cfg(feature = "server")]
#[test]
fn request_context_peer_addr_present() {
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8080);
let mut ext = http::Extensions::new();
ext.insert(crate::server::PeerAddr(addr));
let ctx = RequestContext::new(HeaderMap::new()).with_extensions(ext);
assert_eq!(ctx.peer_addr(), Some(addr));
}
#[cfg(feature = "server-tls")]
#[test]
fn request_context_peer_certs_absent() {
let ctx = RequestContext::new(HeaderMap::new());
assert!(ctx.peer_certs().is_none());
}
#[test]
fn response_map_body_preserves_metadata() {
let r = Response::new(2u32)
.with_header("x-h", "1")
.with_trailer("x-t", "2")
.compress(true);
let r = r.map_body(|n| n.to_string());
assert_eq!(r.body, "2");
assert_eq!(r.headers.get("x-h").unwrap(), "1");
assert_eq!(r.trailers.get("x-t").unwrap(), "2");
assert_eq!(r.compress, Some(true));
}
#[tokio::test]
async fn response_stream_yields_items() {
use futures::StreamExt;
let r: Response<ServiceStream<i32>> =
Response::stream(futures::stream::iter([Ok(1), Ok(2), Ok(3)]));
let collected: Vec<_> = r.body.map(|x| x.unwrap()).collect().await;
assert_eq!(collected, vec![1, 2, 3]);
}
#[test]
#[should_panic]
fn with_header_panics_on_invalid_name() {
let _ = Response::new(()).with_header("invalid header name", "v");
}
#[test]
fn try_with_header_errors_on_invalid_name() {
let err = Response::new(())
.try_with_header("invalid header name", "v")
.unwrap_err();
assert!(err.is::<http::header::InvalidHeaderName>());
}
#[test]
fn try_with_header_ok_appends() {
let r = Response::new(())
.try_with_header("x-a", "1")
.unwrap()
.try_with_header("x-a", "2")
.unwrap();
let vals: Vec<_> = r.headers.get_all("x-a").iter().collect();
assert_eq!(vals.len(), 2);
}
#[test]
fn try_with_trailer_errors_on_invalid_value() {
let err = Response::new(())
.try_with_trailer("x-t", "bad\nvalue")
.unwrap_err();
assert!(err.is::<http::header::InvalidHeaderValue>());
}
#[test]
fn encode_view_body_proto() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let v = StringValueView {
value: "hi",
..Default::default()
};
let bytes = encode_view_body(&v, CodecFormat::Proto).unwrap();
assert_eq!(StringValue::decode_from_slice(&bytes).unwrap().value, "hi");
}
#[test]
fn encode_view_body_json_errors() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let v = StringValueView::default();
let err = encode_view_body(&v, CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::ErrorCode::Unimplemented);
assert!(err.message.as_deref().unwrap().contains("JSON codec"));
}
struct V<'a>(buffa_types::google::protobuf::__buffa::view::StringValueView<'a>);
impl Encodable<StringValue> for V<'_> {
fn encode(&self, c: CodecFormat) -> Result<Bytes, ConnectError> {
encode_view_body(&self.0, c)
}
}
#[test]
fn maybe_borrowed_dispatch() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let owned: MaybeBorrowed<StringValue, V<'_>> =
MaybeBorrowed::Owned(StringValue::from("owned"));
let borrowed = MaybeBorrowed::Borrowed(V(StringValueView {
value: "view",
..Default::default()
}));
assert_eq!(
StringValue::decode_from_slice(&owned.encode(CodecFormat::Proto).unwrap())
.unwrap()
.value,
"owned"
);
assert_eq!(
StringValue::decode_from_slice(&borrowed.encode(CodecFormat::Proto).unwrap())
.unwrap()
.value,
"view"
);
}
#[test]
fn maybe_borrowed_borrowed_json_unimplemented() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let borrowed: MaybeBorrowed<StringValue, V<'_>> =
MaybeBorrowed::Borrowed(V(StringValueView::default()));
let err = borrowed.encode(CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::ErrorCode::Unimplemented);
}
#[test]
fn pre_encoded_proto_round_trip() {
let m = StringValue::from("pre-encoded");
let bytes = m.encode_to_bytes();
let body = PreEncoded::<StringValue>::from_bytes_unchecked(bytes.clone());
let out = Encodable::<StringValue>::encode(&body, CodecFormat::Proto).unwrap();
assert_eq!(out, bytes);
assert_eq!(
StringValue::decode_from_slice(&out).unwrap().value,
"pre-encoded"
);
}
#[test]
fn pre_encoded_json_decodes_then_serializes() {
let m = StringValue::from("hi");
let body = PreEncoded::<StringValue>::from_bytes_unchecked(m.encode_to_bytes());
let out = Encodable::<StringValue>::encode(&body, CodecFormat::Json).unwrap();
assert_eq!(out, Bytes::from(serde_json::to_vec(&m).unwrap()));
}
#[test]
fn pre_encoded_json_decode_failure_is_internal_error() {
let body = PreEncoded::<StringValue> {
bytes: Bytes::from_static(&[0x0a, 0x63, b'h', b'i']),
_marker: std::marker::PhantomData,
};
let err = Encodable::<StringValue>::encode(&body, CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::ErrorCode::Internal);
assert!(err.message.as_deref().unwrap().contains("did not decode"));
}
#[test]
fn pre_encoded_from_view() {
use buffa::view::ViewEncode;
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let v = StringValueView {
value: "from-view",
..Default::default()
};
let body = PreEncoded::from_view(&v);
let out = Encodable::<StringValue>::encode(&body, CodecFormat::Proto).unwrap();
assert_eq!(out, v.encode_to_bytes());
assert_eq!(
StringValue::decode_from_slice(&out).unwrap().value,
"from-view"
);
}
#[test]
fn pre_encoded_from_message() {
let m = StringValue::from("from-message");
let body = PreEncoded::from_message(&m);
let out = Encodable::<StringValue>::encode(&body, CodecFormat::Proto).unwrap();
assert_eq!(out, m.encode_to_bytes());
let body2: PreEncoded<StringValue> = (&m).into();
let out2 = Encodable::<StringValue>::encode(&body2, CodecFormat::Proto).unwrap();
assert_eq!(out2, out);
}
#[test]
fn pre_encoded_codec_fidelity_diverges_on_unknown_fields() {
let bytes_with_unknown =
Bytes::from_static(&[0x0a, 0x02, b'h', b'i', 0x10, 42]);
let body = PreEncoded::<StringValue> {
bytes: bytes_with_unknown.clone(),
_marker: std::marker::PhantomData,
};
let proto = Encodable::<StringValue>::encode(&body, CodecFormat::Proto).unwrap();
assert_eq!(proto, bytes_with_unknown);
let json = Encodable::<StringValue>::encode(&body, CodecFormat::Json).unwrap();
assert_eq!(
json,
Bytes::from(serde_json::to_vec(&StringValue::from("hi")).unwrap())
);
}
#[test]
fn pre_encoded_is_typed() {
use buffa_types::google::protobuf::Int32Value;
let s = PreEncoded::<StringValue>::from_bytes_unchecked(
StringValue::from("a").encode_to_bytes(),
);
let i =
PreEncoded::<Int32Value>::from_bytes_unchecked(Int32Value::from(1).encode_to_bytes());
Encodable::<StringValue>::encode(&s, CodecFormat::Proto).unwrap();
Encodable::<Int32Value>::encode(&i, CodecFormat::Proto).unwrap();
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "do not decode as")]
fn pre_encoded_from_bytes_unchecked_debug_asserts() {
let _ = PreEncoded::<StringValue>::from_bytes_unchecked(Bytes::from_static(&[
0x0a, 0x63, b'h', b'i',
]));
}
#[test]
fn request_context_with_extensions() {
#[derive(Clone, Debug, PartialEq)]
struct Peer(u32);
let mut ext = http::Extensions::new();
ext.insert(Peer(7));
let ctx = RequestContext::new(HeaderMap::new()).with_extensions(ext);
assert_eq!(ctx.extensions().get::<Peer>(), Some(&Peer(7)));
}
#[test]
fn request_context_with_spec_and_protocol() {
use crate::spec::{Spec, StreamType};
let ctx = RequestContext::new(HeaderMap::new());
assert_eq!(ctx.spec(), None);
assert_eq!(ctx.protocol(), None);
const SPEC: Spec = Spec::server("/pkg.Svc/M", StreamType::Unary);
let ctx = RequestContext::new(HeaderMap::new())
.with_spec(Some(SPEC))
.with_protocol(Some(crate::Protocol::Grpc));
assert_eq!(ctx.spec(), Some(SPEC));
assert_eq!(ctx.protocol(), Some(crate::Protocol::Grpc));
let ctx = ctx.with_spec(None).with_protocol(None);
assert_eq!(ctx.spec(), None);
assert_eq!(ctx.protocol(), None);
}
#[test]
fn request_context_with_path() {
let ctx = RequestContext::new(HeaderMap::new());
assert_eq!(ctx.path(), None);
let ctx = RequestContext::new(HeaderMap::new()).with_path("/pkg.Svc/M");
assert_eq!(ctx.path(), Some("/pkg.Svc/M"));
let owned = String::from("/pkg.Svc/Other");
let ctx = RequestContext::new(HeaderMap::new()).with_path(owned);
assert_eq!(ctx.path(), Some("/pkg.Svc/Other"));
let ctx = RequestContext::new(HeaderMap::new()).with_path("");
assert_eq!(ctx.path(), Some(""));
}
}