use bytes::Bytes;
use http_body::Frame;
use serde::Serialize;
use typeway_grpc::framing::encode_grpc_frame;
use typeway_grpc::status::{GrpcCode, GrpcStatus};
use crate::body::{body_from_stream, BoxBody, BoxBodyError};
use crate::response::IntoResponse;
#[derive(Debug, Clone, Copy)]
pub(crate) struct GrpcStreamMarker;
pub struct GrpcStreamSender<T> {
tx: tokio::sync::mpsc::Sender<Result<T, GrpcStatus>>,
}
impl<T> GrpcStreamSender<T> {
pub async fn send(&self, item: T) -> Result<(), typeway_grpc::StreamSendError> {
self.tx
.send(Ok(item))
.await
.map_err(|_| typeway_grpc::StreamSendError)
}
pub async fn send_error(
&self,
status: GrpcStatus,
) -> Result<(), typeway_grpc::StreamSendError> {
self.tx
.send(Err(status))
.await
.map_err(|_| typeway_grpc::StreamSendError)
}
}
pub struct GrpcStream<T> {
rx: tokio::sync::mpsc::Receiver<Result<T, GrpcStatus>>,
}
impl<T> GrpcStream<T> {
pub fn channel(buffer: usize) -> (GrpcStreamSender<T>, GrpcStream<T>) {
let (tx, rx) = tokio::sync::mpsc::channel(buffer);
(GrpcStreamSender { tx }, GrpcStream { rx })
}
}
struct StreamState<T> {
rx: tokio::sync::mpsc::Receiver<Result<T, GrpcStatus>>,
done: bool,
}
impl<T: Serialize + Send + 'static> IntoResponse for GrpcStream<T> {
fn into_response(self) -> http::Response<BoxBody> {
let state = StreamState {
rx: self.rx,
done: false,
};
let stream = futures::stream::unfold(state, |mut state| async move {
if state.done {
return None;
}
match state.rx.recv().await {
Some(Ok(item)) => {
let json_bytes = serde_json::to_vec(&item).unwrap_or_default();
let framed = encode_grpc_frame(&json_bytes);
let frame: Result<Frame<Bytes>, BoxBodyError> =
Ok(Frame::data(Bytes::from(framed)));
Some((frame, state))
}
Some(Err(status)) => {
state.done = true;
let trailers = build_trailers(&status);
Some((Ok(Frame::trailers(trailers)), state))
}
None => {
state.done = true;
let ok_status = GrpcStatus {
code: GrpcCode::Ok,
message: String::new(),
};
let trailers = build_trailers(&ok_status);
Some((Ok(Frame::trailers(trailers)), state))
}
}
});
let body = body_from_stream(stream);
let mut res = http::Response::new(body);
*res.status_mut() = http::StatusCode::OK;
res.headers_mut().insert(
"content-type",
http::HeaderValue::from_static("application/grpc+json"),
);
res.extensions_mut().insert(GrpcStreamMarker);
res
}
}
fn build_trailers(status: &GrpcStatus) -> http::HeaderMap {
let mut trailers = http::HeaderMap::new();
trailers.insert(
"grpc-status",
status
.code
.as_i32()
.to_string()
.parse()
.expect("valid grpc-status"),
);
if !status.message.is_empty() {
if let Ok(val) = status.message.parse() {
trailers.insert("grpc-message", val);
}
}
trailers
}