1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3 attach_invocation_command_message, complete_awakeable_command_message,
4 complete_promise_command_message, get_invocation_output_command_message,
5 output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6 CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7 CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8 GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9 OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10 SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14 ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15 EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19 AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20 ImplicitCancellationOption, Input, NonEmptyValue, NotificationHandle, ResponseHead,
21 RetryPolicy, RunExitResult, SendHandle, TakeOutputResult, Target, TerminalFailure, VMOptions,
22 VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
23};
24use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
25use base64::{alphabet, Engine};
26use bytes::{Buf, BufMut, Bytes, BytesMut};
27use context::{AsyncResultsState, Context, Output, RunState};
28use std::borrow::Cow;
29use std::collections::{HashMap, VecDeque};
30use std::mem::size_of;
31use std::time::Duration;
32use std::{fmt, mem};
33use strum::IntoStaticStr;
34use tracing::{debug, enabled, instrument, Level};
35
36mod context;
37pub(crate) mod errors;
38mod transitions;
39
40const CONTENT_TYPE: &str = "content-type";
41
42#[derive(Debug, IntoStaticStr)]
43pub(crate) enum State {
44 WaitingStart,
45 WaitingReplayEntries {
46 received_entries: u32,
47 commands: VecDeque<RawMessage>,
48 async_results: AsyncResultsState,
49 },
50 Replaying {
51 commands: VecDeque<RawMessage>,
52 run_state: RunState,
53 async_results: AsyncResultsState,
54 },
55 Processing {
56 processing_first_entry: bool,
57 run_state: RunState,
58 async_results: AsyncResultsState,
59 },
60 Closed,
61}
62
63impl State {
64 fn as_unexpected_state(&self, event: &'static str) -> Error {
65 if matches!(self, State::Closed) {
66 return ClosedError::new(event).into();
67 }
68 UnexpectedStateError::new(self.into(), event).into()
69 }
70}
71
72struct TrackedInvocationId {
73 handle: NotificationHandle,
74 invocation_id: Option<String>,
75}
76
77impl TrackedInvocationId {
78 fn is_resolved(&self) -> bool {
79 self.invocation_id.is_some()
80 }
81}
82
83pub struct CoreVM {
84 version: Version,
85 options: VMOptions,
86
87 decoder: Decoder,
89
90 context: Context,
92 last_transition: Result<State, Error>,
93
94 tracked_invocation_ids: Vec<TrackedInvocationId>,
96
97 sys_run_names: HashMap<NotificationHandle, String>,
99}
100
101impl CoreVM {
102 fn debug_invocation_id(&self) -> &str {
104 if let Some(start_info) = self.context.start_info() {
105 &start_info.debug_id
106 } else {
107 ""
108 }
109 }
110
111 fn debug_state(&self) -> &'static str {
112 match &self.last_transition {
113 Ok(s) => s.into(),
114 Err(_) => "Failed",
115 }
116 }
117
118 fn verify_feature_support(
119 &mut self,
120 feature: &'static str,
121 minimum_required_protocol: Version,
122 ) -> VMResult<()> {
123 if self.version < minimum_required_protocol {
124 return self.do_transition(HitError(
125 UnsupportedFeatureForNegotiatedVersion::new(
126 feature,
127 self.version,
128 minimum_required_protocol,
129 )
130 .into(),
131 ));
132 }
133 Ok(())
134 }
135
136 fn _is_completed(&self, handle: NotificationHandle) -> bool {
137 match &self.last_transition {
138 Ok(State::Replaying { async_results, .. })
139 | Ok(State::Processing { async_results, .. }) => {
140 async_results.is_handle_completed(handle)
141 }
142 _ => false,
143 }
144 }
145
146 fn _do_progress(
147 &mut self,
148 any_handle: Vec<NotificationHandle>,
149 ) -> Result<DoProgressResponse, Error> {
150 match self.do_transition(DoProgress(any_handle)) {
151 Ok(Ok(do_progress_response)) => Ok(do_progress_response),
152 Ok(Err(_)) => Err(SUSPENDED),
153 Err(e) => Err(e),
154 }
155 }
156
157 fn is_implicit_cancellation_enabled(&self) -> bool {
158 matches!(
159 self.options.implicit_cancellation,
160 ImplicitCancellationOption::Enabled { .. }
161 )
162 }
163}
164
165impl fmt::Debug for CoreVM {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 let mut s = f.debug_struct("CoreVM");
168 s.field("version", &self.version);
169
170 if let Some(start_info) = self.context.start_info() {
171 s.field("invocation_id", &start_info.debug_id);
172 }
173
174 match &self.last_transition {
175 Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
176 Err(_) => s.field("last_transition", &"Errored"),
177 };
178
179 s.field("command_index", &self.context.journal.command_index())
180 .field(
181 "notification_index",
182 &self.context.journal.notification_index(),
183 )
184 .finish()
185 }
186}
187
188#[allow(unused)]
190const fn is_send<T: Send>() {}
191const _: () = is_send::<CoreVM>();
192
193macro_rules! invocation_debug_logs {
195 ($this:expr, $($arg:tt)*) => {
196 if ($this.is_processing()) {
197 tracing::debug!($($arg)*)
198 }
199 };
200}
201
202impl super::VM for CoreVM {
203 #[instrument(level = "trace", skip(request_headers), ret)]
204 fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
205 let version = request_headers
206 .extract(CONTENT_TYPE)
207 .map_err(|e| {
208 Error::new(
209 errors::codes::BAD_REQUEST,
210 format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
211 )
212 })?
213 .ok_or(errors::MISSING_CONTENT_TYPE)?
214 .parse::<Version>()?;
215
216 if version < Version::minimum_supported_version()
217 || version > Version::maximum_supported_version()
218 {
219 return Err(Error::new(
220 errors::codes::UNSUPPORTED_MEDIA_TYPE,
221 format!(
222 "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
223 You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
224 version,
225 Version::minimum_supported_version(),
226 Version::maximum_supported_version()
227 ),
228 ));
229 }
230
231 Ok(Self {
232 version,
233 options,
234 decoder: Decoder::new(version),
235 context: Context {
236 input_is_closed: false,
237 output: Output::new(version),
238 start_info: None,
239 journal: Default::default(),
240 eager_state: Default::default(),
241 },
242 last_transition: Ok(State::WaitingStart),
243 tracked_invocation_ids: vec![],
244 sys_run_names: HashMap::with_capacity(0),
245 })
246 }
247
248 #[instrument(
249 level = "trace",
250 skip(self),
251 fields(
252 restate.invocation.id = self.debug_invocation_id(),
253 restate.protocol.state = self.debug_state(),
254 restate.journal.command_index = self.context.journal.command_index(),
255 restate.protocol.version = %self.version
256 ),
257 ret
258 )]
259 fn get_response_head(&self) -> ResponseHead {
260 ResponseHead {
261 status_code: 200,
262 headers: vec![Header {
263 key: Cow::Borrowed(CONTENT_TYPE),
264 value: Cow::Borrowed(self.version.content_type()),
265 }],
266 version: self.version,
267 }
268 }
269
270 #[instrument(
271 level = "trace",
272 skip(self),
273 fields(
274 restate.invocation.id = self.debug_invocation_id(),
275 restate.protocol.state = self.debug_state(),
276 restate.journal.command_index = self.context.journal.command_index(),
277 restate.protocol.version = %self.version
278 ),
279 ret
280 )]
281 fn notify_input(&mut self, buffer: Bytes) {
282 self.decoder.push(buffer);
283 loop {
284 match self.decoder.consume_next() {
285 Ok(Some(msg)) => {
286 if self.do_transition(NewMessage(msg)).is_err() {
287 return;
288 }
289 }
290 Ok(None) => {
291 return;
292 }
293 Err(e) => {
294 if self.do_transition(HitError(e.into())).is_err() {
295 return;
296 }
297 }
298 }
299 }
300 }
301
302 #[instrument(
303 level = "trace",
304 skip(self),
305 fields(
306 restate.invocation.id = self.debug_invocation_id(),
307 restate.protocol.state = self.debug_state(),
308 restate.journal.command_index = self.context.journal.command_index(),
309 restate.protocol.version = %self.version
310 ),
311 ret
312 )]
313 fn notify_input_closed(&mut self) {
314 self.context.input_is_closed = true;
315 let _ = self.do_transition(NotifyInputClosed);
316 }
317
318 #[instrument(
319 level = "trace",
320 skip(self),
321 fields(
322 restate.invocation.id = self.debug_invocation_id(),
323 restate.protocol.state = self.debug_state(),
324 restate.journal.command_index = self.context.journal.command_index(),
325 restate.protocol.version = %self.version
326 ),
327 ret
328 )]
329 fn notify_error(
330 &mut self,
331 mut error: Error,
332 command_relationship: Option<CommandRelationship>,
333 ) {
334 if let Some(command_relationship) = command_relationship {
335 error = error.with_related_command_metadata(
336 self.context
337 .journal
338 .resolve_related_command(command_relationship),
339 );
340 }
341
342 let _ = self.do_transition(HitError(error));
343 }
344
345 #[instrument(
346 level = "trace",
347 skip(self),
348 fields(
349 restate.invocation.id = self.debug_invocation_id(),
350 restate.protocol.state = self.debug_state(),
351 restate.journal.command_index = self.context.journal.command_index(),
352 restate.protocol.version = %self.version
353 ),
354 ret
355 )]
356 fn take_output(&mut self) -> TakeOutputResult {
357 if self.context.output.buffer.has_remaining() {
358 TakeOutputResult::Buffer(
359 self.context
360 .output
361 .buffer
362 .copy_to_bytes(self.context.output.buffer.remaining()),
363 )
364 } else if !self.context.output.is_closed() {
365 TakeOutputResult::Buffer(Bytes::default())
366 } else {
367 TakeOutputResult::EOF
368 }
369 }
370
371 #[instrument(
372 level = "trace",
373 skip(self),
374 fields(
375 restate.invocation.id = self.debug_invocation_id(),
376 restate.protocol.state = self.debug_state(),
377 restate.journal.command_index = self.context.journal.command_index(),
378 restate.protocol.version = %self.version
379 ),
380 ret
381 )]
382 fn is_ready_to_execute(&self) -> Result<bool, Error> {
383 match &self.last_transition {
384 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
385 Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
386 Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
387 Err(e) => Err(e.clone()),
388 }
389 }
390
391 #[instrument(
392 level = "trace",
393 skip(self),
394 fields(
395 restate.invocation.id = self.debug_invocation_id(),
396 restate.protocol.state = self.debug_state(),
397 restate.journal.command_index = self.context.journal.command_index(),
398 restate.protocol.version = %self.version
399 ),
400 ret
401 )]
402 fn is_completed(&self, handle: NotificationHandle) -> bool {
403 self._is_completed(handle)
404 }
405
406 #[instrument(
407 level = "trace",
408 skip(self),
409 fields(
410 restate.invocation.id = self.debug_invocation_id(),
411 restate.protocol.state = self.debug_state(),
412 restate.journal.command_index = self.context.journal.command_index(),
413 restate.protocol.version = %self.version
414 ),
415 ret
416 )]
417 fn do_progress(
418 &mut self,
419 mut any_handle: Vec<NotificationHandle>,
420 ) -> VMResult<DoProgressResponse> {
421 if self.is_implicit_cancellation_enabled() {
422 any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
424
425 match self._do_progress(any_handle) {
426 Ok(DoProgressResponse::AnyCompleted) => {
427 if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
429 for i in 0..self.tracked_invocation_ids.len() {
431 if self.tracked_invocation_ids[i].is_resolved() {
432 continue;
433 }
434
435 let handle = self.tracked_invocation_ids[i].handle;
436
437 match self._do_progress(vec![handle]) {
439 Ok(DoProgressResponse::AnyCompleted) => {
440 let invocation_id = match self.do_transition(CopyNotification(handle)) {
441 Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
442 Ok(Err(_)) => Err(SUSPENDED),
443 _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
444 }?;
445
446 self.tracked_invocation_ids[i].invocation_id =
448 Some(invocation_id);
449 }
450 res => return res,
451 }
452 }
453
454 for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
456 self.sys_cancel_invocation(
457 tracked_invocation_id
458 .invocation_id
459 .expect("We resolved before all the invocation ids"),
460 )?;
461 }
462
463 let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
465
466 Ok(DoProgressResponse::CancelSignalReceived)
468 } else {
469 Ok(DoProgressResponse::AnyCompleted)
470 }
471 }
472 res => res,
473 }
474 } else {
475 self._do_progress(any_handle)
476 }
477 }
478
479 #[instrument(
480 level = "trace",
481 skip(self),
482 fields(
483 restate.invocation.id = self.debug_invocation_id(),
484 restate.protocol.state = self.debug_state(),
485 restate.journal.command_index = self.context.journal.command_index(),
486 restate.protocol.version = %self.version
487 ),
488 ret
489 )]
490 fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
491 match self.do_transition(TakeNotification(handle)) {
492 Ok(Ok(Some(value))) => {
493 if self.is_implicit_cancellation_enabled() {
494 if let Ok(found) = self
497 .tracked_invocation_ids
498 .binary_search_by(|tracked| tracked.handle.cmp(&handle))
499 {
500 let Value::InvocationId(invocation_id) = &value else {
501 panic!("Expecting an invocation id here, but got {value:?}");
502 };
503 self.tracked_invocation_ids
505 .get_mut(found)
506 .unwrap()
507 .invocation_id = Some(invocation_id.clone());
508 }
509 }
510
511 Ok(Some(value))
512 }
513 Ok(Ok(None)) => Ok(None),
514 Ok(Err(_)) => Err(SUSPENDED),
515 Err(e) => Err(e),
516 }
517 }
518
519 #[instrument(
520 level = "trace",
521 skip(self),
522 fields(
523 restate.invocation.id = self.debug_invocation_id(),
524 restate.protocol.state = self.debug_state(),
525 restate.journal.command_index = self.context.journal.command_index(),
526 restate.protocol.version = %self.version
527 ),
528 ret
529 )]
530 fn sys_input(&mut self) -> Result<Input, Error> {
531 self.do_transition(SysInput)
532 }
533
534 #[instrument(
535 level = "trace",
536 skip(self),
537 fields(
538 restate.invocation.id = self.debug_invocation_id(),
539 restate.protocol.state = self.debug_state(),
540 restate.journal.command_index = self.context.journal.command_index(),
541 restate.protocol.version = %self.version
542 ),
543 ret
544 )]
545 fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
546 invocation_debug_logs!(self, "Executing 'Get state {key}'");
547 self.do_transition(SysStateGet(key))
548 }
549
550 #[instrument(
551 level = "trace",
552 skip(self),
553 fields(
554 restate.invocation.id = self.debug_invocation_id(),
555 restate.protocol.state = self.debug_state(),
556 restate.journal.command_index = self.context.journal.command_index(),
557 restate.protocol.version = %self.version
558 ),
559 ret
560 )]
561 fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
562 invocation_debug_logs!(self, "Executing 'Get state keys'");
563 self.do_transition(SysStateGetKeys)
564 }
565
566 #[instrument(
567 level = "trace",
568 skip(self, value),
569 fields(
570 restate.invocation.id = self.debug_invocation_id(),
571 restate.protocol.state = self.debug_state(),
572 restate.journal.command_index = self.context.journal.command_index(),
573 restate.protocol.version = %self.version
574 ),
575 ret
576 )]
577 fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
578 invocation_debug_logs!(self, "Executing 'Set state {key}'");
579 self.context.eager_state.set(key.clone(), value.clone());
580 self.do_transition(SysNonCompletableEntry(
581 "SysStateSet",
582 SetStateCommandMessage {
583 key: Bytes::from(key.into_bytes()),
584 value: Some(value.into()),
585 ..SetStateCommandMessage::default()
586 },
587 ))
588 }
589
590 #[instrument(
591 level = "trace",
592 skip(self),
593 fields(
594 restate.invocation.id = self.debug_invocation_id(),
595 restate.protocol.state = self.debug_state(),
596 restate.journal.command_index = self.context.journal.command_index(),
597 restate.protocol.version = %self.version
598 ),
599 ret
600 )]
601 fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
602 invocation_debug_logs!(self, "Executing 'Clear state {key}'");
603 self.context.eager_state.clear(key.clone());
604 self.do_transition(SysNonCompletableEntry(
605 "SysStateClear",
606 ClearStateCommandMessage {
607 key: Bytes::from(key.into_bytes()),
608 ..ClearStateCommandMessage::default()
609 },
610 ))
611 }
612
613 #[instrument(
614 level = "trace",
615 skip(self),
616 fields(
617 restate.invocation.id = self.debug_invocation_id(),
618 restate.protocol.state = self.debug_state(),
619 restate.journal.command_index = self.context.journal.command_index(),
620 restate.protocol.version = %self.version
621 ),
622 ret
623 )]
624 fn sys_state_clear_all(&mut self) -> Result<(), Error> {
625 invocation_debug_logs!(self, "Executing 'Clear all state'");
626 self.context.eager_state.clear_all();
627 self.do_transition(SysNonCompletableEntry(
628 "SysStateClearAll",
629 ClearAllStateCommandMessage::default(),
630 ))
631 }
632
633 #[instrument(
634 level = "trace",
635 skip(self),
636 fields(
637 restate.invocation.id = self.debug_invocation_id(),
638 restate.protocol.state = self.debug_state(),
639 restate.journal.command_index = self.context.journal.command_index(),
640 restate.protocol.version = %self.version
641 ),
642 ret
643 )]
644 fn sys_sleep(
645 &mut self,
646 name: String,
647 wake_up_time_since_unix_epoch: Duration,
648 now_since_unix_epoch: Option<Duration>,
649 ) -> VMResult<NotificationHandle> {
650 if self.is_processing() {
651 match (&name, now_since_unix_epoch) {
652 (name, Some(now_since_unix_epoch)) if name.is_empty() => {
653 debug!(
654 "Executing 'Timer with duration {:?}'",
655 wake_up_time_since_unix_epoch - now_since_unix_epoch
656 );
657 }
658 (name, Some(now_since_unix_epoch)) => {
659 debug!(
660 "Executing 'Timer {name} with duration {:?}'",
661 wake_up_time_since_unix_epoch - now_since_unix_epoch
662 );
663 }
664 (name, None) if name.is_empty() => {
665 debug!("Executing 'Timer'");
666 }
667 (name, None) => {
668 debug!("Executing 'Timer named {name}'");
669 }
670 }
671 }
672
673 let completion_id = self.context.journal.next_completion_notification_id();
674
675 self.do_transition(SysSimpleCompletableEntry(
676 "SysSleep",
677 SleepCommandMessage {
678 wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
679 .expect("millis since Unix epoch should fit in u64"),
680 result_completion_id: completion_id,
681 name,
682 },
683 completion_id,
684 ))
685 }
686
687 #[instrument(
688 level = "trace",
689 skip(self, input),
690 fields(
691 restate.invocation.id = self.debug_invocation_id(),
692 restate.protocol.state = self.debug_state(),
693 restate.journal.command_index = self.context.journal.command_index(),
694 restate.protocol.version = %self.version
695 ),
696 ret
697 )]
698 fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
699 invocation_debug_logs!(
700 self,
701 "Executing 'Call {}/{}'",
702 target.service,
703 target.handler
704 );
705 if let Some(idempotency_key) = &target.idempotency_key {
706 self.verify_feature_support("attach idempotency key to call", Version::V3)?;
707 if idempotency_key.is_empty() {
708 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
709 unreachable!();
710 }
711 }
712
713 let call_invocation_id_completion_id =
714 self.context.journal.next_completion_notification_id();
715 let result_completion_id = self.context.journal.next_completion_notification_id();
716
717 let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
718 "SysCall",
719 CallCommandMessage {
720 service_name: target.service,
721 handler_name: target.handler,
722 key: target.key.unwrap_or_default(),
723 idempotency_key: target.idempotency_key,
724 headers: target
725 .headers
726 .into_iter()
727 .map(crate::service_protocol::messages::Header::from)
728 .collect(),
729 parameter: input,
730 invocation_id_notification_idx: call_invocation_id_completion_id,
731 result_completion_id,
732 ..Default::default()
733 },
734 vec![call_invocation_id_completion_id, result_completion_id],
735 ))?;
736
737 if matches!(
738 self.options.implicit_cancellation,
739 ImplicitCancellationOption::Enabled {
740 cancel_children_calls: true,
741 ..
742 }
743 ) {
744 self.tracked_invocation_ids.push(TrackedInvocationId {
745 handle: handles[0],
746 invocation_id: None,
747 })
748 }
749
750 Ok(CallHandle {
751 invocation_id_notification_handle: handles[0],
752 call_notification_handle: handles[1],
753 })
754 }
755
756 #[instrument(
757 level = "trace",
758 skip(self, input),
759 fields(
760 restate.invocation.id = self.debug_invocation_id(),
761 restate.protocol.state = self.debug_state(),
762 restate.journal.command_index = self.context.journal.command_index(),
763 restate.protocol.version = %self.version
764 ),
765 ret
766 )]
767 fn sys_send(
768 &mut self,
769 target: Target,
770 input: Bytes,
771 delay: Option<Duration>,
772 ) -> VMResult<SendHandle> {
773 invocation_debug_logs!(
774 self,
775 "Executing 'Send to {}/{}'",
776 target.service,
777 target.handler
778 );
779 if let Some(idempotency_key) = &target.idempotency_key {
780 self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
781 if idempotency_key.is_empty() {
782 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
783 unreachable!();
784 }
785 }
786 let call_invocation_id_completion_id =
787 self.context.journal.next_completion_notification_id();
788 let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
789 "SysOneWayCall",
790 OneWayCallCommandMessage {
791 service_name: target.service,
792 handler_name: target.handler,
793 key: target.key.unwrap_or_default(),
794 idempotency_key: target.idempotency_key,
795 headers: target
796 .headers
797 .into_iter()
798 .map(crate::service_protocol::messages::Header::from)
799 .collect(),
800 parameter: input,
801 invoke_time: delay
802 .map(|d| {
803 u64::try_from(d.as_millis())
804 .expect("millis since Unix epoch should fit in u64")
805 })
806 .unwrap_or_default(),
807 invocation_id_notification_idx: call_invocation_id_completion_id,
808 ..Default::default()
809 },
810 call_invocation_id_completion_id,
811 ))?;
812
813 if matches!(
814 self.options.implicit_cancellation,
815 ImplicitCancellationOption::Enabled {
816 cancel_children_one_way_calls: true,
817 ..
818 }
819 ) {
820 self.tracked_invocation_ids.push(TrackedInvocationId {
821 handle: invocation_id_notification_handle,
822 invocation_id: None,
823 })
824 }
825
826 Ok(SendHandle {
827 invocation_id_notification_handle,
828 })
829 }
830
831 #[instrument(
832 level = "trace",
833 skip(self),
834 fields(
835 restate.invocation.id = self.debug_invocation_id(),
836 restate.protocol.state = self.debug_state(),
837 restate.journal.command_index = self.context.journal.command_index(),
838 restate.protocol.version = %self.version
839 ),
840 ret
841 )]
842 fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
843 invocation_debug_logs!(self, "Executing 'Create awakeable'");
844
845 let signal_id = self.context.journal.next_signal_notification_id();
846
847 let handle = self.do_transition(CreateSignalHandle(
848 "SysAwakeable",
849 NotificationId::SignalId(signal_id),
850 ))?;
851
852 Ok((
853 awakeable_id_str(&self.context.expect_start_info().id, signal_id),
854 handle,
855 ))
856 }
857
858 #[instrument(
859 level = "trace",
860 skip(self, value),
861 fields(
862 restate.invocation.id = self.debug_invocation_id(),
863 restate.protocol.state = self.debug_state(),
864 restate.journal.command_index = self.context.journal.command_index(),
865 restate.protocol.version = %self.version
866 ),
867 ret
868 )]
869 fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
870 invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
871 self.do_transition(SysNonCompletableEntry(
872 "SysCompleteAwakeable",
873 CompleteAwakeableCommandMessage {
874 awakeable_id: id,
875 result: Some(match value {
876 NonEmptyValue::Success(s) => {
877 complete_awakeable_command_message::Result::Value(s.into())
878 }
879 NonEmptyValue::Failure(f) => {
880 complete_awakeable_command_message::Result::Failure(f.into())
881 }
882 }),
883 ..Default::default()
884 },
885 ))
886 }
887
888 #[instrument(
889 level = "trace",
890 skip(self),
891 fields(
892 restate.invocation.id = self.debug_invocation_id(),
893 restate.protocol.state = self.debug_state(),
894 restate.journal.command_index = self.context.journal.command_index(),
895 restate.protocol.version = %self.version
896 ),
897 ret
898 )]
899 fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
900 invocation_debug_logs!(self, "Executing 'Create named signal'");
901
902 self.do_transition(CreateSignalHandle(
903 "SysCreateNamedSignal",
904 NotificationId::SignalName(signal_name),
905 ))
906 }
907
908 #[instrument(
909 level = "trace",
910 skip(self, value),
911 fields(
912 restate.invocation.id = self.debug_invocation_id(),
913 restate.protocol.state = self.debug_state(),
914 restate.journal.command_index = self.context.journal.command_index(),
915 restate.protocol.version = %self.version
916 ),
917 ret
918 )]
919 fn sys_complete_signal(
920 &mut self,
921 target_invocation_id: String,
922 signal_name: String,
923 value: NonEmptyValue,
924 ) -> VMResult<()> {
925 invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
926 self.do_transition(SysNonCompletableEntry(
927 "SysCompleteAwakeable",
928 SendSignalCommandMessage {
929 target_invocation_id,
930 signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
931 result: Some(match value {
932 NonEmptyValue::Success(s) => {
933 send_signal_command_message::Result::Value(s.into())
934 }
935 NonEmptyValue::Failure(f) => {
936 send_signal_command_message::Result::Failure(f.into())
937 }
938 }),
939 ..Default::default()
940 },
941 ))
942 }
943
944 #[instrument(
945 level = "trace",
946 skip(self),
947 fields(
948 restate.invocation.id = self.debug_invocation_id(),
949 restate.protocol.state = self.debug_state(),
950 restate.journal.command_index = self.context.journal.command_index(),
951 restate.protocol.version = %self.version
952 ),
953 ret
954 )]
955 fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
956 invocation_debug_logs!(self, "Executing 'Await promise {key}'");
957
958 let result_completion_id = self.context.journal.next_completion_notification_id();
959 self.do_transition(SysSimpleCompletableEntry(
960 "SysGetPromise",
961 GetPromiseCommandMessage {
962 key,
963 result_completion_id,
964 ..Default::default()
965 },
966 result_completion_id,
967 ))
968 }
969
970 #[instrument(
971 level = "trace",
972 skip(self),
973 fields(
974 restate.invocation.id = self.debug_invocation_id(),
975 restate.protocol.state = self.debug_state(),
976 restate.journal.command_index = self.context.journal.command_index(),
977 restate.protocol.version = %self.version
978 ),
979 ret
980 )]
981 fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
982 invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
983
984 let result_completion_id = self.context.journal.next_completion_notification_id();
985 self.do_transition(SysSimpleCompletableEntry(
986 "SysPeekPromise",
987 PeekPromiseCommandMessage {
988 key,
989 result_completion_id,
990 ..Default::default()
991 },
992 result_completion_id,
993 ))
994 }
995
996 #[instrument(
997 level = "trace",
998 skip(self, value),
999 fields(
1000 restate.invocation.id = self.debug_invocation_id(),
1001 restate.protocol.state = self.debug_state(),
1002 restate.journal.command_index = self.context.journal.command_index(),
1003 restate.protocol.version = %self.version
1004 ),
1005 ret
1006 )]
1007 fn sys_complete_promise(
1008 &mut self,
1009 key: String,
1010 value: NonEmptyValue,
1011 ) -> VMResult<NotificationHandle> {
1012 invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1013
1014 let result_completion_id = self.context.journal.next_completion_notification_id();
1015 self.do_transition(SysSimpleCompletableEntry(
1016 "SysCompletePromise",
1017 CompletePromiseCommandMessage {
1018 key,
1019 completion: Some(match value {
1020 NonEmptyValue::Success(s) => {
1021 complete_promise_command_message::Completion::CompletionValue(s.into())
1022 }
1023 NonEmptyValue::Failure(f) => {
1024 complete_promise_command_message::Completion::CompletionFailure(f.into())
1025 }
1026 }),
1027 result_completion_id,
1028 ..Default::default()
1029 },
1030 result_completion_id,
1031 ))
1032 }
1033
1034 #[instrument(
1035 level = "trace",
1036 skip(self),
1037 fields(
1038 restate.invocation.id = self.debug_invocation_id(),
1039 restate.protocol.state = self.debug_state(),
1040 restate.journal.command_index = self.context.journal.command_index(),
1041 restate.protocol.version = %self.version
1042 ),
1043 ret
1044 )]
1045 fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1046 match self.do_transition(SysRun(name.clone())) {
1047 Ok(handle) => {
1048 if enabled!(Level::DEBUG) {
1049 self.sys_run_names.insert(handle, name);
1051 }
1052 Ok(handle)
1053 }
1054 Err(e) => Err(e),
1055 }
1056 }
1057
1058 #[instrument(
1059 level = "trace",
1060 skip(self, value, retry_policy),
1061 fields(
1062 restate.invocation.id = self.debug_invocation_id(),
1063 restate.protocol.state = self.debug_state(),
1064 restate.journal.command_index = self.context.journal.command_index(),
1065 restate.protocol.version = %self.version
1066 ),
1067 ret
1068 )]
1069 fn propose_run_completion(
1070 &mut self,
1071 notification_handle: NotificationHandle,
1072 value: RunExitResult,
1073 retry_policy: RetryPolicy,
1074 ) -> VMResult<()> {
1075 if enabled!(Level::DEBUG) {
1076 let name: &str = self
1077 .sys_run_names
1078 .get(¬ification_handle)
1079 .map(String::as_str)
1080 .unwrap_or_default();
1081 match &value {
1082 RunExitResult::Success(_) => {
1083 invocation_debug_logs!(self, "Journaling run '{name}' success result");
1084 }
1085 RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1086 invocation_debug_logs!(
1087 self,
1088 "Journaling run '{name}' terminal failure {code} result"
1089 );
1090 }
1091 RunExitResult::RetryableFailure { .. } => {
1092 invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1093 }
1094 }
1095 }
1096
1097 self.do_transition(ProposeRunCompletion(
1098 notification_handle,
1099 value,
1100 retry_policy,
1101 ))
1102 }
1103
1104 #[instrument(
1105 level = "trace",
1106 skip(self),
1107 fields(
1108 restate.invocation.id = self.debug_invocation_id(),
1109 restate.protocol.state = self.debug_state(),
1110 restate.journal.command_index = self.context.journal.command_index(),
1111 restate.protocol.version = %self.version
1112 ),
1113 ret
1114 )]
1115 fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1116 invocation_debug_logs!(
1117 self,
1118 "Executing 'Cancel invocation' of {target_invocation_id}"
1119 );
1120 self.verify_feature_support("cancel invocation", Version::V3)?;
1121 self.do_transition(SysNonCompletableEntry(
1122 "SysCancelInvocation",
1123 SendSignalCommandMessage {
1124 target_invocation_id,
1125 signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1126 result: Some(send_signal_command_message::Result::Void(Default::default())),
1127 ..Default::default()
1128 },
1129 ))
1130 }
1131
1132 #[instrument(
1133 level = "trace",
1134 skip(self),
1135 fields(
1136 restate.invocation.id = self.debug_invocation_id(),
1137 restate.protocol.state = self.debug_state(),
1138 restate.journal.command_index = self.context.journal.command_index(),
1139 restate.protocol.version = %self.version
1140 ),
1141 ret
1142 )]
1143 fn sys_attach_invocation(
1144 &mut self,
1145 target: AttachInvocationTarget,
1146 ) -> VMResult<NotificationHandle> {
1147 invocation_debug_logs!(self, "Executing 'Attach invocation'");
1148 self.verify_feature_support("attach invocation", Version::V3)?;
1149
1150 let result_completion_id = self.context.journal.next_completion_notification_id();
1151 self.do_transition(SysSimpleCompletableEntry(
1152 "SysAttachInvocation",
1153 AttachInvocationCommandMessage {
1154 target: Some(match target {
1155 AttachInvocationTarget::InvocationId(id) => {
1156 attach_invocation_command_message::Target::InvocationId(id)
1157 }
1158 AttachInvocationTarget::WorkflowId { name, key } => {
1159 attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1160 workflow_name: name,
1161 workflow_key: key,
1162 })
1163 }
1164 AttachInvocationTarget::IdempotencyId {
1165 service_name,
1166 service_key,
1167 handler_name,
1168 idempotency_key,
1169 } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1170 IdempotentRequestTarget {
1171 service_name,
1172 service_key,
1173 handler_name,
1174 idempotency_key,
1175 },
1176 ),
1177 }),
1178 result_completion_id,
1179 ..Default::default()
1180 },
1181 result_completion_id,
1182 ))
1183 }
1184
1185 #[instrument(
1186 level = "trace",
1187 skip(self),
1188 fields(
1189 restate.invocation.id = self.debug_invocation_id(),
1190 restate.protocol.state = self.debug_state(),
1191 restate.journal.command_index = self.context.journal.command_index(),
1192 restate.protocol.version = %self.version
1193 ),
1194 ret
1195 )]
1196 fn sys_get_invocation_output(
1197 &mut self,
1198 target: AttachInvocationTarget,
1199 ) -> VMResult<NotificationHandle> {
1200 invocation_debug_logs!(self, "Executing 'Get invocation output'");
1201 self.verify_feature_support("get invocation output", Version::V3)?;
1202
1203 let result_completion_id = self.context.journal.next_completion_notification_id();
1204 self.do_transition(SysSimpleCompletableEntry(
1205 "SysGetInvocationOutput",
1206 GetInvocationOutputCommandMessage {
1207 target: Some(match target {
1208 AttachInvocationTarget::InvocationId(id) => {
1209 get_invocation_output_command_message::Target::InvocationId(id)
1210 }
1211 AttachInvocationTarget::WorkflowId { name, key } => {
1212 get_invocation_output_command_message::Target::WorkflowTarget(
1213 WorkflowTarget {
1214 workflow_name: name,
1215 workflow_key: key,
1216 },
1217 )
1218 }
1219 AttachInvocationTarget::IdempotencyId {
1220 service_name,
1221 service_key,
1222 handler_name,
1223 idempotency_key,
1224 } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1225 IdempotentRequestTarget {
1226 service_name,
1227 service_key,
1228 handler_name,
1229 idempotency_key,
1230 },
1231 ),
1232 }),
1233 result_completion_id,
1234 ..Default::default()
1235 },
1236 result_completion_id,
1237 ))
1238 }
1239
1240 #[instrument(
1241 level = "trace",
1242 skip(self, value),
1243 fields(
1244 restate.invocation.id = self.debug_invocation_id(),
1245 restate.protocol.state = self.debug_state(),
1246 restate.journal.command_index = self.context.journal.command_index(),
1247 restate.protocol.version = %self.version
1248 ),
1249 ret
1250 )]
1251 fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1252 match &value {
1253 NonEmptyValue::Success(_) => {
1254 invocation_debug_logs!(self, "Writing invocation result success value");
1255 }
1256 NonEmptyValue::Failure(_) => {
1257 invocation_debug_logs!(self, "Writing invocation result failure value");
1258 }
1259 }
1260 self.do_transition(SysNonCompletableEntry(
1261 "SysWriteOutput",
1262 OutputCommandMessage {
1263 result: Some(match value {
1264 NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1265 NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1266 }),
1267 ..OutputCommandMessage::default()
1268 },
1269 ))
1270 }
1271
1272 #[instrument(
1273 level = "trace",
1274 skip(self),
1275 fields(
1276 restate.invocation.id = self.debug_invocation_id(),
1277 restate.protocol.state = self.debug_state(),
1278 restate.journal.command_index = self.context.journal.command_index(),
1279 restate.protocol.version = %self.version
1280 ),
1281 ret
1282 )]
1283 fn sys_end(&mut self) -> Result<(), Error> {
1284 invocation_debug_logs!(self, "End of the invocation");
1285 self.do_transition(SysEnd)
1286 }
1287
1288 fn is_waiting_preflight(&self) -> bool {
1289 matches!(
1290 &self.last_transition,
1291 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1292 )
1293 }
1294
1295 fn is_replaying(&self) -> bool {
1296 matches!(&self.last_transition, Ok(State::Replaying { .. }))
1297 }
1298
1299 fn is_processing(&self) -> bool {
1300 matches!(&self.last_transition, Ok(State::Processing { .. }))
1301 }
1302
1303 fn last_command_index(&self) -> i64 {
1304 self.context.journal.command_index()
1305 }
1306}
1307
1308const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1309 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1310 .with_encode_padding(false);
1311const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1312
1313const AWAKEABLE_PREFIX: &str = "sign_1";
1314
1315fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1316 let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1317 input_buf.put_slice(id);
1318 input_buf.put_u32(completion_index);
1319 format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1320}