use crate::protos::temporal::api::common::v1::{Memo, Payload, Payloads};
use prost::Message;
pub const DEFAULT_BLOB_SIZE_WARN: usize = 512 * 1024;
pub const DEFAULT_MEMO_SIZE_WARN: usize = 2 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LimitClass {
Blob,
Memo,
}
#[derive(Debug, Clone, Copy)]
pub enum FieldIndexer<'a> {
None,
Index(usize),
Key(&'a str),
}
#[derive(Debug, Clone, Default)]
struct PayloadPath {
segments: Vec<String>,
}
impl PayloadPath {
fn push(&mut self, name: &str, indexer: FieldIndexer) {
self.segments.push(match indexer {
FieldIndexer::None => name.to_string(),
FieldIndexer::Index(index) => format!("{name}[{index}]"),
FieldIndexer::Key(key) => format!("{name}[{key}]"),
});
}
fn pop(&mut self) {
self.segments.pop();
}
fn leaf(&self, field_name: &str) -> String {
if self.segments.is_empty() {
field_name.to_string()
} else {
format!("{}.{}", self.segments.join("."), field_name)
}
}
}
pub trait PayloadLimitSink {
fn check(
&mut self,
field_name: &'static str,
class: LimitClass,
size: usize,
enforce_error: bool,
);
fn enter(&mut self, name: &'static str, indexer: FieldIndexer);
fn exit(&mut self);
}
pub trait PayloadLimitsValidatable {
fn validate_payload_limits(&self, sink: &mut dyn PayloadLimitSink);
}
pub fn payloads_size(payloads: &Payloads) -> usize {
payloads.encoded_len()
}
pub fn payload_size(payload: &Payload) -> usize {
payload.encoded_len()
}
pub fn memo_size(memo: &Memo) -> usize {
memo.encoded_len()
}
pub fn message_size<M: Message>(message: &M) -> usize {
message.encoded_len()
}
pub fn map_payloads_sum<'a, K>(entries: impl IntoIterator<Item = (&'a K, &'a Payloads)>) -> usize
where
K: AsRef<str> + 'a,
{
entries
.into_iter()
.map(|(k, v)| k.as_ref().len() + v.encoded_len())
.sum()
}
pub fn map_payload_data_sum<'a, K>(entries: impl IntoIterator<Item = (&'a K, &'a Payload)>) -> usize
where
K: AsRef<str> + 'a,
{
entries
.into_iter()
.map(|(k, v)| k.as_ref().len() + v.data.len())
.sum()
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PayloadLimits {
pub blob_warn: usize,
pub blob_error: Option<usize>,
pub memo_warn: usize,
pub memo_error: Option<usize>,
}
impl Default for PayloadLimits {
fn default() -> Self {
Self {
blob_warn: DEFAULT_BLOB_SIZE_WARN,
blob_error: None,
memo_warn: DEFAULT_MEMO_SIZE_WARN,
memo_error: None,
}
}
}
impl PayloadLimits {
pub fn warn_only() -> Self {
Self::default()
}
fn thresholds(&self, class: LimitClass) -> (usize, Option<usize>) {
match class {
LimitClass::Blob => (self.blob_warn, self.blob_error),
LimitClass::Memo => (self.memo_warn, self.memo_error),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PayloadLimitViolation {
pub path: String,
pub class: LimitClass,
pub size: usize,
pub limit: usize,
}
impl std::fmt::Display for PayloadLimitViolation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"payload field `{}` size {} bytes exceeds the {:?} limit of {} bytes",
self.path, self.size, self.class, self.limit
)
}
}
impl std::error::Error for PayloadLimitViolation {}
#[derive(Debug, Clone, Default)]
pub struct CollectingSink {
limits: PayloadLimits,
path: PayloadPath,
pub warnings: Vec<PayloadLimitViolation>,
pub errors: Vec<PayloadLimitViolation>,
}
impl CollectingSink {
pub fn new(limits: PayloadLimits) -> Self {
Self {
limits,
..Default::default()
}
}
}
impl PayloadLimitSink for CollectingSink {
fn check(
&mut self,
field_name: &'static str,
class: LimitClass,
size: usize,
enforce_error: bool,
) {
let (warn, error) = self.limits.thresholds(class);
if enforce_error
&& let Some(error) = error
&& size > error
{
self.errors.push(PayloadLimitViolation {
path: self.path.leaf(field_name),
class,
size,
limit: error,
});
} else if size > warn {
self.warnings.push(PayloadLimitViolation {
path: self.path.leaf(field_name),
class,
size,
limit: warn,
});
}
}
fn enter(&mut self, name: &'static str, indexer: FieldIndexer) {
self.path.push(name, indexer);
}
fn exit(&mut self) {
self.path.pop();
}
}
pub fn validate_payload_limits<M: PayloadLimitsValidatable + ?Sized>(
msg: &M,
limits: &PayloadLimits,
) -> Option<PayloadLimitViolation> {
let mut sink = CollectingSink::new(*limits);
msg.validate_payload_limits(&mut sink);
if !sink.errors.is_empty() {
for error in &sink.errors {
error!(
payload_path = error.path.as_str(),
payload_size = error.size,
error_limit = error.limit,
?error.class,
"Payload size exceeds the error limit"
);
}
return sink.errors.into_iter().next();
}
for warning in &sink.warnings {
warn!(
payload_path = warning.path.as_str(),
payload_size = warning.size,
warn_limit = warning.limit,
?warning.class,
"Payload size exceeds the warning limit"
);
}
None
}
include!(concat!(env!("OUT_DIR"), "/payload_limits_impl.rs"));
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use crate::protos::temporal::api::{
command::v1::{
Command, CompleteWorkflowExecutionCommandAttributes,
FailWorkflowExecutionCommandAttributes, ModifyWorkflowPropertiesCommandAttributes,
RecordMarkerCommandAttributes, ScheduleActivityTaskCommandAttributes,
ScheduleNexusOperationCommandAttributes, command::Attributes,
},
failure::v1::Failure,
protocol::v1::Message,
sdk::v1::UserMetadata,
workflowservice::v1::{
RespondActivityTaskFailedRequest, RespondWorkflowTaskCompletedRequest,
StartWorkflowExecutionRequest,
},
};
fn payload(data: &[u8]) -> Payload {
Payload {
metadata: HashMap::new(),
data: data.to_vec(),
external_payloads: vec![],
}
}
#[test]
fn map_payload_data_sum_counts_key_and_raw_data() {
let mut m: HashMap<String, Payload> = HashMap::new();
m.insert("ab".to_string(), payload(&[0u8; 10]));
m.insert("cde".to_string(), payload(&[0u8; 20]));
assert_eq!(map_payload_data_sum(m.iter()), 35);
}
#[derive(Default)]
struct RecordingSink {
path: PayloadPath,
visited: Vec<String>,
}
impl PayloadLimitSink for RecordingSink {
fn check(&mut self, field_name: &'static str, _: LimitClass, _: usize, _: bool) {
self.visited.push(self.path.leaf(field_name));
}
fn enter(&mut self, name: &'static str, indexer: FieldIndexer) {
self.path.push(name, indexer);
}
fn exit(&mut self) {
self.path.pop();
}
}
impl RecordingSink {
fn sorted(&self) -> Vec<String> {
let mut v = self.visited.clone();
v.sort();
v
}
}
fn memo_with_key_value(key: &str, data_len: usize) -> Memo {
let mut fields = HashMap::new();
fields.insert(key.to_string(), payload(&vec![0u8; data_len]));
Memo { fields }
}
fn payloads(total_data: usize) -> Payloads {
Payloads {
payloads: vec![payload(&vec![0u8; total_data])],
}
}
fn worker_limits(blob_error: usize, memo_error: usize) -> PayloadLimits {
PayloadLimits {
blob_warn: 10,
blob_error: Some(blob_error),
memo_warn: 10,
memo_error: Some(memo_error),
}
}
#[test]
fn blob_field_over_error_limit_is_reported() {
let req = StartWorkflowExecutionRequest {
input: Some(payloads(1000)),
..Default::default()
};
let violation =
validate_payload_limits(&req, &worker_limits(100, 100)).expect("should error");
assert_eq!(violation.class, LimitClass::Blob);
assert_eq!(violation.path, "input");
assert!(violation.size > 100);
}
#[test]
fn memo_field_uses_memo_limit() {
let req = StartWorkflowExecutionRequest {
memo: Some(memo_with_key_value("k", 50)),
..Default::default()
};
let limits = PayloadLimits {
blob_warn: 10,
blob_error: Some(1_000_000),
memo_warn: 10,
memo_error: Some(20),
};
let violation = validate_payload_limits(&req, &limits).expect("memo should error");
assert_eq!(violation.class, LimitClass::Memo);
assert_eq!(violation.path, "memo");
}
#[test]
fn warn_only_classified_field_never_errors() {
let req = RespondActivityTaskFailedRequest {
failure: Some(Failure {
message: "x".repeat(10_000),
..Default::default()
}),
..Default::default()
};
assert!(validate_payload_limits(&req, &worker_limits(100, 100)).is_none());
}
#[test]
fn under_limit_is_ok() {
let req = StartWorkflowExecutionRequest {
input: Some(payloads(5)),
..Default::default()
};
assert!(validate_payload_limits(&req, &worker_limits(100_000, 100_000)).is_none());
}
#[test]
fn blob_classed_memo_is_measured_as_fields_data_sum() {
let mut fields = HashMap::new();
fields.insert("ab".to_string(), payload(&[0u8; 10]));
fields.insert("cde".to_string(), payload(&[0u8; 20]));
let attr = ModifyWorkflowPropertiesCommandAttributes {
upserted_memo: Some(Memo { fields }),
};
let violation = validate_payload_limits(&attr, &worker_limits(30, 1_000_000))
.expect("blob fields-data-sum should error");
assert_eq!(violation.class, LimitClass::Blob);
assert_eq!(violation.path, "upserted_memo");
assert_eq!(violation.size, 35);
}
#[test]
fn marker_details_map_is_measured_as_payloads_sum() {
let mut details = HashMap::new();
details.insert("marker".to_string(), payloads(1000));
let attr = RecordMarkerCommandAttributes {
details,
..Default::default()
};
let violation =
validate_payload_limits(&attr, &worker_limits(100, 100)).expect("map-sum should error");
assert_eq!(violation.class, LimitClass::Blob);
assert_eq!(violation.path, "details");
}
#[test]
fn single_payload_field_is_measured_as_payload_size() {
let attr = ScheduleNexusOperationCommandAttributes {
input: Some(payload(&[0u8; 1000])),
..Default::default()
};
let violation = validate_payload_limits(&attr, &worker_limits(100, 100))
.expect("single payload should error");
assert_eq!(violation.class, LimitClass::Blob);
assert_eq!(violation.path, "input");
}
#[test]
fn whole_failure_is_measured_as_message_size() {
let attr = FailWorkflowExecutionCommandAttributes {
failure: Some(Failure {
message: "x".repeat(1000),
..Default::default()
}),
};
let violation = validate_payload_limits(&attr, &worker_limits(100, 100))
.expect("whole-failure should error");
assert_eq!(violation.class, LimitClass::Blob);
assert_eq!(violation.path, "failure");
}
#[test]
fn collecting_sink_classifies_error_vs_warning() {
let mut sink = CollectingSink::new(worker_limits(100, 100));
sink.check("over_error", LimitClass::Blob, 200, true);
sink.check("over_warn", LimitClass::Blob, 50, true);
sink.check("under_warn", LimitClass::Blob, 5, true);
assert_eq!(sink.errors.len(), 1);
assert_eq!(sink.errors[0].path, "over_error");
assert_eq!(sink.errors[0].limit, 100);
assert_eq!(sink.warnings.len(), 1);
assert_eq!(sink.warnings[0].path, "over_warn");
assert_eq!(sink.warnings[0].limit, 10);
}
#[test]
fn collecting_sink_warn_only_field_never_errors() {
let mut sink = CollectingSink::new(worker_limits(100, 100));
sink.check("warn_only", LimitClass::Blob, 5000, false);
assert!(sink.errors.is_empty());
assert_eq!(sink.warnings.len(), 1);
assert_eq!(sink.warnings[0].path, "warn_only");
}
#[test]
fn collecting_sink_no_error_limit_only_warns() {
let mut sink = CollectingSink::new(PayloadLimits::warn_only());
sink.check("big", LimitClass::Blob, DEFAULT_BLOB_SIZE_WARN + 1, true);
assert!(sink.errors.is_empty());
assert_eq!(sink.warnings.len(), 1);
assert_eq!(sink.warnings[0].path, "big");
}
#[test]
fn collecting_sink_routes_memo_to_memo_limit() {
let mut sink = CollectingSink::new(worker_limits(1_000_000, 20));
sink.check("blob_field", LimitClass::Blob, 100, true);
sink.check("memo_field", LimitClass::Memo, 100, true);
assert_eq!(sink.errors.len(), 1);
assert_eq!(sink.errors[0].class, LimitClass::Memo);
assert_eq!(sink.errors[0].path, "memo_field");
}
#[test]
fn visits_validated_fields_and_skips_not_validated() {
let req = StartWorkflowExecutionRequest {
input: Some(payloads(1)),
memo: Some(memo_with_key_value("k", 1)),
header: Some(crate::protos::temporal::api::common::v1::Header {
fields: {
let mut fields = HashMap::new();
fields.insert("h".to_string(), payload(&[1]));
fields
},
}),
user_metadata: Some(UserMetadata {
summary: Some(payload(&[1])),
details: Some(payload(&[1])),
}),
..Default::default()
};
let mut sink = RecordingSink::default();
req.validate_payload_limits(&mut sink);
assert_eq!(sink.sorted(), vec!["input", "memo"]);
}
#[test]
fn visits_only_present_fields() {
let req = StartWorkflowExecutionRequest {
input: Some(payloads(1)),
..Default::default()
};
let mut sink = RecordingSink::default();
req.validate_payload_limits(&mut sink);
assert_eq!(sink.sorted(), vec!["input"]);
}
#[test]
fn visits_payload_fields_of_each_command() {
let req = RespondWorkflowTaskCompletedRequest {
commands: vec![
Command {
attributes: Some(Attributes::ScheduleActivityTaskCommandAttributes(
ScheduleActivityTaskCommandAttributes {
input: Some(payloads(1)),
..Default::default()
},
)),
..Default::default()
},
Command {
attributes: Some(Attributes::CompleteWorkflowExecutionCommandAttributes(
CompleteWorkflowExecutionCommandAttributes {
result: Some(payloads(1)),
},
)),
..Default::default()
},
],
..Default::default()
};
let mut sink = RecordingSink::default();
req.validate_payload_limits(&mut sink);
assert_eq!(
sink.sorted(),
vec![
"commands[0].schedule_activity_task_command_attributes.input",
"commands[1].complete_workflow_execution_command_attributes.result",
]
);
}
#[test]
fn visits_protocol_message_body() {
let req = RespondWorkflowTaskCompletedRequest {
messages: vec![Message {
body: Some(Default::default()),
..Default::default()
}],
..Default::default()
};
let mut sink = RecordingSink::default();
req.validate_payload_limits(&mut sink);
assert_eq!(sink.sorted(), vec!["messages[0].body"]);
}
}