Skip to main content

restate_sdk/endpoint/
context.rs

1use crate::context::{
2    CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3    RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18    CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19    PayloadOptions, RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM,
20    Value,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::future::{Future, ready};
25use std::marker::PhantomData;
26use std::mem;
27use std::pin::Pin;
28use std::sync::{Arc, Mutex};
29use std::task::{Context, Poll, ready};
30use std::time::{Duration, Instant, SystemTime};
31
32pub struct ContextInternalInner {
33    pub(crate) vm: CoreVM,
34    pub(crate) read: InputReceiver,
35    pub(crate) write: OutputSender,
36    pub(super) handler_state: HandlerStateNotifier,
37
38    /// We remember here the state of the span replaying field state, because setting it might be expensive (it's guarded behind locks and other stuff).
39    /// For details, see [ContextInternalInner::maybe_flip_span_replaying_field]
40    pub(super) span_replaying_field_state: bool,
41}
42
43impl ContextInternalInner {
44    fn new(
45        vm: CoreVM,
46        read: InputReceiver,
47        write: OutputSender,
48        handler_state: HandlerStateNotifier,
49    ) -> Self {
50        Self {
51            vm,
52            read,
53            write,
54            handler_state,
55            span_replaying_field_state: false,
56        }
57    }
58
59    pub(super) fn fail(&mut self, e: Error) {
60        self.maybe_flip_span_replaying_field();
61        self.vm.notify_error(
62            CoreError::new(500u16, e.0.to_string())
63                .with_stacktrace(Cow::<str>::Owned(format!("{:#}", e.0))),
64            None,
65        );
66        self.handler_state.mark_error(e);
67    }
68
69    pub(super) fn maybe_flip_span_replaying_field(&mut self) {
70        if !self.span_replaying_field_state && self.vm.is_replaying() {
71            tracing::Span::current().record("restate.sdk.is_replaying", true);
72            self.span_replaying_field_state = true;
73        } else if self.span_replaying_field_state && !self.vm.is_replaying() {
74            tracing::Span::current().record("restate.sdk.is_replaying", false);
75            self.span_replaying_field_state = false;
76        }
77    }
78}
79
80#[allow(unused)]
81const fn is_send_sync<T: Send + Sync>() {}
82const _: () = is_send_sync::<ContextInternal>();
83
84macro_rules! must_lock {
85    ($mutex:expr) => {
86        $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
87    };
88}
89
90macro_rules! unwrap_or_trap {
91    ($inner_lock:expr, $res:expr) => {
92        match $res {
93            Ok(t) => t,
94            Err(e) => {
95                $inner_lock.fail(e.into());
96                return Either::Right(TrapFuture::default());
97            }
98        }
99    };
100}
101
102macro_rules! unwrap_or_trap_durable_future {
103    ($ctx:expr, $inner_lock:expr, $res:expr) => {
104        match $res {
105            Ok(t) => t,
106            Err(e) => {
107                $inner_lock.fail(e.into());
108                return DurableFutureImpl::new(
109                    $ctx.clone(),
110                    NotificationHandle::from(u32::MAX),
111                    Either::Right(TrapFuture::default()),
112                );
113            }
114        }
115    };
116}
117
118#[derive(Debug, Eq, PartialEq)]
119pub struct InputMetadata {
120    pub invocation_id: String,
121    pub random_seed: u64,
122    pub key: String,
123    pub headers: http::HeaderMap<String>,
124}
125
126impl From<RequestTarget> for Target {
127    fn from(value: RequestTarget) -> Self {
128        match value {
129            RequestTarget::Service { name, handler } => Target {
130                service: name,
131                handler,
132                key: None,
133                idempotency_key: None,
134                headers: vec![],
135            },
136            RequestTarget::Object { name, key, handler } => Target {
137                service: name,
138                handler,
139                key: Some(key),
140                idempotency_key: None,
141                headers: vec![],
142            },
143            RequestTarget::Workflow { name, key, handler } => Target {
144                service: name,
145                handler,
146                key: Some(key),
147                idempotency_key: None,
148                headers: vec![],
149            },
150        }
151    }
152}
153
154/// Internal context interface.
155///
156/// For the high level interfaces, look at [`crate::context`].
157#[derive(Clone)]
158pub struct ContextInternal {
159    svc_name: String,
160    handler_name: String,
161    inner: Arc<Mutex<ContextInternalInner>>,
162}
163
164impl ContextInternal {
165    pub(super) fn new(
166        vm: CoreVM,
167        svc_name: String,
168        handler_name: String,
169        read: InputReceiver,
170        write: OutputSender,
171        handler_state: HandlerStateNotifier,
172    ) -> Self {
173        Self {
174            svc_name,
175            handler_name,
176            inner: Arc::new(Mutex::new(ContextInternalInner::new(
177                vm,
178                read,
179                write,
180                handler_state,
181            ))),
182        }
183    }
184
185    pub fn service_name(&self) -> &str {
186        &self.svc_name
187    }
188
189    pub fn handler_name(&self) -> &str {
190        &self.handler_name
191    }
192
193    pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
194        let mut inner_lock = must_lock!(self.inner);
195        let input_result =
196            inner_lock
197                .vm
198                .sys_input()
199                .map_err(ErrorInner::VM)
200                .map(|mut raw_input| {
201                    let headers = http::HeaderMap::<String>::try_from(
202                        &raw_input
203                            .headers
204                            .into_iter()
205                            .map(|h| (h.key.to_string(), h.value.to_string()))
206                            .collect::<HashMap<String, String>>(),
207                    )
208                    .map_err(|e| {
209                        TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
210                    })?;
211
212                    Ok::<_, TerminalError>((
213                        T::deserialize(&mut (raw_input.input)).map_err(|e| {
214                            TerminalError::new_with_code(
215                                400,
216                                format!("Cannot decode input payload: {e:?}"),
217                            )
218                        })?,
219                        InputMetadata {
220                            invocation_id: raw_input.invocation_id,
221                            random_seed: raw_input.random_seed,
222                            key: raw_input.key,
223                            headers,
224                        },
225                    ))
226                });
227        inner_lock.maybe_flip_span_replaying_field();
228
229        match input_result {
230            Ok(Ok(i)) => {
231                drop(inner_lock);
232                return Either::Left(ready(i));
233            }
234            Ok(Err(err)) => {
235                let error_inner = ErrorInner::Deserialization {
236                    syscall: "input",
237                    err: err.0.clone().into(),
238                };
239                let _ = inner_lock
240                    .vm
241                    .sys_write_output(NonEmptyValue::Failure(err.into()), PayloadOptions::stable());
242                let _ = inner_lock.vm.sys_end();
243                // This causes the trap, plus logs the error
244                inner_lock.handler_state.mark_error(error_inner.into());
245                drop(inner_lock);
246            }
247            Err(e) => {
248                inner_lock.fail(e.into());
249                drop(inner_lock);
250            }
251        }
252        Either::Right(TrapFuture::default())
253    }
254
255    pub fn get<T: Deserialize>(
256        &self,
257        key: &str,
258    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
259        let mut inner_lock = must_lock!(self.inner);
260        let handle = unwrap_or_trap!(
261            inner_lock,
262            inner_lock
263                .vm
264                .sys_state_get(key.to_owned(), PayloadOptions::stable())
265        );
266        inner_lock.maybe_flip_span_replaying_field();
267
268        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
269            Ok(Value::Void) => Ok(Ok(None)),
270            Ok(Value::Success(mut s)) => {
271                let t =
272                    T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
273                Ok(Ok(Some(t)))
274            }
275            Ok(Value::Failure(f)) => Ok(Err(f.into())),
276            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
277                variant: <&'static str>::from(v),
278                syscall: "get_state",
279            }
280            .into()),
281            Err(e) => Err(e),
282        });
283
284        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
285    }
286
287    pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
288        let mut inner_lock = must_lock!(self.inner);
289        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
290        inner_lock.maybe_flip_span_replaying_field();
291
292        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
293            Ok(Value::Failure(f)) => Ok(Err(f.into())),
294            Ok(Value::StateKeys(s)) => Ok(Ok(s)),
295            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
296                variant: <&'static str>::from(v),
297                syscall: "get_keys",
298            }
299            .into()),
300            Err(e) => Err(e),
301        });
302
303        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
304    }
305
306    pub fn set<T: Serialize>(&self, key: &str, t: T) {
307        let mut inner_lock = must_lock!(self.inner);
308        match t.serialize() {
309            Ok(b) => {
310                let _ = inner_lock
311                    .vm
312                    .sys_state_set(key.to_owned(), b, PayloadOptions::stable());
313                inner_lock.maybe_flip_span_replaying_field();
314            }
315            Err(e) => {
316                inner_lock.fail(Error::serialization("set_state", e));
317            }
318        }
319    }
320
321    pub fn clear(&self, key: &str) {
322        let mut inner_lock = must_lock!(self.inner);
323        let _ = inner_lock.vm.sys_state_clear(key.to_string());
324        inner_lock.maybe_flip_span_replaying_field();
325    }
326
327    pub fn clear_all(&self) {
328        let mut inner_lock = must_lock!(self.inner);
329        let _ = inner_lock.vm.sys_state_clear_all();
330        inner_lock.maybe_flip_span_replaying_field();
331    }
332
333    pub fn select(
334        &self,
335        handles: Vec<NotificationHandle>,
336    ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
337        InterceptErrorFuture::new(
338            self.clone(),
339            VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
340        )
341    }
342
343    pub fn sleep(
344        &self,
345        sleep_duration: Duration,
346    ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
347        let now = SystemTime::now()
348            .duration_since(SystemTime::UNIX_EPOCH)
349            .expect("Duration since unix epoch cannot fail");
350        let mut inner_lock = must_lock!(self.inner);
351        let handle = unwrap_or_trap_durable_future!(
352            self,
353            inner_lock,
354            inner_lock
355                .vm
356                .sys_sleep(String::default(), now + sleep_duration, Some(now))
357        );
358        inner_lock.maybe_flip_span_replaying_field();
359
360        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
361            Ok(Value::Void) => Ok(Ok(())),
362            Ok(Value::Failure(f)) => Ok(Err(f.into())),
363            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
364                variant: <&'static str>::from(v),
365                syscall: "sleep",
366            }
367            .into()),
368            Err(e) => Err(e),
369        });
370
371        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
372    }
373
374    pub fn request<Req, Res>(
375        &self,
376        request_target: RequestTarget,
377        req: Req,
378    ) -> Request<'_, Req, Res> {
379        Request::new(self, request_target, req)
380    }
381
382    pub fn call<Req: Serialize, Res: Deserialize>(
383        &self,
384        request_target: RequestTarget,
385        idempotency_key: Option<String>,
386        headers: Vec<(String, String)>,
387        req: Req,
388    ) -> impl CallFuture<Response = Res> + Send {
389        let mut inner_lock = must_lock!(self.inner);
390
391        let mut target: Target = request_target.into();
392        target.idempotency_key = idempotency_key;
393        target.headers = headers
394            .into_iter()
395            .map(|(k, v)| Header {
396                key: k.into(),
397                value: v.into(),
398            })
399            .collect();
400        let call_result = Req::serialize(&req)
401            .map_err(|e| Error::serialization("call", e))
402            .and_then(|input| {
403                inner_lock
404                    .vm
405                    .sys_call(target, input, None, PayloadOptions::stable())
406                    .map_err(Into::into)
407            });
408
409        let call_handle = match call_result {
410            Ok(t) => t,
411            Err(e) => {
412                inner_lock.fail(e);
413                return CallFutureImpl {
414                    invocation_id_future: Either::Right(TrapFuture::default()).shared(),
415                    result_future: Either::Right(TrapFuture::default()),
416                    call_notification_handle: NotificationHandle::from(u32::MAX),
417                    ctx: self.clone(),
418                };
419            }
420        };
421        inner_lock.maybe_flip_span_replaying_field();
422        drop(inner_lock);
423
424        // Let's prepare the two futures here
425        let invocation_id_fut = InterceptErrorFuture::new(
426            self.clone(),
427            get_async_result(
428                Arc::clone(&self.inner),
429                call_handle.invocation_id_notification_handle,
430            )
431            .map(|res| match res {
432                Ok(Value::Failure(f)) => Ok(Err(f.into())),
433                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
434                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
435                    variant: <&'static str>::from(v),
436                    syscall: "call",
437                }
438                .into()),
439                Err(e) => Err(e),
440            }),
441        );
442        let result_future = get_async_result(
443            Arc::clone(&self.inner),
444            call_handle.call_notification_handle,
445        )
446        .map(|res| match res {
447            Ok(Value::Success(mut s)) => Ok(Ok(
448                Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
449            )),
450            Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
451            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
452                variant: <&'static str>::from(v),
453                syscall: "call",
454            }
455            .into()),
456            Err(e) => Err(e),
457        });
458
459        CallFutureImpl {
460            invocation_id_future: Either::Left(invocation_id_fut).shared(),
461            result_future: Either::Left(result_future),
462            call_notification_handle: call_handle.call_notification_handle,
463            ctx: self.clone(),
464        }
465    }
466
467    pub fn send<Req: Serialize>(
468        &self,
469        request_target: RequestTarget,
470        idempotency_key: Option<String>,
471        headers: Vec<(String, String)>,
472        req: Req,
473        delay: Option<Duration>,
474    ) -> impl InvocationHandle {
475        let mut inner_lock = must_lock!(self.inner);
476
477        let mut target: Target = request_target.into();
478        target.idempotency_key = idempotency_key;
479        target.headers = headers
480            .into_iter()
481            .map(|(k, v)| Header {
482                key: k.into(),
483                value: v.into(),
484            })
485            .collect();
486        let input = match Req::serialize(&req) {
487            Ok(b) => b,
488            Err(e) => {
489                inner_lock.fail(Error::serialization("call", e));
490                return Either::Right(TrapFuture::<()>::default());
491            }
492        };
493
494        let send_handle = match inner_lock.vm.sys_send(
495            target,
496            input,
497            delay.map(|delay| {
498                SystemTime::now()
499                    .duration_since(SystemTime::UNIX_EPOCH)
500                    .expect("Duration since unix epoch cannot fail")
501                    + delay
502            }),
503            None,
504            PayloadOptions::stable(),
505        ) {
506            Ok(h) => h,
507            Err(e) => {
508                inner_lock.fail(e.into());
509                return Either::Right(TrapFuture::<()>::default());
510            }
511        };
512        inner_lock.maybe_flip_span_replaying_field();
513        drop(inner_lock);
514
515        let invocation_id_fut = InterceptErrorFuture::new(
516            self.clone(),
517            get_async_result(
518                Arc::clone(&self.inner),
519                send_handle.invocation_id_notification_handle,
520            )
521            .map(|res| match res {
522                Ok(Value::Failure(f)) => Ok(Err(f.into())),
523                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
524                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
525                    variant: <&'static str>::from(v),
526                    syscall: "call",
527                }
528                .into()),
529                Err(e) => Err(e),
530            }),
531        );
532
533        Either::Left(SendRequestHandle {
534            invocation_id_future: invocation_id_fut.shared(),
535            ctx: self.clone(),
536        })
537    }
538
539    pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
540        InvocationIdBackedInvocationHandle {
541            ctx: self.clone(),
542            invocation_id,
543        }
544    }
545
546    pub fn awakeable<T: Deserialize>(
547        &self,
548    ) -> (
549        String,
550        impl DurableFuture<Output = Result<T, TerminalError>> + Send,
551    ) {
552        let mut inner_lock = must_lock!(self.inner);
553        let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
554        inner_lock.maybe_flip_span_replaying_field();
555
556        let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
557            Ok((s, handle)) => (s, handle),
558            Err(e) => {
559                inner_lock.fail(e.into());
560                return (
561                    // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice.
562                    //  we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place.
563                    "invalid".to_owned(),
564                    DurableFutureImpl::new(
565                        self.clone(),
566                        NotificationHandle::from(u32::MAX),
567                        Either::Right(TrapFuture::default()),
568                    ),
569                );
570            }
571        };
572        drop(inner_lock);
573
574        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
575            Ok(Value::Success(mut s)) => Ok(Ok(
576                T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
577            )),
578            Ok(Value::Failure(f)) => Ok(Err(f.into())),
579            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
580                variant: <&'static str>::from(v),
581                syscall: "awakeable",
582            }
583            .into()),
584            Err(e) => Err(e),
585        });
586
587        (
588            awakeable_id,
589            DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
590        )
591    }
592
593    pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
594        let mut inner_lock = must_lock!(self.inner);
595        match t.serialize() {
596            Ok(b) => {
597                let _ = inner_lock.vm.sys_complete_awakeable(
598                    id.to_owned(),
599                    NonEmptyValue::Success(b),
600                    PayloadOptions::stable(),
601                );
602            }
603            Err(e) => {
604                inner_lock.fail(Error::serialization("resolve_awakeable", e));
605            }
606        }
607    }
608
609    pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
610        let _ = must_lock!(self.inner).vm.sys_complete_awakeable(
611            id.to_owned(),
612            NonEmptyValue::Failure(failure.into()),
613            PayloadOptions::stable(),
614        );
615    }
616
617    pub fn promise<T: Deserialize>(
618        &self,
619        name: &str,
620    ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
621        let mut inner_lock = must_lock!(self.inner);
622        let handle = unwrap_or_trap_durable_future!(
623            self,
624            inner_lock,
625            inner_lock.vm.sys_get_promise(name.to_owned())
626        );
627        inner_lock.maybe_flip_span_replaying_field();
628        drop(inner_lock);
629
630        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
631            Ok(Value::Success(mut s)) => {
632                let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
633                Ok(Ok(t))
634            }
635            Ok(Value::Failure(f)) => Ok(Err(f.into())),
636            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
637                variant: <&'static str>::from(v),
638                syscall: "promise",
639            }
640            .into()),
641            Err(e) => Err(e),
642        });
643
644        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
645    }
646
647    pub fn peek_promise<T: Deserialize>(
648        &self,
649        name: &str,
650    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
651        let mut inner_lock = must_lock!(self.inner);
652        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
653        inner_lock.maybe_flip_span_replaying_field();
654        drop(inner_lock);
655
656        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
657            Ok(Value::Void) => Ok(Ok(None)),
658            Ok(Value::Success(mut s)) => {
659                let t = T::deserialize(&mut s)
660                    .map_err(|e| Error::deserialization("peek_promise", e))?;
661                Ok(Ok(Some(t)))
662            }
663            Ok(Value::Failure(f)) => Ok(Err(f.into())),
664            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
665                variant: <&'static str>::from(v),
666                syscall: "peek_promise",
667            }
668            .into()),
669            Err(e) => Err(e),
670        });
671
672        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
673    }
674
675    pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
676        let mut inner_lock = must_lock!(self.inner);
677        match t.serialize() {
678            Ok(b) => {
679                let _ = inner_lock.vm.sys_complete_promise(
680                    name.to_owned(),
681                    NonEmptyValue::Success(b),
682                    PayloadOptions::stable(),
683                );
684            }
685            Err(e) => {
686                inner_lock.fail(
687                    ErrorInner::Serialization {
688                        syscall: "resolve_promise",
689                        err: Box::new(e),
690                    }
691                    .into(),
692                );
693            }
694        }
695    }
696
697    pub fn reject_promise(&self, id: &str, failure: TerminalError) {
698        let _ = must_lock!(self.inner).vm.sys_complete_promise(
699            id.to_owned(),
700            NonEmptyValue::Failure(failure.into()),
701            PayloadOptions::stable(),
702        );
703    }
704
705    pub fn run<'a, Run, Fut, Out>(
706        &'a self,
707        run_closure: Run,
708    ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
709    where
710        Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
711        Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
712        Out: Serialize + Deserialize + 'static,
713    {
714        let this = Arc::clone(&self.inner);
715        InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
716    }
717
718    // Used by codegen
719    pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
720        let mut inner_lock = must_lock!(self.inner);
721
722        let res_to_write = match res {
723            Ok(success) => match T::serialize(&success) {
724                Ok(t) => NonEmptyValue::Success(t),
725                Err(e) => {
726                    inner_lock.fail(
727                        ErrorInner::Serialization {
728                            syscall: "output",
729                            err: Box::new(e),
730                        }
731                        .into(),
732                    );
733                    return;
734                }
735            },
736            Err(e) => match e.0 {
737                HandlerErrorInner::Retryable(err) => {
738                    inner_lock.fail(ErrorInner::HandlerResult { err }.into());
739                    return;
740                }
741                HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
742            },
743        };
744
745        let _ = inner_lock
746            .vm
747            .sys_write_output(res_to_write, PayloadOptions::stable());
748        inner_lock.maybe_flip_span_replaying_field();
749    }
750
751    pub fn end(&self) {
752        let _ = must_lock!(self.inner).vm.sys_end();
753    }
754
755    pub(crate) fn consume_to_end(&self) {
756        let mut inner_lock = must_lock!(self.inner);
757
758        let out = inner_lock.vm.take_output();
759        if let TakeOutputResult::Buffer(b) = out
760            && !inner_lock.write.send(b)
761        {
762            // Nothing we can do anymore here
763        }
764    }
765
766    pub(super) fn fail(&self, e: Error) {
767        must_lock!(self.inner).fail(e)
768    }
769}
770
771pin_project! {
772    struct RunFutureImpl<Run, Ret, RunFnFut> {
773        name: String,
774        retry_policy: RetryPolicy,
775        phantom_data: PhantomData<fn() -> Ret>,
776        #[pin]
777        state: RunState<Run, RunFnFut, Ret>,
778    }
779}
780
781pin_project! {
782    #[project = RunStateProj]
783    enum RunState<Run, RunFnFut, Ret> {
784        New {
785            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
786            closure: Option<Run>,
787        },
788        ClosureRunning {
789            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
790            handle: NotificationHandle,
791            start_time: Instant,
792            #[pin]
793            closure_fut: RunFnFut,
794        },
795        WaitingResultFut {
796            result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
797        }
798    }
799}
800
801impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
802    fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
803        Self {
804            name: "".to_string(),
805            retry_policy: RetryPolicy::Infinite,
806            phantom_data: PhantomData,
807            state: RunState::New {
808                ctx: Some(ctx),
809                closure: Some(closure),
810            },
811        }
812    }
813
814    fn boxed_result_fut(
815        ctx: Arc<Mutex<ContextInternalInner>>,
816        handle: NotificationHandle,
817    ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
818    where
819        Ret: Deserialize,
820    {
821        get_async_result(Arc::clone(&ctx), handle)
822            .map(|res| match res {
823                Ok(Value::Success(mut s)) => {
824                    let t =
825                        Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
826                    Ok(Ok(t))
827                }
828                Ok(Value::Failure(f)) => Ok(Err(f.into())),
829                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
830                    variant: <&'static str>::from(v),
831                    syscall: "run",
832                }
833                .into()),
834                Err(e) => Err(e),
835            })
836            .boxed()
837    }
838}
839
840impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
841    for RunFutureImpl<Run, Ret, RunFnFut>
842where
843    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
844    Ret: Serialize + Deserialize,
845    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
846{
847    fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
848        self.retry_policy = RetryPolicy::Exponential {
849            initial_interval: retry_policy.initial_delay,
850            factor: retry_policy.factor,
851            max_interval: retry_policy.max_delay,
852            max_attempts: retry_policy.max_attempts,
853            max_duration: retry_policy.max_duration,
854        };
855        self
856    }
857
858    fn name(mut self, name: impl Into<String>) -> Self {
859        self.name = name.into();
860        self
861    }
862}
863
864impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
865where
866    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
867    Ret: Serialize + Deserialize,
868    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
869{
870    type Output = Result<Result<Ret, TerminalError>, Error>;
871
872    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
873        let mut this = self.project();
874
875        loop {
876            match this.state.as_mut().project() {
877                RunStateProj::New { ctx, closure, .. } => {
878                    let ctx = ctx
879                        .take()
880                        .expect("Future should not be polled after returning Poll::Ready");
881                    let closure = closure
882                        .take()
883                        .expect("Future should not be polled after returning Poll::Ready");
884                    let mut inner_ctx = must_lock!(ctx);
885
886                    let handle = inner_ctx
887                        .vm
888                        .sys_run(this.name.to_owned())
889                        .map_err(ErrorInner::from)?;
890
891                    // Now we do progress once to check whether this closure should be executed or not.
892                    match inner_ctx.vm.do_progress(vec![handle]) {
893                        Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
894                            // In case it returns ExecuteRun, it must be the handle we just gave it,
895                            // and it means we need to execute the closure
896                            assert_eq!(handle, handle_to_run);
897
898                            drop(inner_ctx);
899                            this.state.set(RunState::ClosureRunning {
900                                ctx: Some(ctx),
901                                handle,
902                                start_time: Instant::now(),
903                                closure_fut: closure.run(),
904                            });
905                        }
906                        Ok(DoProgressResponse::CancelSignalReceived) => {
907                            drop(inner_ctx);
908                            // Got cancellation!
909                            this.state.set(RunState::WaitingResultFut {
910                                result_fut: async {
911                                    Ok(Err(TerminalError::from(TerminalFailure {
912                                        code: 409,
913                                        message: "cancelled".to_string(),
914                                        metadata: vec![],
915                                    })))
916                                }
917                                .boxed(),
918                            })
919                        }
920                        _ => {
921                            drop(inner_ctx);
922                            // In all the other cases, just move on waiting the result,
923                            // the poll future state will take care of doing whatever needs to be done here,
924                            // that is propagating state machine error, or result, or whatever
925                            this.state.set(RunState::WaitingResultFut {
926                                result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
927                            })
928                        }
929                    }
930                }
931                RunStateProj::ClosureRunning {
932                    ctx,
933                    handle,
934                    start_time,
935                    closure_fut,
936                } => {
937                    let res = match ready!(closure_fut.poll(cx)) {
938                        Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
939                            ErrorInner::Serialization {
940                                syscall: "run",
941                                err: Box::new(e),
942                            }
943                        })?),
944                        Err(e) => match e.0 {
945                            HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
946                                attempt_duration: start_time.elapsed(),
947                                error: CoreError::new(500u16, err.to_string()),
948                            },
949                            HandlerErrorInner::Terminal(t) => {
950                                RunExitResult::TerminalFailure(TerminalError(t).into())
951                            }
952                        },
953                    };
954
955                    let ctx = ctx
956                        .take()
957                        .expect("Future should not be polled after returning Poll::Ready");
958                    let handle = *handle;
959
960                    let _ = {
961                        must_lock!(ctx).vm.propose_run_completion(
962                            handle,
963                            res,
964                            mem::take(this.retry_policy),
965                        )
966                    };
967
968                    this.state.set(RunState::WaitingResultFut {
969                        result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
970                    });
971                }
972                RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
973            }
974        }
975    }
976}
977
978pin_project! {
979    struct CallFutureImpl<InvIdFut: Future, ResultFut> {
980        #[pin]
981        invocation_id_future: Shared<InvIdFut>,
982        #[pin]
983        result_future: ResultFut,
984        call_notification_handle: NotificationHandle,
985        ctx: ContextInternal,
986    }
987}
988
989impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
990where
991    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
992    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
993{
994    type Output = Result<Res, TerminalError>;
995
996    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
997        let this = self.project();
998        let result = ready!(this.result_future.poll(cx));
999
1000        match result {
1001            Ok(r) => Poll::Ready(r),
1002            Err(e) => {
1003                this.ctx.fail(e);
1004
1005                // Here is the secret sauce. This will immediately cause the whole future chain to be polled,
1006                //  but the poll here will be intercepted by HandlerStateAwareFuture
1007                cx.waker().wake_by_ref();
1008                Poll::Pending
1009            }
1010        }
1011    }
1012}
1013
1014impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
1015where
1016    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1017{
1018    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1019        Shared::clone(&self.invocation_id_future)
1020    }
1021
1022    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1023        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1024        let cloned_ctx = Arc::clone(&self.ctx.inner);
1025        async move {
1026            let inv_id = cloned_invocation_id_fut.await?;
1027            let mut inner_lock = must_lock!(cloned_ctx);
1028            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1029            inner_lock.maybe_flip_span_replaying_field();
1030            drop(inner_lock);
1031            Ok(())
1032        }
1033    }
1034}
1035
1036impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1037where
1038    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1039    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1040{
1041    type Response = Res;
1042}
1043
1044impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1045    for CallFutureImpl<InvIdFut, ResultFut>
1046where
1047    InvIdFut: Future,
1048{
1049    fn inner_context(&self) -> ContextInternal {
1050        self.ctx.clone()
1051    }
1052
1053    fn handle(&self) -> NotificationHandle {
1054        self.call_notification_handle
1055    }
1056}
1057
1058impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1059where
1060    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1061    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1062{
1063}
1064
1065struct SendRequestHandle<InvIdFut: Future> {
1066    invocation_id_future: Shared<InvIdFut>,
1067    ctx: ContextInternal,
1068}
1069
1070impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1071    for SendRequestHandle<InvIdFut>
1072{
1073    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1074        Shared::clone(&self.invocation_id_future)
1075    }
1076
1077    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1078        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1079        let cloned_ctx = Arc::clone(&self.ctx.inner);
1080        async move {
1081            let inv_id = cloned_invocation_id_fut.await?;
1082            let mut inner_lock = must_lock!(cloned_ctx);
1083            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1084            inner_lock.maybe_flip_span_replaying_field();
1085            drop(inner_lock);
1086            Ok(())
1087        }
1088    }
1089}
1090
1091struct InvocationIdBackedInvocationHandle {
1092    ctx: ContextInternal,
1093    invocation_id: String,
1094}
1095
1096impl InvocationHandle for InvocationIdBackedInvocationHandle {
1097    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1098        ready(Ok(self.invocation_id.clone()))
1099    }
1100
1101    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1102        let mut inner_lock = must_lock!(self.ctx.inner);
1103        let _ = inner_lock
1104            .vm
1105            .sys_cancel_invocation(self.invocation_id.clone());
1106        ready(Ok(()))
1107    }
1108}
1109
1110impl<A, B> InvocationHandle for Either<A, B>
1111where
1112    A: InvocationHandle,
1113    B: InvocationHandle,
1114{
1115    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1116        match self {
1117            Either::Left(l) => Either::Left(l.invocation_id()),
1118            Either::Right(r) => Either::Right(r.invocation_id()),
1119        }
1120    }
1121
1122    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1123        match self {
1124            Either::Left(l) => Either::Left(l.cancel()),
1125            Either::Right(r) => Either::Right(r.cancel()),
1126        }
1127    }
1128}
1129
1130impl Error {
1131    fn serialization<E: std::error::Error + Send + Sync + 'static>(
1132        syscall: &'static str,
1133        e: E,
1134    ) -> Self {
1135        ErrorInner::Serialization {
1136            syscall,
1137            err: Box::new(e),
1138        }
1139        .into()
1140    }
1141
1142    fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1143        syscall: &'static str,
1144        e: E,
1145    ) -> Self {
1146        ErrorInner::Deserialization {
1147            syscall,
1148            err: Box::new(e),
1149        }
1150        .into()
1151    }
1152}
1153
1154fn get_async_result(
1155    ctx: Arc<Mutex<ContextInternalInner>>,
1156    handle: NotificationHandle,
1157) -> impl Future<Output = Result<Value, Error>> + Send {
1158    VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1159}