1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3 attach_invocation_command_message, complete_awakeable_command_message,
4 complete_promise_command_message, get_invocation_output_command_message,
5 output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6 CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7 CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8 GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9 OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10 SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14 ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15 EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19 AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20 ImplicitCancellationOption, Input, NonDeterministicChecksOption, NonEmptyValue,
21 NotificationHandle, PayloadOptions, ResponseHead, RetryPolicy, RunExitResult, SendHandle,
22 TakeOutputResult, Target, TerminalFailure, VMOptions, VMResult, Value,
23 CANCEL_NOTIFICATION_HANDLE,
24};
25use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
26use base64::{alphabet, Engine};
27use bytes::{Buf, BufMut, Bytes, BytesMut};
28use context::{AsyncResultsState, Context, Output, RunState};
29use std::borrow::Cow;
30use std::collections::{HashMap, VecDeque};
31use std::mem::size_of;
32use std::time::Duration;
33use std::{fmt, mem};
34use strum::IntoStaticStr;
35use tracing::{debug, enabled, instrument, Level};
36
37mod context;
38pub(crate) mod errors;
39mod transitions;
40
41const CONTENT_TYPE: &str = "content-type";
42
43#[derive(Debug, IntoStaticStr)]
44pub(crate) enum State {
45 WaitingStart,
46 WaitingReplayEntries {
47 received_entries: u32,
48 commands: VecDeque<RawMessage>,
49 async_results: AsyncResultsState,
50 },
51 Replaying {
52 commands: VecDeque<RawMessage>,
53 run_state: RunState,
54 async_results: AsyncResultsState,
55 },
56 Processing {
57 processing_first_entry: bool,
58 run_state: RunState,
59 async_results: AsyncResultsState,
60 },
61 Closed,
62}
63
64impl State {
65 fn as_unexpected_state(&self, event: impl ToString) -> Error {
66 if matches!(self, State::Closed) {
67 return ClosedError::new(event.to_string()).into();
68 }
69 UnexpectedStateError::new(self.into(), event.to_string()).into()
70 }
71}
72
73struct TrackedInvocationId {
74 handle: NotificationHandle,
75 invocation_id: Option<String>,
76}
77
78impl TrackedInvocationId {
79 fn is_resolved(&self) -> bool {
80 self.invocation_id.is_some()
81 }
82}
83
84pub struct CoreVM {
85 options: VMOptions,
86
87 decoder: Decoder,
89
90 context: Context,
92 last_transition: Result<State, Error>,
93
94 tracked_invocation_ids: Vec<TrackedInvocationId>,
96
97 sys_run_names: HashMap<NotificationHandle, String>,
99}
100
101impl CoreVM {
102 fn debug_invocation_id(&self) -> &str {
104 if let Some(start_info) = self.context.start_info() {
105 &start_info.debug_id
106 } else {
107 ""
108 }
109 }
110
111 fn debug_state(&self) -> &'static str {
112 match &self.last_transition {
113 Ok(s) => s.into(),
114 Err(_) => "Failed",
115 }
116 }
117
118 fn verify_error_metadata_feature_support(&mut self, value: &NonEmptyValue) -> VMResult<()> {
119 if let NonEmptyValue::Failure(f) = value {
120 if !f.metadata.is_empty() {
121 self.verify_feature_support("terminal error metadata", Version::V6)?;
122 }
123 }
124 Ok(())
125 }
126
127 #[allow(dead_code)]
128 fn verify_feature_support(
129 &mut self,
130 feature: &'static str,
131 minimum_required_protocol: Version,
132 ) -> VMResult<()> {
133 if self.context.negotiated_protocol_version < minimum_required_protocol {
134 return self.do_transition(HitError(
135 UnsupportedFeatureForNegotiatedVersion::new(
136 feature,
137 self.context.negotiated_protocol_version,
138 minimum_required_protocol,
139 )
140 .into(),
141 ));
142 }
143 Ok(())
144 }
145
146 fn _is_completed(&self, handle: NotificationHandle) -> bool {
147 match &self.last_transition {
148 Ok(State::Replaying { async_results, .. })
149 | Ok(State::Processing { async_results, .. }) => {
150 async_results.is_handle_completed(handle)
151 }
152 _ => false,
153 }
154 }
155
156 fn _do_progress(
157 &mut self,
158 any_handle: Vec<NotificationHandle>,
159 ) -> Result<DoProgressResponse, Error> {
160 match self.do_transition(DoProgress(any_handle)) {
161 Ok(Ok(do_progress_response)) => Ok(do_progress_response),
162 Ok(Err(_)) => Err(SUSPENDED),
163 Err(e) => Err(e),
164 }
165 }
166
167 fn is_implicit_cancellation_enabled(&self) -> bool {
168 matches!(
169 self.options.implicit_cancellation,
170 ImplicitCancellationOption::Enabled { .. }
171 )
172 }
173}
174
175impl fmt::Debug for CoreVM {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 let mut s = f.debug_struct("CoreVM");
178 s.field("version", &self.context.negotiated_protocol_version);
179
180 if let Some(start_info) = self.context.start_info() {
181 s.field("invocation_id", &start_info.debug_id);
182 }
183
184 match &self.last_transition {
185 Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
186 Err(_) => s.field("last_transition", &"Errored"),
187 };
188
189 s.field("command_index", &self.context.journal.command_index())
190 .field(
191 "notification_index",
192 &self.context.journal.notification_index(),
193 )
194 .finish()
195 }
196}
197
198#[allow(unused)]
200const fn is_send<T: Send>() {}
201const _: () = is_send::<CoreVM>();
202
203macro_rules! invocation_debug_logs {
205 ($this:expr, $($arg:tt)*) => {
206 if ($this.is_processing()) {
207 tracing::debug!($($arg)*)
208 }
209 };
210}
211
212impl super::VM for CoreVM {
213 #[instrument(level = "trace", skip(request_headers), ret)]
214 fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
215 let version = request_headers
216 .extract(CONTENT_TYPE)
217 .map_err(|e| {
218 Error::new(
219 errors::codes::BAD_REQUEST,
220 format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
221 )
222 })?
223 .ok_or(errors::MISSING_CONTENT_TYPE)?
224 .parse::<Version>()?;
225
226 if version < Version::minimum_supported_version()
227 || version > Version::maximum_supported_version()
228 {
229 return Err(Error::new(
230 errors::codes::UNSUPPORTED_MEDIA_TYPE,
231 format!(
232 "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
233 You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
234 version,
235 Version::minimum_supported_version(),
236 Version::maximum_supported_version()
237 ),
238 ));
239 }
240 let non_deterministic_checks_ignore_payload_equality = matches!(
241 options.non_determinism_checks,
242 NonDeterministicChecksOption::PayloadChecksDisabled
243 );
244
245 Ok(Self {
246 options,
247 decoder: Decoder::new(version),
248 context: Context {
249 input_is_closed: false,
250 output: Output::new(version),
251 start_info: None,
252 journal: Default::default(),
253 eager_state: Default::default(),
254 non_deterministic_checks_ignore_payload_equality,
255 negotiated_protocol_version: version,
256 },
257 last_transition: Ok(State::WaitingStart),
258 tracked_invocation_ids: vec![],
259 sys_run_names: HashMap::with_capacity(0),
260 })
261 }
262
263 #[instrument(
264 level = "trace",
265 skip(self),
266 fields(
267 restate.invocation.id = self.debug_invocation_id(),
268 restate.protocol.state = self.debug_state(),
269 restate.journal.command_index = self.context.journal.command_index(),
270 restate.protocol.version = %self.context.negotiated_protocol_version
271 ),
272 ret
273 )]
274 fn get_response_head(&self) -> ResponseHead {
275 ResponseHead {
276 status_code: 200,
277 headers: vec![Header {
278 key: Cow::Borrowed(CONTENT_TYPE),
279 value: Cow::Borrowed(self.context.negotiated_protocol_version.content_type()),
280 }],
281 version: self.context.negotiated_protocol_version,
282 }
283 }
284
285 #[instrument(
286 level = "trace",
287 skip(self),
288 fields(
289 restate.invocation.id = self.debug_invocation_id(),
290 restate.protocol.state = self.debug_state(),
291 restate.journal.command_index = self.context.journal.command_index(),
292 restate.protocol.version = %self.context.negotiated_protocol_version
293 ),
294 ret
295 )]
296 fn notify_input(&mut self, buffer: Bytes) {
297 self.decoder.push(buffer);
298 loop {
299 match self.decoder.consume_next() {
300 Ok(Some(msg)) => {
301 if self.do_transition(NewMessage(msg)).is_err() {
302 return;
303 }
304 }
305 Ok(None) => {
306 return;
307 }
308 Err(e) => {
309 if self.do_transition(HitError(e.into())).is_err() {
310 return;
311 }
312 }
313 }
314 }
315 }
316
317 #[instrument(
318 level = "trace",
319 skip(self),
320 fields(
321 restate.invocation.id = self.debug_invocation_id(),
322 restate.protocol.state = self.debug_state(),
323 restate.journal.command_index = self.context.journal.command_index(),
324 restate.protocol.version = %self.context.negotiated_protocol_version
325 ),
326 ret
327 )]
328 fn notify_input_closed(&mut self) {
329 self.context.input_is_closed = true;
330 let _ = self.do_transition(NotifyInputClosed);
331 }
332
333 #[instrument(
334 level = "trace",
335 skip(self),
336 fields(
337 restate.invocation.id = self.debug_invocation_id(),
338 restate.protocol.state = self.debug_state(),
339 restate.journal.command_index = self.context.journal.command_index(),
340 restate.protocol.version = %self.context.negotiated_protocol_version
341 ),
342 ret
343 )]
344 fn notify_error(
345 &mut self,
346 mut error: Error,
347 command_relationship: Option<CommandRelationship>,
348 ) {
349 if let Some(command_relationship) = command_relationship {
350 error = error.with_related_command_metadata(
351 self.context
352 .journal
353 .resolve_related_command(command_relationship),
354 );
355 }
356
357 let _ = self.do_transition(HitError(error));
358 }
359
360 #[instrument(
361 level = "trace",
362 skip(self),
363 fields(
364 restate.invocation.id = self.debug_invocation_id(),
365 restate.protocol.state = self.debug_state(),
366 restate.journal.command_index = self.context.journal.command_index(),
367 restate.protocol.version = %self.context.negotiated_protocol_version
368 ),
369 ret
370 )]
371 fn take_output(&mut self) -> TakeOutputResult {
372 if self.context.output.buffer.has_remaining() {
373 TakeOutputResult::Buffer(
374 self.context
375 .output
376 .buffer
377 .copy_to_bytes(self.context.output.buffer.remaining()),
378 )
379 } else if !self.context.output.is_closed() {
380 TakeOutputResult::Buffer(Bytes::default())
381 } else {
382 TakeOutputResult::EOF
383 }
384 }
385
386 #[instrument(
387 level = "trace",
388 skip(self),
389 fields(
390 restate.invocation.id = self.debug_invocation_id(),
391 restate.protocol.state = self.debug_state(),
392 restate.journal.command_index = self.context.journal.command_index(),
393 restate.protocol.version = %self.context.negotiated_protocol_version
394 ),
395 ret
396 )]
397 fn is_ready_to_execute(&self) -> Result<bool, Error> {
398 match &self.last_transition {
399 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
400 Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
401 Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
402 Err(e) => Err(e.clone()),
403 }
404 }
405
406 #[instrument(
407 level = "trace",
408 skip(self),
409 fields(
410 restate.invocation.id = self.debug_invocation_id(),
411 restate.protocol.state = self.debug_state(),
412 restate.journal.command_index = self.context.journal.command_index(),
413 restate.protocol.version = %self.context.negotiated_protocol_version
414 ),
415 ret
416 )]
417 fn is_completed(&self, handle: NotificationHandle) -> bool {
418 self._is_completed(handle)
419 }
420
421 #[instrument(
422 level = "trace",
423 skip(self),
424 fields(
425 restate.invocation.id = self.debug_invocation_id(),
426 restate.protocol.state = self.debug_state(),
427 restate.journal.command_index = self.context.journal.command_index(),
428 restate.protocol.version = %self.context.negotiated_protocol_version
429 ),
430 ret
431 )]
432 fn do_progress(
433 &mut self,
434 mut any_handle: Vec<NotificationHandle>,
435 ) -> VMResult<DoProgressResponse> {
436 if self.is_implicit_cancellation_enabled() {
437 any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
439
440 match self._do_progress(any_handle) {
441 Ok(DoProgressResponse::AnyCompleted) => {
442 if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
444 for i in 0..self.tracked_invocation_ids.len() {
446 if self.tracked_invocation_ids[i].is_resolved() {
447 continue;
448 }
449
450 let handle = self.tracked_invocation_ids[i].handle;
451
452 match self._do_progress(vec![handle]) {
454 Ok(DoProgressResponse::AnyCompleted) => {
455 let invocation_id = match self.do_transition(CopyNotification(handle)) {
456 Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
457 Ok(Err(_)) => Err(SUSPENDED),
458 _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
459 }?;
460
461 self.tracked_invocation_ids[i].invocation_id =
463 Some(invocation_id);
464 }
465 res => return res,
466 }
467 }
468
469 for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
471 self.sys_cancel_invocation(
472 tracked_invocation_id
473 .invocation_id
474 .expect("We resolved before all the invocation ids"),
475 )?;
476 }
477
478 let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
480
481 Ok(DoProgressResponse::CancelSignalReceived)
483 } else {
484 Ok(DoProgressResponse::AnyCompleted)
485 }
486 }
487 res => res,
488 }
489 } else {
490 self._do_progress(any_handle)
491 }
492 }
493
494 #[instrument(
495 level = "trace",
496 skip(self),
497 fields(
498 restate.invocation.id = self.debug_invocation_id(),
499 restate.protocol.state = self.debug_state(),
500 restate.journal.command_index = self.context.journal.command_index(),
501 restate.protocol.version = %self.context.negotiated_protocol_version
502 ),
503 ret
504 )]
505 fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
506 match self.do_transition(TakeNotification(handle)) {
507 Ok(Ok(Some(value))) => {
508 if self.is_implicit_cancellation_enabled() {
509 if let Ok(found) = self
512 .tracked_invocation_ids
513 .binary_search_by(|tracked| tracked.handle.cmp(&handle))
514 {
515 let Value::InvocationId(invocation_id) = &value else {
516 panic!("Expecting an invocation id here, but got {value:?}");
517 };
518 self.tracked_invocation_ids
520 .get_mut(found)
521 .unwrap()
522 .invocation_id = Some(invocation_id.clone());
523 }
524 }
525
526 Ok(Some(value))
527 }
528 Ok(Ok(None)) => Ok(None),
529 Ok(Err(_)) => Err(SUSPENDED),
530 Err(e) => Err(e),
531 }
532 }
533
534 #[instrument(
535 level = "trace",
536 skip(self),
537 fields(
538 restate.invocation.id = self.debug_invocation_id(),
539 restate.protocol.state = self.debug_state(),
540 restate.journal.command_index = self.context.journal.command_index(),
541 restate.protocol.version = %self.context.negotiated_protocol_version
542 ),
543 ret
544 )]
545 fn sys_input(&mut self) -> Result<Input, Error> {
546 self.do_transition(SysInput)
547 }
548
549 #[instrument(
550 level = "trace",
551 skip(self),
552 fields(
553 restate.invocation.id = self.debug_invocation_id(),
554 restate.protocol.state = self.debug_state(),
555 restate.journal.command_index = self.context.journal.command_index(),
556 restate.protocol.version = %self.context.negotiated_protocol_version
557 ),
558 ret
559 )]
560 fn sys_state_get(
561 &mut self,
562 key: String,
563 options: PayloadOptions,
564 ) -> Result<NotificationHandle, Error> {
565 invocation_debug_logs!(self, "Executing 'Get state {key}'");
566 self.do_transition(SysStateGet(key, options))
567 }
568
569 #[instrument(
570 level = "trace",
571 skip(self),
572 fields(
573 restate.invocation.id = self.debug_invocation_id(),
574 restate.protocol.state = self.debug_state(),
575 restate.journal.command_index = self.context.journal.command_index(),
576 restate.protocol.version = %self.context.negotiated_protocol_version
577 ),
578 ret
579 )]
580 fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
581 invocation_debug_logs!(self, "Executing 'Get state keys'");
582 self.do_transition(SysStateGetKeys)
583 }
584
585 #[instrument(
586 level = "trace",
587 skip(self, value),
588 fields(
589 restate.invocation.id = self.debug_invocation_id(),
590 restate.protocol.state = self.debug_state(),
591 restate.journal.command_index = self.context.journal.command_index(),
592 restate.protocol.version = %self.context.negotiated_protocol_version
593 ),
594 ret
595 )]
596 fn sys_state_set(
597 &mut self,
598 key: String,
599 value: Bytes,
600 options: PayloadOptions,
601 ) -> VMResult<()> {
602 invocation_debug_logs!(self, "Executing 'Set state {key}'");
603 self.context.eager_state.set(key.clone(), value.clone());
604 self.do_transition(SysNonCompletableEntry(
605 SetStateCommandMessage {
606 key: Bytes::from(key.into_bytes()),
607 value: Some(value.into()),
608 ..SetStateCommandMessage::default()
609 },
610 options,
611 ))
612 }
613
614 #[instrument(
615 level = "trace",
616 skip(self),
617 fields(
618 restate.invocation.id = self.debug_invocation_id(),
619 restate.protocol.state = self.debug_state(),
620 restate.journal.command_index = self.context.journal.command_index(),
621 restate.protocol.version = %self.context.negotiated_protocol_version
622 ),
623 ret
624 )]
625 fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
626 invocation_debug_logs!(self, "Executing 'Clear state {key}'");
627 self.context.eager_state.clear(key.clone());
628 self.do_transition(SysNonCompletableEntry(
629 ClearStateCommandMessage {
630 key: Bytes::from(key.into_bytes()),
631 ..ClearStateCommandMessage::default()
632 },
633 PayloadOptions::default(),
634 ))
635 }
636
637 #[instrument(
638 level = "trace",
639 skip(self),
640 fields(
641 restate.invocation.id = self.debug_invocation_id(),
642 restate.protocol.state = self.debug_state(),
643 restate.journal.command_index = self.context.journal.command_index(),
644 restate.protocol.version = %self.context.negotiated_protocol_version
645 ),
646 ret
647 )]
648 fn sys_state_clear_all(&mut self) -> Result<(), Error> {
649 invocation_debug_logs!(self, "Executing 'Clear all state'");
650 self.context.eager_state.clear_all();
651 self.do_transition(SysNonCompletableEntry(
652 ClearAllStateCommandMessage::default(),
653 PayloadOptions::default(),
654 ))
655 }
656
657 #[instrument(
658 level = "trace",
659 skip(self),
660 fields(
661 restate.invocation.id = self.debug_invocation_id(),
662 restate.protocol.state = self.debug_state(),
663 restate.journal.command_index = self.context.journal.command_index(),
664 restate.protocol.version = %self.context.negotiated_protocol_version
665 ),
666 ret
667 )]
668 fn sys_sleep(
669 &mut self,
670 name: String,
671 wake_up_time_since_unix_epoch: Duration,
672 now_since_unix_epoch: Option<Duration>,
673 ) -> VMResult<NotificationHandle> {
674 if self.is_processing() {
675 match (&name, now_since_unix_epoch) {
676 (name, Some(now_since_unix_epoch)) if name.is_empty() => {
677 debug!(
678 "Executing 'Timer with duration {:?}'",
679 wake_up_time_since_unix_epoch - now_since_unix_epoch
680 );
681 }
682 (name, Some(now_since_unix_epoch)) => {
683 debug!(
684 "Executing 'Timer {name} with duration {:?}'",
685 wake_up_time_since_unix_epoch - now_since_unix_epoch
686 );
687 }
688 (name, None) if name.is_empty() => {
689 debug!("Executing 'Timer'");
690 }
691 (name, None) => {
692 debug!("Executing 'Timer named {name}'");
693 }
694 }
695 }
696
697 let completion_id = self.context.journal.next_completion_notification_id();
698
699 self.do_transition(SysSimpleCompletableEntry(
700 SleepCommandMessage {
701 wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
702 .expect("millis since Unix epoch should fit in u64"),
703 result_completion_id: completion_id,
704 name,
705 },
706 completion_id,
707 PayloadOptions::default(),
708 ))
709 }
710
711 #[instrument(
712 level = "trace",
713 skip(self, input),
714 fields(
715 restate.invocation.id = self.debug_invocation_id(),
716 restate.protocol.state = self.debug_state(),
717 restate.journal.command_index = self.context.journal.command_index(),
718 restate.protocol.version = %self.context.negotiated_protocol_version
719 ),
720 ret
721 )]
722 fn sys_call(
723 &mut self,
724 target: Target,
725 input: Bytes,
726 options: PayloadOptions,
727 ) -> VMResult<CallHandle> {
728 invocation_debug_logs!(
729 self,
730 "Executing 'Call {}/{}'",
731 target.service,
732 target.handler
733 );
734 if let Some(idempotency_key) = &target.idempotency_key {
735 if idempotency_key.is_empty() {
736 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
737 unreachable!();
738 }
739 }
740
741 let call_invocation_id_completion_id =
742 self.context.journal.next_completion_notification_id();
743 let result_completion_id = self.context.journal.next_completion_notification_id();
744
745 let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
746 CallCommandMessage {
747 service_name: target.service,
748 handler_name: target.handler,
749 key: target.key.unwrap_or_default(),
750 idempotency_key: target.idempotency_key,
751 headers: target
752 .headers
753 .into_iter()
754 .map(crate::service_protocol::messages::Header::from)
755 .collect(),
756 parameter: input,
757 invocation_id_notification_idx: call_invocation_id_completion_id,
758 result_completion_id,
759 ..Default::default()
760 },
761 vec![call_invocation_id_completion_id, result_completion_id],
762 options,
763 ))?;
764
765 if matches!(
766 self.options.implicit_cancellation,
767 ImplicitCancellationOption::Enabled {
768 cancel_children_calls: true,
769 ..
770 }
771 ) {
772 self.tracked_invocation_ids.push(TrackedInvocationId {
773 handle: handles[0],
774 invocation_id: None,
775 })
776 }
777
778 Ok(CallHandle {
779 invocation_id_notification_handle: handles[0],
780 call_notification_handle: handles[1],
781 })
782 }
783
784 #[instrument(
785 level = "trace",
786 skip(self, input),
787 fields(
788 restate.invocation.id = self.debug_invocation_id(),
789 restate.protocol.state = self.debug_state(),
790 restate.journal.command_index = self.context.journal.command_index(),
791 restate.protocol.version = %self.context.negotiated_protocol_version
792 ),
793 ret
794 )]
795 fn sys_send(
796 &mut self,
797 target: Target,
798 input: Bytes,
799 delay: Option<Duration>,
800 options: PayloadOptions,
801 ) -> VMResult<SendHandle> {
802 invocation_debug_logs!(
803 self,
804 "Executing 'Send to {}/{}'",
805 target.service,
806 target.handler
807 );
808 if let Some(idempotency_key) = &target.idempotency_key {
809 if idempotency_key.is_empty() {
810 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
811 unreachable!();
812 }
813 }
814 let call_invocation_id_completion_id =
815 self.context.journal.next_completion_notification_id();
816 let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
817 OneWayCallCommandMessage {
818 service_name: target.service,
819 handler_name: target.handler,
820 key: target.key.unwrap_or_default(),
821 idempotency_key: target.idempotency_key,
822 headers: target
823 .headers
824 .into_iter()
825 .map(crate::service_protocol::messages::Header::from)
826 .collect(),
827 parameter: input,
828 invoke_time: delay
829 .map(|d| {
830 u64::try_from(d.as_millis())
831 .expect("millis since Unix epoch should fit in u64")
832 })
833 .unwrap_or_default(),
834 invocation_id_notification_idx: call_invocation_id_completion_id,
835 ..Default::default()
836 },
837 call_invocation_id_completion_id,
838 options,
839 ))?;
840
841 if matches!(
842 self.options.implicit_cancellation,
843 ImplicitCancellationOption::Enabled {
844 cancel_children_one_way_calls: true,
845 ..
846 }
847 ) {
848 self.tracked_invocation_ids.push(TrackedInvocationId {
849 handle: invocation_id_notification_handle,
850 invocation_id: None,
851 })
852 }
853
854 Ok(SendHandle {
855 invocation_id_notification_handle,
856 })
857 }
858
859 #[instrument(
860 level = "trace",
861 skip(self),
862 fields(
863 restate.invocation.id = self.debug_invocation_id(),
864 restate.protocol.state = self.debug_state(),
865 restate.journal.command_index = self.context.journal.command_index(),
866 restate.protocol.version = %self.context.negotiated_protocol_version
867 ),
868 ret
869 )]
870 fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
871 invocation_debug_logs!(self, "Executing 'Create awakeable'");
872
873 let signal_id = self.context.journal.next_signal_notification_id();
874
875 let handle = self.do_transition(CreateSignalHandle(
876 "awakeable",
877 NotificationId::SignalId(signal_id),
878 ))?;
879
880 Ok((
881 awakeable_id_str(&self.context.expect_start_info().id, signal_id),
882 handle,
883 ))
884 }
885
886 #[instrument(
887 level = "trace",
888 skip(self, value),
889 fields(
890 restate.invocation.id = self.debug_invocation_id(),
891 restate.protocol.state = self.debug_state(),
892 restate.journal.command_index = self.context.journal.command_index(),
893 restate.protocol.version = %self.context.negotiated_protocol_version
894 ),
895 ret
896 )]
897 fn sys_complete_awakeable(
898 &mut self,
899 id: String,
900 value: NonEmptyValue,
901 options: PayloadOptions,
902 ) -> VMResult<()> {
903 invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
904 self.verify_error_metadata_feature_support(&value)?;
905 self.do_transition(SysNonCompletableEntry(
906 CompleteAwakeableCommandMessage {
907 awakeable_id: id,
908 result: Some(match value {
909 NonEmptyValue::Success(s) => {
910 complete_awakeable_command_message::Result::Value(s.into())
911 }
912 NonEmptyValue::Failure(f) => {
913 complete_awakeable_command_message::Result::Failure(f.into())
914 }
915 }),
916 ..Default::default()
917 },
918 options,
919 ))
920 }
921
922 #[instrument(
923 level = "trace",
924 skip(self),
925 fields(
926 restate.invocation.id = self.debug_invocation_id(),
927 restate.protocol.state = self.debug_state(),
928 restate.journal.command_index = self.context.journal.command_index(),
929 restate.protocol.version = %self.context.negotiated_protocol_version
930 ),
931 ret
932 )]
933 fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
934 invocation_debug_logs!(self, "Executing 'Create named signal'");
935
936 self.do_transition(CreateSignalHandle(
937 "named awakeable",
938 NotificationId::SignalName(signal_name),
939 ))
940 }
941
942 #[instrument(
943 level = "trace",
944 skip(self, value),
945 fields(
946 restate.invocation.id = self.debug_invocation_id(),
947 restate.protocol.state = self.debug_state(),
948 restate.journal.command_index = self.context.journal.command_index(),
949 restate.protocol.version = %self.context.negotiated_protocol_version
950 ),
951 ret
952 )]
953 fn sys_complete_signal(
954 &mut self,
955 target_invocation_id: String,
956 signal_name: String,
957 value: NonEmptyValue,
958 ) -> VMResult<()> {
959 invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
960 self.verify_error_metadata_feature_support(&value)?;
961 self.do_transition(SysNonCompletableEntry(
962 SendSignalCommandMessage {
963 target_invocation_id,
964 signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
965 result: Some(match value {
966 NonEmptyValue::Success(s) => {
967 send_signal_command_message::Result::Value(s.into())
968 }
969 NonEmptyValue::Failure(f) => {
970 send_signal_command_message::Result::Failure(f.into())
971 }
972 }),
973 ..Default::default()
974 },
975 PayloadOptions::default(),
976 ))
977 }
978
979 #[instrument(
980 level = "trace",
981 skip(self),
982 fields(
983 restate.invocation.id = self.debug_invocation_id(),
984 restate.protocol.state = self.debug_state(),
985 restate.journal.command_index = self.context.journal.command_index(),
986 restate.protocol.version = %self.context.negotiated_protocol_version
987 ),
988 ret
989 )]
990 fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
991 invocation_debug_logs!(self, "Executing 'Await promise {key}'");
992
993 let result_completion_id = self.context.journal.next_completion_notification_id();
994 self.do_transition(SysSimpleCompletableEntry(
995 GetPromiseCommandMessage {
996 key,
997 result_completion_id,
998 ..Default::default()
999 },
1000 result_completion_id,
1001 PayloadOptions::default(),
1002 ))
1003 }
1004
1005 #[instrument(
1006 level = "trace",
1007 skip(self),
1008 fields(
1009 restate.invocation.id = self.debug_invocation_id(),
1010 restate.protocol.state = self.debug_state(),
1011 restate.journal.command_index = self.context.journal.command_index(),
1012 restate.protocol.version = %self.context.negotiated_protocol_version
1013 ),
1014 ret
1015 )]
1016 fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
1017 invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
1018
1019 let result_completion_id = self.context.journal.next_completion_notification_id();
1020 self.do_transition(SysSimpleCompletableEntry(
1021 PeekPromiseCommandMessage {
1022 key,
1023 result_completion_id,
1024 ..Default::default()
1025 },
1026 result_completion_id,
1027 PayloadOptions::default(),
1028 ))
1029 }
1030
1031 #[instrument(
1032 level = "trace",
1033 skip(self, value),
1034 fields(
1035 restate.invocation.id = self.debug_invocation_id(),
1036 restate.protocol.state = self.debug_state(),
1037 restate.journal.command_index = self.context.journal.command_index(),
1038 restate.protocol.version = %self.context.negotiated_protocol_version
1039 ),
1040 ret
1041 )]
1042 fn sys_complete_promise(
1043 &mut self,
1044 key: String,
1045 value: NonEmptyValue,
1046 options: PayloadOptions,
1047 ) -> VMResult<NotificationHandle> {
1048 invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1049 self.verify_error_metadata_feature_support(&value)?;
1050
1051 let result_completion_id = self.context.journal.next_completion_notification_id();
1052 self.do_transition(SysSimpleCompletableEntry(
1053 CompletePromiseCommandMessage {
1054 key,
1055 completion: Some(match value {
1056 NonEmptyValue::Success(s) => {
1057 complete_promise_command_message::Completion::CompletionValue(s.into())
1058 }
1059 NonEmptyValue::Failure(f) => {
1060 complete_promise_command_message::Completion::CompletionFailure(f.into())
1061 }
1062 }),
1063 result_completion_id,
1064 ..Default::default()
1065 },
1066 result_completion_id,
1067 options,
1068 ))
1069 }
1070
1071 #[instrument(
1072 level = "trace",
1073 skip(self),
1074 fields(
1075 restate.invocation.id = self.debug_invocation_id(),
1076 restate.protocol.state = self.debug_state(),
1077 restate.journal.command_index = self.context.journal.command_index(),
1078 restate.protocol.version = %self.context.negotiated_protocol_version
1079 ),
1080 ret
1081 )]
1082 fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1083 match self.do_transition(SysRun(name.clone())) {
1084 Ok(handle) => {
1085 if enabled!(Level::DEBUG) {
1086 self.sys_run_names.insert(handle, name);
1088 }
1089 Ok(handle)
1090 }
1091 Err(e) => Err(e),
1092 }
1093 }
1094
1095 #[instrument(
1096 level = "trace",
1097 skip(self, value, retry_policy),
1098 fields(
1099 restate.invocation.id = self.debug_invocation_id(),
1100 restate.protocol.state = self.debug_state(),
1101 restate.journal.command_index = self.context.journal.command_index(),
1102 restate.protocol.version = %self.context.negotiated_protocol_version
1103 ),
1104 ret
1105 )]
1106 fn propose_run_completion(
1107 &mut self,
1108 notification_handle: NotificationHandle,
1109 value: RunExitResult,
1110 retry_policy: RetryPolicy,
1111 ) -> VMResult<()> {
1112 if enabled!(Level::DEBUG) {
1113 let name: &str = self
1114 .sys_run_names
1115 .get(¬ification_handle)
1116 .map(String::as_str)
1117 .unwrap_or_default();
1118 match &value {
1119 RunExitResult::Success(_) => {
1120 invocation_debug_logs!(self, "Journaling run '{name}' success result");
1121 }
1122 RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1123 invocation_debug_logs!(
1124 self,
1125 "Journaling run '{name}' terminal failure {code} result"
1126 );
1127 }
1128 RunExitResult::RetryableFailure { .. } => {
1129 invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1130 }
1131 }
1132 }
1133 if let RunExitResult::TerminalFailure(f) = &value {
1134 if !f.metadata.is_empty() {
1135 self.verify_feature_support("terminal error metadata", Version::V6)?;
1136 }
1137 }
1138
1139 self.do_transition(ProposeRunCompletion(
1140 notification_handle,
1141 value,
1142 retry_policy,
1143 ))
1144 }
1145
1146 #[instrument(
1147 level = "trace",
1148 skip(self),
1149 fields(
1150 restate.invocation.id = self.debug_invocation_id(),
1151 restate.protocol.state = self.debug_state(),
1152 restate.journal.command_index = self.context.journal.command_index(),
1153 restate.protocol.version = %self.context.negotiated_protocol_version
1154 ),
1155 ret
1156 )]
1157 fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1158 invocation_debug_logs!(
1159 self,
1160 "Executing 'Cancel invocation' of {target_invocation_id}"
1161 );
1162 self.do_transition(SysNonCompletableEntry(
1163 SendSignalCommandMessage {
1164 target_invocation_id,
1165 signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1166 result: Some(send_signal_command_message::Result::Void(Default::default())),
1167 ..Default::default()
1168 },
1169 PayloadOptions::default(),
1170 ))
1171 }
1172
1173 #[instrument(
1174 level = "trace",
1175 skip(self),
1176 fields(
1177 restate.invocation.id = self.debug_invocation_id(),
1178 restate.protocol.state = self.debug_state(),
1179 restate.journal.command_index = self.context.journal.command_index(),
1180 restate.protocol.version = %self.context.negotiated_protocol_version
1181 ),
1182 ret
1183 )]
1184 fn sys_attach_invocation(
1185 &mut self,
1186 target: AttachInvocationTarget,
1187 ) -> VMResult<NotificationHandle> {
1188 invocation_debug_logs!(self, "Executing 'Attach invocation'");
1189
1190 let result_completion_id = self.context.journal.next_completion_notification_id();
1191 self.do_transition(SysSimpleCompletableEntry(
1192 AttachInvocationCommandMessage {
1193 target: Some(match target {
1194 AttachInvocationTarget::InvocationId(id) => {
1195 attach_invocation_command_message::Target::InvocationId(id)
1196 }
1197 AttachInvocationTarget::WorkflowId { name, key } => {
1198 attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1199 workflow_name: name,
1200 workflow_key: key,
1201 })
1202 }
1203 AttachInvocationTarget::IdempotencyId {
1204 service_name,
1205 service_key,
1206 handler_name,
1207 idempotency_key,
1208 } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1209 IdempotentRequestTarget {
1210 service_name,
1211 service_key,
1212 handler_name,
1213 idempotency_key,
1214 },
1215 ),
1216 }),
1217 result_completion_id,
1218 ..Default::default()
1219 },
1220 result_completion_id,
1221 PayloadOptions::default(),
1222 ))
1223 }
1224
1225 #[instrument(
1226 level = "trace",
1227 skip(self),
1228 fields(
1229 restate.invocation.id = self.debug_invocation_id(),
1230 restate.protocol.state = self.debug_state(),
1231 restate.journal.command_index = self.context.journal.command_index(),
1232 restate.protocol.version = %self.context.negotiated_protocol_version
1233 ),
1234 ret
1235 )]
1236 fn sys_get_invocation_output(
1237 &mut self,
1238 target: AttachInvocationTarget,
1239 ) -> VMResult<NotificationHandle> {
1240 invocation_debug_logs!(self, "Executing 'Get invocation output'");
1241
1242 let result_completion_id = self.context.journal.next_completion_notification_id();
1243 self.do_transition(SysSimpleCompletableEntry(
1244 GetInvocationOutputCommandMessage {
1245 target: Some(match target {
1246 AttachInvocationTarget::InvocationId(id) => {
1247 get_invocation_output_command_message::Target::InvocationId(id)
1248 }
1249 AttachInvocationTarget::WorkflowId { name, key } => {
1250 get_invocation_output_command_message::Target::WorkflowTarget(
1251 WorkflowTarget {
1252 workflow_name: name,
1253 workflow_key: key,
1254 },
1255 )
1256 }
1257 AttachInvocationTarget::IdempotencyId {
1258 service_name,
1259 service_key,
1260 handler_name,
1261 idempotency_key,
1262 } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1263 IdempotentRequestTarget {
1264 service_name,
1265 service_key,
1266 handler_name,
1267 idempotency_key,
1268 },
1269 ),
1270 }),
1271 result_completion_id,
1272 ..Default::default()
1273 },
1274 result_completion_id,
1275 PayloadOptions::default(),
1276 ))
1277 }
1278
1279 #[instrument(
1280 level = "trace",
1281 skip(self, value),
1282 fields(
1283 restate.invocation.id = self.debug_invocation_id(),
1284 restate.protocol.state = self.debug_state(),
1285 restate.journal.command_index = self.context.journal.command_index(),
1286 restate.protocol.version = %self.context.negotiated_protocol_version
1287 ),
1288 ret
1289 )]
1290 fn sys_write_output(&mut self, value: NonEmptyValue, options: PayloadOptions) -> VMResult<()> {
1291 match &value {
1292 NonEmptyValue::Success(_) => {
1293 invocation_debug_logs!(self, "Writing invocation result success value");
1294 }
1295 NonEmptyValue::Failure(_) => {
1296 invocation_debug_logs!(self, "Writing invocation result failure value");
1297 }
1298 }
1299 self.verify_error_metadata_feature_support(&value)?;
1300 self.do_transition(SysNonCompletableEntry(
1301 OutputCommandMessage {
1302 result: Some(match value {
1303 NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1304 NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1305 }),
1306 ..OutputCommandMessage::default()
1307 },
1308 options,
1309 ))
1310 }
1311
1312 #[instrument(
1313 level = "trace",
1314 skip(self),
1315 fields(
1316 restate.invocation.id = self.debug_invocation_id(),
1317 restate.protocol.state = self.debug_state(),
1318 restate.journal.command_index = self.context.journal.command_index(),
1319 restate.protocol.version = %self.context.negotiated_protocol_version
1320 ),
1321 ret
1322 )]
1323 fn sys_end(&mut self) -> Result<(), Error> {
1324 invocation_debug_logs!(self, "End of the invocation");
1325 self.do_transition(SysEnd)
1326 }
1327
1328 fn is_waiting_preflight(&self) -> bool {
1329 matches!(
1330 &self.last_transition,
1331 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1332 )
1333 }
1334
1335 fn is_replaying(&self) -> bool {
1336 matches!(&self.last_transition, Ok(State::Replaying { .. }))
1337 }
1338
1339 fn is_processing(&self) -> bool {
1340 matches!(&self.last_transition, Ok(State::Processing { .. }))
1341 }
1342
1343 fn last_command_index(&self) -> i64 {
1344 self.context.journal.command_index()
1345 }
1346}
1347
1348const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1349 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1350 .with_encode_padding(false);
1351const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1352
1353const AWAKEABLE_PREFIX: &str = "sign_1";
1354
1355fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1356 let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1357 input_buf.put_slice(id);
1358 input_buf.put_u32(completion_index);
1359 format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1360}