Skip to main content

restate_sdk/endpoint/
mod.rs

1mod builder;
2mod context;
3mod futures;
4mod handler_state;
5
6pub use builder::{Builder, HandlerOptions, ServiceOptions};
7
8use crate::endpoint::futures::handler_state_aware::HandlerStateAwareFuture;
9use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::service::Service;
12use ::futures::future::BoxFuture;
13use ::futures::{FutureExt, Stream, StreamExt, TryStreamExt};
14use bytes::Bytes;
15pub use context::{ContextInternal, InputMetadata};
16use http::header::CONTENT_TYPE;
17use http::{HeaderName, HeaderValue};
18use http_body::{Body, Frame, SizeHint};
19use http_body_util::{BodyExt, Either, Full};
20use pin_project_lite::pin_project;
21use restate_sdk_shared_core::{
22    CoreVM, Error as CoreError, Header, HeaderMap, IdentityVerifier, ResponseHead, VM, VerifyError,
23};
24use std::collections::HashMap;
25use std::convert::Infallible;
26use std::future::poll_fn;
27use std::ops::Deref;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll, ready};
31use tokio::sync::mpsc;
32use tracing::{Instrument, info_span, warn};
33
34#[allow(clippy::declare_interior_mutable_const)]
35const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server");
36const X_RESTATE_SERVER_VALUE: HeaderValue =
37    HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION")));
38const DISCOVERY_CONTENT_TYPE_V2: &str = "application/vnd.restate.endpointmanifest.v2+json";
39const DISCOVERY_CONTENT_TYPE_V3: &str = "application/vnd.restate.endpointmanifest.v3+json";
40const DISCOVERY_CONTENT_TYPE_V4: &str = "application/vnd.restate.endpointmanifest.v4+json";
41
42type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
43
44// TODO can we have the backtrace here?
45/// Endpoint error. This encapsulates any error that happens within the SDK while processing a request.
46#[derive(Debug, thiserror::Error)]
47#[error(transparent)]
48pub struct Error(#[from] ErrorInner);
49
50impl Error {
51    /// New error for unknown handler
52    pub fn unknown_handler(service_name: &str, handler_name: &str) -> Self {
53        Self(ErrorInner::UnknownServiceHandler(
54            service_name.to_owned(),
55            handler_name.to_owned(),
56        ))
57    }
58
59    /// Returns the HTTP status code for this error.
60    pub fn status_code(&self) -> u16 {
61        match &self.0 {
62            ErrorInner::VM(e) => e.code(),
63            ErrorInner::UnknownService(_) | ErrorInner::UnknownServiceHandler(_, _) => 404,
64            ErrorInner::Suspended
65            | ErrorInner::UnexpectedOutputClosed
66            | ErrorInner::UnexpectedValueVariantForSyscall { .. }
67            | ErrorInner::Deserialization { .. }
68            | ErrorInner::Serialization { .. }
69            | ErrorInner::HandlerResult { .. }
70            | ErrorInner::InputDrain(_) => 500,
71            ErrorInner::FieldRequiresMinimumVersion { .. } => 500,
72            ErrorInner::BadDiscoveryVersion(_) => 415,
73            ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400,
74            ErrorInner::IdentityVerification(_) => 401,
75        }
76    }
77}
78
79#[derive(Debug, thiserror::Error)]
80pub(crate) enum ErrorInner {
81    #[error("Received a request for unknown service '{0}'")]
82    UnknownService(String),
83    #[error("Received a request for unknown service handler '{0}/{1}'")]
84    UnknownServiceHandler(String, String),
85    #[error("Error when processing the request: {0:?}")]
86    VM(#[from] CoreError),
87    #[error("Error when verifying identity: {0:?}")]
88    IdentityVerification(#[from] VerifyError),
89    #[error("Cannot convert header '{0}', reason: {1}")]
90    Header(String, #[source] BoxError),
91    #[error(
92        "Cannot reply to discovery, got accept header '{0}' but currently supported discovery versions are v2 and v3"
93    )]
94    BadDiscoveryVersion(String),
95    #[error(
96        "The field '{0}' was set in the service/handler options, but it requires minimum discovery protocol version {1}"
97    )]
98    FieldRequiresMinimumVersion(&'static str, u32),
99    #[error("Bad path '{0}', expected either '/discover' or '/invoke/service/handler'")]
100    BadPath(String),
101    #[error("Suspended")]
102    Suspended,
103    #[error("Unexpected output closed")]
104    UnexpectedOutputClosed,
105    #[error("Unexpected value variant {variant} for syscall '{syscall}'")]
106    UnexpectedValueVariantForSyscall {
107        variant: &'static str,
108        syscall: &'static str,
109    },
110    #[error("Failed to deserialize with '{syscall}': {err:?}'")]
111    Deserialization {
112        syscall: &'static str,
113        #[source]
114        err: BoxError,
115    },
116    #[error("Failed to serialize with '{syscall}': {err:?}'")]
117    Serialization {
118        syscall: &'static str,
119        #[source]
120        err: BoxError,
121    },
122    #[error("Handler failed with retryable error: {err:?}'")]
123    HandlerResult {
124        #[source]
125        err: BoxError,
126    },
127    #[error("Error while draining the input stream: {0}")]
128    InputDrain(BoxError),
129}
130
131impl From<CoreError> for Error {
132    fn from(e: CoreError) -> Self {
133        if e.is_suspended_error() {
134            return ErrorInner::Suspended.into();
135        }
136        ErrorInner::from(e).into()
137    }
138}
139
140struct BoxedService(
141    Box<dyn Service<Future = BoxFuture<'static, Result<(), Error>>> + Send + Sync + 'static>,
142);
143
144impl BoxedService {
145    pub fn new<
146        S: Service<Future = BoxFuture<'static, Result<(), Error>>> + Send + Sync + 'static,
147    >(
148        service: S,
149    ) -> Self {
150        Self(Box::new(service))
151    }
152}
153
154impl Service for BoxedService {
155    type Future = BoxFuture<'static, Result<(), Error>>;
156
157    fn handle(&self, req: ContextInternal) -> Self::Future {
158        self.0.handle(req)
159    }
160}
161
162/// This struct encapsulates all the business logic to handle incoming requests to the SDK,
163/// including service discovery, invocations and identity verification.
164///
165/// It internally wraps the provided services. This structure is cheaply cloneable.
166#[derive(Clone)]
167pub struct Endpoint(Arc<EndpointInner>);
168
169impl Endpoint {
170    /// Create a new builder for [`Endpoint`].
171    pub fn builder() -> Builder {
172        Builder::new()
173    }
174}
175
176struct EndpointInner {
177    svcs: HashMap<String, BoxedService>,
178    discovery_services: Vec<crate::discovery::Service>,
179    identity_verifier: IdentityVerifier,
180}
181
182#[derive(Default)]
183pub enum ProtocolMode {
184    #[allow(dead_code)]
185    RequestResponse,
186    #[default]
187    BidiStream,
188}
189
190/// Options for [`Endpoint::handle`].
191#[derive(Default)]
192pub struct HandleOptions {
193    pub protocol_mode: ProtocolMode,
194}
195
196impl Endpoint {
197    /// Handle an [`http::Request`], producing an [`http::Response`].
198    pub fn handle<B: Body<Data = Bytes, Error: Into<BoxError> + Send> + Send + 'static>(
199        &self,
200        req: http::Request<B>,
201    ) -> http::Response<ResponseBody> {
202        self.handle_with_options(req, HandleOptions::default())
203    }
204
205    /// Handle an [`http::Request`], producing an [`http::Response`].
206    pub fn handle_with_options<
207        B: Body<Data = Bytes, Error: Into<BoxError> + Send> + Send + 'static,
208    >(
209        &self,
210        req: http::Request<B>,
211        options: HandleOptions,
212    ) -> http::Response<ResponseBody> {
213        let (parts, body) = req.into_parts();
214        let path = parts.uri.path();
215        let headers = parts.headers;
216
217        if let Err(e) = self.0.identity_verifier.verify_identity(&headers, path) {
218            return error_response(ErrorInner::IdentityVerification(e));
219        }
220
221        let parts: Vec<&str> = path.split('/').collect();
222
223        if parts.last() == Some(&"health") {
224            return self.handle_health();
225        }
226        if parts.last() == Some(&"discover") {
227            return self.handle_discovery(headers, options.protocol_mode);
228        }
229
230        // Parse service name/handler name
231        let (svc_name, handler_name) = match parts.get(parts.len() - 3..) {
232            None => return error_response(ErrorInner::BadPath(path.to_owned())),
233            Some(last_elements) if last_elements[0] != "invoke" => {
234                return error_response(ErrorInner::BadPath(path.to_owned()));
235            }
236            Some(last_elements) => (last_elements[1].to_owned(), last_elements[2].to_owned()),
237        };
238
239        // Prepare vm
240        let vm = match CoreVM::new(headers, Default::default()) {
241            Ok(vm) => vm,
242            Err(e) => return error_response(e),
243        };
244        let ResponseHead {
245            status_code,
246            headers,
247            ..
248        } = vm.get_response_head();
249
250        // Resolve service
251        if !self.0.svcs.contains_key(&svc_name) {
252            return error_response(ErrorInner::UnknownService(svc_name.to_owned()));
253        }
254
255        // Prepare handle_invocation future
256        let input_receiver =
257            InputReceiver::from_stream(body.into_data_stream().map_err(|e| e.into()));
258        let (output_tx, output_rx) = mpsc::unbounded_channel();
259        let output_sender = OutputSender::from_channel(output_tx);
260        let handle_invocation_fut = Box::pin(handle_invocation(
261            svc_name,
262            handler_name,
263            vm,
264            Arc::clone(&self.0),
265            input_receiver,
266            output_sender,
267        ));
268
269        // Wrap the invocation runner in the response
270        // When the body is pulled, the invocation gets processed.
271        let mut invocation_response_builder = http::Response::builder()
272            .status(status_code)
273            .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
274        for Header { key, value } in headers {
275            invocation_response_builder =
276                invocation_response_builder.header(key.deref(), value.deref());
277        }
278        invocation_response_builder
279            .body(
280                Either::Right(InvocationRunnerBody {
281                    fut: Some(handle_invocation_fut),
282                    output_rx,
283                    end_stream: false,
284                })
285                .into(),
286            )
287            .expect("Headers should be valid")
288    }
289
290    fn handle_health(&self) -> http::Response<ResponseBody> {
291        simple_response(200, vec![], Bytes::default())
292    }
293
294    fn handle_discovery(
295        &self,
296        headers: http::HeaderMap,
297        protocol_mode: ProtocolMode,
298    ) -> http::Response<ResponseBody> {
299        // Extract Accept header from request
300        let accept_header = match headers
301            .extract("accept")
302            .map_err(|e| ErrorInner::Header("accept".to_owned(), Box::new(e)))
303        {
304            Ok(h) => h,
305            Err(e) => return error_response(e),
306        };
307
308        // Negotiate discovery protocol version
309        let mut version = 2;
310        let mut content_type = DISCOVERY_CONTENT_TYPE_V2;
311        if let Some(accept) = accept_header {
312            if accept.contains(DISCOVERY_CONTENT_TYPE_V4) {
313                version = 4;
314                content_type = DISCOVERY_CONTENT_TYPE_V4;
315            } else if accept.contains(DISCOVERY_CONTENT_TYPE_V3) {
316                version = 3;
317                content_type = DISCOVERY_CONTENT_TYPE_V3;
318            } else if accept.contains(DISCOVERY_CONTENT_TYPE_V2) {
319                version = 2;
320                content_type = DISCOVERY_CONTENT_TYPE_V2;
321            } else {
322                return error_response(ErrorInner::BadDiscoveryVersion(accept.to_owned()));
323            }
324        }
325
326        if let Err(e) = self.validate_discovery_request(version) {
327            return error_response(e);
328        }
329
330        simple_response(
331            200,
332            vec![Header {
333                key: "content-type".into(),
334                value: content_type.into(),
335            }],
336            Bytes::from(
337                serde_json::to_string(&crate::discovery::Endpoint {
338                    lambda_compression: None,
339                    max_protocol_version: std::num::NonZero::new(5).unwrap(),
340                    min_protocol_version: std::num::NonZero::new(5).unwrap(),
341                    protocol_mode: Some(match protocol_mode {
342                        ProtocolMode::RequestResponse => {
343                            crate::discovery::ProtocolMode::RequestResponse
344                        }
345                        ProtocolMode::BidiStream => crate::discovery::ProtocolMode::BidiStream,
346                    }),
347                    services: self.0.discovery_services.clone(),
348                })
349                .expect("Discovery should be serializable"),
350            ),
351        )
352    }
353
354    fn validate_discovery_request(&self, version: usize) -> Result<(), ErrorInner> {
355        // Validate that new discovery fields aren't used with older protocol versions
356        if version <= 3 {
357            // Check for new discovery fields in version 3 that shouldn't be used in version 2 or lower
358            for service in &self.0.discovery_services {
359                if service.retry_policy_initial_interval.is_some()
360                    || service.retry_policy_exponentiation_factor.is_some()
361                    || service.retry_policy_max_interval.is_some()
362                    || service.retry_policy_max_attempts.is_some()
363                    || service.retry_policy_on_max_attempts.is_some()
364                {
365                    Err(ErrorInner::FieldRequiresMinimumVersion("retry_policy", 4))?;
366                }
367
368                for handler in &service.handlers {
369                    if handler.retry_policy_initial_interval.is_some()
370                        || handler.retry_policy_exponentiation_factor.is_some()
371                        || handler.retry_policy_max_interval.is_some()
372                        || handler.retry_policy_max_attempts.is_some()
373                        || handler.retry_policy_on_max_attempts.is_some()
374                    {
375                        Err(ErrorInner::FieldRequiresMinimumVersion("retry_policy", 4))?;
376                    }
377                }
378            }
379        }
380        if version <= 2 {
381            // Check for new discovery fields in version 3 that shouldn't be used in version 2 or lower
382            for service in &self.0.discovery_services {
383                if service.inactivity_timeout.is_some() {
384                    Err(ErrorInner::FieldRequiresMinimumVersion(
385                        "inactivity_timeout",
386                        3,
387                    ))?;
388                }
389                if service.abort_timeout.is_some() {
390                    Err(ErrorInner::FieldRequiresMinimumVersion("abort_timeout", 3))?;
391                }
392                if service.idempotency_retention.is_some() {
393                    Err(ErrorInner::FieldRequiresMinimumVersion(
394                        "idempotency_retention",
395                        3,
396                    ))?;
397                }
398                if service.journal_retention.is_some() {
399                    Err(ErrorInner::FieldRequiresMinimumVersion(
400                        "journal_retention",
401                        3,
402                    ))?;
403                }
404                if service.enable_lazy_state.is_some() {
405                    Err(ErrorInner::FieldRequiresMinimumVersion(
406                        "enable_lazy_state",
407                        3,
408                    ))?;
409                }
410                if service.ingress_private.is_some() {
411                    Err(ErrorInner::FieldRequiresMinimumVersion(
412                        "ingress_private",
413                        3,
414                    ))?;
415                }
416
417                for handler in &service.handlers {
418                    if handler.inactivity_timeout.is_some() {
419                        Err(ErrorInner::FieldRequiresMinimumVersion(
420                            "inactivity_timeout",
421                            3,
422                        ))?;
423                    }
424                    if handler.abort_timeout.is_some() {
425                        Err(ErrorInner::FieldRequiresMinimumVersion("abort_timeout", 3))?;
426                    }
427                    if handler.idempotency_retention.is_some() {
428                        Err(ErrorInner::FieldRequiresMinimumVersion(
429                            "idempotency_retention",
430                            3,
431                        ))?;
432                    }
433                    if handler.journal_retention.is_some() {
434                        Err(ErrorInner::FieldRequiresMinimumVersion(
435                            "journal_retention",
436                            3,
437                        ))?;
438                    }
439                    if handler.workflow_completion_retention.is_some() {
440                        Err(ErrorInner::FieldRequiresMinimumVersion(
441                            "workflow_retention",
442                            3,
443                        ))?;
444                    }
445                    if handler.enable_lazy_state.is_some() {
446                        Err(ErrorInner::FieldRequiresMinimumVersion(
447                            "enable_lazy_state",
448                            3,
449                        ))?;
450                    }
451                    if handler.ingress_private.is_some() {
452                        Err(ErrorInner::FieldRequiresMinimumVersion(
453                            "ingress_private",
454                            3,
455                        ))?;
456                    }
457                }
458            }
459        }
460        Ok(())
461    }
462}
463
464type ResponseBodyInner = Either<Full<Bytes>, InvocationRunnerBody>;
465pin_project! {
466    pub struct ResponseBody {
467        #[pin]
468        inner: ResponseBodyInner
469    }
470}
471
472impl From<ResponseBodyInner> for ResponseBody {
473    fn from(e: ResponseBodyInner) -> Self {
474        ResponseBody { inner: e }
475    }
476}
477
478impl Body for ResponseBody {
479    type Data = <ResponseBodyInner as Body>::Data;
480    type Error = <ResponseBodyInner as Body>::Error;
481
482    fn poll_frame(
483        self: Pin<&mut Self>,
484        cx: &mut Context<'_>,
485    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
486        self.project().inner.poll_frame(cx)
487    }
488
489    fn is_end_stream(&self) -> bool {
490        self.inner.is_end_stream()
491    }
492
493    fn size_hint(&self) -> SizeHint {
494        self.inner.size_hint()
495    }
496}
497
498fn simple_response(
499    status_code: u16,
500    headers: Vec<Header>,
501    body: Bytes,
502) -> http::Response<ResponseBody> {
503    let mut response_builder = http::Response::builder()
504        .status(status_code)
505        .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
506
507    for header in headers {
508        response_builder = response_builder.header(header.key.deref(), header.value.deref());
509    }
510
511    response_builder
512        .body(Either::Left(Full::new(body)).into())
513        .expect("headers must be valid")
514}
515
516fn error_response(e: impl Into<Error>) -> http::Response<ResponseBody> {
517    let error = e.into();
518    http::Response::builder()
519        .status(error.status_code())
520        .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE)
521        .header(CONTENT_TYPE, "text/plain")
522        .body(Either::Left(Full::new(error.to_string().into())).into())
523        .expect("headers must be valid")
524}
525
526// --- Handle invocation future
527
528struct OutputSender(mpsc::UnboundedSender<Bytes>);
529
530impl OutputSender {
531    fn from_channel(tx: mpsc::UnboundedSender<Bytes>) -> Self {
532        Self(tx)
533    }
534
535    fn send(&self, b: Bytes) -> bool {
536        self.0.send(b).is_ok()
537    }
538}
539
540struct InputReceiver(Pin<Box<dyn Stream<Item = Result<Bytes, BoxError>> + Send + 'static>>);
541
542impl InputReceiver {
543    fn from_stream<S: Stream<Item = Result<Bytes, BoxError>> + Send + 'static>(s: S) -> Self {
544        Self(Box::pin(s))
545    }
546
547    async fn recv(&mut self) -> Option<Result<Bytes, BoxError>> {
548        poll_fn(|cx| self.poll_recv(cx)).await
549    }
550
551    fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, BoxError>>> {
552        self.0.poll_next_unpin(cx)
553    }
554}
555
556async fn handle_invocation(
557    svc_name: String,
558    handler_name: String,
559    mut vm: CoreVM,
560    endpoint: Arc<EndpointInner>,
561    mut input_rx: InputReceiver,
562    output_tx: OutputSender,
563) -> Result<(), Error> {
564    // Retrieve the service from the Arc
565    let svc = endpoint
566        .svcs
567        .get(&svc_name)
568        .expect("service must exist at this point");
569
570    let span = info_span!(
571        "restate_sdk_endpoint_handle",
572        "rpc.system" = "restate",
573        "rpc.service" = svc_name,
574        "rpc.method" = handler_name,
575        "restate.sdk.is_replaying" = false
576    );
577    async move {
578        init_loop_vm(&mut vm, &mut input_rx).await?;
579
580        // Initialize handler context
581        let (handler_state_tx, handler_state_rx) = HandlerStateNotifier::new();
582        let ctx = ContextInternal::new(
583            vm,
584            svc_name,
585            handler_name,
586            input_rx,
587            output_tx,
588            handler_state_tx,
589        );
590
591        // Start user code
592        let user_code_fut = InterceptErrorFuture::new(ctx.clone(), svc.handle(ctx.clone()));
593
594        // Wrap it in handler state aware future
595        let result =
596            HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await;
597
598        // Drain the request input stream before returning. This ensures we don't
599        // close the HTTP/2 response stream before the request stream is done,
600        // which causes connection errors on proxies like Google Cloud Run.
601        ctx.drain_input().await?;
602
603        result
604    }
605    .instrument(span)
606    .await
607}
608
609async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> {
610    while !vm.is_ready_to_execute().map_err(ErrorInner::VM)? {
611        match input_rx.recv().await {
612            Some(Ok(b)) => vm.notify_input(b),
613            Some(Err(e)) => vm.notify_error(
614                CoreError::new(500u16, format!("Error when reading the body: {e}")),
615                None,
616            ),
617            None => vm.notify_input_closed(),
618        }
619    }
620    Ok(())
621}
622
623// --- Invocation runner body
624
625pub struct InvocationRunnerBody {
626    fut: Option<BoxFuture<'static, Result<(), Error>>>,
627    output_rx: mpsc::UnboundedReceiver<Bytes>,
628    end_stream: bool,
629}
630
631impl Body for InvocationRunnerBody {
632    type Data = Bytes;
633    type Error = Infallible;
634
635    fn poll_frame(
636        mut self: Pin<&mut Self>,
637        cx: &mut Context<'_>,
638    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
639        // First try to consume the runner future
640        if let Some(mut fut) = self.fut.take() {
641            match fut.poll_unpin(cx) {
642                Poll::Ready(res) => {
643                    if let Err(e) = res {
644                        warn!("Handler failure: {e:?}")
645                    }
646                    self.output_rx.close();
647                }
648                Poll::Pending => {
649                    self.fut = Some(fut);
650                }
651            }
652        }
653
654        if let Some(out) = ready!(self.output_rx.poll_recv(cx)) {
655            Poll::Ready(Some(Ok(Frame::data(out))))
656        } else {
657            self.end_stream = true;
658            Poll::Ready(None)
659        }
660    }
661
662    fn is_end_stream(&self) -> bool {
663        self.end_stream
664    }
665}