1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3 attach_invocation_command_message, complete_awakeable_command_message,
4 complete_promise_command_message, get_invocation_output_command_message,
5 output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6 CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7 CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8 GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9 OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10 SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14 UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion, EMPTY_IDEMPOTENCY_KEY,
15};
16use crate::vm::transitions::*;
17use crate::{
18 AttachInvocationTarget, CallHandle, DoProgressResponse, Error, Header,
19 ImplicitCancellationOption, Input, NonEmptyValue, NotificationHandle, ResponseHead,
20 RetryPolicy, RunExitResult, SendHandle, SuspendedOrVMError, TakeOutputResult, Target,
21 TerminalFailure, VMOptions, VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
22};
23use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
24use base64::{alphabet, Engine};
25use bytes::{Buf, BufMut, Bytes, BytesMut};
26use context::{AsyncResultsState, Context, Output, RunState};
27use std::borrow::Cow;
28use std::collections::{HashMap, VecDeque};
29use std::mem::size_of;
30use std::time::Duration;
31use std::{fmt, mem};
32use strum::IntoStaticStr;
33use tracing::{debug, enabled, instrument, Level};
34
35mod context;
36pub(crate) mod errors;
37mod transitions;
38
39const CONTENT_TYPE: &str = "content-type";
40
41#[derive(Debug, IntoStaticStr)]
42pub(crate) enum State {
43 WaitingStart,
44 WaitingReplayEntries {
45 received_entries: u32,
46 commands: VecDeque<RawMessage>,
47 async_results: AsyncResultsState,
48 },
49 Replaying {
50 commands: VecDeque<RawMessage>,
51 run_state: RunState,
52 async_results: AsyncResultsState,
53 },
54 Processing {
55 processing_first_entry: bool,
56 run_state: RunState,
57 async_results: AsyncResultsState,
58 },
59 Ended,
60 Suspended,
61}
62
63impl State {
64 fn as_unexpected_state(&self, event: &'static str) -> Error {
65 UnexpectedStateError::new(self.into(), event).into()
66 }
67}
68
69struct TrackedInvocationId {
70 handle: NotificationHandle,
71 invocation_id: Option<String>,
72}
73
74impl TrackedInvocationId {
75 fn is_resolved(&self) -> bool {
76 self.invocation_id.is_some()
77 }
78}
79
80pub struct CoreVM {
81 version: Version,
82 options: VMOptions,
83
84 decoder: Decoder,
86
87 context: Context,
89 last_transition: Result<State, Error>,
90
91 tracked_invocation_ids: Vec<TrackedInvocationId>,
93
94 sys_run_names: HashMap<NotificationHandle, String>,
96}
97
98impl CoreVM {
99 fn debug_invocation_id(&self) -> &str {
101 if let Some(start_info) = self.context.start_info() {
102 &start_info.debug_id
103 } else {
104 ""
105 }
106 }
107
108 fn debug_state(&self) -> &'static str {
109 match &self.last_transition {
110 Ok(s) => s.into(),
111 Err(_) => "Failed",
112 }
113 }
114
115 fn verify_feature_support(
116 &mut self,
117 feature: &'static str,
118 minimum_required_protocol: Version,
119 ) -> VMResult<()> {
120 if self.version < minimum_required_protocol {
121 return self.do_transition(HitError {
122 error: UnsupportedFeatureForNegotiatedVersion::new(
123 feature,
124 self.version,
125 minimum_required_protocol,
126 )
127 .into(),
128 next_retry_delay: None,
129 });
130 }
131 Ok(())
132 }
133
134 fn _is_completed(&self, handle: NotificationHandle) -> bool {
135 match &self.last_transition {
136 Ok(State::Replaying { async_results, .. })
137 | Ok(State::Processing { async_results, .. }) => {
138 async_results.is_handle_completed(handle)
139 }
140 _ => false,
141 }
142 }
143
144 fn _do_progress(
145 &mut self,
146 any_handle: Vec<NotificationHandle>,
147 ) -> Result<DoProgressResponse, SuspendedOrVMError> {
148 match self.do_transition(DoProgress(any_handle)) {
149 Ok(Ok(do_progress_response)) => Ok(do_progress_response),
150 Ok(Err(suspended)) => Err(SuspendedOrVMError::Suspended(suspended)),
151 Err(e) => Err(SuspendedOrVMError::VM(e)),
152 }
153 }
154
155 fn is_implicit_cancellation_enabled(&self) -> bool {
156 matches!(
157 self.options.implicit_cancellation,
158 ImplicitCancellationOption::Enabled { .. }
159 )
160 }
161}
162
163impl fmt::Debug for CoreVM {
164 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
165 let mut s = f.debug_struct("CoreVM");
166 s.field("version", &self.version);
167
168 if let Some(start_info) = self.context.start_info() {
169 s.field("invocation_id", &start_info.debug_id);
170 }
171
172 match &self.last_transition {
173 Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
174 Err(_) => s.field("last_transition", &"Errored"),
175 };
176
177 s.field("command_index", &self.context.journal.command_index())
178 .field(
179 "notification_index",
180 &self.context.journal.notification_index(),
181 )
182 .finish()
183 }
184}
185
186#[allow(unused)]
188const fn is_send<T: Send>() {}
189const _: () = is_send::<CoreVM>();
190
191macro_rules! invocation_debug_logs {
193 ($this:expr, $($arg:tt)*) => {
194 if ($this.is_processing()) {
195 tracing::debug!($($arg)*)
196 }
197 };
198}
199
200impl super::VM for CoreVM {
201 #[instrument(level = "trace", skip(request_headers), ret)]
202 fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
203 let version = request_headers
204 .extract(CONTENT_TYPE)
205 .map_err(|e| {
206 Error::new(
207 errors::codes::BAD_REQUEST,
208 format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
209 )
210 })?
211 .ok_or(errors::MISSING_CONTENT_TYPE)?
212 .parse::<Version>()?;
213
214 if version < Version::minimum_supported_version()
215 || version > Version::maximum_supported_version()
216 {
217 return Err(Error::new(
218 errors::codes::UNSUPPORTED_MEDIA_TYPE,
219 format!(
220 "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
221 You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
222 version,
223 Version::minimum_supported_version(),
224 Version::maximum_supported_version()
225 ),
226 ));
227 }
228
229 Ok(Self {
230 version,
231 options,
232 decoder: Decoder::new(version),
233 context: Context {
234 input_is_closed: false,
235 output: Output::new(version),
236 start_info: None,
237 journal: Default::default(),
238 eager_state: Default::default(),
239 next_retry_delay: None,
240 },
241 last_transition: Ok(State::WaitingStart),
242 tracked_invocation_ids: vec![],
243 sys_run_names: HashMap::with_capacity(0),
244 })
245 }
246
247 #[instrument(
248 level = "trace",
249 skip(self),
250 fields(
251 restate.invocation.id = self.debug_invocation_id(),
252 restate.protocol.state = self.debug_state(),
253 restate.journal.command_index = self.context.journal.command_index(),
254 restate.protocol.version = %self.version
255 ),
256 ret
257 )]
258 fn get_response_head(&self) -> ResponseHead {
259 ResponseHead {
260 status_code: 200,
261 headers: vec![Header {
262 key: Cow::Borrowed(CONTENT_TYPE),
263 value: Cow::Borrowed(self.version.content_type()),
264 }],
265 version: self.version,
266 }
267 }
268
269 #[instrument(
270 level = "trace",
271 skip(self),
272 fields(
273 restate.invocation.id = self.debug_invocation_id(),
274 restate.protocol.state = self.debug_state(),
275 restate.journal.command_index = self.context.journal.command_index(),
276 restate.protocol.version = %self.version
277 ),
278 ret
279 )]
280 fn notify_input(&mut self, buffer: Bytes) {
281 self.decoder.push(buffer);
282 loop {
283 match self.decoder.consume_next() {
284 Ok(Some(msg)) => {
285 if self.do_transition(NewMessage(msg)).is_err() {
286 return;
287 }
288 }
289 Ok(None) => {
290 return;
291 }
292 Err(e) => {
293 if self
294 .do_transition(HitError {
295 error: e.into(),
296 next_retry_delay: None,
297 })
298 .is_err()
299 {
300 return;
301 }
302 }
303 }
304 }
305 }
306
307 #[instrument(
308 level = "trace",
309 skip(self),
310 fields(
311 restate.invocation.id = self.debug_invocation_id(),
312 restate.protocol.state = self.debug_state(),
313 restate.journal.command_index = self.context.journal.command_index(),
314 restate.protocol.version = %self.version
315 ),
316 ret
317 )]
318 fn notify_input_closed(&mut self) {
319 self.context.input_is_closed = true;
320 let _ = self.do_transition(NotifyInputClosed);
321 }
322
323 #[instrument(
324 level = "trace",
325 skip(self),
326 fields(
327 restate.invocation.id = self.debug_invocation_id(),
328 restate.protocol.state = self.debug_state(),
329 restate.journal.command_index = self.context.journal.command_index(),
330 restate.protocol.version = %self.version
331 ),
332 ret
333 )]
334 fn notify_error(&mut self, error: Error, next_retry_delay: Option<Duration>) {
335 let _ = self.do_transition(HitError {
336 error,
337 next_retry_delay,
338 });
339 }
340
341 #[instrument(
342 level = "trace",
343 skip(self),
344 fields(
345 restate.invocation.id = self.debug_invocation_id(),
346 restate.protocol.state = self.debug_state(),
347 restate.journal.command_index = self.context.journal.command_index(),
348 restate.protocol.version = %self.version
349 ),
350 ret
351 )]
352 fn take_output(&mut self) -> TakeOutputResult {
353 if self.context.output.buffer.has_remaining() {
354 TakeOutputResult::Buffer(
355 self.context
356 .output
357 .buffer
358 .copy_to_bytes(self.context.output.buffer.remaining()),
359 )
360 } else if !self.context.output.is_closed() {
361 TakeOutputResult::Buffer(Bytes::default())
362 } else {
363 TakeOutputResult::EOF
364 }
365 }
366
367 #[instrument(
368 level = "trace",
369 skip(self),
370 fields(
371 restate.invocation.id = self.debug_invocation_id(),
372 restate.protocol.state = self.debug_state(),
373 restate.journal.command_index = self.context.journal.command_index(),
374 restate.protocol.version = %self.version
375 ),
376 ret
377 )]
378 fn is_ready_to_execute(&self) -> Result<bool, Error> {
379 match &self.last_transition {
380 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
381 Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
382 Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
383 Err(e) => Err(e.clone()),
384 }
385 }
386
387 #[instrument(
388 level = "trace",
389 skip(self),
390 fields(
391 restate.invocation.id = self.debug_invocation_id(),
392 restate.protocol.state = self.debug_state(),
393 restate.journal.command_index = self.context.journal.command_index(),
394 restate.protocol.version = %self.version
395 ),
396 ret
397 )]
398 fn is_completed(&self, handle: NotificationHandle) -> bool {
399 self._is_completed(handle)
400 }
401
402 #[instrument(
403 level = "trace",
404 skip(self),
405 fields(
406 restate.invocation.id = self.debug_invocation_id(),
407 restate.protocol.state = self.debug_state(),
408 restate.journal.command_index = self.context.journal.command_index(),
409 restate.protocol.version = %self.version
410 ),
411 ret
412 )]
413 fn do_progress(
414 &mut self,
415 mut any_handle: Vec<NotificationHandle>,
416 ) -> Result<DoProgressResponse, SuspendedOrVMError> {
417 if self.is_implicit_cancellation_enabled() {
418 any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
420
421 match self._do_progress(any_handle) {
422 Ok(DoProgressResponse::AnyCompleted) => {
423 if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
425 for i in 0..self.tracked_invocation_ids.len() {
427 if self.tracked_invocation_ids[i].is_resolved() {
428 continue;
429 }
430
431 let handle = self.tracked_invocation_ids[i].handle;
432
433 match self._do_progress(vec![handle]) {
435 Ok(DoProgressResponse::AnyCompleted) => {
436 let invocation_id = match self.do_transition(CopyNotification(handle)) {
437 Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
438 _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
439 }?;
440
441 self.tracked_invocation_ids[i].invocation_id =
443 Some(invocation_id);
444 }
445 res => return res,
446 }
447 }
448
449 for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
451 self.sys_cancel_invocation(
452 tracked_invocation_id
453 .invocation_id
454 .expect("We resolved before all the invocation ids"),
455 )
456 .map_err(SuspendedOrVMError::VM)?;
457 }
458
459 let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
461
462 Ok(DoProgressResponse::CancelSignalReceived)
464 } else {
465 Ok(DoProgressResponse::AnyCompleted)
466 }
467 }
468 res => res,
469 }
470 } else {
471 self._do_progress(any_handle)
472 }
473 }
474
475 #[instrument(
476 level = "trace",
477 skip(self),
478 fields(
479 restate.invocation.id = self.debug_invocation_id(),
480 restate.protocol.state = self.debug_state(),
481 restate.journal.command_index = self.context.journal.command_index(),
482 restate.protocol.version = %self.version
483 ),
484 ret
485 )]
486 fn take_notification(
487 &mut self,
488 handle: NotificationHandle,
489 ) -> Result<Option<Value>, SuspendedOrVMError> {
490 match self.do_transition(TakeNotification(handle)) {
491 Ok(Ok(Some(value))) => {
492 if self.is_implicit_cancellation_enabled() {
493 if let Ok(found) = self
496 .tracked_invocation_ids
497 .binary_search_by(|tracked| tracked.handle.cmp(&handle))
498 {
499 let Value::InvocationId(invocation_id) = &value else {
500 panic!("Expecting an invocation id here, but got {value:?}");
501 };
502 self.tracked_invocation_ids
504 .get_mut(found)
505 .unwrap()
506 .invocation_id = Some(invocation_id.clone());
507 }
508 }
509
510 Ok(Some(value))
511 }
512 Ok(Ok(None)) => Ok(None),
513 Ok(Err(suspended)) => Err(SuspendedOrVMError::Suspended(suspended)),
514 Err(e) => Err(SuspendedOrVMError::VM(e)),
515 }
516 }
517
518 #[instrument(
519 level = "trace",
520 skip(self),
521 fields(
522 restate.invocation.id = self.debug_invocation_id(),
523 restate.protocol.state = self.debug_state(),
524 restate.journal.command_index = self.context.journal.command_index(),
525 restate.protocol.version = %self.version
526 ),
527 ret
528 )]
529 fn sys_input(&mut self) -> Result<Input, Error> {
530 self.do_transition(SysInput)
531 }
532
533 #[instrument(
534 level = "trace",
535 skip(self),
536 fields(
537 restate.invocation.id = self.debug_invocation_id(),
538 restate.protocol.state = self.debug_state(),
539 restate.journal.command_index = self.context.journal.command_index(),
540 restate.protocol.version = %self.version
541 ),
542 ret
543 )]
544 fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
545 invocation_debug_logs!(self, "Executing 'Get state {key}'");
546 self.do_transition(SysStateGet(key))
547 }
548
549 #[instrument(
550 level = "trace",
551 skip(self),
552 fields(
553 restate.invocation.id = self.debug_invocation_id(),
554 restate.protocol.state = self.debug_state(),
555 restate.journal.command_index = self.context.journal.command_index(),
556 restate.protocol.version = %self.version
557 ),
558 ret
559 )]
560 fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
561 invocation_debug_logs!(self, "Executing 'Get state keys'");
562 self.do_transition(SysStateGetKeys)
563 }
564
565 #[instrument(
566 level = "trace",
567 skip(self, value),
568 fields(
569 restate.invocation.id = self.debug_invocation_id(),
570 restate.protocol.state = self.debug_state(),
571 restate.journal.command_index = self.context.journal.command_index(),
572 restate.protocol.version = %self.version
573 ),
574 ret
575 )]
576 fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
577 invocation_debug_logs!(self, "Executing 'Set state {key}'");
578 self.context.eager_state.set(key.clone(), value.clone());
579 self.do_transition(SysNonCompletableEntry(
580 "SysStateSet",
581 SetStateCommandMessage {
582 key: Bytes::from(key.into_bytes()),
583 value: Some(value.into()),
584 ..SetStateCommandMessage::default()
585 },
586 ))
587 }
588
589 #[instrument(
590 level = "trace",
591 skip(self),
592 fields(
593 restate.invocation.id = self.debug_invocation_id(),
594 restate.protocol.state = self.debug_state(),
595 restate.journal.command_index = self.context.journal.command_index(),
596 restate.protocol.version = %self.version
597 ),
598 ret
599 )]
600 fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
601 invocation_debug_logs!(self, "Executing 'Clear state {key}'");
602 self.context.eager_state.clear(key.clone());
603 self.do_transition(SysNonCompletableEntry(
604 "SysStateClear",
605 ClearStateCommandMessage {
606 key: Bytes::from(key.into_bytes()),
607 ..ClearStateCommandMessage::default()
608 },
609 ))
610 }
611
612 #[instrument(
613 level = "trace",
614 skip(self),
615 fields(
616 restate.invocation.id = self.debug_invocation_id(),
617 restate.protocol.state = self.debug_state(),
618 restate.journal.command_index = self.context.journal.command_index(),
619 restate.protocol.version = %self.version
620 ),
621 ret
622 )]
623 fn sys_state_clear_all(&mut self) -> Result<(), Error> {
624 invocation_debug_logs!(self, "Executing 'Clear all state'");
625 self.context.eager_state.clear_all();
626 self.do_transition(SysNonCompletableEntry(
627 "SysStateClearAll",
628 ClearAllStateCommandMessage::default(),
629 ))
630 }
631
632 #[instrument(
633 level = "trace",
634 skip(self),
635 fields(
636 restate.invocation.id = self.debug_invocation_id(),
637 restate.protocol.state = self.debug_state(),
638 restate.journal.command_index = self.context.journal.command_index(),
639 restate.protocol.version = %self.version
640 ),
641 ret
642 )]
643 fn sys_sleep(
644 &mut self,
645 name: String,
646 wake_up_time_since_unix_epoch: Duration,
647 now_since_unix_epoch: Option<Duration>,
648 ) -> VMResult<NotificationHandle> {
649 if self.is_processing() {
650 match (&name, now_since_unix_epoch) {
651 (name, Some(now_since_unix_epoch)) if name.is_empty() => {
652 debug!(
653 "Executing 'Timer with duration {:?}'",
654 wake_up_time_since_unix_epoch - now_since_unix_epoch
655 );
656 }
657 (name, Some(now_since_unix_epoch)) => {
658 debug!(
659 "Executing 'Timer {name} with duration {:?}'",
660 wake_up_time_since_unix_epoch - now_since_unix_epoch
661 );
662 }
663 (name, None) if name.is_empty() => {
664 debug!("Executing 'Timer'");
665 }
666 (name, None) => {
667 debug!("Executing 'Timer named {name}'");
668 }
669 }
670 }
671
672 let completion_id = self.context.journal.next_completion_notification_id();
673
674 self.do_transition(SysSimpleCompletableEntry(
675 "SysSleep",
676 SleepCommandMessage {
677 wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
678 .expect("millis since Unix epoch should fit in u64"),
679 result_completion_id: completion_id,
680 name,
681 },
682 completion_id,
683 ))
684 }
685
686 #[instrument(
687 level = "trace",
688 skip(self, input),
689 fields(
690 restate.invocation.id = self.debug_invocation_id(),
691 restate.protocol.state = self.debug_state(),
692 restate.journal.command_index = self.context.journal.command_index(),
693 restate.protocol.version = %self.version
694 ),
695 ret
696 )]
697 fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
698 invocation_debug_logs!(
699 self,
700 "Executing 'Call {}/{}'",
701 target.service,
702 target.handler
703 );
704 if let Some(idempotency_key) = &target.idempotency_key {
705 self.verify_feature_support("attach idempotency key to call", Version::V3)?;
706 if idempotency_key.is_empty() {
707 self.do_transition(HitError {
708 error: EMPTY_IDEMPOTENCY_KEY,
709 next_retry_delay: None,
710 })?;
711 unreachable!();
712 }
713 }
714
715 let call_invocation_id_completion_id =
716 self.context.journal.next_completion_notification_id();
717 let result_completion_id = self.context.journal.next_completion_notification_id();
718
719 let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
720 "SysCall",
721 CallCommandMessage {
722 service_name: target.service,
723 handler_name: target.handler,
724 key: target.key.unwrap_or_default(),
725 idempotency_key: target.idempotency_key,
726 headers: target
727 .headers
728 .into_iter()
729 .map(crate::service_protocol::messages::Header::from)
730 .collect(),
731 parameter: input,
732 invocation_id_notification_idx: call_invocation_id_completion_id,
733 result_completion_id,
734 ..Default::default()
735 },
736 vec![call_invocation_id_completion_id, result_completion_id],
737 ))?;
738
739 if matches!(
740 self.options.implicit_cancellation,
741 ImplicitCancellationOption::Enabled {
742 cancel_children_calls: true,
743 ..
744 }
745 ) {
746 self.tracked_invocation_ids.push(TrackedInvocationId {
747 handle: handles[0],
748 invocation_id: None,
749 })
750 }
751
752 Ok(CallHandle {
753 invocation_id_notification_handle: handles[0],
754 call_notification_handle: handles[1],
755 })
756 }
757
758 #[instrument(
759 level = "trace",
760 skip(self, input),
761 fields(
762 restate.invocation.id = self.debug_invocation_id(),
763 restate.protocol.state = self.debug_state(),
764 restate.journal.command_index = self.context.journal.command_index(),
765 restate.protocol.version = %self.version
766 ),
767 ret
768 )]
769 fn sys_send(
770 &mut self,
771 target: Target,
772 input: Bytes,
773 delay: Option<Duration>,
774 ) -> VMResult<SendHandle> {
775 invocation_debug_logs!(
776 self,
777 "Executing 'Send to {}/{}'",
778 target.service,
779 target.handler
780 );
781 if let Some(idempotency_key) = &target.idempotency_key {
782 self.verify_feature_support("attach idempotency key to one way call", Version::V3)?;
783 if idempotency_key.is_empty() {
784 self.do_transition(HitError {
785 error: EMPTY_IDEMPOTENCY_KEY,
786 next_retry_delay: None,
787 })?;
788 unreachable!();
789 }
790 }
791 let call_invocation_id_completion_id =
792 self.context.journal.next_completion_notification_id();
793 let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
794 "SysOneWayCall",
795 OneWayCallCommandMessage {
796 service_name: target.service,
797 handler_name: target.handler,
798 key: target.key.unwrap_or_default(),
799 idempotency_key: target.idempotency_key,
800 headers: target
801 .headers
802 .into_iter()
803 .map(crate::service_protocol::messages::Header::from)
804 .collect(),
805 parameter: input,
806 invoke_time: delay
807 .map(|d| {
808 u64::try_from(d.as_millis())
809 .expect("millis since Unix epoch should fit in u64")
810 })
811 .unwrap_or_default(),
812 invocation_id_notification_idx: call_invocation_id_completion_id,
813 ..Default::default()
814 },
815 call_invocation_id_completion_id,
816 ))?;
817
818 if matches!(
819 self.options.implicit_cancellation,
820 ImplicitCancellationOption::Enabled {
821 cancel_children_one_way_calls: true,
822 ..
823 }
824 ) {
825 self.tracked_invocation_ids.push(TrackedInvocationId {
826 handle: invocation_id_notification_handle,
827 invocation_id: None,
828 })
829 }
830
831 Ok(SendHandle {
832 invocation_id_notification_handle,
833 })
834 }
835
836 #[instrument(
837 level = "trace",
838 skip(self),
839 fields(
840 restate.invocation.id = self.debug_invocation_id(),
841 restate.protocol.state = self.debug_state(),
842 restate.journal.command_index = self.context.journal.command_index(),
843 restate.protocol.version = %self.version
844 ),
845 ret
846 )]
847 fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
848 invocation_debug_logs!(self, "Executing 'Create awakeable'");
849
850 let signal_id = self.context.journal.next_signal_notification_id();
851
852 let handle = self.do_transition(CreateSignalHandle(
853 "SysAwakeable",
854 NotificationId::SignalId(signal_id),
855 ))?;
856
857 Ok((
858 awakeable_id_str(&self.context.expect_start_info().id, signal_id),
859 handle,
860 ))
861 }
862
863 #[instrument(
864 level = "trace",
865 skip(self, value),
866 fields(
867 restate.invocation.id = self.debug_invocation_id(),
868 restate.protocol.state = self.debug_state(),
869 restate.journal.command_index = self.context.journal.command_index(),
870 restate.protocol.version = %self.version
871 ),
872 ret
873 )]
874 fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
875 invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
876 self.do_transition(SysNonCompletableEntry(
877 "SysCompleteAwakeable",
878 CompleteAwakeableCommandMessage {
879 awakeable_id: id,
880 result: Some(match value {
881 NonEmptyValue::Success(s) => {
882 complete_awakeable_command_message::Result::Value(s.into())
883 }
884 NonEmptyValue::Failure(f) => {
885 complete_awakeable_command_message::Result::Failure(f.into())
886 }
887 }),
888 ..Default::default()
889 },
890 ))
891 }
892
893 #[instrument(
894 level = "trace",
895 skip(self),
896 fields(
897 restate.invocation.id = self.debug_invocation_id(),
898 restate.protocol.state = self.debug_state(),
899 restate.journal.command_index = self.context.journal.command_index(),
900 restate.protocol.version = %self.version
901 ),
902 ret
903 )]
904 fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
905 invocation_debug_logs!(self, "Executing 'Create named signal'");
906
907 self.do_transition(CreateSignalHandle(
908 "SysCreateNamedSignal",
909 NotificationId::SignalName(signal_name),
910 ))
911 }
912
913 #[instrument(
914 level = "trace",
915 skip(self, value),
916 fields(
917 restate.invocation.id = self.debug_invocation_id(),
918 restate.protocol.state = self.debug_state(),
919 restate.journal.command_index = self.context.journal.command_index(),
920 restate.protocol.version = %self.version
921 ),
922 ret
923 )]
924 fn sys_complete_signal(
925 &mut self,
926 target_invocation_id: String,
927 signal_name: String,
928 value: NonEmptyValue,
929 ) -> VMResult<()> {
930 invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
931 self.do_transition(SysNonCompletableEntry(
932 "SysCompleteAwakeable",
933 SendSignalCommandMessage {
934 target_invocation_id,
935 signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
936 result: Some(match value {
937 NonEmptyValue::Success(s) => {
938 send_signal_command_message::Result::Value(s.into())
939 }
940 NonEmptyValue::Failure(f) => {
941 send_signal_command_message::Result::Failure(f.into())
942 }
943 }),
944 ..Default::default()
945 },
946 ))
947 }
948
949 #[instrument(
950 level = "trace",
951 skip(self),
952 fields(
953 restate.invocation.id = self.debug_invocation_id(),
954 restate.protocol.state = self.debug_state(),
955 restate.journal.command_index = self.context.journal.command_index(),
956 restate.protocol.version = %self.version
957 ),
958 ret
959 )]
960 fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
961 invocation_debug_logs!(self, "Executing 'Await promise {key}'");
962
963 let result_completion_id = self.context.journal.next_completion_notification_id();
964 self.do_transition(SysSimpleCompletableEntry(
965 "SysGetPromise",
966 GetPromiseCommandMessage {
967 key,
968 result_completion_id,
969 ..Default::default()
970 },
971 result_completion_id,
972 ))
973 }
974
975 #[instrument(
976 level = "trace",
977 skip(self),
978 fields(
979 restate.invocation.id = self.debug_invocation_id(),
980 restate.protocol.state = self.debug_state(),
981 restate.journal.command_index = self.context.journal.command_index(),
982 restate.protocol.version = %self.version
983 ),
984 ret
985 )]
986 fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
987 invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
988
989 let result_completion_id = self.context.journal.next_completion_notification_id();
990 self.do_transition(SysSimpleCompletableEntry(
991 "SysPeekPromise",
992 PeekPromiseCommandMessage {
993 key,
994 result_completion_id,
995 ..Default::default()
996 },
997 result_completion_id,
998 ))
999 }
1000
1001 #[instrument(
1002 level = "trace",
1003 skip(self, value),
1004 fields(
1005 restate.invocation.id = self.debug_invocation_id(),
1006 restate.protocol.state = self.debug_state(),
1007 restate.journal.command_index = self.context.journal.command_index(),
1008 restate.protocol.version = %self.version
1009 ),
1010 ret
1011 )]
1012 fn sys_complete_promise(
1013 &mut self,
1014 key: String,
1015 value: NonEmptyValue,
1016 ) -> VMResult<NotificationHandle> {
1017 invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1018
1019 let result_completion_id = self.context.journal.next_completion_notification_id();
1020 self.do_transition(SysSimpleCompletableEntry(
1021 "SysCompletePromise",
1022 CompletePromiseCommandMessage {
1023 key,
1024 completion: Some(match value {
1025 NonEmptyValue::Success(s) => {
1026 complete_promise_command_message::Completion::CompletionValue(s.into())
1027 }
1028 NonEmptyValue::Failure(f) => {
1029 complete_promise_command_message::Completion::CompletionFailure(f.into())
1030 }
1031 }),
1032 result_completion_id,
1033 ..Default::default()
1034 },
1035 result_completion_id,
1036 ))
1037 }
1038
1039 #[instrument(
1040 level = "trace",
1041 skip(self),
1042 fields(
1043 restate.invocation.id = self.debug_invocation_id(),
1044 restate.protocol.state = self.debug_state(),
1045 restate.journal.command_index = self.context.journal.command_index(),
1046 restate.protocol.version = %self.version
1047 ),
1048 ret
1049 )]
1050 fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1051 match self.do_transition(SysRun(name.clone())) {
1052 Ok(handle) => {
1053 if enabled!(Level::DEBUG) {
1054 self.sys_run_names.insert(handle, name);
1056 }
1057 Ok(handle)
1058 }
1059 Err(e) => Err(e),
1060 }
1061 }
1062
1063 #[instrument(
1064 level = "trace",
1065 skip(self, value, retry_policy),
1066 fields(
1067 restate.invocation.id = self.debug_invocation_id(),
1068 restate.protocol.state = self.debug_state(),
1069 restate.journal.command_index = self.context.journal.command_index(),
1070 restate.protocol.version = %self.version
1071 ),
1072 ret
1073 )]
1074 fn propose_run_completion(
1075 &mut self,
1076 notification_handle: NotificationHandle,
1077 value: RunExitResult,
1078 retry_policy: RetryPolicy,
1079 ) -> VMResult<()> {
1080 if enabled!(Level::DEBUG) {
1081 let name: &str = self
1082 .sys_run_names
1083 .get(¬ification_handle)
1084 .map(String::as_str)
1085 .unwrap_or_default();
1086 match &value {
1087 RunExitResult::Success(_) => {
1088 invocation_debug_logs!(self, "Journaling run '{name}' success result");
1089 }
1090 RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1091 invocation_debug_logs!(
1092 self,
1093 "Journaling run '{name}' terminal failure {code} result"
1094 );
1095 }
1096 RunExitResult::RetryableFailure { .. } => {
1097 invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1098 }
1099 }
1100 }
1101
1102 self.do_transition(ProposeRunCompletion(
1103 notification_handle,
1104 value,
1105 retry_policy,
1106 ))
1107 }
1108
1109 #[instrument(
1110 level = "trace",
1111 skip(self),
1112 fields(
1113 restate.invocation.id = self.debug_invocation_id(),
1114 restate.protocol.state = self.debug_state(),
1115 restate.journal.command_index = self.context.journal.command_index(),
1116 restate.protocol.version = %self.version
1117 ),
1118 ret
1119 )]
1120 fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1121 invocation_debug_logs!(
1122 self,
1123 "Executing 'Cancel invocation' of {target_invocation_id}"
1124 );
1125 self.verify_feature_support("cancel invocation", Version::V3)?;
1126 self.do_transition(SysNonCompletableEntry(
1127 "SysCancelInvocation",
1128 SendSignalCommandMessage {
1129 target_invocation_id,
1130 signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1131 result: Some(send_signal_command_message::Result::Void(Default::default())),
1132 ..Default::default()
1133 },
1134 ))
1135 }
1136
1137 #[instrument(
1138 level = "trace",
1139 skip(self),
1140 fields(
1141 restate.invocation.id = self.debug_invocation_id(),
1142 restate.protocol.state = self.debug_state(),
1143 restate.journal.command_index = self.context.journal.command_index(),
1144 restate.protocol.version = %self.version
1145 ),
1146 ret
1147 )]
1148 fn sys_attach_invocation(
1149 &mut self,
1150 target: AttachInvocationTarget,
1151 ) -> VMResult<NotificationHandle> {
1152 invocation_debug_logs!(self, "Executing 'Attach invocation'");
1153 self.verify_feature_support("attach invocation", Version::V3)?;
1154
1155 let result_completion_id = self.context.journal.next_completion_notification_id();
1156 self.do_transition(SysSimpleCompletableEntry(
1157 "SysAttachInvocation",
1158 AttachInvocationCommandMessage {
1159 target: Some(match target {
1160 AttachInvocationTarget::InvocationId(id) => {
1161 attach_invocation_command_message::Target::InvocationId(id)
1162 }
1163 AttachInvocationTarget::WorkflowId { name, key } => {
1164 attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1165 workflow_name: name,
1166 workflow_key: key,
1167 })
1168 }
1169 AttachInvocationTarget::IdempotencyId {
1170 service_name,
1171 service_key,
1172 handler_name,
1173 idempotency_key,
1174 } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1175 IdempotentRequestTarget {
1176 service_name,
1177 service_key,
1178 handler_name,
1179 idempotency_key,
1180 },
1181 ),
1182 }),
1183 result_completion_id,
1184 ..Default::default()
1185 },
1186 result_completion_id,
1187 ))
1188 }
1189
1190 #[instrument(
1191 level = "trace",
1192 skip(self),
1193 fields(
1194 restate.invocation.id = self.debug_invocation_id(),
1195 restate.protocol.state = self.debug_state(),
1196 restate.journal.command_index = self.context.journal.command_index(),
1197 restate.protocol.version = %self.version
1198 ),
1199 ret
1200 )]
1201 fn sys_get_invocation_output(
1202 &mut self,
1203 target: AttachInvocationTarget,
1204 ) -> VMResult<NotificationHandle> {
1205 invocation_debug_logs!(self, "Executing 'Get invocation output'");
1206 self.verify_feature_support("get invocation output", Version::V3)?;
1207
1208 let result_completion_id = self.context.journal.next_completion_notification_id();
1209 self.do_transition(SysSimpleCompletableEntry(
1210 "SysGetInvocationOutput",
1211 GetInvocationOutputCommandMessage {
1212 target: Some(match target {
1213 AttachInvocationTarget::InvocationId(id) => {
1214 get_invocation_output_command_message::Target::InvocationId(id)
1215 }
1216 AttachInvocationTarget::WorkflowId { name, key } => {
1217 get_invocation_output_command_message::Target::WorkflowTarget(
1218 WorkflowTarget {
1219 workflow_name: name,
1220 workflow_key: key,
1221 },
1222 )
1223 }
1224 AttachInvocationTarget::IdempotencyId {
1225 service_name,
1226 service_key,
1227 handler_name,
1228 idempotency_key,
1229 } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1230 IdempotentRequestTarget {
1231 service_name,
1232 service_key,
1233 handler_name,
1234 idempotency_key,
1235 },
1236 ),
1237 }),
1238 result_completion_id,
1239 ..Default::default()
1240 },
1241 result_completion_id,
1242 ))
1243 }
1244
1245 #[instrument(
1246 level = "trace",
1247 skip(self, value),
1248 fields(
1249 restate.invocation.id = self.debug_invocation_id(),
1250 restate.protocol.state = self.debug_state(),
1251 restate.journal.command_index = self.context.journal.command_index(),
1252 restate.protocol.version = %self.version
1253 ),
1254 ret
1255 )]
1256 fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1257 match &value {
1258 NonEmptyValue::Success(_) => {
1259 invocation_debug_logs!(self, "Writing invocation result success value");
1260 }
1261 NonEmptyValue::Failure(_) => {
1262 invocation_debug_logs!(self, "Writing invocation result failure value");
1263 }
1264 }
1265 self.do_transition(SysNonCompletableEntry(
1266 "SysWriteOutput",
1267 OutputCommandMessage {
1268 result: Some(match value {
1269 NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1270 NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1271 }),
1272 ..OutputCommandMessage::default()
1273 },
1274 ))
1275 }
1276
1277 #[instrument(
1278 level = "trace",
1279 skip(self),
1280 fields(
1281 restate.invocation.id = self.debug_invocation_id(),
1282 restate.protocol.state = self.debug_state(),
1283 restate.journal.command_index = self.context.journal.command_index(),
1284 restate.protocol.version = %self.version
1285 ),
1286 ret
1287 )]
1288 fn sys_end(&mut self) -> Result<(), Error> {
1289 invocation_debug_logs!(self, "End of the invocation");
1290 self.do_transition(SysEnd)
1291 }
1292
1293 fn is_waiting_preflight(&self) -> bool {
1294 matches!(
1295 &self.last_transition,
1296 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1297 )
1298 }
1299
1300 fn is_replaying(&self) -> bool {
1301 matches!(&self.last_transition, Ok(State::Replaying { .. }))
1302 }
1303
1304 fn is_processing(&self) -> bool {
1305 matches!(&self.last_transition, Ok(State::Processing { .. }))
1306 }
1307}
1308
1309const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1310 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1311 .with_encode_padding(false);
1312const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1313
1314const AWAKEABLE_PREFIX: &str = "sign_1";
1315
1316fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1317 let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1318 input_buf.put_slice(id);
1319 input_buf.put_u32(completion_index);
1320 format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1321}