Skip to main content

restate_sdk/endpoint/
context.rs

1use crate::context::{
2    CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3    RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18    CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19    RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM, Value,
20};
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::future::{Future, ready};
24use std::marker::PhantomData;
25use std::mem;
26use std::pin::Pin;
27use std::sync::{Arc, Mutex};
28use std::task::{Context, Poll, ready};
29use std::time::{Duration, Instant, SystemTime};
30
31pub struct ContextInternalInner {
32    pub(crate) vm: CoreVM,
33    pub(crate) read: InputReceiver,
34    pub(crate) write: OutputSender,
35    pub(super) handler_state: HandlerStateNotifier,
36
37    /// We remember here the state of the span replaying field state, because setting it might be expensive (it's guarded behind locks and other stuff).
38    /// For details, see [ContextInternalInner::maybe_flip_span_replaying_field]
39    pub(super) span_replaying_field_state: bool,
40}
41
42impl ContextInternalInner {
43    fn new(
44        vm: CoreVM,
45        read: InputReceiver,
46        write: OutputSender,
47        handler_state: HandlerStateNotifier,
48    ) -> Self {
49        Self {
50            vm,
51            read,
52            write,
53            handler_state,
54            span_replaying_field_state: false,
55        }
56    }
57
58    pub(super) fn fail(&mut self, e: Error) {
59        self.maybe_flip_span_replaying_field();
60        self.vm.notify_error(
61            CoreError::new(500u16, e.0.to_string())
62                .with_stacktrace(Cow::Owned(format!("{:#}", e.0))),
63            None,
64        );
65        self.handler_state.mark_error(e);
66    }
67
68    pub(super) fn maybe_flip_span_replaying_field(&mut self) {
69        if !self.span_replaying_field_state && self.vm.is_replaying() {
70            tracing::Span::current().record("restate.sdk.is_replaying", true);
71            self.span_replaying_field_state = true;
72        } else if self.span_replaying_field_state && !self.vm.is_replaying() {
73            tracing::Span::current().record("restate.sdk.is_replaying", false);
74            self.span_replaying_field_state = false;
75        }
76    }
77}
78
79#[allow(unused)]
80const fn is_send_sync<T: Send + Sync>() {}
81const _: () = is_send_sync::<ContextInternal>();
82
83macro_rules! must_lock {
84    ($mutex:expr) => {
85        $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
86    };
87}
88
89macro_rules! unwrap_or_trap {
90    ($inner_lock:expr, $res:expr) => {
91        match $res {
92            Ok(t) => t,
93            Err(e) => {
94                $inner_lock.fail(e.into());
95                return Either::Right(TrapFuture::default());
96            }
97        }
98    };
99}
100
101macro_rules! unwrap_or_trap_durable_future {
102    ($ctx:expr, $inner_lock:expr, $res:expr) => {
103        match $res {
104            Ok(t) => t,
105            Err(e) => {
106                $inner_lock.fail(e.into());
107                return DurableFutureImpl::new(
108                    $ctx.clone(),
109                    NotificationHandle::from(u32::MAX),
110                    Either::Right(TrapFuture::default()),
111                );
112            }
113        }
114    };
115}
116
117#[derive(Debug, Eq, PartialEq)]
118pub struct InputMetadata {
119    pub invocation_id: String,
120    pub random_seed: u64,
121    pub key: String,
122    pub headers: http::HeaderMap<String>,
123}
124
125impl From<RequestTarget> for Target {
126    fn from(value: RequestTarget) -> Self {
127        match value {
128            RequestTarget::Service { name, handler } => Target {
129                service: name,
130                handler,
131                key: None,
132                idempotency_key: None,
133                headers: vec![],
134            },
135            RequestTarget::Object { name, key, handler } => Target {
136                service: name,
137                handler,
138                key: Some(key),
139                idempotency_key: None,
140                headers: vec![],
141            },
142            RequestTarget::Workflow { name, key, handler } => Target {
143                service: name,
144                handler,
145                key: Some(key),
146                idempotency_key: None,
147                headers: vec![],
148            },
149        }
150    }
151}
152
153/// Internal context interface.
154///
155/// For the high level interfaces, look at [`crate::context`].
156#[derive(Clone)]
157pub struct ContextInternal {
158    svc_name: String,
159    handler_name: String,
160    inner: Arc<Mutex<ContextInternalInner>>,
161}
162
163impl ContextInternal {
164    pub(super) fn new(
165        vm: CoreVM,
166        svc_name: String,
167        handler_name: String,
168        read: InputReceiver,
169        write: OutputSender,
170        handler_state: HandlerStateNotifier,
171    ) -> Self {
172        Self {
173            svc_name,
174            handler_name,
175            inner: Arc::new(Mutex::new(ContextInternalInner::new(
176                vm,
177                read,
178                write,
179                handler_state,
180            ))),
181        }
182    }
183
184    pub fn service_name(&self) -> &str {
185        &self.svc_name
186    }
187
188    pub fn handler_name(&self) -> &str {
189        &self.handler_name
190    }
191
192    pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
193        let mut inner_lock = must_lock!(self.inner);
194        let input_result =
195            inner_lock
196                .vm
197                .sys_input()
198                .map_err(ErrorInner::VM)
199                .map(|mut raw_input| {
200                    let headers = http::HeaderMap::<String>::try_from(
201                        &raw_input
202                            .headers
203                            .into_iter()
204                            .map(|h| (h.key.to_string(), h.value.to_string()))
205                            .collect::<HashMap<String, String>>(),
206                    )
207                    .map_err(|e| {
208                        TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
209                    })?;
210
211                    Ok::<_, TerminalError>((
212                        T::deserialize(&mut (raw_input.input)).map_err(|e| {
213                            TerminalError::new_with_code(
214                                400,
215                                format!("Cannot decode input payload: {e:?}"),
216                            )
217                        })?,
218                        InputMetadata {
219                            invocation_id: raw_input.invocation_id,
220                            random_seed: raw_input.random_seed,
221                            key: raw_input.key,
222                            headers,
223                        },
224                    ))
225                });
226        inner_lock.maybe_flip_span_replaying_field();
227
228        match input_result {
229            Ok(Ok(i)) => {
230                drop(inner_lock);
231                return Either::Left(ready(i));
232            }
233            Ok(Err(err)) => {
234                let error_inner = ErrorInner::Deserialization {
235                    syscall: "input",
236                    err: err.0.clone().into(),
237                };
238                let _ = inner_lock
239                    .vm
240                    .sys_write_output(NonEmptyValue::Failure(err.into()));
241                let _ = inner_lock.vm.sys_end();
242                // This causes the trap, plus logs the error
243                inner_lock.handler_state.mark_error(error_inner.into());
244                drop(inner_lock);
245            }
246            Err(e) => {
247                inner_lock.fail(e.into());
248                drop(inner_lock);
249            }
250        }
251        Either::Right(TrapFuture::default())
252    }
253
254    pub fn get<T: Deserialize>(
255        &self,
256        key: &str,
257    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
258        let mut inner_lock = must_lock!(self.inner);
259        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned()));
260        inner_lock.maybe_flip_span_replaying_field();
261
262        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
263            Ok(Value::Void) => Ok(Ok(None)),
264            Ok(Value::Success(mut s)) => {
265                let t =
266                    T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
267                Ok(Ok(Some(t)))
268            }
269            Ok(Value::Failure(f)) => Ok(Err(f.into())),
270            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
271                variant: <&'static str>::from(v),
272                syscall: "get_state",
273            }
274            .into()),
275            Err(e) => Err(e),
276        });
277
278        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
279    }
280
281    pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
282        let mut inner_lock = must_lock!(self.inner);
283        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
284        inner_lock.maybe_flip_span_replaying_field();
285
286        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
287            Ok(Value::Failure(f)) => Ok(Err(f.into())),
288            Ok(Value::StateKeys(s)) => Ok(Ok(s)),
289            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
290                variant: <&'static str>::from(v),
291                syscall: "get_keys",
292            }
293            .into()),
294            Err(e) => Err(e),
295        });
296
297        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
298    }
299
300    pub fn set<T: Serialize>(&self, key: &str, t: T) {
301        let mut inner_lock = must_lock!(self.inner);
302        match t.serialize() {
303            Ok(b) => {
304                let _ = inner_lock.vm.sys_state_set(key.to_owned(), b);
305                inner_lock.maybe_flip_span_replaying_field();
306            }
307            Err(e) => {
308                inner_lock.fail(Error::serialization("set_state", e));
309            }
310        }
311    }
312
313    pub fn clear(&self, key: &str) {
314        let mut inner_lock = must_lock!(self.inner);
315        let _ = inner_lock.vm.sys_state_clear(key.to_string());
316        inner_lock.maybe_flip_span_replaying_field();
317    }
318
319    pub fn clear_all(&self) {
320        let mut inner_lock = must_lock!(self.inner);
321        let _ = inner_lock.vm.sys_state_clear_all();
322        inner_lock.maybe_flip_span_replaying_field();
323    }
324
325    pub fn select(
326        &self,
327        handles: Vec<NotificationHandle>,
328    ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
329        InterceptErrorFuture::new(
330            self.clone(),
331            VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
332        )
333    }
334
335    pub fn sleep(
336        &self,
337        sleep_duration: Duration,
338    ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
339        let now = SystemTime::now()
340            .duration_since(SystemTime::UNIX_EPOCH)
341            .expect("Duration since unix epoch cannot fail");
342        let mut inner_lock = must_lock!(self.inner);
343        let handle = unwrap_or_trap_durable_future!(
344            self,
345            inner_lock,
346            inner_lock
347                .vm
348                .sys_sleep(String::default(), now + sleep_duration, Some(now))
349        );
350        inner_lock.maybe_flip_span_replaying_field();
351
352        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
353            Ok(Value::Void) => Ok(Ok(())),
354            Ok(Value::Failure(f)) => Ok(Err(f.into())),
355            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
356                variant: <&'static str>::from(v),
357                syscall: "sleep",
358            }
359            .into()),
360            Err(e) => Err(e),
361        });
362
363        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
364    }
365
366    pub fn request<Req, Res>(
367        &self,
368        request_target: RequestTarget,
369        req: Req,
370    ) -> Request<'_, Req, Res> {
371        Request::new(self, request_target, req)
372    }
373
374    pub fn call<Req: Serialize, Res: Deserialize>(
375        &self,
376        request_target: RequestTarget,
377        idempotency_key: Option<String>,
378        headers: Vec<(String, String)>,
379        req: Req,
380    ) -> impl CallFuture<Response = Res> + Send {
381        let mut inner_lock = must_lock!(self.inner);
382
383        let mut target: Target = request_target.into();
384        target.idempotency_key = idempotency_key;
385        target.headers = headers
386            .into_iter()
387            .map(|(k, v)| Header {
388                key: k.into(),
389                value: v.into(),
390            })
391            .collect();
392        let call_result = Req::serialize(&req)
393            .map_err(|e| Error::serialization("call", e))
394            .and_then(|input| inner_lock.vm.sys_call(target, input).map_err(Into::into));
395
396        let call_handle = match call_result {
397            Ok(t) => t,
398            Err(e) => {
399                inner_lock.fail(e);
400                return CallFutureImpl {
401                    invocation_id_future: Either::Right(TrapFuture::default()).shared(),
402                    result_future: Either::Right(TrapFuture::default()),
403                    call_notification_handle: NotificationHandle::from(u32::MAX),
404                    ctx: self.clone(),
405                };
406            }
407        };
408        inner_lock.maybe_flip_span_replaying_field();
409        drop(inner_lock);
410
411        // Let's prepare the two futures here
412        let invocation_id_fut = InterceptErrorFuture::new(
413            self.clone(),
414            get_async_result(
415                Arc::clone(&self.inner),
416                call_handle.invocation_id_notification_handle,
417            )
418            .map(|res| match res {
419                Ok(Value::Failure(f)) => Ok(Err(f.into())),
420                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
421                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
422                    variant: <&'static str>::from(v),
423                    syscall: "call",
424                }
425                .into()),
426                Err(e) => Err(e),
427            }),
428        );
429        let result_future = get_async_result(
430            Arc::clone(&self.inner),
431            call_handle.call_notification_handle,
432        )
433        .map(|res| match res {
434            Ok(Value::Success(mut s)) => Ok(Ok(
435                Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
436            )),
437            Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
438            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
439                variant: <&'static str>::from(v),
440                syscall: "call",
441            }
442            .into()),
443            Err(e) => Err(e),
444        });
445
446        CallFutureImpl {
447            invocation_id_future: Either::Left(invocation_id_fut).shared(),
448            result_future: Either::Left(result_future),
449            call_notification_handle: call_handle.call_notification_handle,
450            ctx: self.clone(),
451        }
452    }
453
454    pub fn send<Req: Serialize>(
455        &self,
456        request_target: RequestTarget,
457        idempotency_key: Option<String>,
458        headers: Vec<(String, String)>,
459        req: Req,
460        delay: Option<Duration>,
461    ) -> impl InvocationHandle {
462        let mut inner_lock = must_lock!(self.inner);
463
464        let mut target: Target = request_target.into();
465        target.idempotency_key = idempotency_key;
466        target.headers = headers
467            .into_iter()
468            .map(|(k, v)| Header {
469                key: k.into(),
470                value: v.into(),
471            })
472            .collect();
473        let input = match Req::serialize(&req) {
474            Ok(b) => b,
475            Err(e) => {
476                inner_lock.fail(Error::serialization("call", e));
477                return Either::Right(TrapFuture::<()>::default());
478            }
479        };
480
481        let send_handle = match inner_lock.vm.sys_send(
482            target,
483            input,
484            delay.map(|delay| {
485                SystemTime::now()
486                    .duration_since(SystemTime::UNIX_EPOCH)
487                    .expect("Duration since unix epoch cannot fail")
488                    + delay
489            }),
490        ) {
491            Ok(h) => h,
492            Err(e) => {
493                inner_lock.fail(e.into());
494                return Either::Right(TrapFuture::<()>::default());
495            }
496        };
497        inner_lock.maybe_flip_span_replaying_field();
498        drop(inner_lock);
499
500        let invocation_id_fut = InterceptErrorFuture::new(
501            self.clone(),
502            get_async_result(
503                Arc::clone(&self.inner),
504                send_handle.invocation_id_notification_handle,
505            )
506            .map(|res| match res {
507                Ok(Value::Failure(f)) => Ok(Err(f.into())),
508                Ok(Value::InvocationId(s)) => Ok(Ok(s)),
509                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
510                    variant: <&'static str>::from(v),
511                    syscall: "call",
512                }
513                .into()),
514                Err(e) => Err(e),
515            }),
516        );
517
518        Either::Left(SendRequestHandle {
519            invocation_id_future: invocation_id_fut.shared(),
520            ctx: self.clone(),
521        })
522    }
523
524    pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
525        InvocationIdBackedInvocationHandle {
526            ctx: self.clone(),
527            invocation_id,
528        }
529    }
530
531    pub fn awakeable<T: Deserialize>(
532        &self,
533    ) -> (
534        String,
535        impl DurableFuture<Output = Result<T, TerminalError>> + Send,
536    ) {
537        let mut inner_lock = must_lock!(self.inner);
538        let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
539        inner_lock.maybe_flip_span_replaying_field();
540
541        let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
542            Ok((s, handle)) => (s, handle),
543            Err(e) => {
544                inner_lock.fail(e.into());
545                return (
546                    // TODO NOW this is REALLY BAD. The reason for this is that we would need to return a future of a future instead, which is not nice.
547                    //  we assume for the time being this works because no user should use the awakeable without doing any other syscall first, which will prevent this invalid awakeable id to work in the first place.
548                    "invalid".to_owned(),
549                    DurableFutureImpl::new(
550                        self.clone(),
551                        NotificationHandle::from(u32::MAX),
552                        Either::Right(TrapFuture::default()),
553                    ),
554                );
555            }
556        };
557        drop(inner_lock);
558
559        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
560            Ok(Value::Success(mut s)) => Ok(Ok(
561                T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
562            )),
563            Ok(Value::Failure(f)) => Ok(Err(f.into())),
564            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
565                variant: <&'static str>::from(v),
566                syscall: "awakeable",
567            }
568            .into()),
569            Err(e) => Err(e),
570        });
571
572        (
573            awakeable_id,
574            DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
575        )
576    }
577
578    pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
579        let mut inner_lock = must_lock!(self.inner);
580        match t.serialize() {
581            Ok(b) => {
582                let _ = inner_lock
583                    .vm
584                    .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b));
585            }
586            Err(e) => {
587                inner_lock.fail(Error::serialization("resolve_awakeable", e));
588            }
589        }
590    }
591
592    pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
593        let _ = must_lock!(self.inner)
594            .vm
595            .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Failure(failure.into()));
596    }
597
598    pub fn promise<T: Deserialize>(
599        &self,
600        name: &str,
601    ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
602        let mut inner_lock = must_lock!(self.inner);
603        let handle = unwrap_or_trap_durable_future!(
604            self,
605            inner_lock,
606            inner_lock.vm.sys_get_promise(name.to_owned())
607        );
608        inner_lock.maybe_flip_span_replaying_field();
609        drop(inner_lock);
610
611        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
612            Ok(Value::Success(mut s)) => {
613                let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
614                Ok(Ok(t))
615            }
616            Ok(Value::Failure(f)) => Ok(Err(f.into())),
617            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
618                variant: <&'static str>::from(v),
619                syscall: "promise",
620            }
621            .into()),
622            Err(e) => Err(e),
623        });
624
625        DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
626    }
627
628    pub fn peek_promise<T: Deserialize>(
629        &self,
630        name: &str,
631    ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
632        let mut inner_lock = must_lock!(self.inner);
633        let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
634        inner_lock.maybe_flip_span_replaying_field();
635        drop(inner_lock);
636
637        let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
638            Ok(Value::Void) => Ok(Ok(None)),
639            Ok(Value::Success(mut s)) => {
640                let t = T::deserialize(&mut s)
641                    .map_err(|e| Error::deserialization("peek_promise", e))?;
642                Ok(Ok(Some(t)))
643            }
644            Ok(Value::Failure(f)) => Ok(Err(f.into())),
645            Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
646                variant: <&'static str>::from(v),
647                syscall: "peek_promise",
648            }
649            .into()),
650            Err(e) => Err(e),
651        });
652
653        Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
654    }
655
656    pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
657        let mut inner_lock = must_lock!(self.inner);
658        match t.serialize() {
659            Ok(b) => {
660                let _ = inner_lock
661                    .vm
662                    .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b));
663            }
664            Err(e) => {
665                inner_lock.fail(
666                    ErrorInner::Serialization {
667                        syscall: "resolve_promise",
668                        err: Box::new(e),
669                    }
670                    .into(),
671                );
672            }
673        }
674    }
675
676    pub fn reject_promise(&self, id: &str, failure: TerminalError) {
677        let _ = must_lock!(self.inner)
678            .vm
679            .sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into()));
680    }
681
682    pub fn run<'a, Run, Fut, Out>(
683        &'a self,
684        run_closure: Run,
685    ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
686    where
687        Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
688        Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
689        Out: Serialize + Deserialize + 'static,
690    {
691        let this = Arc::clone(&self.inner);
692        InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
693    }
694
695    // Used by codegen
696    pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
697        let mut inner_lock = must_lock!(self.inner);
698
699        let res_to_write = match res {
700            Ok(success) => match T::serialize(&success) {
701                Ok(t) => NonEmptyValue::Success(t),
702                Err(e) => {
703                    inner_lock.fail(
704                        ErrorInner::Serialization {
705                            syscall: "output",
706                            err: Box::new(e),
707                        }
708                        .into(),
709                    );
710                    return;
711                }
712            },
713            Err(e) => match e.0 {
714                HandlerErrorInner::Retryable(err) => {
715                    inner_lock.fail(ErrorInner::HandlerResult { err }.into());
716                    return;
717                }
718                HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
719            },
720        };
721
722        let _ = inner_lock.vm.sys_write_output(res_to_write);
723        inner_lock.maybe_flip_span_replaying_field();
724    }
725
726    pub fn end(&self) {
727        let _ = must_lock!(self.inner).vm.sys_end();
728    }
729
730    pub(crate) fn consume_to_end(&self) {
731        let mut inner_lock = must_lock!(self.inner);
732
733        let out = inner_lock.vm.take_output();
734        if let TakeOutputResult::Buffer(b) = out {
735            if !inner_lock.write.send(b) {
736                // Nothing we can do anymore here
737            }
738        }
739    }
740
741    pub(super) fn fail(&self, e: Error) {
742        must_lock!(self.inner).fail(e)
743    }
744}
745
746pin_project! {
747    struct RunFutureImpl<Run, Ret, RunFnFut> {
748        name: String,
749        retry_policy: RetryPolicy,
750        phantom_data: PhantomData<fn() -> Ret>,
751        #[pin]
752        state: RunState<Run, RunFnFut, Ret>,
753    }
754}
755
756pin_project! {
757    #[project = RunStateProj]
758    enum RunState<Run, RunFnFut, Ret> {
759        New {
760            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
761            closure: Option<Run>,
762        },
763        ClosureRunning {
764            ctx: Option<Arc<Mutex<ContextInternalInner>>>,
765            handle: NotificationHandle,
766            start_time: Instant,
767            #[pin]
768            closure_fut: RunFnFut,
769        },
770        WaitingResultFut {
771            result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
772        }
773    }
774}
775
776impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
777    fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
778        Self {
779            name: "".to_string(),
780            retry_policy: RetryPolicy::Infinite,
781            phantom_data: PhantomData,
782            state: RunState::New {
783                ctx: Some(ctx),
784                closure: Some(closure),
785            },
786        }
787    }
788
789    fn boxed_result_fut(
790        ctx: Arc<Mutex<ContextInternalInner>>,
791        handle: NotificationHandle,
792    ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
793    where
794        Ret: Deserialize,
795    {
796        get_async_result(Arc::clone(&ctx), handle)
797            .map(|res| match res {
798                Ok(Value::Success(mut s)) => {
799                    let t =
800                        Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
801                    Ok(Ok(t))
802                }
803                Ok(Value::Failure(f)) => Ok(Err(f.into())),
804                Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
805                    variant: <&'static str>::from(v),
806                    syscall: "run",
807                }
808                .into()),
809                Err(e) => Err(e),
810            })
811            .boxed()
812    }
813}
814
815impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
816    for RunFutureImpl<Run, Ret, RunFnFut>
817where
818    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
819    Ret: Serialize + Deserialize,
820    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
821{
822    fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
823        self.retry_policy = RetryPolicy::Exponential {
824            initial_interval: retry_policy.initial_delay,
825            factor: retry_policy.factor,
826            max_interval: retry_policy.max_delay,
827            max_attempts: retry_policy.max_attempts,
828            max_duration: retry_policy.max_duration,
829        };
830        self
831    }
832
833    fn name(mut self, name: impl Into<String>) -> Self {
834        self.name = name.into();
835        self
836    }
837}
838
839impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
840where
841    Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
842    Ret: Serialize + Deserialize,
843    RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
844{
845    type Output = Result<Result<Ret, TerminalError>, Error>;
846
847    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
848        let mut this = self.project();
849
850        loop {
851            match this.state.as_mut().project() {
852                RunStateProj::New { ctx, closure, .. } => {
853                    let ctx = ctx
854                        .take()
855                        .expect("Future should not be polled after returning Poll::Ready");
856                    let closure = closure
857                        .take()
858                        .expect("Future should not be polled after returning Poll::Ready");
859                    let mut inner_ctx = must_lock!(ctx);
860
861                    let handle = inner_ctx
862                        .vm
863                        .sys_run(this.name.to_owned())
864                        .map_err(ErrorInner::from)?;
865
866                    // Now we do progress once to check whether this closure should be executed or not.
867                    match inner_ctx.vm.do_progress(vec![handle]) {
868                        Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
869                            // In case it returns ExecuteRun, it must be the handle we just gave it,
870                            // and it means we need to execute the closure
871                            assert_eq!(handle, handle_to_run);
872
873                            drop(inner_ctx);
874                            this.state.set(RunState::ClosureRunning {
875                                ctx: Some(ctx),
876                                handle,
877                                start_time: Instant::now(),
878                                closure_fut: closure.run(),
879                            });
880                        }
881                        Ok(DoProgressResponse::CancelSignalReceived) => {
882                            drop(inner_ctx);
883                            // Got cancellation!
884                            this.state.set(RunState::WaitingResultFut {
885                                result_fut: async {
886                                    Ok(Err(TerminalError::from(TerminalFailure {
887                                        code: 409,
888                                        message: "cancelled".to_string(),
889                                    })))
890                                }
891                                .boxed(),
892                            })
893                        }
894                        _ => {
895                            drop(inner_ctx);
896                            // In all the other cases, just move on waiting the result,
897                            // the poll future state will take care of doing whatever needs to be done here,
898                            // that is propagating state machine error, or result, or whatever
899                            this.state.set(RunState::WaitingResultFut {
900                                result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
901                            })
902                        }
903                    }
904                }
905                RunStateProj::ClosureRunning {
906                    ctx,
907                    handle,
908                    start_time,
909                    closure_fut,
910                } => {
911                    let res = match ready!(closure_fut.poll(cx)) {
912                        Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
913                            ErrorInner::Serialization {
914                                syscall: "run",
915                                err: Box::new(e),
916                            }
917                        })?),
918                        Err(e) => match e.0 {
919                            HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
920                                attempt_duration: start_time.elapsed(),
921                                error: CoreError::new(500u16, err.to_string()),
922                            },
923                            HandlerErrorInner::Terminal(t) => {
924                                RunExitResult::TerminalFailure(TerminalError(t).into())
925                            }
926                        },
927                    };
928
929                    let ctx = ctx
930                        .take()
931                        .expect("Future should not be polled after returning Poll::Ready");
932                    let handle = *handle;
933
934                    let _ = {
935                        must_lock!(ctx).vm.propose_run_completion(
936                            handle,
937                            res,
938                            mem::take(this.retry_policy),
939                        )
940                    };
941
942                    this.state.set(RunState::WaitingResultFut {
943                        result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
944                    });
945                }
946                RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
947            }
948        }
949    }
950}
951
952pin_project! {
953    struct CallFutureImpl<InvIdFut: Future, ResultFut> {
954        #[pin]
955        invocation_id_future: Shared<InvIdFut>,
956        #[pin]
957        result_future: ResultFut,
958        call_notification_handle: NotificationHandle,
959        ctx: ContextInternal,
960    }
961}
962
963impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
964where
965    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
966    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
967{
968    type Output = Result<Res, TerminalError>;
969
970    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
971        let this = self.project();
972        let result = ready!(this.result_future.poll(cx));
973
974        match result {
975            Ok(r) => Poll::Ready(r),
976            Err(e) => {
977                this.ctx.fail(e);
978
979                // Here is the secret sauce. This will immediately cause the whole future chain to be polled,
980                //  but the poll here will be intercepted by HandlerStateAwareFuture
981                cx.waker().wake_by_ref();
982                Poll::Pending
983            }
984        }
985    }
986}
987
988impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
989where
990    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
991{
992    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
993        Shared::clone(&self.invocation_id_future)
994    }
995
996    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
997        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
998        let cloned_ctx = Arc::clone(&self.ctx.inner);
999        async move {
1000            let inv_id = cloned_invocation_id_fut.await?;
1001            let mut inner_lock = must_lock!(cloned_ctx);
1002            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1003            inner_lock.maybe_flip_span_replaying_field();
1004            drop(inner_lock);
1005            Ok(())
1006        }
1007    }
1008}
1009
1010impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1011where
1012    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1013    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1014{
1015    type Response = Res;
1016}
1017
1018impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1019    for CallFutureImpl<InvIdFut, ResultFut>
1020where
1021    InvIdFut: Future,
1022{
1023    fn inner_context(&self) -> ContextInternal {
1024        self.ctx.clone()
1025    }
1026
1027    fn handle(&self) -> NotificationHandle {
1028        self.call_notification_handle
1029    }
1030}
1031
1032impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1033where
1034    InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1035    ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1036{
1037}
1038
1039struct SendRequestHandle<InvIdFut: Future> {
1040    invocation_id_future: Shared<InvIdFut>,
1041    ctx: ContextInternal,
1042}
1043
1044impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1045    for SendRequestHandle<InvIdFut>
1046{
1047    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1048        Shared::clone(&self.invocation_id_future)
1049    }
1050
1051    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1052        let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1053        let cloned_ctx = Arc::clone(&self.ctx.inner);
1054        async move {
1055            let inv_id = cloned_invocation_id_fut.await?;
1056            let mut inner_lock = must_lock!(cloned_ctx);
1057            let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1058            inner_lock.maybe_flip_span_replaying_field();
1059            drop(inner_lock);
1060            Ok(())
1061        }
1062    }
1063}
1064
1065struct InvocationIdBackedInvocationHandle {
1066    ctx: ContextInternal,
1067    invocation_id: String,
1068}
1069
1070impl InvocationHandle for InvocationIdBackedInvocationHandle {
1071    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1072        ready(Ok(self.invocation_id.clone()))
1073    }
1074
1075    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1076        let mut inner_lock = must_lock!(self.ctx.inner);
1077        let _ = inner_lock
1078            .vm
1079            .sys_cancel_invocation(self.invocation_id.clone());
1080        ready(Ok(()))
1081    }
1082}
1083
1084impl<A, B> InvocationHandle for Either<A, B>
1085where
1086    A: InvocationHandle,
1087    B: InvocationHandle,
1088{
1089    fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1090        match self {
1091            Either::Left(l) => Either::Left(l.invocation_id()),
1092            Either::Right(r) => Either::Right(r.invocation_id()),
1093        }
1094    }
1095
1096    fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1097        match self {
1098            Either::Left(l) => Either::Left(l.cancel()),
1099            Either::Right(r) => Either::Right(r.cancel()),
1100        }
1101    }
1102}
1103
1104impl Error {
1105    fn serialization<E: std::error::Error + Send + Sync + 'static>(
1106        syscall: &'static str,
1107        e: E,
1108    ) -> Self {
1109        ErrorInner::Serialization {
1110            syscall,
1111            err: Box::new(e),
1112        }
1113        .into()
1114    }
1115
1116    fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1117        syscall: &'static str,
1118        e: E,
1119    ) -> Self {
1120        ErrorInner::Deserialization {
1121            syscall,
1122            err: Box::new(e),
1123        }
1124        .into()
1125    }
1126}
1127
1128fn get_async_result(
1129    ctx: Arc<Mutex<ContextInternalInner>>,
1130    handle: NotificationHandle,
1131) -> impl Future<Output = Result<Value, Error>> + Send {
1132    VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1133}