1use crate::headers::HeaderMap;
2use crate::service_protocol::messages::{
3 attach_invocation_command_message, complete_awakeable_command_message,
4 complete_promise_command_message, get_invocation_output_command_message,
5 output_command_message, send_signal_command_message, AttachInvocationCommandMessage,
6 CallCommandMessage, ClearAllStateCommandMessage, ClearStateCommandMessage,
7 CompleteAwakeableCommandMessage, CompletePromiseCommandMessage,
8 GetInvocationOutputCommandMessage, GetPromiseCommandMessage, IdempotentRequestTarget,
9 OneWayCallCommandMessage, OutputCommandMessage, PeekPromiseCommandMessage,
10 SendSignalCommandMessage, SetStateCommandMessage, SleepCommandMessage, WorkflowTarget,
11};
12use crate::service_protocol::{Decoder, NotificationId, RawMessage, Version, CANCEL_SIGNAL_ID};
13use crate::vm::errors::{
14 ClosedError, UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion,
15 EMPTY_IDEMPOTENCY_KEY, SUSPENDED,
16};
17use crate::vm::transitions::*;
18use crate::{
19 AttachInvocationTarget, CallHandle, CommandRelationship, DoProgressResponse, Error, Header,
20 ImplicitCancellationOption, Input, NonDeterministicChecksOption, NonEmptyValue,
21 NotificationHandle, ResponseHead, RetryPolicy, RunExitResult, SendHandle, TakeOutputResult,
22 Target, TerminalFailure, VMOptions, VMResult, Value, CANCEL_NOTIFICATION_HANDLE,
23};
24use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig};
25use base64::{alphabet, Engine};
26use bytes::{Buf, BufMut, Bytes, BytesMut};
27use context::{AsyncResultsState, Context, Output, RunState};
28use std::borrow::Cow;
29use std::collections::{HashMap, VecDeque};
30use std::mem::size_of;
31use std::time::Duration;
32use std::{fmt, mem};
33use strum::IntoStaticStr;
34use tracing::{debug, enabled, instrument, Level};
35
36mod context;
37pub(crate) mod errors;
38mod transitions;
39
40const CONTENT_TYPE: &str = "content-type";
41
42#[derive(Debug, IntoStaticStr)]
43pub(crate) enum State {
44 WaitingStart,
45 WaitingReplayEntries {
46 received_entries: u32,
47 commands: VecDeque<RawMessage>,
48 async_results: AsyncResultsState,
49 },
50 Replaying {
51 commands: VecDeque<RawMessage>,
52 run_state: RunState,
53 async_results: AsyncResultsState,
54 },
55 Processing {
56 processing_first_entry: bool,
57 run_state: RunState,
58 async_results: AsyncResultsState,
59 },
60 Closed,
61}
62
63impl State {
64 fn as_unexpected_state(&self, event: &'static str) -> Error {
65 if matches!(self, State::Closed) {
66 return ClosedError::new(event).into();
67 }
68 UnexpectedStateError::new(self.into(), event).into()
69 }
70}
71
72struct TrackedInvocationId {
73 handle: NotificationHandle,
74 invocation_id: Option<String>,
75}
76
77impl TrackedInvocationId {
78 fn is_resolved(&self) -> bool {
79 self.invocation_id.is_some()
80 }
81}
82
83pub struct CoreVM {
84 options: VMOptions,
85
86 decoder: Decoder,
88
89 context: Context,
91 last_transition: Result<State, Error>,
92
93 tracked_invocation_ids: Vec<TrackedInvocationId>,
95
96 sys_run_names: HashMap<NotificationHandle, String>,
98}
99
100impl CoreVM {
101 fn debug_invocation_id(&self) -> &str {
103 if let Some(start_info) = self.context.start_info() {
104 &start_info.debug_id
105 } else {
106 ""
107 }
108 }
109
110 fn debug_state(&self) -> &'static str {
111 match &self.last_transition {
112 Ok(s) => s.into(),
113 Err(_) => "Failed",
114 }
115 }
116
117 #[allow(dead_code)]
118 fn verify_feature_support(
119 &mut self,
120 feature: &'static str,
121 minimum_required_protocol: Version,
122 ) -> VMResult<()> {
123 if self.context.negotiated_protocol_version < minimum_required_protocol {
124 return self.do_transition(HitError(
125 UnsupportedFeatureForNegotiatedVersion::new(
126 feature,
127 self.context.negotiated_protocol_version,
128 minimum_required_protocol,
129 )
130 .into(),
131 ));
132 }
133 Ok(())
134 }
135
136 fn _is_completed(&self, handle: NotificationHandle) -> bool {
137 match &self.last_transition {
138 Ok(State::Replaying { async_results, .. })
139 | Ok(State::Processing { async_results, .. }) => {
140 async_results.is_handle_completed(handle)
141 }
142 _ => false,
143 }
144 }
145
146 fn _do_progress(
147 &mut self,
148 any_handle: Vec<NotificationHandle>,
149 ) -> Result<DoProgressResponse, Error> {
150 match self.do_transition(DoProgress(any_handle)) {
151 Ok(Ok(do_progress_response)) => Ok(do_progress_response),
152 Ok(Err(_)) => Err(SUSPENDED),
153 Err(e) => Err(e),
154 }
155 }
156
157 fn is_implicit_cancellation_enabled(&self) -> bool {
158 matches!(
159 self.options.implicit_cancellation,
160 ImplicitCancellationOption::Enabled { .. }
161 )
162 }
163}
164
165impl fmt::Debug for CoreVM {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 let mut s = f.debug_struct("CoreVM");
168 s.field("version", &self.context.negotiated_protocol_version);
169
170 if let Some(start_info) = self.context.start_info() {
171 s.field("invocation_id", &start_info.debug_id);
172 }
173
174 match &self.last_transition {
175 Ok(state) => s.field("last_transition", &<&'static str>::from(state)),
176 Err(_) => s.field("last_transition", &"Errored"),
177 };
178
179 s.field("command_index", &self.context.journal.command_index())
180 .field(
181 "notification_index",
182 &self.context.journal.notification_index(),
183 )
184 .finish()
185 }
186}
187
188#[allow(unused)]
190const fn is_send<T: Send>() {}
191const _: () = is_send::<CoreVM>();
192
193macro_rules! invocation_debug_logs {
195 ($this:expr, $($arg:tt)*) => {
196 if ($this.is_processing()) {
197 tracing::debug!($($arg)*)
198 }
199 };
200}
201
202impl super::VM for CoreVM {
203 #[instrument(level = "trace", skip(request_headers), ret)]
204 fn new(request_headers: impl HeaderMap, options: VMOptions) -> Result<Self, Error> {
205 let version = request_headers
206 .extract(CONTENT_TYPE)
207 .map_err(|e| {
208 Error::new(
209 errors::codes::BAD_REQUEST,
210 format!("cannot read '{CONTENT_TYPE}' header: {e:?}"),
211 )
212 })?
213 .ok_or(errors::MISSING_CONTENT_TYPE)?
214 .parse::<Version>()?;
215
216 if version < Version::minimum_supported_version()
217 || version > Version::maximum_supported_version()
218 {
219 return Err(Error::new(
220 errors::codes::UNSUPPORTED_MEDIA_TYPE,
221 format!(
222 "Unsupported protocol version {:?}, not within [{:?} to {:?}]. \
223 You might need to rediscover the service, check https://docs.restate.dev/references/errors/#RT0015",
224 version,
225 Version::minimum_supported_version(),
226 Version::maximum_supported_version()
227 ),
228 ));
229 }
230 let non_deterministic_checks_ignore_payload_equality = matches!(
231 options.non_determinism_checks,
232 NonDeterministicChecksOption::PayloadChecksDisabled
233 );
234
235 Ok(Self {
236 options,
237 decoder: Decoder::new(version),
238 context: Context {
239 input_is_closed: false,
240 output: Output::new(version),
241 start_info: None,
242 journal: Default::default(),
243 eager_state: Default::default(),
244 non_deterministic_checks_ignore_payload_equality,
245 negotiated_protocol_version: version,
246 },
247 last_transition: Ok(State::WaitingStart),
248 tracked_invocation_ids: vec![],
249 sys_run_names: HashMap::with_capacity(0),
250 })
251 }
252
253 #[instrument(
254 level = "trace",
255 skip(self),
256 fields(
257 restate.invocation.id = self.debug_invocation_id(),
258 restate.protocol.state = self.debug_state(),
259 restate.journal.command_index = self.context.journal.command_index(),
260 restate.protocol.version = %self.context.negotiated_protocol_version
261 ),
262 ret
263 )]
264 fn get_response_head(&self) -> ResponseHead {
265 ResponseHead {
266 status_code: 200,
267 headers: vec![Header {
268 key: Cow::Borrowed(CONTENT_TYPE),
269 value: Cow::Borrowed(self.context.negotiated_protocol_version.content_type()),
270 }],
271 version: self.context.negotiated_protocol_version,
272 }
273 }
274
275 #[instrument(
276 level = "trace",
277 skip(self),
278 fields(
279 restate.invocation.id = self.debug_invocation_id(),
280 restate.protocol.state = self.debug_state(),
281 restate.journal.command_index = self.context.journal.command_index(),
282 restate.protocol.version = %self.context.negotiated_protocol_version
283 ),
284 ret
285 )]
286 fn notify_input(&mut self, buffer: Bytes) {
287 self.decoder.push(buffer);
288 loop {
289 match self.decoder.consume_next() {
290 Ok(Some(msg)) => {
291 if self.do_transition(NewMessage(msg)).is_err() {
292 return;
293 }
294 }
295 Ok(None) => {
296 return;
297 }
298 Err(e) => {
299 if self.do_transition(HitError(e.into())).is_err() {
300 return;
301 }
302 }
303 }
304 }
305 }
306
307 #[instrument(
308 level = "trace",
309 skip(self),
310 fields(
311 restate.invocation.id = self.debug_invocation_id(),
312 restate.protocol.state = self.debug_state(),
313 restate.journal.command_index = self.context.journal.command_index(),
314 restate.protocol.version = %self.context.negotiated_protocol_version
315 ),
316 ret
317 )]
318 fn notify_input_closed(&mut self) {
319 self.context.input_is_closed = true;
320 let _ = self.do_transition(NotifyInputClosed);
321 }
322
323 #[instrument(
324 level = "trace",
325 skip(self),
326 fields(
327 restate.invocation.id = self.debug_invocation_id(),
328 restate.protocol.state = self.debug_state(),
329 restate.journal.command_index = self.context.journal.command_index(),
330 restate.protocol.version = %self.context.negotiated_protocol_version
331 ),
332 ret
333 )]
334 fn notify_error(
335 &mut self,
336 mut error: Error,
337 command_relationship: Option<CommandRelationship>,
338 ) {
339 if let Some(command_relationship) = command_relationship {
340 error = error.with_related_command_metadata(
341 self.context
342 .journal
343 .resolve_related_command(command_relationship),
344 );
345 }
346
347 let _ = self.do_transition(HitError(error));
348 }
349
350 #[instrument(
351 level = "trace",
352 skip(self),
353 fields(
354 restate.invocation.id = self.debug_invocation_id(),
355 restate.protocol.state = self.debug_state(),
356 restate.journal.command_index = self.context.journal.command_index(),
357 restate.protocol.version = %self.context.negotiated_protocol_version
358 ),
359 ret
360 )]
361 fn take_output(&mut self) -> TakeOutputResult {
362 if self.context.output.buffer.has_remaining() {
363 TakeOutputResult::Buffer(
364 self.context
365 .output
366 .buffer
367 .copy_to_bytes(self.context.output.buffer.remaining()),
368 )
369 } else if !self.context.output.is_closed() {
370 TakeOutputResult::Buffer(Bytes::default())
371 } else {
372 TakeOutputResult::EOF
373 }
374 }
375
376 #[instrument(
377 level = "trace",
378 skip(self),
379 fields(
380 restate.invocation.id = self.debug_invocation_id(),
381 restate.protocol.state = self.debug_state(),
382 restate.journal.command_index = self.context.journal.command_index(),
383 restate.protocol.version = %self.context.negotiated_protocol_version
384 ),
385 ret
386 )]
387 fn is_ready_to_execute(&self) -> Result<bool, Error> {
388 match &self.last_transition {
389 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. }) => Ok(false),
390 Ok(State::Processing { .. }) | Ok(State::Replaying { .. }) => Ok(true),
391 Ok(s) => Err(s.as_unexpected_state("IsReadyToExecute")),
392 Err(e) => Err(e.clone()),
393 }
394 }
395
396 #[instrument(
397 level = "trace",
398 skip(self),
399 fields(
400 restate.invocation.id = self.debug_invocation_id(),
401 restate.protocol.state = self.debug_state(),
402 restate.journal.command_index = self.context.journal.command_index(),
403 restate.protocol.version = %self.context.negotiated_protocol_version
404 ),
405 ret
406 )]
407 fn is_completed(&self, handle: NotificationHandle) -> bool {
408 self._is_completed(handle)
409 }
410
411 #[instrument(
412 level = "trace",
413 skip(self),
414 fields(
415 restate.invocation.id = self.debug_invocation_id(),
416 restate.protocol.state = self.debug_state(),
417 restate.journal.command_index = self.context.journal.command_index(),
418 restate.protocol.version = %self.context.negotiated_protocol_version
419 ),
420 ret
421 )]
422 fn do_progress(
423 &mut self,
424 mut any_handle: Vec<NotificationHandle>,
425 ) -> VMResult<DoProgressResponse> {
426 if self.is_implicit_cancellation_enabled() {
427 any_handle.insert(0, CANCEL_NOTIFICATION_HANDLE);
429
430 match self._do_progress(any_handle) {
431 Ok(DoProgressResponse::AnyCompleted) => {
432 if self._is_completed(CANCEL_NOTIFICATION_HANDLE) {
434 for i in 0..self.tracked_invocation_ids.len() {
436 if self.tracked_invocation_ids[i].is_resolved() {
437 continue;
438 }
439
440 let handle = self.tracked_invocation_ids[i].handle;
441
442 match self._do_progress(vec![handle]) {
444 Ok(DoProgressResponse::AnyCompleted) => {
445 let invocation_id = match self.do_transition(CopyNotification(handle)) {
446 Ok(Ok(Some(Value::InvocationId(invocation_id)))) => Ok(invocation_id),
447 Ok(Err(_)) => Err(SUSPENDED),
448 _ => panic!("Unexpected variant! If the id handle is completed, it must be an invocation id handle!")
449 }?;
450
451 self.tracked_invocation_ids[i].invocation_id =
453 Some(invocation_id);
454 }
455 res => return res,
456 }
457 }
458
459 for tracked_invocation_id in mem::take(&mut self.tracked_invocation_ids) {
461 self.sys_cancel_invocation(
462 tracked_invocation_id
463 .invocation_id
464 .expect("We resolved before all the invocation ids"),
465 )?;
466 }
467
468 let _ = self.take_notification(CANCEL_NOTIFICATION_HANDLE);
470
471 Ok(DoProgressResponse::CancelSignalReceived)
473 } else {
474 Ok(DoProgressResponse::AnyCompleted)
475 }
476 }
477 res => res,
478 }
479 } else {
480 self._do_progress(any_handle)
481 }
482 }
483
484 #[instrument(
485 level = "trace",
486 skip(self),
487 fields(
488 restate.invocation.id = self.debug_invocation_id(),
489 restate.protocol.state = self.debug_state(),
490 restate.journal.command_index = self.context.journal.command_index(),
491 restate.protocol.version = %self.context.negotiated_protocol_version
492 ),
493 ret
494 )]
495 fn take_notification(&mut self, handle: NotificationHandle) -> VMResult<Option<Value>> {
496 match self.do_transition(TakeNotification(handle)) {
497 Ok(Ok(Some(value))) => {
498 if self.is_implicit_cancellation_enabled() {
499 if let Ok(found) = self
502 .tracked_invocation_ids
503 .binary_search_by(|tracked| tracked.handle.cmp(&handle))
504 {
505 let Value::InvocationId(invocation_id) = &value else {
506 panic!("Expecting an invocation id here, but got {value:?}");
507 };
508 self.tracked_invocation_ids
510 .get_mut(found)
511 .unwrap()
512 .invocation_id = Some(invocation_id.clone());
513 }
514 }
515
516 Ok(Some(value))
517 }
518 Ok(Ok(None)) => Ok(None),
519 Ok(Err(_)) => Err(SUSPENDED),
520 Err(e) => Err(e),
521 }
522 }
523
524 #[instrument(
525 level = "trace",
526 skip(self),
527 fields(
528 restate.invocation.id = self.debug_invocation_id(),
529 restate.protocol.state = self.debug_state(),
530 restate.journal.command_index = self.context.journal.command_index(),
531 restate.protocol.version = %self.context.negotiated_protocol_version
532 ),
533 ret
534 )]
535 fn sys_input(&mut self) -> Result<Input, Error> {
536 self.do_transition(SysInput)
537 }
538
539 #[instrument(
540 level = "trace",
541 skip(self),
542 fields(
543 restate.invocation.id = self.debug_invocation_id(),
544 restate.protocol.state = self.debug_state(),
545 restate.journal.command_index = self.context.journal.command_index(),
546 restate.protocol.version = %self.context.negotiated_protocol_version
547 ),
548 ret
549 )]
550 fn sys_state_get(&mut self, key: String) -> Result<NotificationHandle, Error> {
551 invocation_debug_logs!(self, "Executing 'Get state {key}'");
552 self.do_transition(SysStateGet(key))
553 }
554
555 #[instrument(
556 level = "trace",
557 skip(self),
558 fields(
559 restate.invocation.id = self.debug_invocation_id(),
560 restate.protocol.state = self.debug_state(),
561 restate.journal.command_index = self.context.journal.command_index(),
562 restate.protocol.version = %self.context.negotiated_protocol_version
563 ),
564 ret
565 )]
566 fn sys_state_get_keys(&mut self) -> VMResult<NotificationHandle> {
567 invocation_debug_logs!(self, "Executing 'Get state keys'");
568 self.do_transition(SysStateGetKeys)
569 }
570
571 #[instrument(
572 level = "trace",
573 skip(self, value),
574 fields(
575 restate.invocation.id = self.debug_invocation_id(),
576 restate.protocol.state = self.debug_state(),
577 restate.journal.command_index = self.context.journal.command_index(),
578 restate.protocol.version = %self.context.negotiated_protocol_version
579 ),
580 ret
581 )]
582 fn sys_state_set(&mut self, key: String, value: Bytes) -> Result<(), Error> {
583 invocation_debug_logs!(self, "Executing 'Set state {key}'");
584 self.context.eager_state.set(key.clone(), value.clone());
585 self.do_transition(SysNonCompletableEntry(
586 "SysStateSet",
587 SetStateCommandMessage {
588 key: Bytes::from(key.into_bytes()),
589 value: Some(value.into()),
590 ..SetStateCommandMessage::default()
591 },
592 ))
593 }
594
595 #[instrument(
596 level = "trace",
597 skip(self),
598 fields(
599 restate.invocation.id = self.debug_invocation_id(),
600 restate.protocol.state = self.debug_state(),
601 restate.journal.command_index = self.context.journal.command_index(),
602 restate.protocol.version = %self.context.negotiated_protocol_version
603 ),
604 ret
605 )]
606 fn sys_state_clear(&mut self, key: String) -> Result<(), Error> {
607 invocation_debug_logs!(self, "Executing 'Clear state {key}'");
608 self.context.eager_state.clear(key.clone());
609 self.do_transition(SysNonCompletableEntry(
610 "SysStateClear",
611 ClearStateCommandMessage {
612 key: Bytes::from(key.into_bytes()),
613 ..ClearStateCommandMessage::default()
614 },
615 ))
616 }
617
618 #[instrument(
619 level = "trace",
620 skip(self),
621 fields(
622 restate.invocation.id = self.debug_invocation_id(),
623 restate.protocol.state = self.debug_state(),
624 restate.journal.command_index = self.context.journal.command_index(),
625 restate.protocol.version = %self.context.negotiated_protocol_version
626 ),
627 ret
628 )]
629 fn sys_state_clear_all(&mut self) -> Result<(), Error> {
630 invocation_debug_logs!(self, "Executing 'Clear all state'");
631 self.context.eager_state.clear_all();
632 self.do_transition(SysNonCompletableEntry(
633 "SysStateClearAll",
634 ClearAllStateCommandMessage::default(),
635 ))
636 }
637
638 #[instrument(
639 level = "trace",
640 skip(self),
641 fields(
642 restate.invocation.id = self.debug_invocation_id(),
643 restate.protocol.state = self.debug_state(),
644 restate.journal.command_index = self.context.journal.command_index(),
645 restate.protocol.version = %self.context.negotiated_protocol_version
646 ),
647 ret
648 )]
649 fn sys_sleep(
650 &mut self,
651 name: String,
652 wake_up_time_since_unix_epoch: Duration,
653 now_since_unix_epoch: Option<Duration>,
654 ) -> VMResult<NotificationHandle> {
655 if self.is_processing() {
656 match (&name, now_since_unix_epoch) {
657 (name, Some(now_since_unix_epoch)) if name.is_empty() => {
658 debug!(
659 "Executing 'Timer with duration {:?}'",
660 wake_up_time_since_unix_epoch - now_since_unix_epoch
661 );
662 }
663 (name, Some(now_since_unix_epoch)) => {
664 debug!(
665 "Executing 'Timer {name} with duration {:?}'",
666 wake_up_time_since_unix_epoch - now_since_unix_epoch
667 );
668 }
669 (name, None) if name.is_empty() => {
670 debug!("Executing 'Timer'");
671 }
672 (name, None) => {
673 debug!("Executing 'Timer named {name}'");
674 }
675 }
676 }
677
678 let completion_id = self.context.journal.next_completion_notification_id();
679
680 self.do_transition(SysSimpleCompletableEntry(
681 "SysSleep",
682 SleepCommandMessage {
683 wake_up_time: u64::try_from(wake_up_time_since_unix_epoch.as_millis())
684 .expect("millis since Unix epoch should fit in u64"),
685 result_completion_id: completion_id,
686 name,
687 },
688 completion_id,
689 ))
690 }
691
692 #[instrument(
693 level = "trace",
694 skip(self, input),
695 fields(
696 restate.invocation.id = self.debug_invocation_id(),
697 restate.protocol.state = self.debug_state(),
698 restate.journal.command_index = self.context.journal.command_index(),
699 restate.protocol.version = %self.context.negotiated_protocol_version
700 ),
701 ret
702 )]
703 fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult<CallHandle> {
704 invocation_debug_logs!(
705 self,
706 "Executing 'Call {}/{}'",
707 target.service,
708 target.handler
709 );
710 if let Some(idempotency_key) = &target.idempotency_key {
711 if idempotency_key.is_empty() {
712 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
713 unreachable!();
714 }
715 }
716
717 let call_invocation_id_completion_id =
718 self.context.journal.next_completion_notification_id();
719 let result_completion_id = self.context.journal.next_completion_notification_id();
720
721 let handles = self.do_transition(SysCompletableEntryWithMultipleCompletions(
722 "SysCall",
723 CallCommandMessage {
724 service_name: target.service,
725 handler_name: target.handler,
726 key: target.key.unwrap_or_default(),
727 idempotency_key: target.idempotency_key,
728 headers: target
729 .headers
730 .into_iter()
731 .map(crate::service_protocol::messages::Header::from)
732 .collect(),
733 parameter: input,
734 invocation_id_notification_idx: call_invocation_id_completion_id,
735 result_completion_id,
736 ..Default::default()
737 },
738 vec![call_invocation_id_completion_id, result_completion_id],
739 ))?;
740
741 if matches!(
742 self.options.implicit_cancellation,
743 ImplicitCancellationOption::Enabled {
744 cancel_children_calls: true,
745 ..
746 }
747 ) {
748 self.tracked_invocation_ids.push(TrackedInvocationId {
749 handle: handles[0],
750 invocation_id: None,
751 })
752 }
753
754 Ok(CallHandle {
755 invocation_id_notification_handle: handles[0],
756 call_notification_handle: handles[1],
757 })
758 }
759
760 #[instrument(
761 level = "trace",
762 skip(self, input),
763 fields(
764 restate.invocation.id = self.debug_invocation_id(),
765 restate.protocol.state = self.debug_state(),
766 restate.journal.command_index = self.context.journal.command_index(),
767 restate.protocol.version = %self.context.negotiated_protocol_version
768 ),
769 ret
770 )]
771 fn sys_send(
772 &mut self,
773 target: Target,
774 input: Bytes,
775 delay: Option<Duration>,
776 ) -> VMResult<SendHandle> {
777 invocation_debug_logs!(
778 self,
779 "Executing 'Send to {}/{}'",
780 target.service,
781 target.handler
782 );
783 if let Some(idempotency_key) = &target.idempotency_key {
784 if idempotency_key.is_empty() {
785 self.do_transition(HitError(EMPTY_IDEMPOTENCY_KEY))?;
786 unreachable!();
787 }
788 }
789 let call_invocation_id_completion_id =
790 self.context.journal.next_completion_notification_id();
791 let invocation_id_notification_handle = self.do_transition(SysSimpleCompletableEntry(
792 "SysOneWayCall",
793 OneWayCallCommandMessage {
794 service_name: target.service,
795 handler_name: target.handler,
796 key: target.key.unwrap_or_default(),
797 idempotency_key: target.idempotency_key,
798 headers: target
799 .headers
800 .into_iter()
801 .map(crate::service_protocol::messages::Header::from)
802 .collect(),
803 parameter: input,
804 invoke_time: delay
805 .map(|d| {
806 u64::try_from(d.as_millis())
807 .expect("millis since Unix epoch should fit in u64")
808 })
809 .unwrap_or_default(),
810 invocation_id_notification_idx: call_invocation_id_completion_id,
811 ..Default::default()
812 },
813 call_invocation_id_completion_id,
814 ))?;
815
816 if matches!(
817 self.options.implicit_cancellation,
818 ImplicitCancellationOption::Enabled {
819 cancel_children_one_way_calls: true,
820 ..
821 }
822 ) {
823 self.tracked_invocation_ids.push(TrackedInvocationId {
824 handle: invocation_id_notification_handle,
825 invocation_id: None,
826 })
827 }
828
829 Ok(SendHandle {
830 invocation_id_notification_handle,
831 })
832 }
833
834 #[instrument(
835 level = "trace",
836 skip(self),
837 fields(
838 restate.invocation.id = self.debug_invocation_id(),
839 restate.protocol.state = self.debug_state(),
840 restate.journal.command_index = self.context.journal.command_index(),
841 restate.protocol.version = %self.context.negotiated_protocol_version
842 ),
843 ret
844 )]
845 fn sys_awakeable(&mut self) -> VMResult<(String, NotificationHandle)> {
846 invocation_debug_logs!(self, "Executing 'Create awakeable'");
847
848 let signal_id = self.context.journal.next_signal_notification_id();
849
850 let handle = self.do_transition(CreateSignalHandle(
851 "SysAwakeable",
852 NotificationId::SignalId(signal_id),
853 ))?;
854
855 Ok((
856 awakeable_id_str(&self.context.expect_start_info().id, signal_id),
857 handle,
858 ))
859 }
860
861 #[instrument(
862 level = "trace",
863 skip(self, value),
864 fields(
865 restate.invocation.id = self.debug_invocation_id(),
866 restate.protocol.state = self.debug_state(),
867 restate.journal.command_index = self.context.journal.command_index(),
868 restate.protocol.version = %self.context.negotiated_protocol_version
869 ),
870 ret
871 )]
872 fn sys_complete_awakeable(&mut self, id: String, value: NonEmptyValue) -> VMResult<()> {
873 invocation_debug_logs!(self, "Executing 'Complete awakeable {id}'");
874 self.do_transition(SysNonCompletableEntry(
875 "SysCompleteAwakeable",
876 CompleteAwakeableCommandMessage {
877 awakeable_id: id,
878 result: Some(match value {
879 NonEmptyValue::Success(s) => {
880 complete_awakeable_command_message::Result::Value(s.into())
881 }
882 NonEmptyValue::Failure(f) => {
883 complete_awakeable_command_message::Result::Failure(f.into())
884 }
885 }),
886 ..Default::default()
887 },
888 ))
889 }
890
891 #[instrument(
892 level = "trace",
893 skip(self),
894 fields(
895 restate.invocation.id = self.debug_invocation_id(),
896 restate.protocol.state = self.debug_state(),
897 restate.journal.command_index = self.context.journal.command_index(),
898 restate.protocol.version = %self.context.negotiated_protocol_version
899 ),
900 ret
901 )]
902 fn create_signal_handle(&mut self, signal_name: String) -> VMResult<NotificationHandle> {
903 invocation_debug_logs!(self, "Executing 'Create named signal'");
904
905 self.do_transition(CreateSignalHandle(
906 "SysCreateNamedSignal",
907 NotificationId::SignalName(signal_name),
908 ))
909 }
910
911 #[instrument(
912 level = "trace",
913 skip(self, value),
914 fields(
915 restate.invocation.id = self.debug_invocation_id(),
916 restate.protocol.state = self.debug_state(),
917 restate.journal.command_index = self.context.journal.command_index(),
918 restate.protocol.version = %self.context.negotiated_protocol_version
919 ),
920 ret
921 )]
922 fn sys_complete_signal(
923 &mut self,
924 target_invocation_id: String,
925 signal_name: String,
926 value: NonEmptyValue,
927 ) -> VMResult<()> {
928 invocation_debug_logs!(self, "Executing 'Complete named signal {signal_name}'");
929 self.do_transition(SysNonCompletableEntry(
930 "SysCompleteAwakeable",
931 SendSignalCommandMessage {
932 target_invocation_id,
933 signal_id: Some(send_signal_command_message::SignalId::Name(signal_name)),
934 result: Some(match value {
935 NonEmptyValue::Success(s) => {
936 send_signal_command_message::Result::Value(s.into())
937 }
938 NonEmptyValue::Failure(f) => {
939 send_signal_command_message::Result::Failure(f.into())
940 }
941 }),
942 ..Default::default()
943 },
944 ))
945 }
946
947 #[instrument(
948 level = "trace",
949 skip(self),
950 fields(
951 restate.invocation.id = self.debug_invocation_id(),
952 restate.protocol.state = self.debug_state(),
953 restate.journal.command_index = self.context.journal.command_index(),
954 restate.protocol.version = %self.context.negotiated_protocol_version
955 ),
956 ret
957 )]
958 fn sys_get_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
959 invocation_debug_logs!(self, "Executing 'Await promise {key}'");
960
961 let result_completion_id = self.context.journal.next_completion_notification_id();
962 self.do_transition(SysSimpleCompletableEntry(
963 "SysGetPromise",
964 GetPromiseCommandMessage {
965 key,
966 result_completion_id,
967 ..Default::default()
968 },
969 result_completion_id,
970 ))
971 }
972
973 #[instrument(
974 level = "trace",
975 skip(self),
976 fields(
977 restate.invocation.id = self.debug_invocation_id(),
978 restate.protocol.state = self.debug_state(),
979 restate.journal.command_index = self.context.journal.command_index(),
980 restate.protocol.version = %self.context.negotiated_protocol_version
981 ),
982 ret
983 )]
984 fn sys_peek_promise(&mut self, key: String) -> VMResult<NotificationHandle> {
985 invocation_debug_logs!(self, "Executing 'Peek promise {key}'");
986
987 let result_completion_id = self.context.journal.next_completion_notification_id();
988 self.do_transition(SysSimpleCompletableEntry(
989 "SysPeekPromise",
990 PeekPromiseCommandMessage {
991 key,
992 result_completion_id,
993 ..Default::default()
994 },
995 result_completion_id,
996 ))
997 }
998
999 #[instrument(
1000 level = "trace",
1001 skip(self, value),
1002 fields(
1003 restate.invocation.id = self.debug_invocation_id(),
1004 restate.protocol.state = self.debug_state(),
1005 restate.journal.command_index = self.context.journal.command_index(),
1006 restate.protocol.version = %self.context.negotiated_protocol_version
1007 ),
1008 ret
1009 )]
1010 fn sys_complete_promise(
1011 &mut self,
1012 key: String,
1013 value: NonEmptyValue,
1014 ) -> VMResult<NotificationHandle> {
1015 invocation_debug_logs!(self, "Executing 'Complete promise {key}'");
1016
1017 let result_completion_id = self.context.journal.next_completion_notification_id();
1018 self.do_transition(SysSimpleCompletableEntry(
1019 "SysCompletePromise",
1020 CompletePromiseCommandMessage {
1021 key,
1022 completion: Some(match value {
1023 NonEmptyValue::Success(s) => {
1024 complete_promise_command_message::Completion::CompletionValue(s.into())
1025 }
1026 NonEmptyValue::Failure(f) => {
1027 complete_promise_command_message::Completion::CompletionFailure(f.into())
1028 }
1029 }),
1030 result_completion_id,
1031 ..Default::default()
1032 },
1033 result_completion_id,
1034 ))
1035 }
1036
1037 #[instrument(
1038 level = "trace",
1039 skip(self),
1040 fields(
1041 restate.invocation.id = self.debug_invocation_id(),
1042 restate.protocol.state = self.debug_state(),
1043 restate.journal.command_index = self.context.journal.command_index(),
1044 restate.protocol.version = %self.context.negotiated_protocol_version
1045 ),
1046 ret
1047 )]
1048 fn sys_run(&mut self, name: String) -> VMResult<NotificationHandle> {
1049 match self.do_transition(SysRun(name.clone())) {
1050 Ok(handle) => {
1051 if enabled!(Level::DEBUG) {
1052 self.sys_run_names.insert(handle, name);
1054 }
1055 Ok(handle)
1056 }
1057 Err(e) => Err(e),
1058 }
1059 }
1060
1061 #[instrument(
1062 level = "trace",
1063 skip(self, value, retry_policy),
1064 fields(
1065 restate.invocation.id = self.debug_invocation_id(),
1066 restate.protocol.state = self.debug_state(),
1067 restate.journal.command_index = self.context.journal.command_index(),
1068 restate.protocol.version = %self.context.negotiated_protocol_version
1069 ),
1070 ret
1071 )]
1072 fn propose_run_completion(
1073 &mut self,
1074 notification_handle: NotificationHandle,
1075 value: RunExitResult,
1076 retry_policy: RetryPolicy,
1077 ) -> VMResult<()> {
1078 if enabled!(Level::DEBUG) {
1079 let name: &str = self
1080 .sys_run_names
1081 .get(¬ification_handle)
1082 .map(String::as_str)
1083 .unwrap_or_default();
1084 match &value {
1085 RunExitResult::Success(_) => {
1086 invocation_debug_logs!(self, "Journaling run '{name}' success result");
1087 }
1088 RunExitResult::TerminalFailure(TerminalFailure { code, .. }) => {
1089 invocation_debug_logs!(
1090 self,
1091 "Journaling run '{name}' terminal failure {code} result"
1092 );
1093 }
1094 RunExitResult::RetryableFailure { .. } => {
1095 invocation_debug_logs!(self, "Propagating run '{name}' retryable failure");
1096 }
1097 }
1098 }
1099
1100 self.do_transition(ProposeRunCompletion(
1101 notification_handle,
1102 value,
1103 retry_policy,
1104 ))
1105 }
1106
1107 #[instrument(
1108 level = "trace",
1109 skip(self),
1110 fields(
1111 restate.invocation.id = self.debug_invocation_id(),
1112 restate.protocol.state = self.debug_state(),
1113 restate.journal.command_index = self.context.journal.command_index(),
1114 restate.protocol.version = %self.context.negotiated_protocol_version
1115 ),
1116 ret
1117 )]
1118 fn sys_cancel_invocation(&mut self, target_invocation_id: String) -> VMResult<()> {
1119 invocation_debug_logs!(
1120 self,
1121 "Executing 'Cancel invocation' of {target_invocation_id}"
1122 );
1123 self.do_transition(SysNonCompletableEntry(
1124 "SysCancelInvocation",
1125 SendSignalCommandMessage {
1126 target_invocation_id,
1127 signal_id: Some(send_signal_command_message::SignalId::Idx(CANCEL_SIGNAL_ID)),
1128 result: Some(send_signal_command_message::Result::Void(Default::default())),
1129 ..Default::default()
1130 },
1131 ))
1132 }
1133
1134 #[instrument(
1135 level = "trace",
1136 skip(self),
1137 fields(
1138 restate.invocation.id = self.debug_invocation_id(),
1139 restate.protocol.state = self.debug_state(),
1140 restate.journal.command_index = self.context.journal.command_index(),
1141 restate.protocol.version = %self.context.negotiated_protocol_version
1142 ),
1143 ret
1144 )]
1145 fn sys_attach_invocation(
1146 &mut self,
1147 target: AttachInvocationTarget,
1148 ) -> VMResult<NotificationHandle> {
1149 invocation_debug_logs!(self, "Executing 'Attach invocation'");
1150
1151 let result_completion_id = self.context.journal.next_completion_notification_id();
1152 self.do_transition(SysSimpleCompletableEntry(
1153 "SysAttachInvocation",
1154 AttachInvocationCommandMessage {
1155 target: Some(match target {
1156 AttachInvocationTarget::InvocationId(id) => {
1157 attach_invocation_command_message::Target::InvocationId(id)
1158 }
1159 AttachInvocationTarget::WorkflowId { name, key } => {
1160 attach_invocation_command_message::Target::WorkflowTarget(WorkflowTarget {
1161 workflow_name: name,
1162 workflow_key: key,
1163 })
1164 }
1165 AttachInvocationTarget::IdempotencyId {
1166 service_name,
1167 service_key,
1168 handler_name,
1169 idempotency_key,
1170 } => attach_invocation_command_message::Target::IdempotentRequestTarget(
1171 IdempotentRequestTarget {
1172 service_name,
1173 service_key,
1174 handler_name,
1175 idempotency_key,
1176 },
1177 ),
1178 }),
1179 result_completion_id,
1180 ..Default::default()
1181 },
1182 result_completion_id,
1183 ))
1184 }
1185
1186 #[instrument(
1187 level = "trace",
1188 skip(self),
1189 fields(
1190 restate.invocation.id = self.debug_invocation_id(),
1191 restate.protocol.state = self.debug_state(),
1192 restate.journal.command_index = self.context.journal.command_index(),
1193 restate.protocol.version = %self.context.negotiated_protocol_version
1194 ),
1195 ret
1196 )]
1197 fn sys_get_invocation_output(
1198 &mut self,
1199 target: AttachInvocationTarget,
1200 ) -> VMResult<NotificationHandle> {
1201 invocation_debug_logs!(self, "Executing 'Get invocation output'");
1202
1203 let result_completion_id = self.context.journal.next_completion_notification_id();
1204 self.do_transition(SysSimpleCompletableEntry(
1205 "SysGetInvocationOutput",
1206 GetInvocationOutputCommandMessage {
1207 target: Some(match target {
1208 AttachInvocationTarget::InvocationId(id) => {
1209 get_invocation_output_command_message::Target::InvocationId(id)
1210 }
1211 AttachInvocationTarget::WorkflowId { name, key } => {
1212 get_invocation_output_command_message::Target::WorkflowTarget(
1213 WorkflowTarget {
1214 workflow_name: name,
1215 workflow_key: key,
1216 },
1217 )
1218 }
1219 AttachInvocationTarget::IdempotencyId {
1220 service_name,
1221 service_key,
1222 handler_name,
1223 idempotency_key,
1224 } => get_invocation_output_command_message::Target::IdempotentRequestTarget(
1225 IdempotentRequestTarget {
1226 service_name,
1227 service_key,
1228 handler_name,
1229 idempotency_key,
1230 },
1231 ),
1232 }),
1233 result_completion_id,
1234 ..Default::default()
1235 },
1236 result_completion_id,
1237 ))
1238 }
1239
1240 #[instrument(
1241 level = "trace",
1242 skip(self, value),
1243 fields(
1244 restate.invocation.id = self.debug_invocation_id(),
1245 restate.protocol.state = self.debug_state(),
1246 restate.journal.command_index = self.context.journal.command_index(),
1247 restate.protocol.version = %self.context.negotiated_protocol_version
1248 ),
1249 ret
1250 )]
1251 fn sys_write_output(&mut self, value: NonEmptyValue) -> Result<(), Error> {
1252 match &value {
1253 NonEmptyValue::Success(_) => {
1254 invocation_debug_logs!(self, "Writing invocation result success value");
1255 }
1256 NonEmptyValue::Failure(_) => {
1257 invocation_debug_logs!(self, "Writing invocation result failure value");
1258 }
1259 }
1260 self.do_transition(SysNonCompletableEntry(
1261 "SysWriteOutput",
1262 OutputCommandMessage {
1263 result: Some(match value {
1264 NonEmptyValue::Success(b) => output_command_message::Result::Value(b.into()),
1265 NonEmptyValue::Failure(f) => output_command_message::Result::Failure(f.into()),
1266 }),
1267 ..OutputCommandMessage::default()
1268 },
1269 ))
1270 }
1271
1272 #[instrument(
1273 level = "trace",
1274 skip(self),
1275 fields(
1276 restate.invocation.id = self.debug_invocation_id(),
1277 restate.protocol.state = self.debug_state(),
1278 restate.journal.command_index = self.context.journal.command_index(),
1279 restate.protocol.version = %self.context.negotiated_protocol_version
1280 ),
1281 ret
1282 )]
1283 fn sys_end(&mut self) -> Result<(), Error> {
1284 invocation_debug_logs!(self, "End of the invocation");
1285 self.do_transition(SysEnd)
1286 }
1287
1288 fn is_waiting_preflight(&self) -> bool {
1289 matches!(
1290 &self.last_transition,
1291 Ok(State::WaitingStart) | Ok(State::WaitingReplayEntries { .. })
1292 )
1293 }
1294
1295 fn is_replaying(&self) -> bool {
1296 matches!(&self.last_transition, Ok(State::Replaying { .. }))
1297 }
1298
1299 fn is_processing(&self) -> bool {
1300 matches!(&self.last_transition, Ok(State::Processing { .. }))
1301 }
1302
1303 fn last_command_index(&self) -> i64 {
1304 self.context.journal.command_index()
1305 }
1306}
1307
1308const INDIFFERENT_PAD: GeneralPurposeConfig = GeneralPurposeConfig::new()
1309 .with_decode_padding_mode(DecodePaddingMode::Indifferent)
1310 .with_encode_padding(false);
1311const URL_SAFE: GeneralPurpose = GeneralPurpose::new(&alphabet::URL_SAFE, INDIFFERENT_PAD);
1312
1313const AWAKEABLE_PREFIX: &str = "sign_1";
1314
1315fn awakeable_id_str(id: &[u8], completion_index: u32) -> String {
1316 let mut input_buf = BytesMut::with_capacity(id.len() + size_of::<u32>());
1317 input_buf.put_slice(id);
1318 input_buf.put_u32(completion_index);
1319 format!("{AWAKEABLE_PREFIX}{}", URL_SAFE.encode(input_buf.freeze()))
1320}