Skip to main content

restate_sdk_shared_core/vm/
mod.rs

1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3    attach_invocation_command_message, complete_awakeable_command_message,
4    complete_promise_command_message, get_invocation_output_command_message,
5    output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6    CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7    CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8    GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9    OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10    SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14    ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15    EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19    AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20    ImplicitCancellationOption, Input, NonDeterministicChecksOption, NonEmptyValue,
21    NotificationHandle, PayloadOptions, ResponseHead, RetryPolicy, RunExitResult, SendHandle,
22    TakeOutputResult, Target, TerminalFailure, VMOptions, VMResult, Value,
23    CANCEL_NOTIFICATION_HANDLE,
24};
25use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
26use base64::{alphabet, Engine};
27use bytes::{Buf, BufMut, Bytes, BytesMut};
28use context::{AsyncResultsState, Context, Output, RunState};
29use std::borrow::Cow;
30use std::collections::{HashMap, VecDeque};
31use std::mem::size_of;
32use std::time::Duration;
33use std::{fmt, mem};
34use strum::IntoStaticStr;
35use tracing::{debug, enabled, instrument, Level};
36
37mod context;
38pub(crate) mod errors;
39mod transitions;
40
41const CONTENT_TYPE: &str = "content-type";
42
43#[derive(Debug, IntoStaticStr)]
44pub(crate) enum State {
45    WaitingStart,
46    WaitingReplayEntries {
47        received_entries: u32,
48        commands: VecDeque<RawMessage>,
49        async_results: AsyncResultsState,
50    },
51    Replaying {
52        commands: VecDeque<RawMessage>,
53        run_state: RunState,
54        async_results: AsyncResultsState,
55    },
56    Processing {
57        processing_first_entry: bool,
58        run_state: RunState,
59        async_results: AsyncResultsState,
60    },
61    Closed,
62}
63
64impl State {
65    fn as_unexpected_state(&self, event: impl ToString) -> Error {
66        if matches!(self, State::Closed) {
67            return ClosedError::new(event.to_string()).into();
68        }
69        UnexpectedStateError::new(self.into(), event.to_string()).into()
70    }
71}
72
73struct TrackedInvocationId {
74    handle: NotificationHandle,
75    invocation_id: Option<String>,
76}
77
78impl TrackedInvocationId {
79    fn is_resolved(&self) -> bool {
80        self.invocation_id.is_some()
81    }
82}
83
84pub struct CoreVM {
85    options: VMOptions,
86
87    // Input decoder
88    decoder: Decoder,
89
90    // State machine
91    context: Context,
92    last_transition: Result<State, Error>,
93
94    // Implicit cancellation tracking
95    tracked_invocation_ids: Vec<TrackedInvocationId>,
96
97    // Run names, useful for debugging
98    sys_run_names: HashMap<NotificationHandle, String>,
99}
100
101impl CoreVM {
102    // Returns empty string if the invocation id is not present
103    fn debug_invocation_id(&self) -> &str {
104        if let Some(start_info) = self.context.start_info() {
105            &start_info.debug_id
106        } else {
107            ""
108        }
109    }
110
111    fn debug_state(&self) -> &'static str {
112        match &self.last_transition {
113            Ok(s) => s.into(),
114            Err(_) => "Failed",
115        }
116    }
117
118    fn verify_error_metadata_feature_support(&mut self, value: &NonEmptyValue) -> VMResult<()> {
119        if let NonEmptyValue::Failure(f) = value {
120            if !f.metadata.is_empty() {
121                self.verify_feature_support("terminal error metadata", Version::V6)?;
122            }
123        }
124        Ok(())
125    }
126
127    #[allow(dead_code)]
128    fn verify_feature_support(
129        &mut self,
130        feature: &'static str,
131        minimum_required_protocol: Version,
132    ) -> VMResult<()> {
133        if self.context.negotiated_protocol_version < minimum_required_protocol {
134            return self.do_transition(HitError(
135                UnsupportedFeatureForNegotiatedVersion::new(
136                    feature,
137                    self.context.negotiated_protocol_version,
138                    minimum_required_protocol,
139                )
140                .into(),
141            ));
142        }
143        Ok(())
144    }
145
146    fn _is_completed(&self, handle: NotificationHandle) -> bool {
147        match &self.last_transition {
148            Ok(State::Replaying { async_results, .. })
149            | Ok(State::Processing { async_results, .. }) => {
150                async_results.is_handle_completed(handle)
151            }
152            _ => false,
153        }
154    }
155
156    fn _do_progress(
157        &mut self,
158        any_handle: Vec<NotificationHandle>,
159    ) -> Result<DoProgressResponse, Error> {
160        match self.do_transition(DoProgress(any_handle)) {
161            Ok(Ok(do_progress_response)) => Ok(do_progress_response),
162            Ok(Err(_)) => Err(SUSPENDED),
163            Err(e) => Err(e),
164        }
165    }
166
167    fn is_implicit_cancellation_enabled(&self) -> bool {
168        matches!(
169            self.options.implicit_cancellation,
170            ImplicitCancellationOption::Enabled { .. }
171        )
172    }
173}
174
175impl fmt::Debug for CoreVM {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        let mut s = f.debug_struct("CoreVM");
178        s.field("version", &self.context.negotiated_protocol_version);
179
180        if let Some(start_info) = self.context.start_info() {
181            s.field("invocation_id", &start_info.debug_id);
182        }
183
184        match &self.last_transition {
185            Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
186            Err(_) => s.field("last_transition", &"Errored"),
187        };
188
189        s.field("command_index", &self.context.journal.command_index())
190            .field(
191                "notification_index",
192                &self.context.journal.notification_index(),
193            )
194            .finish()
195    }
196}
197
198// --- Bound checks
199#[allow(unused)]
200const fn is_send<T: Send>() {}
201const _: () = is_send::<CoreVM>();
202
203// Macro used for informative debug logs
204macro_rules! invocation_debug_logs {
205    ($this:expr, $($arg:tt)*) => {
206        if ($this.is_processing()) {
207            tracing::debug!($($arg)*)
208        }
209    };
210}
211
212impl super::VM for CoreVM {
213    #[instrument(level = "trace", skip(request_headers), ret)]
214    fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
215        let version = request_headers
216            .extract(CONTENT_TYPE)
217            .map_err(|e| {
218                Error::new(
219                    errors::codes::BAD_REQUEST,
220                    format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
221                )
222            })?
223            .ok_or(errors::MISSING_CONTENT_TYPE)?
224            .parse::<Version>()?;
225
226        if version < Version::minimum_supported_version()
227            || version > Version::maximum_supported_version()
228        {
229            return Err(Error::new(
230                errors::codes::UNSUPPORTED_MEDIA_TYPE,
231                format!(
232                    "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
233                    You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
234                    version,
235                    Version::minimum_supported_version(),
236                    Version::maximum_supported_version()
237                ),
238            ));
239        }
240        let non_deterministic_checks_ignore_payload_equality = matches!(
241            options.non_determinism_checks,
242            NonDeterministicChecksOption::PayloadChecksDisabled
243        );
244
245        Ok(Self {
246            options,
247            decoder: Decoder::new(version),
248            context: Context {
249                input_is_closed: false,
250                output: Output::new(version),
251                start_info: None,
252                journal: Default::default(),
253                eager_state: Default::default(),
254                non_deterministic_checks_ignore_payload_equality,
255                negotiated_protocol_version: version,
256            },
257            last_transition: Ok(State::WaitingStart),
258            tracked_invocation_ids: vec![],
259            sys_run_names: HashMap::with_capacity(0),
260        })
261    }
262
263    #[instrument(
264        level = "trace",
265        skip(self),
266        fields(
267            restate.invocation.id = self.debug_invocation_id(),
268            restate.protocol.state = self.debug_state(),
269            restate.journal.command_index = self.context.journal.command_index(),
270            restate.protocol.version = %self.context.negotiated_protocol_version
271        ),
272        ret
273    )]
274    fn get_response_head(&self) -> ResponseHead {
275        ResponseHead {
276            status_code: 200,
277            headers: vec![Header {
278                key: Cow::Borrowed(CONTENT_TYPE),
279                value: Cow::Borrowed(self.context.negotiated_protocol_version.content_type()),
280            }],
281            version: self.context.negotiated_protocol_version,
282        }
283    }
284
285    #[instrument(
286        level = "trace",
287        skip(self),
288        fields(
289            restate.invocation.id = self.debug_invocation_id(),
290            restate.protocol.state = self.debug_state(),
291            restate.journal.command_index = self.context.journal.command_index(),
292            restate.protocol.version = %self.context.negotiated_protocol_version
293        ),
294        ret
295    )]
296    fn notify_input(&mut self, buffer: Bytes) {
297        self.decoder.push(buffer);
298        loop {
299            match self.decoder.consume_next() {
300                Ok(Some(msg)) => {
301                    if self.do_transition(NewMessage(msg)).is_err() {
302                        return;
303                    }
304                }
305                Ok(None) => {
306                    return;
307                }
308                Err(e) => {
309                    if self.do_transition(HitError(e.into())).is_err() {
310                        return;
311                    }
312                }
313            }
314        }
315    }
316
317    #[instrument(
318        level = "trace",
319        skip(self),
320        fields(
321            restate.invocation.id = self.debug_invocation_id(),
322            restate.protocol.state = self.debug_state(),
323            restate.journal.command_index = self.context.journal.command_index(),
324            restate.protocol.version = %self.context.negotiated_protocol_version
325        ),
326        ret
327    )]
328    fn notify_input_closed(&mut self) {
329        self.context.input_is_closed = true;
330        let _ = self.do_transition(NotifyInputClosed);
331    }
332
333    #[instrument(
334        level = "trace",
335        skip(self),
336        fields(
337            restate.invocation.id = self.debug_invocation_id(),
338            restate.protocol.state = self.debug_state(),
339            restate.journal.command_index = self.context.journal.command_index(),
340            restate.protocol.version = %self.context.negotiated_protocol_version
341        ),
342        ret
343    )]
344    fn notify_error(
345        &mut self,
346        mut error: Error,
347        command_relationship: Option<CommandRelationship>,
348    ) {
349        if let Some(command_relationship) = command_relationship {
350            error = error.with_related_command_metadata(
351                self.context
352                    .journal
353                    .resolve_related_command(command_relationship),
354            );
355        }
356
357        let _ = self.do_transition(HitError(error));
358    }
359
360    #[instrument(
361        level = "trace",
362        skip(self),
363        fields(
364            restate.invocation.id = self.debug_invocation_id(),
365            restate.protocol.state = self.debug_state(),
366            restate.journal.command_index = self.context.journal.command_index(),
367            restate.protocol.version = %self.context.negotiated_protocol_version
368        ),
369        ret
370    )]
371    fn take_output(&mut self) -> TakeOutputResult {
372        if self.context.output.buffer.has_remaining() {
373            TakeOutputResult::Buffer(
374                self.context
375                    .output
376                    .buffer
377                    .copy_to_bytes(self.context.output.buffer.remaining()),
378            )
379        } else if !self.context.output.is_closed() {
380            TakeOutputResult::Buffer(Bytes::default())
381        } else {
382            TakeOutputResult::EOF
383        }
384    }
385
386    #[instrument(
387        level = "trace",
388        skip(self),
389        fields(
390            restate.invocation.id = self.debug_invocation_id(),
391            restate.protocol.state = self.debug_state(),
392            restate.journal.command_index = self.context.journal.command_index(),
393            restate.protocol.version = %self.context.negotiated_protocol_version
394        ),
395        ret
396    )]
397    fn is_ready_to_execute(&self) -> Result<bool, Error> {
398        match &self.last_transition {
399            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
400            Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
401            Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
402            Err(e) => Err(e.clone()),
403        }
404    }
405
406    #[instrument(
407        level = "trace",
408        skip(self),
409        fields(
410            restate.invocation.id = self.debug_invocation_id(),
411            restate.protocol.state = self.debug_state(),
412            restate.journal.command_index = self.context.journal.command_index(),
413            restate.protocol.version = %self.context.negotiated_protocol_version
414        ),
415        ret
416    )]
417    fn is_completed(&self, handle: NotificationHandle) -> bool {
418        self._is_completed(handle)
419    }
420
421    #[instrument(
422        level = "trace",
423        skip(self),
424        fields(
425            restate.invocation.id = self.debug_invocation_id(),
426            restate.protocol.state = self.debug_state(),
427            restate.journal.command_index = self.context.journal.command_index(),
428            restate.protocol.version = %self.context.negotiated_protocol_version
429        ),
430        ret
431    )]
432    fn do_progress(
433        &mut self,
434        mut any_handle: Vec<NotificationHandle>,
435    ) -> VMResult<DoProgressResponse> {
436        if self.is_implicit_cancellation_enabled() {
437            // We want the runtime to wake us up in case cancel notification comes in.
438            any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
439
440            match self._do_progress(any_handle) {
441                Ok(DoProgressResponse::AnyCompleted) => {
442                    // If it's cancel signal, then let's go on with the cancellation logic
443                    if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
444                        // Loop once over the tracked invocation ids to resolve the unresolved ones
445                        for i in 0..self.tracked_invocation_ids.len() {
446                            if self.tracked_invocation_ids[i].is_resolved() {
447                                continue;
448                            }
449
450                            let handle = self.tracked_invocation_ids[i].handle;
451
452                            // Try to resolve it
453                            match self._do_progress(vec![handle]) {
454                                Ok(DoProgressResponse::AnyCompleted) => {
455                                    let invocation_id = match self.do_transition(CopyNotification(handle)) {
456                                        Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
457                                        Ok(Err(_)) => Err(SUSPENDED),
458                                        _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
459                                    }?;
460
461                                    // This handle is resolved
462                                    self.tracked_invocation_ids[i].invocation_id =
463                                        Some(invocation_id);
464                                }
465                                res => return res,
466                            }
467                        }
468
469                        // Now we got all the invocation IDs, let's cancel!
470                        for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
471                            self.sys_cancel_invocation(
472                                tracked_invocation_id
473                                    .invocation_id
474                                    .expect("We resolved before all the invocation ids"),
475                            )?;
476                        }
477
478                        // Flip the cancellation
479                        let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
480
481                        // Done
482                        Ok(DoProgressResponse::CancelSignalReceived)
483                    } else {
484                        Ok(DoProgressResponse::AnyCompleted)
485                    }
486                }
487                res => res,
488            }
489        } else {
490            self._do_progress(any_handle)
491        }
492    }
493
494    #[instrument(
495        level = "trace",
496        skip(self),
497        fields(
498            restate.invocation.id = self.debug_invocation_id(),
499            restate.protocol.state = self.debug_state(),
500            restate.journal.command_index = self.context.journal.command_index(),
501            restate.protocol.version = %self.context.negotiated_protocol_version
502        ),
503        ret
504    )]
505    fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
506        match self.do_transition(TakeNotification(handle)) {
507            Ok(Ok(Some(value))) => {
508                if self.is_implicit_cancellation_enabled() {
509                    // Let's check if that's one of the tracked invocation ids
510                    // We can do binary search here because we assume tracked_invocation_ids is ordered, as handles are incremental numbers
511                    if let Ok(found) = self
512                        .tracked_invocation_ids
513                        .binary_search_by(|tracked| tracked.handle.cmp(&handle))
514                    {
515                        let Value::InvocationId(invocation_id) = &value else {
516                            panic!("Expecting an invocation id here, but got {value:?}");
517                        };
518                        // Keep track of this invocation id
519                        self.tracked_invocation_ids
520                            .get_mut(found)
521                            .unwrap()
522                            .invocation_id = Some(invocation_id.clone());
523                    }
524                }
525
526                Ok(Some(value))
527            }
528            Ok(Ok(None)) => Ok(None),
529            Ok(Err(_)) => Err(SUSPENDED),
530            Err(e) => Err(e),
531        }
532    }
533
534    #[instrument(
535        level = "trace",
536        skip(self),
537        fields(
538            restate.invocation.id = self.debug_invocation_id(),
539            restate.protocol.state = self.debug_state(),
540            restate.journal.command_index = self.context.journal.command_index(),
541            restate.protocol.version = %self.context.negotiated_protocol_version
542        ),
543        ret
544    )]
545    fn sys_input(&mut self) -> Result<Input, Error> {
546        self.do_transition(SysInput)
547    }
548
549    #[instrument(
550        level = "trace",
551        skip(self),
552        fields(
553            restate.invocation.id = self.debug_invocation_id(),
554            restate.protocol.state = self.debug_state(),
555            restate.journal.command_index = self.context.journal.command_index(),
556            restate.protocol.version = %self.context.negotiated_protocol_version
557        ),
558        ret
559    )]
560    fn sys_state_get(
561        &mut self,
562        key: String,
563        options: PayloadOptions,
564    ) -> Result<NotificationHandle, Error> {
565        invocation_debug_logs!(self, "Executing 'Get state {key}'");
566        self.do_transition(SysStateGet(key, options))
567    }
568
569    #[instrument(
570        level = "trace",
571        skip(self),
572        fields(
573            restate.invocation.id = self.debug_invocation_id(),
574            restate.protocol.state = self.debug_state(),
575            restate.journal.command_index = self.context.journal.command_index(),
576            restate.protocol.version = %self.context.negotiated_protocol_version
577        ),
578        ret
579    )]
580    fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
581        invocation_debug_logs!(self, "Executing 'Get state keys'");
582        self.do_transition(SysStateGetKeys)
583    }
584
585    #[instrument(
586        level = "trace",
587        skip(self, value),
588        fields(
589            restate.invocation.id = self.debug_invocation_id(),
590            restate.protocol.state = self.debug_state(),
591            restate.journal.command_index = self.context.journal.command_index(),
592            restate.protocol.version = %self.context.negotiated_protocol_version
593        ),
594        ret
595    )]
596    fn sys_state_set(
597        &mut self,
598        key: String,
599        value: Bytes,
600        options: PayloadOptions,
601    ) -> VMResult<()> {
602        invocation_debug_logs!(self, "Executing 'Set state {key}'");
603        self.context.eager_state.set(key.clone(), value.clone());
604        self.do_transition(SysNonCompletableEntry(
605            SetStateCommandMessage {
606                key: Bytes::from(key.into_bytes()),
607                value: Some(value.into()),
608                ..SetStateCommandMessage::default()
609            },
610            options,
611        ))
612    }
613
614    #[instrument(
615        level = "trace",
616        skip(self),
617        fields(
618            restate.invocation.id = self.debug_invocation_id(),
619            restate.protocol.state = self.debug_state(),
620            restate.journal.command_index = self.context.journal.command_index(),
621            restate.protocol.version = %self.context.negotiated_protocol_version
622        ),
623        ret
624    )]
625    fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
626        invocation_debug_logs!(self, "Executing 'Clear state {key}'");
627        self.context.eager_state.clear(key.clone());
628        self.do_transition(SysNonCompletableEntry(
629            ClearStateCommandMessage {
630                key: Bytes::from(key.into_bytes()),
631                ..ClearStateCommandMessage::default()
632            },
633            PayloadOptions::default(),
634        ))
635    }
636
637    #[instrument(
638        level = "trace",
639        skip(self),
640        fields(
641            restate.invocation.id = self.debug_invocation_id(),
642            restate.protocol.state = self.debug_state(),
643            restate.journal.command_index = self.context.journal.command_index(),
644            restate.protocol.version = %self.context.negotiated_protocol_version
645        ),
646        ret
647    )]
648    fn sys_state_clear_all(&mut self) -> Result<(), Error> {
649        invocation_debug_logs!(self, "Executing 'Clear all state'");
650        self.context.eager_state.clear_all();
651        self.do_transition(SysNonCompletableEntry(
652            ClearAllStateCommandMessage::default(),
653            PayloadOptions::default(),
654        ))
655    }
656
657    #[instrument(
658        level = "trace",
659        skip(self),
660        fields(
661            restate.invocation.id = self.debug_invocation_id(),
662            restate.protocol.state = self.debug_state(),
663            restate.journal.command_index = self.context.journal.command_index(),
664            restate.protocol.version = %self.context.negotiated_protocol_version
665        ),
666        ret
667    )]
668    fn sys_sleep(
669        &mut self,
670        name: String,
671        wake_up_time_since_unix_epoch: Duration,
672        now_since_unix_epoch: Option<Duration>,
673    ) -> VMResult<NotificationHandle> {
674        if self.is_processing() {
675            match (&name, now_since_unix_epoch) {
676                (name, Some(now_since_unix_epoch)) if name.is_empty() => {
677                    debug!(
678                        "Executing 'Timer with duration {:?}'",
679                        wake_up_time_since_unix_epoch - now_since_unix_epoch
680                    );
681                }
682                (name, Some(now_since_unix_epoch)) => {
683                    debug!(
684                        "Executing 'Timer {name} with duration {:?}'",
685                        wake_up_time_since_unix_epoch - now_since_unix_epoch
686                    );
687                }
688                (name, None) if name.is_empty() => {
689                    debug!("Executing 'Timer'");
690                }
691                (name, None) => {
692                    debug!("Executing 'Timer named {name}'");
693                }
694            }
695        }
696
697        let completion_id = self.context.journal.next_completion_notification_id();
698
699        self.do_transition(SysSimpleCompletableEntry(
700            SleepCommandMessage {
701                wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
702                    .expect("millis since Unix epoch should fit in u64"),
703                result_completion_id: completion_id,
704                name,
705            },
706            completion_id,
707            PayloadOptions::default(),
708        ))
709    }
710
711    #[instrument(
712        level = "trace",
713        skip(self, input),
714        fields(
715            restate.invocation.id = self.debug_invocation_id(),
716            restate.protocol.state = self.debug_state(),
717            restate.journal.command_index = self.context.journal.command_index(),
718            restate.protocol.version = %self.context.negotiated_protocol_version
719        ),
720        ret
721    )]
722    fn sys_call(
723        &mut self,
724        target: Target,
725        input: Bytes,
726        name: Option<String>,
727        options: PayloadOptions,
728    ) -> VMResult<CallHandle> {
729        invocation_debug_logs!(
730            self,
731            "Executing 'Call {}/{}'",
732            target.service,
733            target.handler
734        );
735        if let Some(idempotency_key) = &target.idempotency_key {
736            if idempotency_key.is_empty() {
737                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
738                unreachable!();
739            }
740        }
741
742        let call_invocation_id_completion_id =
743            self.context.journal.next_completion_notification_id();
744        let result_completion_id = self.context.journal.next_completion_notification_id();
745
746        let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
747            CallCommandMessage {
748                service_name: target.service,
749                handler_name: target.handler,
750                key: target.key.unwrap_or_default(),
751                idempotency_key: target.idempotency_key,
752                headers: target
753                    .headers
754                    .into_iter()
755                    .map(crate::service_protocol::messages::Header::from)
756                    .collect(),
757                parameter: input,
758                invocation_id_notification_idx: call_invocation_id_completion_id,
759                name: name.unwrap_or_default(),
760                result_completion_id,
761            },
762            vec![call_invocation_id_completion_id, result_completion_id],
763            options,
764        ))?;
765
766        if matches!(
767            self.options.implicit_cancellation,
768            ImplicitCancellationOption::Enabled {
769                cancel_children_calls: true,
770                ..
771            }
772        ) {
773            self.tracked_invocation_ids.push(TrackedInvocationId {
774                handle: handles[0],
775                invocation_id: None,
776            })
777        }
778
779        Ok(CallHandle {
780            invocation_id_notification_handle: handles[0],
781            call_notification_handle: handles[1],
782        })
783    }
784
785    #[instrument(
786        level = "trace",
787        skip(self, input),
788        fields(
789            restate.invocation.id = self.debug_invocation_id(),
790            restate.protocol.state = self.debug_state(),
791            restate.journal.command_index = self.context.journal.command_index(),
792            restate.protocol.version = %self.context.negotiated_protocol_version
793        ),
794        ret
795    )]
796    fn sys_send(
797        &mut self,
798        target: Target,
799        input: Bytes,
800        delay: Option<Duration>,
801        name: Option<String>,
802        options: PayloadOptions,
803    ) -> VMResult<SendHandle> {
804        invocation_debug_logs!(
805            self,
806            "Executing 'Send to {}/{}'",
807            target.service,
808            target.handler
809        );
810        if let Some(idempotency_key) = &target.idempotency_key {
811            if idempotency_key.is_empty() {
812                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
813                unreachable!();
814            }
815        }
816        let call_invocation_id_completion_id =
817            self.context.journal.next_completion_notification_id();
818        let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
819            OneWayCallCommandMessage {
820                service_name: target.service,
821                handler_name: target.handler,
822                key: target.key.unwrap_or_default(),
823                idempotency_key: target.idempotency_key,
824                headers: target
825                    .headers
826                    .into_iter()
827                    .map(crate::service_protocol::messages::Header::from)
828                    .collect(),
829                parameter: input,
830                invoke_time: delay
831                    .map(|d| {
832                        u64::try_from(d.as_millis())
833                            .expect("millis since Unix epoch should fit in u64")
834                    })
835                    .unwrap_or_default(),
836                invocation_id_notification_idx: call_invocation_id_completion_id,
837                name: name.unwrap_or_default(),
838            },
839            call_invocation_id_completion_id,
840            options,
841        ))?;
842
843        if matches!(
844            self.options.implicit_cancellation,
845            ImplicitCancellationOption::Enabled {
846                cancel_children_one_way_calls: true,
847                ..
848            }
849        ) {
850            self.tracked_invocation_ids.push(TrackedInvocationId {
851                handle: invocation_id_notification_handle,
852                invocation_id: None,
853            })
854        }
855
856        Ok(SendHandle {
857            invocation_id_notification_handle,
858        })
859    }
860
861    #[instrument(
862        level = "trace",
863        skip(self),
864        fields(
865            restate.invocation.id = self.debug_invocation_id(),
866            restate.protocol.state = self.debug_state(),
867            restate.journal.command_index = self.context.journal.command_index(),
868            restate.protocol.version = %self.context.negotiated_protocol_version
869        ),
870        ret
871    )]
872    fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
873        invocation_debug_logs!(self, "Executing 'Create awakeable'");
874
875        let signal_id = self.context.journal.next_signal_notification_id();
876
877        let handle = self.do_transition(CreateSignalHandle(
878            "awakeable",
879            NotificationId::SignalId(signal_id),
880        ))?;
881
882        Ok((
883            awakeable_id_str(&self.context.expect_start_info().id, signal_id),
884            handle,
885        ))
886    }
887
888    #[instrument(
889        level = "trace",
890        skip(self, value),
891        fields(
892            restate.invocation.id = self.debug_invocation_id(),
893            restate.protocol.state = self.debug_state(),
894            restate.journal.command_index = self.context.journal.command_index(),
895            restate.protocol.version = %self.context.negotiated_protocol_version
896        ),
897        ret
898    )]
899    fn sys_complete_awakeable(
900        &mut self,
901        id: String,
902        value: NonEmptyValue,
903        options: PayloadOptions,
904    ) -> VMResult<()> {
905        invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
906        self.verify_error_metadata_feature_support(&value)?;
907        self.do_transition(SysNonCompletableEntry(
908            CompleteAwakeableCommandMessage {
909                awakeable_id: id,
910                result: Some(match value {
911                    NonEmptyValue::Success(s) => {
912                        complete_awakeable_command_message::Result::Value(s.into())
913                    }
914                    NonEmptyValue::Failure(f) => {
915                        complete_awakeable_command_message::Result::Failure(f.into())
916                    }
917                }),
918                ..Default::default()
919            },
920            options,
921        ))
922    }
923
924    #[instrument(
925        level = "trace",
926        skip(self),
927        fields(
928            restate.invocation.id = self.debug_invocation_id(),
929            restate.protocol.state = self.debug_state(),
930            restate.journal.command_index = self.context.journal.command_index(),
931            restate.protocol.version = %self.context.negotiated_protocol_version
932        ),
933        ret
934    )]
935    fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
936        invocation_debug_logs!(self, "Executing 'Create named signal'");
937
938        self.do_transition(CreateSignalHandle(
939            "named awakeable",
940            NotificationId::SignalName(signal_name),
941        ))
942    }
943
944    #[instrument(
945        level = "trace",
946        skip(self, value),
947        fields(
948            restate.invocation.id = self.debug_invocation_id(),
949            restate.protocol.state = self.debug_state(),
950            restate.journal.command_index = self.context.journal.command_index(),
951            restate.protocol.version = %self.context.negotiated_protocol_version
952        ),
953        ret
954    )]
955    fn sys_complete_signal(
956        &mut self,
957        target_invocation_id: String,
958        signal_name: String,
959        value: NonEmptyValue,
960    ) -> VMResult<()> {
961        invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
962        self.verify_error_metadata_feature_support(&value)?;
963        self.do_transition(SysNonCompletableEntry(
964            SendSignalCommandMessage {
965                target_invocation_id,
966                signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
967                result: Some(match value {
968                    NonEmptyValue::Success(s) => {
969                        send_signal_command_message::Result::Value(s.into())
970                    }
971                    NonEmptyValue::Failure(f) => {
972                        send_signal_command_message::Result::Failure(f.into())
973                    }
974                }),
975                ..Default::default()
976            },
977            PayloadOptions::default(),
978        ))
979    }
980
981    #[instrument(
982        level = "trace",
983        skip(self),
984        fields(
985            restate.invocation.id = self.debug_invocation_id(),
986            restate.protocol.state = self.debug_state(),
987            restate.journal.command_index = self.context.journal.command_index(),
988            restate.protocol.version = %self.context.negotiated_protocol_version
989        ),
990        ret
991    )]
992    fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
993        invocation_debug_logs!(self, "Executing 'Await promise {key}'");
994
995        let result_completion_id = self.context.journal.next_completion_notification_id();
996        self.do_transition(SysSimpleCompletableEntry(
997            GetPromiseCommandMessage {
998                key,
999                result_completion_id,
1000                ..Default::default()
1001            },
1002            result_completion_id,
1003            PayloadOptions::default(),
1004        ))
1005    }
1006
1007    #[instrument(
1008        level = "trace",
1009        skip(self),
1010        fields(
1011            restate.invocation.id = self.debug_invocation_id(),
1012            restate.protocol.state = self.debug_state(),
1013            restate.journal.command_index = self.context.journal.command_index(),
1014            restate.protocol.version = %self.context.negotiated_protocol_version
1015        ),
1016        ret
1017    )]
1018    fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
1019        invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
1020
1021        let result_completion_id = self.context.journal.next_completion_notification_id();
1022        self.do_transition(SysSimpleCompletableEntry(
1023            PeekPromiseCommandMessage {
1024                key,
1025                result_completion_id,
1026                ..Default::default()
1027            },
1028            result_completion_id,
1029            PayloadOptions::default(),
1030        ))
1031    }
1032
1033    #[instrument(
1034        level = "trace",
1035        skip(self, value),
1036        fields(
1037            restate.invocation.id = self.debug_invocation_id(),
1038            restate.protocol.state = self.debug_state(),
1039            restate.journal.command_index = self.context.journal.command_index(),
1040            restate.protocol.version = %self.context.negotiated_protocol_version
1041        ),
1042        ret
1043    )]
1044    fn sys_complete_promise(
1045        &mut self,
1046        key: String,
1047        value: NonEmptyValue,
1048        options: PayloadOptions,
1049    ) -> VMResult<NotificationHandle> {
1050        invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1051        self.verify_error_metadata_feature_support(&value)?;
1052
1053        let result_completion_id = self.context.journal.next_completion_notification_id();
1054        self.do_transition(SysSimpleCompletableEntry(
1055            CompletePromiseCommandMessage {
1056                key,
1057                completion: Some(match value {
1058                    NonEmptyValue::Success(s) => {
1059                        complete_promise_command_message::Completion::CompletionValue(s.into())
1060                    }
1061                    NonEmptyValue::Failure(f) => {
1062                        complete_promise_command_message::Completion::CompletionFailure(f.into())
1063                    }
1064                }),
1065                result_completion_id,
1066                ..Default::default()
1067            },
1068            result_completion_id,
1069            options,
1070        ))
1071    }
1072
1073    #[instrument(
1074        level = "trace",
1075        skip(self),
1076        fields(
1077            restate.invocation.id = self.debug_invocation_id(),
1078            restate.protocol.state = self.debug_state(),
1079            restate.journal.command_index = self.context.journal.command_index(),
1080            restate.protocol.version = %self.context.negotiated_protocol_version
1081        ),
1082        ret
1083    )]
1084    fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1085        match self.do_transition(SysRun(name.clone())) {
1086            Ok(handle) => {
1087                if enabled!(Level::DEBUG) {
1088                    // Store the name, we need it later when completing
1089                    self.sys_run_names.insert(handle, name);
1090                }
1091                Ok(handle)
1092            }
1093            Err(e) => Err(e),
1094        }
1095    }
1096
1097    #[instrument(
1098        level = "trace",
1099        skip(self, value, retry_policy),
1100        fields(
1101            restate.invocation.id = self.debug_invocation_id(),
1102            restate.protocol.state = self.debug_state(),
1103            restate.journal.command_index = self.context.journal.command_index(),
1104            restate.protocol.version = %self.context.negotiated_protocol_version
1105        ),
1106        ret
1107    )]
1108    fn propose_run_completion(
1109        &mut self,
1110        notification_handle: NotificationHandle,
1111        value: RunExitResult,
1112        retry_policy: RetryPolicy,
1113    ) -> VMResult<()> {
1114        if enabled!(Level::DEBUG) {
1115            let name: &str = self
1116                .sys_run_names
1117                .get(&notification_handle)
1118                .map(String::as_str)
1119                .unwrap_or_default();
1120            match &value {
1121                RunExitResult::Success(_) => {
1122                    invocation_debug_logs!(self, "Journaling run '{name}' success result");
1123                }
1124                RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1125                    invocation_debug_logs!(
1126                        self,
1127                        "Journaling run '{name}' terminal failure {code} result"
1128                    );
1129                }
1130                RunExitResult::RetryableFailure { .. } => {
1131                    invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1132                }
1133            }
1134        }
1135        if let RunExitResult::TerminalFailure(f) = &value {
1136            if !f.metadata.is_empty() {
1137                self.verify_feature_support("terminal error metadata", Version::V6)?;
1138            }
1139        }
1140
1141        self.do_transition(ProposeRunCompletion(
1142            notification_handle,
1143            value,
1144            retry_policy,
1145        ))
1146    }
1147
1148    #[instrument(
1149        level = "trace",
1150        skip(self),
1151        fields(
1152            restate.invocation.id = self.debug_invocation_id(),
1153            restate.protocol.state = self.debug_state(),
1154            restate.journal.command_index = self.context.journal.command_index(),
1155            restate.protocol.version = %self.context.negotiated_protocol_version
1156        ),
1157        ret
1158    )]
1159    fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1160        invocation_debug_logs!(
1161            self,
1162            "Executing 'Cancel invocation' of {target_invocation_id}"
1163        );
1164        self.do_transition(SysNonCompletableEntry(
1165            SendSignalCommandMessage {
1166                target_invocation_id,
1167                signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1168                result: Some(send_signal_command_message::Result::Void(Default::default())),
1169                ..Default::default()
1170            },
1171            PayloadOptions::default(),
1172        ))
1173    }
1174
1175    #[instrument(
1176        level = "trace",
1177        skip(self),
1178        fields(
1179            restate.invocation.id = self.debug_invocation_id(),
1180            restate.protocol.state = self.debug_state(),
1181            restate.journal.command_index = self.context.journal.command_index(),
1182            restate.protocol.version = %self.context.negotiated_protocol_version
1183        ),
1184        ret
1185    )]
1186    fn sys_attach_invocation(
1187        &mut self,
1188        target: AttachInvocationTarget,
1189    ) -> VMResult<NotificationHandle> {
1190        invocation_debug_logs!(self, "Executing 'Attach invocation'");
1191
1192        let result_completion_id = self.context.journal.next_completion_notification_id();
1193        self.do_transition(SysSimpleCompletableEntry(
1194            AttachInvocationCommandMessage {
1195                target: Some(match target {
1196                    AttachInvocationTarget::InvocationId(id) => {
1197                        attach_invocation_command_message::Target::InvocationId(id)
1198                    }
1199                    AttachInvocationTarget::WorkflowId { name, key } => {
1200                        attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1201                            workflow_name: name,
1202                            workflow_key: key,
1203                        })
1204                    }
1205                    AttachInvocationTarget::IdempotencyId {
1206                        service_name,
1207                        service_key,
1208                        handler_name,
1209                        idempotency_key,
1210                    } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1211                        IdempotentRequestTarget {
1212                            service_name,
1213                            service_key,
1214                            handler_name,
1215                            idempotency_key,
1216                        },
1217                    ),
1218                }),
1219                result_completion_id,
1220                ..Default::default()
1221            },
1222            result_completion_id,
1223            PayloadOptions::default(),
1224        ))
1225    }
1226
1227    #[instrument(
1228        level = "trace",
1229        skip(self),
1230        fields(
1231            restate.invocation.id = self.debug_invocation_id(),
1232            restate.protocol.state = self.debug_state(),
1233            restate.journal.command_index = self.context.journal.command_index(),
1234            restate.protocol.version = %self.context.negotiated_protocol_version
1235        ),
1236        ret
1237    )]
1238    fn sys_get_invocation_output(
1239        &mut self,
1240        target: AttachInvocationTarget,
1241    ) -> VMResult<NotificationHandle> {
1242        invocation_debug_logs!(self, "Executing 'Get invocation output'");
1243
1244        let result_completion_id = self.context.journal.next_completion_notification_id();
1245        self.do_transition(SysSimpleCompletableEntry(
1246            GetInvocationOutputCommandMessage {
1247                target: Some(match target {
1248                    AttachInvocationTarget::InvocationId(id) => {
1249                        get_invocation_output_command_message::Target::InvocationId(id)
1250                    }
1251                    AttachInvocationTarget::WorkflowId { name, key } => {
1252                        get_invocation_output_command_message::Target::WorkflowTarget(
1253                            WorkflowTarget {
1254                                workflow_name: name,
1255                                workflow_key: key,
1256                            },
1257                        )
1258                    }
1259                    AttachInvocationTarget::IdempotencyId {
1260                        service_name,
1261                        service_key,
1262                        handler_name,
1263                        idempotency_key,
1264                    } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1265                        IdempotentRequestTarget {
1266                            service_name,
1267                            service_key,
1268                            handler_name,
1269                            idempotency_key,
1270                        },
1271                    ),
1272                }),
1273                result_completion_id,
1274                ..Default::default()
1275            },
1276            result_completion_id,
1277            PayloadOptions::default(),
1278        ))
1279    }
1280
1281    #[instrument(
1282        level = "trace",
1283        skip(self, value),
1284        fields(
1285            restate.invocation.id = self.debug_invocation_id(),
1286            restate.protocol.state = self.debug_state(),
1287            restate.journal.command_index = self.context.journal.command_index(),
1288            restate.protocol.version = %self.context.negotiated_protocol_version
1289        ),
1290        ret
1291    )]
1292    fn sys_write_output(&mut self, value: NonEmptyValue, options: PayloadOptions) -> VMResult<()> {
1293        match &value {
1294            NonEmptyValue::Success(_) => {
1295                invocation_debug_logs!(self, "Writing invocation result success value");
1296            }
1297            NonEmptyValue::Failure(_) => {
1298                invocation_debug_logs!(self, "Writing invocation result failure value");
1299            }
1300        }
1301        self.verify_error_metadata_feature_support(&value)?;
1302        self.do_transition(SysNonCompletableEntry(
1303            OutputCommandMessage {
1304                result: Some(match value {
1305                    NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1306                    NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1307                }),
1308                ..OutputCommandMessage::default()
1309            },
1310            options,
1311        ))
1312    }
1313
1314    #[instrument(
1315        level = "trace",
1316        skip(self),
1317        fields(
1318            restate.invocation.id = self.debug_invocation_id(),
1319            restate.protocol.state = self.debug_state(),
1320            restate.journal.command_index = self.context.journal.command_index(),
1321            restate.protocol.version = %self.context.negotiated_protocol_version
1322        ),
1323        ret
1324    )]
1325    fn sys_end(&mut self) -> Result<(), Error> {
1326        invocation_debug_logs!(self, "End of the invocation");
1327        self.do_transition(SysEnd)
1328    }
1329
1330    fn is_waiting_preflight(&self) -> bool {
1331        matches!(
1332            &self.last_transition,
1333            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1334        )
1335    }
1336
1337    fn is_replaying(&self) -> bool {
1338        matches!(&self.last_transition, Ok(State::Replaying { .. }))
1339    }
1340
1341    fn is_processing(&self) -> bool {
1342        matches!(&self.last_transition, Ok(State::Processing { .. }))
1343    }
1344
1345    fn last_command_index(&self) -> i64 {
1346        self.context.journal.command_index()
1347    }
1348}
1349
1350const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1351    .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1352    .with_encode_padding(false);
1353const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1354
1355const AWAKEABLE_PREFIX: &str = "sign_1";
1356
1357pub(super) fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1358    let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1359    input_buf.put_slice(id);
1360    input_buf.put_u32(completion_index);
1361    format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1362}