restate_sdk_shared_core/vm/
mod.rs

1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3    attach_invocation_command_message, complete_awakeable_command_message,
4    complete_promise_command_message, get_invocation_output_command_message,
5    output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6    CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7    CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8    GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9    OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10    SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14    ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15    EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19    AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20    ImplicitCancellationOption, Input, NonEmptyValue, NotificationHandle, ResponseHead,
21    RetryPolicy, RunExitResult, SendHandle, TakeOutputResult, Target, TerminalFailure, VMOptions,
22    VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
23};
24use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
25use base64::{alphabet, Engine};
26use bytes::{Buf, BufMut, Bytes, BytesMut};
27use context::{AsyncResultsState, Context, Output, RunState};
28use std::borrow::Cow;
29use std::collections::{HashMap, VecDeque};
30use std::mem::size_of;
31use std::time::Duration;
32use std::{fmt, mem};
33use strum::IntoStaticStr;
34use tracing::{debug, enabled, instrument, Level};
35
36mod context;
37pub(crate) mod errors;
38mod transitions;
39
40const CONTENT_TYPE: &str = "content-type";
41
42#[derive(Debug, IntoStaticStr)]
43pub(crate) enum State {
44    WaitingStart,
45    WaitingReplayEntries {
46        received_entries: u32,
47        commands: VecDeque<RawMessage>,
48        async_results: AsyncResultsState,
49    },
50    Replaying {
51        commands: VecDeque<RawMessage>,
52        run_state: RunState,
53        async_results: AsyncResultsState,
54    },
55    Processing {
56        processing_first_entry: bool,
57        run_state: RunState,
58        async_results: AsyncResultsState,
59    },
60    Closed,
61}
62
63impl State {
64    fn as_unexpected_state(&self, event: &'static str) -> Error {
65        if matches!(self, State::Closed) {
66            return ClosedError::new(event).into();
67        }
68        UnexpectedStateError::new(self.into(), event).into()
69    }
70}
71
72struct TrackedInvocationId {
73    handle: NotificationHandle,
74    invocation_id: Option<String>,
75}
76
77impl TrackedInvocationId {
78    fn is_resolved(&self) -> bool {
79        self.invocation_id.is_some()
80    }
81}
82
83pub struct CoreVM {
84    version: Version,
85    options: VMOptions,
86
87    // Input decoder
88    decoder: Decoder,
89
90    // State machine
91    context: Context,
92    last_transition: Result<State, Error>,
93
94    // Implicit cancellation tracking
95    tracked_invocation_ids: Vec<TrackedInvocationId>,
96
97    // Run names, useful for debugging
98    sys_run_names: HashMap<NotificationHandle, String>,
99}
100
101impl CoreVM {
102    // Returns empty string if the invocation id is not present
103    fn debug_invocation_id(&self) -> &str {
104        if let Some(start_info) = self.context.start_info() {
105            &start_info.debug_id
106        } else {
107            ""
108        }
109    }
110
111    fn debug_state(&self) -> &'static str {
112        match &self.last_transition {
113            Ok(s) => s.into(),
114            Err(_) => "Failed",
115        }
116    }
117
118    fn verify_feature_support(
119        &mut self,
120        feature: &'static str,
121        minimum_required_protocol: Version,
122    ) -> VMResult<()> {
123        if self.version < minimum_required_protocol {
124            return self.do_transition(HitError(
125                UnsupportedFeatureForNegotiatedVersion::new(
126                    feature,
127                    self.version,
128                    minimum_required_protocol,
129                )
130                .into(),
131            ));
132        }
133        Ok(())
134    }
135
136    fn _is_completed(&self, handle: NotificationHandle) -> bool {
137        match &self.last_transition {
138            Ok(State::Replaying { async_results, .. })
139            | Ok(State::Processing { async_results, .. }) => {
140                async_results.is_handle_completed(handle)
141            }
142            _ => false,
143        }
144    }
145
146    fn _do_progress(
147        &mut self,
148        any_handle: Vec<NotificationHandle>,
149    ) -> Result<DoProgressResponse, Error> {
150        match self.do_transition(DoProgress(any_handle)) {
151            Ok(Ok(do_progress_response)) => Ok(do_progress_response),
152            Ok(Err(_)) => Err(SUSPENDED),
153            Err(e) => Err(e),
154        }
155    }
156
157    fn is_implicit_cancellation_enabled(&self) -> bool {
158        matches!(
159            self.options.implicit_cancellation,
160            ImplicitCancellationOption::Enabled { .. }
161        )
162    }
163}
164
165impl fmt::Debug for CoreVM {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        let mut s = f.debug_struct("CoreVM");
168        s.field("version", &self.version);
169
170        if let Some(start_info) = self.context.start_info() {
171            s.field("invocation_id", &start_info.debug_id);
172        }
173
174        match &self.last_transition {
175            Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
176            Err(_) => s.field("last_transition", &"Errored"),
177        };
178
179        s.field("command_index", &self.context.journal.command_index())
180            .field(
181                "notification_index",
182                &self.context.journal.notification_index(),
183            )
184            .finish()
185    }
186}
187
188// --- Bound checks
189#[allow(unused)]
190const fn is_send<T: Send>() {}
191const _: () = is_send::<CoreVM>();
192
193// Macro used for informative debug logs
194macro_rules! invocation_debug_logs {
195    ($this:expr, $($arg:tt)*) => {
196        if ($this.is_processing()) {
197            tracing::debug!($($arg)*)
198        }
199    };
200}
201
202impl super::VM for CoreVM {
203    #[instrument(level = "trace", skip(request_headers), ret)]
204    fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
205        let version = request_headers
206            .extract(CONTENT_TYPE)
207            .map_err(|e| {
208                Error::new(
209                    errors::codes::BAD_REQUEST,
210                    format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
211                )
212            })?
213            .ok_or(errors::MISSING_CONTENT_TYPE)?
214            .parse::<Version>()?;
215
216        if version < Version::minimum_supported_version()
217            || version > Version::maximum_supported_version()
218        {
219            return Err(Error::new(
220                errors::codes::UNSUPPORTED_MEDIA_TYPE,
221                format!(
222                    "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
223                    You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
224                    version,
225                    Version::minimum_supported_version(),
226                    Version::maximum_supported_version()
227                ),
228            ));
229        }
230
231        Ok(Self {
232            version,
233            options,
234            decoder: Decoder::new(version),
235            context: Context {
236                input_is_closed: false,
237                output: Output::new(version),
238                start_info: None,
239                journal: Default::default(),
240                eager_state: Default::default(),
241            },
242            last_transition: Ok(State::WaitingStart),
243            tracked_invocation_ids: vec![],
244            sys_run_names: HashMap::with_capacity(0),
245        })
246    }
247
248    #[instrument(
249        level = "trace",
250        skip(self),
251        fields(
252            restate.invocation.id = self.debug_invocation_id(),
253            restate.protocol.state = self.debug_state(),
254            restate.journal.command_index = self.context.journal.command_index(),
255            restate.protocol.version = %self.version
256        ),
257        ret
258    )]
259    fn get_response_head(&self) -> ResponseHead {
260        ResponseHead {
261            status_code: 200,
262            headers: vec![Header {
263                key: Cow::Borrowed(CONTENT_TYPE),
264                value: Cow::Borrowed(self.version.content_type()),
265            }],
266            version: self.version,
267        }
268    }
269
270    #[instrument(
271        level = "trace",
272        skip(self),
273        fields(
274            restate.invocation.id = self.debug_invocation_id(),
275            restate.protocol.state = self.debug_state(),
276            restate.journal.command_index = self.context.journal.command_index(),
277            restate.protocol.version = %self.version
278        ),
279        ret
280    )]
281    fn notify_input(&mut self, buffer: Bytes) {
282        self.decoder.push(buffer);
283        loop {
284            match self.decoder.consume_next() {
285                Ok(Some(msg)) => {
286                    if self.do_transition(NewMessage(msg)).is_err() {
287                        return;
288                    }
289                }
290                Ok(None) => {
291                    return;
292                }
293                Err(e) => {
294                    if self.do_transition(HitError(e.into())).is_err() {
295                        return;
296                    }
297                }
298            }
299        }
300    }
301
302    #[instrument(
303        level = "trace",
304        skip(self),
305        fields(
306            restate.invocation.id = self.debug_invocation_id(),
307            restate.protocol.state = self.debug_state(),
308            restate.journal.command_index = self.context.journal.command_index(),
309            restate.protocol.version = %self.version
310        ),
311        ret
312    )]
313    fn notify_input_closed(&mut self) {
314        self.context.input_is_closed = true;
315        let _ = self.do_transition(NotifyInputClosed);
316    }
317
318    #[instrument(
319        level = "trace",
320        skip(self),
321        fields(
322            restate.invocation.id = self.debug_invocation_id(),
323            restate.protocol.state = self.debug_state(),
324            restate.journal.command_index = self.context.journal.command_index(),
325            restate.protocol.version = %self.version
326        ),
327        ret
328    )]
329    fn notify_error(
330        &mut self,
331        mut error: Error,
332        command_relationship: Option<CommandRelationship>,
333    ) {
334        if let Some(command_relationship) = command_relationship {
335            error = error.with_related_command_metadata(
336                self.context
337                    .journal
338                    .resolve_related_command(command_relationship),
339            );
340        }
341
342        let _ = self.do_transition(HitError(error));
343    }
344
345    #[instrument(
346        level = "trace",
347        skip(self),
348        fields(
349            restate.invocation.id = self.debug_invocation_id(),
350            restate.protocol.state = self.debug_state(),
351            restate.journal.command_index = self.context.journal.command_index(),
352            restate.protocol.version = %self.version
353        ),
354        ret
355    )]
356    fn take_output(&mut self) -> TakeOutputResult {
357        if self.context.output.buffer.has_remaining() {
358            TakeOutputResult::Buffer(
359                self.context
360                    .output
361                    .buffer
362                    .copy_to_bytes(self.context.output.buffer.remaining()),
363            )
364        } else if !self.context.output.is_closed() {
365            TakeOutputResult::Buffer(Bytes::default())
366        } else {
367            TakeOutputResult::EOF
368        }
369    }
370
371    #[instrument(
372        level = "trace",
373        skip(self),
374        fields(
375            restate.invocation.id = self.debug_invocation_id(),
376            restate.protocol.state = self.debug_state(),
377            restate.journal.command_index = self.context.journal.command_index(),
378            restate.protocol.version = %self.version
379        ),
380        ret
381    )]
382    fn is_ready_to_execute(&self) -> Result<bool, Error> {
383        match &self.last_transition {
384            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
385            Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
386            Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
387            Err(e) => Err(e.clone()),
388        }
389    }
390
391    #[instrument(
392        level = "trace",
393        skip(self),
394        fields(
395            restate.invocation.id = self.debug_invocation_id(),
396            restate.protocol.state = self.debug_state(),
397            restate.journal.command_index = self.context.journal.command_index(),
398            restate.protocol.version = %self.version
399        ),
400        ret
401    )]
402    fn is_completed(&self, handle: NotificationHandle) -> bool {
403        self._is_completed(handle)
404    }
405
406    #[instrument(
407        level = "trace",
408        skip(self),
409        fields(
410            restate.invocation.id = self.debug_invocation_id(),
411            restate.protocol.state = self.debug_state(),
412            restate.journal.command_index = self.context.journal.command_index(),
413            restate.protocol.version = %self.version
414        ),
415        ret
416    )]
417    fn do_progress(
418        &mut self,
419        mut any_handle: Vec<NotificationHandle>,
420    ) -> VMResult<DoProgressResponse> {
421        if self.is_implicit_cancellation_enabled() {
422            // We want the runtime to wake us up in case cancel notification comes in.
423            any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
424
425            match self._do_progress(any_handle) {
426                Ok(DoProgressResponse::AnyCompleted) => {
427                    // If it's cancel signal, then let's go on with the cancellation logic
428                    if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
429                        // Loop once over the tracked invocation ids to resolve the unresolved ones
430                        for i in 0..self.tracked_invocation_ids.len() {
431                            if self.tracked_invocation_ids[i].is_resolved() {
432                                continue;
433                            }
434
435                            let handle = self.tracked_invocation_ids[i].handle;
436
437                            // Try to resolve it
438                            match self._do_progress(vec![handle]) {
439                                Ok(DoProgressResponse::AnyCompleted) => {
440                                    let invocation_id = match self.do_transition(CopyNotification(handle)) {
441                                        Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
442                                        Ok(Err(_)) => Err(SUSPENDED),
443                                        _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
444                                    }?;
445
446                                    // This handle is resolved
447                                    self.tracked_invocation_ids[i].invocation_id =
448                                        Some(invocation_id);
449                                }
450                                res => return res,
451                            }
452                        }
453
454                        // Now we got all the invocation IDs, let's cancel!
455                        for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
456                            self.sys_cancel_invocation(
457                                tracked_invocation_id
458                                    .invocation_id
459                                    .expect("We resolved before all the invocation ids"),
460                            )?;
461                        }
462
463                        // Flip the cancellation
464                        let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
465
466                        // Done
467                        Ok(DoProgressResponse::CancelSignalReceived)
468                    } else {
469                        Ok(DoProgressResponse::AnyCompleted)
470                    }
471                }
472                res => res,
473            }
474        } else {
475            self._do_progress(any_handle)
476        }
477    }
478
479    #[instrument(
480        level = "trace",
481        skip(self),
482        fields(
483            restate.invocation.id = self.debug_invocation_id(),
484            restate.protocol.state = self.debug_state(),
485            restate.journal.command_index = self.context.journal.command_index(),
486            restate.protocol.version = %self.version
487        ),
488        ret
489    )]
490    fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
491        match self.do_transition(TakeNotification(handle)) {
492            Ok(Ok(Some(value))) => {
493                if self.is_implicit_cancellation_enabled() {
494                    // Let's check if that's one of the tracked invocation ids
495                    // We can do binary search here because we assume tracked_invocation_ids is ordered, as handles are incremental numbers
496                    if let Ok(found) = self
497                        .tracked_invocation_ids
498                        .binary_search_by(|tracked| tracked.handle.cmp(&handle))
499                    {
500                        let Value::InvocationId(invocation_id) = &value else {
501                            panic!("Expecting an invocation id here, but got {value:?}");
502                        };
503                        // Keep track of this invocation id
504                        self.tracked_invocation_ids
505                            .get_mut(found)
506                            .unwrap()
507                            .invocation_id = Some(invocation_id.clone());
508                    }
509                }
510
511                Ok(Some(value))
512            }
513            Ok(Ok(None)) => Ok(None),
514            Ok(Err(_)) => Err(SUSPENDED),
515            Err(e) => Err(e),
516        }
517    }
518
519    #[instrument(
520        level = "trace",
521        skip(self),
522        fields(
523            restate.invocation.id = self.debug_invocation_id(),
524            restate.protocol.state = self.debug_state(),
525            restate.journal.command_index = self.context.journal.command_index(),
526            restate.protocol.version = %self.version
527        ),
528        ret
529    )]
530    fn sys_input(&mut self) -> Result<Input, Error> {
531        self.do_transition(SysInput)
532    }
533
534    #[instrument(
535        level = "trace",
536        skip(self),
537        fields(
538            restate.invocation.id = self.debug_invocation_id(),
539            restate.protocol.state = self.debug_state(),
540            restate.journal.command_index = self.context.journal.command_index(),
541            restate.protocol.version = %self.version
542        ),
543        ret
544    )]
545    fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
546        invocation_debug_logs!(self, "Executing 'Get state {key}'");
547        self.do_transition(SysStateGet(key))
548    }
549
550    #[instrument(
551        level = "trace",
552        skip(self),
553        fields(
554            restate.invocation.id = self.debug_invocation_id(),
555            restate.protocol.state = self.debug_state(),
556            restate.journal.command_index = self.context.journal.command_index(),
557            restate.protocol.version = %self.version
558        ),
559        ret
560    )]
561    fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
562        invocation_debug_logs!(self, "Executing 'Get state keys'");
563        self.do_transition(SysStateGetKeys)
564    }
565
566    #[instrument(
567        level = "trace",
568        skip(self, value),
569        fields(
570            restate.invocation.id = self.debug_invocation_id(),
571            restate.protocol.state = self.debug_state(),
572            restate.journal.command_index = self.context.journal.command_index(),
573            restate.protocol.version = %self.version
574        ),
575        ret
576    )]
577    fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
578        invocation_debug_logs!(self, "Executing 'Set state {key}'");
579        self.context.eager_state.set(key.clone(), value.clone());
580        self.do_transition(SysNonCompletableEntry(
581            "SysStateSet",
582            SetStateCommandMessage {
583                key: Bytes::from(key.into_bytes()),
584                value: Some(value.into()),
585                ..SetStateCommandMessage::default()
586            },
587        ))
588    }
589
590    #[instrument(
591        level = "trace",
592        skip(self),
593        fields(
594            restate.invocation.id = self.debug_invocation_id(),
595            restate.protocol.state = self.debug_state(),
596            restate.journal.command_index = self.context.journal.command_index(),
597            restate.protocol.version = %self.version
598        ),
599        ret
600    )]
601    fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
602        invocation_debug_logs!(self, "Executing 'Clear state {key}'");
603        self.context.eager_state.clear(key.clone());
604        self.do_transition(SysNonCompletableEntry(
605            "SysStateClear",
606            ClearStateCommandMessage {
607                key: Bytes::from(key.into_bytes()),
608                ..ClearStateCommandMessage::default()
609            },
610        ))
611    }
612
613    #[instrument(
614        level = "trace",
615        skip(self),
616        fields(
617            restate.invocation.id = self.debug_invocation_id(),
618            restate.protocol.state = self.debug_state(),
619            restate.journal.command_index = self.context.journal.command_index(),
620            restate.protocol.version = %self.version
621        ),
622        ret
623    )]
624    fn sys_state_clear_all(&mut self) -> Result<(), Error> {
625        invocation_debug_logs!(self, "Executing 'Clear all state'");
626        self.context.eager_state.clear_all();
627        self.do_transition(SysNonCompletableEntry(
628            "SysStateClearAll",
629            ClearAllStateCommandMessage::default(),
630        ))
631    }
632
633    #[instrument(
634        level = "trace",
635        skip(self),
636        fields(
637            restate.invocation.id = self.debug_invocation_id(),
638            restate.protocol.state = self.debug_state(),
639            restate.journal.command_index = self.context.journal.command_index(),
640            restate.protocol.version = %self.version
641        ),
642        ret
643    )]
644    fn sys_sleep(
645        &mut self,
646        name: String,
647        wake_up_time_since_unix_epoch: Duration,
648        now_since_unix_epoch: Option<Duration>,
649    ) -> VMResult<NotificationHandle> {
650        if self.is_processing() {
651            match (&name, now_since_unix_epoch) {
652                (name, Some(now_since_unix_epoch)) if name.is_empty() => {
653                    debug!(
654                        "Executing 'Timer with duration {:?}'",
655                        wake_up_time_since_unix_epoch - now_since_unix_epoch
656                    );
657                }
658                (name, Some(now_since_unix_epoch)) => {
659                    debug!(
660                        "Executing 'Timer {name} with duration {:?}'",
661                        wake_up_time_since_unix_epoch - now_since_unix_epoch
662                    );
663                }
664                (name, None) if name.is_empty() => {
665                    debug!("Executing 'Timer'");
666                }
667                (name, None) => {
668                    debug!("Executing 'Timer named {name}'");
669                }
670            }
671        }
672
673        let completion_id = self.context.journal.next_completion_notification_id();
674
675        self.do_transition(SysSimpleCompletableEntry(
676            "SysSleep",
677            SleepCommandMessage {
678                wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
679                    .expect("millis since Unix epoch should fit in u64"),
680                result_completion_id: completion_id,
681                name,
682            },
683            completion_id,
684        ))
685    }
686
687    #[instrument(
688        level = "trace",
689        skip(self, input),
690        fields(
691            restate.invocation.id = self.debug_invocation_id(),
692            restate.protocol.state = self.debug_state(),
693            restate.journal.command_index = self.context.journal.command_index(),
694            restate.protocol.version = %self.version
695        ),
696        ret
697    )]
698    fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
699        invocation_debug_logs!(
700            self,
701            "Executing 'Call {}/{}'",
702            target.service,
703            target.handler
704        );
705        if let Some(idempotency_key) = &target.idempotency_key {
706            self.verify_feature_support("attach idempotency key to call", Version::V3)?;
707            if idempotency_key.is_empty() {
708                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
709                unreachable!();
710            }
711        }
712
713        let call_invocation_id_completion_id =
714            self.context.journal.next_completion_notification_id();
715        let result_completion_id = self.context.journal.next_completion_notification_id();
716
717        let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
718            "SysCall",
719            CallCommandMessage {
720                service_name: target.service,
721                handler_name: target.handler,
722                key: target.key.unwrap_or_default(),
723                idempotency_key: target.idempotency_key,
724                headers: target
725                    .headers
726                    .into_iter()
727                    .map(crate::service_protocol::messages::Header::from)
728                    .collect(),
729                parameter: input,
730                invocation_id_notification_idx: call_invocation_id_completion_id,
731                result_completion_id,
732                ..Default::default()
733            },
734            vec![call_invocation_id_completion_id, result_completion_id],
735        ))?;
736
737        if matches!(
738            self.options.implicit_cancellation,
739            ImplicitCancellationOption::Enabled {
740                cancel_children_calls: true,
741                ..
742            }
743        ) {
744            self.tracked_invocation_ids.push(TrackedInvocationId {
745                handle: handles[0],
746                invocation_id: None,
747            })
748        }
749
750        Ok(CallHandle {
751            invocation_id_notification_handle: handles[0],
752            call_notification_handle: handles[1],
753        })
754    }
755
756    #[instrument(
757        level = "trace",
758        skip(self, input),
759        fields(
760            restate.invocation.id = self.debug_invocation_id(),
761            restate.protocol.state = self.debug_state(),
762            restate.journal.command_index = self.context.journal.command_index(),
763            restate.protocol.version = %self.version
764        ),
765        ret
766    )]
767    fn sys_send(
768        &mut self,
769        target: Target,
770        input: Bytes,
771        delay: Option<Duration>,
772    ) -> VMResult<SendHandle> {
773        invocation_debug_logs!(
774            self,
775            "Executing 'Send to {}/{}'",
776            target.service,
777            target.handler
778        );
779        if let Some(idempotency_key) = &target.idempotency_key {
780            self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
781            if idempotency_key.is_empty() {
782                self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
783                unreachable!();
784            }
785        }
786        let call_invocation_id_completion_id =
787            self.context.journal.next_completion_notification_id();
788        let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
789            "SysOneWayCall",
790            OneWayCallCommandMessage {
791                service_name: target.service,
792                handler_name: target.handler,
793                key: target.key.unwrap_or_default(),
794                idempotency_key: target.idempotency_key,
795                headers: target
796                    .headers
797                    .into_iter()
798                    .map(crate::service_protocol::messages::Header::from)
799                    .collect(),
800                parameter: input,
801                invoke_time: delay
802                    .map(|d| {
803                        u64::try_from(d.as_millis())
804                            .expect("millis since Unix epoch should fit in u64")
805                    })
806                    .unwrap_or_default(),
807                invocation_id_notification_idx: call_invocation_id_completion_id,
808                ..Default::default()
809            },
810            call_invocation_id_completion_id,
811        ))?;
812
813        if matches!(
814            self.options.implicit_cancellation,
815            ImplicitCancellationOption::Enabled {
816                cancel_children_one_way_calls: true,
817                ..
818            }
819        ) {
820            self.tracked_invocation_ids.push(TrackedInvocationId {
821                handle: invocation_id_notification_handle,
822                invocation_id: None,
823            })
824        }
825
826        Ok(SendHandle {
827            invocation_id_notification_handle,
828        })
829    }
830
831    #[instrument(
832        level = "trace",
833        skip(self),
834        fields(
835            restate.invocation.id = self.debug_invocation_id(),
836            restate.protocol.state = self.debug_state(),
837            restate.journal.command_index = self.context.journal.command_index(),
838            restate.protocol.version = %self.version
839        ),
840        ret
841    )]
842    fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
843        invocation_debug_logs!(self, "Executing 'Create awakeable'");
844
845        let signal_id = self.context.journal.next_signal_notification_id();
846
847        let handle = self.do_transition(CreateSignalHandle(
848            "SysAwakeable",
849            NotificationId::SignalId(signal_id),
850        ))?;
851
852        Ok((
853            awakeable_id_str(&self.context.expect_start_info().id, signal_id),
854            handle,
855        ))
856    }
857
858    #[instrument(
859        level = "trace",
860        skip(self, value),
861        fields(
862            restate.invocation.id = self.debug_invocation_id(),
863            restate.protocol.state = self.debug_state(),
864            restate.journal.command_index = self.context.journal.command_index(),
865            restate.protocol.version = %self.version
866        ),
867        ret
868    )]
869    fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
870        invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
871        self.do_transition(SysNonCompletableEntry(
872            "SysCompleteAwakeable",
873            CompleteAwakeableCommandMessage {
874                awakeable_id: id,
875                result: Some(match value {
876                    NonEmptyValue::Success(s) => {
877                        complete_awakeable_command_message::Result::Value(s.into())
878                    }
879                    NonEmptyValue::Failure(f) => {
880                        complete_awakeable_command_message::Result::Failure(f.into())
881                    }
882                }),
883                ..Default::default()
884            },
885        ))
886    }
887
888    #[instrument(
889        level = "trace",
890        skip(self),
891        fields(
892            restate.invocation.id = self.debug_invocation_id(),
893            restate.protocol.state = self.debug_state(),
894            restate.journal.command_index = self.context.journal.command_index(),
895            restate.protocol.version = %self.version
896        ),
897        ret
898    )]
899    fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
900        invocation_debug_logs!(self, "Executing 'Create named signal'");
901
902        self.do_transition(CreateSignalHandle(
903            "SysCreateNamedSignal",
904            NotificationId::SignalName(signal_name),
905        ))
906    }
907
908    #[instrument(
909        level = "trace",
910        skip(self, value),
911        fields(
912            restate.invocation.id = self.debug_invocation_id(),
913            restate.protocol.state = self.debug_state(),
914            restate.journal.command_index = self.context.journal.command_index(),
915            restate.protocol.version = %self.version
916        ),
917        ret
918    )]
919    fn sys_complete_signal(
920        &mut self,
921        target_invocation_id: String,
922        signal_name: String,
923        value: NonEmptyValue,
924    ) -> VMResult<()> {
925        invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
926        self.do_transition(SysNonCompletableEntry(
927            "SysCompleteAwakeable",
928            SendSignalCommandMessage {
929                target_invocation_id,
930                signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
931                result: Some(match value {
932                    NonEmptyValue::Success(s) => {
933                        send_signal_command_message::Result::Value(s.into())
934                    }
935                    NonEmptyValue::Failure(f) => {
936                        send_signal_command_message::Result::Failure(f.into())
937                    }
938                }),
939                ..Default::default()
940            },
941        ))
942    }
943
944    #[instrument(
945        level = "trace",
946        skip(self),
947        fields(
948            restate.invocation.id = self.debug_invocation_id(),
949            restate.protocol.state = self.debug_state(),
950            restate.journal.command_index = self.context.journal.command_index(),
951            restate.protocol.version = %self.version
952        ),
953        ret
954    )]
955    fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
956        invocation_debug_logs!(self, "Executing 'Await promise {key}'");
957
958        let result_completion_id = self.context.journal.next_completion_notification_id();
959        self.do_transition(SysSimpleCompletableEntry(
960            "SysGetPromise",
961            GetPromiseCommandMessage {
962                key,
963                result_completion_id,
964                ..Default::default()
965            },
966            result_completion_id,
967        ))
968    }
969
970    #[instrument(
971        level = "trace",
972        skip(self),
973        fields(
974            restate.invocation.id = self.debug_invocation_id(),
975            restate.protocol.state = self.debug_state(),
976            restate.journal.command_index = self.context.journal.command_index(),
977            restate.protocol.version = %self.version
978        ),
979        ret
980    )]
981    fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
982        invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
983
984        let result_completion_id = self.context.journal.next_completion_notification_id();
985        self.do_transition(SysSimpleCompletableEntry(
986            "SysPeekPromise",
987            PeekPromiseCommandMessage {
988                key,
989                result_completion_id,
990                ..Default::default()
991            },
992            result_completion_id,
993        ))
994    }
995
996    #[instrument(
997        level = "trace",
998        skip(self, value),
999        fields(
1000            restate.invocation.id = self.debug_invocation_id(),
1001            restate.protocol.state = self.debug_state(),
1002            restate.journal.command_index = self.context.journal.command_index(),
1003            restate.protocol.version = %self.version
1004        ),
1005        ret
1006    )]
1007    fn sys_complete_promise(
1008        &mut self,
1009        key: String,
1010        value: NonEmptyValue,
1011    ) -> VMResult<NotificationHandle> {
1012        invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1013
1014        let result_completion_id = self.context.journal.next_completion_notification_id();
1015        self.do_transition(SysSimpleCompletableEntry(
1016            "SysCompletePromise",
1017            CompletePromiseCommandMessage {
1018                key,
1019                completion: Some(match value {
1020                    NonEmptyValue::Success(s) => {
1021                        complete_promise_command_message::Completion::CompletionValue(s.into())
1022                    }
1023                    NonEmptyValue::Failure(f) => {
1024                        complete_promise_command_message::Completion::CompletionFailure(f.into())
1025                    }
1026                }),
1027                result_completion_id,
1028                ..Default::default()
1029            },
1030            result_completion_id,
1031        ))
1032    }
1033
1034    #[instrument(
1035        level = "trace",
1036        skip(self),
1037        fields(
1038            restate.invocation.id = self.debug_invocation_id(),
1039            restate.protocol.state = self.debug_state(),
1040            restate.journal.command_index = self.context.journal.command_index(),
1041            restate.protocol.version = %self.version
1042        ),
1043        ret
1044    )]
1045    fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1046        match self.do_transition(SysRun(name.clone())) {
1047            Ok(handle) => {
1048                if enabled!(Level::DEBUG) {
1049                    // Store the name, we need it later when completing
1050                    self.sys_run_names.insert(handle, name);
1051                }
1052                Ok(handle)
1053            }
1054            Err(e) => Err(e),
1055        }
1056    }
1057
1058    #[instrument(
1059        level = "trace",
1060        skip(self, value, retry_policy),
1061        fields(
1062            restate.invocation.id = self.debug_invocation_id(),
1063            restate.protocol.state = self.debug_state(),
1064            restate.journal.command_index = self.context.journal.command_index(),
1065            restate.protocol.version = %self.version
1066        ),
1067        ret
1068    )]
1069    fn propose_run_completion(
1070        &mut self,
1071        notification_handle: NotificationHandle,
1072        value: RunExitResult,
1073        retry_policy: RetryPolicy,
1074    ) -> VMResult<()> {
1075        if enabled!(Level::DEBUG) {
1076            let name: &str = self
1077                .sys_run_names
1078                .get(&notification_handle)
1079                .map(String::as_str)
1080                .unwrap_or_default();
1081            match &value {
1082                RunExitResult::Success(_) => {
1083                    invocation_debug_logs!(self, "Journaling run '{name}' success result");
1084                }
1085                RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1086                    invocation_debug_logs!(
1087                        self,
1088                        "Journaling run '{name}' terminal failure {code} result"
1089                    );
1090                }
1091                RunExitResult::RetryableFailure { .. } => {
1092                    invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1093                }
1094            }
1095        }
1096
1097        self.do_transition(ProposeRunCompletion(
1098            notification_handle,
1099            value,
1100            retry_policy,
1101        ))
1102    }
1103
1104    #[instrument(
1105        level = "trace",
1106        skip(self),
1107        fields(
1108            restate.invocation.id = self.debug_invocation_id(),
1109            restate.protocol.state = self.debug_state(),
1110            restate.journal.command_index = self.context.journal.command_index(),
1111            restate.protocol.version = %self.version
1112        ),
1113        ret
1114    )]
1115    fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1116        invocation_debug_logs!(
1117            self,
1118            "Executing 'Cancel invocation' of {target_invocation_id}"
1119        );
1120        self.verify_feature_support("cancel invocation", Version::V3)?;
1121        self.do_transition(SysNonCompletableEntry(
1122            "SysCancelInvocation",
1123            SendSignalCommandMessage {
1124                target_invocation_id,
1125                signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1126                result: Some(send_signal_command_message::Result::Void(Default::default())),
1127                ..Default::default()
1128            },
1129        ))
1130    }
1131
1132    #[instrument(
1133        level = "trace",
1134        skip(self),
1135        fields(
1136            restate.invocation.id = self.debug_invocation_id(),
1137            restate.protocol.state = self.debug_state(),
1138            restate.journal.command_index = self.context.journal.command_index(),
1139            restate.protocol.version = %self.version
1140        ),
1141        ret
1142    )]
1143    fn sys_attach_invocation(
1144        &mut self,
1145        target: AttachInvocationTarget,
1146    ) -> VMResult<NotificationHandle> {
1147        invocation_debug_logs!(self, "Executing 'Attach invocation'");
1148        self.verify_feature_support("attach invocation", Version::V3)?;
1149
1150        let result_completion_id = self.context.journal.next_completion_notification_id();
1151        self.do_transition(SysSimpleCompletableEntry(
1152            "SysAttachInvocation",
1153            AttachInvocationCommandMessage {
1154                target: Some(match target {
1155                    AttachInvocationTarget::InvocationId(id) => {
1156                        attach_invocation_command_message::Target::InvocationId(id)
1157                    }
1158                    AttachInvocationTarget::WorkflowId { name, key } => {
1159                        attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1160                            workflow_name: name,
1161                            workflow_key: key,
1162                        })
1163                    }
1164                    AttachInvocationTarget::IdempotencyId {
1165                        service_name,
1166                        service_key,
1167                        handler_name,
1168                        idempotency_key,
1169                    } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1170                        IdempotentRequestTarget {
1171                            service_name,
1172                            service_key,
1173                            handler_name,
1174                            idempotency_key,
1175                        },
1176                    ),
1177                }),
1178                result_completion_id,
1179                ..Default::default()
1180            },
1181            result_completion_id,
1182        ))
1183    }
1184
1185    #[instrument(
1186        level = "trace",
1187        skip(self),
1188        fields(
1189            restate.invocation.id = self.debug_invocation_id(),
1190            restate.protocol.state = self.debug_state(),
1191            restate.journal.command_index = self.context.journal.command_index(),
1192            restate.protocol.version = %self.version
1193        ),
1194        ret
1195    )]
1196    fn sys_get_invocation_output(
1197        &mut self,
1198        target: AttachInvocationTarget,
1199    ) -> VMResult<NotificationHandle> {
1200        invocation_debug_logs!(self, "Executing 'Get invocation output'");
1201        self.verify_feature_support("get invocation output", Version::V3)?;
1202
1203        let result_completion_id = self.context.journal.next_completion_notification_id();
1204        self.do_transition(SysSimpleCompletableEntry(
1205            "SysGetInvocationOutput",
1206            GetInvocationOutputCommandMessage {
1207                target: Some(match target {
1208                    AttachInvocationTarget::InvocationId(id) => {
1209                        get_invocation_output_command_message::Target::InvocationId(id)
1210                    }
1211                    AttachInvocationTarget::WorkflowId { name, key } => {
1212                        get_invocation_output_command_message::Target::WorkflowTarget(
1213                            WorkflowTarget {
1214                                workflow_name: name,
1215                                workflow_key: key,
1216                            },
1217                        )
1218                    }
1219                    AttachInvocationTarget::IdempotencyId {
1220                        service_name,
1221                        service_key,
1222                        handler_name,
1223                        idempotency_key,
1224                    } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1225                        IdempotentRequestTarget {
1226                            service_name,
1227                            service_key,
1228                            handler_name,
1229                            idempotency_key,
1230                        },
1231                    ),
1232                }),
1233                result_completion_id,
1234                ..Default::default()
1235            },
1236            result_completion_id,
1237        ))
1238    }
1239
1240    #[instrument(
1241        level = "trace",
1242        skip(self, value),
1243        fields(
1244            restate.invocation.id = self.debug_invocation_id(),
1245            restate.protocol.state = self.debug_state(),
1246            restate.journal.command_index = self.context.journal.command_index(),
1247            restate.protocol.version = %self.version
1248        ),
1249        ret
1250    )]
1251    fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1252        match &value {
1253            NonEmptyValue::Success(_) => {
1254                invocation_debug_logs!(self, "Writing invocation result success value");
1255            }
1256            NonEmptyValue::Failure(_) => {
1257                invocation_debug_logs!(self, "Writing invocation result failure value");
1258            }
1259        }
1260        self.do_transition(SysNonCompletableEntry(
1261            "SysWriteOutput",
1262            OutputCommandMessage {
1263                result: Some(match value {
1264                    NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1265                    NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1266                }),
1267                ..OutputCommandMessage::default()
1268            },
1269        ))
1270    }
1271
1272    #[instrument(
1273        level = "trace",
1274        skip(self),
1275        fields(
1276            restate.invocation.id = self.debug_invocation_id(),
1277            restate.protocol.state = self.debug_state(),
1278            restate.journal.command_index = self.context.journal.command_index(),
1279            restate.protocol.version = %self.version
1280        ),
1281        ret
1282    )]
1283    fn sys_end(&mut self) -> Result<(), Error> {
1284        invocation_debug_logs!(self, "End of the invocation");
1285        self.do_transition(SysEnd)
1286    }
1287
1288    fn is_waiting_preflight(&self) -> bool {
1289        matches!(
1290            &self.last_transition,
1291            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1292        )
1293    }
1294
1295    fn is_replaying(&self) -> bool {
1296        matches!(&self.last_transition, Ok(State::Replaying { .. }))
1297    }
1298
1299    fn is_processing(&self) -> bool {
1300        matches!(&self.last_transition, Ok(State::Processing { .. }))
1301    }
1302
1303    fn last_command_index(&self) -> i64 {
1304        self.context.journal.command_index()
1305    }
1306}
1307
1308const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1309    .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1310    .with_encode_padding(false);
1311const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1312
1313const AWAKEABLE_PREFIX: &str = "sign_1";
1314
1315fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1316    let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1317    input_buf.put_slice(id);
1318    input_buf.put_u32(completion_index);
1319    format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1320}