restate_sdk_shared_core/vm/
mod.rs

1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3    attach_invocation_command_message, complete_awakeable_command_message,
4    complete_promise_command_message, get_invocation_output_command_message,
5    output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6    CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7    CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8    GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9    OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10    SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14    UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion, EMPTY_IDEMPOTENCY_KEY,
15};
16use crate::vm::transitions::*;
17use crate::{
18    AttachInvocationTarget, CallHandle, DoProgressResponse, Error, Header,
19    ImplicitCancellationOption, Input, NonEmptyValue, NotificationHandle, ResponseHead,
20    RetryPolicy, RunExitResult, SendHandle, SuspendedOrVMError, TakeOutputResult, Target,
21    TerminalFailure, VMOptions, VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
22};
23use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
24use base64::{alphabet, Engine};
25use bytes::{Buf, BufMut, Bytes, BytesMut};
26use context::{AsyncResultsState, Context, Output, RunState};
27use std::borrow::Cow;
28use std::collections::{HashMap, VecDeque};
29use std::mem::size_of;
30use std::time::Duration;
31use std::{fmt, mem};
32use strum::IntoStaticStr;
33use tracing::{debug, enabled, instrument, Level};
34
35mod context;
36pub(crate) mod errors;
37mod transitions;
38
39const CONTENT_TYPE: &str = "content-type";
40
41#[derive(Debug, IntoStaticStr)]
42pub(crate) enum State {
43    WaitingStart,
44    WaitingReplayEntries {
45        received_entries: u32,
46        commands: VecDeque<RawMessage>,
47        async_results: AsyncResultsState,
48    },
49    Replaying {
50        commands: VecDeque<RawMessage>,
51        run_state: RunState,
52        async_results: AsyncResultsState,
53    },
54    Processing {
55        processing_first_entry: bool,
56        run_state: RunState,
57        async_results: AsyncResultsState,
58    },
59    Ended,
60    Suspended,
61}
62
63impl State {
64    fn as_unexpected_state(&self, event: &'static str) -> Error {
65        UnexpectedStateError::new(self.into(), event).into()
66    }
67}
68
69struct TrackedInvocationId {
70    handle: NotificationHandle,
71    invocation_id: Option<String>,
72}
73
74impl TrackedInvocationId {
75    fn is_resolved(&self) -> bool {
76        self.invocation_id.is_some()
77    }
78}
79
80pub struct CoreVM {
81    version: Version,
82    options: VMOptions,
83
84    // Input decoder
85    decoder: Decoder,
86
87    // State machine
88    context: Context,
89    last_transition: Result<State, Error>,
90
91    // Implicit cancellation tracking
92    tracked_invocation_ids: Vec<TrackedInvocationId>,
93
94    // Run names, useful for debugging
95    sys_run_names: HashMap<NotificationHandle, String>,
96}
97
98impl CoreVM {
99    // Returns empty string if the invocation id is not present
100    fn debug_invocation_id(&self) -> &str {
101        if let Some(start_info) = self.context.start_info() {
102            &start_info.debug_id
103        } else {
104            ""
105        }
106    }
107
108    fn debug_state(&self) -> &'static str {
109        match &self.last_transition {
110            Ok(s) => s.into(),
111            Err(_) => "Failed",
112        }
113    }
114
115    fn verify_feature_support(
116        &mut self,
117        feature: &'static str,
118        minimum_required_protocol: Version,
119    ) -> VMResult<()> {
120        if self.version < minimum_required_protocol {
121            return self.do_transition(HitError {
122                error: UnsupportedFeatureForNegotiatedVersion::new(
123                    feature,
124                    self.version,
125                    minimum_required_protocol,
126                )
127                .into(),
128                next_retry_delay: None,
129            });
130        }
131        Ok(())
132    }
133
134    fn _is_completed(&self, handle: NotificationHandle) -> bool {
135        match &self.last_transition {
136            Ok(State::Replaying { async_results, .. })
137            | Ok(State::Processing { async_results, .. }) => {
138                async_results.is_handle_completed(handle)
139            }
140            _ => false,
141        }
142    }
143
144    fn _do_progress(
145        &mut self,
146        any_handle: Vec<NotificationHandle>,
147    ) -> Result<DoProgressResponse, SuspendedOrVMError> {
148        match self.do_transition(DoProgress(any_handle)) {
149            Ok(Ok(do_progress_response)) => Ok(do_progress_response),
150            Ok(Err(suspended)) => Err(SuspendedOrVMError::Suspended(suspended)),
151            Err(e) => Err(SuspendedOrVMError::VM(e)),
152        }
153    }
154
155    fn is_implicit_cancellation_enabled(&self) -> bool {
156        matches!(
157            self.options.implicit_cancellation,
158            ImplicitCancellationOption::Enabled { .. }
159        )
160    }
161}
162
163impl fmt::Debug for CoreVM {
164    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165        let mut s = f.debug_struct("CoreVM");
166        s.field("version", &self.version);
167
168        if let Some(start_info) = self.context.start_info() {
169            s.field("invocation_id", &start_info.debug_id);
170        }
171
172        match &self.last_transition {
173            Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
174            Err(_) => s.field("last_transition", &"Errored"),
175        };
176
177        s.field("command_index", &self.context.journal.command_index())
178            .field(
179                "notification_index",
180                &self.context.journal.notification_index(),
181            )
182            .finish()
183    }
184}
185
186// --- Bound checks
187#[allow(unused)]
188const fn is_send<T: Send>() {}
189const _: () = is_send::<CoreVM>();
190
191// Macro used for informative debug logs
192macro_rules! invocation_debug_logs {
193    ($this:expr, $($arg:tt)*) => {
194        if ($this.is_processing()) {
195            tracing::debug!($($arg)*)
196        }
197    };
198}
199
200impl super::VM for CoreVM {
201    #[instrument(level = "trace", skip(request_headers), ret)]
202    fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
203        let version = request_headers
204            .extract(CONTENT_TYPE)
205            .map_err(|e| {
206                Error::new(
207                    errors::codes::BAD_REQUEST,
208                    format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
209                )
210            })?
211            .ok_or(errors::MISSING_CONTENT_TYPE)?
212            .parse::<Version>()?;
213
214        if version < Version::minimum_supported_version()
215            || version > Version::maximum_supported_version()
216        {
217            return Err(Error::new(
218                errors::codes::UNSUPPORTED_MEDIA_TYPE,
219                format!(
220                    "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
221                    You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
222                    version,
223                    Version::minimum_supported_version(),
224                    Version::maximum_supported_version()
225                ),
226            ));
227        }
228
229        Ok(Self {
230            version,
231            options,
232            decoder: Decoder::new(version),
233            context: Context {
234                input_is_closed: false,
235                output: Output::new(version),
236                start_info: None,
237                journal: Default::default(),
238                eager_state: Default::default(),
239                next_retry_delay: None,
240            },
241            last_transition: Ok(State::WaitingStart),
242            tracked_invocation_ids: vec![],
243            sys_run_names: HashMap::with_capacity(0),
244        })
245    }
246
247    #[instrument(
248        level = "trace",
249        skip(self),
250        fields(
251            restate.invocation.id = self.debug_invocation_id(),
252            restate.protocol.state = self.debug_state(),
253            restate.journal.command_index = self.context.journal.command_index(),
254            restate.protocol.version = %self.version
255        ),
256        ret
257    )]
258    fn get_response_head(&self) -> ResponseHead {
259        ResponseHead {
260            status_code: 200,
261            headers: vec![Header {
262                key: Cow::Borrowed(CONTENT_TYPE),
263                value: Cow::Borrowed(self.version.content_type()),
264            }],
265            version: self.version,
266        }
267    }
268
269    #[instrument(
270        level = "trace",
271        skip(self),
272        fields(
273            restate.invocation.id = self.debug_invocation_id(),
274            restate.protocol.state = self.debug_state(),
275            restate.journal.command_index = self.context.journal.command_index(),
276            restate.protocol.version = %self.version
277        ),
278        ret
279    )]
280    fn notify_input(&mut self, buffer: Bytes) {
281        self.decoder.push(buffer);
282        loop {
283            match self.decoder.consume_next() {
284                Ok(Some(msg)) => {
285                    if self.do_transition(NewMessage(msg)).is_err() {
286                        return;
287                    }
288                }
289                Ok(None) => {
290                    return;
291                }
292                Err(e) => {
293                    if self
294                        .do_transition(HitError {
295                            error: e.into(),
296                            next_retry_delay: None,
297                        })
298                        .is_err()
299                    {
300                        return;
301                    }
302                }
303            }
304        }
305    }
306
307    #[instrument(
308        level = "trace",
309        skip(self),
310        fields(
311            restate.invocation.id = self.debug_invocation_id(),
312            restate.protocol.state = self.debug_state(),
313            restate.journal.command_index = self.context.journal.command_index(),
314            restate.protocol.version = %self.version
315        ),
316        ret
317    )]
318    fn notify_input_closed(&mut self) {
319        self.context.input_is_closed = true;
320        let _ = self.do_transition(NotifyInputClosed);
321    }
322
323    #[instrument(
324        level = "trace",
325        skip(self),
326        fields(
327            restate.invocation.id = self.debug_invocation_id(),
328            restate.protocol.state = self.debug_state(),
329            restate.journal.command_index = self.context.journal.command_index(),
330            restate.protocol.version = %self.version
331        ),
332        ret
333    )]
334    fn notify_error(&mut self, error: Error, next_retry_delay: Option<Duration>) {
335        let _ = self.do_transition(HitError {
336            error,
337            next_retry_delay,
338        });
339    }
340
341    #[instrument(
342        level = "trace",
343        skip(self),
344        fields(
345            restate.invocation.id = self.debug_invocation_id(),
346            restate.protocol.state = self.debug_state(),
347            restate.journal.command_index = self.context.journal.command_index(),
348            restate.protocol.version = %self.version
349        ),
350        ret
351    )]
352    fn take_output(&mut self) -> TakeOutputResult {
353        if self.context.output.buffer.has_remaining() {
354            TakeOutputResult::Buffer(
355                self.context
356                    .output
357                    .buffer
358                    .copy_to_bytes(self.context.output.buffer.remaining()),
359            )
360        } else if !self.context.output.is_closed() {
361            TakeOutputResult::Buffer(Bytes::default())
362        } else {
363            TakeOutputResult::EOF
364        }
365    }
366
367    #[instrument(
368        level = "trace",
369        skip(self),
370        fields(
371            restate.invocation.id = self.debug_invocation_id(),
372            restate.protocol.state = self.debug_state(),
373            restate.journal.command_index = self.context.journal.command_index(),
374            restate.protocol.version = %self.version
375        ),
376        ret
377    )]
378    fn is_ready_to_execute(&self) -> Result<bool, Error> {
379        match &self.last_transition {
380            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
381            Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
382            Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
383            Err(e) => Err(e.clone()),
384        }
385    }
386
387    #[instrument(
388        level = "trace",
389        skip(self),
390        fields(
391            restate.invocation.id = self.debug_invocation_id(),
392            restate.protocol.state = self.debug_state(),
393            restate.journal.command_index = self.context.journal.command_index(),
394            restate.protocol.version = %self.version
395        ),
396        ret
397    )]
398    fn is_completed(&self, handle: NotificationHandle) -> bool {
399        self._is_completed(handle)
400    }
401
402    #[instrument(
403        level = "trace",
404        skip(self),
405        fields(
406            restate.invocation.id = self.debug_invocation_id(),
407            restate.protocol.state = self.debug_state(),
408            restate.journal.command_index = self.context.journal.command_index(),
409            restate.protocol.version = %self.version
410        ),
411        ret
412    )]
413    fn do_progress(
414        &mut self,
415        mut any_handle: Vec<NotificationHandle>,
416    ) -> Result<DoProgressResponse, SuspendedOrVMError> {
417        if self.is_implicit_cancellation_enabled() {
418            // We want the runtime to wake us up in case cancel notification comes in.
419            any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
420
421            match self._do_progress(any_handle) {
422                Ok(DoProgressResponse::AnyCompleted) => {
423                    // If it's cancel signal, then let's go on with the cancellation logic
424                    if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
425                        // Loop once over the tracked invocation ids to resolve the unresolved ones
426                        for i in 0..self.tracked_invocation_ids.len() {
427                            if self.tracked_invocation_ids[i].is_resolved() {
428                                continue;
429                            }
430
431                            let handle = self.tracked_invocation_ids[i].handle;
432
433                            // Try to resolve it
434                            match self._do_progress(vec![handle]) {
435                                Ok(DoProgressResponse::AnyCompleted) => {
436                                    let invocation_id = match self.do_transition(CopyNotification(handle)) {
437                                        Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
438                                        _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
439                                    }?;
440
441                                    // This handle is resolved
442                                    self.tracked_invocation_ids[i].invocation_id =
443                                        Some(invocation_id);
444                                }
445                                res => return res,
446                            }
447                        }
448
449                        // Now we got all the invocation IDs, let's cancel!
450                        for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
451                            self.sys_cancel_invocation(
452                                tracked_invocation_id
453                                    .invocation_id
454                                    .expect("We resolved before all the invocation ids"),
455                            )
456                            .map_err(SuspendedOrVMError::VM)?;
457                        }
458
459                        // Flip the cancellation
460                        let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
461
462                        // Done
463                        Ok(DoProgressResponse::CancelSignalReceived)
464                    } else {
465                        Ok(DoProgressResponse::AnyCompleted)
466                    }
467                }
468                res => res,
469            }
470        } else {
471            self._do_progress(any_handle)
472        }
473    }
474
475    #[instrument(
476        level = "trace",
477        skip(self),
478        fields(
479            restate.invocation.id = self.debug_invocation_id(),
480            restate.protocol.state = self.debug_state(),
481            restate.journal.command_index = self.context.journal.command_index(),
482            restate.protocol.version = %self.version
483        ),
484        ret
485    )]
486    fn take_notification(
487        &mut self,
488        handle: NotificationHandle,
489    ) -> Result<Option<Value>, SuspendedOrVMError> {
490        match self.do_transition(TakeNotification(handle)) {
491            Ok(Ok(Some(value))) => {
492                if self.is_implicit_cancellation_enabled() {
493                    // Let's check if that's one of the tracked invocation ids
494                    // We can do binary search here because we assume tracked_invocation_ids is ordered, as handles are incremental numbers
495                    if let Ok(found) = self
496                        .tracked_invocation_ids
497                        .binary_search_by(|tracked| tracked.handle.cmp(&handle))
498                    {
499                        let Value::InvocationId(invocation_id) = &value else {
500                            panic!("Expecting an invocation id here, but got {value:?}");
501                        };
502                        // Keep track of this invocation id
503                        self.tracked_invocation_ids
504                            .get_mut(found)
505                            .unwrap()
506                            .invocation_id = Some(invocation_id.clone());
507                    }
508                }
509
510                Ok(Some(value))
511            }
512            Ok(Ok(None)) => Ok(None),
513            Ok(Err(suspended)) => Err(SuspendedOrVMError::Suspended(suspended)),
514            Err(e) => Err(SuspendedOrVMError::VM(e)),
515        }
516    }
517
518    #[instrument(
519        level = "trace",
520        skip(self),
521        fields(
522            restate.invocation.id = self.debug_invocation_id(),
523            restate.protocol.state = self.debug_state(),
524            restate.journal.command_index = self.context.journal.command_index(),
525            restate.protocol.version = %self.version
526        ),
527        ret
528    )]
529    fn sys_input(&mut self) -> Result<Input, Error> {
530        self.do_transition(SysInput)
531    }
532
533    #[instrument(
534        level = "trace",
535        skip(self),
536        fields(
537            restate.invocation.id = self.debug_invocation_id(),
538            restate.protocol.state = self.debug_state(),
539            restate.journal.command_index = self.context.journal.command_index(),
540            restate.protocol.version = %self.version
541        ),
542        ret
543    )]
544    fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
545        invocation_debug_logs!(self, "Executing 'Get state {key}'");
546        self.do_transition(SysStateGet(key))
547    }
548
549    #[instrument(
550        level = "trace",
551        skip(self),
552        fields(
553            restate.invocation.id = self.debug_invocation_id(),
554            restate.protocol.state = self.debug_state(),
555            restate.journal.command_index = self.context.journal.command_index(),
556            restate.protocol.version = %self.version
557        ),
558        ret
559    )]
560    fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
561        invocation_debug_logs!(self, "Executing 'Get state keys'");
562        self.do_transition(SysStateGetKeys)
563    }
564
565    #[instrument(
566        level = "trace",
567        skip(self, value),
568        fields(
569            restate.invocation.id = self.debug_invocation_id(),
570            restate.protocol.state = self.debug_state(),
571            restate.journal.command_index = self.context.journal.command_index(),
572            restate.protocol.version = %self.version
573        ),
574        ret
575    )]
576    fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
577        invocation_debug_logs!(self, "Executing 'Set state {key}'");
578        self.context.eager_state.set(key.clone(), value.clone());
579        self.do_transition(SysNonCompletableEntry(
580            "SysStateSet",
581            SetStateCommandMessage {
582                key: Bytes::from(key.into_bytes()),
583                value: Some(value.into()),
584                ..SetStateCommandMessage::default()
585            },
586        ))
587    }
588
589    #[instrument(
590        level = "trace",
591        skip(self),
592        fields(
593            restate.invocation.id = self.debug_invocation_id(),
594            restate.protocol.state = self.debug_state(),
595            restate.journal.command_index = self.context.journal.command_index(),
596            restate.protocol.version = %self.version
597        ),
598        ret
599    )]
600    fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
601        invocation_debug_logs!(self, "Executing 'Clear state {key}'");
602        self.context.eager_state.clear(key.clone());
603        self.do_transition(SysNonCompletableEntry(
604            "SysStateClear",
605            ClearStateCommandMessage {
606                key: Bytes::from(key.into_bytes()),
607                ..ClearStateCommandMessage::default()
608            },
609        ))
610    }
611
612    #[instrument(
613        level = "trace",
614        skip(self),
615        fields(
616            restate.invocation.id = self.debug_invocation_id(),
617            restate.protocol.state = self.debug_state(),
618            restate.journal.command_index = self.context.journal.command_index(),
619            restate.protocol.version = %self.version
620        ),
621        ret
622    )]
623    fn sys_state_clear_all(&mut self) -> Result<(), Error> {
624        invocation_debug_logs!(self, "Executing 'Clear all state'");
625        self.context.eager_state.clear_all();
626        self.do_transition(SysNonCompletableEntry(
627            "SysStateClearAll",
628            ClearAllStateCommandMessage::default(),
629        ))
630    }
631
632    #[instrument(
633        level = "trace",
634        skip(self),
635        fields(
636            restate.invocation.id = self.debug_invocation_id(),
637            restate.protocol.state = self.debug_state(),
638            restate.journal.command_index = self.context.journal.command_index(),
639            restate.protocol.version = %self.version
640        ),
641        ret
642    )]
643    fn sys_sleep(
644        &mut self,
645        name: String,
646        wake_up_time_since_unix_epoch: Duration,
647        now_since_unix_epoch: Option<Duration>,
648    ) -> VMResult<NotificationHandle> {
649        if self.is_processing() {
650            match (&name, now_since_unix_epoch) {
651                (name, Some(now_since_unix_epoch)) if name.is_empty() => {
652                    debug!(
653                        "Executing 'Timer with duration {:?}'",
654                        wake_up_time_since_unix_epoch - now_since_unix_epoch
655                    );
656                }
657                (name, Some(now_since_unix_epoch)) => {
658                    debug!(
659                        "Executing 'Timer {name} with duration {:?}'",
660                        wake_up_time_since_unix_epoch - now_since_unix_epoch
661                    );
662                }
663                (name, None) if name.is_empty() => {
664                    debug!("Executing 'Timer'");
665                }
666                (name, None) => {
667                    debug!("Executing 'Timer named {name}'");
668                }
669            }
670        }
671
672        let completion_id = self.context.journal.next_completion_notification_id();
673
674        self.do_transition(SysSimpleCompletableEntry(
675            "SysSleep",
676            SleepCommandMessage {
677                wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
678                    .expect("millis since Unix epoch should fit in u64"),
679                result_completion_id: completion_id,
680                name,
681            },
682            completion_id,
683        ))
684    }
685
686    #[instrument(
687        level = "trace",
688        skip(self, input),
689        fields(
690            restate.invocation.id = self.debug_invocation_id(),
691            restate.protocol.state = self.debug_state(),
692            restate.journal.command_index = self.context.journal.command_index(),
693            restate.protocol.version = %self.version
694        ),
695        ret
696    )]
697    fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
698        invocation_debug_logs!(
699            self,
700            "Executing 'Call {}/{}'",
701            target.service,
702            target.handler
703        );
704        if let Some(idempotency_key) = &target.idempotency_key {
705            self.verify_feature_support("attach idempotency key to call", Version::V3)?;
706            if idempotency_key.is_empty() {
707                self.do_transition(HitError {
708                    error: EMPTY_IDEMPOTENCY_KEY,
709                    next_retry_delay: None,
710                })?;
711                unreachable!();
712            }
713        }
714
715        let call_invocation_id_completion_id =
716            self.context.journal.next_completion_notification_id();
717        let result_completion_id = self.context.journal.next_completion_notification_id();
718
719        let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
720            "SysCall",
721            CallCommandMessage {
722                service_name: target.service,
723                handler_name: target.handler,
724                key: target.key.unwrap_or_default(),
725                idempotency_key: target.idempotency_key,
726                headers: target
727                    .headers
728                    .into_iter()
729                    .map(crate::service_protocol::messages::Header::from)
730                    .collect(),
731                parameter: input,
732                invocation_id_notification_idx: call_invocation_id_completion_id,
733                result_completion_id,
734                ..Default::default()
735            },
736            vec![call_invocation_id_completion_id, result_completion_id],
737        ))?;
738
739        if matches!(
740            self.options.implicit_cancellation,
741            ImplicitCancellationOption::Enabled {
742                cancel_children_calls: true,
743                ..
744            }
745        ) {
746            self.tracked_invocation_ids.push(TrackedInvocationId {
747                handle: handles[0],
748                invocation_id: None,
749            })
750        }
751
752        Ok(CallHandle {
753            invocation_id_notification_handle: handles[0],
754            call_notification_handle: handles[1],
755        })
756    }
757
758    #[instrument(
759        level = "trace",
760        skip(self, input),
761        fields(
762            restate.invocation.id = self.debug_invocation_id(),
763            restate.protocol.state = self.debug_state(),
764            restate.journal.command_index = self.context.journal.command_index(),
765            restate.protocol.version = %self.version
766        ),
767        ret
768    )]
769    fn sys_send(
770        &mut self,
771        target: Target,
772        input: Bytes,
773        delay: Option<Duration>,
774    ) -> VMResult<SendHandle> {
775        invocation_debug_logs!(
776            self,
777            "Executing 'Send to {}/{}'",
778            target.service,
779            target.handler
780        );
781        if let Some(idempotency_key) = &target.idempotency_key {
782            self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
783            if idempotency_key.is_empty() {
784                self.do_transition(HitError {
785                    error: EMPTY_IDEMPOTENCY_KEY,
786                    next_retry_delay: None,
787                })?;
788                unreachable!();
789            }
790        }
791        let call_invocation_id_completion_id =
792            self.context.journal.next_completion_notification_id();
793        let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
794            "SysOneWayCall",
795            OneWayCallCommandMessage {
796                service_name: target.service,
797                handler_name: target.handler,
798                key: target.key.unwrap_or_default(),
799                idempotency_key: target.idempotency_key,
800                headers: target
801                    .headers
802                    .into_iter()
803                    .map(crate::service_protocol::messages::Header::from)
804                    .collect(),
805                parameter: input,
806                invoke_time: delay
807                    .map(|d| {
808                        u64::try_from(d.as_millis())
809                            .expect("millis since Unix epoch should fit in u64")
810                    })
811                    .unwrap_or_default(),
812                invocation_id_notification_idx: call_invocation_id_completion_id,
813                ..Default::default()
814            },
815            call_invocation_id_completion_id,
816        ))?;
817
818        if matches!(
819            self.options.implicit_cancellation,
820            ImplicitCancellationOption::Enabled {
821                cancel_children_one_way_calls: true,
822                ..
823            }
824        ) {
825            self.tracked_invocation_ids.push(TrackedInvocationId {
826                handle: invocation_id_notification_handle,
827                invocation_id: None,
828            })
829        }
830
831        Ok(SendHandle {
832            invocation_id_notification_handle,
833        })
834    }
835
836    #[instrument(
837        level = "trace",
838        skip(self),
839        fields(
840            restate.invocation.id = self.debug_invocation_id(),
841            restate.protocol.state = self.debug_state(),
842            restate.journal.command_index = self.context.journal.command_index(),
843            restate.protocol.version = %self.version
844        ),
845        ret
846    )]
847    fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
848        invocation_debug_logs!(self, "Executing 'Create awakeable'");
849
850        let signal_id = self.context.journal.next_signal_notification_id();
851
852        let handle = self.do_transition(CreateSignalHandle(
853            "SysAwakeable",
854            NotificationId::SignalId(signal_id),
855        ))?;
856
857        Ok((
858            awakeable_id_str(&self.context.expect_start_info().id, signal_id),
859            handle,
860        ))
861    }
862
863    #[instrument(
864        level = "trace",
865        skip(self, value),
866        fields(
867            restate.invocation.id = self.debug_invocation_id(),
868            restate.protocol.state = self.debug_state(),
869            restate.journal.command_index = self.context.journal.command_index(),
870            restate.protocol.version = %self.version
871        ),
872        ret
873    )]
874    fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
875        invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
876        self.do_transition(SysNonCompletableEntry(
877            "SysCompleteAwakeable",
878            CompleteAwakeableCommandMessage {
879                awakeable_id: id,
880                result: Some(match value {
881                    NonEmptyValue::Success(s) => {
882                        complete_awakeable_command_message::Result::Value(s.into())
883                    }
884                    NonEmptyValue::Failure(f) => {
885                        complete_awakeable_command_message::Result::Failure(f.into())
886                    }
887                }),
888                ..Default::default()
889            },
890        ))
891    }
892
893    #[instrument(
894        level = "trace",
895        skip(self),
896        fields(
897            restate.invocation.id = self.debug_invocation_id(),
898            restate.protocol.state = self.debug_state(),
899            restate.journal.command_index = self.context.journal.command_index(),
900            restate.protocol.version = %self.version
901        ),
902        ret
903    )]
904    fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
905        invocation_debug_logs!(self, "Executing 'Create named signal'");
906
907        self.do_transition(CreateSignalHandle(
908            "SysCreateNamedSignal",
909            NotificationId::SignalName(signal_name),
910        ))
911    }
912
913    #[instrument(
914        level = "trace",
915        skip(self, value),
916        fields(
917            restate.invocation.id = self.debug_invocation_id(),
918            restate.protocol.state = self.debug_state(),
919            restate.journal.command_index = self.context.journal.command_index(),
920            restate.protocol.version = %self.version
921        ),
922        ret
923    )]
924    fn sys_complete_signal(
925        &mut self,
926        target_invocation_id: String,
927        signal_name: String,
928        value: NonEmptyValue,
929    ) -> VMResult<()> {
930        invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
931        self.do_transition(SysNonCompletableEntry(
932            "SysCompleteAwakeable",
933            SendSignalCommandMessage {
934                target_invocation_id,
935                signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
936                result: Some(match value {
937                    NonEmptyValue::Success(s) => {
938                        send_signal_command_message::Result::Value(s.into())
939                    }
940                    NonEmptyValue::Failure(f) => {
941                        send_signal_command_message::Result::Failure(f.into())
942                    }
943                }),
944                ..Default::default()
945            },
946        ))
947    }
948
949    #[instrument(
950        level = "trace",
951        skip(self),
952        fields(
953            restate.invocation.id = self.debug_invocation_id(),
954            restate.protocol.state = self.debug_state(),
955            restate.journal.command_index = self.context.journal.command_index(),
956            restate.protocol.version = %self.version
957        ),
958        ret
959    )]
960    fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
961        invocation_debug_logs!(self, "Executing 'Await promise {key}'");
962
963        let result_completion_id = self.context.journal.next_completion_notification_id();
964        self.do_transition(SysSimpleCompletableEntry(
965            "SysGetPromise",
966            GetPromiseCommandMessage {
967                key,
968                result_completion_id,
969                ..Default::default()
970            },
971            result_completion_id,
972        ))
973    }
974
975    #[instrument(
976        level = "trace",
977        skip(self),
978        fields(
979            restate.invocation.id = self.debug_invocation_id(),
980            restate.protocol.state = self.debug_state(),
981            restate.journal.command_index = self.context.journal.command_index(),
982            restate.protocol.version = %self.version
983        ),
984        ret
985    )]
986    fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
987        invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
988
989        let result_completion_id = self.context.journal.next_completion_notification_id();
990        self.do_transition(SysSimpleCompletableEntry(
991            "SysPeekPromise",
992            PeekPromiseCommandMessage {
993                key,
994                result_completion_id,
995                ..Default::default()
996            },
997            result_completion_id,
998        ))
999    }
1000
1001    #[instrument(
1002        level = "trace",
1003        skip(self, value),
1004        fields(
1005            restate.invocation.id = self.debug_invocation_id(),
1006            restate.protocol.state = self.debug_state(),
1007            restate.journal.command_index = self.context.journal.command_index(),
1008            restate.protocol.version = %self.version
1009        ),
1010        ret
1011    )]
1012    fn sys_complete_promise(
1013        &mut self,
1014        key: String,
1015        value: NonEmptyValue,
1016    ) -> VMResult<NotificationHandle> {
1017        invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1018
1019        let result_completion_id = self.context.journal.next_completion_notification_id();
1020        self.do_transition(SysSimpleCompletableEntry(
1021            "SysCompletePromise",
1022            CompletePromiseCommandMessage {
1023                key,
1024                completion: Some(match value {
1025                    NonEmptyValue::Success(s) => {
1026                        complete_promise_command_message::Completion::CompletionValue(s.into())
1027                    }
1028                    NonEmptyValue::Failure(f) => {
1029                        complete_promise_command_message::Completion::CompletionFailure(f.into())
1030                    }
1031                }),
1032                result_completion_id,
1033                ..Default::default()
1034            },
1035            result_completion_id,
1036        ))
1037    }
1038
1039    #[instrument(
1040        level = "trace",
1041        skip(self),
1042        fields(
1043            restate.invocation.id = self.debug_invocation_id(),
1044            restate.protocol.state = self.debug_state(),
1045            restate.journal.command_index = self.context.journal.command_index(),
1046            restate.protocol.version = %self.version
1047        ),
1048        ret
1049    )]
1050    fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1051        match self.do_transition(SysRun(name.clone())) {
1052            Ok(handle) => {
1053                if enabled!(Level::DEBUG) {
1054                    // Store the name, we need it later when completing
1055                    self.sys_run_names.insert(handle, name);
1056                }
1057                Ok(handle)
1058            }
1059            Err(e) => Err(e),
1060        }
1061    }
1062
1063    #[instrument(
1064        level = "trace",
1065        skip(self, value, retry_policy),
1066        fields(
1067            restate.invocation.id = self.debug_invocation_id(),
1068            restate.protocol.state = self.debug_state(),
1069            restate.journal.command_index = self.context.journal.command_index(),
1070            restate.protocol.version = %self.version
1071        ),
1072        ret
1073    )]
1074    fn propose_run_completion(
1075        &mut self,
1076        notification_handle: NotificationHandle,
1077        value: RunExitResult,
1078        retry_policy: RetryPolicy,
1079    ) -> VMResult<()> {
1080        if enabled!(Level::DEBUG) {
1081            let name: &str = self
1082                .sys_run_names
1083                .get(&notification_handle)
1084                .map(String::as_str)
1085                .unwrap_or_default();
1086            match &value {
1087                RunExitResult::Success(_) => {
1088                    invocation_debug_logs!(self, "Journaling run '{name}' success result");
1089                }
1090                RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1091                    invocation_debug_logs!(
1092                        self,
1093                        "Journaling run '{name}' terminal failure {code} result"
1094                    );
1095                }
1096                RunExitResult::RetryableFailure { .. } => {
1097                    invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1098                }
1099            }
1100        }
1101
1102        self.do_transition(ProposeRunCompletion(
1103            notification_handle,
1104            value,
1105            retry_policy,
1106        ))
1107    }
1108
1109    #[instrument(
1110        level = "trace",
1111        skip(self),
1112        fields(
1113            restate.invocation.id = self.debug_invocation_id(),
1114            restate.protocol.state = self.debug_state(),
1115            restate.journal.command_index = self.context.journal.command_index(),
1116            restate.protocol.version = %self.version
1117        ),
1118        ret
1119    )]
1120    fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1121        invocation_debug_logs!(
1122            self,
1123            "Executing 'Cancel invocation' of {target_invocation_id}"
1124        );
1125        self.verify_feature_support("cancel invocation", Version::V3)?;
1126        self.do_transition(SysNonCompletableEntry(
1127            "SysCancelInvocation",
1128            SendSignalCommandMessage {
1129                target_invocation_id,
1130                signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1131                result: Some(send_signal_command_message::Result::Void(Default::default())),
1132                ..Default::default()
1133            },
1134        ))
1135    }
1136
1137    #[instrument(
1138        level = "trace",
1139        skip(self),
1140        fields(
1141            restate.invocation.id = self.debug_invocation_id(),
1142            restate.protocol.state = self.debug_state(),
1143            restate.journal.command_index = self.context.journal.command_index(),
1144            restate.protocol.version = %self.version
1145        ),
1146        ret
1147    )]
1148    fn sys_attach_invocation(
1149        &mut self,
1150        target: AttachInvocationTarget,
1151    ) -> VMResult<NotificationHandle> {
1152        invocation_debug_logs!(self, "Executing 'Attach invocation'");
1153        self.verify_feature_support("attach invocation", Version::V3)?;
1154
1155        let result_completion_id = self.context.journal.next_completion_notification_id();
1156        self.do_transition(SysSimpleCompletableEntry(
1157            "SysAttachInvocation",
1158            AttachInvocationCommandMessage {
1159                target: Some(match target {
1160                    AttachInvocationTarget::InvocationId(id) => {
1161                        attach_invocation_command_message::Target::InvocationId(id)
1162                    }
1163                    AttachInvocationTarget::WorkflowId { name, key } => {
1164                        attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1165                            workflow_name: name,
1166                            workflow_key: key,
1167                        })
1168                    }
1169                    AttachInvocationTarget::IdempotencyId {
1170                        service_name,
1171                        service_key,
1172                        handler_name,
1173                        idempotency_key,
1174                    } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1175                        IdempotentRequestTarget {
1176                            service_name,
1177                            service_key,
1178                            handler_name,
1179                            idempotency_key,
1180                        },
1181                    ),
1182                }),
1183                result_completion_id,
1184                ..Default::default()
1185            },
1186            result_completion_id,
1187        ))
1188    }
1189
1190    #[instrument(
1191        level = "trace",
1192        skip(self),
1193        fields(
1194            restate.invocation.id = self.debug_invocation_id(),
1195            restate.protocol.state = self.debug_state(),
1196            restate.journal.command_index = self.context.journal.command_index(),
1197            restate.protocol.version = %self.version
1198        ),
1199        ret
1200    )]
1201    fn sys_get_invocation_output(
1202        &mut self,
1203        target: AttachInvocationTarget,
1204    ) -> VMResult<NotificationHandle> {
1205        invocation_debug_logs!(self, "Executing 'Get invocation output'");
1206        self.verify_feature_support("get invocation output", Version::V3)?;
1207
1208        let result_completion_id = self.context.journal.next_completion_notification_id();
1209        self.do_transition(SysSimpleCompletableEntry(
1210            "SysGetInvocationOutput",
1211            GetInvocationOutputCommandMessage {
1212                target: Some(match target {
1213                    AttachInvocationTarget::InvocationId(id) => {
1214                        get_invocation_output_command_message::Target::InvocationId(id)
1215                    }
1216                    AttachInvocationTarget::WorkflowId { name, key } => {
1217                        get_invocation_output_command_message::Target::WorkflowTarget(
1218                            WorkflowTarget {
1219                                workflow_name: name,
1220                                workflow_key: key,
1221                            },
1222                        )
1223                    }
1224                    AttachInvocationTarget::IdempotencyId {
1225                        service_name,
1226                        service_key,
1227                        handler_name,
1228                        idempotency_key,
1229                    } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1230                        IdempotentRequestTarget {
1231                            service_name,
1232                            service_key,
1233                            handler_name,
1234                            idempotency_key,
1235                        },
1236                    ),
1237                }),
1238                result_completion_id,
1239                ..Default::default()
1240            },
1241            result_completion_id,
1242        ))
1243    }
1244
1245    #[instrument(
1246        level = "trace",
1247        skip(self, value),
1248        fields(
1249            restate.invocation.id = self.debug_invocation_id(),
1250            restate.protocol.state = self.debug_state(),
1251            restate.journal.command_index = self.context.journal.command_index(),
1252            restate.protocol.version = %self.version
1253        ),
1254        ret
1255    )]
1256    fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1257        match &value {
1258            NonEmptyValue::Success(_) => {
1259                invocation_debug_logs!(self, "Writing invocation result success value");
1260            }
1261            NonEmptyValue::Failure(_) => {
1262                invocation_debug_logs!(self, "Writing invocation result failure value");
1263            }
1264        }
1265        self.do_transition(SysNonCompletableEntry(
1266            "SysWriteOutput",
1267            OutputCommandMessage {
1268                result: Some(match value {
1269                    NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1270                    NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1271                }),
1272                ..OutputCommandMessage::default()
1273            },
1274        ))
1275    }
1276
1277    #[instrument(
1278        level = "trace",
1279        skip(self),
1280        fields(
1281            restate.invocation.id = self.debug_invocation_id(),
1282            restate.protocol.state = self.debug_state(),
1283            restate.journal.command_index = self.context.journal.command_index(),
1284            restate.protocol.version = %self.version
1285        ),
1286        ret
1287    )]
1288    fn sys_end(&mut self) -> Result<(), Error> {
1289        invocation_debug_logs!(self, "End of the invocation");
1290        self.do_transition(SysEnd)
1291    }
1292
1293    fn is_waiting_preflight(&self) -> bool {
1294        matches!(
1295            &self.last_transition,
1296            Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1297        )
1298    }
1299
1300    fn is_replaying(&self) -> bool {
1301        matches!(&self.last_transition, Ok(State::Replaying { .. }))
1302    }
1303
1304    fn is_processing(&self) -> bool {
1305        matches!(&self.last_transition, Ok(State::Processing { .. }))
1306    }
1307}
1308
1309const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1310    .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1311    .with_encode_padding(false);
1312const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1313
1314const AWAKEABLE_PREFIX: &str = "sign_1";
1315
1316fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1317    let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1318    input_buf.put_slice(id);
1319    input_buf.put_u32(completion_index);
1320    format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1321}