1use crate::context::{
2 CallFuture, DurableFuture, InvocationHandle, Request, RequestTarget, RunClosure, RunFuture,
3 RunRetryPolicy,
4};
5use crate::endpoint::futures::async_result_poll::VmAsyncResultPollFuture;
6use crate::endpoint::futures::durable_future_impl::DurableFutureImpl;
7use crate::endpoint::futures::intercept_error::InterceptErrorFuture;
8use crate::endpoint::futures::select_poll::VmSelectAsyncResultPollFuture;
9use crate::endpoint::futures::trap::TrapFuture;
10use crate::endpoint::handler_state::HandlerStateNotifier;
11use crate::endpoint::{Error, ErrorInner, InputReceiver, OutputSender};
12use crate::errors::{HandlerErrorInner, HandlerResult, TerminalError};
13use crate::serde::{Deserialize, Serialize};
14use futures::future::{BoxFuture, Either, Shared};
15use futures::{FutureExt, TryFutureExt};
16use pin_project_lite::pin_project;
17use restate_sdk_shared_core::{
18 CoreVM, DoProgressResponse, Error as CoreError, Header, NonEmptyValue, NotificationHandle,
19 RetryPolicy, RunExitResult, TakeOutputResult, Target, TerminalFailure, VM, Value,
20};
21use std::borrow::Cow;
22use std::collections::HashMap;
23use std::future::{Future, ready};
24use std::marker::PhantomData;
25use std::mem;
26use std::pin::Pin;
27use std::sync::{Arc, Mutex};
28use std::task::{Context, Poll, ready};
29use std::time::{Duration, Instant, SystemTime};
30
31pub struct ContextInternalInner {
32 pub(crate) vm: CoreVM,
33 pub(crate) read: InputReceiver,
34 pub(crate) write: OutputSender,
35 pub(super) handler_state: HandlerStateNotifier,
36
37 pub(super) span_replaying_field_state: bool,
40}
41
42impl ContextInternalInner {
43 fn new(
44 vm: CoreVM,
45 read: InputReceiver,
46 write: OutputSender,
47 handler_state: HandlerStateNotifier,
48 ) -> Self {
49 Self {
50 vm,
51 read,
52 write,
53 handler_state,
54 span_replaying_field_state: false,
55 }
56 }
57
58 pub(super) fn fail(&mut self, e: Error) {
59 self.maybe_flip_span_replaying_field();
60 self.vm.notify_error(
61 CoreError::new(500u16, e.0.to_string())
62 .with_stacktrace(Cow::Owned(format!("{:#}", e.0))),
63 None,
64 );
65 self.handler_state.mark_error(e);
66 }
67
68 pub(super) fn maybe_flip_span_replaying_field(&mut self) {
69 if !self.span_replaying_field_state && self.vm.is_replaying() {
70 tracing::Span::current().record("restate.sdk.is_replaying", true);
71 self.span_replaying_field_state = true;
72 } else if self.span_replaying_field_state && !self.vm.is_replaying() {
73 tracing::Span::current().record("restate.sdk.is_replaying", false);
74 self.span_replaying_field_state = false;
75 }
76 }
77}
78
79#[allow(unused)]
80const fn is_send_sync<T: Send + Sync>() {}
81const _: () = is_send_sync::<ContextInternal>();
82
83macro_rules! must_lock {
84 ($mutex:expr) => {
85 $mutex.try_lock().expect("You're trying to await two futures at the same time and/or trying to perform some operation on the restate context while awaiting a future. This is not supported!")
86 };
87}
88
89macro_rules! unwrap_or_trap {
90 ($inner_lock:expr, $res:expr) => {
91 match $res {
92 Ok(t) => t,
93 Err(e) => {
94 $inner_lock.fail(e.into());
95 return Either::Right(TrapFuture::default());
96 }
97 }
98 };
99}
100
101macro_rules! unwrap_or_trap_durable_future {
102 ($ctx:expr, $inner_lock:expr, $res:expr) => {
103 match $res {
104 Ok(t) => t,
105 Err(e) => {
106 $inner_lock.fail(e.into());
107 return DurableFutureImpl::new(
108 $ctx.clone(),
109 NotificationHandle::from(u32::MAX),
110 Either::Right(TrapFuture::default()),
111 );
112 }
113 }
114 };
115}
116
117#[derive(Debug, Eq, PartialEq)]
118pub struct InputMetadata {
119 pub invocation_id: String,
120 pub random_seed: u64,
121 pub key: String,
122 pub headers: http::HeaderMap<String>,
123}
124
125impl From<RequestTarget> for Target {
126 fn from(value: RequestTarget) -> Self {
127 match value {
128 RequestTarget::Service { name, handler } => Target {
129 service: name,
130 handler,
131 key: None,
132 idempotency_key: None,
133 headers: vec![],
134 },
135 RequestTarget::Object { name, key, handler } => Target {
136 service: name,
137 handler,
138 key: Some(key),
139 idempotency_key: None,
140 headers: vec![],
141 },
142 RequestTarget::Workflow { name, key, handler } => Target {
143 service: name,
144 handler,
145 key: Some(key),
146 idempotency_key: None,
147 headers: vec![],
148 },
149 }
150 }
151}
152
153#[derive(Clone)]
157pub struct ContextInternal {
158 svc_name: String,
159 handler_name: String,
160 inner: Arc<Mutex<ContextInternalInner>>,
161}
162
163impl ContextInternal {
164 pub(super) fn new(
165 vm: CoreVM,
166 svc_name: String,
167 handler_name: String,
168 read: InputReceiver,
169 write: OutputSender,
170 handler_state: HandlerStateNotifier,
171 ) -> Self {
172 Self {
173 svc_name,
174 handler_name,
175 inner: Arc::new(Mutex::new(ContextInternalInner::new(
176 vm,
177 read,
178 write,
179 handler_state,
180 ))),
181 }
182 }
183
184 pub fn service_name(&self) -> &str {
185 &self.svc_name
186 }
187
188 pub fn handler_name(&self) -> &str {
189 &self.handler_name
190 }
191
192 pub fn input<T: Deserialize>(&self) -> impl Future<Output = (T, InputMetadata)> {
193 let mut inner_lock = must_lock!(self.inner);
194 let input_result =
195 inner_lock
196 .vm
197 .sys_input()
198 .map_err(ErrorInner::VM)
199 .map(|mut raw_input| {
200 let headers = http::HeaderMap::<String>::try_from(
201 &raw_input
202 .headers
203 .into_iter()
204 .map(|h| (h.key.to_string(), h.value.to_string()))
205 .collect::<HashMap<String, String>>(),
206 )
207 .map_err(|e| {
208 TerminalError::new_with_code(400, format!("Cannot decode headers: {e:?}"))
209 })?;
210
211 Ok::<_, TerminalError>((
212 T::deserialize(&mut (raw_input.input)).map_err(|e| {
213 TerminalError::new_with_code(
214 400,
215 format!("Cannot decode input payload: {e:?}"),
216 )
217 })?,
218 InputMetadata {
219 invocation_id: raw_input.invocation_id,
220 random_seed: raw_input.random_seed,
221 key: raw_input.key,
222 headers,
223 },
224 ))
225 });
226 inner_lock.maybe_flip_span_replaying_field();
227
228 match input_result {
229 Ok(Ok(i)) => {
230 drop(inner_lock);
231 return Either::Left(ready(i));
232 }
233 Ok(Err(err)) => {
234 let error_inner = ErrorInner::Deserialization {
235 syscall: "input",
236 err: err.0.clone().into(),
237 };
238 let _ = inner_lock
239 .vm
240 .sys_write_output(NonEmptyValue::Failure(err.into()));
241 let _ = inner_lock.vm.sys_end();
242 inner_lock.handler_state.mark_error(error_inner.into());
244 drop(inner_lock);
245 }
246 Err(e) => {
247 inner_lock.fail(e.into());
248 drop(inner_lock);
249 }
250 }
251 Either::Right(TrapFuture::default())
252 }
253
254 pub fn get<T: Deserialize>(
255 &self,
256 key: &str,
257 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
258 let mut inner_lock = must_lock!(self.inner);
259 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get(key.to_owned()));
260 inner_lock.maybe_flip_span_replaying_field();
261
262 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
263 Ok(Value::Void) => Ok(Ok(None)),
264 Ok(Value::Success(mut s)) => {
265 let t =
266 T::deserialize(&mut s).map_err(|e| Error::deserialization("get_state", e))?;
267 Ok(Ok(Some(t)))
268 }
269 Ok(Value::Failure(f)) => Ok(Err(f.into())),
270 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
271 variant: <&'static str>::from(v),
272 syscall: "get_state",
273 }
274 .into()),
275 Err(e) => Err(e),
276 });
277
278 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
279 }
280
281 pub fn get_keys(&self) -> impl Future<Output = Result<Vec<String>, TerminalError>> + Send {
282 let mut inner_lock = must_lock!(self.inner);
283 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_state_get_keys());
284 inner_lock.maybe_flip_span_replaying_field();
285
286 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
287 Ok(Value::Failure(f)) => Ok(Err(f.into())),
288 Ok(Value::StateKeys(s)) => Ok(Ok(s)),
289 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
290 variant: <&'static str>::from(v),
291 syscall: "get_keys",
292 }
293 .into()),
294 Err(e) => Err(e),
295 });
296
297 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
298 }
299
300 pub fn set<T: Serialize>(&self, key: &str, t: T) {
301 let mut inner_lock = must_lock!(self.inner);
302 match t.serialize() {
303 Ok(b) => {
304 let _ = inner_lock.vm.sys_state_set(key.to_owned(), b);
305 inner_lock.maybe_flip_span_replaying_field();
306 }
307 Err(e) => {
308 inner_lock.fail(Error::serialization("set_state", e));
309 }
310 }
311 }
312
313 pub fn clear(&self, key: &str) {
314 let mut inner_lock = must_lock!(self.inner);
315 let _ = inner_lock.vm.sys_state_clear(key.to_string());
316 inner_lock.maybe_flip_span_replaying_field();
317 }
318
319 pub fn clear_all(&self) {
320 let mut inner_lock = must_lock!(self.inner);
321 let _ = inner_lock.vm.sys_state_clear_all();
322 inner_lock.maybe_flip_span_replaying_field();
323 }
324
325 pub fn select(
326 &self,
327 handles: Vec<NotificationHandle>,
328 ) -> impl Future<Output = Result<usize, TerminalError>> + Send {
329 InterceptErrorFuture::new(
330 self.clone(),
331 VmSelectAsyncResultPollFuture::new(self.inner.clone(), handles).map_err(Error::from),
332 )
333 }
334
335 pub fn sleep(
336 &self,
337 sleep_duration: Duration,
338 ) -> impl DurableFuture<Output = Result<(), TerminalError>> + Send {
339 let now = SystemTime::now()
340 .duration_since(SystemTime::UNIX_EPOCH)
341 .expect("Duration since unix epoch cannot fail");
342 let mut inner_lock = must_lock!(self.inner);
343 let handle = unwrap_or_trap_durable_future!(
344 self,
345 inner_lock,
346 inner_lock
347 .vm
348 .sys_sleep(String::default(), now + sleep_duration, Some(now))
349 );
350 inner_lock.maybe_flip_span_replaying_field();
351
352 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
353 Ok(Value::Void) => Ok(Ok(())),
354 Ok(Value::Failure(f)) => Ok(Err(f.into())),
355 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
356 variant: <&'static str>::from(v),
357 syscall: "sleep",
358 }
359 .into()),
360 Err(e) => Err(e),
361 });
362
363 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
364 }
365
366 pub fn request<Req, Res>(
367 &self,
368 request_target: RequestTarget,
369 req: Req,
370 ) -> Request<'_, Req, Res> {
371 Request::new(self, request_target, req)
372 }
373
374 pub fn call<Req: Serialize, Res: Deserialize>(
375 &self,
376 request_target: RequestTarget,
377 idempotency_key: Option<String>,
378 headers: Vec<(String, String)>,
379 req: Req,
380 ) -> impl CallFuture<Response = Res> + Send {
381 let mut inner_lock = must_lock!(self.inner);
382
383 let mut target: Target = request_target.into();
384 target.idempotency_key = idempotency_key;
385 target.headers = headers
386 .into_iter()
387 .map(|(k, v)| Header {
388 key: k.into(),
389 value: v.into(),
390 })
391 .collect();
392 let call_result = Req::serialize(&req)
393 .map_err(|e| Error::serialization("call", e))
394 .and_then(|input| inner_lock.vm.sys_call(target, input).map_err(Into::into));
395
396 let call_handle = match call_result {
397 Ok(t) => t,
398 Err(e) => {
399 inner_lock.fail(e);
400 return CallFutureImpl {
401 invocation_id_future: Either::Right(TrapFuture::default()).shared(),
402 result_future: Either::Right(TrapFuture::default()),
403 call_notification_handle: NotificationHandle::from(u32::MAX),
404 ctx: self.clone(),
405 };
406 }
407 };
408 inner_lock.maybe_flip_span_replaying_field();
409 drop(inner_lock);
410
411 let invocation_id_fut = InterceptErrorFuture::new(
413 self.clone(),
414 get_async_result(
415 Arc::clone(&self.inner),
416 call_handle.invocation_id_notification_handle,
417 )
418 .map(|res| match res {
419 Ok(Value::Failure(f)) => Ok(Err(f.into())),
420 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
421 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
422 variant: <&'static str>::from(v),
423 syscall: "call",
424 }
425 .into()),
426 Err(e) => Err(e),
427 }),
428 );
429 let result_future = get_async_result(
430 Arc::clone(&self.inner),
431 call_handle.call_notification_handle,
432 )
433 .map(|res| match res {
434 Ok(Value::Success(mut s)) => Ok(Ok(
435 Res::deserialize(&mut s).map_err(|e| Error::deserialization("call", e))?
436 )),
437 Ok(Value::Failure(f)) => Ok(Err(TerminalError::from(f))),
438 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
439 variant: <&'static str>::from(v),
440 syscall: "call",
441 }
442 .into()),
443 Err(e) => Err(e),
444 });
445
446 CallFutureImpl {
447 invocation_id_future: Either::Left(invocation_id_fut).shared(),
448 result_future: Either::Left(result_future),
449 call_notification_handle: call_handle.call_notification_handle,
450 ctx: self.clone(),
451 }
452 }
453
454 pub fn send<Req: Serialize>(
455 &self,
456 request_target: RequestTarget,
457 idempotency_key: Option<String>,
458 headers: Vec<(String, String)>,
459 req: Req,
460 delay: Option<Duration>,
461 ) -> impl InvocationHandle {
462 let mut inner_lock = must_lock!(self.inner);
463
464 let mut target: Target = request_target.into();
465 target.idempotency_key = idempotency_key;
466 target.headers = headers
467 .into_iter()
468 .map(|(k, v)| Header {
469 key: k.into(),
470 value: v.into(),
471 })
472 .collect();
473 let input = match Req::serialize(&req) {
474 Ok(b) => b,
475 Err(e) => {
476 inner_lock.fail(Error::serialization("call", e));
477 return Either::Right(TrapFuture::<()>::default());
478 }
479 };
480
481 let send_handle = match inner_lock.vm.sys_send(
482 target,
483 input,
484 delay.map(|delay| {
485 SystemTime::now()
486 .duration_since(SystemTime::UNIX_EPOCH)
487 .expect("Duration since unix epoch cannot fail")
488 + delay
489 }),
490 ) {
491 Ok(h) => h,
492 Err(e) => {
493 inner_lock.fail(e.into());
494 return Either::Right(TrapFuture::<()>::default());
495 }
496 };
497 inner_lock.maybe_flip_span_replaying_field();
498 drop(inner_lock);
499
500 let invocation_id_fut = InterceptErrorFuture::new(
501 self.clone(),
502 get_async_result(
503 Arc::clone(&self.inner),
504 send_handle.invocation_id_notification_handle,
505 )
506 .map(|res| match res {
507 Ok(Value::Failure(f)) => Ok(Err(f.into())),
508 Ok(Value::InvocationId(s)) => Ok(Ok(s)),
509 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
510 variant: <&'static str>::from(v),
511 syscall: "call",
512 }
513 .into()),
514 Err(e) => Err(e),
515 }),
516 );
517
518 Either::Left(SendRequestHandle {
519 invocation_id_future: invocation_id_fut.shared(),
520 ctx: self.clone(),
521 })
522 }
523
524 pub fn invocation_handle(&self, invocation_id: String) -> impl InvocationHandle {
525 InvocationIdBackedInvocationHandle {
526 ctx: self.clone(),
527 invocation_id,
528 }
529 }
530
531 pub fn awakeable<T: Deserialize>(
532 &self,
533 ) -> (
534 String,
535 impl DurableFuture<Output = Result<T, TerminalError>> + Send,
536 ) {
537 let mut inner_lock = must_lock!(self.inner);
538 let maybe_awakeable_id_and_handle = inner_lock.vm.sys_awakeable();
539 inner_lock.maybe_flip_span_replaying_field();
540
541 let (awakeable_id, handle) = match maybe_awakeable_id_and_handle {
542 Ok((s, handle)) => (s, handle),
543 Err(e) => {
544 inner_lock.fail(e.into());
545 return (
546 "invalid".to_owned(),
549 DurableFutureImpl::new(
550 self.clone(),
551 NotificationHandle::from(u32::MAX),
552 Either::Right(TrapFuture::default()),
553 ),
554 );
555 }
556 };
557 drop(inner_lock);
558
559 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
560 Ok(Value::Success(mut s)) => Ok(Ok(
561 T::deserialize(&mut s).map_err(|e| Error::deserialization("awakeable", e))?
562 )),
563 Ok(Value::Failure(f)) => Ok(Err(f.into())),
564 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
565 variant: <&'static str>::from(v),
566 syscall: "awakeable",
567 }
568 .into()),
569 Err(e) => Err(e),
570 });
571
572 (
573 awakeable_id,
574 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future)),
575 )
576 }
577
578 pub fn resolve_awakeable<T: Serialize>(&self, id: &str, t: T) {
579 let mut inner_lock = must_lock!(self.inner);
580 match t.serialize() {
581 Ok(b) => {
582 let _ = inner_lock
583 .vm
584 .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Success(b));
585 }
586 Err(e) => {
587 inner_lock.fail(Error::serialization("resolve_awakeable", e));
588 }
589 }
590 }
591
592 pub fn reject_awakeable(&self, id: &str, failure: TerminalError) {
593 let _ = must_lock!(self.inner)
594 .vm
595 .sys_complete_awakeable(id.to_owned(), NonEmptyValue::Failure(failure.into()));
596 }
597
598 pub fn promise<T: Deserialize>(
599 &self,
600 name: &str,
601 ) -> impl DurableFuture<Output = Result<T, TerminalError>> + Send {
602 let mut inner_lock = must_lock!(self.inner);
603 let handle = unwrap_or_trap_durable_future!(
604 self,
605 inner_lock,
606 inner_lock.vm.sys_get_promise(name.to_owned())
607 );
608 inner_lock.maybe_flip_span_replaying_field();
609 drop(inner_lock);
610
611 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
612 Ok(Value::Success(mut s)) => {
613 let t = T::deserialize(&mut s).map_err(|e| Error::deserialization("promise", e))?;
614 Ok(Ok(t))
615 }
616 Ok(Value::Failure(f)) => Ok(Err(f.into())),
617 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
618 variant: <&'static str>::from(v),
619 syscall: "promise",
620 }
621 .into()),
622 Err(e) => Err(e),
623 });
624
625 DurableFutureImpl::new(self.clone(), handle, Either::Left(poll_future))
626 }
627
628 pub fn peek_promise<T: Deserialize>(
629 &self,
630 name: &str,
631 ) -> impl Future<Output = Result<Option<T>, TerminalError>> + Send {
632 let mut inner_lock = must_lock!(self.inner);
633 let handle = unwrap_or_trap!(inner_lock, inner_lock.vm.sys_peek_promise(name.to_owned()));
634 inner_lock.maybe_flip_span_replaying_field();
635 drop(inner_lock);
636
637 let poll_future = get_async_result(Arc::clone(&self.inner), handle).map(|res| match res {
638 Ok(Value::Void) => Ok(Ok(None)),
639 Ok(Value::Success(mut s)) => {
640 let t = T::deserialize(&mut s)
641 .map_err(|e| Error::deserialization("peek_promise", e))?;
642 Ok(Ok(Some(t)))
643 }
644 Ok(Value::Failure(f)) => Ok(Err(f.into())),
645 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
646 variant: <&'static str>::from(v),
647 syscall: "peek_promise",
648 }
649 .into()),
650 Err(e) => Err(e),
651 });
652
653 Either::Left(InterceptErrorFuture::new(self.clone(), poll_future))
654 }
655
656 pub fn resolve_promise<T: Serialize>(&self, name: &str, t: T) {
657 let mut inner_lock = must_lock!(self.inner);
658 match t.serialize() {
659 Ok(b) => {
660 let _ = inner_lock
661 .vm
662 .sys_complete_promise(name.to_owned(), NonEmptyValue::Success(b));
663 }
664 Err(e) => {
665 inner_lock.fail(
666 ErrorInner::Serialization {
667 syscall: "resolve_promise",
668 err: Box::new(e),
669 }
670 .into(),
671 );
672 }
673 }
674 }
675
676 pub fn reject_promise(&self, id: &str, failure: TerminalError) {
677 let _ = must_lock!(self.inner)
678 .vm
679 .sys_complete_promise(id.to_owned(), NonEmptyValue::Failure(failure.into()));
680 }
681
682 pub fn run<'a, Run, Fut, Out>(
683 &'a self,
684 run_closure: Run,
685 ) -> impl RunFuture<Result<Out, TerminalError>> + Send + 'a
686 where
687 Run: RunClosure<Fut = Fut, Output = Out> + Send + 'a,
688 Fut: Future<Output = HandlerResult<Out>> + Send + 'a,
689 Out: Serialize + Deserialize + 'static,
690 {
691 let this = Arc::clone(&self.inner);
692 InterceptErrorFuture::new(self.clone(), RunFutureImpl::new(this, run_closure))
693 }
694
695 pub fn handle_handler_result<T: Serialize>(&self, res: HandlerResult<T>) {
697 let mut inner_lock = must_lock!(self.inner);
698
699 let res_to_write = match res {
700 Ok(success) => match T::serialize(&success) {
701 Ok(t) => NonEmptyValue::Success(t),
702 Err(e) => {
703 inner_lock.fail(
704 ErrorInner::Serialization {
705 syscall: "output",
706 err: Box::new(e),
707 }
708 .into(),
709 );
710 return;
711 }
712 },
713 Err(e) => match e.0 {
714 HandlerErrorInner::Retryable(err) => {
715 inner_lock.fail(ErrorInner::HandlerResult { err }.into());
716 return;
717 }
718 HandlerErrorInner::Terminal(t) => NonEmptyValue::Failure(TerminalError(t).into()),
719 },
720 };
721
722 let _ = inner_lock.vm.sys_write_output(res_to_write);
723 inner_lock.maybe_flip_span_replaying_field();
724 }
725
726 pub fn end(&self) {
727 let _ = must_lock!(self.inner).vm.sys_end();
728 }
729
730 pub(crate) fn consume_to_end(&self) {
731 let mut inner_lock = must_lock!(self.inner);
732
733 let out = inner_lock.vm.take_output();
734 if let TakeOutputResult::Buffer(b) = out {
735 if !inner_lock.write.send(b) {
736 }
738 }
739 }
740
741 pub(super) fn fail(&self, e: Error) {
742 must_lock!(self.inner).fail(e)
743 }
744}
745
746pin_project! {
747 struct RunFutureImpl<Run, Ret, RunFnFut> {
748 name: String,
749 retry_policy: RetryPolicy,
750 phantom_data: PhantomData<fn() -> Ret>,
751 #[pin]
752 state: RunState<Run, RunFnFut, Ret>,
753 }
754}
755
756pin_project! {
757 #[project = RunStateProj]
758 enum RunState<Run, RunFnFut, Ret> {
759 New {
760 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
761 closure: Option<Run>,
762 },
763 ClosureRunning {
764 ctx: Option<Arc<Mutex<ContextInternalInner>>>,
765 handle: NotificationHandle,
766 start_time: Instant,
767 #[pin]
768 closure_fut: RunFnFut,
769 },
770 WaitingResultFut {
771 result_fut: BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
772 }
773 }
774}
775
776impl<Run, Ret, RunFnFut> RunFutureImpl<Run, Ret, RunFnFut> {
777 fn new(ctx: Arc<Mutex<ContextInternalInner>>, closure: Run) -> Self {
778 Self {
779 name: "".to_string(),
780 retry_policy: RetryPolicy::Infinite,
781 phantom_data: PhantomData,
782 state: RunState::New {
783 ctx: Some(ctx),
784 closure: Some(closure),
785 },
786 }
787 }
788
789 fn boxed_result_fut(
790 ctx: Arc<Mutex<ContextInternalInner>>,
791 handle: NotificationHandle,
792 ) -> BoxFuture<'static, Result<Result<Ret, TerminalError>, Error>>
793 where
794 Ret: Deserialize,
795 {
796 get_async_result(Arc::clone(&ctx), handle)
797 .map(|res| match res {
798 Ok(Value::Success(mut s)) => {
799 let t =
800 Ret::deserialize(&mut s).map_err(|e| Error::deserialization("run", e))?;
801 Ok(Ok(t))
802 }
803 Ok(Value::Failure(f)) => Ok(Err(f.into())),
804 Ok(v) => Err(ErrorInner::UnexpectedValueVariantForSyscall {
805 variant: <&'static str>::from(v),
806 syscall: "run",
807 }
808 .into()),
809 Err(e) => Err(e),
810 })
811 .boxed()
812 }
813}
814
815impl<Run, Ret, RunFnFut> RunFuture<Result<Result<Ret, TerminalError>, Error>>
816 for RunFutureImpl<Run, Ret, RunFnFut>
817where
818 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
819 Ret: Serialize + Deserialize,
820 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
821{
822 fn retry_policy(mut self, retry_policy: RunRetryPolicy) -> Self {
823 self.retry_policy = RetryPolicy::Exponential {
824 initial_interval: retry_policy.initial_delay,
825 factor: retry_policy.factor,
826 max_interval: retry_policy.max_delay,
827 max_attempts: retry_policy.max_attempts,
828 max_duration: retry_policy.max_duration,
829 };
830 self
831 }
832
833 fn name(mut self, name: impl Into<String>) -> Self {
834 self.name = name.into();
835 self
836 }
837}
838
839impl<Run, Ret, RunFnFut> Future for RunFutureImpl<Run, Ret, RunFnFut>
840where
841 Run: RunClosure<Fut = RunFnFut, Output = Ret> + Send,
842 Ret: Serialize + Deserialize,
843 RunFnFut: Future<Output = HandlerResult<Ret>> + Send,
844{
845 type Output = Result<Result<Ret, TerminalError>, Error>;
846
847 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
848 let mut this = self.project();
849
850 loop {
851 match this.state.as_mut().project() {
852 RunStateProj::New { ctx, closure, .. } => {
853 let ctx = ctx
854 .take()
855 .expect("Future should not be polled after returning Poll::Ready");
856 let closure = closure
857 .take()
858 .expect("Future should not be polled after returning Poll::Ready");
859 let mut inner_ctx = must_lock!(ctx);
860
861 let handle = inner_ctx
862 .vm
863 .sys_run(this.name.to_owned())
864 .map_err(ErrorInner::from)?;
865
866 match inner_ctx.vm.do_progress(vec![handle]) {
868 Ok(DoProgressResponse::ExecuteRun(handle_to_run)) => {
869 assert_eq!(handle, handle_to_run);
872
873 drop(inner_ctx);
874 this.state.set(RunState::ClosureRunning {
875 ctx: Some(ctx),
876 handle,
877 start_time: Instant::now(),
878 closure_fut: closure.run(),
879 });
880 }
881 Ok(DoProgressResponse::CancelSignalReceived) => {
882 drop(inner_ctx);
883 this.state.set(RunState::WaitingResultFut {
885 result_fut: async {
886 Ok(Err(TerminalError::from(TerminalFailure {
887 code: 409,
888 message: "cancelled".to_string(),
889 })))
890 }
891 .boxed(),
892 })
893 }
894 _ => {
895 drop(inner_ctx);
896 this.state.set(RunState::WaitingResultFut {
900 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
901 })
902 }
903 }
904 }
905 RunStateProj::ClosureRunning {
906 ctx,
907 handle,
908 start_time,
909 closure_fut,
910 } => {
911 let res = match ready!(closure_fut.poll(cx)) {
912 Ok(t) => RunExitResult::Success(Ret::serialize(&t).map_err(|e| {
913 ErrorInner::Serialization {
914 syscall: "run",
915 err: Box::new(e),
916 }
917 })?),
918 Err(e) => match e.0 {
919 HandlerErrorInner::Retryable(err) => RunExitResult::RetryableFailure {
920 attempt_duration: start_time.elapsed(),
921 error: CoreError::new(500u16, err.to_string()),
922 },
923 HandlerErrorInner::Terminal(t) => {
924 RunExitResult::TerminalFailure(TerminalError(t).into())
925 }
926 },
927 };
928
929 let ctx = ctx
930 .take()
931 .expect("Future should not be polled after returning Poll::Ready");
932 let handle = *handle;
933
934 let _ = {
935 must_lock!(ctx).vm.propose_run_completion(
936 handle,
937 res,
938 mem::take(this.retry_policy),
939 )
940 };
941
942 this.state.set(RunState::WaitingResultFut {
943 result_fut: Self::boxed_result_fut(Arc::clone(&ctx), handle),
944 });
945 }
946 RunStateProj::WaitingResultFut { result_fut } => return result_fut.poll_unpin(cx),
947 }
948 }
949 }
950}
951
952pin_project! {
953 struct CallFutureImpl<InvIdFut: Future, ResultFut> {
954 #[pin]
955 invocation_id_future: Shared<InvIdFut>,
956 #[pin]
957 result_future: ResultFut,
958 call_notification_handle: NotificationHandle,
959 ctx: ContextInternal,
960 }
961}
962
963impl<InvIdFut, ResultFut, Res> Future for CallFutureImpl<InvIdFut, ResultFut>
964where
965 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
966 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
967{
968 type Output = Result<Res, TerminalError>;
969
970 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
971 let this = self.project();
972 let result = ready!(this.result_future.poll(cx));
973
974 match result {
975 Ok(r) => Poll::Ready(r),
976 Err(e) => {
977 this.ctx.fail(e);
978
979 cx.waker().wake_by_ref();
982 Poll::Pending
983 }
984 }
985 }
986}
987
988impl<InvIdFut, ResultFut> InvocationHandle for CallFutureImpl<InvIdFut, ResultFut>
989where
990 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
991{
992 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
993 Shared::clone(&self.invocation_id_future)
994 }
995
996 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
997 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
998 let cloned_ctx = Arc::clone(&self.ctx.inner);
999 async move {
1000 let inv_id = cloned_invocation_id_fut.await?;
1001 let mut inner_lock = must_lock!(cloned_ctx);
1002 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1003 inner_lock.maybe_flip_span_replaying_field();
1004 drop(inner_lock);
1005 Ok(())
1006 }
1007 }
1008}
1009
1010impl<InvIdFut, ResultFut, Res> CallFuture for CallFutureImpl<InvIdFut, ResultFut>
1011where
1012 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1013 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1014{
1015 type Response = Res;
1016}
1017
1018impl<InvIdFut, ResultFut> crate::context::macro_support::SealedDurableFuture
1019 for CallFutureImpl<InvIdFut, ResultFut>
1020where
1021 InvIdFut: Future,
1022{
1023 fn inner_context(&self) -> ContextInternal {
1024 self.ctx.clone()
1025 }
1026
1027 fn handle(&self) -> NotificationHandle {
1028 self.call_notification_handle
1029 }
1030}
1031
1032impl<InvIdFut, ResultFut, Res> DurableFuture for CallFutureImpl<InvIdFut, ResultFut>
1033where
1034 InvIdFut: Future<Output = Result<String, TerminalError>> + Send,
1035 ResultFut: Future<Output = Result<Result<Res, TerminalError>, Error>> + Send,
1036{
1037}
1038
1039struct SendRequestHandle<InvIdFut: Future> {
1040 invocation_id_future: Shared<InvIdFut>,
1041 ctx: ContextInternal,
1042}
1043
1044impl<InvIdFut: Future<Output = Result<String, TerminalError>> + Send> InvocationHandle
1045 for SendRequestHandle<InvIdFut>
1046{
1047 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1048 Shared::clone(&self.invocation_id_future)
1049 }
1050
1051 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1052 let cloned_invocation_id_fut = Shared::clone(&self.invocation_id_future);
1053 let cloned_ctx = Arc::clone(&self.ctx.inner);
1054 async move {
1055 let inv_id = cloned_invocation_id_fut.await?;
1056 let mut inner_lock = must_lock!(cloned_ctx);
1057 let _ = inner_lock.vm.sys_cancel_invocation(inv_id);
1058 inner_lock.maybe_flip_span_replaying_field();
1059 drop(inner_lock);
1060 Ok(())
1061 }
1062 }
1063}
1064
1065struct InvocationIdBackedInvocationHandle {
1066 ctx: ContextInternal,
1067 invocation_id: String,
1068}
1069
1070impl InvocationHandle for InvocationIdBackedInvocationHandle {
1071 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1072 ready(Ok(self.invocation_id.clone()))
1073 }
1074
1075 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1076 let mut inner_lock = must_lock!(self.ctx.inner);
1077 let _ = inner_lock
1078 .vm
1079 .sys_cancel_invocation(self.invocation_id.clone());
1080 ready(Ok(()))
1081 }
1082}
1083
1084impl<A, B> InvocationHandle for Either<A, B>
1085where
1086 A: InvocationHandle,
1087 B: InvocationHandle,
1088{
1089 fn invocation_id(&self) -> impl Future<Output = Result<String, TerminalError>> + Send {
1090 match self {
1091 Either::Left(l) => Either::Left(l.invocation_id()),
1092 Either::Right(r) => Either::Right(r.invocation_id()),
1093 }
1094 }
1095
1096 fn cancel(&self) -> impl Future<Output = Result<(), TerminalError>> + Send {
1097 match self {
1098 Either::Left(l) => Either::Left(l.cancel()),
1099 Either::Right(r) => Either::Right(r.cancel()),
1100 }
1101 }
1102}
1103
1104impl Error {
1105 fn serialization<E: std::error::Error + Send + Sync + 'static>(
1106 syscall: &'static str,
1107 e: E,
1108 ) -> Self {
1109 ErrorInner::Serialization {
1110 syscall,
1111 err: Box::new(e),
1112 }
1113 .into()
1114 }
1115
1116 fn deserialization<E: std::error::Error + Send + Sync + 'static>(
1117 syscall: &'static str,
1118 e: E,
1119 ) -> Self {
1120 ErrorInner::Deserialization {
1121 syscall,
1122 err: Box::new(e),
1123 }
1124 .into()
1125 }
1126}
1127
1128fn get_async_result(
1129 ctx: Arc<Mutex<ContextInternalInner>>,
1130 handle: NotificationHandle,
1131) -> impl Future<Output = Result<Value, Error>> + Send {
1132 VmAsyncResultPollFuture::new(ctx, handle).map_err(Error::from)
1133}