use std::pin::Pin;
use std::time::Instant;
use buffa::Message;
use buffa::view::ViewEncode;
use bytes::Bytes;
use futures::Stream;
use http::HeaderMap;
use http::header::{HeaderName, HeaderValue};
use serde::Serialize;
use crate::codec::CodecFormat;
use crate::error::ConnectError;
#[derive(Debug, Clone, Default)]
pub struct RequestContext {
pub headers: HeaderMap,
pub deadline: Option<Instant>,
pub extensions: http::Extensions,
}
impl RequestContext {
pub fn new(headers: HeaderMap) -> Self {
Self {
headers,
deadline: None,
extensions: http::Extensions::new(),
}
}
#[must_use]
pub fn with_deadline(mut self, deadline: Option<Instant>) -> Self {
self.deadline = deadline;
self
}
#[must_use]
pub fn with_extensions(mut self, extensions: http::Extensions) -> Self {
self.extensions = extensions;
self
}
pub fn header(&self, key: impl http::header::AsHeaderName) -> Option<&HeaderValue> {
self.headers.get(key)
}
}
#[derive(Debug, Clone)]
pub struct Response<B> {
pub body: B,
pub headers: HeaderMap,
pub trailers: HeaderMap,
pub compress: Option<bool>,
}
impl<B> Response<B> {
pub fn ok(body: B) -> ServiceResult<B> {
Ok(Self::from(body))
}
pub fn new(body: B) -> Self {
Self {
body,
headers: HeaderMap::new(),
trailers: HeaderMap::new(),
compress: None,
}
}
#[must_use]
pub fn with_header<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,
K::Error: std::fmt::Debug,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Debug,
{
self.headers
.append(name.try_into().unwrap(), value.try_into().unwrap());
self
}
pub fn try_with_header<K, V>(mut self, name: K, value: V) -> Result<Self, http::Error>
where
K: TryInto<HeaderName>,
K::Error: Into<http::Error>,
V: TryInto<HeaderValue>,
V::Error: Into<http::Error>,
{
self.headers.append(
name.try_into().map_err(Into::into)?,
value.try_into().map_err(Into::into)?,
);
Ok(self)
}
#[must_use]
pub fn with_trailer<K, V>(mut self, name: K, value: V) -> Self
where
K: TryInto<HeaderName>,
K::Error: std::fmt::Debug,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Debug,
{
self.trailers
.append(name.try_into().unwrap(), value.try_into().unwrap());
self
}
pub fn try_with_trailer<K, V>(mut self, name: K, value: V) -> Result<Self, http::Error>
where
K: TryInto<HeaderName>,
K::Error: Into<http::Error>,
V: TryInto<HeaderValue>,
V::Error: Into<http::Error>,
{
self.trailers.append(
name.try_into().map_err(Into::into)?,
value.try_into().map_err(Into::into)?,
);
Ok(self)
}
#[must_use]
pub fn compress(mut self, enabled: impl Into<Option<bool>>) -> Self {
self.compress = enabled.into();
self
}
pub fn map_body<C>(self, f: impl FnOnce(B) -> C) -> Response<C> {
Response {
body: f(self.body),
headers: self.headers,
trailers: self.trailers,
compress: self.compress,
}
}
}
impl<B> From<B> for Response<B> {
fn from(body: B) -> Self {
Self::new(body)
}
}
impl<T> Response<ServiceStream<T>> {
pub fn stream(s: impl Stream<Item = Result<T, ConnectError>> + Send + 'static) -> Self {
Self::new(Box::pin(s))
}
pub fn stream_ok(
s: impl Stream<Item = Result<T, ConnectError>> + Send + 'static,
) -> ServiceResult<ServiceStream<T>> {
Ok(Self::stream(s))
}
}
pub type ServiceResult<B> = Result<Response<B>, ConnectError>;
pub type ServiceStream<T> = Pin<Box<dyn Stream<Item = Result<T, ConnectError>> + Send>>;
pub trait Encodable<M> {
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError>;
}
impl<M: Message + Serialize> Encodable<M> for M {
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError> {
match codec {
CodecFormat::Proto => Ok(self.encode_to_bytes()),
CodecFormat::Json => serde_json::to_vec(self).map(Bytes::from).map_err(|e| {
ConnectError::internal(format!("failed to encode JSON response: {e}"))
}),
}
}
}
#[doc(hidden)]
pub fn encode_view_body<'a, V: ViewEncode<'a>>(
view: &V,
codec: CodecFormat,
) -> Result<Bytes, ConnectError> {
match codec {
CodecFormat::Proto => Ok(view.encode_to_bytes()),
CodecFormat::Json => Err(ConnectError::unimplemented(
"view-body responses do not support the JSON codec; return the owned message type for JSON-serving handlers",
)),
}
}
#[derive(Debug, Clone)]
pub enum MaybeBorrowed<M, V> {
Owned(M),
Borrowed(V),
}
impl<M, V> Encodable<M> for MaybeBorrowed<M, V>
where
M: Encodable<M>,
V: Encodable<M>,
{
fn encode(&self, codec: CodecFormat) -> Result<Bytes, ConnectError> {
match self {
Self::Owned(m) => m.encode(codec),
Self::Borrowed(v) => v.encode(codec),
}
}
}
pub type EncodedResponse = Response<Bytes>;
impl<B> Response<B> {
#[doc(hidden)] pub fn encode<M>(self, codec: CodecFormat) -> Result<EncodedResponse, ConnectError>
where
B: Encodable<M>,
{
let bytes = self.body.encode(codec)?;
Ok(Response {
body: bytes,
headers: self.headers,
trailers: self.trailers,
compress: self.compress,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use buffa_types::google::protobuf::StringValue;
#[tokio::test]
async fn response_stream_ok_shorthand() {
use futures::StreamExt;
let r: ServiceResult<ServiceStream<i32>> =
Response::stream_ok(futures::stream::iter([Ok(7)]));
let collected: Vec<_> = r.unwrap().body.map(|x| x.unwrap()).collect().await;
assert_eq!(collected, vec![7]);
}
#[test]
fn compress_tristate() {
assert_eq!(Response::new(()).compress(true).compress, Some(true));
assert_eq!(Response::new(()).compress(false).compress, Some(false));
assert_eq!(Response::new(()).compress(None).compress, None);
}
#[test]
fn header_accepts_str() {
let mut h = HeaderMap::new();
h.insert("x-custom", HeaderValue::from_static("v"));
let ctx = RequestContext::new(h);
assert_eq!(ctx.header("x-custom").unwrap(), "v");
}
#[test]
fn response_ok_shorthand() {
let r: ServiceResult<u32> = Response::ok(42);
let r = r.unwrap();
assert_eq!(r.body, 42);
assert!(r.headers.is_empty());
}
#[test]
fn response_from_body() {
let r: Response<StringValue> = StringValue::from("hi").into();
assert_eq!(r.body.value, "hi");
assert!(r.headers.is_empty());
assert!(r.trailers.is_empty());
assert_eq!(r.compress, None);
}
#[test]
fn response_builder() {
let r = Response::new(StringValue::from("hi"))
.with_header("x-a", "1")
.with_trailer("x-b", "2")
.compress(true);
assert_eq!(r.headers.get("x-a").unwrap(), "1");
assert_eq!(r.trailers.get("x-b").unwrap(), "2");
assert_eq!(r.compress, Some(true));
}
#[test]
fn encodable_owned_proto() {
let m = StringValue::from("hello");
let bytes = Encodable::<StringValue>::encode(&m, CodecFormat::Proto).unwrap();
assert_eq!(
StringValue::decode_from_slice(&bytes).unwrap().value,
"hello"
);
}
#[test]
fn encodable_owned_json() {
let m = StringValue::from("hello");
let bytes = Encodable::<StringValue>::encode(&m, CodecFormat::Json).unwrap();
assert_eq!(&bytes[..], b"\"hello\"");
}
#[test]
fn response_encode() {
let r = Response::new(StringValue::from("hi")).with_header("x-a", "1");
let enc = r.encode::<StringValue>(CodecFormat::Proto).unwrap();
assert_eq!(enc.headers.get("x-a").unwrap(), "1");
assert_eq!(
StringValue::decode_from_slice(&enc.body).unwrap().value,
"hi"
);
}
#[test]
fn request_context_new() {
let mut h = HeaderMap::new();
h.insert("x-custom", HeaderValue::from_static("v"));
let ctx = RequestContext::new(h);
assert_eq!(
ctx.header(HeaderName::from_static("x-custom")).unwrap(),
"v"
);
assert!(ctx.deadline.is_none());
}
#[test]
fn request_context_with_deadline() {
let d = Instant::now();
let ctx = RequestContext::new(HeaderMap::new()).with_deadline(Some(d));
assert_eq!(ctx.deadline, Some(d));
}
#[test]
fn response_map_body_preserves_metadata() {
let r = Response::new(2u32)
.with_header("x-h", "1")
.with_trailer("x-t", "2")
.compress(true);
let r = r.map_body(|n| n.to_string());
assert_eq!(r.body, "2");
assert_eq!(r.headers.get("x-h").unwrap(), "1");
assert_eq!(r.trailers.get("x-t").unwrap(), "2");
assert_eq!(r.compress, Some(true));
}
#[tokio::test]
async fn response_stream_yields_items() {
use futures::StreamExt;
let r: Response<ServiceStream<i32>> =
Response::stream(futures::stream::iter([Ok(1), Ok(2), Ok(3)]));
let collected: Vec<_> = r.body.map(|x| x.unwrap()).collect().await;
assert_eq!(collected, vec![1, 2, 3]);
}
#[test]
#[should_panic]
fn with_header_panics_on_invalid_name() {
let _ = Response::new(()).with_header("invalid header name", "v");
}
#[test]
fn try_with_header_errors_on_invalid_name() {
let err = Response::new(())
.try_with_header("invalid header name", "v")
.unwrap_err();
assert!(err.is::<http::header::InvalidHeaderName>());
}
#[test]
fn try_with_header_ok_appends() {
let r = Response::new(())
.try_with_header("x-a", "1")
.unwrap()
.try_with_header("x-a", "2")
.unwrap();
let vals: Vec<_> = r.headers.get_all("x-a").iter().collect();
assert_eq!(vals.len(), 2);
}
#[test]
fn try_with_trailer_errors_on_invalid_value() {
let err = Response::new(())
.try_with_trailer("x-t", "bad\nvalue")
.unwrap_err();
assert!(err.is::<http::header::InvalidHeaderValue>());
}
#[test]
fn encode_view_body_proto() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let v = StringValueView {
value: "hi",
..Default::default()
};
let bytes = encode_view_body(&v, CodecFormat::Proto).unwrap();
assert_eq!(StringValue::decode_from_slice(&bytes).unwrap().value, "hi");
}
#[test]
fn encode_view_body_json_errors() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let v = StringValueView::default();
let err = encode_view_body(&v, CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::ErrorCode::Unimplemented);
assert!(err.message.as_deref().unwrap().contains("JSON codec"));
}
struct V<'a>(buffa_types::google::protobuf::__buffa::view::StringValueView<'a>);
impl Encodable<StringValue> for V<'_> {
fn encode(&self, c: CodecFormat) -> Result<Bytes, ConnectError> {
encode_view_body(&self.0, c)
}
}
#[test]
fn maybe_borrowed_dispatch() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let owned: MaybeBorrowed<StringValue, V<'_>> =
MaybeBorrowed::Owned(StringValue::from("owned"));
let borrowed = MaybeBorrowed::Borrowed(V(StringValueView {
value: "view",
..Default::default()
}));
assert_eq!(
StringValue::decode_from_slice(&owned.encode(CodecFormat::Proto).unwrap())
.unwrap()
.value,
"owned"
);
assert_eq!(
StringValue::decode_from_slice(&borrowed.encode(CodecFormat::Proto).unwrap())
.unwrap()
.value,
"view"
);
}
#[test]
fn maybe_borrowed_borrowed_json_unimplemented() {
use buffa_types::google::protobuf::__buffa::view::StringValueView;
let borrowed: MaybeBorrowed<StringValue, V<'_>> =
MaybeBorrowed::Borrowed(V(StringValueView::default()));
let err = borrowed.encode(CodecFormat::Json).unwrap_err();
assert_eq!(err.code, crate::ErrorCode::Unimplemented);
}
#[test]
fn request_context_with_extensions() {
#[derive(Clone, Debug, PartialEq)]
struct Peer(u32);
let mut ext = http::Extensions::new();
ext.insert(Peer(7));
let ctx = RequestContext::new(HeaderMap::new()).with_extensions(ext);
assert_eq!(ctx.extensions.get::<Peer>(), Some(&Peer(7)));
}
}