1mod builder;
2mod context;
3mod futures;
4mod handler_state;
5
6pub use builder::{Builder, HandlerOptions, ServiceOptions};
7
8use crate::endpoint::futures::handler_state_aware::HandlerStateAwareFuture;
9use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::service::Service;
12use ::futures::future::BoxFuture;
13use ::futures::{FutureExt, Stream, StreamExt, TryStreamExt};
14use bytes::Bytes;
15pub use context::{ContextInternal, InputMetadata};
16use http::header::CONTENT_TYPE;
17use http::{HeaderName, HeaderValue};
18use http_body::{Body, Frame, SizeHint};
19use http_body_util::{BodyExt, Either, Full};
20use pin_project_lite::pin_project;
21use restate_sdk_shared_core::{
22 CoreVM, Error as CoreError, Header, HeaderMap, IdentityVerifier, ResponseHead, VM, VerifyError,
23};
24use std::collections::HashMap;
25use std::convert::Infallible;
26use std::future::poll_fn;
27use std::ops::Deref;
28use std::pin::Pin;
29use std::sync::Arc;
30use std::task::{Context, Poll, ready};
31use tokio::sync::mpsc;
32use tracing::{Instrument, info_span, warn};
33
34#[allow(clippy::declare_interior_mutable_const)]
35const X_RESTATE_SERVER: HeaderName = HeaderName::from_static("x-restate-server");
36const X_RESTATE_SERVER_VALUE: HeaderValue =
37 HeaderValue::from_static(concat!("restate-sdk-rust/", env!("CARGO_PKG_VERSION")));
38const DISCOVERY_CONTENT_TYPE_V2: &str = "application/vnd.restate.endpointmanifest.v2+json";
39const DISCOVERY_CONTENT_TYPE_V3: &str = "application/vnd.restate.endpointmanifest.v3+json";
40const DISCOVERY_CONTENT_TYPE_V4: &str = "application/vnd.restate.endpointmanifest.v4+json";
41
42type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
43
44#[derive(Debug, thiserror::Error)]
47#[error(transparent)]
48pub struct Error(#[from] ErrorInner);
49
50impl Error {
51 pub fn unknown_handler(service_name: &str, handler_name: &str) -> Self {
53 Self(ErrorInner::UnknownServiceHandler(
54 service_name.to_owned(),
55 handler_name.to_owned(),
56 ))
57 }
58
59 pub fn status_code(&self) -> u16 {
61 match &self.0 {
62 ErrorInner::VM(e) => e.code(),
63 ErrorInner::UnknownService(_) | ErrorInner::UnknownServiceHandler(_, _) => 404,
64 ErrorInner::Suspended
65 | ErrorInner::UnexpectedOutputClosed
66 | ErrorInner::UnexpectedValueVariantForSyscall { .. }
67 | ErrorInner::Deserialization { .. }
68 | ErrorInner::Serialization { .. }
69 | ErrorInner::HandlerResult { .. }
70 | ErrorInner::InputDrain(_) => 500,
71 ErrorInner::FieldRequiresMinimumVersion { .. } => 500,
72 ErrorInner::BadDiscoveryVersion(_) => 415,
73 ErrorInner::Header { .. } | ErrorInner::BadPath { .. } => 400,
74 ErrorInner::IdentityVerification(_) => 401,
75 }
76 }
77}
78
79#[derive(Debug, thiserror::Error)]
80pub(crate) enum ErrorInner {
81 #[error("Received a request for unknown service '{0}'")]
82 UnknownService(String),
83 #[error("Received a request for unknown service handler '{0}/{1}'")]
84 UnknownServiceHandler(String, String),
85 #[error("Error when processing the request: {0:?}")]
86 VM(#[from] CoreError),
87 #[error("Error when verifying identity: {0:?}")]
88 IdentityVerification(#[from] VerifyError),
89 #[error("Cannot convert header '{0}', reason: {1}")]
90 Header(String, #[source] BoxError),
91 #[error(
92 "Cannot reply to discovery, got accept header '{0}' but currently supported discovery versions are v2 and v3"
93 )]
94 BadDiscoveryVersion(String),
95 #[error(
96 "The field '{0}' was set in the service/handler options, but it requires minimum discovery protocol version {1}"
97 )]
98 FieldRequiresMinimumVersion(&'static str, u32),
99 #[error("Bad path '{0}', expected either '/discover' or '/invoke/service/handler'")]
100 BadPath(String),
101 #[error("Suspended")]
102 Suspended,
103 #[error("Unexpected output closed")]
104 UnexpectedOutputClosed,
105 #[error("Unexpected value variant {variant} for syscall '{syscall}'")]
106 UnexpectedValueVariantForSyscall {
107 variant: &'static str,
108 syscall: &'static str,
109 },
110 #[error("Failed to deserialize with '{syscall}': {err:?}'")]
111 Deserialization {
112 syscall: &'static str,
113 #[source]
114 err: BoxError,
115 },
116 #[error("Failed to serialize with '{syscall}': {err:?}'")]
117 Serialization {
118 syscall: &'static str,
119 #[source]
120 err: BoxError,
121 },
122 #[error("Handler failed with retryable error: {err:?}'")]
123 HandlerResult {
124 #[source]
125 err: BoxError,
126 },
127 #[error("Error while draining the input stream: {0}")]
128 InputDrain(BoxError),
129}
130
131impl From<CoreError> for Error {
132 fn from(e: CoreError) -> Self {
133 if e.is_suspended_error() {
134 return ErrorInner::Suspended.into();
135 }
136 ErrorInner::from(e).into()
137 }
138}
139
140struct BoxedService(
141 Box<dyn Service<Future = BoxFuture<'static, Result<(), Error>>> + Send + Sync + 'static>,
142);
143
144impl BoxedService {
145 pub fn new<
146 S: Service<Future = BoxFuture<'static, Result<(), Error>>> + Send + Sync + 'static,
147 >(
148 service: S,
149 ) -> Self {
150 Self(Box::new(service))
151 }
152}
153
154impl Service for BoxedService {
155 type Future = BoxFuture<'static, Result<(), Error>>;
156
157 fn handle(&self, req: ContextInternal) -> Self::Future {
158 self.0.handle(req)
159 }
160}
161
162#[derive(Clone)]
167pub struct Endpoint(Arc<EndpointInner>);
168
169impl Endpoint {
170 pub fn builder() -> Builder {
172 Builder::new()
173 }
174}
175
176struct EndpointInner {
177 svcs: HashMap<String, BoxedService>,
178 discovery_services: Vec<crate::discovery::Service>,
179 identity_verifier: IdentityVerifier,
180}
181
182#[derive(Default)]
183pub enum ProtocolMode {
184 #[allow(dead_code)]
185 RequestResponse,
186 #[default]
187 BidiStream,
188}
189
190#[derive(Default)]
192pub struct HandleOptions {
193 pub protocol_mode: ProtocolMode,
194}
195
196impl Endpoint {
197 pub fn handle<B: Body<Data = Bytes, Error: Into<BoxError> + Send> + Send + 'static>(
199 &self,
200 req: http::Request<B>,
201 ) -> http::Response<ResponseBody> {
202 self.handle_with_options(req, HandleOptions::default())
203 }
204
205 pub fn handle_with_options<
207 B: Body<Data = Bytes, Error: Into<BoxError> + Send> + Send + 'static,
208 >(
209 &self,
210 req: http::Request<B>,
211 options: HandleOptions,
212 ) -> http::Response<ResponseBody> {
213 let (parts, body) = req.into_parts();
214 let path = parts.uri.path();
215 let headers = parts.headers;
216
217 if let Err(e) = self.0.identity_verifier.verify_identity(&headers, path) {
218 return error_response(ErrorInner::IdentityVerification(e));
219 }
220
221 let parts: Vec<&str> = path.split('/').collect();
222
223 if parts.last() == Some(&"health") {
224 return self.handle_health();
225 }
226 if parts.last() == Some(&"discover") {
227 return self.handle_discovery(headers, options.protocol_mode);
228 }
229
230 let (svc_name, handler_name) = match parts.get(parts.len() - 3..) {
232 None => return error_response(ErrorInner::BadPath(path.to_owned())),
233 Some(last_elements) if last_elements[0] != "invoke" => {
234 return error_response(ErrorInner::BadPath(path.to_owned()));
235 }
236 Some(last_elements) => (last_elements[1].to_owned(), last_elements[2].to_owned()),
237 };
238
239 let vm = match CoreVM::new(headers, Default::default()) {
241 Ok(vm) => vm,
242 Err(e) => return error_response(e),
243 };
244 let ResponseHead {
245 status_code,
246 headers,
247 ..
248 } = vm.get_response_head();
249
250 if !self.0.svcs.contains_key(&svc_name) {
252 return error_response(ErrorInner::UnknownService(svc_name.to_owned()));
253 }
254
255 let input_receiver =
257 InputReceiver::from_stream(body.into_data_stream().map_err(|e| e.into()));
258 let (output_tx, output_rx) = mpsc::unbounded_channel();
259 let output_sender = OutputSender::from_channel(output_tx);
260 let handle_invocation_fut = Box::pin(handle_invocation(
261 svc_name,
262 handler_name,
263 vm,
264 Arc::clone(&self.0),
265 input_receiver,
266 output_sender,
267 ));
268
269 let mut invocation_response_builder = http::Response::builder()
272 .status(status_code)
273 .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
274 for Header { key, value } in headers {
275 invocation_response_builder =
276 invocation_response_builder.header(key.deref(), value.deref());
277 }
278 invocation_response_builder
279 .body(
280 Either::Right(InvocationRunnerBody {
281 fut: Some(handle_invocation_fut),
282 output_rx,
283 end_stream: false,
284 })
285 .into(),
286 )
287 .expect("Headers should be valid")
288 }
289
290 fn handle_health(&self) -> http::Response<ResponseBody> {
291 simple_response(200, vec![], Bytes::default())
292 }
293
294 fn handle_discovery(
295 &self,
296 headers: http::HeaderMap,
297 protocol_mode: ProtocolMode,
298 ) -> http::Response<ResponseBody> {
299 let accept_header = match headers
301 .extract("accept")
302 .map_err(|e| ErrorInner::Header("accept".to_owned(), Box::new(e)))
303 {
304 Ok(h) => h,
305 Err(e) => return error_response(e),
306 };
307
308 let mut version = 2;
310 let mut content_type = DISCOVERY_CONTENT_TYPE_V2;
311 if let Some(accept) = accept_header {
312 if accept.contains(DISCOVERY_CONTENT_TYPE_V4) {
313 version = 4;
314 content_type = DISCOVERY_CONTENT_TYPE_V4;
315 } else if accept.contains(DISCOVERY_CONTENT_TYPE_V3) {
316 version = 3;
317 content_type = DISCOVERY_CONTENT_TYPE_V3;
318 } else if accept.contains(DISCOVERY_CONTENT_TYPE_V2) {
319 version = 2;
320 content_type = DISCOVERY_CONTENT_TYPE_V2;
321 } else {
322 return error_response(ErrorInner::BadDiscoveryVersion(accept.to_owned()));
323 }
324 }
325
326 if let Err(e) = self.validate_discovery_request(version) {
327 return error_response(e);
328 }
329
330 simple_response(
331 200,
332 vec![Header {
333 key: "content-type".into(),
334 value: content_type.into(),
335 }],
336 Bytes::from(
337 serde_json::to_string(&crate::discovery::Endpoint {
338 lambda_compression: None,
339 max_protocol_version: std::num::NonZero::new(5).unwrap(),
340 min_protocol_version: std::num::NonZero::new(5).unwrap(),
341 protocol_mode: Some(match protocol_mode {
342 ProtocolMode::RequestResponse => {
343 crate::discovery::ProtocolMode::RequestResponse
344 }
345 ProtocolMode::BidiStream => crate::discovery::ProtocolMode::BidiStream,
346 }),
347 services: self.0.discovery_services.clone(),
348 })
349 .expect("Discovery should be serializable"),
350 ),
351 )
352 }
353
354 fn validate_discovery_request(&self, version: usize) -> Result<(), ErrorInner> {
355 if version <= 3 {
357 for service in &self.0.discovery_services {
359 if service.retry_policy_initial_interval.is_some()
360 || service.retry_policy_exponentiation_factor.is_some()
361 || service.retry_policy_max_interval.is_some()
362 || service.retry_policy_max_attempts.is_some()
363 || service.retry_policy_on_max_attempts.is_some()
364 {
365 Err(ErrorInner::FieldRequiresMinimumVersion("retry_policy", 4))?;
366 }
367
368 for handler in &service.handlers {
369 if handler.retry_policy_initial_interval.is_some()
370 || handler.retry_policy_exponentiation_factor.is_some()
371 || handler.retry_policy_max_interval.is_some()
372 || handler.retry_policy_max_attempts.is_some()
373 || handler.retry_policy_on_max_attempts.is_some()
374 {
375 Err(ErrorInner::FieldRequiresMinimumVersion("retry_policy", 4))?;
376 }
377 }
378 }
379 }
380 if version <= 2 {
381 for service in &self.0.discovery_services {
383 if service.inactivity_timeout.is_some() {
384 Err(ErrorInner::FieldRequiresMinimumVersion(
385 "inactivity_timeout",
386 3,
387 ))?;
388 }
389 if service.abort_timeout.is_some() {
390 Err(ErrorInner::FieldRequiresMinimumVersion("abort_timeout", 3))?;
391 }
392 if service.idempotency_retention.is_some() {
393 Err(ErrorInner::FieldRequiresMinimumVersion(
394 "idempotency_retention",
395 3,
396 ))?;
397 }
398 if service.journal_retention.is_some() {
399 Err(ErrorInner::FieldRequiresMinimumVersion(
400 "journal_retention",
401 3,
402 ))?;
403 }
404 if service.enable_lazy_state.is_some() {
405 Err(ErrorInner::FieldRequiresMinimumVersion(
406 "enable_lazy_state",
407 3,
408 ))?;
409 }
410 if service.ingress_private.is_some() {
411 Err(ErrorInner::FieldRequiresMinimumVersion(
412 "ingress_private",
413 3,
414 ))?;
415 }
416
417 for handler in &service.handlers {
418 if handler.inactivity_timeout.is_some() {
419 Err(ErrorInner::FieldRequiresMinimumVersion(
420 "inactivity_timeout",
421 3,
422 ))?;
423 }
424 if handler.abort_timeout.is_some() {
425 Err(ErrorInner::FieldRequiresMinimumVersion("abort_timeout", 3))?;
426 }
427 if handler.idempotency_retention.is_some() {
428 Err(ErrorInner::FieldRequiresMinimumVersion(
429 "idempotency_retention",
430 3,
431 ))?;
432 }
433 if handler.journal_retention.is_some() {
434 Err(ErrorInner::FieldRequiresMinimumVersion(
435 "journal_retention",
436 3,
437 ))?;
438 }
439 if handler.workflow_completion_retention.is_some() {
440 Err(ErrorInner::FieldRequiresMinimumVersion(
441 "workflow_retention",
442 3,
443 ))?;
444 }
445 if handler.enable_lazy_state.is_some() {
446 Err(ErrorInner::FieldRequiresMinimumVersion(
447 "enable_lazy_state",
448 3,
449 ))?;
450 }
451 if handler.ingress_private.is_some() {
452 Err(ErrorInner::FieldRequiresMinimumVersion(
453 "ingress_private",
454 3,
455 ))?;
456 }
457 }
458 }
459 }
460 Ok(())
461 }
462}
463
464type ResponseBodyInner = Either<Full<Bytes>, InvocationRunnerBody>;
465pin_project! {
466 pub struct ResponseBody {
467 #[pin]
468 inner: ResponseBodyInner
469 }
470}
471
472impl From<ResponseBodyInner> for ResponseBody {
473 fn from(e: ResponseBodyInner) -> Self {
474 ResponseBody { inner: e }
475 }
476}
477
478impl Body for ResponseBody {
479 type Data = <ResponseBodyInner as Body>::Data;
480 type Error = <ResponseBodyInner as Body>::Error;
481
482 fn poll_frame(
483 self: Pin<&mut Self>,
484 cx: &mut Context<'_>,
485 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
486 self.project().inner.poll_frame(cx)
487 }
488
489 fn is_end_stream(&self) -> bool {
490 self.inner.is_end_stream()
491 }
492
493 fn size_hint(&self) -> SizeHint {
494 self.inner.size_hint()
495 }
496}
497
498fn simple_response(
499 status_code: u16,
500 headers: Vec<Header>,
501 body: Bytes,
502) -> http::Response<ResponseBody> {
503 let mut response_builder = http::Response::builder()
504 .status(status_code)
505 .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE);
506
507 for header in headers {
508 response_builder = response_builder.header(header.key.deref(), header.value.deref());
509 }
510
511 response_builder
512 .body(Either::Left(Full::new(body)).into())
513 .expect("headers must be valid")
514}
515
516fn error_response(e: impl Into<Error>) -> http::Response<ResponseBody> {
517 let error = e.into();
518 http::Response::builder()
519 .status(error.status_code())
520 .header(X_RESTATE_SERVER, X_RESTATE_SERVER_VALUE)
521 .header(CONTENT_TYPE, "text/plain")
522 .body(Either::Left(Full::new(error.to_string().into())).into())
523 .expect("headers must be valid")
524}
525
526struct OutputSender(mpsc::UnboundedSender<Bytes>);
529
530impl OutputSender {
531 fn from_channel(tx: mpsc::UnboundedSender<Bytes>) -> Self {
532 Self(tx)
533 }
534
535 fn send(&self, b: Bytes) -> bool {
536 self.0.send(b).is_ok()
537 }
538}
539
540struct InputReceiver(Pin<Box<dyn Stream<Item = Result<Bytes, BoxError>> + Send + 'static>>);
541
542impl InputReceiver {
543 fn from_stream<S: Stream<Item = Result<Bytes, BoxError>> + Send + 'static>(s: S) -> Self {
544 Self(Box::pin(s))
545 }
546
547 async fn recv(&mut self) -> Option<Result<Bytes, BoxError>> {
548 poll_fn(|cx| self.poll_recv(cx)).await
549 }
550
551 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<Result<Bytes, BoxError>>> {
552 self.0.poll_next_unpin(cx)
553 }
554}
555
556async fn handle_invocation(
557 svc_name: String,
558 handler_name: String,
559 mut vm: CoreVM,
560 endpoint: Arc<EndpointInner>,
561 mut input_rx: InputReceiver,
562 output_tx: OutputSender,
563) -> Result<(), Error> {
564 let svc = endpoint
566 .svcs
567 .get(&svc_name)
568 .expect("service must exist at this point");
569
570 let span = info_span!(
571 "restate_sdk_endpoint_handle",
572 "rpc.system" = "restate",
573 "rpc.service" = svc_name,
574 "rpc.method" = handler_name,
575 "restate.sdk.is_replaying" = false
576 );
577 async move {
578 init_loop_vm(&mut vm, &mut input_rx).await?;
579
580 let (handler_state_tx, handler_state_rx) = HandlerStateNotifier::new();
582 let ctx = ContextInternal::new(
583 vm,
584 svc_name,
585 handler_name,
586 input_rx,
587 output_tx,
588 handler_state_tx,
589 );
590
591 let user_code_fut = InterceptErrorFuture::new(ctx.clone(), svc.handle(ctx.clone()));
593
594 let result =
596 HandlerStateAwareFuture::new(ctx.clone(), handler_state_rx, user_code_fut).await;
597
598 ctx.drain_input().await?;
602
603 result
604 }
605 .instrument(span)
606 .await
607}
608
609async fn init_loop_vm(vm: &mut CoreVM, input_rx: &mut InputReceiver) -> Result<(), ErrorInner> {
610 while !vm.is_ready_to_execute().map_err(ErrorInner::VM)? {
611 match input_rx.recv().await {
612 Some(Ok(b)) => vm.notify_input(b),
613 Some(Err(e)) => vm.notify_error(
614 CoreError::new(500u16, format!("Error when reading the body: {e}")),
615 None,
616 ),
617 None => vm.notify_input_closed(),
618 }
619 }
620 Ok(())
621}
622
623pub struct InvocationRunnerBody {
626 fut: Option<BoxFuture<'static, Result<(), Error>>>,
627 output_rx: mpsc::UnboundedReceiver<Bytes>,
628 end_stream: bool,
629}
630
631impl Body for InvocationRunnerBody {
632 type Data = Bytes;
633 type Error = Infallible;
634
635 fn poll_frame(
636 mut self: Pin<&mut Self>,
637 cx: &mut Context<'_>,
638 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
639 if let Some(mut fut) = self.fut.take() {
641 match fut.poll_unpin(cx) {
642 Poll::Ready(res) => {
643 if let Err(e) = res {
644 warn!("Handler failure: {e:?}")
645 }
646 self.output_rx.close();
647 }
648 Poll::Pending => {
649 self.fut = Some(fut);
650 }
651 }
652 }
653
654 if let Some(out) = ready!(self.output_rx.poll_recv(cx)) {
655 Poll::Ready(Some(Ok(Frame::data(out))))
656 } else {
657 self.end_stream = true;
658 Poll::Ready(None)
659 }
660 }
661
662 fn is_end_stream(&self) -> bool {
663 self.end_stream
664 }
665}