Skip to main content

restate_sdk/endpoint/
context.rs

1use crate::context::{
2    CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3    RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18    CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19    PayloadOptions, RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM,
20    Value,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::future::{Future, poll_fn, ready};
25use std::marker::PhantomData;
26use std::mem;
27use std::pin::Pin;
28use std::sync::{Arc, Mutex};
29use std::task::{Context, Poll, ready};
30use std::time::{Duration, Instant, SystemTime};
31
32pub struct ContextInternalInner {
33    pub(crate) vm: CoreVM,
34    pub(crate) read: InputReceiver,
35    pub(crate) write: OutputSender,
36    pub(super) handler_state: HandlerStateNotifier,
37
38    /// We remember here the state of the span replaying field state, because setting it might be expensive (it's guarded behind locks and other stuff).
39    /// For details, see [ContextInternalInner::maybe_flip_span_replaying_field]
40    pub(super) span_replaying_field_state: bool,
41}
42
43impl ContextInternalInner {
44    fn new(
45        vm: CoreVM,
46        read: InputReceiver,
47        write: OutputSender,
48        handler_state: HandlerStateNotifier,
49    ) -> Self {
50        Self {
51            vm,
52            read,
53            write,
54            handler_state,
55            span_replaying_field_state: false,
56        }
57    }
58
59    pub(super) fn fail(&mut self, e: Error) {
60        self.maybe_flip_span_replaying_field();
61        self.vm.notify_error(
62            CoreError::new(500u16, e.0.to_string())
63                .with_stacktrace(Cow::<str>::Owned(format!("{:#}", e.0))),
64            None,
65        );
66        self.handler_state.mark_error(e);
67    }
68
69    pub(super) fn maybe_flip_span_replaying_field(&mut self) {
70        if !self.span_replaying_field_state && self.vm.is_replaying() {
71            tracing::Span::current().record("restate.sdk.is_replaying", true);
72            self.span_replaying_field_state = true;
73        } else if self.span_replaying_field_state && !self.vm.is_replaying() {
74            tracing::Span::current().record("restate.sdk.is_replaying", false);
75            self.span_replaying_field_state = false;
76        }
77    }
78}
79
80#[allow(unused)]
81const fn is_send_sync<T: Send + Sync>() {}
82const _: () = is_send_sync::<ContextInternal>();
83
84macro_rules! must_lock {
85    ($mutex:expr) => {
86        $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
87    };
88}
89
90macro_rules! unwrap_or_trap {
91    ($inner_lock:expr, $res:expr) => {
92        match $res {
93            Ok(t) => t,
94            Err(e) => {
95                $inner_lock.fail(e.into());
96                return Either::Right(TrapFuture::default());
97            }
98        }
99    };
100}
101
102macro_rules! unwrap_or_trap_durable_future {
103    ($ctx:expr, $inner_lock:expr, $res:expr) => {
104        match $res {
105            Ok(t) => t,
106            Err(e) => {
107                $inner_lock.fail(e.into());
108                return DurableFutureImpl::new(
109                    $ctx.clone(),
110                    NotificationHandle::from(u32::MAX),
111                    Either::Right(TrapFuture::default()),
112                );
113            }
114        }
115    };
116}
117
118#[derive(Debug, Eq, PartialEq)]
119pub struct InputMetadata {
120    pub invocation_id: String,
121    pub random_seed: u64,
122    pub key: String,
123    pub headers: http::HeaderMap<String>,
124}
125
126impl From<RequestTarget> for Target {
127    fn from(value: RequestTarget) -> Self {
128        match value {
129            RequestTarget::Service { name, handler } => Target {
130                service: name,
131                handler,
132                key: None,
133                idempotency_key: None,
134                headers: vec![],
135            },
136            RequestTarget::Object { name, key, handler } => Target {
137                service: name,
138                handler,
139                key: Some(key),
140                idempotency_key: None,
141                headers: vec![],
142            },
143            RequestTarget::Workflow { name, key, handler } => Target {
144                service: name,
145                handler,
146                key: Some(key),
147                idempotency_key: None,
148                headers: vec![],
149            },
150        }
151    }
152}
153
154/// Internal context interface.
155///
156/// For the high level interfaces, look at [`crate::context`].
157#[derive(Clone)]
158pub struct ContextInternal {
159    svc_name: String,
160    handler_name: String,
161    inner: Arc<Mutex<ContextInternalInner>>,
162}
163
164impl ContextInternal {
165    pub(super) fn new(
166        vm: CoreVM,
167        svc_name: String,
168        handler_name: String,
169        read: InputReceiver,
170        write: OutputSender,
171        handler_state: HandlerStateNotifier,
172    ) -> Self {
173        Self {
174            svc_name,
175            handler_name,
176            inner: Arc::new(Mutex::new(ContextInternalInner::new(
177                vm,
178                read,
179                write,
180                handler_state,
181            ))),
182        }
183    }
184
185    pub fn service_name(&self) -> &str {
186        &self.svc_name
187    }
188
189    pub fn handler_name(&self) -> &str {
190        &self.handler_name
191    }
192
193    pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
194        let mut inner_lock = must_lock!(self.inner);
195        let input_result =
196            inner_lock
197                .vm
198                .sys_input()
199                .map_err(ErrorInner::VM)
200                .map(|mut raw_input| {
201                    let headers = http::HeaderMap::<String>::try_from(
202                        &raw_input
203                            .headers
204                            .into_iter()
205                            .map(|h| (h.key.to_string(), h.value.to_string()))
206                            .collect::<HashMap<String, String>>(),
207                    )
208                    .map_err(|e| {
209                        TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
210                    })?;
211
212                    Ok::<_, TerminalError>((
213                        T::deserialize(&mut (raw_input.input)).map_err(|e| {
214                            TerminalError::new_with_code(
215                                400,
216                                format!("Cannot decode input payload: {e:?}"),
217                            )
218                        })?,
219                        InputMetadata {
220                            invocation_id: raw_input.invocation_id,
221                            random_seed: raw_input.random_seed,
222                            key: raw_input.key,
223                            headers,
224                        },
225                    ))
226                });
227        inner_lock.maybe_flip_span_replaying_field();
228
229        match input_result {
230            Ok(Ok(i)) => {
231                drop(inner_lock);
232                return Either::Left(ready(i));
233            }
234            Ok(Err(err)) => {
235                let error_inner = ErrorInner::Deserialization {
236                    syscall: "input",
237                    err: err.0.clone().into(),
238                };
239                let _ = inner_lock
240                    .vm
241                    .sys_write_output(NonEmptyValue::Failure(err.into()), PayloadOptions::stable());
242                let _ = inner_lock.vm.sys_end();
243                // This causes the trap, plus logs the error
244                inner_lock.handler_state.mark_error(error_inner.into());
245                drop(inner_lock);
246            }
247            Err(e) => {
248                inner_lock.fail(e.into());
249                drop(inner_lock);
250            }
251        }
252        Either::Right(TrapFuture::default())
253    }
254
255    pub fn get<T: Deserialize>(
256        &self,
257        key: &str,
258    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
259        let mut inner_lock = must_lock!(self.inner);
260        let handle = unwrap_or_trap!(
261            inner_lock,
262            inner_lock
263                .vm
264                .sys_state_get(key.to_owned(), PayloadOptions::stable())
265        );
266        inner_lock.maybe_flip_span_replaying_field();
267
268        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
269            Ok(Value::Void) => Ok(Ok(None)),
270            Ok(Value::Success(mut s)) => {
271                let t =
272                    T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
273                Ok(Ok(Some(t)))
274            }
275            Ok(Value::Failure(f)) => Ok(Err(f.into())),
276            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
277                variant: <&'static str>::from(v),
278                syscall: "get_state",
279            }
280            .into()),
281            Err(e) => Err(e),
282        });
283
284        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
285    }
286
287    pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
288        let mut inner_lock = must_lock!(self.inner);
289        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
290        inner_lock.maybe_flip_span_replaying_field();
291
292        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
293            Ok(Value::Failure(f)) => Ok(Err(f.into())),
294            Ok(Value::StateKeys(s)) => Ok(Ok(s)),
295            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
296                variant: <&'static str>::from(v),
297                syscall: "get_keys",
298            }
299            .into()),
300            Err(e) => Err(e),
301        });
302
303        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
304    }
305
306    pub fn set<T: Serialize>(&self, key: &str, t: T) {
307        let mut inner_lock = must_lock!(self.inner);
308        match t.serialize() {
309            Ok(b) => {
310                let _ = inner_lock
311                    .vm
312                    .sys_state_set(key.to_owned(), b, PayloadOptions::stable());
313                inner_lock.maybe_flip_span_replaying_field();
314            }
315            Err(e) => {
316                inner_lock.fail(Error::serialization("set_state", e));
317            }
318        }
319    }
320
321    pub fn clear(&self, key: &str) {
322        let mut inner_lock = must_lock!(self.inner);
323        let _ = inner_lock.vm.sys_state_clear(key.to_string());
324        inner_lock.maybe_flip_span_replaying_field();
325    }
326
327    pub fn clear_all(&self) {
328        let mut inner_lock = must_lock!(self.inner);
329        let _ = inner_lock.vm.sys_state_clear_all();
330        inner_lock.maybe_flip_span_replaying_field();
331    }
332
333    pub fn select(
334        &self,
335        handles: Vec<NotificationHandle>,
336    ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
337        InterceptErrorFuture::new(
338            self.clone(),
339            VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
340        )
341    }
342
343    pub fn sleep(
344        &self,
345        sleep_duration: Duration,
346    ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
347        let now = SystemTime::now()
348            .duration_since(SystemTime::UNIX_EPOCH)
349            .expect("Duration since unix epoch cannot fail");
350        let mut inner_lock = must_lock!(self.inner);
351        let handle = unwrap_or_trap_durable_future!(
352            self,
353            inner_lock,
354            inner_lock
355                .vm
356                .sys_sleep(String::default(), now + sleep_duration, Some(now))
357        );
358        inner_lock.maybe_flip_span_replaying_field();
359
360        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
361            Ok(Value::Void) => Ok(Ok(())),
362            Ok(Value::Failure(f)) => Ok(Err(f.into())),
363            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
364                variant: <&'static str>::from(v),
365                syscall: "sleep",
366            }
367            .into()),
368            Err(e) => Err(e),
369        });
370
371        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
372    }
373
374    pub fn request<Req, Res>(
375        &self,
376        request_target: RequestTarget,
377        req: Req,
378    ) -> Request<'_, Req, Res> {
379        Request::new(self, request_target, req)
380    }
381
382    pub fn call<Req: Serialize, Res: Deserialize>(
383        &self,
384        request_target: RequestTarget,
385        idempotency_key: Option<String>,
386        headers: Vec<(String, String)>,
387        req: Req,
388    ) -> impl CallFuture<Response = Res> + Send {
389        let mut inner_lock = must_lock!(self.inner);
390
391        let mut target: Target = request_target.into();
392        target.idempotency_key = idempotency_key;
393        target.headers = headers
394            .into_iter()
395            .map(|(k, v)| Header {
396                key: k.into(),
397                value: v.into(),
398            })
399            .collect();
400        let call_result = Req::serialize(&req)
401            .map_err(|e| Error::serialization("call", e))
402            .and_then(|input| {
403                inner_lock
404                    .vm
405                    .sys_call(target, input, None, PayloadOptions::stable())
406                    .map_err(Into::into)
407            });
408
409        let call_handle = match call_result {
410            Ok(t) => t,
411            Err(e) => {
412                inner_lock.fail(e);
413                return CallFutureImpl {
414                    invocation_id_future: Either::Right(TrapFuture::default()).shared(),
415                    result_future: Either::Right(TrapFuture::default()),
416                    call_notification_handle: NotificationHandle::from(u32::MAX),
417                    ctx: self.clone(),
418                };
419            }
420        };
421        inner_lock.maybe_flip_span_replaying_field();
422        drop(inner_lock);
423
424        // Let's prepare the two futures here
425        let invocation_id_fut = InterceptErrorFuture::new(
426            self.clone(),
427            get_async_result(
428                Arc::clone(&self.inner),
429                call_handle.invocation_id_notification_handle,
430            )
431            .map(|res| match res {
432                Ok(Value::Failure(f)) => Ok(Err(f.into())),
433                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
434                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
435                    variant: <&'static str>::from(v),
436                    syscall: "call",
437                }
438                .into()),
439                Err(e) => Err(e),
440            }),
441        );
442        let result_future = get_async_result(
443            Arc::clone(&self.inner),
444            call_handle.call_notification_handle,
445        )
446        .map(|res| match res {
447            Ok(Value::Success(mut s)) => Ok(Ok(
448                Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
449            )),
450            Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
451            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
452                variant: <&'static str>::from(v),
453                syscall: "call",
454            }
455            .into()),
456            Err(e) => Err(e),
457        });
458
459        CallFutureImpl {
460            invocation_id_future: Either::Left(invocation_id_fut).shared(),
461            result_future: Either::Left(result_future),
462            call_notification_handle: call_handle.call_notification_handle,
463            ctx: self.clone(),
464        }
465    }
466
467    pub fn send<Req: Serialize>(
468        &self,
469        request_target: RequestTarget,
470        idempotency_key: Option<String>,
471        headers: Vec<(String, String)>,
472        req: Req,
473        delay: Option<Duration>,
474    ) -> impl InvocationHandle {
475        let mut inner_lock = must_lock!(self.inner);
476
477        let mut target: Target = request_target.into();
478        target.idempotency_key = idempotency_key;
479        target.headers = headers
480            .into_iter()
481            .map(|(k, v)| Header {
482                key: k.into(),
483                value: v.into(),
484            })
485            .collect();
486        let input = match Req::serialize(&req) {
487            Ok(b) => b,
488            Err(e) => {
489                inner_lock.fail(Error::serialization("call", e));
490                return Either::Right(TrapFuture::<()>::default());
491            }
492        };
493
494        let send_handle = match inner_lock.vm.sys_send(
495            target,
496            input,
497            delay.map(|delay| {
498                SystemTime::now()
499                    .duration_since(SystemTime::UNIX_EPOCH)
500                    .expect("Duration since unix epoch cannot fail")
501                    + delay
502            }),
503            None,
504            PayloadOptions::stable(),
505        ) {
506            Ok(h) => h,
507            Err(e) => {
508                inner_lock.fail(e.into());
509                return Either::Right(TrapFuture::<()>::default());
510            }
511        };
512        inner_lock.maybe_flip_span_replaying_field();
513        drop(inner_lock);
514
515        let invocation_id_fut = InterceptErrorFuture::new(
516            self.clone(),
517            get_async_result(
518                Arc::clone(&self.inner),
519                send_handle.invocation_id_notification_handle,
520            )
521            .map(|res| match res {
522                Ok(Value::Failure(f)) => Ok(Err(f.into())),
523                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
524                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
525                    variant: <&'static str>::from(v),
526                    syscall: "call",
527                }
528                .into()),
529                Err(e) => Err(e),
530            }),
531        );
532
533        Either::Left(SendRequestHandle {
534            invocation_id_future: invocation_id_fut.shared(),
535            ctx: self.clone(),
536        })
537    }
538
539    pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
540        InvocationIdBackedInvocationHandle {
541            ctx: self.clone(),
542            invocation_id,
543        }
544    }
545
546    pub fn awakeable<T: Deserialize>(
547        &self,
548    ) -> (
549        String,
550        impl DurableFuture<Output = Result<T, TerminalError>> + Send,
551    ) {
552        let mut inner_lock = must_lock!(self.inner);
553        let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
554        inner_lock.maybe_flip_span_replaying_field();
555
556        let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
557            Ok((s, handle)) => (s, handle),
558            Err(e) => {
559                inner_lock.fail(e.into());
560                return (
561                    // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice.
562                    //  we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place.
563                    "invalid".to_owned(),
564                    DurableFutureImpl::new(
565                        self.clone(),
566                        NotificationHandle::from(u32::MAX),
567                        Either::Right(TrapFuture::default()),
568                    ),
569                );
570            }
571        };
572        drop(inner_lock);
573
574        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
575            Ok(Value::Success(mut s)) => Ok(Ok(
576                T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
577            )),
578            Ok(Value::Failure(f)) => Ok(Err(f.into())),
579            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
580                variant: <&'static str>::from(v),
581                syscall: "awakeable",
582            }
583            .into()),
584            Err(e) => Err(e),
585        });
586
587        (
588            awakeable_id,
589            DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
590        )
591    }
592
593    pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
594        let mut inner_lock = must_lock!(self.inner);
595        match t.serialize() {
596            Ok(b) => {
597                let _ = inner_lock.vm.sys_complete_awakeable(
598                    id.to_owned(),
599                    NonEmptyValue::Success(b),
600                    PayloadOptions::stable(),
601                );
602            }
603            Err(e) => {
604                inner_lock.fail(Error::serialization("resolve_awakeable", e));
605            }
606        }
607    }
608
609    pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
610        let _ = must_lock!(self.inner).vm.sys_complete_awakeable(
611            id.to_owned(),
612            NonEmptyValue::Failure(failure.into()),
613            PayloadOptions::stable(),
614        );
615    }
616
617    pub fn promise<T: Deserialize>(
618        &self,
619        name: &str,
620    ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
621        let mut inner_lock = must_lock!(self.inner);
622        let handle = unwrap_or_trap_durable_future!(
623            self,
624            inner_lock,
625            inner_lock.vm.sys_get_promise(name.to_owned())
626        );
627        inner_lock.maybe_flip_span_replaying_field();
628        drop(inner_lock);
629
630        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
631            Ok(Value::Success(mut s)) => {
632                let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
633                Ok(Ok(t))
634            }
635            Ok(Value::Failure(f)) => Ok(Err(f.into())),
636            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
637                variant: <&'static str>::from(v),
638                syscall: "promise",
639            }
640            .into()),
641            Err(e) => Err(e),
642        });
643
644        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
645    }
646
647    pub fn peek_promise<T: Deserialize>(
648        &self,
649        name: &str,
650    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
651        let mut inner_lock = must_lock!(self.inner);
652        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
653        inner_lock.maybe_flip_span_replaying_field();
654        drop(inner_lock);
655
656        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
657            Ok(Value::Void) => Ok(Ok(None)),
658            Ok(Value::Success(mut s)) => {
659                let t = T::deserialize(&mut s)
660                    .map_err(|e| Error::deserialization("peek_promise", e))?;
661                Ok(Ok(Some(t)))
662            }
663            Ok(Value::Failure(f)) => Ok(Err(f.into())),
664            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
665                variant: <&'static str>::from(v),
666                syscall: "peek_promise",
667            }
668            .into()),
669            Err(e) => Err(e),
670        });
671
672        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
673    }
674
675    pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
676        let mut inner_lock = must_lock!(self.inner);
677        match t.serialize() {
678            Ok(b) => {
679                let _ = inner_lock.vm.sys_complete_promise(
680                    name.to_owned(),
681                    NonEmptyValue::Success(b),
682                    PayloadOptions::stable(),
683                );
684            }
685            Err(e) => {
686                inner_lock.fail(
687                    ErrorInner::Serialization {
688                        syscall: "resolve_promise",
689                        err: Box::new(e),
690                    }
691                    .into(),
692                );
693            }
694        }
695    }
696
697    pub fn reject_promise(&self, id: &str, failure: TerminalError) {
698        let _ = must_lock!(self.inner).vm.sys_complete_promise(
699            id.to_owned(),
700            NonEmptyValue::Failure(failure.into()),
701            PayloadOptions::stable(),
702        );
703    }
704
705    pub fn run<'a, Run, Fut, Out>(
706        &'a self,
707        run_closure: Run,
708    ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
709    where
710        Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
711        Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
712        Out: Serialize + Deserialize + 'static,
713    {
714        let this = Arc::clone(&self.inner);
715        InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
716    }
717
718    // Used by codegen
719    pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
720        let mut inner_lock = must_lock!(self.inner);
721
722        let res_to_write = match res {
723            Ok(success) => match T::serialize(&success) {
724                Ok(t) => NonEmptyValue::Success(t),
725                Err(e) => {
726                    inner_lock.fail(
727                        ErrorInner::Serialization {
728                            syscall: "output",
729                            err: Box::new(e),
730                        }
731                        .into(),
732                    );
733                    return;
734                }
735            },
736            Err(e) => match e.0 {
737                HandlerErrorInner::Retryable(err) => {
738                    inner_lock.fail(ErrorInner::HandlerResult { err }.into());
739                    return;
740                }
741                HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
742            },
743        };
744
745        let _ = inner_lock
746            .vm
747            .sys_write_output(res_to_write, PayloadOptions::stable());
748        inner_lock.maybe_flip_span_replaying_field();
749    }
750
751    pub fn end(&self) {
752        let _ = must_lock!(self.inner).vm.sys_end();
753    }
754
755    pub(crate) fn consume_to_end(&self) {
756        let mut inner_lock = must_lock!(self.inner);
757
758        let out = inner_lock.vm.take_output();
759        if let TakeOutputResult::Buffer(b) = out
760            && !inner_lock.write.send(b)
761        {
762            // Nothing we can do anymore here
763        }
764    }
765
766    /// Drain the request input stream to completion.
767    ///
768    /// This ensures we don't close the HTTP/2 response stream before the request
769    /// stream is done, which causes connection errors on proxies like Google Cloud Run.
770    pub(crate) async fn drain_input(&self) -> Result<(), ErrorInner> {
771        tokio::time::timeout(Duration::from_secs(60), async {
772            loop {
773                let result = poll_fn(|cx| {
774                    let mut inner = must_lock!(self.inner);
775                    inner.read.poll_recv(cx)
776                })
777                .await;
778                match result {
779                    None => return Ok(()),
780                    Some(Ok(_)) => continue,
781                    Some(Err(e)) => return Err(ErrorInner::InputDrain(e)),
782                }
783            }
784        })
785        .await
786        .unwrap_or_else(|_| {
787            Err(ErrorInner::InputDrain(
788                "Timed out draining input stream after 60s".into(),
789            ))
790        })
791    }
792
793    pub(super) fn fail(&self, e: Error) {
794        must_lock!(self.inner).fail(e)
795    }
796}
797
798pin_project! {
799    struct RunFutureImpl<Run, Ret, RunFnFut> {
800        name: String,
801        retry_policy: RetryPolicy,
802        phantom_data: PhantomData<fn() -> Ret>,
803        #[pin]
804        state: RunState<Run, RunFnFut, Ret>,
805    }
806}
807
808pin_project! {
809    #[project = RunStateProj]
810    enum RunState<Run, RunFnFut, Ret> {
811        New {
812            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
813            closure: Option<Run>,
814        },
815        ClosureRunning {
816            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
817            handle: NotificationHandle,
818            start_time: Instant,
819            #[pin]
820            closure_fut: RunFnFut,
821        },
822        WaitingResultFut {
823            result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
824        }
825    }
826}
827
828impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
829    fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
830        Self {
831            name: "".to_string(),
832            retry_policy: RetryPolicy::Infinite,
833            phantom_data: PhantomData,
834            state: RunState::New {
835                ctx: Some(ctx),
836                closure: Some(closure),
837            },
838        }
839    }
840
841    fn boxed_result_fut(
842        ctx: Arc<Mutex<ContextInternalInner>>,
843        handle: NotificationHandle,
844    ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
845    where
846        Ret: Deserialize,
847    {
848        get_async_result(Arc::clone(&ctx), handle)
849            .map(|res| match res {
850                Ok(Value::Success(mut s)) => {
851                    let t =
852                        Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
853                    Ok(Ok(t))
854                }
855                Ok(Value::Failure(f)) => Ok(Err(f.into())),
856                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
857                    variant: <&'static str>::from(v),
858                    syscall: "run",
859                }
860                .into()),
861                Err(e) => Err(e),
862            })
863            .boxed()
864    }
865}
866
867impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
868    for RunFutureImpl<Run, Ret, RunFnFut>
869where
870    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
871    Ret: Serialize + Deserialize,
872    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
873{
874    fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
875        self.retry_policy = RetryPolicy::Exponential {
876            initial_interval: retry_policy.initial_delay,
877            factor: retry_policy.factor,
878            max_interval: retry_policy.max_delay,
879            max_attempts: retry_policy.max_attempts,
880            max_duration: retry_policy.max_duration,
881        };
882        self
883    }
884
885    fn name(mut self, name: impl Into<String>) -> Self {
886        self.name = name.into();
887        self
888    }
889}
890
891impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
892where
893    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
894    Ret: Serialize + Deserialize,
895    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
896{
897    type Output = Result<Result<Ret, TerminalError>, Error>;
898
899    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
900        let mut this = self.project();
901
902        loop {
903            match this.state.as_mut().project() {
904                RunStateProj::New { ctx, closure, .. } => {
905                    let ctx = ctx
906                        .take()
907                        .expect("Future should not be polled after returning Poll::Ready");
908                    let closure = closure
909                        .take()
910                        .expect("Future should not be polled after returning Poll::Ready");
911                    let mut inner_ctx = must_lock!(ctx);
912
913                    let handle = inner_ctx
914                        .vm
915                        .sys_run(this.name.to_owned())
916                        .map_err(ErrorInner::from)?;
917
918                    // Now we do progress once to check whether this closure should be executed or not.
919                    match inner_ctx.vm.do_progress(vec![handle]) {
920                        Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
921                            // In case it returns ExecuteRun, it must be the handle we just gave it,
922                            // and it means we need to execute the closure
923                            assert_eq!(handle, handle_to_run);
924
925                            drop(inner_ctx);
926                            this.state.set(RunState::ClosureRunning {
927                                ctx: Some(ctx),
928                                handle,
929                                start_time: Instant::now(),
930                                closure_fut: closure.run(),
931                            });
932                        }
933                        Ok(DoProgressResponse::CancelSignalReceived) => {
934                            drop(inner_ctx);
935                            // Got cancellation!
936                            this.state.set(RunState::WaitingResultFut {
937                                result_fut: async {
938                                    Ok(Err(TerminalError::from(TerminalFailure {
939                                        code: 409,
940                                        message: "cancelled".to_string(),
941                                        metadata: vec![],
942                                    })))
943                                }
944                                .boxed(),
945                            })
946                        }
947                        _ => {
948                            drop(inner_ctx);
949                            // In all the other cases, just move on waiting the result,
950                            // the poll future state will take care of doing whatever needs to be done here,
951                            // that is propagating state machine error, or result, or whatever
952                            this.state.set(RunState::WaitingResultFut {
953                                result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
954                            })
955                        }
956                    }
957                }
958                RunStateProj::ClosureRunning {
959                    ctx,
960                    handle,
961                    start_time,
962                    closure_fut,
963                } => {
964                    let res = match ready!(closure_fut.poll(cx)) {
965                        Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
966                            ErrorInner::Serialization {
967                                syscall: "run",
968                                err: Box::new(e),
969                            }
970                        })?),
971                        Err(e) => match e.0 {
972                            HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
973                                attempt_duration: start_time.elapsed(),
974                                error: CoreError::new(500u16, err.to_string()),
975                            },
976                            HandlerErrorInner::Terminal(t) => {
977                                RunExitResult::TerminalFailure(TerminalError(t).into())
978                            }
979                        },
980                    };
981
982                    let ctx = ctx
983                        .take()
984                        .expect("Future should not be polled after returning Poll::Ready");
985                    let handle = *handle;
986
987                    let _ = {
988                        must_lock!(ctx).vm.propose_run_completion(
989                            handle,
990                            res,
991                            mem::take(this.retry_policy),
992                        )
993                    };
994
995                    this.state.set(RunState::WaitingResultFut {
996                        result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
997                    });
998                }
999                RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
1000            }
1001        }
1002    }
1003}
1004
1005pin_project! {
1006    struct CallFutureImpl<InvIdFut: Future, ResultFut> {
1007        #[pin]
1008        invocation_id_future: Shared<InvIdFut>,
1009        #[pin]
1010        result_future: ResultFut,
1011        call_notification_handle: NotificationHandle,
1012        ctx: ContextInternal,
1013    }
1014}
1015
1016impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
1017where
1018    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1019    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1020{
1021    type Output = Result<Res, TerminalError>;
1022
1023    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1024        let this = self.project();
1025        let result = ready!(this.result_future.poll(cx));
1026
1027        match result {
1028            Ok(r) => Poll::Ready(r),
1029            Err(e) => {
1030                this.ctx.fail(e);
1031
1032                // Here is the secret sauce. This will immediately cause the whole future chain to be polled,
1033                //  but the poll here will be intercepted by HandlerStateAwareFuture
1034                cx.waker().wake_by_ref();
1035                Poll::Pending
1036            }
1037        }
1038    }
1039}
1040
1041impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
1042where
1043    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1044{
1045    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1046        Shared::clone(&self.invocation_id_future)
1047    }
1048
1049    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1050        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1051        let cloned_ctx = Arc::clone(&self.ctx.inner);
1052        async move {
1053            let inv_id = cloned_invocation_id_fut.await?;
1054            let mut inner_lock = must_lock!(cloned_ctx);
1055            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1056            inner_lock.maybe_flip_span_replaying_field();
1057            drop(inner_lock);
1058            Ok(())
1059        }
1060    }
1061}
1062
1063impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1064where
1065    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1066    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1067{
1068    type Response = Res;
1069}
1070
1071impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1072    for CallFutureImpl<InvIdFut, ResultFut>
1073where
1074    InvIdFut: Future,
1075{
1076    fn inner_context(&self) -> ContextInternal {
1077        self.ctx.clone()
1078    }
1079
1080    fn handle(&self) -> NotificationHandle {
1081        self.call_notification_handle
1082    }
1083}
1084
1085impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1086where
1087    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1088    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1089{
1090}
1091
1092struct SendRequestHandle<InvIdFut: Future> {
1093    invocation_id_future: Shared<InvIdFut>,
1094    ctx: ContextInternal,
1095}
1096
1097impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1098    for SendRequestHandle<InvIdFut>
1099{
1100    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1101        Shared::clone(&self.invocation_id_future)
1102    }
1103
1104    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1105        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1106        let cloned_ctx = Arc::clone(&self.ctx.inner);
1107        async move {
1108            let inv_id = cloned_invocation_id_fut.await?;
1109            let mut inner_lock = must_lock!(cloned_ctx);
1110            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1111            inner_lock.maybe_flip_span_replaying_field();
1112            drop(inner_lock);
1113            Ok(())
1114        }
1115    }
1116}
1117
1118struct InvocationIdBackedInvocationHandle {
1119    ctx: ContextInternal,
1120    invocation_id: String,
1121}
1122
1123impl InvocationHandle for InvocationIdBackedInvocationHandle {
1124    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1125        ready(Ok(self.invocation_id.clone()))
1126    }
1127
1128    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1129        let mut inner_lock = must_lock!(self.ctx.inner);
1130        let _ = inner_lock
1131            .vm
1132            .sys_cancel_invocation(self.invocation_id.clone());
1133        ready(Ok(()))
1134    }
1135}
1136
1137impl<A, B> InvocationHandle for Either<A, B>
1138where
1139    A: InvocationHandle,
1140    B: InvocationHandle,
1141{
1142    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1143        match self {
1144            Either::Left(l) => Either::Left(l.invocation_id()),
1145            Either::Right(r) => Either::Right(r.invocation_id()),
1146        }
1147    }
1148
1149    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1150        match self {
1151            Either::Left(l) => Either::Left(l.cancel()),
1152            Either::Right(r) => Either::Right(r.cancel()),
1153        }
1154    }
1155}
1156
1157impl Error {
1158    fn serialization<E: std::error::Error + Send + Sync + 'static>(
1159        syscall: &'static str,
1160        e: E,
1161    ) -> Self {
1162        ErrorInner::Serialization {
1163            syscall,
1164            err: Box::new(e),
1165        }
1166        .into()
1167    }
1168
1169    fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1170        syscall: &'static str,
1171        e: E,
1172    ) -> Self {
1173        ErrorInner::Deserialization {
1174            syscall,
1175            err: Box::new(e),
1176        }
1177        .into()
1178    }
1179}
1180
1181fn get_async_result(
1182    ctx: Arc<Mutex<ContextInternalInner>>,
1183    handle: NotificationHandle,
1184) -> impl Future<Output = Result<Value, Error>> + Send {
1185    VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1186}