restate_sdk_shared_core/vm/
mod.rs

1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3    attach_invocation_command_message, complete_awakeable_command_message,
4    complete_promise_command_message, get_invocation_output_command_message,
5    output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6    CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7    CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8    GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9    OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10    SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14    ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15    EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19    AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20    ImplicitCancellationOption, Input, NonDeterministicChecksOption, NonEmptyValue,
21    NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, SendHandle, TakeOutputResult,
22    Target, TerminalFailure, VMOptions, VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
23};
24use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
25use base64::{alphabet, Engine};
26use bytes::{Buf, BufMut, Bytes, BytesMut};
27use context::{AsyncResultsState, Context, Output, RunState};
28use std::borrow::Cow;
29use std::collections::{HashMap, VecDeque};
30use std::mem::size_of;
31use std::time::Duration;
32use std::{fmt, mem};
33use strum::IntoStaticStr;
34use tracing::{debug, enabled, instrument, Level};
35
36mod context;
37pub(crate) mod errors;
38mod transitions;
39
40const CONTENT_TYPE: &str = "content-type";
41
42#[derive(Debug, IntoStaticStr)]
43pub(crate) enum State {
44    WaitingStart,
45    WaitingReplayEntries {
46        received_entries: u32,
47        commands: VecDeque<RawMessage>,
48        async_results: AsyncResultsState,
49    },
50    Replaying {
51        commands: VecDeque<RawMessage>,
52        run_state: RunState,
53        async_results: AsyncResultsState,
54    },
55    Processing {
56        processing_first_entry: bool,
57        run_state: RunState,
58        async_results: AsyncResultsState,
59    },
60    Closed,
61}
62
63impl State {
64    fn as_unexpected_state(&self, event: &'static str) -> Error {
65        if matches!(self, State::Closed) {
66            return ClosedError::new(event).into();
67        }
68        UnexpectedStateError::new(self.into(), event).into()
69    }
70}
71
72struct TrackedInvocationId {
73    handle: NotificationHandle,
74    invocation_id: Option<String>,
75}
76
77impl TrackedInvocationId {
78    fn is_resolved(&self) -> bool {
79        self.invocation_id.is_some()
80    }
81}
82
83pub struct CoreVM {
84    options: VMOptions,
85
86    // Input decoder
87    decoder: Decoder,
88
89    // State machine
90    context: Context,
91    last_transition: Result<State, Error>,
92
93    // Implicit cancellation tracking
94    tracked_invocation_ids: Vec<TrackedInvocationId>,
95
96    // Run names, useful for debugging
97    sys_run_names: HashMap<NotificationHandle, String>,
98}
99
100impl CoreVM {
101    // Returns empty string if the invocation id is not present
102    fn debug_invocation_id(&self) -> &str {
103        if let Some(start_info) = self.context.start_info() {
104            &start_info.debug_id
105        } else {
106            ""
107        }
108    }
109
110    fn debug_state(&self) -> &'static str {
111        match &self.last_transition {
112            Ok(s) => s.into(),
113            Err(_) => "Failed",
114        }
115    }
116
117    #[allow(dead_code)]
118    fn verify_feature_support(
119        &mut self,
120        feature: &'static str,
121        minimum_required_protocol: Version,
122    ) -> VMResult<()> {
123        if self.context.negotiated_protocol_version < minimum_required_protocol {
124            return self.do_transition(HitError(
125                UnsupportedFeatureForNegotiatedVersion::new(
126                    feature,
127                    self.context.negotiated_protocol_version,
128                    minimum_required_protocol,
129                )
130                .into(),
131            ));
132        }
133        Ok(())
134    }
135
136    fn _is_completed(&self, handle: NotificationHandle) -> bool {
137        match &self.last_transition {
138            Ok(State::Replaying { async_results, .. })
139            | Ok(State::Processing { async_results, .. }) => {
140                async_results.is_handle_completed(handle)
141            }
142            _ => false,
143        }
144    }
145
146    fn _do_progress(
147        &mut self,
148        any_handle: Vec<NotificationHandle>,
149    ) -> Result<DoProgressResponse, Error> {
150        match self.do_transition(DoProgress(any_handle)) {
151            Ok(Ok(do_progress_response)) => Ok(do_progress_response),
152            Ok(Err(_)) => Err(SUSPENDED),
153            Err(e) => Err(e),
154        }
155    }
156
157    fn is_implicit_cancellation_enabled(&self) -> bool {
158        matches!(
159            self.options.implicit_cancellation,
160            ImplicitCancellationOption::Enabled { .. }
161        )
162    }
163}
164
165impl fmt::Debug for CoreVM {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        let mut s = f.debug_struct("CoreVM");
168        s.field("version", &self.context.negotiated_protocol_version);
169
170        if let Some(start_info) = self.context.start_info() {
171            s.field("invocation_id", &start_info.debug_id);
172        }
173
174        match &self.last_transition {
175            Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
176            Err(_) => s.field("last_transition", &"Errored"),
177        };
178
179        s.field("command_index", &self.context.journal.command_index())
180            .field(
181                "notification_index",
182                &self.context.journal.notification_index(),
183            )
184            .finish()
185    }
186}
187
188// --- Bound checks
189#[allow(unused)]
190const fn is_send<T: Send>() {}
191const _: () = is_send::<CoreVM>();
192
193// Macro used for informative debug logs
194macro_rules! invocation_debug_logs {
195    ($this:expr, $($arg:tt)*) => {
196        if ($this.is_processing()) {
197            tracing::debug!($($arg)*)
198        }
199    };
200}
201
202impl super::VM for CoreVM {
203    #[instrument(level = "trace", skip(request_headers), ret)]
204    fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
205        let version = request_headers
206            .extract(CONTENT_TYPE)
207            .map_err(|e| {
208                Error::new(
209                    errors::codes::BAD_REQUEST,
210                    format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
211                )
212            })?
213            .ok_or(errors::MISSING_CONTENT_TYPE)?
214            .parse::<Version>()?;
215
216        if version < Version::minimum_supported_version()
217            || version > Version::maximum_supported_version()
218        {
219            return Err(Error::new(
220                errors::codes::UNSUPPORTED_MEDIA_TYPE,
221                format!(
222                    "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
223                    You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
224                    version,
225                    Version::minimum_supported_version(),
226                    Version::maximum_supported_version()
227                ),
228            ));
229        }
230        let non_deterministic_checks_ignore_payload_equality = matches!(
231            options.non_determinism_checks,
232            NonDeterministicChecksOption::PayloadChecksDisabled
233        );
234
235        Ok(Self {
236            options,
237            decoder: Decoder::new(version),
238            context: Context {
239                input_is_closed: false,
240                output: Output::new(version),
241                start_info: None,
242                journal: Default::default(),
243                eager_state: Default::default(),
244                non_deterministic_checks_ignore_payload_equality,
245                negotiated_protocol_version: version,
246            },
247            last_transition: Ok(State::WaitingStart),
248            tracked_invocation_ids: vec![],
249            sys_run_names: HashMap::with_capacity(0),
250        })
251    }
252
253    #[instrument(
254        level = "trace",
255        skip(self),
256        fields(
257            restate.invocation.id = self.debug_invocation_id(),
258            restate.protocol.state = self.debug_state(),
259            restate.journal.command_index = self.context.journal.command_index(),
260            restate.protocol.version = %self.context.negotiated_protocol_version
261        ),
262        ret
263    )]
264    fn get_response_head(&self) -> ResponseHead {
265        ResponseHead {
266            status_code: 200,
267            headers: vec![Header {
268                key: Cow::Borrowed(CONTENT_TYPE),
269                value: Cow::Borrowed(self.context.negotiated_protocol_version.content_type()),
270            }],
271            version: self.context.negotiated_protocol_version,
272        }
273    }
274
275    #[instrument(
276        level = "trace",
277        skip(self),
278        fields(
279            restate.invocation.id = self.debug_invocation_id(),
280            restate.protocol.state = self.debug_state(),
281            restate.journal.command_index = self.context.journal.command_index(),
282            restate.protocol.version = %self.context.negotiated_protocol_version
283        ),
284        ret
285    )]
286    fn notify_input(&mut self, buffer: Bytes) {
287        self.decoder.push(buffer);
288        loop {
289            match self.decoder.consume_next() {
290                Ok(Some(msg)) => {
291                    if self.do_transition(NewMessage(msg)).is_err() {
292                        return;
293                    }
294                }
295                Ok(None) => {
296                    return;
297                }
298                Err(e) => {
299                    if self.do_transition(HitError(e.into())).is_err() {
300                        return;
301                    }
302                }
303            }
304        }
305    }
306
307    #[instrument(
308        level = "trace",
309        skip(self),
310        fields(
311            restate.invocation.id = self.debug_invocation_id(),
312            restate.protocol.state = self.debug_state(),
313            restate.journal.command_index = self.context.journal.command_index(),
314            restate.protocol.version = %self.context.negotiated_protocol_version
315        ),
316        ret
317    )]
318    fn notify_input_closed(&mut self) {
319        self.context.input_is_closed = true;
320        let _ = self.do_transition(NotifyInputClosed);
321    }
322
323    #[instrument(
324        level = "trace",
325        skip(self),
326        fields(
327            restate.invocation.id = self.debug_invocation_id(),
328            restate.protocol.state = self.debug_state(),
329            restate.journal.command_index = self.context.journal.command_index(),
330            restate.protocol.version = %self.context.negotiated_protocol_version
331        ),
332        ret
333    )]
334    fn notify_error(
335        &mut self,
336        mut error: Error,
337        command_relationship: Option<CommandRelationship>,
338    ) {
339        if let Some(command_relationship) = command_relationship {
340            error = error.with_related_command_metadata(
341                self.context
342                    .journal
343                    .resolve_related_command(command_relationship),
344            );
345        }
346
347        let _ = self.do_transition(HitError(error));
348    }
349
350    #[instrument(
351        level = "trace",
352        skip(self),
353        fields(
354            restate.invocation.id = self.debug_invocation_id(),
355            restate.protocol.state = self.debug_state(),
356            restate.journal.command_index = self.context.journal.command_index(),
357            restate.protocol.version = %self.context.negotiated_protocol_version
358        ),
359        ret
360    )]
361    fn take_output(&mut self) -> TakeOutputResult {
362        if self.context.output.buffer.has_remaining() {
363            TakeOutputResult::Buffer(
364                self.context
365                    .output
366                    .buffer
367                    .copy_to_bytes(self.context.output.buffer.remaining()),
368            )
369        } else if !self.context.output.is_closed() {
370            TakeOutputResult::Buffer(Bytes::default())
371        } else {
372            TakeOutputResult::EOF
373        }
374    }
375
376    #[instrument(
377        level = "trace",
378        skip(self),
379        fields(
380            restate.invocation.id = self.debug_invocation_id(),
381            restate.protocol.state = self.debug_state(),
382            restate.journal.command_index = self.context.journal.command_index(),
383            restate.protocol.version = %self.context.negotiated_protocol_version
384        ),
385        ret
386    )]
387    fn is_ready_to_execute(&self) -> Result<bool, Error> {
388        match &self.last_transition {
389            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
390            Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
391            Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
392            Err(e) => Err(e.clone()),
393        }
394    }
395
396    #[instrument(
397        level = "trace",
398        skip(self),
399        fields(
400            restate.invocation.id = self.debug_invocation_id(),
401            restate.protocol.state = self.debug_state(),
402            restate.journal.command_index = self.context.journal.command_index(),
403            restate.protocol.version = %self.context.negotiated_protocol_version
404        ),
405        ret
406    )]
407    fn is_completed(&self, handle: NotificationHandle) -> bool {
408        self._is_completed(handle)
409    }
410
411    #[instrument(
412        level = "trace",
413        skip(self),
414        fields(
415            restate.invocation.id = self.debug_invocation_id(),
416            restate.protocol.state = self.debug_state(),
417            restate.journal.command_index = self.context.journal.command_index(),
418            restate.protocol.version = %self.context.negotiated_protocol_version
419        ),
420        ret
421    )]
422    fn do_progress(
423        &mut self,
424        mut any_handle: Vec<NotificationHandle>,
425    ) -> VMResult<DoProgressResponse> {
426        if self.is_implicit_cancellation_enabled() {
427            // We want the runtime to wake us up in case cancel notification comes in.
428            any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
429
430            match self._do_progress(any_handle) {
431                Ok(DoProgressResponse::AnyCompleted) => {
432                    // If it's cancel signal, then let's go on with the cancellation logic
433                    if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
434                        // Loop once over the tracked invocation ids to resolve the unresolved ones
435                        for i in 0..self.tracked_invocation_ids.len() {
436                            if self.tracked_invocation_ids[i].is_resolved() {
437                                continue;
438                            }
439
440                            let handle = self.tracked_invocation_ids[i].handle;
441
442                            // Try to resolve it
443                            match self._do_progress(vec![handle]) {
444                                Ok(DoProgressResponse::AnyCompleted) => {
445                                    let invocation_id = match self.do_transition(CopyNotification(handle)) {
446                                        Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
447                                        Ok(Err(_)) => Err(SUSPENDED),
448                                        _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
449                                    }?;
450
451                                    // This handle is resolved
452                                    self.tracked_invocation_ids[i].invocation_id =
453                                        Some(invocation_id);
454                                }
455                                res => return res,
456                            }
457                        }
458
459                        // Now we got all the invocation IDs, let's cancel!
460                        for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
461                            self.sys_cancel_invocation(
462                                tracked_invocation_id
463                                    .invocation_id
464                                    .expect("We resolved before all the invocation ids"),
465                            )?;
466                        }
467
468                        // Flip the cancellation
469                        let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
470
471                        // Done
472                        Ok(DoProgressResponse::CancelSignalReceived)
473                    } else {
474                        Ok(DoProgressResponse::AnyCompleted)
475                    }
476                }
477                res => res,
478            }
479        } else {
480            self._do_progress(any_handle)
481        }
482    }
483
484    #[instrument(
485        level = "trace",
486        skip(self),
487        fields(
488            restate.invocation.id = self.debug_invocation_id(),
489            restate.protocol.state = self.debug_state(),
490            restate.journal.command_index = self.context.journal.command_index(),
491            restate.protocol.version = %self.context.negotiated_protocol_version
492        ),
493        ret
494    )]
495    fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
496        match self.do_transition(TakeNotification(handle)) {
497            Ok(Ok(Some(value))) => {
498                if self.is_implicit_cancellation_enabled() {
499                    // Let's check if that's one of the tracked invocation ids
500                    // We can do binary search here because we assume tracked_invocation_ids is ordered, as handles are incremental numbers
501                    if let Ok(found) = self
502                        .tracked_invocation_ids
503                        .binary_search_by(|tracked| tracked.handle.cmp(&handle))
504                    {
505                        let Value::InvocationId(invocation_id) = &value else {
506                            panic!("Expecting an invocation id here, but got {value:?}");
507                        };
508                        // Keep track of this invocation id
509                        self.tracked_invocation_ids
510                            .get_mut(found)
511                            .unwrap()
512                            .invocation_id = Some(invocation_id.clone());
513                    }
514                }
515
516                Ok(Some(value))
517            }
518            Ok(Ok(None)) => Ok(None),
519            Ok(Err(_)) => Err(SUSPENDED),
520            Err(e) => Err(e),
521        }
522    }
523
524    #[instrument(
525        level = "trace",
526        skip(self),
527        fields(
528            restate.invocation.id = self.debug_invocation_id(),
529            restate.protocol.state = self.debug_state(),
530            restate.journal.command_index = self.context.journal.command_index(),
531            restate.protocol.version = %self.context.negotiated_protocol_version
532        ),
533        ret
534    )]
535    fn sys_input(&mut self) -> Result<Input, Error> {
536        self.do_transition(SysInput)
537    }
538
539    #[instrument(
540        level = "trace",
541        skip(self),
542        fields(
543            restate.invocation.id = self.debug_invocation_id(),
544            restate.protocol.state = self.debug_state(),
545            restate.journal.command_index = self.context.journal.command_index(),
546            restate.protocol.version = %self.context.negotiated_protocol_version
547        ),
548        ret
549    )]
550    fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
551        invocation_debug_logs!(self, "Executing 'Get state {key}'");
552        self.do_transition(SysStateGet(key))
553    }
554
555    #[instrument(
556        level = "trace",
557        skip(self),
558        fields(
559            restate.invocation.id = self.debug_invocation_id(),
560            restate.protocol.state = self.debug_state(),
561            restate.journal.command_index = self.context.journal.command_index(),
562            restate.protocol.version = %self.context.negotiated_protocol_version
563        ),
564        ret
565    )]
566    fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
567        invocation_debug_logs!(self, "Executing 'Get state keys'");
568        self.do_transition(SysStateGetKeys)
569    }
570
571    #[instrument(
572        level = "trace",
573        skip(self, value),
574        fields(
575            restate.invocation.id = self.debug_invocation_id(),
576            restate.protocol.state = self.debug_state(),
577            restate.journal.command_index = self.context.journal.command_index(),
578            restate.protocol.version = %self.context.negotiated_protocol_version
579        ),
580        ret
581    )]
582    fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
583        invocation_debug_logs!(self, "Executing 'Set state {key}'");
584        self.context.eager_state.set(key.clone(), value.clone());
585        self.do_transition(SysNonCompletableEntry(
586            "SysStateSet",
587            SetStateCommandMessage {
588                key: Bytes::from(key.into_bytes()),
589                value: Some(value.into()),
590                ..SetStateCommandMessage::default()
591            },
592        ))
593    }
594
595    #[instrument(
596        level = "trace",
597        skip(self),
598        fields(
599            restate.invocation.id = self.debug_invocation_id(),
600            restate.protocol.state = self.debug_state(),
601            restate.journal.command_index = self.context.journal.command_index(),
602            restate.protocol.version = %self.context.negotiated_protocol_version
603        ),
604        ret
605    )]
606    fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
607        invocation_debug_logs!(self, "Executing 'Clear state {key}'");
608        self.context.eager_state.clear(key.clone());
609        self.do_transition(SysNonCompletableEntry(
610            "SysStateClear",
611            ClearStateCommandMessage {
612                key: Bytes::from(key.into_bytes()),
613                ..ClearStateCommandMessage::default()
614            },
615        ))
616    }
617
618    #[instrument(
619        level = "trace",
620        skip(self),
621        fields(
622            restate.invocation.id = self.debug_invocation_id(),
623            restate.protocol.state = self.debug_state(),
624            restate.journal.command_index = self.context.journal.command_index(),
625            restate.protocol.version = %self.context.negotiated_protocol_version
626        ),
627        ret
628    )]
629    fn sys_state_clear_all(&mut self) -> Result<(), Error> {
630        invocation_debug_logs!(self, "Executing 'Clear all state'");
631        self.context.eager_state.clear_all();
632        self.do_transition(SysNonCompletableEntry(
633            "SysStateClearAll",
634            ClearAllStateCommandMessage::default(),
635        ))
636    }
637
638    #[instrument(
639        level = "trace",
640        skip(self),
641        fields(
642            restate.invocation.id = self.debug_invocation_id(),
643            restate.protocol.state = self.debug_state(),
644            restate.journal.command_index = self.context.journal.command_index(),
645            restate.protocol.version = %self.context.negotiated_protocol_version
646        ),
647        ret
648    )]
649    fn sys_sleep(
650        &mut self,
651        name: String,
652        wake_up_time_since_unix_epoch: Duration,
653        now_since_unix_epoch: Option<Duration>,
654    ) -> VMResult<NotificationHandle> {
655        if self.is_processing() {
656            match (&name, now_since_unix_epoch) {
657                (name, Some(now_since_unix_epoch)) if name.is_empty() => {
658                    debug!(
659                        "Executing 'Timer with duration {:?}'",
660                        wake_up_time_since_unix_epoch - now_since_unix_epoch
661                    );
662                }
663                (name, Some(now_since_unix_epoch)) => {
664                    debug!(
665                        "Executing 'Timer {name} with duration {:?}'",
666                        wake_up_time_since_unix_epoch - now_since_unix_epoch
667                    );
668                }
669                (name, None) if name.is_empty() => {
670                    debug!("Executing 'Timer'");
671                }
672                (name, None) => {
673                    debug!("Executing 'Timer named {name}'");
674                }
675            }
676        }
677
678        let completion_id = self.context.journal.next_completion_notification_id();
679
680        self.do_transition(SysSimpleCompletableEntry(
681            "SysSleep",
682            SleepCommandMessage {
683                wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
684                    .expect("millis since Unix epoch should fit in u64"),
685                result_completion_id: completion_id,
686                name,
687            },
688            completion_id,
689        ))
690    }
691
692    #[instrument(
693        level = "trace",
694        skip(self, input),
695        fields(
696            restate.invocation.id = self.debug_invocation_id(),
697            restate.protocol.state = self.debug_state(),
698            restate.journal.command_index = self.context.journal.command_index(),
699            restate.protocol.version = %self.context.negotiated_protocol_version
700        ),
701        ret
702    )]
703    fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
704        invocation_debug_logs!(
705            self,
706            "Executing 'Call {}/{}'",
707            target.service,
708            target.handler
709        );
710        if let Some(idempotency_key) = &target.idempotency_key {
711            if idempotency_key.is_empty() {
712                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
713                unreachable!();
714            }
715        }
716
717        let call_invocation_id_completion_id =
718            self.context.journal.next_completion_notification_id();
719        let result_completion_id = self.context.journal.next_completion_notification_id();
720
721        let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
722            "SysCall",
723            CallCommandMessage {
724                service_name: target.service,
725                handler_name: target.handler,
726                key: target.key.unwrap_or_default(),
727                idempotency_key: target.idempotency_key,
728                headers: target
729                    .headers
730                    .into_iter()
731                    .map(crate::service_protocol::messages::Header::from)
732                    .collect(),
733                parameter: input,
734                invocation_id_notification_idx: call_invocation_id_completion_id,
735                result_completion_id,
736                ..Default::default()
737            },
738            vec![call_invocation_id_completion_id, result_completion_id],
739        ))?;
740
741        if matches!(
742            self.options.implicit_cancellation,
743            ImplicitCancellationOption::Enabled {
744                cancel_children_calls: true,
745                ..
746            }
747        ) {
748            self.tracked_invocation_ids.push(TrackedInvocationId {
749                handle: handles[0],
750                invocation_id: None,
751            })
752        }
753
754        Ok(CallHandle {
755            invocation_id_notification_handle: handles[0],
756            call_notification_handle: handles[1],
757        })
758    }
759
760    #[instrument(
761        level = "trace",
762        skip(self, input),
763        fields(
764            restate.invocation.id = self.debug_invocation_id(),
765            restate.protocol.state = self.debug_state(),
766            restate.journal.command_index = self.context.journal.command_index(),
767            restate.protocol.version = %self.context.negotiated_protocol_version
768        ),
769        ret
770    )]
771    fn sys_send(
772        &mut self,
773        target: Target,
774        input: Bytes,
775        delay: Option<Duration>,
776    ) -> VMResult<SendHandle> {
777        invocation_debug_logs!(
778            self,
779            "Executing 'Send to {}/{}'",
780            target.service,
781            target.handler
782        );
783        if let Some(idempotency_key) = &target.idempotency_key {
784            if idempotency_key.is_empty() {
785                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
786                unreachable!();
787            }
788        }
789        let call_invocation_id_completion_id =
790            self.context.journal.next_completion_notification_id();
791        let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
792            "SysOneWayCall",
793            OneWayCallCommandMessage {
794                service_name: target.service,
795                handler_name: target.handler,
796                key: target.key.unwrap_or_default(),
797                idempotency_key: target.idempotency_key,
798                headers: target
799                    .headers
800                    .into_iter()
801                    .map(crate::service_protocol::messages::Header::from)
802                    .collect(),
803                parameter: input,
804                invoke_time: delay
805                    .map(|d| {
806                        u64::try_from(d.as_millis())
807                            .expect("millis since Unix epoch should fit in u64")
808                    })
809                    .unwrap_or_default(),
810                invocation_id_notification_idx: call_invocation_id_completion_id,
811                ..Default::default()
812            },
813            call_invocation_id_completion_id,
814        ))?;
815
816        if matches!(
817            self.options.implicit_cancellation,
818            ImplicitCancellationOption::Enabled {
819                cancel_children_one_way_calls: true,
820                ..
821            }
822        ) {
823            self.tracked_invocation_ids.push(TrackedInvocationId {
824                handle: invocation_id_notification_handle,
825                invocation_id: None,
826            })
827        }
828
829        Ok(SendHandle {
830            invocation_id_notification_handle,
831        })
832    }
833
834    #[instrument(
835        level = "trace",
836        skip(self),
837        fields(
838            restate.invocation.id = self.debug_invocation_id(),
839            restate.protocol.state = self.debug_state(),
840            restate.journal.command_index = self.context.journal.command_index(),
841            restate.protocol.version = %self.context.negotiated_protocol_version
842        ),
843        ret
844    )]
845    fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
846        invocation_debug_logs!(self, "Executing 'Create awakeable'");
847
848        let signal_id = self.context.journal.next_signal_notification_id();
849
850        let handle = self.do_transition(CreateSignalHandle(
851            "SysAwakeable",
852            NotificationId::SignalId(signal_id),
853        ))?;
854
855        Ok((
856            awakeable_id_str(&self.context.expect_start_info().id, signal_id),
857            handle,
858        ))
859    }
860
861    #[instrument(
862        level = "trace",
863        skip(self, value),
864        fields(
865            restate.invocation.id = self.debug_invocation_id(),
866            restate.protocol.state = self.debug_state(),
867            restate.journal.command_index = self.context.journal.command_index(),
868            restate.protocol.version = %self.context.negotiated_protocol_version
869        ),
870        ret
871    )]
872    fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
873        invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
874        self.do_transition(SysNonCompletableEntry(
875            "SysCompleteAwakeable",
876            CompleteAwakeableCommandMessage {
877                awakeable_id: id,
878                result: Some(match value {
879                    NonEmptyValue::Success(s) => {
880                        complete_awakeable_command_message::Result::Value(s.into())
881                    }
882                    NonEmptyValue::Failure(f) => {
883                        complete_awakeable_command_message::Result::Failure(f.into())
884                    }
885                }),
886                ..Default::default()
887            },
888        ))
889    }
890
891    #[instrument(
892        level = "trace",
893        skip(self),
894        fields(
895            restate.invocation.id = self.debug_invocation_id(),
896            restate.protocol.state = self.debug_state(),
897            restate.journal.command_index = self.context.journal.command_index(),
898            restate.protocol.version = %self.context.negotiated_protocol_version
899        ),
900        ret
901    )]
902    fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
903        invocation_debug_logs!(self, "Executing 'Create named signal'");
904
905        self.do_transition(CreateSignalHandle(
906            "SysCreateNamedSignal",
907            NotificationId::SignalName(signal_name),
908        ))
909    }
910
911    #[instrument(
912        level = "trace",
913        skip(self, value),
914        fields(
915            restate.invocation.id = self.debug_invocation_id(),
916            restate.protocol.state = self.debug_state(),
917            restate.journal.command_index = self.context.journal.command_index(),
918            restate.protocol.version = %self.context.negotiated_protocol_version
919        ),
920        ret
921    )]
922    fn sys_complete_signal(
923        &mut self,
924        target_invocation_id: String,
925        signal_name: String,
926        value: NonEmptyValue,
927    ) -> VMResult<()> {
928        invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
929        self.do_transition(SysNonCompletableEntry(
930            "SysCompleteAwakeable",
931            SendSignalCommandMessage {
932                target_invocation_id,
933                signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
934                result: Some(match value {
935                    NonEmptyValue::Success(s) => {
936                        send_signal_command_message::Result::Value(s.into())
937                    }
938                    NonEmptyValue::Failure(f) => {
939                        send_signal_command_message::Result::Failure(f.into())
940                    }
941                }),
942                ..Default::default()
943            },
944        ))
945    }
946
947    #[instrument(
948        level = "trace",
949        skip(self),
950        fields(
951            restate.invocation.id = self.debug_invocation_id(),
952            restate.protocol.state = self.debug_state(),
953            restate.journal.command_index = self.context.journal.command_index(),
954            restate.protocol.version = %self.context.negotiated_protocol_version
955        ),
956        ret
957    )]
958    fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
959        invocation_debug_logs!(self, "Executing 'Await promise {key}'");
960
961        let result_completion_id = self.context.journal.next_completion_notification_id();
962        self.do_transition(SysSimpleCompletableEntry(
963            "SysGetPromise",
964            GetPromiseCommandMessage {
965                key,
966                result_completion_id,
967                ..Default::default()
968            },
969            result_completion_id,
970        ))
971    }
972
973    #[instrument(
974        level = "trace",
975        skip(self),
976        fields(
977            restate.invocation.id = self.debug_invocation_id(),
978            restate.protocol.state = self.debug_state(),
979            restate.journal.command_index = self.context.journal.command_index(),
980            restate.protocol.version = %self.context.negotiated_protocol_version
981        ),
982        ret
983    )]
984    fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
985        invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
986
987        let result_completion_id = self.context.journal.next_completion_notification_id();
988        self.do_transition(SysSimpleCompletableEntry(
989            "SysPeekPromise",
990            PeekPromiseCommandMessage {
991                key,
992                result_completion_id,
993                ..Default::default()
994            },
995            result_completion_id,
996        ))
997    }
998
999    #[instrument(
1000        level = "trace",
1001        skip(self, value),
1002        fields(
1003            restate.invocation.id = self.debug_invocation_id(),
1004            restate.protocol.state = self.debug_state(),
1005            restate.journal.command_index = self.context.journal.command_index(),
1006            restate.protocol.version = %self.context.negotiated_protocol_version
1007        ),
1008        ret
1009    )]
1010    fn sys_complete_promise(
1011        &mut self,
1012        key: String,
1013        value: NonEmptyValue,
1014    ) -> VMResult<NotificationHandle> {
1015        invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1016
1017        let result_completion_id = self.context.journal.next_completion_notification_id();
1018        self.do_transition(SysSimpleCompletableEntry(
1019            "SysCompletePromise",
1020            CompletePromiseCommandMessage {
1021                key,
1022                completion: Some(match value {
1023                    NonEmptyValue::Success(s) => {
1024                        complete_promise_command_message::Completion::CompletionValue(s.into())
1025                    }
1026                    NonEmptyValue::Failure(f) => {
1027                        complete_promise_command_message::Completion::CompletionFailure(f.into())
1028                    }
1029                }),
1030                result_completion_id,
1031                ..Default::default()
1032            },
1033            result_completion_id,
1034        ))
1035    }
1036
1037    #[instrument(
1038        level = "trace",
1039        skip(self),
1040        fields(
1041            restate.invocation.id = self.debug_invocation_id(),
1042            restate.protocol.state = self.debug_state(),
1043            restate.journal.command_index = self.context.journal.command_index(),
1044            restate.protocol.version = %self.context.negotiated_protocol_version
1045        ),
1046        ret
1047    )]
1048    fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1049        match self.do_transition(SysRun(name.clone())) {
1050            Ok(handle) => {
1051                if enabled!(Level::DEBUG) {
1052                    // Store the name, we need it later when completing
1053                    self.sys_run_names.insert(handle, name);
1054                }
1055                Ok(handle)
1056            }
1057            Err(e) => Err(e),
1058        }
1059    }
1060
1061    #[instrument(
1062        level = "trace",
1063        skip(self, value, retry_policy),
1064        fields(
1065            restate.invocation.id = self.debug_invocation_id(),
1066            restate.protocol.state = self.debug_state(),
1067            restate.journal.command_index = self.context.journal.command_index(),
1068            restate.protocol.version = %self.context.negotiated_protocol_version
1069        ),
1070        ret
1071    )]
1072    fn propose_run_completion(
1073        &mut self,
1074        notification_handle: NotificationHandle,
1075        value: RunExitResult,
1076        retry_policy: RetryPolicy,
1077    ) -> VMResult<()> {
1078        if enabled!(Level::DEBUG) {
1079            let name: &str = self
1080                .sys_run_names
1081                .get(&notification_handle)
1082                .map(String::as_str)
1083                .unwrap_or_default();
1084            match &value {
1085                RunExitResult::Success(_) => {
1086                    invocation_debug_logs!(self, "Journaling run '{name}' success result");
1087                }
1088                RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1089                    invocation_debug_logs!(
1090                        self,
1091                        "Journaling run '{name}' terminal failure {code} result"
1092                    );
1093                }
1094                RunExitResult::RetryableFailure { .. } => {
1095                    invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1096                }
1097            }
1098        }
1099
1100        self.do_transition(ProposeRunCompletion(
1101            notification_handle,
1102            value,
1103            retry_policy,
1104        ))
1105    }
1106
1107    #[instrument(
1108        level = "trace",
1109        skip(self),
1110        fields(
1111            restate.invocation.id = self.debug_invocation_id(),
1112            restate.protocol.state = self.debug_state(),
1113            restate.journal.command_index = self.context.journal.command_index(),
1114            restate.protocol.version = %self.context.negotiated_protocol_version
1115        ),
1116        ret
1117    )]
1118    fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1119        invocation_debug_logs!(
1120            self,
1121            "Executing 'Cancel invocation' of {target_invocation_id}"
1122        );
1123        self.do_transition(SysNonCompletableEntry(
1124            "SysCancelInvocation",
1125            SendSignalCommandMessage {
1126                target_invocation_id,
1127                signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1128                result: Some(send_signal_command_message::Result::Void(Default::default())),
1129                ..Default::default()
1130            },
1131        ))
1132    }
1133
1134    #[instrument(
1135        level = "trace",
1136        skip(self),
1137        fields(
1138            restate.invocation.id = self.debug_invocation_id(),
1139            restate.protocol.state = self.debug_state(),
1140            restate.journal.command_index = self.context.journal.command_index(),
1141            restate.protocol.version = %self.context.negotiated_protocol_version
1142        ),
1143        ret
1144    )]
1145    fn sys_attach_invocation(
1146        &mut self,
1147        target: AttachInvocationTarget,
1148    ) -> VMResult<NotificationHandle> {
1149        invocation_debug_logs!(self, "Executing 'Attach invocation'");
1150
1151        let result_completion_id = self.context.journal.next_completion_notification_id();
1152        self.do_transition(SysSimpleCompletableEntry(
1153            "SysAttachInvocation",
1154            AttachInvocationCommandMessage {
1155                target: Some(match target {
1156                    AttachInvocationTarget::InvocationId(id) => {
1157                        attach_invocation_command_message::Target::InvocationId(id)
1158                    }
1159                    AttachInvocationTarget::WorkflowId { name, key } => {
1160                        attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1161                            workflow_name: name,
1162                            workflow_key: key,
1163                        })
1164                    }
1165                    AttachInvocationTarget::IdempotencyId {
1166                        service_name,
1167                        service_key,
1168                        handler_name,
1169                        idempotency_key,
1170                    } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1171                        IdempotentRequestTarget {
1172                            service_name,
1173                            service_key,
1174                            handler_name,
1175                            idempotency_key,
1176                        },
1177                    ),
1178                }),
1179                result_completion_id,
1180                ..Default::default()
1181            },
1182            result_completion_id,
1183        ))
1184    }
1185
1186    #[instrument(
1187        level = "trace",
1188        skip(self),
1189        fields(
1190            restate.invocation.id = self.debug_invocation_id(),
1191            restate.protocol.state = self.debug_state(),
1192            restate.journal.command_index = self.context.journal.command_index(),
1193            restate.protocol.version = %self.context.negotiated_protocol_version
1194        ),
1195        ret
1196    )]
1197    fn sys_get_invocation_output(
1198        &mut self,
1199        target: AttachInvocationTarget,
1200    ) -> VMResult<NotificationHandle> {
1201        invocation_debug_logs!(self, "Executing 'Get invocation output'");
1202
1203        let result_completion_id = self.context.journal.next_completion_notification_id();
1204        self.do_transition(SysSimpleCompletableEntry(
1205            "SysGetInvocationOutput",
1206            GetInvocationOutputCommandMessage {
1207                target: Some(match target {
1208                    AttachInvocationTarget::InvocationId(id) => {
1209                        get_invocation_output_command_message::Target::InvocationId(id)
1210                    }
1211                    AttachInvocationTarget::WorkflowId { name, key } => {
1212                        get_invocation_output_command_message::Target::WorkflowTarget(
1213                            WorkflowTarget {
1214                                workflow_name: name,
1215                                workflow_key: key,
1216                            },
1217                        )
1218                    }
1219                    AttachInvocationTarget::IdempotencyId {
1220                        service_name,
1221                        service_key,
1222                        handler_name,
1223                        idempotency_key,
1224                    } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1225                        IdempotentRequestTarget {
1226                            service_name,
1227                            service_key,
1228                            handler_name,
1229                            idempotency_key,
1230                        },
1231                    ),
1232                }),
1233                result_completion_id,
1234                ..Default::default()
1235            },
1236            result_completion_id,
1237        ))
1238    }
1239
1240    #[instrument(
1241        level = "trace",
1242        skip(self, value),
1243        fields(
1244            restate.invocation.id = self.debug_invocation_id(),
1245            restate.protocol.state = self.debug_state(),
1246            restate.journal.command_index = self.context.journal.command_index(),
1247            restate.protocol.version = %self.context.negotiated_protocol_version
1248        ),
1249        ret
1250    )]
1251    fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1252        match &value {
1253            NonEmptyValue::Success(_) => {
1254                invocation_debug_logs!(self, "Writing invocation result success value");
1255            }
1256            NonEmptyValue::Failure(_) => {
1257                invocation_debug_logs!(self, "Writing invocation result failure value");
1258            }
1259        }
1260        self.do_transition(SysNonCompletableEntry(
1261            "SysWriteOutput",
1262            OutputCommandMessage {
1263                result: Some(match value {
1264                    NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1265                    NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1266                }),
1267                ..OutputCommandMessage::default()
1268            },
1269        ))
1270    }
1271
1272    #[instrument(
1273        level = "trace",
1274        skip(self),
1275        fields(
1276            restate.invocation.id = self.debug_invocation_id(),
1277            restate.protocol.state = self.debug_state(),
1278            restate.journal.command_index = self.context.journal.command_index(),
1279            restate.protocol.version = %self.context.negotiated_protocol_version
1280        ),
1281        ret
1282    )]
1283    fn sys_end(&mut self) -> Result<(), Error> {
1284        invocation_debug_logs!(self, "End of the invocation");
1285        self.do_transition(SysEnd)
1286    }
1287
1288    fn is_waiting_preflight(&self) -> bool {
1289        matches!(
1290            &self.last_transition,
1291            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1292        )
1293    }
1294
1295    fn is_replaying(&self) -> bool {
1296        matches!(&self.last_transition, Ok(State::Replaying { .. }))
1297    }
1298
1299    fn is_processing(&self) -> bool {
1300        matches!(&self.last_transition, Ok(State::Processing { .. }))
1301    }
1302
1303    fn last_command_index(&self) -> i64 {
1304        self.context.journal.command_index()
1305    }
1306}
1307
1308const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1309    .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1310    .with_encode_padding(false);
1311const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1312
1313const AWAKEABLE_PREFIX: &str = "sign_1";
1314
1315fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1316    let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1317    input_buf.put_slice(id);
1318    input_buf.put_u32(completion_index);
1319    format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1320}