1use crate::context::{
2 CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3 RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18 CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19 PayloadOptions, RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM,
20 Value,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::future::{Future, ready};
25use std::marker::PhantomData;
26use std::mem;
27use std::pin::Pin;
28use std::sync::{Arc, Mutex};
29use std::task::{Context, Poll, ready};
30use std::time::{Duration, Instant, SystemTime};
31
32pub struct ContextInternalInner {
33 pub(crate) vm: CoreVM,
34 pub(crate) read: InputReceiver,
35 pub(crate) write: OutputSender,
36 pub(super) handler_state: HandlerStateNotifier,
37
38 pub(super) span_replaying_field_state: bool,
41}
42
43impl ContextInternalInner {
44 fn new(
45 vm: CoreVM,
46 read: InputReceiver,
47 write: OutputSender,
48 handler_state: HandlerStateNotifier,
49 ) -> Self {
50 Self {
51 vm,
52 read,
53 write,
54 handler_state,
55 span_replaying_field_state: false,
56 }
57 }
58
59 pub(super) fn fail(&mut self, e: Error) {
60 self.maybe_flip_span_replaying_field();
61 self.vm.notify_error(
62 CoreError::new(500u16, e.0.to_string())
63 .with_stacktrace(Cow::<str>::Owned(format!("{:#}", e.0))),
64 None,
65 );
66 self.handler_state.mark_error(e);
67 }
68
69 pub(super) fn maybe_flip_span_replaying_field(&mut self) {
70 if !self.span_replaying_field_state && self.vm.is_replaying() {
71 tracing::Span::current().record("restate.sdk.is_replaying", true);
72 self.span_replaying_field_state = true;
73 } else if self.span_replaying_field_state && !self.vm.is_replaying() {
74 tracing::Span::current().record("restate.sdk.is_replaying", false);
75 self.span_replaying_field_state = false;
76 }
77 }
78}
79
80#[allow(unused)]
81const fn is_send_sync<T: Send + Sync>() {}
82const _: () = is_send_sync::<ContextInternal>();
83
84macro_rules! must_lock {
85 ($mutex:expr) => {
86 $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
87 };
88}
89
90macro_rules! unwrap_or_trap {
91 ($inner_lock:expr, $res:expr) => {
92 match $res {
93 Ok(t) => t,
94 Err(e) => {
95 $inner_lock.fail(e.into());
96 return Either::Right(TrapFuture::default());
97 }
98 }
99 };
100}
101
102macro_rules! unwrap_or_trap_durable_future {
103 ($ctx:expr, $inner_lock:expr, $res:expr) => {
104 match $res {
105 Ok(t) => t,
106 Err(e) => {
107 $inner_lock.fail(e.into());
108 return DurableFutureImpl::new(
109 $ctx.clone(),
110 NotificationHandle::from(u32::MAX),
111 Either::Right(TrapFuture::default()),
112 );
113 }
114 }
115 };
116}
117
118#[derive(Debug, Eq, PartialEq)]
119pub struct InputMetadata {
120 pub invocation_id: String,
121 pub random_seed: u64,
122 pub key: String,
123 pub headers: http::HeaderMap<String>,
124}
125
126impl From<RequestTarget> for Target {
127 fn from(value: RequestTarget) -> Self {
128 match value {
129 RequestTarget::Service { name, handler } => Target {
130 service: name,
131 handler,
132 key: None,
133 idempotency_key: None,
134 headers: vec![],
135 },
136 RequestTarget::Object { name, key, handler } => Target {
137 service: name,
138 handler,
139 key: Some(key),
140 idempotency_key: None,
141 headers: vec![],
142 },
143 RequestTarget::Workflow { name, key, handler } => Target {
144 service: name,
145 handler,
146 key: Some(key),
147 idempotency_key: None,
148 headers: vec![],
149 },
150 }
151 }
152}
153
154#[derive(Clone)]
158pub struct ContextInternal {
159 svc_name: String,
160 handler_name: String,
161 inner: Arc<Mutex<ContextInternalInner>>,
162}
163
164impl ContextInternal {
165 pub(super) fn new(
166 vm: CoreVM,
167 svc_name: String,
168 handler_name: String,
169 read: InputReceiver,
170 write: OutputSender,
171 handler_state: HandlerStateNotifier,
172 ) -> Self {
173 Self {
174 svc_name,
175 handler_name,
176 inner: Arc::new(Mutex::new(ContextInternalInner::new(
177 vm,
178 read,
179 write,
180 handler_state,
181 ))),
182 }
183 }
184
185 pub fn service_name(&self) -> &str {
186 &self.svc_name
187 }
188
189 pub fn handler_name(&self) -> &str {
190 &self.handler_name
191 }
192
193 pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
194 let mut inner_lock = must_lock!(self.inner);
195 let input_result =
196 inner_lock
197 .vm
198 .sys_input()
199 .map_err(ErrorInner::VM)
200 .map(|mut raw_input| {
201 let headers = http::HeaderMap::<String>::try_from(
202 &raw_input
203 .headers
204 .into_iter()
205 .map(|h| (h.key.to_string(), h.value.to_string()))
206 .collect::<HashMap<String, String>>(),
207 )
208 .map_err(|e| {
209 TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
210 })?;
211
212 Ok::<_, TerminalError>((
213 T::deserialize(&mut (raw_input.input)).map_err(|e| {
214 TerminalError::new_with_code(
215 400,
216 format!("Cannot decode input payload: {e:?}"),
217 )
218 })?,
219 InputMetadata {
220 invocation_id: raw_input.invocation_id,
221 random_seed: raw_input.random_seed,
222 key: raw_input.key,
223 headers,
224 },
225 ))
226 });
227 inner_lock.maybe_flip_span_replaying_field();
228
229 match input_result {
230 Ok(Ok(i)) => {
231 drop(inner_lock);
232 return Either::Left(ready(i));
233 }
234 Ok(Err(err)) => {
235 let error_inner = ErrorInner::Deserialization {
236 syscall: "input",
237 err: err.0.clone().into(),
238 };
239 let _ = inner_lock
240 .vm
241 .sys_write_output(NonEmptyValue::Failure(err.into()), PayloadOptions::stable());
242 let _ = inner_lock.vm.sys_end();
243 inner_lock.handler_state.mark_error(error_inner.into());
245 drop(inner_lock);
246 }
247 Err(e) => {
248 inner_lock.fail(e.into());
249 drop(inner_lock);
250 }
251 }
252 Either::Right(TrapFuture::default())
253 }
254
255 pub fn get<T: Deserialize>(
256 &self,
257 key: &str,
258 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
259 let mut inner_lock = must_lock!(self.inner);
260 let handle = unwrap_or_trap!(
261 inner_lock,
262 inner_lock
263 .vm
264 .sys_state_get(key.to_owned(), PayloadOptions::stable())
265 );
266 inner_lock.maybe_flip_span_replaying_field();
267
268 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
269 Ok(Value::Void) => Ok(Ok(None)),
270 Ok(Value::Success(mut s)) => {
271 let t =
272 T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
273 Ok(Ok(Some(t)))
274 }
275 Ok(Value::Failure(f)) => Ok(Err(f.into())),
276 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
277 variant: <&'static str>::from(v),
278 syscall: "get_state",
279 }
280 .into()),
281 Err(e) => Err(e),
282 });
283
284 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
285 }
286
287 pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
288 let mut inner_lock = must_lock!(self.inner);
289 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
290 inner_lock.maybe_flip_span_replaying_field();
291
292 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
293 Ok(Value::Failure(f)) => Ok(Err(f.into())),
294 Ok(Value::StateKeys(s)) => Ok(Ok(s)),
295 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
296 variant: <&'static str>::from(v),
297 syscall: "get_keys",
298 }
299 .into()),
300 Err(e) => Err(e),
301 });
302
303 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
304 }
305
306 pub fn set<T: Serialize>(&self, key: &str, t: T) {
307 let mut inner_lock = must_lock!(self.inner);
308 match t.serialize() {
309 Ok(b) => {
310 let _ = inner_lock
311 .vm
312 .sys_state_set(key.to_owned(), b, PayloadOptions::stable());
313 inner_lock.maybe_flip_span_replaying_field();
314 }
315 Err(e) => {
316 inner_lock.fail(Error::serialization("set_state", e));
317 }
318 }
319 }
320
321 pub fn clear(&self, key: &str) {
322 let mut inner_lock = must_lock!(self.inner);
323 let _ = inner_lock.vm.sys_state_clear(key.to_string());
324 inner_lock.maybe_flip_span_replaying_field();
325 }
326
327 pub fn clear_all(&self) {
328 let mut inner_lock = must_lock!(self.inner);
329 let _ = inner_lock.vm.sys_state_clear_all();
330 inner_lock.maybe_flip_span_replaying_field();
331 }
332
333 pub fn select(
334 &self,
335 handles: Vec<NotificationHandle>,
336 ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
337 InterceptErrorFuture::new(
338 self.clone(),
339 VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
340 )
341 }
342
343 pub fn sleep(
344 &self,
345 sleep_duration: Duration,
346 ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
347 let now = SystemTime::now()
348 .duration_since(SystemTime::UNIX_EPOCH)
349 .expect("Duration since unix epoch cannot fail");
350 let mut inner_lock = must_lock!(self.inner);
351 let handle = unwrap_or_trap_durable_future!(
352 self,
353 inner_lock,
354 inner_lock
355 .vm
356 .sys_sleep(String::default(), now + sleep_duration, Some(now))
357 );
358 inner_lock.maybe_flip_span_replaying_field();
359
360 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
361 Ok(Value::Void) => Ok(Ok(())),
362 Ok(Value::Failure(f)) => Ok(Err(f.into())),
363 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
364 variant: <&'static str>::from(v),
365 syscall: "sleep",
366 }
367 .into()),
368 Err(e) => Err(e),
369 });
370
371 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
372 }
373
374 pub fn request<Req, Res>(
375 &self,
376 request_target: RequestTarget,
377 req: Req,
378 ) -> Request<'_, Req, Res> {
379 Request::new(self, request_target, req)
380 }
381
382 pub fn call<Req: Serialize, Res: Deserialize>(
383 &self,
384 request_target: RequestTarget,
385 idempotency_key: Option<String>,
386 headers: Vec<(String, String)>,
387 req: Req,
388 ) -> impl CallFuture<Response = Res> + Send {
389 let mut inner_lock = must_lock!(self.inner);
390
391 let mut target: Target = request_target.into();
392 target.idempotency_key = idempotency_key;
393 target.headers = headers
394 .into_iter()
395 .map(|(k, v)| Header {
396 key: k.into(),
397 value: v.into(),
398 })
399 .collect();
400 let call_result = Req::serialize(&req)
401 .map_err(|e| Error::serialization("call", e))
402 .and_then(|input| {
403 inner_lock
404 .vm
405 .sys_call(target, input, None, PayloadOptions::stable())
406 .map_err(Into::into)
407 });
408
409 let call_handle = match call_result {
410 Ok(t) => t,
411 Err(e) => {
412 inner_lock.fail(e);
413 return CallFutureImpl {
414 invocation_id_future: Either::Right(TrapFuture::default()).shared(),
415 result_future: Either::Right(TrapFuture::default()),
416 call_notification_handle: NotificationHandle::from(u32::MAX),
417 ctx: self.clone(),
418 };
419 }
420 };
421 inner_lock.maybe_flip_span_replaying_field();
422 drop(inner_lock);
423
424 let invocation_id_fut = InterceptErrorFuture::new(
426 self.clone(),
427 get_async_result(
428 Arc::clone(&self.inner),
429 call_handle.invocation_id_notification_handle,
430 )
431 .map(|res| match res {
432 Ok(Value::Failure(f)) => Ok(Err(f.into())),
433 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
434 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
435 variant: <&'static str>::from(v),
436 syscall: "call",
437 }
438 .into()),
439 Err(e) => Err(e),
440 }),
441 );
442 let result_future = get_async_result(
443 Arc::clone(&self.inner),
444 call_handle.call_notification_handle,
445 )
446 .map(|res| match res {
447 Ok(Value::Success(mut s)) => Ok(Ok(
448 Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
449 )),
450 Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
451 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
452 variant: <&'static str>::from(v),
453 syscall: "call",
454 }
455 .into()),
456 Err(e) => Err(e),
457 });
458
459 CallFutureImpl {
460 invocation_id_future: Either::Left(invocation_id_fut).shared(),
461 result_future: Either::Left(result_future),
462 call_notification_handle: call_handle.call_notification_handle,
463 ctx: self.clone(),
464 }
465 }
466
467 pub fn send<Req: Serialize>(
468 &self,
469 request_target: RequestTarget,
470 idempotency_key: Option<String>,
471 headers: Vec<(String, String)>,
472 req: Req,
473 delay: Option<Duration>,
474 ) -> impl InvocationHandle {
475 let mut inner_lock = must_lock!(self.inner);
476
477 let mut target: Target = request_target.into();
478 target.idempotency_key = idempotency_key;
479 target.headers = headers
480 .into_iter()
481 .map(|(k, v)| Header {
482 key: k.into(),
483 value: v.into(),
484 })
485 .collect();
486 let input = match Req::serialize(&req) {
487 Ok(b) => b,
488 Err(e) => {
489 inner_lock.fail(Error::serialization("call", e));
490 return Either::Right(TrapFuture::<()>::default());
491 }
492 };
493
494 let send_handle = match inner_lock.vm.sys_send(
495 target,
496 input,
497 delay.map(|delay| {
498 SystemTime::now()
499 .duration_since(SystemTime::UNIX_EPOCH)
500 .expect("Duration since unix epoch cannot fail")
501 + delay
502 }),
503 None,
504 PayloadOptions::stable(),
505 ) {
506 Ok(h) => h,
507 Err(e) => {
508 inner_lock.fail(e.into());
509 return Either::Right(TrapFuture::<()>::default());
510 }
511 };
512 inner_lock.maybe_flip_span_replaying_field();
513 drop(inner_lock);
514
515 let invocation_id_fut = InterceptErrorFuture::new(
516 self.clone(),
517 get_async_result(
518 Arc::clone(&self.inner),
519 send_handle.invocation_id_notification_handle,
520 )
521 .map(|res| match res {
522 Ok(Value::Failure(f)) => Ok(Err(f.into())),
523 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
524 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
525 variant: <&'static str>::from(v),
526 syscall: "call",
527 }
528 .into()),
529 Err(e) => Err(e),
530 }),
531 );
532
533 Either::Left(SendRequestHandle {
534 invocation_id_future: invocation_id_fut.shared(),
535 ctx: self.clone(),
536 })
537 }
538
539 pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
540 InvocationIdBackedInvocationHandle {
541 ctx: self.clone(),
542 invocation_id,
543 }
544 }
545
546 pub fn awakeable<T: Deserialize>(
547 &self,
548 ) -> (
549 String,
550 impl DurableFuture<Output = Result<T, TerminalError>> + Send,
551 ) {
552 let mut inner_lock = must_lock!(self.inner);
553 let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
554 inner_lock.maybe_flip_span_replaying_field();
555
556 let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
557 Ok((s, handle)) => (s, handle),
558 Err(e) => {
559 inner_lock.fail(e.into());
560 return (
561 "invalid".to_owned(),
564 DurableFutureImpl::new(
565 self.clone(),
566 NotificationHandle::from(u32::MAX),
567 Either::Right(TrapFuture::default()),
568 ),
569 );
570 }
571 };
572 drop(inner_lock);
573
574 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
575 Ok(Value::Success(mut s)) => Ok(Ok(
576 T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
577 )),
578 Ok(Value::Failure(f)) => Ok(Err(f.into())),
579 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
580 variant: <&'static str>::from(v),
581 syscall: "awakeable",
582 }
583 .into()),
584 Err(e) => Err(e),
585 });
586
587 (
588 awakeable_id,
589 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
590 )
591 }
592
593 pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
594 let mut inner_lock = must_lock!(self.inner);
595 match t.serialize() {
596 Ok(b) => {
597 let _ = inner_lock.vm.sys_complete_awakeable(
598 id.to_owned(),
599 NonEmptyValue::Success(b),
600 PayloadOptions::stable(),
601 );
602 }
603 Err(e) => {
604 inner_lock.fail(Error::serialization("resolve_awakeable", e));
605 }
606 }
607 }
608
609 pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
610 let _ = must_lock!(self.inner).vm.sys_complete_awakeable(
611 id.to_owned(),
612 NonEmptyValue::Failure(failure.into()),
613 PayloadOptions::stable(),
614 );
615 }
616
617 pub fn promise<T: Deserialize>(
618 &self,
619 name: &str,
620 ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
621 let mut inner_lock = must_lock!(self.inner);
622 let handle = unwrap_or_trap_durable_future!(
623 self,
624 inner_lock,
625 inner_lock.vm.sys_get_promise(name.to_owned())
626 );
627 inner_lock.maybe_flip_span_replaying_field();
628 drop(inner_lock);
629
630 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
631 Ok(Value::Success(mut s)) => {
632 let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
633 Ok(Ok(t))
634 }
635 Ok(Value::Failure(f)) => Ok(Err(f.into())),
636 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
637 variant: <&'static str>::from(v),
638 syscall: "promise",
639 }
640 .into()),
641 Err(e) => Err(e),
642 });
643
644 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
645 }
646
647 pub fn peek_promise<T: Deserialize>(
648 &self,
649 name: &str,
650 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
651 let mut inner_lock = must_lock!(self.inner);
652 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
653 inner_lock.maybe_flip_span_replaying_field();
654 drop(inner_lock);
655
656 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
657 Ok(Value::Void) => Ok(Ok(None)),
658 Ok(Value::Success(mut s)) => {
659 let t = T::deserialize(&mut s)
660 .map_err(|e| Error::deserialization("peek_promise", e))?;
661 Ok(Ok(Some(t)))
662 }
663 Ok(Value::Failure(f)) => Ok(Err(f.into())),
664 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
665 variant: <&'static str>::from(v),
666 syscall: "peek_promise",
667 }
668 .into()),
669 Err(e) => Err(e),
670 });
671
672 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
673 }
674
675 pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
676 let mut inner_lock = must_lock!(self.inner);
677 match t.serialize() {
678 Ok(b) => {
679 let _ = inner_lock.vm.sys_complete_promise(
680 name.to_owned(),
681 NonEmptyValue::Success(b),
682 PayloadOptions::stable(),
683 );
684 }
685 Err(e) => {
686 inner_lock.fail(
687 ErrorInner::Serialization {
688 syscall: "resolve_promise",
689 err: Box::new(e),
690 }
691 .into(),
692 );
693 }
694 }
695 }
696
697 pub fn reject_promise(&self, id: &str, failure: TerminalError) {
698 let _ = must_lock!(self.inner).vm.sys_complete_promise(
699 id.to_owned(),
700 NonEmptyValue::Failure(failure.into()),
701 PayloadOptions::stable(),
702 );
703 }
704
705 pub fn run<'a, Run, Fut, Out>(
706 &'a self,
707 run_closure: Run,
708 ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
709 where
710 Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
711 Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
712 Out: Serialize + Deserialize + 'static,
713 {
714 let this = Arc::clone(&self.inner);
715 InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
716 }
717
718 pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
720 let mut inner_lock = must_lock!(self.inner);
721
722 let res_to_write = match res {
723 Ok(success) => match T::serialize(&success) {
724 Ok(t) => NonEmptyValue::Success(t),
725 Err(e) => {
726 inner_lock.fail(
727 ErrorInner::Serialization {
728 syscall: "output",
729 err: Box::new(e),
730 }
731 .into(),
732 );
733 return;
734 }
735 },
736 Err(e) => match e.0 {
737 HandlerErrorInner::Retryable(err) => {
738 inner_lock.fail(ErrorInner::HandlerResult { err }.into());
739 return;
740 }
741 HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
742 },
743 };
744
745 let _ = inner_lock
746 .vm
747 .sys_write_output(res_to_write, PayloadOptions::stable());
748 inner_lock.maybe_flip_span_replaying_field();
749 }
750
751 pub fn end(&self) {
752 let _ = must_lock!(self.inner).vm.sys_end();
753 }
754
755 pub(crate) fn consume_to_end(&self) {
756 let mut inner_lock = must_lock!(self.inner);
757
758 let out = inner_lock.vm.take_output();
759 if let TakeOutputResult::Buffer(b) = out
760 && !inner_lock.write.send(b)
761 {
762 }
764 }
765
766 pub(super) fn fail(&self, e: Error) {
767 must_lock!(self.inner).fail(e)
768 }
769}
770
771pin_project! {
772 struct RunFutureImpl<Run, Ret, RunFnFut> {
773 name: String,
774 retry_policy: RetryPolicy,
775 phantom_data: PhantomData<fn() -> Ret>,
776 #[pin]
777 state: RunState<Run, RunFnFut, Ret>,
778 }
779}
780
781pin_project! {
782 #[project = RunStateProj]
783 enum RunState<Run, RunFnFut, Ret> {
784 New {
785 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
786 closure: Option<Run>,
787 },
788 ClosureRunning {
789 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
790 handle: NotificationHandle,
791 start_time: Instant,
792 #[pin]
793 closure_fut: RunFnFut,
794 },
795 WaitingResultFut {
796 result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
797 }
798 }
799}
800
801impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
802 fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
803 Self {
804 name: "".to_string(),
805 retry_policy: RetryPolicy::Infinite,
806 phantom_data: PhantomData,
807 state: RunState::New {
808 ctx: Some(ctx),
809 closure: Some(closure),
810 },
811 }
812 }
813
814 fn boxed_result_fut(
815 ctx: Arc<Mutex<ContextInternalInner>>,
816 handle: NotificationHandle,
817 ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
818 where
819 Ret: Deserialize,
820 {
821 get_async_result(Arc::clone(&ctx), handle)
822 .map(|res| match res {
823 Ok(Value::Success(mut s)) => {
824 let t =
825 Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
826 Ok(Ok(t))
827 }
828 Ok(Value::Failure(f)) => Ok(Err(f.into())),
829 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
830 variant: <&'static str>::from(v),
831 syscall: "run",
832 }
833 .into()),
834 Err(e) => Err(e),
835 })
836 .boxed()
837 }
838}
839
840impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
841 for RunFutureImpl<Run, Ret, RunFnFut>
842where
843 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
844 Ret: Serialize + Deserialize,
845 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
846{
847 fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
848 self.retry_policy = RetryPolicy::Exponential {
849 initial_interval: retry_policy.initial_delay,
850 factor: retry_policy.factor,
851 max_interval: retry_policy.max_delay,
852 max_attempts: retry_policy.max_attempts,
853 max_duration: retry_policy.max_duration,
854 };
855 self
856 }
857
858 fn name(mut self, name: impl Into<String>) -> Self {
859 self.name = name.into();
860 self
861 }
862}
863
864impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
865where
866 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
867 Ret: Serialize + Deserialize,
868 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
869{
870 type Output = Result<Result<Ret, TerminalError>, Error>;
871
872 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
873 let mut this = self.project();
874
875 loop {
876 match this.state.as_mut().project() {
877 RunStateProj::New { ctx, closure, .. } => {
878 let ctx = ctx
879 .take()
880 .expect("Future should not be polled after returning Poll::Ready");
881 let closure = closure
882 .take()
883 .expect("Future should not be polled after returning Poll::Ready");
884 let mut inner_ctx = must_lock!(ctx);
885
886 let handle = inner_ctx
887 .vm
888 .sys_run(this.name.to_owned())
889 .map_err(ErrorInner::from)?;
890
891 match inner_ctx.vm.do_progress(vec![handle]) {
893 Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
894 assert_eq!(handle, handle_to_run);
897
898 drop(inner_ctx);
899 this.state.set(RunState::ClosureRunning {
900 ctx: Some(ctx),
901 handle,
902 start_time: Instant::now(),
903 closure_fut: closure.run(),
904 });
905 }
906 Ok(DoProgressResponse::CancelSignalReceived) => {
907 drop(inner_ctx);
908 this.state.set(RunState::WaitingResultFut {
910 result_fut: async {
911 Ok(Err(TerminalError::from(TerminalFailure {
912 code: 409,
913 message: "cancelled".to_string(),
914 metadata: vec![],
915 })))
916 }
917 .boxed(),
918 })
919 }
920 _ => {
921 drop(inner_ctx);
922 this.state.set(RunState::WaitingResultFut {
926 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
927 })
928 }
929 }
930 }
931 RunStateProj::ClosureRunning {
932 ctx,
933 handle,
934 start_time,
935 closure_fut,
936 } => {
937 let res = match ready!(closure_fut.poll(cx)) {
938 Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
939 ErrorInner::Serialization {
940 syscall: "run",
941 err: Box::new(e),
942 }
943 })?),
944 Err(e) => match e.0 {
945 HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
946 attempt_duration: start_time.elapsed(),
947 error: CoreError::new(500u16, err.to_string()),
948 },
949 HandlerErrorInner::Terminal(t) => {
950 RunExitResult::TerminalFailure(TerminalError(t).into())
951 }
952 },
953 };
954
955 let ctx = ctx
956 .take()
957 .expect("Future should not be polled after returning Poll::Ready");
958 let handle = *handle;
959
960 let _ = {
961 must_lock!(ctx).vm.propose_run_completion(
962 handle,
963 res,
964 mem::take(this.retry_policy),
965 )
966 };
967
968 this.state.set(RunState::WaitingResultFut {
969 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
970 });
971 }
972 RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
973 }
974 }
975 }
976}
977
978pin_project! {
979 struct CallFutureImpl<InvIdFut: Future, ResultFut> {
980 #[pin]
981 invocation_id_future: Shared<InvIdFut>,
982 #[pin]
983 result_future: ResultFut,
984 call_notification_handle: NotificationHandle,
985 ctx: ContextInternal,
986 }
987}
988
989impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
990where
991 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
992 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
993{
994 type Output = Result<Res, TerminalError>;
995
996 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
997 let this = self.project();
998 let result = ready!(this.result_future.poll(cx));
999
1000 match result {
1001 Ok(r) => Poll::Ready(r),
1002 Err(e) => {
1003 this.ctx.fail(e);
1004
1005 cx.waker().wake_by_ref();
1008 Poll::Pending
1009 }
1010 }
1011 }
1012}
1013
1014impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
1015where
1016 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1017{
1018 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1019 Shared::clone(&self.invocation_id_future)
1020 }
1021
1022 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1023 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1024 let cloned_ctx = Arc::clone(&self.ctx.inner);
1025 async move {
1026 let inv_id = cloned_invocation_id_fut.await?;
1027 let mut inner_lock = must_lock!(cloned_ctx);
1028 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1029 inner_lock.maybe_flip_span_replaying_field();
1030 drop(inner_lock);
1031 Ok(())
1032 }
1033 }
1034}
1035
1036impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1037where
1038 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1039 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1040{
1041 type Response = Res;
1042}
1043
1044impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1045 for CallFutureImpl<InvIdFut, ResultFut>
1046where
1047 InvIdFut: Future,
1048{
1049 fn inner_context(&self) -> ContextInternal {
1050 self.ctx.clone()
1051 }
1052
1053 fn handle(&self) -> NotificationHandle {
1054 self.call_notification_handle
1055 }
1056}
1057
1058impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1059where
1060 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1061 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1062{
1063}
1064
1065struct SendRequestHandle<InvIdFut: Future> {
1066 invocation_id_future: Shared<InvIdFut>,
1067 ctx: ContextInternal,
1068}
1069
1070impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1071 for SendRequestHandle<InvIdFut>
1072{
1073 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1074 Shared::clone(&self.invocation_id_future)
1075 }
1076
1077 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1078 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1079 let cloned_ctx = Arc::clone(&self.ctx.inner);
1080 async move {
1081 let inv_id = cloned_invocation_id_fut.await?;
1082 let mut inner_lock = must_lock!(cloned_ctx);
1083 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1084 inner_lock.maybe_flip_span_replaying_field();
1085 drop(inner_lock);
1086 Ok(())
1087 }
1088 }
1089}
1090
1091struct InvocationIdBackedInvocationHandle {
1092 ctx: ContextInternal,
1093 invocation_id: String,
1094}
1095
1096impl InvocationHandle for InvocationIdBackedInvocationHandle {
1097 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1098 ready(Ok(self.invocation_id.clone()))
1099 }
1100
1101 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1102 let mut inner_lock = must_lock!(self.ctx.inner);
1103 let _ = inner_lock
1104 .vm
1105 .sys_cancel_invocation(self.invocation_id.clone());
1106 ready(Ok(()))
1107 }
1108}
1109
1110impl<A, B> InvocationHandle for Either<A, B>
1111where
1112 A: InvocationHandle,
1113 B: InvocationHandle,
1114{
1115 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1116 match self {
1117 Either::Left(l) => Either::Left(l.invocation_id()),
1118 Either::Right(r) => Either::Right(r.invocation_id()),
1119 }
1120 }
1121
1122 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1123 match self {
1124 Either::Left(l) => Either::Left(l.cancel()),
1125 Either::Right(r) => Either::Right(r.cancel()),
1126 }
1127 }
1128}
1129
1130impl Error {
1131 fn serialization<E: std::error::Error + Send + Sync + 'static>(
1132 syscall: &'static str,
1133 e: E,
1134 ) -> Self {
1135 ErrorInner::Serialization {
1136 syscall,
1137 err: Box::new(e),
1138 }
1139 .into()
1140 }
1141
1142 fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1143 syscall: &'static str,
1144 e: E,
1145 ) -> Self {
1146 ErrorInner::Deserialization {
1147 syscall,
1148 err: Box::new(e),
1149 }
1150 .into()
1151 }
1152}
1153
1154fn get_async_result(
1155 ctx: Arc<Mutex<ContextInternalInner>>,
1156 handle: NotificationHandle,
1157) -> impl Future<Output = Result<Value, Error>> + Send {
1158 VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1159}