1use crate::endpoint;
4use crate::endpoint::{Endpoint, InputReceiver, OutputSender};
5use bytes::Bytes;
6use futures::future::BoxFuture;
7use futures::{FutureExt, TryStreamExt};
8use http::header::CONTENT_TYPE;
9use http::{response, HeaderName, HeaderValue, Request, Response};
10use http_body_util::{BodyExt, Either, Full};
11use hyper::body::{Body, Frame, Incoming};
12use hyper::service::Service;
13use restate_sdk_shared_core::Header;
14use std::convert::Infallible;
15use std::future::{ready, Ready};
16use std::ops::Deref;
17use std::pin::Pin;
18use std::task::{ready, Context, Poll};
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22#[allow(clippy::declare_interior_mutable_const)]
23const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server");
24const X_RESTATE_SERVER_VALUE: HeaderValue =
25 HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION")));
26
27#[derive(Clone)]
29pub struct HyperEndpoint(Endpoint);
30
31impl HyperEndpoint {
32 pub fn new(endpoint: Endpoint) -> Self {
33 Self(endpoint)
34 }
35}
36
37impl Service<Request<Incoming>> for HyperEndpoint {
38 type Response = Response<Either<Full<Bytes>, BidiStreamRunner>>;
39 type Error = endpoint::Error;
40 type Future = Ready<Result<Self::Response, Self::Error>>;
41
42 fn call(&self, req: Request<Incoming>) -> Self::Future {
43 let (parts, body) = req.into_parts();
44 let endpoint_response = match self.0.resolve(parts.uri.path(), parts.headers) {
45 Ok(res) => res,
46 Err(err) => {
47 debug!("Error when trying to handle incoming request: {err}");
48 return ready(Ok(Response::builder()
49 .status(err.status_code())
50 .header(CONTENT_TYPE, "text/plain")
51 .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE)
52 .body(Either::Left(Full::new(Bytes::from(err.to_string()))))
53 .expect("Headers should be valid")));
54 }
55 };
56
57 match endpoint_response {
58 endpoint::Response::ReplyNow {
59 status_code,
60 headers,
61 body,
62 } => ready(Ok(response_builder_from_response_head(
63 status_code,
64 headers,
65 )
66 .body(Either::Left(Full::new(body)))
67 .expect("Headers should be valid"))),
68 endpoint::Response::BidiStream {
69 status_code,
70 headers,
71 handler,
72 } => {
73 let input_receiver =
74 InputReceiver::from_stream(body.into_data_stream().map_err(|e| e.into()));
75
76 let (output_tx, output_rx) = mpsc::unbounded_channel();
77 let output_sender = OutputSender::from_channel(output_tx);
78
79 let handler_fut = Box::pin(handler.handle(input_receiver, output_sender));
80
81 ready(Ok(response_builder_from_response_head(
82 status_code,
83 headers,
84 )
85 .body(Either::Right(BidiStreamRunner {
86 fut: Some(handler_fut),
87 output_rx,
88 end_stream: false,
89 }))
90 .expect("Headers should be valid")))
91 }
92 }
93 }
94}
95
96fn response_builder_from_response_head(
97 status_code: u16,
98 headers: Vec<Header>,
99) -> response::Builder {
100 let mut response_builder = Response::builder()
101 .status(status_code)
102 .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
103
104 for header in headers {
105 response_builder = response_builder.header(header.key.deref(), header.value.deref());
106 }
107
108 response_builder
109}
110
111pub struct BidiStreamRunner {
112 fut: Option<BoxFuture<'static, Result<(), endpoint::Error>>>,
113 output_rx: mpsc::UnboundedReceiver<Bytes>,
114 end_stream: bool,
115}
116
117impl Body for BidiStreamRunner {
118 type Data = Bytes;
119 type Error = Infallible;
120
121 fn poll_frame(
122 mut self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
125 if let Some(mut fut) = self.fut.take() {
127 match fut.poll_unpin(cx) {
128 Poll::Ready(res) => {
129 if let Err(e) = res {
130 warn!("Handler failure: {e:?}")
131 }
132 self.output_rx.close();
133 }
134 Poll::Pending => {
135 self.fut = Some(fut);
136 }
137 }
138 }
139
140 if let Some(out) = ready!(self.output_rx.poll_recv(cx)) {
141 Poll::Ready(Some(Ok(Frame::data(out))))
142 } else {
143 self.end_stream = true;
144 Poll::Ready(None)
145 }
146 }
147
148 fn is_end_stream(&self) -> bool {
149 self.end_stream
150 }
151}