use std::convert::Infallible;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use bytes::Bytes;
use bytes::BytesMut;
use futures_util::Stream;
use futures_util::StreamExt;
use http::HeaderMap;
use http::StatusCode;
use http_body::Frame;
use http_body_util::StreamBody;
use prost::Message;
use super::GrpcError;
use super::framing::MAX_GRPC_MESSAGE_SIZE;
use super::framing::grpc_encode;
use super::status::GrpcStatus;
use crate::body::TakoBody;
use crate::extractors::FromRequest;
use crate::responder::Responder;
use crate::types::Request;
use crate::types::Response;
pub struct GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
pub stream: S,
pub initial_metadata: HeaderMap,
}
impl<S, T> GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
pub fn new(stream: S) -> Self {
Self {
stream,
initial_metadata: HeaderMap::new(),
}
}
pub fn with_metadata(mut self, headers: HeaderMap) -> Self {
self.initial_metadata = headers;
self
}
}
impl<S, T> Responder for GrpcServerStream<S, T>
where
S: Stream<Item = Result<T, GrpcStatus>> + Send + 'static,
T: Message + Send + 'static,
{
fn into_response(self) -> Response {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
let error_emitted = Arc::new(AtomicBool::new(false));
let mark_err = error_emitted.clone();
let stream = self.stream.map(move |item| match item {
Ok(msg) => {
let bytes = grpc_encode(&msg);
Ok::<_, Infallible>(Frame::data(Bytes::from(bytes)))
}
Err(status) => {
mark_err.store(true, Ordering::Release);
Ok(Frame::trailers(status.write_trailers()))
}
});
let check_err = error_emitted.clone();
let mut once = false;
let trailer = futures_util::stream::iter(std::iter::from_fn(move || {
if once {
None
} else {
once = true;
if check_err.load(Ordering::Acquire) {
None
} else {
Some(Ok::<_, Infallible>(Frame::trailers(
GrpcStatus::ok().write_trailers(),
)))
}
}
}));
let combined = stream.chain(trailer);
let mut resp = http::Response::builder()
.status(StatusCode::OK)
.header(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/grpc"),
)
.body(TakoBody::new(StreamBody::new(combined)))
.expect("static headers + body construction is infallible");
let headers = resp.headers_mut();
for (k, v) in &self.initial_metadata {
headers.insert(k.clone(), v.clone());
}
resp
}
}
pub struct GrpcClientStream<T: Message + Default + Send + 'static> {
pub stream: Pin<Box<dyn Stream<Item = Result<T, GrpcError>> + Send>>,
}
impl<'a, T> FromRequest<'a> for GrpcClientStream<T>
where
T: Message + Default + Send + 'static,
{
type Error = GrpcError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
let ct = req
.headers()
.get(http::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !ct.starts_with("application/grpc") {
return Err(GrpcError::InvalidContentType);
}
let body = std::mem::take(req.body_mut());
let stream = GrpcFrameStream::new(body);
Ok(GrpcClientStream {
stream: Box::pin(stream),
})
}
}
}
struct GrpcFrameStream<T> {
body: TakoBody,
buffer: BytesMut,
finished: bool,
_marker: std::marker::PhantomData<fn() -> T>,
}
impl<T> GrpcFrameStream<T> {
fn new(body: TakoBody) -> Self {
Self {
body,
buffer: BytesMut::new(),
finished: false,
_marker: std::marker::PhantomData,
}
}
}
impl<T> Stream for GrpcFrameStream<T>
where
T: Message + Default,
{
type Item = Result<T, GrpcError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
loop {
if this.buffer.len() >= 5 {
let msg_len = u32::from_be_bytes([
this.buffer[1],
this.buffer[2],
this.buffer[3],
this.buffer[4],
]) as usize;
if msg_len > MAX_GRPC_MESSAGE_SIZE {
return Poll::Ready(Some(Err(GrpcError::MessageTooLarge)));
}
if this.buffer.len() >= 5 + msg_len {
if this.buffer[0] != 0 {
return Poll::Ready(Some(Err(GrpcError::CompressionUnsupported)));
}
let payload = this.buffer.split_to(5 + msg_len);
let msg_bytes = &payload[5..5 + msg_len];
return match T::decode(msg_bytes) {
Ok(m) => Poll::Ready(Some(Ok(m))),
Err(e) => Poll::Ready(Some(Err(GrpcError::DecodeError(e.to_string())))),
};
}
}
if this.finished {
return Poll::Ready(None);
}
let mut body = Pin::new(&mut this.body);
match http_body::Body::poll_frame(body.as_mut(), cx) {
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
this.buffer.extend_from_slice(data);
}
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(GrpcError::BodyReadError(e.to_string()))));
}
Poll::Ready(None) => {
this.finished = true;
}
Poll::Pending => return Poll::Pending,
}
}
}
}
pub struct GrpcBidi<Req, Resp>
where
Req: Message + Default + Send + 'static,
Resp: Message + Send + 'static,
{
pub inbound: GrpcClientStream<Req>,
pub _phantom: std::marker::PhantomData<Resp>,
}
impl<'a, Req, Resp> FromRequest<'a> for GrpcBidi<Req, Resp>
where
Req: Message + Default + Send + 'static,
Resp: Message + Send + 'static,
{
type Error = GrpcError;
fn from_request(
req: &'a mut Request,
) -> impl core::future::Future<Output = core::result::Result<Self, Self::Error>> + Send + 'a {
async move {
Ok(GrpcBidi {
inbound: GrpcClientStream::<Req>::from_request(req).await?,
_phantom: std::marker::PhantomData,
})
}
}
}