1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3 attach_invocation_command_message, complete_awakeable_command_message,
4 complete_promise_command_message, get_invocation_output_command_message,
5 output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6 CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7 CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8 GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9 OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10 SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14 ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15 EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19 AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20 ImplicitCancellationOption, Input, NonDeterministicChecksOption, NonEmptyValue,
21 NotificationHandle, PayloadOptions, ResponseHead, RetryPolicy, RunExitResult, SendHandle,
22 TakeOutputResult, Target, TerminalFailure, VMOptions, VMResult, Value,
23 CANCEL_NOTIFICATION_HANDLE,
24};
25use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
26use base64::{alphabet, Engine};
27use bytes::{Buf, BufMut, Bytes, BytesMut};
28use context::{AsyncResultsState, Context, Output, RunState};
29use std::borrow::Cow;
30use std::collections::{HashMap, VecDeque};
31use std::mem::size_of;
32use std::time::Duration;
33use std::{fmt, mem};
34use strum::IntoStaticStr;
35use tracing::{debug, enabled, instrument, Level};
36
37mod context;
38pub(crate) mod errors;
39mod transitions;
40
41const CONTENT_TYPE: &str = "content-type";
42
43#[derive(Debug, IntoStaticStr)]
44pub(crate) enum State {
45 WaitingStart,
46 WaitingReplayEntries {
47 received_entries: u32,
48 commands: VecDeque<RawMessage>,
49 async_results: AsyncResultsState,
50 },
51 Replaying {
52 commands: VecDeque<RawMessage>,
53 run_state: RunState,
54 async_results: AsyncResultsState,
55 },
56 Processing {
57 processing_first_entry: bool,
58 run_state: RunState,
59 async_results: AsyncResultsState,
60 },
61 Closed,
62}
63
64impl State {
65 fn as_unexpected_state(&self, event: impl ToString) -> Error {
66 if matches!(self, State::Closed) {
67 return ClosedError::new(event.to_string()).into();
68 }
69 UnexpectedStateError::new(self.into(), event.to_string()).into()
70 }
71}
72
73struct TrackedInvocationId {
74 handle: NotificationHandle,
75 invocation_id: Option<String>,
76}
77
78impl TrackedInvocationId {
79 fn is_resolved(&self) -> bool {
80 self.invocation_id.is_some()
81 }
82}
83
84pub struct CoreVM {
85 options: VMOptions,
86
87 decoder: Decoder,
89
90 context: Context,
92 last_transition: Result<State, Error>,
93
94 tracked_invocation_ids: Vec<TrackedInvocationId>,
96
97 sys_run_names: HashMap<NotificationHandle, String>,
99}
100
101impl CoreVM {
102 fn debug_invocation_id(&self) -> &str {
104 if let Some(start_info) = self.context.start_info() {
105 &start_info.debug_id
106 } else {
107 ""
108 }
109 }
110
111 fn debug_state(&self) -> &'static str {
112 match &self.last_transition {
113 Ok(s) => s.into(),
114 Err(_) => "Failed",
115 }
116 }
117
118 fn verify_error_metadata_feature_support(&mut self, value: &NonEmptyValue) -> VMResult<()> {
119 if let NonEmptyValue::Failure(f) = value {
120 if !f.metadata.is_empty() {
121 self.verify_feature_support("terminal error metadata", Version::V6)?;
122 }
123 }
124 Ok(())
125 }
126
127 #[allow(dead_code)]
128 fn verify_feature_support(
129 &mut self,
130 feature: &'static str,
131 minimum_required_protocol: Version,
132 ) -> VMResult<()> {
133 if self.context.negotiated_protocol_version < minimum_required_protocol {
134 return self.do_transition(HitError(
135 UnsupportedFeatureForNegotiatedVersion::new(
136 feature,
137 self.context.negotiated_protocol_version,
138 minimum_required_protocol,
139 )
140 .into(),
141 ));
142 }
143 Ok(())
144 }
145
146 fn _is_completed(&self, handle: NotificationHandle) -> bool {
147 match &self.last_transition {
148 Ok(State::Replaying { async_results, .. })
149 | Ok(State::Processing { async_results, .. }) => {
150 async_results.is_handle_completed(handle)
151 }
152 _ => false,
153 }
154 }
155
156 fn _do_progress(
157 &mut self,
158 any_handle: Vec<NotificationHandle>,
159 ) -> Result<DoProgressResponse, Error> {
160 match self.do_transition(DoProgress(any_handle)) {
161 Ok(Ok(do_progress_response)) => Ok(do_progress_response),
162 Ok(Err(_)) => Err(SUSPENDED),
163 Err(e) => Err(e),
164 }
165 }
166
167 fn is_implicit_cancellation_enabled(&self) -> bool {
168 matches!(
169 self.options.implicit_cancellation,
170 ImplicitCancellationOption::Enabled { .. }
171 )
172 }
173}
174
175impl fmt::Debug for CoreVM {
176 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177 let mut s = f.debug_struct("CoreVM");
178 s.field("version", &self.context.negotiated_protocol_version);
179
180 if let Some(start_info) = self.context.start_info() {
181 s.field("invocation_id", &start_info.debug_id);
182 }
183
184 match &self.last_transition {
185 Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
186 Err(_) => s.field("last_transition", &"Errored"),
187 };
188
189 s.field("command_index", &self.context.journal.command_index())
190 .field(
191 "notification_index",
192 &self.context.journal.notification_index(),
193 )
194 .finish()
195 }
196}
197
198#[allow(unused)]
200const fn is_send<T: Send>() {}
201const _: () = is_send::<CoreVM>();
202
203macro_rules! invocation_debug_logs {
205 ($this:expr, $($arg:tt)*) => {
206 if ($this.is_processing()) {
207 tracing::debug!($($arg)*)
208 }
209 };
210}
211
212impl super::VM for CoreVM {
213 #[instrument(level = "trace", skip(request_headers), ret)]
214 fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
215 let version = request_headers
216 .extract(CONTENT_TYPE)
217 .map_err(|e| {
218 Error::new(
219 errors::codes::BAD_REQUEST,
220 format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
221 )
222 })?
223 .ok_or(errors::MISSING_CONTENT_TYPE)?
224 .parse::<Version>()?;
225
226 if version < Version::minimum_supported_version()
227 || version > Version::maximum_supported_version()
228 {
229 return Err(Error::new(
230 errors::codes::UNSUPPORTED_MEDIA_TYPE,
231 format!(
232 "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
233 You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
234 version,
235 Version::minimum_supported_version(),
236 Version::maximum_supported_version()
237 ),
238 ));
239 }
240 let non_deterministic_checks_ignore_payload_equality = matches!(
241 options.non_determinism_checks,
242 NonDeterministicChecksOption::PayloadChecksDisabled
243 );
244
245 Ok(Self {
246 options,
247 decoder: Decoder::new(version),
248 context: Context {
249 input_is_closed: false,
250 output: Output::new(version),
251 start_info: None,
252 journal: Default::default(),
253 eager_state: Default::default(),
254 non_deterministic_checks_ignore_payload_equality,
255 negotiated_protocol_version: version,
256 },
257 last_transition: Ok(State::WaitingStart),
258 tracked_invocation_ids: vec![],
259 sys_run_names: HashMap::with_capacity(0),
260 })
261 }
262
263 #[instrument(
264 level = "trace",
265 skip(self),
266 fields(
267 restate.invocation.id = self.debug_invocation_id(),
268 restate.protocol.state = self.debug_state(),
269 restate.journal.command_index = self.context.journal.command_index(),
270 restate.protocol.version = %self.context.negotiated_protocol_version
271 ),
272 ret
273 )]
274 fn get_response_head(&self) -> ResponseHead {
275 ResponseHead {
276 status_code: 200,
277 headers: vec![Header {
278 key: Cow::Borrowed(CONTENT_TYPE),
279 value: Cow::Borrowed(self.context.negotiated_protocol_version.content_type()),
280 }],
281 version: self.context.negotiated_protocol_version,
282 }
283 }
284
285 #[instrument(
286 level = "trace",
287 skip(self),
288 fields(
289 restate.invocation.id = self.debug_invocation_id(),
290 restate.protocol.state = self.debug_state(),
291 restate.journal.command_index = self.context.journal.command_index(),
292 restate.protocol.version = %self.context.negotiated_protocol_version
293 ),
294 ret
295 )]
296 fn notify_input(&mut self, buffer: Bytes) {
297 self.decoder.push(buffer);
298 loop {
299 match self.decoder.consume_next() {
300 Ok(Some(msg)) => {
301 if self.do_transition(NewMessage(msg)).is_err() {
302 return;
303 }
304 }
305 Ok(None) => {
306 return;
307 }
308 Err(e) => {
309 if self.do_transition(HitError(e.into())).is_err() {
310 return;
311 }
312 }
313 }
314 }
315 }
316
317 #[instrument(
318 level = "trace",
319 skip(self),
320 fields(
321 restate.invocation.id = self.debug_invocation_id(),
322 restate.protocol.state = self.debug_state(),
323 restate.journal.command_index = self.context.journal.command_index(),
324 restate.protocol.version = %self.context.negotiated_protocol_version
325 ),
326 ret
327 )]
328 fn notify_input_closed(&mut self) {
329 self.context.input_is_closed = true;
330 let _ = self.do_transition(NotifyInputClosed);
331 }
332
333 #[instrument(
334 level = "trace",
335 skip(self),
336 fields(
337 restate.invocation.id = self.debug_invocation_id(),
338 restate.protocol.state = self.debug_state(),
339 restate.journal.command_index = self.context.journal.command_index(),
340 restate.protocol.version = %self.context.negotiated_protocol_version
341 ),
342 ret
343 )]
344 fn notify_error(
345 &mut self,
346 mut error: Error,
347 command_relationship: Option<CommandRelationship>,
348 ) {
349 if let Some(command_relationship) = command_relationship {
350 error = error.with_related_command_metadata(
351 self.context
352 .journal
353 .resolve_related_command(command_relationship),
354 );
355 }
356
357 let _ = self.do_transition(HitError(error));
358 }
359
360 #[instrument(
361 level = "trace",
362 skip(self),
363 fields(
364 restate.invocation.id = self.debug_invocation_id(),
365 restate.protocol.state = self.debug_state(),
366 restate.journal.command_index = self.context.journal.command_index(),
367 restate.protocol.version = %self.context.negotiated_protocol_version
368 ),
369 ret
370 )]
371 fn take_output(&mut self) -> TakeOutputResult {
372 if self.context.output.buffer.has_remaining() {
373 TakeOutputResult::Buffer(
374 self.context
375 .output
376 .buffer
377 .copy_to_bytes(self.context.output.buffer.remaining()),
378 )
379 } else if !self.context.output.is_closed() {
380 TakeOutputResult::Buffer(Bytes::default())
381 } else {
382 TakeOutputResult::EOF
383 }
384 }
385
386 #[instrument(
387 level = "trace",
388 skip(self),
389 fields(
390 restate.invocation.id = self.debug_invocation_id(),
391 restate.protocol.state = self.debug_state(),
392 restate.journal.command_index = self.context.journal.command_index(),
393 restate.protocol.version = %self.context.negotiated_protocol_version
394 ),
395 ret
396 )]
397 fn is_ready_to_execute(&self) -> Result<bool, Error> {
398 match &self.last_transition {
399 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
400 Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
401 Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
402 Err(e) => Err(e.clone()),
403 }
404 }
405
406 #[instrument(
407 level = "trace",
408 skip(self),
409 fields(
410 restate.invocation.id = self.debug_invocation_id(),
411 restate.protocol.state = self.debug_state(),
412 restate.journal.command_index = self.context.journal.command_index(),
413 restate.protocol.version = %self.context.negotiated_protocol_version
414 ),
415 ret
416 )]
417 fn is_completed(&self, handle: NotificationHandle) -> bool {
418 self._is_completed(handle)
419 }
420
421 #[instrument(
422 level = "trace",
423 skip(self),
424 fields(
425 restate.invocation.id = self.debug_invocation_id(),
426 restate.protocol.state = self.debug_state(),
427 restate.journal.command_index = self.context.journal.command_index(),
428 restate.protocol.version = %self.context.negotiated_protocol_version
429 ),
430 ret
431 )]
432 fn do_progress(
433 &mut self,
434 mut any_handle: Vec<NotificationHandle>,
435 ) -> VMResult<DoProgressResponse> {
436 if self.is_implicit_cancellation_enabled() {
437 any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
439
440 match self._do_progress(any_handle) {
441 Ok(DoProgressResponse::AnyCompleted) => {
442 if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
444 for i in 0..self.tracked_invocation_ids.len() {
446 if self.tracked_invocation_ids[i].is_resolved() {
447 continue;
448 }
449
450 let handle = self.tracked_invocation_ids[i].handle;
451
452 match self._do_progress(vec![handle]) {
454 Ok(DoProgressResponse::AnyCompleted) => {
455 let invocation_id = match self.do_transition(CopyNotification(handle)) {
456 Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
457 Ok(Err(_)) => Err(SUSPENDED),
458 _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
459 }?;
460
461 self.tracked_invocation_ids[i].invocation_id =
463 Some(invocation_id);
464 }
465 res => return res,
466 }
467 }
468
469 for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
471 self.sys_cancel_invocation(
472 tracked_invocation_id
473 .invocation_id
474 .expect("We resolved before all the invocation ids"),
475 )?;
476 }
477
478 let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
480
481 Ok(DoProgressResponse::CancelSignalReceived)
483 } else {
484 Ok(DoProgressResponse::AnyCompleted)
485 }
486 }
487 res => res,
488 }
489 } else {
490 self._do_progress(any_handle)
491 }
492 }
493
494 #[instrument(
495 level = "trace",
496 skip(self),
497 fields(
498 restate.invocation.id = self.debug_invocation_id(),
499 restate.protocol.state = self.debug_state(),
500 restate.journal.command_index = self.context.journal.command_index(),
501 restate.protocol.version = %self.context.negotiated_protocol_version
502 ),
503 ret
504 )]
505 fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
506 match self.do_transition(TakeNotification(handle)) {
507 Ok(Ok(Some(value))) => {
508 if self.is_implicit_cancellation_enabled() {
509 if let Ok(found) = self
512 .tracked_invocation_ids
513 .binary_search_by(|tracked| tracked.handle.cmp(&handle))
514 {
515 let Value::InvocationId(invocation_id) = &value else {
516 panic!("Expecting an invocation id here, but got {value:?}");
517 };
518 self.tracked_invocation_ids
520 .get_mut(found)
521 .unwrap()
522 .invocation_id = Some(invocation_id.clone());
523 }
524 }
525
526 Ok(Some(value))
527 }
528 Ok(Ok(None)) => Ok(None),
529 Ok(Err(_)) => Err(SUSPENDED),
530 Err(e) => Err(e),
531 }
532 }
533
534 #[instrument(
535 level = "trace",
536 skip(self),
537 fields(
538 restate.invocation.id = self.debug_invocation_id(),
539 restate.protocol.state = self.debug_state(),
540 restate.journal.command_index = self.context.journal.command_index(),
541 restate.protocol.version = %self.context.negotiated_protocol_version
542 ),
543 ret
544 )]
545 fn sys_input(&mut self) -> Result<Input, Error> {
546 self.do_transition(SysInput)
547 }
548
549 #[instrument(
550 level = "trace",
551 skip(self),
552 fields(
553 restate.invocation.id = self.debug_invocation_id(),
554 restate.protocol.state = self.debug_state(),
555 restate.journal.command_index = self.context.journal.command_index(),
556 restate.protocol.version = %self.context.negotiated_protocol_version
557 ),
558 ret
559 )]
560 fn sys_state_get(
561 &mut self,
562 key: String,
563 options: PayloadOptions,
564 ) -> Result<NotificationHandle, Error> {
565 invocation_debug_logs!(self, "Executing 'Get state {key}'");
566 self.do_transition(SysStateGet(key, options))
567 }
568
569 #[instrument(
570 level = "trace",
571 skip(self),
572 fields(
573 restate.invocation.id = self.debug_invocation_id(),
574 restate.protocol.state = self.debug_state(),
575 restate.journal.command_index = self.context.journal.command_index(),
576 restate.protocol.version = %self.context.negotiated_protocol_version
577 ),
578 ret
579 )]
580 fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
581 invocation_debug_logs!(self, "Executing 'Get state keys'");
582 self.do_transition(SysStateGetKeys)
583 }
584
585 #[instrument(
586 level = "trace",
587 skip(self, value),
588 fields(
589 restate.invocation.id = self.debug_invocation_id(),
590 restate.protocol.state = self.debug_state(),
591 restate.journal.command_index = self.context.journal.command_index(),
592 restate.protocol.version = %self.context.negotiated_protocol_version
593 ),
594 ret
595 )]
596 fn sys_state_set(
597 &mut self,
598 key: String,
599 value: Bytes,
600 options: PayloadOptions,
601 ) -> VMResult<()> {
602 invocation_debug_logs!(self, "Executing 'Set state {key}'");
603 self.context.eager_state.set(key.clone(), value.clone());
604 self.do_transition(SysNonCompletableEntry(
605 SetStateCommandMessage {
606 key: Bytes::from(key.into_bytes()),
607 value: Some(value.into()),
608 ..SetStateCommandMessage::default()
609 },
610 options,
611 ))
612 }
613
614 #[instrument(
615 level = "trace",
616 skip(self),
617 fields(
618 restate.invocation.id = self.debug_invocation_id(),
619 restate.protocol.state = self.debug_state(),
620 restate.journal.command_index = self.context.journal.command_index(),
621 restate.protocol.version = %self.context.negotiated_protocol_version
622 ),
623 ret
624 )]
625 fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
626 invocation_debug_logs!(self, "Executing 'Clear state {key}'");
627 self.context.eager_state.clear(key.clone());
628 self.do_transition(SysNonCompletableEntry(
629 ClearStateCommandMessage {
630 key: Bytes::from(key.into_bytes()),
631 ..ClearStateCommandMessage::default()
632 },
633 PayloadOptions::default(),
634 ))
635 }
636
637 #[instrument(
638 level = "trace",
639 skip(self),
640 fields(
641 restate.invocation.id = self.debug_invocation_id(),
642 restate.protocol.state = self.debug_state(),
643 restate.journal.command_index = self.context.journal.command_index(),
644 restate.protocol.version = %self.context.negotiated_protocol_version
645 ),
646 ret
647 )]
648 fn sys_state_clear_all(&mut self) -> Result<(), Error> {
649 invocation_debug_logs!(self, "Executing 'Clear all state'");
650 self.context.eager_state.clear_all();
651 self.do_transition(SysNonCompletableEntry(
652 ClearAllStateCommandMessage::default(),
653 PayloadOptions::default(),
654 ))
655 }
656
657 #[instrument(
658 level = "trace",
659 skip(self),
660 fields(
661 restate.invocation.id = self.debug_invocation_id(),
662 restate.protocol.state = self.debug_state(),
663 restate.journal.command_index = self.context.journal.command_index(),
664 restate.protocol.version = %self.context.negotiated_protocol_version
665 ),
666 ret
667 )]
668 fn sys_sleep(
669 &mut self,
670 name: String,
671 wake_up_time_since_unix_epoch: Duration,
672 now_since_unix_epoch: Option<Duration>,
673 ) -> VMResult<NotificationHandle> {
674 if self.is_processing() {
675 match (&name, now_since_unix_epoch) {
676 (name, Some(now_since_unix_epoch)) if name.is_empty() => {
677 debug!(
678 "Executing 'Timer with duration {:?}'",
679 wake_up_time_since_unix_epoch - now_since_unix_epoch
680 );
681 }
682 (name, Some(now_since_unix_epoch)) => {
683 debug!(
684 "Executing 'Timer {name} with duration {:?}'",
685 wake_up_time_since_unix_epoch - now_since_unix_epoch
686 );
687 }
688 (name, None) if name.is_empty() => {
689 debug!("Executing 'Timer'");
690 }
691 (name, None) => {
692 debug!("Executing 'Timer named {name}'");
693 }
694 }
695 }
696
697 let completion_id = self.context.journal.next_completion_notification_id();
698
699 self.do_transition(SysSimpleCompletableEntry(
700 SleepCommandMessage {
701 wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
702 .expect("millis since Unix epoch should fit in u64"),
703 result_completion_id: completion_id,
704 name,
705 },
706 completion_id,
707 PayloadOptions::default(),
708 ))
709 }
710
711 #[instrument(
712 level = "trace",
713 skip(self, input),
714 fields(
715 restate.invocation.id = self.debug_invocation_id(),
716 restate.protocol.state = self.debug_state(),
717 restate.journal.command_index = self.context.journal.command_index(),
718 restate.protocol.version = %self.context.negotiated_protocol_version
719 ),
720 ret
721 )]
722 fn sys_call(
723 &mut self,
724 target: Target,
725 input: Bytes,
726 name: Option<String>,
727 options: PayloadOptions,
728 ) -> VMResult<CallHandle> {
729 invocation_debug_logs!(
730 self,
731 "Executing 'Call {}/{}'",
732 target.service,
733 target.handler
734 );
735 if let Some(idempotency_key) = &target.idempotency_key {
736 if idempotency_key.is_empty() {
737 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
738 unreachable!();
739 }
740 }
741
742 let call_invocation_id_completion_id =
743 self.context.journal.next_completion_notification_id();
744 let result_completion_id = self.context.journal.next_completion_notification_id();
745
746 let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
747 CallCommandMessage {
748 service_name: target.service,
749 handler_name: target.handler,
750 key: target.key.unwrap_or_default(),
751 idempotency_key: target.idempotency_key,
752 headers: target
753 .headers
754 .into_iter()
755 .map(crate::service_protocol::messages::Header::from)
756 .collect(),
757 parameter: input,
758 invocation_id_notification_idx: call_invocation_id_completion_id,
759 name: name.unwrap_or_default(),
760 result_completion_id,
761 },
762 vec![call_invocation_id_completion_id, result_completion_id],
763 options,
764 ))?;
765
766 if matches!(
767 self.options.implicit_cancellation,
768 ImplicitCancellationOption::Enabled {
769 cancel_children_calls: true,
770 ..
771 }
772 ) {
773 self.tracked_invocation_ids.push(TrackedInvocationId {
774 handle: handles[0],
775 invocation_id: None,
776 })
777 }
778
779 Ok(CallHandle {
780 invocation_id_notification_handle: handles[0],
781 call_notification_handle: handles[1],
782 })
783 }
784
785 #[instrument(
786 level = "trace",
787 skip(self, input),
788 fields(
789 restate.invocation.id = self.debug_invocation_id(),
790 restate.protocol.state = self.debug_state(),
791 restate.journal.command_index = self.context.journal.command_index(),
792 restate.protocol.version = %self.context.negotiated_protocol_version
793 ),
794 ret
795 )]
796 fn sys_send(
797 &mut self,
798 target: Target,
799 input: Bytes,
800 delay: Option<Duration>,
801 name: Option<String>,
802 options: PayloadOptions,
803 ) -> VMResult<SendHandle> {
804 invocation_debug_logs!(
805 self,
806 "Executing 'Send to {}/{}'",
807 target.service,
808 target.handler
809 );
810 if let Some(idempotency_key) = &target.idempotency_key {
811 if idempotency_key.is_empty() {
812 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
813 unreachable!();
814 }
815 }
816 let call_invocation_id_completion_id =
817 self.context.journal.next_completion_notification_id();
818 let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
819 OneWayCallCommandMessage {
820 service_name: target.service,
821 handler_name: target.handler,
822 key: target.key.unwrap_or_default(),
823 idempotency_key: target.idempotency_key,
824 headers: target
825 .headers
826 .into_iter()
827 .map(crate::service_protocol::messages::Header::from)
828 .collect(),
829 parameter: input,
830 invoke_time: delay
831 .map(|d| {
832 u64::try_from(d.as_millis())
833 .expect("millis since Unix epoch should fit in u64")
834 })
835 .unwrap_or_default(),
836 invocation_id_notification_idx: call_invocation_id_completion_id,
837 name: name.unwrap_or_default(),
838 },
839 call_invocation_id_completion_id,
840 options,
841 ))?;
842
843 if matches!(
844 self.options.implicit_cancellation,
845 ImplicitCancellationOption::Enabled {
846 cancel_children_one_way_calls: true,
847 ..
848 }
849 ) {
850 self.tracked_invocation_ids.push(TrackedInvocationId {
851 handle: invocation_id_notification_handle,
852 invocation_id: None,
853 })
854 }
855
856 Ok(SendHandle {
857 invocation_id_notification_handle,
858 })
859 }
860
861 #[instrument(
862 level = "trace",
863 skip(self),
864 fields(
865 restate.invocation.id = self.debug_invocation_id(),
866 restate.protocol.state = self.debug_state(),
867 restate.journal.command_index = self.context.journal.command_index(),
868 restate.protocol.version = %self.context.negotiated_protocol_version
869 ),
870 ret
871 )]
872 fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
873 invocation_debug_logs!(self, "Executing 'Create awakeable'");
874
875 let signal_id = self.context.journal.next_signal_notification_id();
876
877 let handle = self.do_transition(CreateSignalHandle(
878 "awakeable",
879 NotificationId::SignalId(signal_id),
880 ))?;
881
882 Ok((
883 awakeable_id_str(&self.context.expect_start_info().id, signal_id),
884 handle,
885 ))
886 }
887
888 #[instrument(
889 level = "trace",
890 skip(self, value),
891 fields(
892 restate.invocation.id = self.debug_invocation_id(),
893 restate.protocol.state = self.debug_state(),
894 restate.journal.command_index = self.context.journal.command_index(),
895 restate.protocol.version = %self.context.negotiated_protocol_version
896 ),
897 ret
898 )]
899 fn sys_complete_awakeable(
900 &mut self,
901 id: String,
902 value: NonEmptyValue,
903 options: PayloadOptions,
904 ) -> VMResult<()> {
905 invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
906 self.verify_error_metadata_feature_support(&value)?;
907 self.do_transition(SysNonCompletableEntry(
908 CompleteAwakeableCommandMessage {
909 awakeable_id: id,
910 result: Some(match value {
911 NonEmptyValue::Success(s) => {
912 complete_awakeable_command_message::Result::Value(s.into())
913 }
914 NonEmptyValue::Failure(f) => {
915 complete_awakeable_command_message::Result::Failure(f.into())
916 }
917 }),
918 ..Default::default()
919 },
920 options,
921 ))
922 }
923
924 #[instrument(
925 level = "trace",
926 skip(self),
927 fields(
928 restate.invocation.id = self.debug_invocation_id(),
929 restate.protocol.state = self.debug_state(),
930 restate.journal.command_index = self.context.journal.command_index(),
931 restate.protocol.version = %self.context.negotiated_protocol_version
932 ),
933 ret
934 )]
935 fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
936 invocation_debug_logs!(self, "Executing 'Create named signal'");
937
938 self.do_transition(CreateSignalHandle(
939 "named awakeable",
940 NotificationId::SignalName(signal_name),
941 ))
942 }
943
944 #[instrument(
945 level = "trace",
946 skip(self, value),
947 fields(
948 restate.invocation.id = self.debug_invocation_id(),
949 restate.protocol.state = self.debug_state(),
950 restate.journal.command_index = self.context.journal.command_index(),
951 restate.protocol.version = %self.context.negotiated_protocol_version
952 ),
953 ret
954 )]
955 fn sys_complete_signal(
956 &mut self,
957 target_invocation_id: String,
958 signal_name: String,
959 value: NonEmptyValue,
960 ) -> VMResult<()> {
961 invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
962 self.verify_error_metadata_feature_support(&value)?;
963 self.do_transition(SysNonCompletableEntry(
964 SendSignalCommandMessage {
965 target_invocation_id,
966 signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
967 result: Some(match value {
968 NonEmptyValue::Success(s) => {
969 send_signal_command_message::Result::Value(s.into())
970 }
971 NonEmptyValue::Failure(f) => {
972 send_signal_command_message::Result::Failure(f.into())
973 }
974 }),
975 ..Default::default()
976 },
977 PayloadOptions::default(),
978 ))
979 }
980
981 #[instrument(
982 level = "trace",
983 skip(self),
984 fields(
985 restate.invocation.id = self.debug_invocation_id(),
986 restate.protocol.state = self.debug_state(),
987 restate.journal.command_index = self.context.journal.command_index(),
988 restate.protocol.version = %self.context.negotiated_protocol_version
989 ),
990 ret
991 )]
992 fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
993 invocation_debug_logs!(self, "Executing 'Await promise {key}'");
994
995 let result_completion_id = self.context.journal.next_completion_notification_id();
996 self.do_transition(SysSimpleCompletableEntry(
997 GetPromiseCommandMessage {
998 key,
999 result_completion_id,
1000 ..Default::default()
1001 },
1002 result_completion_id,
1003 PayloadOptions::default(),
1004 ))
1005 }
1006
1007 #[instrument(
1008 level = "trace",
1009 skip(self),
1010 fields(
1011 restate.invocation.id = self.debug_invocation_id(),
1012 restate.protocol.state = self.debug_state(),
1013 restate.journal.command_index = self.context.journal.command_index(),
1014 restate.protocol.version = %self.context.negotiated_protocol_version
1015 ),
1016 ret
1017 )]
1018 fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
1019 invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
1020
1021 let result_completion_id = self.context.journal.next_completion_notification_id();
1022 self.do_transition(SysSimpleCompletableEntry(
1023 PeekPromiseCommandMessage {
1024 key,
1025 result_completion_id,
1026 ..Default::default()
1027 },
1028 result_completion_id,
1029 PayloadOptions::default(),
1030 ))
1031 }
1032
1033 #[instrument(
1034 level = "trace",
1035 skip(self, value),
1036 fields(
1037 restate.invocation.id = self.debug_invocation_id(),
1038 restate.protocol.state = self.debug_state(),
1039 restate.journal.command_index = self.context.journal.command_index(),
1040 restate.protocol.version = %self.context.negotiated_protocol_version
1041 ),
1042 ret
1043 )]
1044 fn sys_complete_promise(
1045 &mut self,
1046 key: String,
1047 value: NonEmptyValue,
1048 options: PayloadOptions,
1049 ) -> VMResult<NotificationHandle> {
1050 invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1051 self.verify_error_metadata_feature_support(&value)?;
1052
1053 let result_completion_id = self.context.journal.next_completion_notification_id();
1054 self.do_transition(SysSimpleCompletableEntry(
1055 CompletePromiseCommandMessage {
1056 key,
1057 completion: Some(match value {
1058 NonEmptyValue::Success(s) => {
1059 complete_promise_command_message::Completion::CompletionValue(s.into())
1060 }
1061 NonEmptyValue::Failure(f) => {
1062 complete_promise_command_message::Completion::CompletionFailure(f.into())
1063 }
1064 }),
1065 result_completion_id,
1066 ..Default::default()
1067 },
1068 result_completion_id,
1069 options,
1070 ))
1071 }
1072
1073 #[instrument(
1074 level = "trace",
1075 skip(self),
1076 fields(
1077 restate.invocation.id = self.debug_invocation_id(),
1078 restate.protocol.state = self.debug_state(),
1079 restate.journal.command_index = self.context.journal.command_index(),
1080 restate.protocol.version = %self.context.negotiated_protocol_version
1081 ),
1082 ret
1083 )]
1084 fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1085 match self.do_transition(SysRun(name.clone())) {
1086 Ok(handle) => {
1087 if enabled!(Level::DEBUG) {
1088 self.sys_run_names.insert(handle, name);
1090 }
1091 Ok(handle)
1092 }
1093 Err(e) => Err(e),
1094 }
1095 }
1096
1097 #[instrument(
1098 level = "trace",
1099 skip(self, value, retry_policy),
1100 fields(
1101 restate.invocation.id = self.debug_invocation_id(),
1102 restate.protocol.state = self.debug_state(),
1103 restate.journal.command_index = self.context.journal.command_index(),
1104 restate.protocol.version = %self.context.negotiated_protocol_version
1105 ),
1106 ret
1107 )]
1108 fn propose_run_completion(
1109 &mut self,
1110 notification_handle: NotificationHandle,
1111 value: RunExitResult,
1112 retry_policy: RetryPolicy,
1113 ) -> VMResult<()> {
1114 if enabled!(Level::DEBUG) {
1115 let name: &str = self
1116 .sys_run_names
1117 .get(¬ification_handle)
1118 .map(String::as_str)
1119 .unwrap_or_default();
1120 match &value {
1121 RunExitResult::Success(_) => {
1122 invocation_debug_logs!(self, "Journaling run '{name}' success result");
1123 }
1124 RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1125 invocation_debug_logs!(
1126 self,
1127 "Journaling run '{name}' terminal failure {code} result"
1128 );
1129 }
1130 RunExitResult::RetryableFailure { .. } => {
1131 invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1132 }
1133 }
1134 }
1135 if let RunExitResult::TerminalFailure(f) = &value {
1136 if !f.metadata.is_empty() {
1137 self.verify_feature_support("terminal error metadata", Version::V6)?;
1138 }
1139 }
1140
1141 self.do_transition(ProposeRunCompletion(
1142 notification_handle,
1143 value,
1144 retry_policy,
1145 ))
1146 }
1147
1148 #[instrument(
1149 level = "trace",
1150 skip(self),
1151 fields(
1152 restate.invocation.id = self.debug_invocation_id(),
1153 restate.protocol.state = self.debug_state(),
1154 restate.journal.command_index = self.context.journal.command_index(),
1155 restate.protocol.version = %self.context.negotiated_protocol_version
1156 ),
1157 ret
1158 )]
1159 fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1160 invocation_debug_logs!(
1161 self,
1162 "Executing 'Cancel invocation' of {target_invocation_id}"
1163 );
1164 self.do_transition(SysNonCompletableEntry(
1165 SendSignalCommandMessage {
1166 target_invocation_id,
1167 signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1168 result: Some(send_signal_command_message::Result::Void(Default::default())),
1169 ..Default::default()
1170 },
1171 PayloadOptions::default(),
1172 ))
1173 }
1174
1175 #[instrument(
1176 level = "trace",
1177 skip(self),
1178 fields(
1179 restate.invocation.id = self.debug_invocation_id(),
1180 restate.protocol.state = self.debug_state(),
1181 restate.journal.command_index = self.context.journal.command_index(),
1182 restate.protocol.version = %self.context.negotiated_protocol_version
1183 ),
1184 ret
1185 )]
1186 fn sys_attach_invocation(
1187 &mut self,
1188 target: AttachInvocationTarget,
1189 ) -> VMResult<NotificationHandle> {
1190 invocation_debug_logs!(self, "Executing 'Attach invocation'");
1191
1192 let result_completion_id = self.context.journal.next_completion_notification_id();
1193 self.do_transition(SysSimpleCompletableEntry(
1194 AttachInvocationCommandMessage {
1195 target: Some(match target {
1196 AttachInvocationTarget::InvocationId(id) => {
1197 attach_invocation_command_message::Target::InvocationId(id)
1198 }
1199 AttachInvocationTarget::WorkflowId { name, key } => {
1200 attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1201 workflow_name: name,
1202 workflow_key: key,
1203 })
1204 }
1205 AttachInvocationTarget::IdempotencyId {
1206 service_name,
1207 service_key,
1208 handler_name,
1209 idempotency_key,
1210 } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1211 IdempotentRequestTarget {
1212 service_name,
1213 service_key,
1214 handler_name,
1215 idempotency_key,
1216 },
1217 ),
1218 }),
1219 result_completion_id,
1220 ..Default::default()
1221 },
1222 result_completion_id,
1223 PayloadOptions::default(),
1224 ))
1225 }
1226
1227 #[instrument(
1228 level = "trace",
1229 skip(self),
1230 fields(
1231 restate.invocation.id = self.debug_invocation_id(),
1232 restate.protocol.state = self.debug_state(),
1233 restate.journal.command_index = self.context.journal.command_index(),
1234 restate.protocol.version = %self.context.negotiated_protocol_version
1235 ),
1236 ret
1237 )]
1238 fn sys_get_invocation_output(
1239 &mut self,
1240 target: AttachInvocationTarget,
1241 ) -> VMResult<NotificationHandle> {
1242 invocation_debug_logs!(self, "Executing 'Get invocation output'");
1243
1244 let result_completion_id = self.context.journal.next_completion_notification_id();
1245 self.do_transition(SysSimpleCompletableEntry(
1246 GetInvocationOutputCommandMessage {
1247 target: Some(match target {
1248 AttachInvocationTarget::InvocationId(id) => {
1249 get_invocation_output_command_message::Target::InvocationId(id)
1250 }
1251 AttachInvocationTarget::WorkflowId { name, key } => {
1252 get_invocation_output_command_message::Target::WorkflowTarget(
1253 WorkflowTarget {
1254 workflow_name: name,
1255 workflow_key: key,
1256 },
1257 )
1258 }
1259 AttachInvocationTarget::IdempotencyId {
1260 service_name,
1261 service_key,
1262 handler_name,
1263 idempotency_key,
1264 } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1265 IdempotentRequestTarget {
1266 service_name,
1267 service_key,
1268 handler_name,
1269 idempotency_key,
1270 },
1271 ),
1272 }),
1273 result_completion_id,
1274 ..Default::default()
1275 },
1276 result_completion_id,
1277 PayloadOptions::default(),
1278 ))
1279 }
1280
1281 #[instrument(
1282 level = "trace",
1283 skip(self, value),
1284 fields(
1285 restate.invocation.id = self.debug_invocation_id(),
1286 restate.protocol.state = self.debug_state(),
1287 restate.journal.command_index = self.context.journal.command_index(),
1288 restate.protocol.version = %self.context.negotiated_protocol_version
1289 ),
1290 ret
1291 )]
1292 fn sys_write_output(&mut self, value: NonEmptyValue, options: PayloadOptions) -> VMResult<()> {
1293 match &value {
1294 NonEmptyValue::Success(_) => {
1295 invocation_debug_logs!(self, "Writing invocation result success value");
1296 }
1297 NonEmptyValue::Failure(_) => {
1298 invocation_debug_logs!(self, "Writing invocation result failure value");
1299 }
1300 }
1301 self.verify_error_metadata_feature_support(&value)?;
1302 self.do_transition(SysNonCompletableEntry(
1303 OutputCommandMessage {
1304 result: Some(match value {
1305 NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1306 NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1307 }),
1308 ..OutputCommandMessage::default()
1309 },
1310 options,
1311 ))
1312 }
1313
1314 #[instrument(
1315 level = "trace",
1316 skip(self),
1317 fields(
1318 restate.invocation.id = self.debug_invocation_id(),
1319 restate.protocol.state = self.debug_state(),
1320 restate.journal.command_index = self.context.journal.command_index(),
1321 restate.protocol.version = %self.context.negotiated_protocol_version
1322 ),
1323 ret
1324 )]
1325 fn sys_end(&mut self) -> Result<(), Error> {
1326 invocation_debug_logs!(self, "End of the invocation");
1327 self.do_transition(SysEnd)
1328 }
1329
1330 fn is_waiting_preflight(&self) -> bool {
1331 matches!(
1332 &self.last_transition,
1333 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1334 )
1335 }
1336
1337 fn is_replaying(&self) -> bool {
1338 matches!(&self.last_transition, Ok(State::Replaying { .. }))
1339 }
1340
1341 fn is_processing(&self) -> bool {
1342 matches!(&self.last_transition, Ok(State::Processing { .. }))
1343 }
1344
1345 fn last_command_index(&self) -> i64 {
1346 self.context.journal.command_index()
1347 }
1348}
1349
1350const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1351 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1352 .with_encode_padding(false);
1353const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1354
1355const AWAKEABLE_PREFIX: &str = "sign_1";
1356
1357pub(super) fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1358 let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1359 input_buf.put_slice(id);
1360 input_buf.put_u32(completion_index);
1361 format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1362}