use crate::{
data_converters::{PayloadCodec, SerializationContextData},
protos::temporal::api::common::v1::{Payload, Payloads},
};
use futures::future::BoxFuture;
pub struct PayloadField<'a> {
pub path: &'static str,
pub data: PayloadFieldData<'a>,
}
pub enum PayloadFieldData<'a> {
Single(&'a mut Payload),
Repeated(&'a mut Vec<Payload>),
Payloads(&'a mut Payloads),
}
pub trait AsyncPayloadVisitor {
fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()>;
}
pub trait PayloadVisitable: Send {
fn visit_payloads_mut<'a>(
&'a mut self,
visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
) -> BoxFuture<'a, ()>;
}
fn is_search_attributes_path(path: &str) -> bool {
path.contains("SearchAttributes.indexed_fields")
}
fn should_encode(path: &str) -> bool {
!is_search_attributes_path(path)
}
pub struct EncodeVisitor<'a> {
codec: &'a (dyn PayloadCodec + Send + Sync),
context: &'a SerializationContextData,
}
impl AsyncPayloadVisitor for EncodeVisitor<'_> {
fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
Box::pin(async move {
if !should_encode(field.path) {
return;
}
match field.data {
PayloadFieldData::Single(payload) => {
let encoded = self
.codec
.encode(self.context, vec![std::mem::take(payload)])
.await;
if let Some(p) = encoded.into_iter().next() {
*payload = p;
}
}
PayloadFieldData::Repeated(payloads) => {
*payloads = self
.codec
.encode(self.context, std::mem::take(payloads))
.await;
}
PayloadFieldData::Payloads(payloads_msg) => {
payloads_msg.payloads = self
.codec
.encode(self.context, std::mem::take(&mut payloads_msg.payloads))
.await;
}
}
})
}
}
pub struct DecodeVisitor<'a> {
codec: &'a (dyn PayloadCodec + Send + Sync),
context: &'a SerializationContextData,
}
impl AsyncPayloadVisitor for DecodeVisitor<'_> {
fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
Box::pin(async move {
match field.data {
PayloadFieldData::Single(payload) => {
let decoded = self
.codec
.decode(self.context, vec![std::mem::take(payload)])
.await;
if let Some(p) = decoded.into_iter().next() {
*payload = p;
}
}
PayloadFieldData::Repeated(payloads) => {
*payloads = self
.codec
.decode(self.context, std::mem::take(payloads))
.await;
}
PayloadFieldData::Payloads(payloads_msg) => {
payloads_msg.payloads = self
.codec
.decode(self.context, std::mem::take(&mut payloads_msg.payloads))
.await;
}
}
})
}
}
pub async fn encode_payloads<M: PayloadVisitable + Send>(
msg: &mut M,
codec: &(dyn PayloadCodec + Send + Sync),
context: &SerializationContextData,
) {
let mut visitor = EncodeVisitor { codec, context };
msg.visit_payloads_mut(&mut visitor).await;
}
pub async fn decode_payloads<M: PayloadVisitable + Send>(
msg: &mut M,
codec: &(dyn PayloadCodec + Send + Sync),
context: &SerializationContextData,
) {
let mut visitor = DecodeVisitor { codec, context };
msg.visit_payloads_mut(&mut visitor).await;
}
impl PayloadVisitable for Payload {
fn visit_payloads_mut<'a>(
&'a mut self,
visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
) -> BoxFuture<'a, ()> {
Box::pin(async move {
visitor
.visit(PayloadField {
path: "temporal.api.common.v1.Payload",
data: PayloadFieldData::Single(self),
})
.await;
})
}
}
impl PayloadVisitable for Payloads {
fn visit_payloads_mut<'a>(
&'a mut self,
visitor: &'a mut (dyn AsyncPayloadVisitor + Send),
) -> BoxFuture<'a, ()> {
Box::pin(async move {
visitor
.visit(PayloadField {
path: "temporal.api.common.v1.Payloads",
data: PayloadFieldData::Payloads(self),
})
.await;
})
}
}
include!(concat!(env!("OUT_DIR"), "/payload_visitor_impl.rs"));
#[cfg(test)]
mod tests {
use super::*;
use crate::protos::{
coresdk::{
activity_result::{
ActivityResolution, Success, activity_resolution::Status as ActivityStatus,
},
workflow_activation::{
InitializeWorkflow, ResolveActivity, WorkflowActivation, WorkflowActivationJob,
workflow_activation_job::Variant,
},
workflow_commands::{
ContinueAsNewWorkflowExecution, ScheduleActivity, StartChildWorkflowExecution,
UpsertWorkflowSearchAttributes, WorkflowCommand,
workflow_command::Variant as CmdVariant,
},
workflow_completion::{
WorkflowActivationCompletion, workflow_activation_completion::Status,
},
},
temporal::api::common::v1::{Memo, SearchAttributes},
};
use futures::FutureExt;
use std::collections::HashMap;
struct MarkingCodec;
impl PayloadCodec for MarkingCodec {
fn encode(
&self,
_: &SerializationContextData,
payloads: Vec<Payload>,
) -> BoxFuture<'static, Vec<Payload>> {
async move {
payloads
.into_iter()
.map(|mut p| {
p.metadata.insert("encoded".to_string(), b"true".to_vec());
p
})
.collect()
}
.boxed()
}
fn decode(
&self,
_: &SerializationContextData,
payloads: Vec<Payload>,
) -> BoxFuture<'static, Vec<Payload>> {
async move {
payloads
.into_iter()
.map(|mut p| {
p.metadata.insert("decoded".to_string(), b"true".to_vec());
p
})
.collect()
}
.boxed()
}
}
struct PathRecordingVisitor {
visited_paths: Vec<String>,
}
impl PathRecordingVisitor {
fn new() -> Self {
Self {
visited_paths: Vec::new(),
}
}
fn paths(&self) -> Vec<String> {
self.visited_paths.clone()
}
}
impl AsyncPayloadVisitor for PathRecordingVisitor {
fn visit<'a>(&'a mut self, field: PayloadField<'a>) -> BoxFuture<'a, ()> {
let path = field.path.to_string();
self.visited_paths.push(path);
async move {}.boxed()
}
}
fn make_payload(data: &str) -> Payload {
Payload {
metadata: HashMap::new(),
data: data.as_bytes().to_vec(),
external_payloads: vec![],
}
}
fn is_encoded(p: &Payload) -> bool {
p.metadata.contains_key("encoded")
}
fn is_decoded(p: &Payload) -> bool {
p.metadata.contains_key("decoded")
}
#[tokio::test]
async fn test_direct_visitor_records_paths() {
let mut activation = WorkflowActivation {
run_id: "test-run".to_string(),
jobs: vec![WorkflowActivationJob {
variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
workflow_type: "test-workflow".to_string(),
arguments: vec![make_payload("input1")],
headers: {
let mut h = HashMap::new();
h.insert("header-key".to_string(), make_payload("header-value"));
h
},
memo: Some(Memo {
fields: {
let mut m = HashMap::new();
m.insert("memo-key".to_string(), make_payload("memo-value"));
m
},
}),
..Default::default()
})),
}],
..Default::default()
};
let mut visitor = PathRecordingVisitor::new();
activation.visit_payloads_mut(&mut visitor).await;
let paths = visitor.paths();
assert!(
paths
.iter()
.any(|p| p.contains("InitializeWorkflow.arguments")),
"should visit arguments, got: {:?}",
paths
);
assert!(
paths
.iter()
.any(|p| p.contains("InitializeWorkflow.headers")),
"should visit headers, got: {:?}",
paths
);
assert!(
paths.iter().any(|p| p.contains("Memo.fields")),
"should visit memo fields, got: {:?}",
paths
);
}
#[tokio::test]
async fn test_encode_workflow_activation_completion_with_schedule_activity() {
let mut completion = WorkflowActivationCompletion {
run_id: "test-run".to_string(),
status: Some(Status::Successful(
crate::protos::coresdk::workflow_completion::Success {
commands: vec![WorkflowCommand {
variant: Some(CmdVariant::ScheduleActivity(ScheduleActivity {
activity_id: "act-1".to_string(),
activity_type: "test-activity".to_string(),
arguments: vec![make_payload("arg1"), make_payload("arg2")],
headers: {
let mut h = HashMap::new();
h.insert("header-key".to_string(), make_payload("header-value"));
h
},
..Default::default()
})),
user_metadata: None,
}],
..Default::default()
},
)),
};
encode_payloads(
&mut completion,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
let status = completion.status.as_ref().unwrap();
let Status::Successful(success) = status else {
panic!("Expected successful status")
};
let cmd = &success.commands[0];
let CmdVariant::ScheduleActivity(schedule) = cmd.variant.as_ref().unwrap() else {
panic!("Expected ScheduleActivity")
};
assert!(is_encoded(&schedule.arguments[0]), "arg1 should be encoded");
assert!(is_encoded(&schedule.arguments[1]), "arg2 should be encoded");
assert!(
is_encoded(schedule.headers.get("header-key").unwrap()),
"header should be encoded"
);
}
#[tokio::test]
async fn test_decode_workflow_activation_with_initialize() {
let mut activation = WorkflowActivation {
run_id: "test-run".to_string(),
jobs: vec![WorkflowActivationJob {
variant: Some(Variant::InitializeWorkflow(InitializeWorkflow {
workflow_type: "test-workflow".to_string(),
arguments: vec![make_payload("input1"), make_payload("input2")],
headers: {
let mut h = HashMap::new();
h.insert("header-key".to_string(), make_payload("header-value"));
h
},
..Default::default()
})),
}],
..Default::default()
};
decode_payloads(
&mut activation,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
let job = &activation.jobs[0];
let Variant::InitializeWorkflow(init) = job.variant.as_ref().unwrap() else {
panic!("Expected InitializeWorkflow")
};
assert!(is_decoded(&init.arguments[0]), "arg1 should be decoded");
assert!(is_decoded(&init.arguments[1]), "arg2 should be decoded");
assert!(
is_decoded(init.headers.get("header-key").unwrap()),
"header should be decoded"
);
}
#[tokio::test]
async fn test_decode_workflow_activation_with_resolve_activity() {
let mut activation = WorkflowActivation {
run_id: "test-run".to_string(),
jobs: vec![WorkflowActivationJob {
variant: Some(Variant::ResolveActivity(ResolveActivity {
seq: 1,
result: Some(ActivityResolution {
status: Some(ActivityStatus::Completed(Success {
result: Some(make_payload("activity-result")),
})),
}),
..Default::default()
})),
}],
..Default::default()
};
decode_payloads(
&mut activation,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
let job = &activation.jobs[0];
let Variant::ResolveActivity(resolve) = job.variant.as_ref().unwrap() else {
panic!("Expected ResolveActivity")
};
let ActivityStatus::Completed(success) =
resolve.result.as_ref().unwrap().status.as_ref().unwrap()
else {
panic!("Expected Completed status")
};
assert!(
is_decoded(success.result.as_ref().unwrap()),
"activity result should be decoded"
);
}
#[tokio::test]
async fn test_search_attributes_skipped_on_encode() {
let mut completion = WorkflowActivationCompletion {
run_id: "test-run".to_string(),
status: Some(Status::Successful(
crate::protos::coresdk::workflow_completion::Success {
commands: vec![
WorkflowCommand {
variant: Some(CmdVariant::UpsertWorkflowSearchAttributes(
UpsertWorkflowSearchAttributes {
search_attributes: Some(SearchAttributes {
indexed_fields: {
let mut sa = HashMap::new();
sa.insert(
"CustomField".to_string(),
make_payload("search-value"),
);
sa
},
}),
},
)),
user_metadata: None,
},
WorkflowCommand {
variant: Some(CmdVariant::ContinueAsNewWorkflowExecution(
ContinueAsNewWorkflowExecution {
arguments: vec![make_payload("continue-arg")],
search_attributes: Some(SearchAttributes {
indexed_fields: {
let mut sa = HashMap::new();
sa.insert(
"CustomField".to_string(),
make_payload("continue-search-value"),
);
sa
},
}),
..Default::default()
},
)),
user_metadata: None,
},
WorkflowCommand {
variant: Some(CmdVariant::StartChildWorkflowExecution(
StartChildWorkflowExecution {
seq: 1,
workflow_type: "child-workflow".to_string(),
input: vec![make_payload("child-arg")],
search_attributes: Some(SearchAttributes {
indexed_fields: {
let mut sa = HashMap::new();
sa.insert(
"CustomField".to_string(),
make_payload("child-search-value"),
);
sa
},
}),
..Default::default()
},
)),
user_metadata: None,
},
],
..Default::default()
},
)),
};
encode_payloads(
&mut completion,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
let status = completion.status.as_ref().unwrap();
let Status::Successful(success) = status else {
panic!("Expected successful status")
};
let CmdVariant::UpsertWorkflowSearchAttributes(upsert) =
success.commands[0].variant.as_ref().unwrap()
else {
panic!("Expected UpsertWorkflowSearchAttributes")
};
let sa = upsert.search_attributes.as_ref().unwrap();
assert!(
!is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
"search attributes should NOT be encoded"
);
let CmdVariant::ContinueAsNewWorkflowExecution(continue_as_new) =
success.commands[1].variant.as_ref().unwrap()
else {
panic!("Expected ContinueAsNewWorkflowExecution")
};
assert!(
is_encoded(&continue_as_new.arguments[0]),
"arguments should be encoded"
);
let sa = continue_as_new.search_attributes.as_ref().unwrap();
assert!(
!is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
"search attributes should NOT be encoded"
);
let CmdVariant::StartChildWorkflowExecution(start_child) =
success.commands[2].variant.as_ref().unwrap()
else {
panic!("Expected StartChildWorkflowExecution")
};
assert!(is_encoded(&start_child.input[0]), "input should be encoded");
let sa = start_child.search_attributes.as_ref().unwrap();
assert!(
!is_encoded(sa.indexed_fields.get("CustomField").unwrap()),
"search attributes should NOT be encoded"
);
}
#[tokio::test]
async fn test_encode_single_payload() {
let mut payload = make_payload("test-data");
encode_payloads(
&mut payload,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
assert!(is_encoded(&payload), "single payload should be encoded");
}
#[tokio::test]
async fn test_decode_single_payload() {
let mut payload = make_payload("test-data");
decode_payloads(
&mut payload,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
assert!(is_decoded(&payload), "single payload should be decoded");
}
#[tokio::test]
async fn test_encode_payloads_message() {
let mut payloads = Payloads {
payloads: vec![make_payload("p1"), make_payload("p2"), make_payload("p3")],
};
encode_payloads(
&mut payloads,
&MarkingCodec,
&SerializationContextData::Workflow,
)
.await;
for (i, p) in payloads.payloads.iter().enumerate() {
assert!(is_encoded(p), "payload {} should be encoded", i);
}
}
}