1use crate::context::{
2 CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3 RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18 CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19 PayloadOptions, RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM,
20 Value,
21};
22use std::borrow::Cow;
23use std::collections::HashMap;
24use std::future::{Future, poll_fn, ready};
25use std::marker::PhantomData;
26use std::mem;
27use std::pin::Pin;
28use std::sync::{Arc, Mutex};
29use std::task::{Context, Poll, ready};
30use std::time::{Duration, Instant, SystemTime};
31
32pub struct ContextInternalInner {
33 pub(crate) vm: CoreVM,
34 pub(crate) read: InputReceiver,
35 pub(crate) write: OutputSender,
36 pub(super) handler_state: HandlerStateNotifier,
37
38 pub(super) span_replaying_field_state: bool,
41}
42
43impl ContextInternalInner {
44 fn new(
45 vm: CoreVM,
46 read: InputReceiver,
47 write: OutputSender,
48 handler_state: HandlerStateNotifier,
49 ) -> Self {
50 Self {
51 vm,
52 read,
53 write,
54 handler_state,
55 span_replaying_field_state: false,
56 }
57 }
58
59 pub(super) fn fail(&mut self, e: Error) {
60 self.maybe_flip_span_replaying_field();
61 self.vm.notify_error(
62 CoreError::new(500u16, e.0.to_string())
63 .with_stacktrace(Cow::<str>::Owned(format!("{:#}", e.0))),
64 None,
65 );
66 self.handler_state.mark_error(e);
67 }
68
69 pub(super) fn maybe_flip_span_replaying_field(&mut self) {
70 if !self.span_replaying_field_state && self.vm.is_replaying() {
71 tracing::Span::current().record("restate.sdk.is_replaying", true);
72 self.span_replaying_field_state = true;
73 } else if self.span_replaying_field_state && !self.vm.is_replaying() {
74 tracing::Span::current().record("restate.sdk.is_replaying", false);
75 self.span_replaying_field_state = false;
76 }
77 }
78}
79
80#[allow(unused)]
81const fn is_send_sync<T: Send + Sync>() {}
82const _: () = is_send_sync::<ContextInternal>();
83
84macro_rules! must_lock {
85 ($mutex:expr) => {
86 $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
87 };
88}
89
90macro_rules! unwrap_or_trap {
91 ($inner_lock:expr, $res:expr) => {
92 match $res {
93 Ok(t) => t,
94 Err(e) => {
95 $inner_lock.fail(e.into());
96 return Either::Right(TrapFuture::default());
97 }
98 }
99 };
100}
101
102macro_rules! unwrap_or_trap_durable_future {
103 ($ctx:expr, $inner_lock:expr, $res:expr) => {
104 match $res {
105 Ok(t) => t,
106 Err(e) => {
107 $inner_lock.fail(e.into());
108 return DurableFutureImpl::new(
109 $ctx.clone(),
110 NotificationHandle::from(u32::MAX),
111 Either::Right(TrapFuture::default()),
112 );
113 }
114 }
115 };
116}
117
118#[derive(Debug, Eq, PartialEq)]
119pub struct InputMetadata {
120 pub invocation_id: String,
121 pub random_seed: u64,
122 pub key: String,
123 pub headers: http::HeaderMap<String>,
124}
125
126impl From<RequestTarget> for Target {
127 fn from(value: RequestTarget) -> Self {
128 match value {
129 RequestTarget::Service { name, handler } => Target {
130 service: name,
131 handler,
132 key: None,
133 idempotency_key: None,
134 headers: vec![],
135 },
136 RequestTarget::Object { name, key, handler } => Target {
137 service: name,
138 handler,
139 key: Some(key),
140 idempotency_key: None,
141 headers: vec![],
142 },
143 RequestTarget::Workflow { name, key, handler } => Target {
144 service: name,
145 handler,
146 key: Some(key),
147 idempotency_key: None,
148 headers: vec![],
149 },
150 }
151 }
152}
153
154#[derive(Clone)]
158pub struct ContextInternal {
159 svc_name: String,
160 handler_name: String,
161 inner: Arc<Mutex<ContextInternalInner>>,
162}
163
164impl ContextInternal {
165 pub(super) fn new(
166 vm: CoreVM,
167 svc_name: String,
168 handler_name: String,
169 read: InputReceiver,
170 write: OutputSender,
171 handler_state: HandlerStateNotifier,
172 ) -> Self {
173 Self {
174 svc_name,
175 handler_name,
176 inner: Arc::new(Mutex::new(ContextInternalInner::new(
177 vm,
178 read,
179 write,
180 handler_state,
181 ))),
182 }
183 }
184
185 pub fn service_name(&self) -> &str {
186 &self.svc_name
187 }
188
189 pub fn handler_name(&self) -> &str {
190 &self.handler_name
191 }
192
193 pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
194 let mut inner_lock = must_lock!(self.inner);
195 let input_result =
196 inner_lock
197 .vm
198 .sys_input()
199 .map_err(ErrorInner::VM)
200 .map(|mut raw_input| {
201 let headers = http::HeaderMap::<String>::try_from(
202 &raw_input
203 .headers
204 .into_iter()
205 .map(|h| (h.key.to_string(), h.value.to_string()))
206 .collect::<HashMap<String, String>>(),
207 )
208 .map_err(|e| {
209 TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
210 })?;
211
212 Ok::<_, TerminalError>((
213 T::deserialize(&mut (raw_input.input)).map_err(|e| {
214 TerminalError::new_with_code(
215 400,
216 format!("Cannot decode input payload: {e:?}"),
217 )
218 })?,
219 InputMetadata {
220 invocation_id: raw_input.invocation_id,
221 random_seed: raw_input.random_seed,
222 key: raw_input.key,
223 headers,
224 },
225 ))
226 });
227 inner_lock.maybe_flip_span_replaying_field();
228
229 match input_result {
230 Ok(Ok(i)) => {
231 drop(inner_lock);
232 return Either::Left(ready(i));
233 }
234 Ok(Err(err)) => {
235 let error_inner = ErrorInner::Deserialization {
236 syscall: "input",
237 err: err.0.clone().into(),
238 };
239 let _ = inner_lock
240 .vm
241 .sys_write_output(NonEmptyValue::Failure(err.into()), PayloadOptions::stable());
242 let _ = inner_lock.vm.sys_end();
243 inner_lock.handler_state.mark_error(error_inner.into());
245 drop(inner_lock);
246 }
247 Err(e) => {
248 inner_lock.fail(e.into());
249 drop(inner_lock);
250 }
251 }
252 Either::Right(TrapFuture::default())
253 }
254
255 pub fn get<T: Deserialize>(
256 &self,
257 key: &str,
258 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
259 let mut inner_lock = must_lock!(self.inner);
260 let handle = unwrap_or_trap!(
261 inner_lock,
262 inner_lock
263 .vm
264 .sys_state_get(key.to_owned(), PayloadOptions::stable())
265 );
266 inner_lock.maybe_flip_span_replaying_field();
267
268 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
269 Ok(Value::Void) => Ok(Ok(None)),
270 Ok(Value::Success(mut s)) => {
271 let t =
272 T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
273 Ok(Ok(Some(t)))
274 }
275 Ok(Value::Failure(f)) => Ok(Err(f.into())),
276 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
277 variant: <&'static str>::from(v),
278 syscall: "get_state",
279 }
280 .into()),
281 Err(e) => Err(e),
282 });
283
284 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
285 }
286
287 pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
288 let mut inner_lock = must_lock!(self.inner);
289 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
290 inner_lock.maybe_flip_span_replaying_field();
291
292 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
293 Ok(Value::Failure(f)) => Ok(Err(f.into())),
294 Ok(Value::StateKeys(s)) => Ok(Ok(s)),
295 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
296 variant: <&'static str>::from(v),
297 syscall: "get_keys",
298 }
299 .into()),
300 Err(e) => Err(e),
301 });
302
303 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
304 }
305
306 pub fn set<T: Serialize>(&self, key: &str, t: T) {
307 let mut inner_lock = must_lock!(self.inner);
308 match t.serialize() {
309 Ok(b) => {
310 let _ = inner_lock
311 .vm
312 .sys_state_set(key.to_owned(), b, PayloadOptions::stable());
313 inner_lock.maybe_flip_span_replaying_field();
314 }
315 Err(e) => {
316 inner_lock.fail(Error::serialization("set_state", e));
317 }
318 }
319 }
320
321 pub fn clear(&self, key: &str) {
322 let mut inner_lock = must_lock!(self.inner);
323 let _ = inner_lock.vm.sys_state_clear(key.to_string());
324 inner_lock.maybe_flip_span_replaying_field();
325 }
326
327 pub fn clear_all(&self) {
328 let mut inner_lock = must_lock!(self.inner);
329 let _ = inner_lock.vm.sys_state_clear_all();
330 inner_lock.maybe_flip_span_replaying_field();
331 }
332
333 pub fn select(
334 &self,
335 handles: Vec<NotificationHandle>,
336 ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
337 InterceptErrorFuture::new(
338 self.clone(),
339 VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
340 )
341 }
342
343 pub fn sleep(
344 &self,
345 sleep_duration: Duration,
346 ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
347 let now = SystemTime::now()
348 .duration_since(SystemTime::UNIX_EPOCH)
349 .expect("Duration since unix epoch cannot fail");
350 let mut inner_lock = must_lock!(self.inner);
351 let handle = unwrap_or_trap_durable_future!(
352 self,
353 inner_lock,
354 inner_lock
355 .vm
356 .sys_sleep(String::default(), now + sleep_duration, Some(now))
357 );
358 inner_lock.maybe_flip_span_replaying_field();
359
360 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
361 Ok(Value::Void) => Ok(Ok(())),
362 Ok(Value::Failure(f)) => Ok(Err(f.into())),
363 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
364 variant: <&'static str>::from(v),
365 syscall: "sleep",
366 }
367 .into()),
368 Err(e) => Err(e),
369 });
370
371 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
372 }
373
374 pub fn request<Req, Res>(
375 &self,
376 request_target: RequestTarget,
377 req: Req,
378 ) -> Request<'_, Req, Res> {
379 Request::new(self, request_target, req)
380 }
381
382 pub fn call<Req: Serialize, Res: Deserialize>(
383 &self,
384 request_target: RequestTarget,
385 idempotency_key: Option<String>,
386 headers: Vec<(String, String)>,
387 req: Req,
388 ) -> impl CallFuture<Response = Res> + Send {
389 let mut inner_lock = must_lock!(self.inner);
390
391 let mut target: Target = request_target.into();
392 target.idempotency_key = idempotency_key;
393 target.headers = headers
394 .into_iter()
395 .map(|(k, v)| Header {
396 key: k.into(),
397 value: v.into(),
398 })
399 .collect();
400 let call_result = Req::serialize(&req)
401 .map_err(|e| Error::serialization("call", e))
402 .and_then(|input| {
403 inner_lock
404 .vm
405 .sys_call(target, input, None, PayloadOptions::stable())
406 .map_err(Into::into)
407 });
408
409 let call_handle = match call_result {
410 Ok(t) => t,
411 Err(e) => {
412 inner_lock.fail(e);
413 return CallFutureImpl {
414 invocation_id_future: Either::Right(TrapFuture::default()).shared(),
415 result_future: Either::Right(TrapFuture::default()),
416 call_notification_handle: NotificationHandle::from(u32::MAX),
417 ctx: self.clone(),
418 };
419 }
420 };
421 inner_lock.maybe_flip_span_replaying_field();
422 drop(inner_lock);
423
424 let invocation_id_fut = InterceptErrorFuture::new(
426 self.clone(),
427 get_async_result(
428 Arc::clone(&self.inner),
429 call_handle.invocation_id_notification_handle,
430 )
431 .map(|res| match res {
432 Ok(Value::Failure(f)) => Ok(Err(f.into())),
433 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
434 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
435 variant: <&'static str>::from(v),
436 syscall: "call",
437 }
438 .into()),
439 Err(e) => Err(e),
440 }),
441 );
442 let result_future = get_async_result(
443 Arc::clone(&self.inner),
444 call_handle.call_notification_handle,
445 )
446 .map(|res| match res {
447 Ok(Value::Success(mut s)) => Ok(Ok(
448 Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
449 )),
450 Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
451 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
452 variant: <&'static str>::from(v),
453 syscall: "call",
454 }
455 .into()),
456 Err(e) => Err(e),
457 });
458
459 CallFutureImpl {
460 invocation_id_future: Either::Left(invocation_id_fut).shared(),
461 result_future: Either::Left(result_future),
462 call_notification_handle: call_handle.call_notification_handle,
463 ctx: self.clone(),
464 }
465 }
466
467 pub fn send<Req: Serialize>(
468 &self,
469 request_target: RequestTarget,
470 idempotency_key: Option<String>,
471 headers: Vec<(String, String)>,
472 req: Req,
473 delay: Option<Duration>,
474 ) -> impl InvocationHandle {
475 let mut inner_lock = must_lock!(self.inner);
476
477 let mut target: Target = request_target.into();
478 target.idempotency_key = idempotency_key;
479 target.headers = headers
480 .into_iter()
481 .map(|(k, v)| Header {
482 key: k.into(),
483 value: v.into(),
484 })
485 .collect();
486 let input = match Req::serialize(&req) {
487 Ok(b) => b,
488 Err(e) => {
489 inner_lock.fail(Error::serialization("call", e));
490 return Either::Right(TrapFuture::<()>::default());
491 }
492 };
493
494 let send_handle = match inner_lock.vm.sys_send(
495 target,
496 input,
497 delay.map(|delay| {
498 SystemTime::now()
499 .duration_since(SystemTime::UNIX_EPOCH)
500 .expect("Duration since unix epoch cannot fail")
501 + delay
502 }),
503 None,
504 PayloadOptions::stable(),
505 ) {
506 Ok(h) => h,
507 Err(e) => {
508 inner_lock.fail(e.into());
509 return Either::Right(TrapFuture::<()>::default());
510 }
511 };
512 inner_lock.maybe_flip_span_replaying_field();
513 drop(inner_lock);
514
515 let invocation_id_fut = InterceptErrorFuture::new(
516 self.clone(),
517 get_async_result(
518 Arc::clone(&self.inner),
519 send_handle.invocation_id_notification_handle,
520 )
521 .map(|res| match res {
522 Ok(Value::Failure(f)) => Ok(Err(f.into())),
523 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
524 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
525 variant: <&'static str>::from(v),
526 syscall: "call",
527 }
528 .into()),
529 Err(e) => Err(e),
530 }),
531 );
532
533 Either::Left(SendRequestHandle {
534 invocation_id_future: invocation_id_fut.shared(),
535 ctx: self.clone(),
536 })
537 }
538
539 pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
540 InvocationIdBackedInvocationHandle {
541 ctx: self.clone(),
542 invocation_id,
543 }
544 }
545
546 pub fn awakeable<T: Deserialize>(
547 &self,
548 ) -> (
549 String,
550 impl DurableFuture<Output = Result<T, TerminalError>> + Send,
551 ) {
552 let mut inner_lock = must_lock!(self.inner);
553 let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
554 inner_lock.maybe_flip_span_replaying_field();
555
556 let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
557 Ok((s, handle)) => (s, handle),
558 Err(e) => {
559 inner_lock.fail(e.into());
560 return (
561 "invalid".to_owned(),
564 DurableFutureImpl::new(
565 self.clone(),
566 NotificationHandle::from(u32::MAX),
567 Either::Right(TrapFuture::default()),
568 ),
569 );
570 }
571 };
572 drop(inner_lock);
573
574 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
575 Ok(Value::Success(mut s)) => Ok(Ok(
576 T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
577 )),
578 Ok(Value::Failure(f)) => Ok(Err(f.into())),
579 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
580 variant: <&'static str>::from(v),
581 syscall: "awakeable",
582 }
583 .into()),
584 Err(e) => Err(e),
585 });
586
587 (
588 awakeable_id,
589 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
590 )
591 }
592
593 pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
594 let mut inner_lock = must_lock!(self.inner);
595 match t.serialize() {
596 Ok(b) => {
597 let _ = inner_lock.vm.sys_complete_awakeable(
598 id.to_owned(),
599 NonEmptyValue::Success(b),
600 PayloadOptions::stable(),
601 );
602 }
603 Err(e) => {
604 inner_lock.fail(Error::serialization("resolve_awakeable", e));
605 }
606 }
607 }
608
609 pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
610 let _ = must_lock!(self.inner).vm.sys_complete_awakeable(
611 id.to_owned(),
612 NonEmptyValue::Failure(failure.into()),
613 PayloadOptions::stable(),
614 );
615 }
616
617 pub fn promise<T: Deserialize>(
618 &self,
619 name: &str,
620 ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
621 let mut inner_lock = must_lock!(self.inner);
622 let handle = unwrap_or_trap_durable_future!(
623 self,
624 inner_lock,
625 inner_lock.vm.sys_get_promise(name.to_owned())
626 );
627 inner_lock.maybe_flip_span_replaying_field();
628 drop(inner_lock);
629
630 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
631 Ok(Value::Success(mut s)) => {
632 let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
633 Ok(Ok(t))
634 }
635 Ok(Value::Failure(f)) => Ok(Err(f.into())),
636 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
637 variant: <&'static str>::from(v),
638 syscall: "promise",
639 }
640 .into()),
641 Err(e) => Err(e),
642 });
643
644 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
645 }
646
647 pub fn peek_promise<T: Deserialize>(
648 &self,
649 name: &str,
650 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
651 let mut inner_lock = must_lock!(self.inner);
652 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
653 inner_lock.maybe_flip_span_replaying_field();
654 drop(inner_lock);
655
656 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
657 Ok(Value::Void) => Ok(Ok(None)),
658 Ok(Value::Success(mut s)) => {
659 let t = T::deserialize(&mut s)
660 .map_err(|e| Error::deserialization("peek_promise", e))?;
661 Ok(Ok(Some(t)))
662 }
663 Ok(Value::Failure(f)) => Ok(Err(f.into())),
664 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
665 variant: <&'static str>::from(v),
666 syscall: "peek_promise",
667 }
668 .into()),
669 Err(e) => Err(e),
670 });
671
672 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
673 }
674
675 pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
676 let mut inner_lock = must_lock!(self.inner);
677 match t.serialize() {
678 Ok(b) => {
679 let _ = inner_lock.vm.sys_complete_promise(
680 name.to_owned(),
681 NonEmptyValue::Success(b),
682 PayloadOptions::stable(),
683 );
684 }
685 Err(e) => {
686 inner_lock.fail(
687 ErrorInner::Serialization {
688 syscall: "resolve_promise",
689 err: Box::new(e),
690 }
691 .into(),
692 );
693 }
694 }
695 }
696
697 pub fn reject_promise(&self, id: &str, failure: TerminalError) {
698 let _ = must_lock!(self.inner).vm.sys_complete_promise(
699 id.to_owned(),
700 NonEmptyValue::Failure(failure.into()),
701 PayloadOptions::stable(),
702 );
703 }
704
705 pub fn run<'a, Run, Fut, Out>(
706 &'a self,
707 run_closure: Run,
708 ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
709 where
710 Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
711 Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
712 Out: Serialize + Deserialize + 'static,
713 {
714 let this = Arc::clone(&self.inner);
715 InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
716 }
717
718 pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
720 let mut inner_lock = must_lock!(self.inner);
721
722 let res_to_write = match res {
723 Ok(success) => match T::serialize(&success) {
724 Ok(t) => NonEmptyValue::Success(t),
725 Err(e) => {
726 inner_lock.fail(
727 ErrorInner::Serialization {
728 syscall: "output",
729 err: Box::new(e),
730 }
731 .into(),
732 );
733 return;
734 }
735 },
736 Err(e) => match e.0 {
737 HandlerErrorInner::Retryable(err) => {
738 inner_lock.fail(ErrorInner::HandlerResult { err }.into());
739 return;
740 }
741 HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
742 },
743 };
744
745 let _ = inner_lock
746 .vm
747 .sys_write_output(res_to_write, PayloadOptions::stable());
748 inner_lock.maybe_flip_span_replaying_field();
749 }
750
751 pub fn end(&self) {
752 let _ = must_lock!(self.inner).vm.sys_end();
753 }
754
755 pub(crate) fn consume_to_end(&self) {
756 let mut inner_lock = must_lock!(self.inner);
757
758 let out = inner_lock.vm.take_output();
759 if let TakeOutputResult::Buffer(b) = out
760 && !inner_lock.write.send(b)
761 {
762 }
764 }
765
766 pub(crate) async fn drain_input(&self) -> Result<(), ErrorInner> {
771 tokio::time::timeout(Duration::from_secs(60), async {
772 loop {
773 let result = poll_fn(|cx| {
774 let mut inner = must_lock!(self.inner);
775 inner.read.poll_recv(cx)
776 })
777 .await;
778 match result {
779 None => return Ok(()),
780 Some(Ok(_)) => continue,
781 Some(Err(e)) => return Err(ErrorInner::InputDrain(e)),
782 }
783 }
784 })
785 .await
786 .unwrap_or_else(|_| {
787 Err(ErrorInner::InputDrain(
788 "Timed out draining input stream after 60s".into(),
789 ))
790 })
791 }
792
793 pub(super) fn fail(&self, e: Error) {
794 must_lock!(self.inner).fail(e)
795 }
796}
797
798pin_project! {
799 struct RunFutureImpl<Run, Ret, RunFnFut> {
800 name: String,
801 retry_policy: RetryPolicy,
802 phantom_data: PhantomData<fn() -> Ret>,
803 #[pin]
804 state: RunState<Run, RunFnFut, Ret>,
805 }
806}
807
808pin_project! {
809 #[project = RunStateProj]
810 enum RunState<Run, RunFnFut, Ret> {
811 New {
812 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
813 closure: Option<Run>,
814 },
815 ClosureRunning {
816 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
817 handle: NotificationHandle,
818 start_time: Instant,
819 #[pin]
820 closure_fut: RunFnFut,
821 },
822 WaitingResultFut {
823 result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
824 }
825 }
826}
827
828impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
829 fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
830 Self {
831 name: "".to_string(),
832 retry_policy: RetryPolicy::Infinite,
833 phantom_data: PhantomData,
834 state: RunState::New {
835 ctx: Some(ctx),
836 closure: Some(closure),
837 },
838 }
839 }
840
841 fn boxed_result_fut(
842 ctx: Arc<Mutex<ContextInternalInner>>,
843 handle: NotificationHandle,
844 ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
845 where
846 Ret: Deserialize,
847 {
848 get_async_result(Arc::clone(&ctx), handle)
849 .map(|res| match res {
850 Ok(Value::Success(mut s)) => {
851 let t =
852 Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
853 Ok(Ok(t))
854 }
855 Ok(Value::Failure(f)) => Ok(Err(f.into())),
856 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
857 variant: <&'static str>::from(v),
858 syscall: "run",
859 }
860 .into()),
861 Err(e) => Err(e),
862 })
863 .boxed()
864 }
865}
866
867impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
868 for RunFutureImpl<Run, Ret, RunFnFut>
869where
870 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
871 Ret: Serialize + Deserialize,
872 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
873{
874 fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
875 self.retry_policy = RetryPolicy::Exponential {
876 initial_interval: retry_policy.initial_delay,
877 factor: retry_policy.factor,
878 max_interval: retry_policy.max_delay,
879 max_attempts: retry_policy.max_attempts,
880 max_duration: retry_policy.max_duration,
881 };
882 self
883 }
884
885 fn name(mut self, name: impl Into<String>) -> Self {
886 self.name = name.into();
887 self
888 }
889}
890
891impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
892where
893 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
894 Ret: Serialize + Deserialize,
895 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
896{
897 type Output = Result<Result<Ret, TerminalError>, Error>;
898
899 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
900 let mut this = self.project();
901
902 loop {
903 match this.state.as_mut().project() {
904 RunStateProj::New { ctx, closure, .. } => {
905 let ctx = ctx
906 .take()
907 .expect("Future should not be polled after returning Poll::Ready");
908 let closure = closure
909 .take()
910 .expect("Future should not be polled after returning Poll::Ready");
911 let mut inner_ctx = must_lock!(ctx);
912
913 let handle = inner_ctx
914 .vm
915 .sys_run(this.name.to_owned())
916 .map_err(ErrorInner::from)?;
917
918 match inner_ctx.vm.do_progress(vec![handle]) {
920 Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
921 assert_eq!(handle, handle_to_run);
924
925 drop(inner_ctx);
926 this.state.set(RunState::ClosureRunning {
927 ctx: Some(ctx),
928 handle,
929 start_time: Instant::now(),
930 closure_fut: closure.run(),
931 });
932 }
933 Ok(DoProgressResponse::CancelSignalReceived) => {
934 drop(inner_ctx);
935 this.state.set(RunState::WaitingResultFut {
937 result_fut: async {
938 Ok(Err(TerminalError::from(TerminalFailure {
939 code: 409,
940 message: "cancelled".to_string(),
941 metadata: vec![],
942 })))
943 }
944 .boxed(),
945 })
946 }
947 _ => {
948 drop(inner_ctx);
949 this.state.set(RunState::WaitingResultFut {
953 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
954 })
955 }
956 }
957 }
958 RunStateProj::ClosureRunning {
959 ctx,
960 handle,
961 start_time,
962 closure_fut,
963 } => {
964 let res = match ready!(closure_fut.poll(cx)) {
965 Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
966 ErrorInner::Serialization {
967 syscall: "run",
968 err: Box::new(e),
969 }
970 })?),
971 Err(e) => match e.0 {
972 HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
973 attempt_duration: start_time.elapsed(),
974 error: CoreError::new(500u16, err.to_string()),
975 },
976 HandlerErrorInner::Terminal(t) => {
977 RunExitResult::TerminalFailure(TerminalError(t).into())
978 }
979 },
980 };
981
982 let ctx = ctx
983 .take()
984 .expect("Future should not be polled after returning Poll::Ready");
985 let handle = *handle;
986
987 let _ = {
988 must_lock!(ctx).vm.propose_run_completion(
989 handle,
990 res,
991 mem::take(this.retry_policy),
992 )
993 };
994
995 this.state.set(RunState::WaitingResultFut {
996 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
997 });
998 }
999 RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
1000 }
1001 }
1002 }
1003}
1004
1005pin_project! {
1006 struct CallFutureImpl<InvIdFut: Future, ResultFut> {
1007 #[pin]
1008 invocation_id_future: Shared<InvIdFut>,
1009 #[pin]
1010 result_future: ResultFut,
1011 call_notification_handle: NotificationHandle,
1012 ctx: ContextInternal,
1013 }
1014}
1015
1016impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
1017where
1018 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1019 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1020{
1021 type Output = Result<Res, TerminalError>;
1022
1023 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
1024 let this = self.project();
1025 let result = ready!(this.result_future.poll(cx));
1026
1027 match result {
1028 Ok(r) => Poll::Ready(r),
1029 Err(e) => {
1030 this.ctx.fail(e);
1031
1032 cx.waker().wake_by_ref();
1035 Poll::Pending
1036 }
1037 }
1038 }
1039}
1040
1041impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
1042where
1043 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1044{
1045 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1046 Shared::clone(&self.invocation_id_future)
1047 }
1048
1049 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1050 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1051 let cloned_ctx = Arc::clone(&self.ctx.inner);
1052 async move {
1053 let inv_id = cloned_invocation_id_fut.await?;
1054 let mut inner_lock = must_lock!(cloned_ctx);
1055 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1056 inner_lock.maybe_flip_span_replaying_field();
1057 drop(inner_lock);
1058 Ok(())
1059 }
1060 }
1061}
1062
1063impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1064where
1065 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1066 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1067{
1068 type Response = Res;
1069}
1070
1071impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1072 for CallFutureImpl<InvIdFut, ResultFut>
1073where
1074 InvIdFut: Future,
1075{
1076 fn inner_context(&self) -> ContextInternal {
1077 self.ctx.clone()
1078 }
1079
1080 fn handle(&self) -> NotificationHandle {
1081 self.call_notification_handle
1082 }
1083}
1084
1085impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1086where
1087 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1088 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1089{
1090}
1091
1092struct SendRequestHandle<InvIdFut: Future> {
1093 invocation_id_future: Shared<InvIdFut>,
1094 ctx: ContextInternal,
1095}
1096
1097impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1098 for SendRequestHandle<InvIdFut>
1099{
1100 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1101 Shared::clone(&self.invocation_id_future)
1102 }
1103
1104 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1105 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1106 let cloned_ctx = Arc::clone(&self.ctx.inner);
1107 async move {
1108 let inv_id = cloned_invocation_id_fut.await?;
1109 let mut inner_lock = must_lock!(cloned_ctx);
1110 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1111 inner_lock.maybe_flip_span_replaying_field();
1112 drop(inner_lock);
1113 Ok(())
1114 }
1115 }
1116}
1117
1118struct InvocationIdBackedInvocationHandle {
1119 ctx: ContextInternal,
1120 invocation_id: String,
1121}
1122
1123impl InvocationHandle for InvocationIdBackedInvocationHandle {
1124 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1125 ready(Ok(self.invocation_id.clone()))
1126 }
1127
1128 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1129 let mut inner_lock = must_lock!(self.ctx.inner);
1130 let _ = inner_lock
1131 .vm
1132 .sys_cancel_invocation(self.invocation_id.clone());
1133 ready(Ok(()))
1134 }
1135}
1136
1137impl<A, B> InvocationHandle for Either<A, B>
1138where
1139 A: InvocationHandle,
1140 B: InvocationHandle,
1141{
1142 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1143 match self {
1144 Either::Left(l) => Either::Left(l.invocation_id()),
1145 Either::Right(r) => Either::Right(r.invocation_id()),
1146 }
1147 }
1148
1149 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1150 match self {
1151 Either::Left(l) => Either::Left(l.cancel()),
1152 Either::Right(r) => Either::Right(r.cancel()),
1153 }
1154 }
1155}
1156
1157impl Error {
1158 fn serialization<E: std::error::Error + Send + Sync + 'static>(
1159 syscall: &'static str,
1160 e: E,
1161 ) -> Self {
1162 ErrorInner::Serialization {
1163 syscall,
1164 err: Box::new(e),
1165 }
1166 .into()
1167 }
1168
1169 fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1170 syscall: &'static str,
1171 e: E,
1172 ) -> Self {
1173 ErrorInner::Deserialization {
1174 syscall,
1175 err: Box::new(e),
1176 }
1177 .into()
1178 }
1179}
1180
1181fn get_async_result(
1182 ctx: Arc<Mutex<ContextInternalInner>>,
1183 handle: NotificationHandle,
1184) -> impl Future<Output = Result<Value, Error>> + Send {
1185 VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1186}