restate_sdk/
hyper.rs

1//! Hyper integration.
2
3use crate::endpoint;
4use crate::endpoint::{Endpoint, InputReceiver, OutputSender};
5use bytes::Bytes;
6use futures::future::BoxFuture;
7use futures::{FutureExt, TryStreamExt};
8use http::header::CONTENT_TYPE;
9use http::{response, HeaderName, HeaderValue, Request, Response};
10use http_body_util::{BodyExt, Either, Full};
11use hyper::body::{Body, Frame, Incoming};
12use hyper::service::Service;
13use restate_sdk_shared_core::Header;
14use std::convert::Infallible;
15use std::future::{ready, Ready};
16use std::ops::Deref;
17use std::pin::Pin;
18use std::task::{ready, Context, Poll};
19use tokio::sync::mpsc;
20use tracing::{debug, warn};
21
22#[allow(clippy::declare_interior_mutable_const)]
23const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server");
24const X_RESTATE_SERVER_VALUE: HeaderValue =
25    HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION")));
26
27/// Wraps [`Endpoint`] to implement hyper [`Service`].
28#[derive(Clone)]
29pub struct HyperEndpoint(Endpoint);
30
31impl HyperEndpoint {
32    pub fn new(endpoint: Endpoint) -> Self {
33        Self(endpoint)
34    }
35}
36
37impl Service<Request<Incoming>> for HyperEndpoint {
38    type Response = Response<Either<Full<Bytes>, BidiStreamRunner>>;
39    type Error = endpoint::Error;
40    type Future = Ready<Result<Self::Response, Self::Error>>;
41
42    fn call(&self, req: Request<Incoming>) -> Self::Future {
43        let (parts, body) = req.into_parts();
44        let endpoint_response = match self.0.resolve(parts.uri.path(), parts.headers) {
45            Ok(res) => res,
46            Err(err) => {
47                debug!("Error when trying to handle incoming request: {err}");
48                return ready(Ok(Response::builder()
49                    .status(err.status_code())
50                    .header(CONTENT_TYPE, "text/plain")
51                    .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE)
52                    .body(Either::Left(Full::new(Bytes::from(err.to_string()))))
53                    .expect("Headers should be valid")));
54            }
55        };
56
57        match endpoint_response {
58            endpoint::Response::ReplyNow {
59                status_code,
60                headers,
61                body,
62            } => ready(Ok(response_builder_from_response_head(
63                status_code,
64                headers,
65            )
66            .body(Either::Left(Full::new(body)))
67            .expect("Headers should be valid"))),
68            endpoint::Response::BidiStream {
69                status_code,
70                headers,
71                handler,
72            } => {
73                let input_receiver =
74                    InputReceiver::from_stream(body.into_data_stream().map_err(|e| e.into()));
75
76                let (output_tx, output_rx) = mpsc::unbounded_channel();
77                let output_sender = OutputSender::from_channel(output_tx);
78
79                let handler_fut = Box::pin(handler.handle(input_receiver, output_sender));
80
81                ready(Ok(response_builder_from_response_head(
82                    status_code,
83                    headers,
84                )
85                .body(Either::Right(BidiStreamRunner {
86                    fut: Some(handler_fut),
87                    output_rx,
88                    end_stream: false,
89                }))
90                .expect("Headers should be valid")))
91            }
92        }
93    }
94}
95
96fn response_builder_from_response_head(
97    status_code: u16,
98    headers: Vec<Header>,
99) -> response::Builder {
100    let mut response_builder = Response::builder()
101        .status(status_code)
102        .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
103
104    for header in headers {
105        response_builder = response_builder.header(header.key.deref(), header.value.deref());
106    }
107
108    response_builder
109}
110
111pub struct BidiStreamRunner {
112    fut: Option<BoxFuture<'static, Result<(), endpoint::Error>>>,
113    output_rx: mpsc::UnboundedReceiver<Bytes>,
114    end_stream: bool,
115}
116
117impl Body for BidiStreamRunner {
118    type Data = Bytes;
119    type Error = Infallible;
120
121    fn poll_frame(
122        mut self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
125        // First try to consume the runner future
126        if let Some(mut fut) = self.fut.take() {
127            match fut.poll_unpin(cx) {
128                Poll::Ready(res) => {
129                    if let Err(e) = res {
130                        warn!("Handler failure: {e:?}")
131                    }
132                    self.output_rx.close();
133                }
134                Poll::Pending => {
135                    self.fut = Some(fut);
136                }
137            }
138        }
139
140        if let Some(out) = ready!(self.output_rx.poll_recv(cx)) {
141            Poll::Ready(Some(Ok(Frame::data(out))))
142        } else {
143            self.end_stream = true;
144            Poll::Ready(None)
145        }
146    }
147
148    fn is_end_stream(&self) -> bool {
149        self.end_stream
150    }
151}